[SplitFunction] support internal values and memories in sub-function
This commit is contained in:
parent
bc5818254b
commit
abfccd8052
|
@ -38,31 +38,83 @@ void SplitFunction::runOnOperation() {
|
|||
auto name = "dataflow" + std::to_string(pair.first);
|
||||
auto ops = pair.second;
|
||||
|
||||
// Collect input and output information.
|
||||
SmallVector<Type, 8> inputTypes;
|
||||
SmallVector<Value, 8> inputValues;
|
||||
|
||||
// Collect output types and values.
|
||||
SmallVector<Type, 8> outputTypes;
|
||||
SmallVector<Value, 8> outputValues;
|
||||
SmallVector<Value, 8> internalValues;
|
||||
|
||||
for (auto op : ops) {
|
||||
for (auto result : op->getResults()) {
|
||||
// Only add values that are used.
|
||||
if (result.getUses().empty())
|
||||
continue;
|
||||
|
||||
// If the result is only used by operations in the same level, it is
|
||||
// an internal value and will not be returned.
|
||||
bool isInternalResult = true;
|
||||
for (auto user : result.getUsers())
|
||||
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
|
||||
isInternalResult = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (isInternalResult) {
|
||||
internalValues.push_back(result);
|
||||
continue;
|
||||
}
|
||||
|
||||
outputTypes.push_back(result.getType());
|
||||
outputValues.push_back(result);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect input types and values.
|
||||
SmallVector<Type, 8> inputTypes;
|
||||
SmallVector<Value, 8> inputValues;
|
||||
SmallVector<Operation *, 8> internalMemories;
|
||||
|
||||
unsigned opIndex = 0;
|
||||
for (auto op : ops) {
|
||||
// Push back all operands and live ins as candidates.
|
||||
SmallVector<Value, 8> candidateInputs(op->getOperands());
|
||||
SmallVector<Value, 8> inputCandidates(op->getOperands());
|
||||
if (auto loop = dyn_cast<mlir::AffineForOp>(op)) {
|
||||
auto liveIns = liveness.getLiveIn(&loop.getLoopBody().front());
|
||||
for (auto liveIn : liveIns)
|
||||
if (!isForInductionVar(liveIn))
|
||||
candidateInputs.push_back(liveIn);
|
||||
inputCandidates.push_back(liveIn);
|
||||
}
|
||||
|
||||
// Collect input types and values.
|
||||
for (auto input : candidateInputs) {
|
||||
// If the current input candidate is defined by an operation in the
|
||||
// same level, it does not need to be passed in as argument.
|
||||
if (auto defOp = input.getDefiningOp())
|
||||
if (std::find(ops.begin(), ops.end(), defOp) != ops.end())
|
||||
continue;
|
||||
for (auto input : inputCandidates) {
|
||||
// If the current input candidate is internal value, it does not need
|
||||
// to be passed in as argument.
|
||||
if (std::find(internalValues.begin(), internalValues.end(), input) !=
|
||||
internalValues.end())
|
||||
continue;
|
||||
|
||||
// Internal memory defining operation should be moved into the sub
|
||||
// function, except TensorToMemrefOp.
|
||||
if (auto defOp = input.getDefiningOp()) {
|
||||
if (input.getType().isa<MemRefType>() &&
|
||||
!isa<TensorToMemrefOp>(defOp)) {
|
||||
bool isInternalMemory = true;
|
||||
for (auto user : input.getUsers()) {
|
||||
bool hasAncestor = false;
|
||||
for (auto op : ops)
|
||||
if (op->isAncestor(user))
|
||||
hasAncestor = true;
|
||||
|
||||
if (!hasAncestor) {
|
||||
isInternalMemory = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (isInternalMemory) {
|
||||
internalMemories.push_back(defOp);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only unique inputs will be added.
|
||||
if (std::find(inputValues.begin(), inputValues.end(), input) !=
|
||||
|
@ -72,28 +124,6 @@ void SplitFunction::runOnOperation() {
|
|||
inputTypes.push_back(input.getType());
|
||||
inputValues.push_back(input);
|
||||
}
|
||||
|
||||
// Collect output types and values.
|
||||
for (auto result : op->getResults()) {
|
||||
// Only add values that are used.
|
||||
if (result.getUses().empty())
|
||||
continue;
|
||||
|
||||
// If the result is only used by operations in the same level, it does
|
||||
// not need to be returned.
|
||||
bool isInternalResult = true;
|
||||
for (auto user : result.getUsers())
|
||||
if (std::find(ops.begin(), ops.end(), user) == ops.end())
|
||||
isInternalResult = false;
|
||||
|
||||
if (isInternalResult)
|
||||
continue;
|
||||
|
||||
outputTypes.push_back(result.getType());
|
||||
outputValues.push_back(result);
|
||||
}
|
||||
|
||||
opIndex++;
|
||||
}
|
||||
|
||||
// Create a new function for the current dataflow level.
|
||||
|
@ -115,7 +145,6 @@ void SplitFunction::runOnOperation() {
|
|||
builder.create<mlir::ReturnOp>(func.getLoc(), outputValues);
|
||||
|
||||
// Move same level operations into the new created function.
|
||||
opIndex = 0;
|
||||
for (auto op : ops) {
|
||||
op->moveBefore(returnOp);
|
||||
op->removeAttr("dataflow_level");
|
||||
|
@ -125,8 +154,11 @@ void SplitFunction::runOnOperation() {
|
|||
entry->getArgument(i), [&](mlir::OpOperand &use) {
|
||||
return func.getOperation()->isAncestor(use.getOwner());
|
||||
});
|
||||
opIndex++;
|
||||
}
|
||||
|
||||
// Move internal memory defining operation into the new created function.
|
||||
for (auto memoryDefOp : internalMemories)
|
||||
memoryDefOp->moveBefore(&func.front().front());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue