mirror of https://github.com/llvm/circt.git
[StandardToHandshake] Use isa instead of dyn_cast for conditions (#87)
This commit is contained in:
parent
eeec4b26fd
commit
3842d113fd
|
@ -8,8 +8,8 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "circt/Conversion/StandardToHandshake/StandardToHandshake.h"
|
||||
#include "circt/Dialect/StaticLogic/StaticLogic.h"
|
||||
#include "circt/Dialect/Handshake/HandshakeOps.h"
|
||||
#include "circt/Dialect/StaticLogic/StaticLogic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -151,7 +151,7 @@ void setControlOnlyPath(handshake::FuncOp f,
|
|||
// Replace original return ops with new returns with additional control input
|
||||
for (Block &block : f) {
|
||||
Operation *termOp = block.getTerminator();
|
||||
if (dyn_cast<mlir::ReturnOp>(termOp)) {
|
||||
if (isa<mlir::ReturnOp>(termOp)) {
|
||||
|
||||
rewriter.setInsertionPoint(termOp);
|
||||
|
||||
|
@ -308,7 +308,7 @@ Operation *insertMerge(Block *block, Value val,
|
|||
|
||||
// Control-only path originates from StartOp
|
||||
if (val.getKind() != Value::Kind::BlockArgument) {
|
||||
if (dyn_cast<StartOp>(val.getDefiningOp())) {
|
||||
if (isa<StartOp>(val.getDefiningOp())) {
|
||||
return rewriter.create<handshake::ControlMergeOp>(block->front().getLoc(),
|
||||
val, numPredecessors);
|
||||
}
|
||||
|
@ -387,15 +387,14 @@ Value getMergeOperand(Operation *op, Block *predBlock, BlockOps blockMerges) {
|
|||
else {
|
||||
unsigned index = srcVal.cast<BlockArgument>().getArgNumber();
|
||||
Operation *termOp = predBlock->getTerminator();
|
||||
if (dyn_cast<mlir::CondBranchOp>(termOp)) {
|
||||
mlir::CondBranchOp br = dyn_cast<mlir::CondBranchOp>(termOp);
|
||||
if (mlir::CondBranchOp br = dyn_cast<mlir::CondBranchOp>(termOp)) {
|
||||
if (block == br.getTrueDest())
|
||||
return br.getTrueOperand(index);
|
||||
else {
|
||||
assert(block == br.getFalseDest());
|
||||
return br.getFalseOperand(index);
|
||||
}
|
||||
} else if (dyn_cast<mlir::BranchOp>(termOp))
|
||||
} else if (isa<mlir::BranchOp>(termOp))
|
||||
return termOp->getOperand(index);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -416,7 +415,7 @@ void removeBlockOperands(handshake::FuncOp f) {
|
|||
Operation *getControlMerge(Block *block) {
|
||||
// Returns CMerge of block
|
||||
for (Operation &op : *block) {
|
||||
if (dyn_cast<ControlMergeOp>(op)) {
|
||||
if (isa<ControlMergeOp>(op)) {
|
||||
return &op;
|
||||
}
|
||||
}
|
||||
|
@ -426,7 +425,7 @@ Operation *getControlMerge(Block *block) {
|
|||
Operation *getStartOp(Block *block) {
|
||||
// Returns CMerge of block
|
||||
for (Operation &op : *block) {
|
||||
if (dyn_cast<StartOp>(op)) {
|
||||
if (isa<StartOp>(op)) {
|
||||
return &op;
|
||||
}
|
||||
}
|
||||
|
@ -790,8 +789,8 @@ void checkMergePredecessors(Operation *op) {
|
|||
void checkDataflowConversion(handshake::FuncOp f) {
|
||||
for (Block &block : f) {
|
||||
for (Operation &op : block) {
|
||||
if (!dyn_cast<mlir::CondBranchOp>(op) && !dyn_cast<mlir::BranchOp>(op) &&
|
||||
!dyn_cast<mlir::LoadOp>(op) && !dyn_cast<mlir::ConstantOp>(op)) {
|
||||
if (!isa<mlir::CondBranchOp, mlir::BranchOp, mlir::LoadOp,
|
||||
mlir::ConstantOp>(op)) {
|
||||
if (op.getNumResults() > 0) {
|
||||
for (auto result : op.getResults()) {
|
||||
checkUseCount(&op, result);
|
||||
|
@ -828,9 +827,7 @@ Value getOpMemRef(Operation *op) {
|
|||
op->emitError("Unknown Op type");
|
||||
}
|
||||
|
||||
bool isMemoryOp(Operation *op) {
|
||||
return (dyn_cast<mlir::LoadOp>(op) || dyn_cast<mlir::StoreOp>(op));
|
||||
}
|
||||
bool isMemoryOp(Operation *op) { return isa<mlir::LoadOp, mlir::StoreOp>(op); }
|
||||
|
||||
typedef llvm::MapVector<Value, vector<Operation *>> MemRefToMemoryAccessOp;
|
||||
|
||||
|
@ -854,28 +851,25 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
|
|||
Value memref = getOpMemRef(&op);
|
||||
Operation *newOp = nullptr;
|
||||
|
||||
if (dyn_cast<mlir::LoadOp>(op)) {
|
||||
if (mlir::LoadOp loadOp = dyn_cast<mlir::LoadOp>(op)) {
|
||||
// Get operands which correspond to address indices
|
||||
// This will add all operands except alloc
|
||||
SmallVector<Value, 8> operands(
|
||||
dyn_cast<mlir::LoadOp>(op).getIndices());
|
||||
SmallVector<Value, 8> operands(loadOp.getIndices());
|
||||
|
||||
newOp =
|
||||
rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
|
||||
op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
|
||||
|
||||
} else {
|
||||
|
||||
assert(dyn_cast<mlir::StoreOp>(op));
|
||||
assert(isa<mlir::StoreOp>(op));
|
||||
// Get operands which correspond to address indices
|
||||
// This will add all operands except alloc and data
|
||||
SmallVector<Value, 8> operands(
|
||||
dyn_cast<mlir::StoreOp>(op).getIndices());
|
||||
mlir::StoreOp storeOp = dyn_cast<mlir::StoreOp>(op);
|
||||
SmallVector<Value, 8> operands(storeOp.getIndices());
|
||||
|
||||
// Create new op where operands are store data and address indices
|
||||
newOp = rewriter.create<handshake::StoreOp>(
|
||||
op.getLoc(), dyn_cast<mlir::StoreOp>(op).getValueToStore(),
|
||||
operands);
|
||||
op.getLoc(), storeOp.getValueToStore(), operands);
|
||||
}
|
||||
MemRefOps[memref].push_back(newOp);
|
||||
|
||||
|
@ -912,18 +906,17 @@ vector<Block *> getOperationBlocks(vector<Operation *> ops) {
|
|||
SmallVector<Value, 8> getResultsToMemory(Operation *op) {
|
||||
// Get load/store results which are given as inputs to MemoryOp
|
||||
|
||||
if (dyn_cast<handshake::LoadOp>(op)) {
|
||||
if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
|
||||
// For load, get all address outputs/indices
|
||||
// (load also has one data output which goes to successor operation)
|
||||
SmallVector<Value, 8> results(
|
||||
dyn_cast<handshake::LoadOp>(op).addressResults());
|
||||
SmallVector<Value, 8> results(loadOp.addressResults());
|
||||
return results;
|
||||
|
||||
} else {
|
||||
// For store, all outputs (data and address indices) go to memory
|
||||
assert(dyn_cast<handshake::StoreOp>(op));
|
||||
SmallVector<Value, 8> results(
|
||||
dyn_cast<handshake::StoreOp>(op).getResults());
|
||||
handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
|
||||
SmallVector<Value, 8> results(storeOp.getResults());
|
||||
return results;
|
||||
}
|
||||
}
|
||||
|
@ -942,8 +935,7 @@ void addMemOpForks(handshake::FuncOp f, ConversionPatternRewriter &rewriter) {
|
|||
|
||||
for (Block &block : f) {
|
||||
for (Operation &op : block) {
|
||||
if (dyn_cast<MemoryOp>(op) || dyn_cast<StartOp>(op) ||
|
||||
dyn_cast<ControlMergeOp>(op)) {
|
||||
if (isa<MemoryOp, StartOp, ControlMergeOp>(op)) {
|
||||
for (auto result : op.getResults()) {
|
||||
// If there is a result and it is used more than once
|
||||
if (!result.use_empty() && !result.hasOneUse())
|
||||
|
@ -959,7 +951,7 @@ void removeAllocOps(handshake::FuncOp f, ConversionPatternRewriter &rewriter) {
|
|||
|
||||
for (Block &block : f)
|
||||
for (Operation &op : block) {
|
||||
if (dyn_cast<AllocOp>(op)) {
|
||||
if (isa<AllocOp>(op)) {
|
||||
assert(op.getResult(0).hasOneUse());
|
||||
for (auto &u : op.getResult(0).getUses()) {
|
||||
Operation *useOp = u.getOwner();
|
||||
|
@ -980,9 +972,9 @@ void removeRedundantSinks(handshake::FuncOp f,
|
|||
|
||||
for (Block &block : f)
|
||||
for (Operation &op : block) {
|
||||
if (dyn_cast<SinkOp>(op))
|
||||
if (isa<SinkOp>(op))
|
||||
if (!op.getOperand(0).hasOneUse() ||
|
||||
dyn_cast<AllocOp>(op.getOperand(0).getDefiningOp()))
|
||||
isa<AllocOp>(op.getOperand(0).getDefiningOp()))
|
||||
redundantSinks.push_back(&op);
|
||||
}
|
||||
for (unsigned i = 0, e = redundantSinks.size(); i != e; ++i) {
|
||||
|
@ -1014,7 +1006,7 @@ void addJoinOps(ConversionPatternRewriter &rewriter,
|
|||
auto srcOp = val.getDefiningOp();
|
||||
|
||||
// Insert only single join per block
|
||||
if (!dyn_cast<JoinOp>(srcOp)) {
|
||||
if (!isa<JoinOp>(srcOp)) {
|
||||
rewriter.setInsertionPointAfter(srcOp);
|
||||
Operation *newOp = rewriter.create<JoinOp>(srcOp->getLoc(), val);
|
||||
for (auto &u : val.getUses())
|
||||
|
@ -1051,7 +1043,7 @@ void setLoadDataInputs(vector<Operation *> memOps, Operation *memOp) {
|
|||
// Set memory outputs as load input data
|
||||
int ld_count = 0;
|
||||
for (auto *op : memOps) {
|
||||
if (dyn_cast<handshake::LoadOp>(op))
|
||||
if (isa<handshake::LoadOp>(op))
|
||||
addValueToOperands(op, memOp->getResult(ld_count++));
|
||||
}
|
||||
}
|
||||
|
@ -1064,7 +1056,7 @@ void setJoinControlInputs(vector<Operation *> memOps, Operation *memOp,
|
|||
auto *op = memOps[i];
|
||||
Value val = getBlockControlValue(op->getBlock());
|
||||
auto srcOp = val.getDefiningOp();
|
||||
if (!dyn_cast<JoinOp>(srcOp)) {
|
||||
if (!isa<JoinOp, StartOp>(srcOp)) {
|
||||
srcOp->emitError("Op expected to be a JoinOp or StartOp");
|
||||
}
|
||||
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
|
||||
|
@ -1090,8 +1082,7 @@ void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
|
|||
Block *predBlock = predOp->getBlock();
|
||||
if (currBlock == predBlock)
|
||||
// Any dependency but RARs
|
||||
if (!(dyn_cast<handshake::LoadOp>(currOp) &&
|
||||
dyn_cast<handshake::LoadOp>(predOp)))
|
||||
if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
|
||||
// cntrlInd maps memOps index to correct control output index
|
||||
controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
|
||||
}
|
||||
|
@ -1144,7 +1135,7 @@ void connectToMemory(handshake::FuncOp f, MemRefToMemoryAccessOp MemRefOps,
|
|||
int ind = 0;
|
||||
for (int i = 0, e = memory.second.size(); i < e; ++i) {
|
||||
auto *op = memory.second[i];
|
||||
if (dyn_cast<handshake::StoreOp>(op)) {
|
||||
if (isa<handshake::StoreOp>(op)) {
|
||||
SmallVector<Value, 8> results = getResultsToMemory(op);
|
||||
operands.insert(operands.end(), results.begin(), results.end());
|
||||
newInd[i] = ind++;
|
||||
|
@ -1155,7 +1146,7 @@ void connectToMemory(handshake::FuncOp f, MemRefToMemoryAccessOp MemRefOps,
|
|||
|
||||
for (int i = 0, e = memory.second.size(); i < e; ++i) {
|
||||
auto *op = memory.second[i];
|
||||
if (dyn_cast<handshake::LoadOp>(op)) {
|
||||
if (isa<handshake::LoadOp>(op)) {
|
||||
SmallVector<Value, 8> results = getResultsToMemory(op);
|
||||
operands.insert(operands.end(), results.begin(), results.end());
|
||||
|
||||
|
|
Loading…
Reference in New Issue