[FakeQuantize] Implement this pass; Add fake-quantize option to the pytorch pipeline

This commit is contained in:
Hanchen Ye 2022-02-28 18:26:37 -06:00
parent c360f7e702
commit 10252043d4
9 changed files with 125 additions and 12 deletions

View File

@ -38,6 +38,10 @@ struct ScaleHLSPyTorchPipelineOptions
Option<unsigned> vectorSize{
*this, "vector-size",
llvm::cl::desc("The size of vectorization (set 0 to disable)")};
Option<bool> fakeQuantize{
*this, "fake-quantize", llvm::cl::init(false),
llvm::cl::desc("Trigger the fake quantization (just for testing use)")};
};
/// QoR estimation and DSE passes.
@ -45,6 +49,7 @@ std::unique_ptr<Pass> createQoREstimationPass();
std::unique_ptr<Pass> createMultipleLevelDSEPass();
/// Graph optimization passes.
std::unique_ptr<Pass> createFakeQuantizePass();
std::unique_ptr<Pass> createCreateRuntimeMainPass();
std::unique_ptr<Pass>
createCreateRuntimeMainPass(const ScaleHLSPyTorchPipelineOptions &opts);

View File

@ -63,6 +63,12 @@ def MultipleLevelDSE : Pass<"dse", "ModuleOp"> {
// Graph Optimization Passes
//===----------------------------------------------------------------------===//
def FakeQuantize : Pass<"fake-quantize", "ModuleOp"> {
let summary = "Convert to 8-bits quantized model (only for testing use)";
let constructor = "mlir::scalehls::createFakeQuantizePass()";
}
def CreateRuntimeMain : Pass<"create-runtime-main", "ModuleOp"> {
let summary = "Create the main function of runtime";
let description = [{

View File

@ -5,6 +5,7 @@ add_mlir_library(MLIRScaleHLSTransforms
Directive/LoopPipelining.cpp
Graph/ConvertCopyToAffineLoops.cpp
Graph/CreateRuntimeMain.cpp
Graph/FakeQuantize.cpp
Graph/LegalizeDataflow.cpp
Graph/SimplifyTosaGraph.cpp
Graph/SplitFunction.cpp

View File

@ -0,0 +1,82 @@
//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "scalehls/Transforms/Passes.h"
#include "scalehls/Transforms/Utils.h"
using namespace mlir;
using namespace scalehls;
static Type getQuantizeType(Type type) {
auto i8Type = IntegerType::get(type.getContext(), 8);
if (type.isa<Float32Type>())
return i8Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
if (tensorType.getElementType().isa<Float32Type>())
return RankedTensorType::get(tensorType.getShape(), i8Type);
return nullptr;
}
namespace {
/// This pass is only for testing use!!! To really support quantized model,
/// first we need to have front-ends, such as Torch-MLIR, to support the model
/// quantization, which has not came true unfortunately.
struct FakeQuantize : public FakeQuantizeBase<FakeQuantize> {
void runOnOperation() override {
auto module = getOperation();
auto builder = OpBuilder(module);
// Convert the type of block arguments.
module.walk([&](Block *block) {
for (auto arg : block->getArguments())
if (auto quantType = getQuantizeType(arg.getType()))
arg.setType(quantType);
});
// Convert the type of operation results. Also, handle function, constant,
// conv2d, and matmul operations.
int8_t fakeIdx = 0;
module.walk([&](Operation *op) {
for (auto result : op->getResults())
if (auto quantType = getQuantizeType(result.getType())) {
result.setType(quantType);
if (auto constant = dyn_cast<tosa::ConstOp>(op)) {
// Because we are not trying to really quantize the model, here we
// just assign a fake value to the constant operation.
SmallVector<int8_t, 64> list(constant.value().size(), fakeIdx++);
// for (auto value : constant.valueAttr().getValues<float>())
// list.push_back(value);
auto quantValue = DenseIntElementsAttr::get(quantType, list);
constant->setAttr(constant.valueAttrName(), quantValue);
}
if (auto conv2d = dyn_cast<tosa::Conv2DOp>(op)) {
auto quantInfoAttr = tosa::ConvOpQuantizationAttr::get(
builder.getI32IntegerAttr(0), builder.getI32IntegerAttr(0),
conv2d.getContext());
conv2d->setAttr(conv2d.quantization_infoAttrName(), quantInfoAttr);
}
}
// As we have updated the type of all values in the function, we can
// safely convert the function type as well.
if (auto func = dyn_cast<FuncOp>(op))
func.setType(FunctionType::get(
func.getContext(), func.front().getArgumentTypes(),
func.back().getTerminator()->getOperandTypes()));
});
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createFakeQuantizePass() {
return std::make_unique<FakeQuantize>();
}

View File

@ -25,6 +25,7 @@ bool scalehls::applyAffineLoopOrderOpt(AffineLoopBand &band,
ArrayRef<unsigned> permMap,
bool reverse) {
LLVM_DEBUG(llvm::dbgs() << "Loop order opt ";);
assert(!band.empty() && "no loops provided");
if (!isPerfectlyNested(band))
return false;

View File

@ -16,6 +16,8 @@ using namespace scalehls;
/// Apply loop perfection. Try to sink all operations between loop statements
/// into the innermost loop of the input loop band.
bool scalehls::applyAffineLoopPerfection(AffineLoopBand &band) {
assert(!band.empty() && "no loops provided");
auto innermostLoop = band.back();
auto builder = OpBuilder(innermostLoop);

View File

@ -83,34 +83,42 @@ static void simplifyAffineStructures(Block &block) {
/// innermost tile-space loop.
Optional<unsigned> scalehls::applyLoopTiling(AffineLoopBand &band,
TileList tileList, bool simplify) {
assert(!band.empty() && "no loops provided");
if (!isPerfectlyNested(band))
return Optional<unsigned>();
// Loop tiling.
auto bandSize = band.size();
auto originalBandSize = band.size();
AffineLoopBand tiledBand;
if (failed(tilePerfectlyNested(band, tileList, &tiledBand)))
return Optional<unsigned>();
// Simplify the tiled loop band if required.
if (simplify) {
band.clear();
unsigned simplifiedBandSize = 0;
for (unsigned i = 0, e = tiledBand.size(); i < e; ++i) {
auto loop = tiledBand[i];
(void)normalizeAffineFor(loop);
Optional<uint64_t> tripCount = getConstantTripCount(loop);
if (i < originalBandSize - 1 || simplifiedBandSize > 0 || !tripCount ||
tripCount.getValue() != 1)
(void)normalizeAffineFor(loop);
if (loop && !loop.getLoopBody().empty()) {
band.push_back(loop);
if (i < bandSize)
if (i < originalBandSize)
++simplifiedBandSize;
}
}
simplifyAffineStructures(*band.front().getBody());
return simplifiedBandSize - 1;
} else {
band = tiledBand;
return bandSize - 1;
}
// Otherwise, directly return the tiled loop band.
band = tiledBand;
return originalBandSize - 1;
}
namespace {
@ -151,13 +159,14 @@ struct AffineLoopUnrollAndPipeline
sizes.push_back(1);
}
auto tileLoc = applyLoopTiling(band, sizes).getValue();
band.resize(tileLoc + 1);
// Apply loop tiling and extract the tile loops if applicable.
if (auto tileLoc = applyLoopTiling(band, sizes))
band.resize(tileLoc.getValue() + 1);
// TODO: canonicalize here to eliminate affine.apply ops?
// Apply loop order optimization and pipelining.
if (loopOrderOpt)
applyAffineLoopOrderOpt(band);
applyLoopPipelining(band, tileLoc, (unsigned)1);
applyLoopPipelining(band, band.size() - 1, (unsigned)1);
}
}
};

View File

@ -14,6 +14,8 @@ using namespace scalehls;
/// Apply remove variable bound to all inner loops of the input loop.
bool scalehls::applyRemoveVariableBound(AffineLoopBand &band) {
assert(!band.empty() && "no loops provided");
auto innermostLoop = band.back();
auto builder = OpBuilder(innermostLoop);

View File

@ -42,6 +42,9 @@ void scalehls::registerScaleHLSPyTorchPipeline() {
if (opts.vectorSize.hasValue())
vectorSize = opts.vectorSize;
if (opts.fakeQuantize)
pm.addPass(scalehls::createFakeQuantizePass());
// Graph-level optimizations.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(scalehls::createSimplifyTosaGraphPass());
@ -63,12 +66,14 @@ void scalehls::registerScaleHLSPyTorchPipeline() {
pm.addPass(scalehls::createConvertCopyToAffineLoopsPass());
// Loop-level optimizations.
if (vectorSize)
pm.addPass(mlir::createSuperVectorizePass({vectorSize}));
pm.addPass(memref::createFoldSubViewOpsPass());
pm.addPass(mlir::createAffineLoopNormalizePass());
pm.addPass(mlir::createSimplifyAffineStructuresPass());
pm.addPass(mlir::createCanonicalizerPass());
if (vectorSize) {
pm.addPass(mlir::createSuperVectorizePass({vectorSize}));
pm.addPass(mlir::createCanonicalizerPass());
}
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts));
pm.addPass(scalehls::createMaterializeReductionPass());
if (loopUnrollSize) {