[StandardToHandshake] Use isa instead of dyn_cast for conditions (#87)

This commit is contained in:
Ruizhe Zhao 2020-09-16 15:24:53 +01:00 committed by GitHub
parent eeec4b26fd
commit 3842d113fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 30 additions and 39 deletions

View File

@ -8,8 +8,8 @@
// =============================================================================
#include "circt/Conversion/StandardToHandshake/StandardToHandshake.h"
#include "circt/Dialect/StaticLogic/StaticLogic.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/StaticLogic/StaticLogic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
@ -151,7 +151,7 @@ void setControlOnlyPath(handshake::FuncOp f,
// Replace original return ops with new returns with additional control input
for (Block &block : f) {
Operation *termOp = block.getTerminator();
if (dyn_cast<mlir::ReturnOp>(termOp)) {
if (isa<mlir::ReturnOp>(termOp)) {
rewriter.setInsertionPoint(termOp);
@ -308,7 +308,7 @@ Operation *insertMerge(Block *block, Value val,
// Control-only path originates from StartOp
if (val.getKind() != Value::Kind::BlockArgument) {
if (dyn_cast<StartOp>(val.getDefiningOp())) {
if (isa<StartOp>(val.getDefiningOp())) {
return rewriter.create<handshake::ControlMergeOp>(block->front().getLoc(),
val, numPredecessors);
}
@ -387,15 +387,14 @@ Value getMergeOperand(Operation *op, Block *predBlock, BlockOps blockMerges) {
else {
unsigned index = srcVal.cast<BlockArgument>().getArgNumber();
Operation *termOp = predBlock->getTerminator();
if (dyn_cast<mlir::CondBranchOp>(termOp)) {
mlir::CondBranchOp br = dyn_cast<mlir::CondBranchOp>(termOp);
if (mlir::CondBranchOp br = dyn_cast<mlir::CondBranchOp>(termOp)) {
if (block == br.getTrueDest())
return br.getTrueOperand(index);
else {
assert(block == br.getFalseDest());
return br.getFalseOperand(index);
}
} else if (dyn_cast<mlir::BranchOp>(termOp))
} else if (isa<mlir::BranchOp>(termOp))
return termOp->getOperand(index);
}
return nullptr;
@ -416,7 +415,7 @@ void removeBlockOperands(handshake::FuncOp f) {
Operation *getControlMerge(Block *block) {
// Returns CMerge of block
for (Operation &op : *block) {
if (dyn_cast<ControlMergeOp>(op)) {
if (isa<ControlMergeOp>(op)) {
return &op;
}
}
@ -426,7 +425,7 @@ Operation *getControlMerge(Block *block) {
Operation *getStartOp(Block *block) {
// Returns CMerge of block
for (Operation &op : *block) {
if (dyn_cast<StartOp>(op)) {
if (isa<StartOp>(op)) {
return &op;
}
}
@ -790,8 +789,8 @@ void checkMergePredecessors(Operation *op) {
void checkDataflowConversion(handshake::FuncOp f) {
for (Block &block : f) {
for (Operation &op : block) {
if (!dyn_cast<mlir::CondBranchOp>(op) && !dyn_cast<mlir::BranchOp>(op) &&
!dyn_cast<mlir::LoadOp>(op) && !dyn_cast<mlir::ConstantOp>(op)) {
if (!isa<mlir::CondBranchOp, mlir::BranchOp, mlir::LoadOp,
mlir::ConstantOp>(op)) {
if (op.getNumResults() > 0) {
for (auto result : op.getResults()) {
checkUseCount(&op, result);
@ -828,9 +827,7 @@ Value getOpMemRef(Operation *op) {
op->emitError("Unknown Op type");
}
bool isMemoryOp(Operation *op) {
return (dyn_cast<mlir::LoadOp>(op) || dyn_cast<mlir::StoreOp>(op));
}
bool isMemoryOp(Operation *op) { return isa<mlir::LoadOp, mlir::StoreOp>(op); }
typedef llvm::MapVector<Value, vector<Operation *>> MemRefToMemoryAccessOp;
@ -854,28 +851,25 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
Value memref = getOpMemRef(&op);
Operation *newOp = nullptr;
if (dyn_cast<mlir::LoadOp>(op)) {
if (mlir::LoadOp loadOp = dyn_cast<mlir::LoadOp>(op)) {
// Get operands which correspond to address indices
// This will add all operands except alloc
SmallVector<Value, 8> operands(
dyn_cast<mlir::LoadOp>(op).getIndices());
SmallVector<Value, 8> operands(loadOp.getIndices());
newOp =
rewriter.create<handshake::LoadOp>(op.getLoc(), memref, operands);
op.getResult(0).replaceAllUsesWith(newOp->getResult(0));
} else {
assert(dyn_cast<mlir::StoreOp>(op));
assert(isa<mlir::StoreOp>(op));
// Get operands which correspond to address indices
// This will add all operands except alloc and data
SmallVector<Value, 8> operands(
dyn_cast<mlir::StoreOp>(op).getIndices());
mlir::StoreOp storeOp = dyn_cast<mlir::StoreOp>(op);
SmallVector<Value, 8> operands(storeOp.getIndices());
// Create new op where operands are store data and address indices
newOp = rewriter.create<handshake::StoreOp>(
op.getLoc(), dyn_cast<mlir::StoreOp>(op).getValueToStore(),
operands);
op.getLoc(), storeOp.getValueToStore(), operands);
}
MemRefOps[memref].push_back(newOp);
@ -912,18 +906,17 @@ vector<Block *> getOperationBlocks(vector<Operation *> ops) {
SmallVector<Value, 8> getResultsToMemory(Operation *op) {
// Get load/store results which are given as inputs to MemoryOp
if (dyn_cast<handshake::LoadOp>(op)) {
if (handshake::LoadOp loadOp = dyn_cast<handshake::LoadOp>(op)) {
// For load, get all address outputs/indices
// (load also has one data output which goes to successor operation)
SmallVector<Value, 8> results(
dyn_cast<handshake::LoadOp>(op).addressResults());
SmallVector<Value, 8> results(loadOp.addressResults());
return results;
} else {
// For store, all outputs (data and address indices) go to memory
assert(dyn_cast<handshake::StoreOp>(op));
SmallVector<Value, 8> results(
dyn_cast<handshake::StoreOp>(op).getResults());
handshake::StoreOp storeOp = dyn_cast<handshake::StoreOp>(op);
SmallVector<Value, 8> results(storeOp.getResults());
return results;
}
}
@ -942,8 +935,7 @@ void addMemOpForks(handshake::FuncOp f, ConversionPatternRewriter &rewriter) {
for (Block &block : f) {
for (Operation &op : block) {
if (dyn_cast<MemoryOp>(op) || dyn_cast<StartOp>(op) ||
dyn_cast<ControlMergeOp>(op)) {
if (isa<MemoryOp, StartOp, ControlMergeOp>(op)) {
for (auto result : op.getResults()) {
// If there is a result and it is used more than once
if (!result.use_empty() && !result.hasOneUse())
@ -959,7 +951,7 @@ void removeAllocOps(handshake::FuncOp f, ConversionPatternRewriter &rewriter) {
for (Block &block : f)
for (Operation &op : block) {
if (dyn_cast<AllocOp>(op)) {
if (isa<AllocOp>(op)) {
assert(op.getResult(0).hasOneUse());
for (auto &u : op.getResult(0).getUses()) {
Operation *useOp = u.getOwner();
@ -980,9 +972,9 @@ void removeRedundantSinks(handshake::FuncOp f,
for (Block &block : f)
for (Operation &op : block) {
if (dyn_cast<SinkOp>(op))
if (isa<SinkOp>(op))
if (!op.getOperand(0).hasOneUse() ||
dyn_cast<AllocOp>(op.getOperand(0).getDefiningOp()))
isa<AllocOp>(op.getOperand(0).getDefiningOp()))
redundantSinks.push_back(&op);
}
for (unsigned i = 0, e = redundantSinks.size(); i != e; ++i) {
@ -1014,7 +1006,7 @@ void addJoinOps(ConversionPatternRewriter &rewriter,
auto srcOp = val.getDefiningOp();
// Insert only single join per block
if (!dyn_cast<JoinOp>(srcOp)) {
if (!isa<JoinOp>(srcOp)) {
rewriter.setInsertionPointAfter(srcOp);
Operation *newOp = rewriter.create<JoinOp>(srcOp->getLoc(), val);
for (auto &u : val.getUses())
@ -1051,7 +1043,7 @@ void setLoadDataInputs(vector<Operation *> memOps, Operation *memOp) {
// Set memory outputs as load input data
int ld_count = 0;
for (auto *op : memOps) {
if (dyn_cast<handshake::LoadOp>(op))
if (isa<handshake::LoadOp>(op))
addValueToOperands(op, memOp->getResult(ld_count++));
}
}
@ -1064,7 +1056,7 @@ void setJoinControlInputs(vector<Operation *> memOps, Operation *memOp,
auto *op = memOps[i];
Value val = getBlockControlValue(op->getBlock());
auto srcOp = val.getDefiningOp();
if (!dyn_cast<JoinOp>(srcOp)) {
if (!isa<JoinOp, StartOp>(srcOp)) {
srcOp->emitError("Op expected to be a JoinOp or StartOp");
}
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
@ -1090,8 +1082,7 @@ void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
Block *predBlock = predOp->getBlock();
if (currBlock == predBlock)
// Any dependency but RARs
if (!(dyn_cast<handshake::LoadOp>(currOp) &&
dyn_cast<handshake::LoadOp>(predOp)))
if (!(isa<handshake::LoadOp>(currOp) && isa<handshake::LoadOp>(predOp)))
// cntrlInd maps memOps index to correct control output index
controlOperands.push_back(memOp->getResult(offset + cntrlInd[j]));
}
@ -1144,7 +1135,7 @@ void connectToMemory(handshake::FuncOp f, MemRefToMemoryAccessOp MemRefOps,
int ind = 0;
for (int i = 0, e = memory.second.size(); i < e; ++i) {
auto *op = memory.second[i];
if (dyn_cast<handshake::StoreOp>(op)) {
if (isa<handshake::StoreOp>(op)) {
SmallVector<Value, 8> results = getResultsToMemory(op);
operands.insert(operands.end(), results.begin(), results.end());
newInd[i] = ind++;
@ -1155,7 +1146,7 @@ void connectToMemory(handshake::FuncOp f, MemRefToMemoryAccessOp MemRefOps,
for (int i = 0, e = memory.second.size(); i < e; ++i) {
auto *op = memory.second[i];
if (dyn_cast<handshake::LoadOp>(op)) {
if (isa<handshake::LoadOp>(op)) {
SmallVector<Value, 8> results = getResultsToMemory(op);
operands.insert(operands.end(), results.begin(), results.end());