[SplitFunction] support internal values and memories in sub-function

This commit is contained in:
Hanchen Ye 2020-12-26 13:38:24 -06:00
parent bc5818254b
commit abfccd8052
1 changed files with 69 additions and 37 deletions

View File

@ -38,31 +38,83 @@ void SplitFunction::runOnOperation() {
auto name = "dataflow" + std::to_string(pair.first);
auto ops = pair.second;
// Collect input and output information.
SmallVector<Type, 8> inputTypes;
SmallVector<Value, 8> inputValues;
// Collect output types and values.
SmallVector<Type, 8> outputTypes;
SmallVector<Value, 8> outputValues;
SmallVector<Value, 8> internalValues;
for (auto op : ops) {
for (auto result : op->getResults()) {
// Only add values that are used.
if (result.getUses().empty())
continue;
// If the result is only used by operations in the same level, it is
// an internal value and will not be returned.
bool isInternalResult = true;
for (auto user : result.getUsers())
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
isInternalResult = false;
break;
}
if (isInternalResult) {
internalValues.push_back(result);
continue;
}
outputTypes.push_back(result.getType());
outputValues.push_back(result);
}
}
// Collect input types and values.
SmallVector<Type, 8> inputTypes;
SmallVector<Value, 8> inputValues;
SmallVector<Operation *, 8> internalMemories;
unsigned opIndex = 0;
for (auto op : ops) {
// Push back all operands and live ins as candidates.
SmallVector<Value, 8> candidateInputs(op->getOperands());
SmallVector<Value, 8> inputCandidates(op->getOperands());
if (auto loop = dyn_cast<mlir::AffineForOp>(op)) {
auto liveIns = liveness.getLiveIn(&loop.getLoopBody().front());
for (auto liveIn : liveIns)
if (!isForInductionVar(liveIn))
candidateInputs.push_back(liveIn);
inputCandidates.push_back(liveIn);
}
// Collect input types and values.
for (auto input : candidateInputs) {
// If the current input candidate is defined by an operation in the
// same level, it does not need to be passed in as argument.
if (auto defOp = input.getDefiningOp())
if (std::find(ops.begin(), ops.end(), defOp) != ops.end())
continue;
for (auto input : inputCandidates) {
// If the current input candidate is internal value, it does not need
// to be passed in as argument.
if (std::find(internalValues.begin(), internalValues.end(), input) !=
internalValues.end())
continue;
// Internal memory defining operation should be moved into the sub
// function, except TensorToMemrefOp.
if (auto defOp = input.getDefiningOp()) {
if (input.getType().isa<MemRefType>() &&
!isa<TensorToMemrefOp>(defOp)) {
bool isInternalMemory = true;
for (auto user : input.getUsers()) {
bool hasAncestor = false;
for (auto op : ops)
if (op->isAncestor(user))
hasAncestor = true;
if (!hasAncestor) {
isInternalMemory = false;
break;
}
}
if (isInternalMemory) {
internalMemories.push_back(defOp);
continue;
}
}
}
// Only unique inputs will be added.
if (std::find(inputValues.begin(), inputValues.end(), input) !=
@ -72,28 +124,6 @@ void SplitFunction::runOnOperation() {
inputTypes.push_back(input.getType());
inputValues.push_back(input);
}
// Collect output types and values.
for (auto result : op->getResults()) {
// Only add values that are used.
if (result.getUses().empty())
continue;
// If the result is only used by operations in the same level, it does
// not need to be returned.
bool isInternalResult = true;
for (auto user : result.getUsers())
if (std::find(ops.begin(), ops.end(), user) == ops.end())
isInternalResult = false;
if (isInternalResult)
continue;
outputTypes.push_back(result.getType());
outputValues.push_back(result);
}
opIndex++;
}
// Create a new function for the current dataflow level.
@ -115,7 +145,6 @@ void SplitFunction::runOnOperation() {
builder.create<mlir::ReturnOp>(func.getLoc(), outputValues);
// Move same level operations into the new created function.
opIndex = 0;
for (auto op : ops) {
op->moveBefore(returnOp);
op->removeAttr("dataflow_level");
@ -125,8 +154,11 @@ void SplitFunction::runOnOperation() {
entry->getArgument(i), [&](mlir::OpOperand &use) {
return func.getOperation()->isAncestor(use.getOwner());
});
opIndex++;
}
// Move internal memory defining operation into the new created function.
for (auto memoryDefOp : internalMemories)
memoryDefOp->moveBefore(&func.front().front());
}
}
}