[ArrayPartition][LegalizeHLSCpp] Handle the memref type of global op
This commit is contained in:
parent
310d9c6981
commit
3bea951beb
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "scalehls/Transforms/Passes.h"
|
||||
#include "scalehls/Transforms/Utils.h"
|
||||
|
@ -91,6 +92,14 @@ bool scalehls::applyArrayPartition(Value array, ArrayRef<unsigned> factors,
|
|||
// Set new type.
|
||||
array.setType(newType);
|
||||
|
||||
// FIXME: This is a very very bad practice...
|
||||
// TODO: How to represent different memory resource?
|
||||
if (auto getGlobal = array.getDefiningOp<memref::GetGlobalOp>()) {
|
||||
auto module = getGlobal->getParentOfType<ModuleOp>();
|
||||
auto global = module.lookupSymbol<memref::GlobalOp>(getGlobal.nameAttr());
|
||||
global->setAttr(global.typeAttrName(), TypeAttr::get(newType));
|
||||
}
|
||||
|
||||
if (updateFuncSignature)
|
||||
if (auto func = dyn_cast<FuncOp>(array.getParentBlock()->getParentOp())) {
|
||||
// Align function type with entry block argument types only if the array
|
||||
|
|
|
@ -97,6 +97,15 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) {
|
|||
MemRefType::get(type.getShape(), type.getElementType(),
|
||||
type.getLayout().getAffineMap(), (unsigned)kind);
|
||||
memref.setType(newType);
|
||||
|
||||
// FIXME: This is a very very bad practice...
|
||||
// TODO: How to represent different memory resource?
|
||||
if (auto getGlobal = memref.getDefiningOp<memref::GetGlobalOp>()) {
|
||||
auto module = getGlobal->getParentOfType<ModuleOp>();
|
||||
auto global =
|
||||
module.lookupSymbol<memref::GlobalOp>(getGlobal.nameAttr());
|
||||
global->setAttr(global.typeAttrName(), TypeAttr::get(newType));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ void scalehls::registerScaleHLSPassPipeline() {
|
|||
pm.addPass(mlir::createSimplifyAffineStructuresPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
} else
|
||||
llvm_unreachable("please use support front-end: torch or onnx.");
|
||||
llvm_unreachable("please use supported front-end: torch or onnx.");
|
||||
|
||||
// Graph-level optimizations.
|
||||
if (dataflowGran) {
|
||||
|
|
|
@ -7,11 +7,11 @@ $ # Parse PyTorch model to Linalg dialect (with mlir_venv activated).
|
|||
$ python3 export_resnet18_mlir.py | torch-mlir-opt \
|
||||
-torchscript-module-to-torch-backend-pipeline="optimize=true" \
|
||||
-torch-backend-to-linalg-on-tensors-backend-pipeline="optimize=true" \
|
||||
-linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops create-deallocs=false" \
|
||||
-canonicalize > resnet18.mlir
|
||||
|
||||
$ # Optimize the model and emit C++ code (not working, will be fixed soon).
|
||||
$ scalehls-opt resnet18.mlir \
|
||||
-linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops create-deallocs=false" \
|
||||
-scalehls-pipeline="top-func=main_graph opt-level=2 frontend=torch" \
|
||||
| scalehls-translate -emit-hlscpp > resnet18.cpp
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue