[BenchmarkGen] update CNN benchmarkGen, now all kernels and biases become function arguments
This commit is contained in:
parent
7d9ee1b965
commit
41b5cb32a8
|
@ -122,7 +122,7 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
|||
SmallVector<mlir::Type, 2> outputTypes;
|
||||
|
||||
auto func = builder.create<FuncOp>(
|
||||
loc, "auto_gen_cnn", builder.getFunctionType(inputTypes, outputTypes));
|
||||
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
|
||||
func.addEntryBlock();
|
||||
builder.setInsertionPointToStart(&func.front());
|
||||
|
||||
|
@ -204,6 +204,52 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
|||
biases.back(), fmaps.back());
|
||||
|
||||
builder.create<mlir::ReturnOp>(loc);
|
||||
|
||||
// Create a new function taking all kernels and biases as arguments. This will
|
||||
// eliminate all the AllocOp for kernels and biases in the generated code.
|
||||
builder.setInsertionPointAfter(func);
|
||||
SmallVector<mlir::Type, 32> newInputTypes;
|
||||
|
||||
// Add original types.
|
||||
for (auto type : inputTypes)
|
||||
newInputTypes.push_back(type);
|
||||
|
||||
// Add kernel types.
|
||||
for (auto kernel : kernels)
|
||||
newInputTypes.push_back(kernel.getType());
|
||||
|
||||
// Add bias types.
|
||||
for (auto bias : biases)
|
||||
newInputTypes.push_back(bias.getType());
|
||||
|
||||
// Create function with new signature.
|
||||
auto newFunc = builder.create<FuncOp>(
|
||||
loc, "auto_gen_cnn", builder.getFunctionType(newInputTypes, outputTypes));
|
||||
newFunc.addEntryBlock();
|
||||
builder.setInsertionPointToStart(&newFunc.front());
|
||||
|
||||
// Move all operations in the original function into the new created function.
|
||||
auto &entryBlock = newFunc.front().getOperations();
|
||||
entryBlock.splice(entryBlock.end(), func.front().getOperations());
|
||||
|
||||
// Replace use of original arguments with new function arguments.
|
||||
unsigned argIndex = 0;
|
||||
for (auto arg : func.getArguments())
|
||||
arg.replaceAllUsesWith(newFunc.getArgument(argIndex++));
|
||||
func.erase();
|
||||
|
||||
// Replace use of original kernel memref with corresponding argument.
|
||||
for (auto kernel : kernels) {
|
||||
kernel.replaceAllUsesWith(newFunc.getArgument(argIndex++));
|
||||
kernel.getDefiningOp()->erase();
|
||||
}
|
||||
|
||||
// Replace use of original bias memref with corresponding argument.
|
||||
for (auto bias : biases) {
|
||||
bias.replaceAllUsesWith(newFunc.getArgument(argIndex++));
|
||||
bias.getDefiningOp()->erase();
|
||||
}
|
||||
|
||||
os << module << "\n";
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue