From 41b5cb32a86d6590c9d78084e62f4f52bbce77a6 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Tue, 1 Dec 2020 23:37:28 -0600 Subject: [PATCH] [BenchmarkGen] update CNN benchmarkGen, now all kernels and biases become function arguments --- tools/benchmark-gen/benchmark-gen.cpp | 48 ++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tools/benchmark-gen/benchmark-gen.cpp b/tools/benchmark-gen/benchmark-gen.cpp index 6172e19..f5c7d1c 100644 --- a/tools/benchmark-gen/benchmark-gen.cpp +++ b/tools/benchmark-gen/benchmark-gen.cpp @@ -122,7 +122,7 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) { SmallVector outputTypes; auto func = builder.create( - loc, "auto_gen_cnn", builder.getFunctionType(inputTypes, outputTypes)); + loc, "tmp", builder.getFunctionType(inputTypes, outputTypes)); func.addEntryBlock(); builder.setInsertionPointToStart(&func.front()); @@ -204,6 +204,52 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) { biases.back(), fmaps.back()); builder.create(loc); + + // Create a new function taking all kernels and biases as arguments. This will + // eliminate all the AllocOp for kernels and biases in the generated code. + builder.setInsertionPointAfter(func); + SmallVector newInputTypes; + + // Add original types. + for (auto type : inputTypes) + newInputTypes.push_back(type); + + // Add kernel types. + for (auto kernel : kernels) + newInputTypes.push_back(kernel.getType()); + + // Add bias types. + for (auto bias : biases) + newInputTypes.push_back(bias.getType()); + + // Create function with new signature. + auto newFunc = builder.create( + loc, "auto_gen_cnn", builder.getFunctionType(newInputTypes, outputTypes)); + newFunc.addEntryBlock(); + builder.setInsertionPointToStart(&newFunc.front()); + + // Move all operations in the original function into the new created function. + auto &entryBlock = newFunc.front().getOperations(); + entryBlock.splice(entryBlock.end(), func.front().getOperations()); + + // Replace use of original arguments with new function arguments. + unsigned argIndex = 0; + for (auto arg : func.getArguments()) + arg.replaceAllUsesWith(newFunc.getArgument(argIndex++)); + func.erase(); + + // Replace use of original kernel memref with corresponding argument. + for (auto kernel : kernels) { + kernel.replaceAllUsesWith(newFunc.getArgument(argIndex++)); + kernel.getDefiningOp()->erase(); + } + + // Replace use of original bias memref with corresponding argument. + for (auto bias : biases) { + bias.replaceAllUsesWith(newFunc.getArgument(argIndex++)); + bias.getDefiningOp()->erase(); + } + os << module << "\n"; return success(); }