hanchenye-scalehls/lib/Transforms/Runtime/CreateAxiInterface.cpp

137 lines
5.2 KiB
C++

//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "scalehls/Transforms/Passes.h"
#include "scalehls/Transforms/Utils.h"
using namespace mlir;
using namespace scalehls;
using namespace hls;
/// A helper to get corresponding DRAM memref type from normal memref type.
static Type getDramType(Type type) {
if (auto memrefType = type.dyn_cast<MemRefType>())
return MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
memrefType.getLayout().getAffineMap(),
(unsigned)MemoryKind::DRAM);
return type;
}
/// A helper to update function type recursively.
void updateFuncType(func::FuncOp func, OpBuilder &builder) {
func.setType(
builder.getFunctionType(func.front().getArgumentTypes(),
func.back().getTerminator()->getOperandTypes()));
for (auto subCall : func.getOps<func::CallOp>()) {
auto subFunc = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
subCall, subCall.getCalleeAttr());
for (auto t : llvm::zip(subFunc.getArguments(), subCall.getOperandTypes()))
std::get<0>(t).setType(std::get<1>(t));
updateFuncType(subFunc, builder);
}
}
namespace {
struct CreateAxiInterface
: public scalehls::CreateAxiInterfaceBase<CreateAxiInterface> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder builder(module);
// Get the top and runtime function of the module.
auto func = getTopFunc(module);
auto runtime = getRuntimeFunc(module);
if (!func || !runtime || func.getNumResults() || runtime.getNumResults() ||
!llvm::hasSingleElement(runtime.getOps<func::CallOp>())) {
emitError(module.getLoc(), "fail to find legal top/runtime function");
return signalPassFailure();
}
// Get the top function call.
auto call = *runtime.getOps<func::CallOp>().begin();
if (call.getCallee() != func.getName()) {
call.emitOpError("must reference the top function");
return signalPassFailure();
};
// Move each allocs, buffer primitives, and constant primitives, allocated
// in the top function to the runtime function. As each AXI interface only
// has one read and one write channel, we need to avoid interface conflicts
// by analyzing the memroy access pattern of sub-functions.
SmallVector<Value, 32> inputs(call.getOperands());
for (auto &op : llvm::make_early_inc_range(func.front())) {
if (!isa<memref::AllocOp, PrimBufferOp, PrimConstOp>(op))
continue;
auto memref = op.getResult(0);
auto type = memref.getType().cast<MemRefType>();
op.moveBefore(call);
// Add a new AXI interface to the top function.
inputs.push_back(memref);
auto interface = func.front().addArgument(type, op.getLoc());
bool writeChannel = true, readChannel = true;
for (auto &use : llvm::make_early_inc_range(op.getUses())) {
if (auto subCall = dyn_cast<func::CallOp>(use.getOwner())) {
auto arg = module.lookupSymbol<FuncOp>(subCall.getCallee())
.getArgument(use.getOperandNumber());
auto readFlag = llvm::any_of(arg.getUsers(), [](Operation *op) {
return isa<mlir::AffineReadOpInterface, func::CallOp>(op);
});
auto writeFlag = llvm::any_of(arg.getUsers(), [](Operation *op) {
return isa<mlir::AffineWriteOpInterface, func::CallOp>(op);
});
// If the read/write is already occupied, add a new AXI interface.
if ((readFlag && !readChannel) || (writeFlag && !writeChannel)) {
inputs.push_back(memref);
interface = func.front().addArgument(type, op.getLoc());
writeChannel = true, readChannel = true;
}
// Occupy the current read/write channel.
if (readFlag)
readChannel = false;
if (writeFlag)
writeChannel = false;
} else {
use.getOwner()->emitOpError("memref must be used by call op");
return signalPassFailure();
}
// Set the current use to the current interface argument.
use.set(interface);
}
}
// Update the top function and call.
builder.setInsertionPoint(call);
auto newCall = builder.create<func::CallOp>(call.getLoc(), func.getName(),
func.getResultTypes(), inputs);
call.replaceAllUsesWith(newCall);
call.erase();
func.setType(newCall.getCalleeType());
// Convert each memory in the runtime function to DRAM type.
for (auto arg : runtime.getArguments())
arg.setType(getDramType(arg.getType()));
for (auto &op : runtime.front())
for (auto result : op.getResults())
result.setType(getDramType(result.getType()));
// Update function type recursively.
updateFuncType(runtime, builder);
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createCreateAxiInterfacePass() {
return std::make_unique<CreateAxiInterface>();
}