[ArrayPartition] use memref layout map to represent partition type rather than ArrayOp attributes

This commit is contained in:
Hanchen Ye 2021-01-03 16:16:55 -06:00
parent 34f32ab2f2
commit ff6e7f0c4c
4 changed files with 87 additions and 35 deletions

View File

@ -104,6 +104,9 @@ hlscpp::ArrayOp getArrayOp(Operation *op);
Optional<std::pair<int64_t, int64_t>>
getBoundOfAffineBound(AffineBound bound, MLIRContext *context);
void getPartitionFactors(ArrayRef<int64_t> shape, AffineMap layoutMap,
SmallVector<int64_t, 4> &factors);
} // namespace scalehls
} // namespace mlir

View File

@ -195,9 +195,11 @@ void HLSCppEstimator::getFuncDependencies() {
/// Calculate the overall partition index.
int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
auto arrayOp = getArrayOp(op);
auto access = MemRefAccess(op);
auto memrefType = access.memref.getType().cast<MemRefType>();
AffineValueMap accessMap;
MemRefAccess(op).getAccessMap(&accessMap);
access.getAccessMap(&accessMap);
// Replace all dims in the memory access AffineMap with (step * dims). This
// will ensure the "cyclic" array partition can be correctly detected.
@ -211,11 +213,8 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
if (isForInductionVar(operand))
step = getForInductionVarOwner(operand).getStep();
if (step == 1)
dimReplacements.push_back(builder.getAffineDimExpr(operandIdx));
else
dimReplacements.push_back(builder.getAffineConstantExpr(step) *
builder.getAffineDimExpr(operandIdx));
dimReplacements.push_back(builder.getAffineConstantExpr(step) *
builder.getAffineDimExpr(operandIdx));
} else {
symReplacements.push_back(
builder.getAffineSymbolExpr(operandIdx - accessMap.getNumDims()));
@ -227,42 +226,37 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
dimReplacements, symReplacements, accessMap.getNumDims(),
accessMap.getNumSymbols());
// Check whether the memref is partitioned.
auto memrefMaps = memrefType.getAffineMaps();
if (memrefMaps.empty())
return 0;
// Compose the access map with the layout map.
auto layoutMap = memrefMaps.back();
auto composeMap = layoutMap.compose(newMap);
// Collect partition factors.
SmallVector<int64_t, 4> factors;
getPartitionFactors(memrefType.getShape(), layoutMap, factors);
// Calculate the partition index of this load/store operation honoring the
// partition strategy applied.
int64_t partitionIdx = 0;
int64_t accumFactor = 1;
unsigned dim = 0;
for (auto expr : newMap.getResults()) {
auto idxExpr = builder.getAffineConstantExpr(0);
int64_t factor = 1;
for (auto dim = 0; dim < memrefType.getRank(); ++dim) {
auto idxExpr = composeMap.getResult(dim);
if (arrayOp.partition()) {
auto type = getPartitionType(arrayOp, dim);
factor = getPartitionFactor(arrayOp, dim);
if (type == "cyclic")
idxExpr = expr % builder.getAffineConstantExpr(factor);
else if (type == "block") {
auto size = arrayOp.getShapedType().getShape()[dim];
idxExpr = expr.floorDiv(
builder.getAffineConstantExpr((size + factor - 1) / factor));
}
}
if (auto constExpr = idxExpr.dyn_cast<AffineConstantExpr>()) {
if (dim == 0)
partitionIdx = constExpr.getValue();
else
partitionIdx += constExpr.getValue() * accumFactor;
} else {
if (auto constExpr = idxExpr.dyn_cast<AffineConstantExpr>())
partitionIdx += constExpr.getValue() * accumFactor;
else {
partitionIdx = -1;
break;
}
accumFactor *= factor;
dim++;
accumFactor *= factors[dim];
}
return partitionIdx;
}

View File

@ -176,3 +176,25 @@ hlscpp::ArrayOp scalehls::getArrayOp(Value memref) {
hlscpp::ArrayOp scalehls::getArrayOp(Operation *op) {
return getArrayOp(MemRefAccess(op).memref);
}
void scalehls::getPartitionFactors(ArrayRef<int64_t> shape, AffineMap layoutMap,
SmallVector<int64_t, 4> &factors) {
for (unsigned dim = 0, e = shape.size(); dim < e; ++dim) {
auto expr = layoutMap.getResult(dim);
if (auto binaryExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
if (auto factor = binaryExpr.getRHS().dyn_cast<AffineConstantExpr>()) {
if (expr.getKind() == AffineExprKind::Mod)
factors.push_back(factor.getValue());
else if (expr.getKind() == AffineExprKind::FloorDiv) {
auto blockFactor =
(shape[dim] + factor.getValue() - 1) / factor.getValue();
factors.push_back(blockFactor);
}
}
} else if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
if (constExpr.getValue() == 0)
factors.push_back(1);
}
}
}

View File

@ -38,9 +38,16 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
auto arrayShape = arrayOp.getShapedType().getShape();
auto arrayAccesses = pair.second;
auto memref = pair.first;
auto memrefType = memref.getType().cast<MemRefType>();
// Walk through each dimension of the targeted array.
SmallVector<Attribute, 4> partitionFactor;
SmallVector<int64_t, 4> partitionFactor;
SmallVector<StringRef, 4> partitionType;
SmallVector<AffineExpr, 4> partitionIndices;
SmallVector<AffineExpr, 4> addressIndices;
unsigned partitionNum = 1;
for (size_t dim = 0, e = arrayShape.size(); dim < e; ++dim) {
@ -84,6 +91,9 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
// should not be partitioned.
partitionType.push_back("none");
partitionIndices.push_back(builder.getAffineConstantExpr(0));
addressIndices.push_back(builder.getAffineDimExpr(dim));
} else if (accessNum >= maxDistance) {
// This means some elements are accessed more than once or exactly
// once, and successive elements are accessed. In most cases,
@ -91,20 +101,43 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
partitionType.push_back("cyclic");
factor = maxDistance;
partitionIndices.push_back(builder.getAffineDimExpr(dim) % factor);
addressIndices.push_back(
builder.getAffineDimExpr(dim).floorDiv(factor));
} else {
// This means discrete elements are accessed. Typically, "block"
// partition will be most benefit for this occasion.
partitionType.push_back("block");
factor = accessNum;
auto blockFactor = (memrefType.getShape()[dim] + factor - 1) / factor;
partitionIndices.push_back(
builder.getAffineDimExpr(dim).floorDiv(blockFactor));
addressIndices.push_back(builder.getAffineDimExpr(dim) % blockFactor);
}
partitionFactor.push_back(builder.getI64IntegerAttr(factor));
partitionFactor.push_back(factor);
partitionNum *= factor;
}
// Construct new layout map.
partitionIndices.append(addressIndices.begin(), addressIndices.end());
auto layoutMap = AffineMap::get(memrefType.getRank(), 0, partitionIndices,
builder.getContext());
// Construct new memref type.
auto newType = MemRefType::get(memrefType.getShape(),
memrefType.getElementType(), layoutMap);
// Set new type.
memref.setType(newType);
// TODO: set function type.
arrayOp.setAttr("partition", builder.getBoolAttr(true));
arrayOp.setAttr("partition_type", builder.getStrArrayAttr(partitionType));
arrayOp.setAttr("partition_factor", builder.getArrayAttr(partitionFactor));
arrayOp.setAttr("partition_factor",
builder.getI64ArrayAttr(partitionFactor));
arrayOp.setAttr("partition_num", builder.getI64IntegerAttr(partitionNum));
}
}