From ff6e7f0c4c7d3bf69640cd3899d10af35dc28435 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Sun, 3 Jan 2021 16:16:55 -0600 Subject: [PATCH] [ArrayPartition] use memref layout map to represent partition type rather than ArrayOp attributes --- include/Analysis/Utils.h | 3 ++ lib/Analysis/QoREstimation.cpp | 58 ++++++++++++++----------------- lib/Analysis/Utils.cpp | 22 ++++++++++++ lib/Transforms/ArrayPartition.cpp | 39 +++++++++++++++++++-- 4 files changed, 87 insertions(+), 35 deletions(-) diff --git a/include/Analysis/Utils.h b/include/Analysis/Utils.h index 642d977..37bd10e 100644 --- a/include/Analysis/Utils.h +++ b/include/Analysis/Utils.h @@ -104,6 +104,9 @@ hlscpp::ArrayOp getArrayOp(Operation *op); Optional> getBoundOfAffineBound(AffineBound bound, MLIRContext *context); +void getPartitionFactors(ArrayRef shape, AffineMap layoutMap, + SmallVector &factors); + } // namespace scalehls } // namespace mlir diff --git a/lib/Analysis/QoREstimation.cpp b/lib/Analysis/QoREstimation.cpp index 14d1cbe..df4661b 100644 --- a/lib/Analysis/QoREstimation.cpp +++ b/lib/Analysis/QoREstimation.cpp @@ -195,9 +195,11 @@ void HLSCppEstimator::getFuncDependencies() { /// Calculate the overall partition index. int64_t HLSCppEstimator::getPartitionIndex(Operation *op) { - auto arrayOp = getArrayOp(op); + auto access = MemRefAccess(op); + auto memrefType = access.memref.getType().cast(); + AffineValueMap accessMap; - MemRefAccess(op).getAccessMap(&accessMap); + access.getAccessMap(&accessMap); // Replace all dims in the memory access AffineMap with (step * dims). This // will ensure the "cyclic" array partition can be correctly detected. @@ -211,11 +213,8 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) { if (isForInductionVar(operand)) step = getForInductionVarOwner(operand).getStep(); - if (step == 1) - dimReplacements.push_back(builder.getAffineDimExpr(operandIdx)); - else - dimReplacements.push_back(builder.getAffineConstantExpr(step) * - builder.getAffineDimExpr(operandIdx)); + dimReplacements.push_back(builder.getAffineConstantExpr(step) * + builder.getAffineDimExpr(operandIdx)); } else { symReplacements.push_back( builder.getAffineSymbolExpr(operandIdx - accessMap.getNumDims())); @@ -227,42 +226,37 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) { dimReplacements, symReplacements, accessMap.getNumDims(), accessMap.getNumSymbols()); + // Check whether the memref is partitioned. + auto memrefMaps = memrefType.getAffineMaps(); + if (memrefMaps.empty()) + return 0; + + // Compose the access map with the layout map. + auto layoutMap = memrefMaps.back(); + auto composeMap = layoutMap.compose(newMap); + + // Collect partition factors. + SmallVector factors; + getPartitionFactors(memrefType.getShape(), layoutMap, factors); + // Calculate the partition index of this load/store operation honoring the // partition strategy applied. int64_t partitionIdx = 0; int64_t accumFactor = 1; - unsigned dim = 0; - for (auto expr : newMap.getResults()) { - auto idxExpr = builder.getAffineConstantExpr(0); - int64_t factor = 1; + for (auto dim = 0; dim < memrefType.getRank(); ++dim) { + auto idxExpr = composeMap.getResult(dim); - if (arrayOp.partition()) { - auto type = getPartitionType(arrayOp, dim); - factor = getPartitionFactor(arrayOp, dim); - - if (type == "cyclic") - idxExpr = expr % builder.getAffineConstantExpr(factor); - else if (type == "block") { - auto size = arrayOp.getShapedType().getShape()[dim]; - idxExpr = expr.floorDiv( - builder.getAffineConstantExpr((size + factor - 1) / factor)); - } - } - - if (auto constExpr = idxExpr.dyn_cast()) { - if (dim == 0) - partitionIdx = constExpr.getValue(); - else - partitionIdx += constExpr.getValue() * accumFactor; - } else { + if (auto constExpr = idxExpr.dyn_cast()) + partitionIdx += constExpr.getValue() * accumFactor; + else { partitionIdx = -1; break; } - accumFactor *= factor; - dim++; + accumFactor *= factors[dim]; } + return partitionIdx; } diff --git a/lib/Analysis/Utils.cpp b/lib/Analysis/Utils.cpp index db3ff68..f574938 100644 --- a/lib/Analysis/Utils.cpp +++ b/lib/Analysis/Utils.cpp @@ -176,3 +176,25 @@ hlscpp::ArrayOp scalehls::getArrayOp(Value memref) { hlscpp::ArrayOp scalehls::getArrayOp(Operation *op) { return getArrayOp(MemRefAccess(op).memref); } + +void scalehls::getPartitionFactors(ArrayRef shape, AffineMap layoutMap, + SmallVector &factors) { + for (unsigned dim = 0, e = shape.size(); dim < e; ++dim) { + auto expr = layoutMap.getResult(dim); + + if (auto binaryExpr = expr.dyn_cast()) { + if (auto factor = binaryExpr.getRHS().dyn_cast()) { + if (expr.getKind() == AffineExprKind::Mod) + factors.push_back(factor.getValue()); + else if (expr.getKind() == AffineExprKind::FloorDiv) { + auto blockFactor = + (shape[dim] + factor.getValue() - 1) / factor.getValue(); + factors.push_back(blockFactor); + } + } + } else if (auto constExpr = expr.dyn_cast()) { + if (constExpr.getValue() == 0) + factors.push_back(1); + } + } +} diff --git a/lib/Transforms/ArrayPartition.cpp b/lib/Transforms/ArrayPartition.cpp index 7d84f51..2e218a7 100644 --- a/lib/Transforms/ArrayPartition.cpp +++ b/lib/Transforms/ArrayPartition.cpp @@ -38,9 +38,16 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) { auto arrayShape = arrayOp.getShapedType().getShape(); auto arrayAccesses = pair.second; + auto memref = pair.first; + auto memrefType = memref.getType().cast(); + // Walk through each dimension of the targeted array. - SmallVector partitionFactor; + SmallVector partitionFactor; SmallVector partitionType; + + SmallVector partitionIndices; + SmallVector addressIndices; + unsigned partitionNum = 1; for (size_t dim = 0, e = arrayShape.size(); dim < e; ++dim) { @@ -84,6 +91,9 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) { // should not be partitioned. partitionType.push_back("none"); + partitionIndices.push_back(builder.getAffineConstantExpr(0)); + addressIndices.push_back(builder.getAffineDimExpr(dim)); + } else if (accessNum >= maxDistance) { // This means some elements are accessed more than once or exactly // once, and successive elements are accessed. In most cases, @@ -91,20 +101,43 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) { partitionType.push_back("cyclic"); factor = maxDistance; + partitionIndices.push_back(builder.getAffineDimExpr(dim) % factor); + addressIndices.push_back( + builder.getAffineDimExpr(dim).floorDiv(factor)); + } else { // This means discrete elements are accessed. Typically, "block" // partition will be most benefit for this occasion. partitionType.push_back("block"); factor = accessNum; + + auto blockFactor = (memrefType.getShape()[dim] + factor - 1) / factor; + partitionIndices.push_back( + builder.getAffineDimExpr(dim).floorDiv(blockFactor)); + addressIndices.push_back(builder.getAffineDimExpr(dim) % blockFactor); } - partitionFactor.push_back(builder.getI64IntegerAttr(factor)); + partitionFactor.push_back(factor); partitionNum *= factor; } + // Construct new layout map. + partitionIndices.append(addressIndices.begin(), addressIndices.end()); + auto layoutMap = AffineMap::get(memrefType.getRank(), 0, partitionIndices, + builder.getContext()); + + // Construct new memref type. + auto newType = MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), layoutMap); + + // Set new type. + memref.setType(newType); + // TODO: set function type. + arrayOp.setAttr("partition", builder.getBoolAttr(true)); arrayOp.setAttr("partition_type", builder.getStrArrayAttr(partitionType)); - arrayOp.setAttr("partition_factor", builder.getArrayAttr(partitionFactor)); + arrayOp.setAttr("partition_factor", + builder.getI64ArrayAttr(partitionFactor)); arrayOp.setAttr("partition_num", builder.getI64IntegerAttr(partitionNum)); } }