[ArrayPartition] use memref layout map to represent partition type rather than ArrayOp attributes
This commit is contained in:
parent
34f32ab2f2
commit
ff6e7f0c4c
|
@ -104,6 +104,9 @@ hlscpp::ArrayOp getArrayOp(Operation *op);
|
|||
Optional<std::pair<int64_t, int64_t>>
|
||||
getBoundOfAffineBound(AffineBound bound, MLIRContext *context);
|
||||
|
||||
void getPartitionFactors(ArrayRef<int64_t> shape, AffineMap layoutMap,
|
||||
SmallVector<int64_t, 4> &factors);
|
||||
|
||||
} // namespace scalehls
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -195,9 +195,11 @@ void HLSCppEstimator::getFuncDependencies() {
|
|||
|
||||
/// Calculate the overall partition index.
|
||||
int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
|
||||
auto arrayOp = getArrayOp(op);
|
||||
auto access = MemRefAccess(op);
|
||||
auto memrefType = access.memref.getType().cast<MemRefType>();
|
||||
|
||||
AffineValueMap accessMap;
|
||||
MemRefAccess(op).getAccessMap(&accessMap);
|
||||
access.getAccessMap(&accessMap);
|
||||
|
||||
// Replace all dims in the memory access AffineMap with (step * dims). This
|
||||
// will ensure the "cyclic" array partition can be correctly detected.
|
||||
|
@ -211,11 +213,8 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
|
|||
if (isForInductionVar(operand))
|
||||
step = getForInductionVarOwner(operand).getStep();
|
||||
|
||||
if (step == 1)
|
||||
dimReplacements.push_back(builder.getAffineDimExpr(operandIdx));
|
||||
else
|
||||
dimReplacements.push_back(builder.getAffineConstantExpr(step) *
|
||||
builder.getAffineDimExpr(operandIdx));
|
||||
dimReplacements.push_back(builder.getAffineConstantExpr(step) *
|
||||
builder.getAffineDimExpr(operandIdx));
|
||||
} else {
|
||||
symReplacements.push_back(
|
||||
builder.getAffineSymbolExpr(operandIdx - accessMap.getNumDims()));
|
||||
|
@ -227,42 +226,37 @@ int64_t HLSCppEstimator::getPartitionIndex(Operation *op) {
|
|||
dimReplacements, symReplacements, accessMap.getNumDims(),
|
||||
accessMap.getNumSymbols());
|
||||
|
||||
// Check whether the memref is partitioned.
|
||||
auto memrefMaps = memrefType.getAffineMaps();
|
||||
if (memrefMaps.empty())
|
||||
return 0;
|
||||
|
||||
// Compose the access map with the layout map.
|
||||
auto layoutMap = memrefMaps.back();
|
||||
auto composeMap = layoutMap.compose(newMap);
|
||||
|
||||
// Collect partition factors.
|
||||
SmallVector<int64_t, 4> factors;
|
||||
getPartitionFactors(memrefType.getShape(), layoutMap, factors);
|
||||
|
||||
// Calculate the partition index of this load/store operation honoring the
|
||||
// partition strategy applied.
|
||||
int64_t partitionIdx = 0;
|
||||
int64_t accumFactor = 1;
|
||||
|
||||
unsigned dim = 0;
|
||||
for (auto expr : newMap.getResults()) {
|
||||
auto idxExpr = builder.getAffineConstantExpr(0);
|
||||
int64_t factor = 1;
|
||||
for (auto dim = 0; dim < memrefType.getRank(); ++dim) {
|
||||
auto idxExpr = composeMap.getResult(dim);
|
||||
|
||||
if (arrayOp.partition()) {
|
||||
auto type = getPartitionType(arrayOp, dim);
|
||||
factor = getPartitionFactor(arrayOp, dim);
|
||||
|
||||
if (type == "cyclic")
|
||||
idxExpr = expr % builder.getAffineConstantExpr(factor);
|
||||
else if (type == "block") {
|
||||
auto size = arrayOp.getShapedType().getShape()[dim];
|
||||
idxExpr = expr.floorDiv(
|
||||
builder.getAffineConstantExpr((size + factor - 1) / factor));
|
||||
}
|
||||
}
|
||||
|
||||
if (auto constExpr = idxExpr.dyn_cast<AffineConstantExpr>()) {
|
||||
if (dim == 0)
|
||||
partitionIdx = constExpr.getValue();
|
||||
else
|
||||
partitionIdx += constExpr.getValue() * accumFactor;
|
||||
} else {
|
||||
if (auto constExpr = idxExpr.dyn_cast<AffineConstantExpr>())
|
||||
partitionIdx += constExpr.getValue() * accumFactor;
|
||||
else {
|
||||
partitionIdx = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
accumFactor *= factor;
|
||||
dim++;
|
||||
accumFactor *= factors[dim];
|
||||
}
|
||||
|
||||
return partitionIdx;
|
||||
}
|
||||
|
||||
|
|
|
@ -176,3 +176,25 @@ hlscpp::ArrayOp scalehls::getArrayOp(Value memref) {
|
|||
hlscpp::ArrayOp scalehls::getArrayOp(Operation *op) {
|
||||
return getArrayOp(MemRefAccess(op).memref);
|
||||
}
|
||||
|
||||
void scalehls::getPartitionFactors(ArrayRef<int64_t> shape, AffineMap layoutMap,
|
||||
SmallVector<int64_t, 4> &factors) {
|
||||
for (unsigned dim = 0, e = shape.size(); dim < e; ++dim) {
|
||||
auto expr = layoutMap.getResult(dim);
|
||||
|
||||
if (auto binaryExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
|
||||
if (auto factor = binaryExpr.getRHS().dyn_cast<AffineConstantExpr>()) {
|
||||
if (expr.getKind() == AffineExprKind::Mod)
|
||||
factors.push_back(factor.getValue());
|
||||
else if (expr.getKind() == AffineExprKind::FloorDiv) {
|
||||
auto blockFactor =
|
||||
(shape[dim] + factor.getValue() - 1) / factor.getValue();
|
||||
factors.push_back(blockFactor);
|
||||
}
|
||||
}
|
||||
} else if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
if (constExpr.getValue() == 0)
|
||||
factors.push_back(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,9 +38,16 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
|
|||
auto arrayShape = arrayOp.getShapedType().getShape();
|
||||
auto arrayAccesses = pair.second;
|
||||
|
||||
auto memref = pair.first;
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
|
||||
// Walk through each dimension of the targeted array.
|
||||
SmallVector<Attribute, 4> partitionFactor;
|
||||
SmallVector<int64_t, 4> partitionFactor;
|
||||
SmallVector<StringRef, 4> partitionType;
|
||||
|
||||
SmallVector<AffineExpr, 4> partitionIndices;
|
||||
SmallVector<AffineExpr, 4> addressIndices;
|
||||
|
||||
unsigned partitionNum = 1;
|
||||
|
||||
for (size_t dim = 0, e = arrayShape.size(); dim < e; ++dim) {
|
||||
|
@ -84,6 +91,9 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
|
|||
// should not be partitioned.
|
||||
partitionType.push_back("none");
|
||||
|
||||
partitionIndices.push_back(builder.getAffineConstantExpr(0));
|
||||
addressIndices.push_back(builder.getAffineDimExpr(dim));
|
||||
|
||||
} else if (accessNum >= maxDistance) {
|
||||
// This means some elements are accessed more than once or exactly
|
||||
// once, and successive elements are accessed. In most cases,
|
||||
|
@ -91,20 +101,43 @@ static void applyArrayPartition(MemAccessesMap &map, OpBuilder &builder) {
|
|||
partitionType.push_back("cyclic");
|
||||
factor = maxDistance;
|
||||
|
||||
partitionIndices.push_back(builder.getAffineDimExpr(dim) % factor);
|
||||
addressIndices.push_back(
|
||||
builder.getAffineDimExpr(dim).floorDiv(factor));
|
||||
|
||||
} else {
|
||||
// This means discrete elements are accessed. Typically, "block"
|
||||
// partition will be most benefit for this occasion.
|
||||
partitionType.push_back("block");
|
||||
factor = accessNum;
|
||||
|
||||
auto blockFactor = (memrefType.getShape()[dim] + factor - 1) / factor;
|
||||
partitionIndices.push_back(
|
||||
builder.getAffineDimExpr(dim).floorDiv(blockFactor));
|
||||
addressIndices.push_back(builder.getAffineDimExpr(dim) % blockFactor);
|
||||
}
|
||||
|
||||
partitionFactor.push_back(builder.getI64IntegerAttr(factor));
|
||||
partitionFactor.push_back(factor);
|
||||
partitionNum *= factor;
|
||||
}
|
||||
|
||||
// Construct new layout map.
|
||||
partitionIndices.append(addressIndices.begin(), addressIndices.end());
|
||||
auto layoutMap = AffineMap::get(memrefType.getRank(), 0, partitionIndices,
|
||||
builder.getContext());
|
||||
|
||||
// Construct new memref type.
|
||||
auto newType = MemRefType::get(memrefType.getShape(),
|
||||
memrefType.getElementType(), layoutMap);
|
||||
|
||||
// Set new type.
|
||||
memref.setType(newType);
|
||||
// TODO: set function type.
|
||||
|
||||
arrayOp.setAttr("partition", builder.getBoolAttr(true));
|
||||
arrayOp.setAttr("partition_type", builder.getStrArrayAttr(partitionType));
|
||||
arrayOp.setAttr("partition_factor", builder.getArrayAttr(partitionFactor));
|
||||
arrayOp.setAttr("partition_factor",
|
||||
builder.getI64ArrayAttr(partitionFactor));
|
||||
arrayOp.setAttr("partition_num", builder.getI64IntegerAttr(partitionNum));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue