Adding missing file

This commit is contained in:
William S. Moses 2021-08-26 09:15:17 -04:00 committed by William Moses
parent 1a1a1955b6
commit 4123d514ea
2 changed files with 171 additions and 2 deletions

View File

@ -0,0 +1,169 @@
//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
// a generic parallel for representation
//===----------------------------------------------------------------------===//
#include "polygeist/Ops.h"
#include "polygeist/Passes/Passes.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#define DEBUG_TYPE "convert-polygeist-to-llvm"
using namespace mlir;
using namespace polygeist;
/// Conversion pattern that transforms a subview op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The subview op is replaced by the descriptor.
struct SubIndexOpLowering : public ConvertOpToLLVMPattern<SubIndexOp> {
using ConvertOpToLLVMPattern<SubIndexOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SubIndexOp subViewOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = subViewOp.getLoc();
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
auto sourceElementTy =
typeConverter->convertType(sourceMemRefType.getElementType());
auto viewMemRefType = subViewOp.getType().cast<MemRefType>();
if (sourceMemRefType.getShape().size() != viewMemRefType.getShape().size())
return failure();
//MemRefDescriptor sourceMemRef(operands.front());
SubIndexOp::Adaptor transformed(operands);
MemRefDescriptor targetMemRef(transformed.source());//MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Offset.
auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
if (false) {
Value baseOffset = targetMemRef.offset(rewriter, loc);
Value stride = targetMemRef.stride(rewriter, loc, 0);
Value offset = transformed.index(); //rewriter.create<mlir::IndexCastOp>(loc, operands.back(), stride.getType());
Value mul = rewriter.create<LLVM::MulOp>(loc, offset, stride);
baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
targetMemRef.setOffset(rewriter, loc, baseOffset);
} else {
Value prev = targetMemRef.alignedPtr(rewriter, loc);
Value idxs[] = {transformed.index()};
targetMemRef.setAlignedPtr(rewriter, loc, rewriter.create<LLVM::GEPOp>(loc, prev.getType(), prev, idxs));
}
rewriter.replaceOp(subViewOp, {targetMemRef});
return success();
}
};
struct Memref2PointerOpLowering : public ConvertOpToLLVMPattern<Memref2PointerOp> {
using ConvertOpToLLVMPattern<Memref2PointerOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Memref2PointerOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
//MemRefDescriptor sourceMemRef(operands.front());
SubIndexOp::Adaptor transformed(operands);
MemRefDescriptor targetMemRef(transformed.source());//MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Offset.
auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value baseOffset = targetMemRef.offset(rewriter, loc);
Value ptr = targetMemRef.alignedPtr(rewriter, loc);
Value idxs[] = {baseOffset};
ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), ptr, idxs);
rewriter.replaceOp(op, {ptr});
return success();
}
};
void populatePolygeistToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// clang-format off
patterns.add<SubIndexOpLowering>(converter);
patterns.add<Memref2PointerOpLowering>(converter);
}
namespace {
struct ConvertPolygeistToLLVMPass : public ConvertPolygeistToLLVMBase<ConvertPolygeistToLLVMPass> {
ConvertPolygeistToLLVMPass() = default;
ConvertPolygeistToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers,
unsigned indexBitwidth, bool useAlignedAlloc,
const llvm::DataLayout &dataLayout) {
this->useBarePtrCallConv = useBarePtrCallConv;
this->emitCWrappers = emitCWrappers;
this->indexBitwidth = indexBitwidth;
this->dataLayout = dataLayout.getStringRepresentation();
}
void runOnOperation() override {
ModuleOp m = getOperation();
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
LowerToLLVMOptions options(&getContext(),
dataLayoutAnalysis.getAtOrAbove(m));
options.useBarePtrCallConv = useBarePtrCallConv;
options.emitCWrappers = emitCWrappers;
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
options.dataLayout = llvm::DataLayout(this->dataLayout);
LLVMTypeConverter converter(&getContext(), options,
&dataLayoutAnalysis);
RewritePatternSet patterns(&getContext());
populatePolygeistToLLVMConversionPatterns(converter, patterns);
populateMemRefToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
populateStdExpandOpsPatterns(patterns);
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
[&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
omp::BarrierOp, omp::TaskwaitOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::polygeist::createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options) {
auto allocLowering = options.allocLowering;
// There is no way to provide additional patterns for pass, so
// AllocLowering::None will always fail.
assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
"LLVMLoweringPass doesn't support AllocLowering::None");
bool useAlignedAlloc =
(allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
return std::make_unique<ConvertPolygeistToLLVMPass>(
options.useBarePtrCallConv, options.emitCWrappers,
options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
}

View File

@ -5222,7 +5222,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
}
mlir::OpBuilder builder(module.getContext());
NamedAttrList attrs(function->getAttrDictionary());
attrs.set("linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
attrs.set("llvm.linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
function->setAttrs(attrs.getDictionary(builder.getContext()));
}
@ -5325,7 +5325,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
SymbolTable::setSymbolVisibility(function, SymbolTable::Visibility::Public);
}
NamedAttrList attrs(function->getAttrDictionary());
attrs.append("linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
attrs.append("llvm.linkage", builder.getI64IntegerAttr(static_cast<int64_t>(lnk)));
function->setAttrs(attrs.getDictionary(builder.getContext()));
functions[name] = function;