Bump Polygeist to version 3f0e66045f01305e2b09de186552a3c42cdfd8e5

This commit is contained in:
Hanchen Ye 2022-01-21 02:58:31 -06:00
parent 072a330cae
commit f354a8a3a6
19 changed files with 188 additions and 196 deletions

View File

@ -9,7 +9,8 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
def AssignOp : HLSCppOp<"assign", [SameOperandsAndResultType, NoSideEffect]> {
def AssignOp : HLSCppOp<"assign",
[SameOperandsAndResultElementType, NoSideEffect]> {
let summary = "Assign the input value to the output";
let description = [{
This hlscpp.assign operation assigns the input value to the output, and can

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -39,11 +40,11 @@ public:
// Memref-related statements.
memref::AllocOp, memref::AllocaOp, memref::LoadOp, memref::StoreOp,
memref::DeallocOp, memref::DmaStartOp, memref::DmaWaitOp,
memref::ViewOp, memref::SubViewOp, AtomicRMWOp, GenericAtomicRMWOp,
AtomicYieldOp,
memref::ViewOp, memref::SubViewOp, memref::AtomicRMWOp,
GenericAtomicRMWOp, AtomicYieldOp,
// Tensor-related statements.
memref::TensorLoadOp, memref::TensorStoreOp, memref::BufferCastOp,
SplatOp, memref::DimOp, RankOp,
bufferization::ToMemrefOp, bufferization::ToTensorOp,
memref::TensorStoreOp, SplatOp, memref::DimOp, memref::RankOp,
// Unary expressions.
math::AbsOp, math::CeilOp, math::CosOp, math::SinOp, math::TanhOp,
math::SqrtOp, math::RsqrtOp, math::ExpOp, math::Exp2Op, math::LogOp,
@ -118,19 +119,19 @@ public:
HANDLE(memref::DeallocOp);
HANDLE(memref::DmaStartOp);
HANDLE(memref::DmaWaitOp);
HANDLE(AtomicRMWOp);
HANDLE(memref::AtomicRMWOp);
HANDLE(GenericAtomicRMWOp);
HANDLE(AtomicYieldOp);
HANDLE(memref::ViewOp);
HANDLE(memref::SubViewOp);
// Tensor-related statements.
HANDLE(memref::TensorLoadOp);
HANDLE(bufferization::ToMemrefOp);
HANDLE(bufferization::ToTensorOp);
HANDLE(memref::TensorStoreOp);
HANDLE(memref::BufferCastOp);
HANDLE(SplatOp);
HANDLE(memref::DimOp);
HANDLE(RankOp);
HANDLE(memref::RankOp);
// Unary expressions.
HANDLE(math::AbsOp);

View File

@ -8,8 +8,9 @@
#define SCALEHLS_INITALLDIALECTS_H
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -32,6 +33,7 @@ inline void registerAllDialects(mlir::DialectRegistry &registry) {
mlir::math::MathDialect,
mlir::arith::ArithmeticDialect,
mlir::scf::SCFDialect,
mlir::bufferization::BufferizationDialect,
mlir::linalg::LinalgDialect,
mlir::LLVM::LLVMDialect
>();

View File

@ -60,9 +60,6 @@ unsigned getCommonSurroundingLoops(Operation *A, Operation *B,
/// Calculate the upper and lower bound of "bound" if possible.
Optional<std::pair<int64_t, int64_t>> getBoundOfAffineBound(AffineBound bound);
/// Return the layout map of "memrefType".
AffineMap getLayoutMap(MemRefType memrefType);
/// Calculate partition factors through analyzing the "memrefType" and return
/// them in "factors". Meanwhile, the overall partition number is calculated and
/// returned as well.

View File

@ -8,8 +8,6 @@ include(AddMLIRPython)
# Declare native Python extension
################################################################################
set(MLIR_PYTHON_SOURCE_DIR "${MLIR_MAIN_SRC_DIR}/lib/Bindings/Python")
declare_mlir_python_sources(ScaleHLSBindingsPythonExtension)
declare_mlir_python_extension(ScaleHLSBindingsPythonExtension.Core
@ -17,7 +15,6 @@ declare_mlir_python_extension(ScaleHLSBindingsPythonExtension.Core
ADD_TO_PARENT ScaleHLSBindingsPythonExtension
SOURCES
ScaleHLSModule.cpp
${MLIR_PYTHON_SOURCE_DIR}/PybindUtils.cpp
EMBED_CAPI_LINK_LIBS
MLIRScaleHLSCAPIHLSCpp
MLIRScaleHLSCAPIEmitHLSCpp

View File

@ -29,6 +29,17 @@ using namespace mlir;
using namespace mlir::python;
using namespace scalehls;
//===----------------------------------------------------------------------===//
// PybindUtils
//===----------------------------------------------------------------------===//
pybind11::error_already_set
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
auto messageStr = message.str();
PyErr_SetString(excClass, messageStr.c_str());
return pybind11::error_already_set();
}
//===----------------------------------------------------------------------===//
// Customized Python classes
//===----------------------------------------------------------------------===//

View File

@ -6,6 +6,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "scalehls/Conversion/Passes.h"
@ -93,7 +94,7 @@ void LegalizeOnnx::runOnOperation() {
if (auto defOp = index.getDefiningOp())
if (auto constOp = dyn_cast<arith::ConstantOp>(defOp))
if (constOp.getType().isa<IndexType>())
if (auto constAttr = constOp.value().dyn_cast<IntegerAttr>()) {
if (auto constAttr = constOp.getValue().dyn_cast<IntegerAttr>()) {
exprs.push_back(
builder.getAffineConstantExpr(constAttr.getUInt()));
continue;
@ -137,7 +138,7 @@ void LegalizeOnnx::runOnOperation() {
// constant operation to substitute it.
builder.setInsertionPoint(&op);
auto tensor = builder.create<arith::ConstantOp>(op.getLoc(), value);
auto memref = builder.create<memref::BufferCastOp>(
auto memref = builder.create<bufferization::ToMemrefOp>(
op.getLoc(), op.getResult(0).getType(), tensor);
op.getResult(0).replaceAllUsesWith(memref);

View File

@ -47,8 +47,9 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) {
// TODO: determine memory kind according to data type.
MemoryKind kind = MemoryKind::BRAM_S2P;
auto newType = MemRefType::get(type.getShape(), type.getElementType(),
type.getAffineMaps(), (unsigned)kind);
auto newType =
MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout().getAffineMap(), (unsigned)kind);
memref.setType(newType);
}
}

View File

@ -6,6 +6,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
@ -22,9 +23,9 @@ static Type quantizeType(Type type, OpBuilder &builder) {
// biases are quantized to 32-bits.
if (auto memrefType = type.dyn_cast<MemRefType>()) {
auto shape = memrefType.getShape();
auto maps = memrefType.getAffineMaps();
auto layout = memrefType.getLayout();
auto memorySpace = memrefType.getMemorySpace();
return MemRefType::get(shape, int8Type, maps, memorySpace);
return MemRefType::get(shape, int8Type, layout, memorySpace);
} else if (auto tensorType = type.dyn_cast<TensorType>()) {
auto shape = tensorType.getShape();
@ -48,7 +49,7 @@ static void quantizeBlock(Block &block, OpBuilder &builder,
builder.setInsertionPoint(&op);
if (auto constOp = dyn_cast<arith::ConstantOp>(op)) {
auto attr = constOp.value();
auto attr = constOp.getValue();
if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
int8_t floatValue = floatAttr.getValue().convertToFloat();
@ -79,16 +80,16 @@ static void quantizeBlock(Block &block, OpBuilder &builder,
} else if (auto castOp = dyn_cast<arith::UIToFPOp>(op)) {
newOp = builder.create<hlscpp::CastOp>(castOp.getLoc(), int8Type,
castOp.in());
castOp.getIn());
} else if (auto allocOp = dyn_cast<memref::AllocOp>(op)) {
auto newType = quantizeType(allocOp.memref().getType(), builder);
newOp = builder.create<memref::AllocOp>(allocOp.getLoc(),
newType.cast<MemRefType>());
} else if (auto bufferOp = dyn_cast<memref::BufferCastOp>(op)) {
} else if (auto bufferOp = dyn_cast<bufferization::ToMemrefOp>(op)) {
auto newType = quantizeType(bufferOp.memref().getType(), builder);
newOp = builder.create<memref::BufferCastOp>(
newOp = builder.create<bufferization::ToMemrefOp>(
bufferOp.getLoc(), newType.cast<MemRefType>(), bufferOp.tensor());
}
@ -98,23 +99,23 @@ static void quantizeBlock(Block &block, OpBuilder &builder,
else if (auto selectOp = dyn_cast<mlir::SelectOp>(op))
newOp = builder.create<mlir::SelectOp>(
selectOp.getLoc(), selectOp.condition(), selectOp.true_value(),
selectOp.false_value());
selectOp.getLoc(), selectOp.getCondition(), selectOp.getTrueValue(),
selectOp.getFalseValue());
else if (auto mulOp = dyn_cast<arith::MulFOp>(op)) {
auto lhsValue = builder.create<hlscpp::CastOp>(mulOp.getLoc(), int16Type,
mulOp.lhs());
mulOp.getLhs());
auto rhsValue = builder.create<hlscpp::CastOp>(mulOp.getLoc(), int16Type,
mulOp.rhs());
mulOp.getRhs());
newOp = builder.create<hlscpp::MulOp>(mulOp.getLoc(), int32Type, lhsValue,
rhsValue);
}
else if (auto addOp = dyn_cast<arith::AddFOp>(op)) {
auto lhsValue = builder.create<hlscpp::CastOp>(addOp.getLoc(), int32Type,
addOp.lhs());
addOp.getLhs());
auto rhsValue = builder.create<hlscpp::CastOp>(addOp.getLoc(), int32Type,
addOp.rhs());
addOp.getRhs());
auto accValue = builder.create<hlscpp::AddOp>(addOp.getLoc(), int32Type,
lhsValue, rhsValue);
newOp =
@ -122,12 +123,12 @@ static void quantizeBlock(Block &block, OpBuilder &builder,
}
else if (auto divOp = dyn_cast<arith::DivFOp>(op))
newOp = builder.create<arith::DivSIOp>(divOp.getLoc(), divOp.lhs(),
divOp.rhs());
newOp = builder.create<arith::DivSIOp>(divOp.getLoc(), divOp.getLhs(),
divOp.getRhs());
else if (auto cmpOp = dyn_cast<arith::CmpFOp>(op)) {
arith::CmpIPredicate predicate;
switch (cmpOp.predicate()) {
switch (cmpOp.getPredicate()) {
case arith::CmpFPredicate::OEQ:
predicate = arith::CmpIPredicate::eq;
break;
@ -151,7 +152,7 @@ static void quantizeBlock(Block &block, OpBuilder &builder,
break;
}
newOp = builder.create<arith::CmpIOp>(cmpOp.getLoc(), predicate,
cmpOp.lhs(), cmpOp.rhs());
cmpOp.getLhs(), cmpOp.getRhs());
} else if (!isa<mlir::CallOp, mlir::ReturnOp, memref::DeallocOp,
mlir::AffineApplyOp, mlir::AffineForOp, mlir::AffineIfOp,

View File

@ -32,32 +32,11 @@ void HLSCppDialect::initialize() {
#define GET_ATTRDEF_CLASSES
#include "scalehls/Dialect/HLSCpp/HLSCppAttributes.cpp.inc"
Attribute HLSCppDialect::parseAttribute(DialectAsmParser &p, Type type) const {
StringRef attrName;
Attribute attr;
if (p.parseKeyword(&attrName))
return Attribute();
auto parseResult = generatedAttributeParser(p, attrName, type, attr);
if (parseResult.hasValue())
return attr;
p.emitError(p.getNameLoc(), "Unexpected hlscpp attribute");
return Attribute();
}
void HLSCppDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
if (succeeded(generatedAttributePrinter(attr, p)))
return;
llvm_unreachable("Unexpected attribute");
}
//===----------------------------------------------------------------------===//
// ResourceAttr
//===----------------------------------------------------------------------===//
Attribute ResourceAttr::parse(DialectAsmParser &p, Type type) {
Attribute ResourceAttr::parse(AsmParser &p, Type type) {
StringRef lutKw, dspKw, bramKw, nonShareDspKw;
int64_t lut, dsp, bram, nonShareDsp;
if (p.parseLess() || p.parseKeyword(&lutKw) || p.parseEqual() ||
@ -75,7 +54,7 @@ Attribute ResourceAttr::parse(DialectAsmParser &p, Type type) {
return ResourceAttr::get(p.getContext(), lut, dsp, bram, nonShareDsp);
}
void ResourceAttr::print(DialectAsmPrinter &p) const {
void ResourceAttr::print(AsmPrinter &p) const {
p << getMnemonic() << "<lut=" << getLut() << ", dsp=" << getDsp()
<< ", bram=" << getBram() << ", nonShareDsp=" << getNonShareDsp() << ">";
}
@ -84,7 +63,7 @@ void ResourceAttr::print(DialectAsmPrinter &p) const {
// TimingAttr
//===----------------------------------------------------------------------===//
Attribute TimingAttr::parse(DialectAsmParser &p, Type type) {
Attribute TimingAttr::parse(AsmParser &p, Type type) {
int64_t begin, end, latency, interval;
if (p.parseLess() || p.parseInteger(begin) || p.parseArrow() ||
p.parseInteger(end) || p.parseComma() || p.parseInteger(latency) ||
@ -94,7 +73,7 @@ Attribute TimingAttr::parse(DialectAsmParser &p, Type type) {
return TimingAttr::get(p.getContext(), begin, end, latency, interval);
}
void TimingAttr::print(DialectAsmPrinter &p) const {
void TimingAttr::print(AsmPrinter &p) const {
p << getMnemonic() << "<" << getBegin() << " -> " << getEnd() << ", "
<< getLatency() << ", " << getInterval() << ">";
}
@ -103,7 +82,7 @@ void TimingAttr::print(DialectAsmPrinter &p) const {
// LoopInfoAttr
//===----------------------------------------------------------------------===//
Attribute LoopInfoAttr::parse(DialectAsmParser &p, Type type) {
Attribute LoopInfoAttr::parse(AsmParser &p, Type type) {
StringRef flattenTripCountKw, iterLatencyKw, minIIKw;
int64_t flattenTripCount, iterLatency, minII;
if (p.parseLess() || p.parseKeyword(&flattenTripCountKw) || p.parseEqual() ||
@ -122,7 +101,7 @@ Attribute LoopInfoAttr::parse(DialectAsmParser &p, Type type) {
minII);
}
void LoopInfoAttr::print(DialectAsmPrinter &p) const {
void LoopInfoAttr::print(AsmPrinter &p) const {
p << getMnemonic() << "<flattenTripCount=" << getFlattenTripCount()
<< ", iterLatency=" << getIterLatency() << ", minII=" << getMinII() << ">";
}
@ -131,7 +110,7 @@ void LoopInfoAttr::print(DialectAsmPrinter &p) const {
// LoopDirectiveAttr
//===----------------------------------------------------------------------===//
Attribute LoopDirectiveAttr::parse(DialectAsmParser &p, Type type) {
Attribute LoopDirectiveAttr::parse(AsmParser &p, Type type) {
StringRef pipelineKw, targetIIKw, dataflowKw, flattenKw, parallelKw;
StringRef pipeline, dataflow, flatten, parallel;
int64_t targetII;
@ -157,7 +136,7 @@ Attribute LoopDirectiveAttr::parse(DialectAsmParser &p, Type type) {
parallel == "true");
}
void LoopDirectiveAttr::print(DialectAsmPrinter &p) const {
void LoopDirectiveAttr::print(AsmPrinter &p) const {
p << getMnemonic() << "<pipeline=" << getPipeline()
<< ", targetII=" << getTargetII() << ", dataflow=" << getDataflow()
<< ", flatten=" << getFlatten() << ", parallel=" << getParallel() << ">";
@ -167,7 +146,7 @@ void LoopDirectiveAttr::print(DialectAsmPrinter &p) const {
// FuncDirectiveAttr
//===----------------------------------------------------------------------===//
Attribute FuncDirectiveAttr::parse(DialectAsmParser &p, Type type) {
Attribute FuncDirectiveAttr::parse(AsmParser &p, Type type) {
StringRef pipelineKw, targetIntervalKw, dataflowKw, topFuncKw;
StringRef pipeline, dataflow, topFunc;
int64_t targetInterval;
@ -190,7 +169,7 @@ Attribute FuncDirectiveAttr::parse(DialectAsmParser &p, Type type) {
topFunc == "true");
}
void FuncDirectiveAttr::print(DialectAsmPrinter &p) const {
void FuncDirectiveAttr::print(AsmPrinter &p) const {
p << getMnemonic() << "<pipeline=" << getPipeline()
<< ", targetInterval=" << getTargetInterval()
<< ", dataflow=" << getDataflow() << ", topFunc=" << getTopFunc() << ">";

View File

@ -4,6 +4,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
@ -35,7 +36,7 @@ void HLSKernelBufferize::runOnOperation() {
if (auto operandType = operand.getType().dyn_cast<RankedTensorType>()) {
auto memRefType = MemRefType::get(operandType.getShape(),
operandType.getElementType());
auto operandMemRef = builder.create<memref::BufferCastOp>(
auto operandMemRef = builder.create<bufferization::ToMemrefOp>(
func.getLoc(), memRefType, operand);
op->setOperand(operandIndex, operandMemRef);
}
@ -54,8 +55,8 @@ void HLSKernelBufferize::runOnOperation() {
// Create a TensorLoad operaion to replace the original returned tensor.
builder.setInsertionPointAfter(op);
auto resultTensor =
builder.create<memref::TensorLoadOp>(func.getLoc(), resultMemRef);
auto resultTensor = builder.create<bufferization::ToTensorOp>(
func.getLoc(), resultMemRef);
op->getResult(0).replaceAllUsesWith(resultTensor);
}
});

View File

@ -194,29 +194,17 @@ scalehls::getBoundOfAffineBound(AffineBound bound) {
return std::pair<int64_t, int64_t>(*minmax.first, *minmax.second);
}
/// Return the layout map of "memrefType".
AffineMap scalehls::getLayoutMap(MemRefType memrefType) {
// Check whether the memref has layout map.
auto memrefMaps = memrefType.getAffineMaps();
if (memrefMaps.empty())
return (AffineMap) nullptr;
return memrefMaps.back();
}
bool scalehls::isFullyPartitioned(MemRefType memrefType) {
if (memrefType.getRank() == 0)
return true;
bool fullyPartitioned = false;
if (auto layoutMap = getLayoutMap(memrefType)) {
SmallVector<int64_t, 8> factors;
getPartitionFactors(memrefType, &factors);
SmallVector<int64_t, 8> factors;
getPartitionFactors(memrefType, &factors);
auto shapes = memrefType.getShape();
fullyPartitioned =
factors == SmallVector<int64_t, 8>(shapes.begin(), shapes.end());
}
auto shapes = memrefType.getShape();
fullyPartitioned =
factors == SmallVector<int64_t, 8>(shapes.begin(), shapes.end());
return fullyPartitioned;
}
@ -227,7 +215,7 @@ bool scalehls::isFullyPartitioned(MemRefType memrefType) {
int64_t scalehls::getPartitionFactors(MemRefType memrefType,
SmallVector<int64_t, 8> *factors) {
auto shape = memrefType.getShape();
auto layoutMap = getLayoutMap(memrefType);
auto layoutMap = memrefType.getLayout().getAffineMap();
int64_t accumFactor = 1;
for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {

View File

@ -15,7 +15,7 @@ using namespace hlscpp;
static void updateSubFuncs(FuncOp func, Builder builder) {
func.walk([&](CallOp op) {
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.calleeAttr());
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr());
auto subFunc = dyn_cast<FuncOp>(callee);
// Set sub-function type.
@ -258,7 +258,7 @@ bool scalehls::applyAutoArrayPartition(FuncOp func) {
// Apply partition to all sub-functions and traverse all function to update
// the "partitionsMap".
func.walk([&](CallOp op) {
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.calleeAttr());
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr());
auto subFunc = dyn_cast<FuncOp>(callee);
assert(subFunc && "callable is not a function operation");
@ -268,35 +268,38 @@ bool scalehls::applyAutoArrayPartition(FuncOp func) {
auto subFuncType = subFunc.getType();
unsigned index = 0;
for (auto inputType : subFuncType.getInputs()) {
if (auto memrefType = inputType.dyn_cast<MemRefType>())
if (auto layout = getLayoutMap(memrefType)) {
auto &partitions = partitionsMap[op.getOperand(index)];
if (auto memrefType = inputType.dyn_cast<MemRefType>()) {
auto &partitions = partitionsMap[op.getOperand(index)];
// If the current partitionsMap is empty, initialize it with no
// partition and factor of 1.
if (partitions.empty()) {
for (int64_t dim = 0; dim < memrefType.getRank(); ++dim)
partitions.push_back(PartitionInfo(PartitionKind::NONE, 1));
}
// If the current partitionsMap is empty, initialize it with no
// partition and factor of 1.
if (partitions.empty()) {
for (int64_t dim = 0; dim < memrefType.getRank(); ++dim)
partitions.push_back(PartitionInfo(PartitionKind::NONE, 1));
}
// Get the partition factor collected from sub-function.
SmallVector<int64_t, 8> factors;
getPartitionFactors(memrefType, &factors);
// Get the partition factor collected from sub-function.
SmallVector<int64_t, 8> factors;
getPartitionFactors(memrefType, &factors);
// Traverse all dimension of the memref.
for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
auto factor = factors[dim];
// Traverse all dimension of the memref.
for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
auto factor = factors[dim];
// If the factor from the sub-function is larger than the current
// factor, replace it.
if (factor > partitions[dim].second) {
if (layout.getResult(dim).getKind() == AffineExprKind::FloorDiv)
partitions[dim] = PartitionInfo(PartitionKind::BLOCK, factor);
else
partitions[dim] = PartitionInfo(PartitionKind::CYCLIC, factor);
}
// If the factor from the sub-function is larger than the current
// factor, replace it.
if (factor > partitions[dim].second) {
if (memrefType.getLayout()
.getAffineMap()
.getResult(dim)
.getKind() == AffineExprKind::FloorDiv)
partitions[dim] = PartitionInfo(PartitionKind::BLOCK, factor);
else
partitions[dim] = PartitionInfo(PartitionKind::CYCLIC, factor);
}
}
}
++index;
}
});

View File

@ -4,7 +4,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "scalehls/Dialect/HLSKernel/HLSKernel.h"
#include "scalehls/Transforms/Passes.h"
@ -43,7 +44,7 @@ static void getSuccessorsMap(Block &block, SuccessorsMap &map) {
for (auto &op : block.getOperations()) {
// TODO: Some operations are dataflow source, which will not be scheduled.
if (isa<memref::AllocOp, memref::AllocaOp, arith::ConstantOp,
memref::TensorLoadOp, memref::BufferCastOp>(op))
bufferization::ToTensorOp, bufferization::ToMemrefOp>(op))
continue;
// Collect all memref results if the current operation is a loop.

View File

@ -5,7 +5,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "scalehls/Dialect/HLSKernel/HLSKernel.h"
#include "scalehls/Transforms/Passes.h"
@ -75,7 +76,7 @@ static bool applySplitFunction(FuncOp func, ArrayRef<Operation *> ops,
// function, except BufferCastOp.
if (auto defOp = input.getDefiningOp()) {
if (input.getType().isa<MemRefType>() &&
!isa<memref::BufferCastOp>(defOp)) {
!isa<bufferization::ToMemrefOp>(defOp)) {
bool isInternalMemory = true;
for (auto user : input.getUsers()) {
bool hasAncestor = false;

View File

@ -29,7 +29,7 @@ static AffineMap simplify(AffineMap map) {
/// operation if necessary.
template <typename AttrT>
static void
simplifyAndUpdateAttr(Operation *op, Identifier name, AttrT attr,
simplifyAndUpdateAttr(Operation *op, StringAttr name, AttrT attr,
DenseMap<Attribute, Attribute> &simplifiedAttrs) {
auto &simplified = simplifiedAttrs[attr];
if (simplified == attr)
@ -66,10 +66,10 @@ static void simplifyAffineStructures(Block &block) {
SmallVector<Operation *> opsToSimplify;
block.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttr(op, attr.first, mapAttr, simplifiedAttrs);
else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
simplifyAndUpdateAttr(op, attr.first, setAttr, simplifiedAttrs);
if (auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttr(op, attr.getName(), mapAttr, simplifiedAttrs);
else if (auto setAttr = attr.getValue().dyn_cast<IntegerSetAttr>())
simplifyAndUpdateAttr(op, attr.getName(), setAttr, simplifiedAttrs);
}
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))

View File

@ -26,7 +26,7 @@ void ScaleHLSEstimator::getPartitionIndices(Operation *op) {
auto memrefType = access.memref.getType().cast<MemRefType>();
// If the layout map does not exist, it means the memory is not partitioned.
auto layoutMap = getLayoutMap(memrefType);
auto layoutMap = memrefType.getLayout().getAffineMap();
if (!layoutMap) {
auto partitionIndices = SmallVector<int64_t, 8>(memrefType.getRank(), 0);
op->setAttr("partition_indices", builder.getI64ArrayAttr(partitionIndices));
@ -526,7 +526,7 @@ bool ScaleHLSEstimator::visitOp(AffineIfOp op, int64_t begin) {
}
bool ScaleHLSEstimator::visitOp(CallOp op, int64_t begin) {
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.calleeAttr());
auto callee = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr());
auto subFunc = dyn_cast<FuncOp>(callee);
assert(subFunc && "callable is not a function operation");

View File

@ -160,7 +160,7 @@ SmallString<8> ScaleHLSEmitterBase::getName(Value val) {
// than the value name.
if (auto defOp = val.getDefiningOp()) {
if (auto constOp = dyn_cast<arith::ConstantOp>(defOp)) {
auto constAttr = constOp.value();
auto constAttr = constOp.getValue();
if (auto floatAttr = constAttr.dyn_cast<FloatAttr>()) {
auto value = floatAttr.getValueAsDouble();
@ -215,11 +215,11 @@ public:
void emitStore(memref::StoreOp op);
/// Tensor-related statement emitters.
void emitTensorLoad(memref::TensorLoadOp op);
void emitTensorToMemref(bufferization::ToMemrefOp op);
void emitMemrefToTensor(bufferization::ToTensorOp op);
void emitTensorStore(memref::TensorStoreOp op);
void emitTensorToMemref(memref::BufferCastOp op);
void emitDim(memref::DimOp op);
void emitRank(RankOp op);
void emitRank(memref::RankOp op);
/// Standard expression emitters.
void emitBinary(Operation *op, const char *syntax);
@ -384,17 +384,17 @@ public:
bool visitOp(memref::DeallocOp op) { return true; }
/// Tensor-related statements.
bool visitOp(memref::TensorLoadOp op) {
return emitter.emitTensorLoad(op), true;
bool visitOp(bufferization::ToMemrefOp op) {
return emitter.emitTensorToMemref(op), true;
}
bool visitOp(bufferization::ToTensorOp op) {
return emitter.emitMemrefToTensor(op), true;
}
bool visitOp(memref::TensorStoreOp op) {
return emitter.emitTensorStore(op), true;
}
bool visitOp(memref::BufferCastOp op) {
return emitter.emitTensorToMemref(op), true;
}
bool visitOp(memref::DimOp op) { return emitter.emitDim(op), true; }
bool visitOp(RankOp op) { return emitter.emitRank(op), true; }
bool visitOp(memref::RankOp op) { return emitter.emitRank(op), true; }
/// HLSCpp operations.
bool visitOp(AssignOp op) { return emitter.emitAssign(op), true; }
@ -542,19 +542,19 @@ void ModuleEmitter::emitScfFor(scf::ForOp op) {
// Emit lower bound.
emitValue(iterVar);
os << " = ";
emitValue(op.lowerBound());
emitValue(op.getLowerBound());
os << "; ";
// Emit upper bound.
emitValue(iterVar);
os << " < ";
emitValue(op.upperBound());
emitValue(op.getUpperBound());
os << "; ";
// Emit increase step.
emitValue(iterVar);
os << " += ";
emitValue(op.step());
emitValue(op.getStep());
os << ") {";
emitInfoAndNewLine(op);
@ -584,19 +584,19 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) {
indent();
os << "if (";
emitValue(op.condition());
emitValue(op.getCondition());
os << ") {";
emitInfoAndNewLine(op);
addIndent();
emitBlock(op.thenRegion().front());
emitBlock(op.getThenRegion().front());
reduceIndent();
if (!op.elseRegion().empty()) {
if (!op.getElseRegion().empty()) {
indent();
os << "} else {\n";
addIndent();
emitBlock(op.elseRegion().front());
emitBlock(op.getElseRegion().front());
reduceIndent();
}
@ -912,7 +912,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
os << "} else {\n";
// Otherwise, generated values will be accumulated/reduced to the
// current results with corresponding AtomicRMWKind operations.
// current results with corresponding arith::AtomicRMWKind operations.
addIndent();
auto RMWAttrs =
getIntArrayAttrValue(parentOp, parentOp.getReductionsAttrName());
@ -921,39 +921,47 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
unsigned rank = emitNestedLoopHead(result);
indent();
emitValue(result, rank);
switch ((AtomicRMWKind)RMWAttrs[resultIdx]) {
case (AtomicRMWKind::addf):
case (AtomicRMWKind::addi):
switch ((arith::AtomicRMWKind)RMWAttrs[resultIdx]) {
case (arith::AtomicRMWKind::addf):
case (arith::AtomicRMWKind::addi):
os << " += ";
emitValue(op.getOperand(resultIdx++), rank);
break;
case (AtomicRMWKind::assign):
case (arith::AtomicRMWKind::assign):
os << " = ";
emitValue(op.getOperand(resultIdx++), rank);
break;
case (AtomicRMWKind::maxf):
case (AtomicRMWKind::maxs):
case (AtomicRMWKind::maxu):
case (arith::AtomicRMWKind::maxf):
case (arith::AtomicRMWKind::maxs):
case (arith::AtomicRMWKind::maxu):
os << " = max(";
emitValue(result, rank);
os << ", ";
emitValue(op.getOperand(resultIdx++), rank);
os << ")";
break;
case (AtomicRMWKind::minf):
case (AtomicRMWKind::mins):
case (AtomicRMWKind::minu):
case (arith::AtomicRMWKind::minf):
case (arith::AtomicRMWKind::mins):
case (arith::AtomicRMWKind::minu):
os << " = min(";
emitValue(result, rank);
os << ", ";
emitValue(op.getOperand(resultIdx++), rank);
os << ")";
break;
case (AtomicRMWKind::mulf):
case (AtomicRMWKind::muli):
case (arith::AtomicRMWKind::mulf):
case (arith::AtomicRMWKind::muli):
os << " *= ";
emitValue(op.getOperand(resultIdx++), rank);
break;
case (arith::AtomicRMWKind::ori):
os << " |= ";
emitValue(op.getOperand(resultIdx++), rank);
break;
case (arith::AtomicRMWKind::andi):
os << " &= ";
emitValue(op.getOperand(resultIdx++), rank);
break;
}
os << ";";
emitInfoAndNewLine(op);
@ -1013,7 +1021,25 @@ void ModuleEmitter::emitStore(memref::StoreOp op) {
}
/// Tensor-related statement emitters.
void ModuleEmitter::emitTensorLoad(memref::TensorLoadOp op) {
void ModuleEmitter::emitTensorToMemref(bufferization::ToMemrefOp op) {
// A declared result indicates that the memref is output of the function, and
// has been declared in the function signature.
if (isDeclared(op.getResult())) {
auto rank = emitNestedLoopHead(op.getResult());
indent();
emitValue(op.getResult(), rank);
os << " = ";
emitValue(op.getOperand(), rank);
os << ";";
emitInfoAndNewLine(op);
emitNestedLoopTail(rank);
} else {
addAlias(op.getOperand(), op.getResult());
emitArrayDirectives(op.getResult());
}
}
void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) {
auto rank = emitNestedLoopHead(op.getResult());
indent();
emitValue(op.getResult(), rank);
@ -1035,28 +1061,10 @@ void ModuleEmitter::emitTensorStore(memref::TensorStoreOp op) {
emitNestedLoopTail(rank);
}
void ModuleEmitter::emitTensorToMemref(memref::BufferCastOp op) {
// A declared result indicates that the memref is output of the function, and
// has been declared in the function signature.
if (isDeclared(op.getResult())) {
auto rank = emitNestedLoopHead(op.getResult());
indent();
emitValue(op.getResult(), rank);
os << " = ";
emitValue(op.getOperand(), rank);
os << ";";
emitInfoAndNewLine(op);
emitNestedLoopTail(rank);
} else {
addAlias(op.getOperand(), op.getResult());
emitArrayDirectives(op.getResult());
}
}
void ModuleEmitter::emitDim(memref::DimOp op) {
if (auto constOp =
dyn_cast<arith::ConstantOp>(op.getOperand(1).getDefiningOp())) {
auto constVal = constOp.value().cast<IntegerAttr>().getInt();
auto constVal = constOp.getValue().cast<IntegerAttr>().getInt();
auto type = op.getOperand(0).getType().cast<ShapedType>();
if (type.hasStaticShape()) {
@ -1074,7 +1082,7 @@ void ModuleEmitter::emitDim(memref::DimOp op) {
emitError(op, "index is not a constant.");
}
void ModuleEmitter::emitRank(RankOp op) {
void ModuleEmitter::emitRank(memref::RankOp op) {
auto type = op.getOperand().getType().cast<ShapedType>();
if (type.hasRank()) {
indent();
@ -1138,7 +1146,7 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) {
if (isDeclared(op.getResult()))
return;
if (auto denseAttr = op.value().dyn_cast<DenseElementsAttr>()) {
if (auto denseAttr = op.getValue().dyn_cast<DenseElementsAttr>()) {
indent();
emitArrayDecl(op.getResult());
os << " = {";
@ -1365,29 +1373,28 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
bool emitPragmaFlag = false;
auto type = memref.getType().cast<MemRefType>();
if (auto layoutMap = getLayoutMap(type)) {
// Emit array_partition pragma(s).
SmallVector<int64_t, 8> factors;
getPartitionFactors(type, &factors);
// Emit array_partition pragma(s).
SmallVector<int64_t, 8> factors;
getPartitionFactors(type, &factors);
for (int64_t dim = 0; dim < type.getRank(); ++dim) {
if (factors[dim] != 1) {
emitPragmaFlag = true;
for (int64_t dim = 0; dim < type.getRank(); ++dim) {
if (factors[dim] != 1) {
emitPragmaFlag = true;
indent();
os << "#pragma HLS array_partition";
os << " variable=";
emitValue(memref);
indent();
os << "#pragma HLS array_partition";
os << " variable=";
emitValue(memref);
// Emit partition type.
if (layoutMap.getResult(dim).getKind() == AffineExprKind::FloorDiv)
os << " block";
else
os << " cyclic";
// Emit partition type.
if (type.getLayout().getAffineMap().getResult(dim).getKind() ==
AffineExprKind::FloorDiv)
os << " block";
else
os << " cyclic";
os << " factor=" << factors[dim];
os << " dim=" << dim + 1 << "\n";
}
os << " factor=" << factors[dim];
os << " dim=" << dim + 1 << "\n";
}
}

@ -1 +1 @@
Subproject commit ca7fa5cf0fbb5e05146de634869b0098e2919359
Subproject commit 3f0e66045f01305e2b09de186552a3c42cdfd8e5