[BenchmarkGen] update CNN benchmarkGen, now all kernels and biases become function arguments

This commit is contained in:
Hanchen Ye 2020-12-01 23:37:28 -06:00
parent 7d9ee1b965
commit 41b5cb32a8
1 changed files with 47 additions and 1 deletions

View File

@ -122,7 +122,7 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
SmallVector<mlir::Type, 2> outputTypes;
auto func = builder.create<FuncOp>(
loc, "auto_gen_cnn", builder.getFunctionType(inputTypes, outputTypes));
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
func.addEntryBlock();
builder.setInsertionPointToStart(&func.front());
@ -204,6 +204,52 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
biases.back(), fmaps.back());
builder.create<mlir::ReturnOp>(loc);
// Create a new function taking all kernels and biases as arguments. This will
// eliminate all the AllocOp for kernels and biases in the generated code.
builder.setInsertionPointAfter(func);
SmallVector<mlir::Type, 32> newInputTypes;
// Add original types.
for (auto type : inputTypes)
newInputTypes.push_back(type);
// Add kernel types.
for (auto kernel : kernels)
newInputTypes.push_back(kernel.getType());
// Add bias types.
for (auto bias : biases)
newInputTypes.push_back(bias.getType());
// Create function with new signature.
auto newFunc = builder.create<FuncOp>(
loc, "auto_gen_cnn", builder.getFunctionType(newInputTypes, outputTypes));
newFunc.addEntryBlock();
builder.setInsertionPointToStart(&newFunc.front());
// Move all operations in the original function into the new created function.
auto &entryBlock = newFunc.front().getOperations();
entryBlock.splice(entryBlock.end(), func.front().getOperations());
// Replace use of original arguments with new function arguments.
unsigned argIndex = 0;
for (auto arg : func.getArguments())
arg.replaceAllUsesWith(newFunc.getArgument(argIndex++));
func.erase();
// Replace use of original kernel memref with corresponding argument.
for (auto kernel : kernels) {
kernel.replaceAllUsesWith(newFunc.getArgument(argIndex++));
kernel.getDefiningOp()->erase();
}
// Replace use of original bias memref with corresponding argument.
for (auto bias : biases) {
bias.replaceAllUsesWith(newFunc.getArgument(argIndex++));
bias.getDefiningOp()->erase();
}
os << module << "\n";
return success();
}