Adding missing file
This commit is contained in:
parent
1a1a1955b6
commit
4123d514ea
|
@ -0,0 +1,169 @@
|
||||||
|
//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
|
||||||
|
// a generic parallel for representation
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "polygeist/Ops.h"
|
||||||
|
#include "polygeist/Passes/Passes.h"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
||||||
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "convert-polygeist-to-llvm"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace polygeist;
|
||||||
|
|
||||||
|
/// Conversion pattern that transforms a subview op into:
|
||||||
|
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
|
||||||
|
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
||||||
|
/// and stride.
|
||||||
|
/// The subview op is replaced by the descriptor.
|
||||||
|
struct SubIndexOpLowering : public ConvertOpToLLVMPattern<SubIndexOp> {
|
||||||
|
using ConvertOpToLLVMPattern<SubIndexOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(SubIndexOp subViewOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = subViewOp.getLoc();
|
||||||
|
|
||||||
|
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
|
||||||
|
auto sourceElementTy =
|
||||||
|
typeConverter->convertType(sourceMemRefType.getElementType());
|
||||||
|
|
||||||
|
auto viewMemRefType = subViewOp.getType().cast<MemRefType>();
|
||||||
|
|
||||||
|
if (sourceMemRefType.getShape().size() != viewMemRefType.getShape().size())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
|
||||||
|
//MemRefDescriptor sourceMemRef(operands.front());
|
||||||
|
SubIndexOp::Adaptor transformed(operands);
|
||||||
|
MemRefDescriptor targetMemRef(transformed.source());//MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||||
|
|
||||||
|
// Offset.
|
||||||
|
auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
Value baseOffset = targetMemRef.offset(rewriter, loc);
|
||||||
|
Value stride = targetMemRef.stride(rewriter, loc, 0);
|
||||||
|
Value offset = transformed.index(); //rewriter.create<mlir::IndexCastOp>(loc, operands.back(), stride.getType());
|
||||||
|
Value mul = rewriter.create<LLVM::MulOp>(loc, offset, stride);
|
||||||
|
baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
|
||||||
|
targetMemRef.setOffset(rewriter, loc, baseOffset);
|
||||||
|
} else {
|
||||||
|
Value prev = targetMemRef.alignedPtr(rewriter, loc);
|
||||||
|
Value idxs[] = {transformed.index()};
|
||||||
|
targetMemRef.setAlignedPtr(rewriter, loc, rewriter.create<LLVM::GEPOp>(loc, prev.getType(), prev, idxs));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(subViewOp, {targetMemRef});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Memref2PointerOpLowering : public ConvertOpToLLVMPattern<Memref2PointerOp> {
|
||||||
|
using ConvertOpToLLVMPattern<Memref2PointerOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Memref2PointerOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
//MemRefDescriptor sourceMemRef(operands.front());
|
||||||
|
SubIndexOp::Adaptor transformed(operands);
|
||||||
|
MemRefDescriptor targetMemRef(transformed.source());//MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||||
|
|
||||||
|
// Offset.
|
||||||
|
auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
|
||||||
|
|
||||||
|
Value baseOffset = targetMemRef.offset(rewriter, loc);
|
||||||
|
Value ptr = targetMemRef.alignedPtr(rewriter, loc);
|
||||||
|
Value idxs[] = {baseOffset};
|
||||||
|
ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), ptr, idxs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {ptr});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populatePolygeistToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||||
|
RewritePatternSet &patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns.add<SubIndexOpLowering>(converter);
|
||||||
|
patterns.add<Memref2PointerOpLowering>(converter);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ConvertPolygeistToLLVMPass : public ConvertPolygeistToLLVMBase<ConvertPolygeistToLLVMPass> {
|
||||||
|
ConvertPolygeistToLLVMPass() = default;
|
||||||
|
ConvertPolygeistToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers,
|
||||||
|
unsigned indexBitwidth, bool useAlignedAlloc,
|
||||||
|
const llvm::DataLayout &dataLayout) {
|
||||||
|
this->useBarePtrCallConv = useBarePtrCallConv;
|
||||||
|
this->emitCWrappers = emitCWrappers;
|
||||||
|
this->indexBitwidth = indexBitwidth;
|
||||||
|
this->dataLayout = dataLayout.getStringRepresentation();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp m = getOperation();
|
||||||
|
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
|
||||||
|
|
||||||
|
LowerToLLVMOptions options(&getContext(),
|
||||||
|
dataLayoutAnalysis.getAtOrAbove(m));
|
||||||
|
options.useBarePtrCallConv = useBarePtrCallConv;
|
||||||
|
options.emitCWrappers = emitCWrappers;
|
||||||
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
||||||
|
options.overrideIndexBitwidth(indexBitwidth);
|
||||||
|
|
||||||
|
options.dataLayout = llvm::DataLayout(this->dataLayout);
|
||||||
|
|
||||||
|
LLVMTypeConverter converter(&getContext(), options,
|
||||||
|
&dataLayoutAnalysis);
|
||||||
|
RewritePatternSet patterns(&getContext());
|
||||||
|
populatePolygeistToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateStdExpandOpsPatterns(patterns);
|
||||||
|
|
||||||
|
LLVMConversionTarget target(getContext());
|
||||||
|
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
|
||||||
|
[&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
|
||||||
|
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
|
||||||
|
omp::BarrierOp, omp::TaskwaitOp>();
|
||||||
|
if (failed(applyPartialConversion(m, target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::polygeist::createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options) {
|
||||||
|
auto allocLowering = options.allocLowering;
|
||||||
|
// There is no way to provide additional patterns for pass, so
|
||||||
|
// AllocLowering::None will always fail.
|
||||||
|
assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
|
||||||
|
"LLVMLoweringPass doesn't support AllocLowering::None");
|
||||||
|
bool useAlignedAlloc =
|
||||||
|
(allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
|
||||||
|
return std::make_unique<ConvertPolygeistToLLVMPass>(
|
||||||
|
options.useBarePtrCallConv, options.emitCWrappers,
|
||||||
|
options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
|
||||||
|
}
|
|
@ -5222,7 +5222,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
|
||||||
}
|
}
|
||||||
mlir::OpBuilder builder(module.getContext());
|
mlir::OpBuilder builder(module.getContext());
|
||||||
NamedAttrList attrs(function->getAttrDictionary());
|
NamedAttrList attrs(function->getAttrDictionary());
|
||||||
attrs.set("linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
|
attrs.set("llvm.linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
|
||||||
function->setAttrs(attrs.getDictionary(builder.getContext()));
|
function->setAttrs(attrs.getDictionary(builder.getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5325,7 +5325,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
|
||||||
SymbolTable::setSymbolVisibility(function, SymbolTable::Visibility::Public);
|
SymbolTable::setSymbolVisibility(function, SymbolTable::Visibility::Public);
|
||||||
}
|
}
|
||||||
NamedAttrList attrs(function->getAttrDictionary());
|
NamedAttrList attrs(function->getAttrDictionary());
|
||||||
attrs.append("linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
|
attrs.append("llvm.linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
|
||||||
function->setAttrs(attrs.getDictionary(builder.getContext()));
|
function->setAttrs(attrs.getDictionary(builder.getContext()));
|
||||||
|
|
||||||
functions[name] = function;
|
functions[name] = function;
|
||||||
|
|
Loading…
Reference in New Issue