mirror of https://github.com/llvm/circt.git
[StandardToHandshake] Actually fail conversion when raising errors (#2267)
This commit is contained in:
parent
7cf4f5e365
commit
9bc200c6ef
|
@ -682,6 +682,7 @@ def LoadOp : Handshake_Op<"load", [
|
||||||
let builders = [OpBuilder<(ins "Value":$memref, "ArrayRef<Value>":$indices)>];
|
let builders = [OpBuilder<(ins "Value":$memref, "ArrayRef<Value>":$indices)>];
|
||||||
let printer = "return ::print$cppClass(p, *this);";
|
let printer = "return ::print$cppClass(p, *this);";
|
||||||
let parser = "return ::parse$cppClass(parser, result);";
|
let parser = "return ::parse$cppClass(parser, result);";
|
||||||
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
}
|
}
|
||||||
|
|
||||||
def StoreOp : Handshake_Op<"store", [
|
def StoreOp : Handshake_Op<"store", [
|
||||||
|
@ -715,6 +716,7 @@ def StoreOp : Handshake_Op<"store", [
|
||||||
[OpBuilder<(ins "Value":$valueToStore, "ArrayRef<Value>":$indices)>];
|
[OpBuilder<(ins "Value":$valueToStore, "ArrayRef<Value>":$indices)>];
|
||||||
let printer = "return ::print$cppClass(p, *this);";
|
let printer = "return ::print$cppClass(p, *this);";
|
||||||
let parser = "return ::parse$cppClass(parser, result);";
|
let parser = "return ::parse$cppClass(parser, result);";
|
||||||
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
}
|
}
|
||||||
|
|
||||||
def JoinOp : Handshake_Op<"join", [
|
def JoinOp : Handshake_Op<"join", [
|
||||||
|
|
|
@ -576,7 +576,7 @@ LogicalResult connectConstantsToControl(handshake::FuncOp f,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void checkUseCount(Operation *op, Value res) {
|
LogicalResult checkUseCount(Operation *op, Value res) {
|
||||||
// Checks if every result has single use
|
// Checks if every result has single use
|
||||||
if (!res.hasOneUse()) {
|
if (!res.hasOneUse()) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -584,24 +584,25 @@ void checkUseCount(Operation *op, Value res) {
|
||||||
user->emitWarning("user here");
|
user->emitWarning("user here");
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
op->emitError("every result must have exactly one user, but had ") << i;
|
return op->emitOpError("every result must have exactly one user, but had ")
|
||||||
|
<< i;
|
||||||
}
|
}
|
||||||
return;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if op successors are in appropriate blocks
|
// Checks if op successors are in appropriate blocks
|
||||||
void checkSuccessorBlocks(Operation *op, Value res) {
|
LogicalResult checkSuccessorBlocks(Operation *op, Value res) {
|
||||||
for (auto &u : res.getUses()) {
|
for (auto &u : res.getUses()) {
|
||||||
Operation *succOp = u.getOwner();
|
Operation *succOp = u.getOwner();
|
||||||
// Non-branch ops: succesors must be in same block
|
// Non-branch ops: succesors must be in same block
|
||||||
if (!(isa<handshake::ConditionalBranchOp>(op) ||
|
if (!(isa<handshake::ConditionalBranchOp>(op) ||
|
||||||
isa<handshake::BranchOp>(op))) {
|
isa<handshake::BranchOp>(op))) {
|
||||||
if (op->getBlock() != succOp->getBlock())
|
if (op->getBlock() != succOp->getBlock())
|
||||||
op->emitError("cannot be block live-out");
|
return op->emitOpError("cannot be block live-out");
|
||||||
} else {
|
} else {
|
||||||
// Branch ops: must have successor per successor block
|
// Branch ops: must have successor per successor block
|
||||||
if (op->getBlock()->getNumSuccessors() != op->getNumResults())
|
if (op->getBlock()->getNumSuccessors() != op->getNumResults())
|
||||||
op->emitError("incorrect successor count");
|
return op->emitOpError("incorrect successor count");
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (int i = 0, e = op->getBlock()->getNumSuccessors(); i < e; ++i) {
|
for (int i = 0, e = op->getBlock()->getNumSuccessors(); i < e; ++i) {
|
||||||
Block *succ = op->getBlock()->getSuccessor(i);
|
Block *succ = op->getBlock()->getSuccessor(i);
|
||||||
|
@ -609,25 +610,26 @@ void checkSuccessorBlocks(Operation *op, Value res) {
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
if (!found)
|
if (!found)
|
||||||
op->emitError("branch successor in incorrect block");
|
return op->emitOpError("branch successor in incorrect block");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if merge predecessors are in appropriate block
|
// Checks if merge predecessors are in appropriate block
|
||||||
void checkMergePredecessors(MergeLikeOpInterface mergeOp) {
|
LogicalResult checkMergePredecessors(MergeLikeOpInterface mergeOp) {
|
||||||
Block *block = mergeOp->getBlock();
|
Block *block = mergeOp->getBlock();
|
||||||
unsigned operand_count = mergeOp.dataOperands().size();
|
unsigned operand_count = mergeOp.dataOperands().size();
|
||||||
|
|
||||||
// Merges in entry block have single predecessor (argument)
|
// Merges in entry block have single predecessor (argument)
|
||||||
if (block->isEntryBlock()) {
|
if (block->isEntryBlock()) {
|
||||||
if (operand_count != 1)
|
if (operand_count != 1)
|
||||||
mergeOp->emitError("merge operations in entry block must have a ")
|
return mergeOp->emitOpError(
|
||||||
|
"merge operations in entry block must have a ")
|
||||||
<< "single predecessor";
|
<< "single predecessor";
|
||||||
} else {
|
} else {
|
||||||
if (operand_count > getBlockPredecessorCount(block))
|
if (operand_count > getBlockPredecessorCount(block))
|
||||||
mergeOp->emitError("merge operation has ")
|
return mergeOp->emitOpError("merge operation has ")
|
||||||
<< operand_count << " data inputs, but only "
|
<< operand_count << " data inputs, but only "
|
||||||
<< getBlockPredecessorCount(block) << " predecessor blocks";
|
<< getBlockPredecessorCount(block) << " predecessor blocks";
|
||||||
}
|
}
|
||||||
|
@ -643,19 +645,19 @@ void checkMergePredecessors(MergeLikeOpInterface mergeOp) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!found)
|
if (!found)
|
||||||
mergeOp->emitError("missing predecessor from predecessor block");
|
return mergeOp->emitOpError("missing predecessor from predecessor block");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select operand must come from same block
|
// Select operand must come from same block
|
||||||
if (auto muxOp = dyn_cast<MuxOp>(mergeOp.getOperation())) {
|
if (auto muxOp = dyn_cast<MuxOp>(mergeOp.getOperation())) {
|
||||||
auto *operand = muxOp.selectOperand().getDefiningOp();
|
auto *operand = muxOp.selectOperand().getDefiningOp();
|
||||||
if (operand->getBlock() != block)
|
if (operand->getBlock() != block)
|
||||||
mergeOp->emitError("mux select operand must be from same block");
|
return mergeOp->emitOpError("mux select operand must be from same block");
|
||||||
}
|
}
|
||||||
return;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void checkDataflowConversion(handshake::FuncOp f) {
|
LogicalResult checkDataflowConversion(handshake::FuncOp f) {
|
||||||
for (Operation &op : f.getOps()) {
|
for (Operation &op : f.getOps()) {
|
||||||
if (isa<mlir::CondBranchOp, mlir::BranchOp, memref::LoadOp,
|
if (isa<mlir::CondBranchOp, mlir::BranchOp, memref::LoadOp,
|
||||||
arith::ConstantOp, mlir::AffineReadOpInterface, mlir::AffineForOp>(
|
arith::ConstantOp, mlir::AffineReadOpInterface, mlir::AffineForOp>(
|
||||||
|
@ -664,13 +666,16 @@ void checkDataflowConversion(handshake::FuncOp f) {
|
||||||
|
|
||||||
if (op.getNumResults() > 0) {
|
if (op.getNumResults() > 0) {
|
||||||
for (auto result : op.getResults()) {
|
for (auto result : op.getResults()) {
|
||||||
checkUseCount(&op, result);
|
if (checkUseCount(&op, result).failed() ||
|
||||||
checkSuccessorBlocks(&op, result);
|
checkSuccessorBlocks(&op, result).failed())
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto mergeOp = dyn_cast<MergeLikeOpInterface>(op); mergeOp)
|
if (auto mergeOp = dyn_cast<MergeLikeOpInterface>(op); mergeOp)
|
||||||
checkMergePredecessors(mergeOp);
|
if (checkMergePredecessors(mergeOp).failed())
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getBlockControlValue(Block *block) {
|
Value getBlockControlValue(Block *block) {
|
||||||
|
@ -688,17 +693,19 @@ Value getBlockControlValue(Block *block) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOpMemRef(Operation *op) {
|
LogicalResult getOpMemRef(Operation *op, Value &out) {
|
||||||
|
out = Value();
|
||||||
if (auto memOp = dyn_cast<memref::LoadOp>(op))
|
if (auto memOp = dyn_cast<memref::LoadOp>(op))
|
||||||
return memOp.getMemRef();
|
out = memOp.getMemRef();
|
||||||
if (auto memOp = dyn_cast<memref::StoreOp>(op))
|
else if (auto memOp = dyn_cast<memref::StoreOp>(op))
|
||||||
return memOp.getMemRef();
|
out = memOp.getMemRef();
|
||||||
if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) {
|
else if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) {
|
||||||
MemRefAccess access(op);
|
MemRefAccess access(op);
|
||||||
return access.memref;
|
out = access.memref;
|
||||||
}
|
}
|
||||||
op->emitError("Unknown Op type");
|
if (out != Value())
|
||||||
return Value();
|
return success();
|
||||||
|
return op->emitOpError("Unknown Op type");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isMemoryOp(Operation *op) {
|
bool isMemoryOp(Operation *op) {
|
||||||
|
@ -712,21 +719,22 @@ typedef llvm::MapVector<Value, std::vector<Operation *>> MemRefToMemoryAccessOp;
|
||||||
// ops which connect to memory/LSQ). Returns a map with an ordered
|
// ops which connect to memory/LSQ). Returns a map with an ordered
|
||||||
// list of new ops corresponding to each memref. Later, we instantiate
|
// list of new ops corresponding to each memref. Later, we instantiate
|
||||||
// a memory node for each memref and connect it to its load/store ops
|
// a memory node for each memref and connect it to its load/store ops
|
||||||
MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
|
LogicalResult replaceMemoryOps(handshake::FuncOp f,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter,
|
||||||
// Map from original memref to new load/store operations.
|
MemRefToMemoryAccessOp &memRefOps) {
|
||||||
MemRefToMemoryAccessOp MemRefOps;
|
|
||||||
|
|
||||||
std::vector<Operation *> opsToErase;
|
std::vector<Operation *> opsToErase;
|
||||||
|
|
||||||
// Replace load and store ops with the corresponding handshake ops
|
// Replace load and store ops with the corresponding handshake ops
|
||||||
// Need to traverse ops in blocks to store them in MemRefOps in program order
|
// Need to traverse ops in blocks to store them in memRefOps in program order
|
||||||
for (Operation &op : f.getOps()) {
|
for (Operation &op : f.getOps()) {
|
||||||
if (!isMemoryOp(&op))
|
if (!isMemoryOp(&op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&op);
|
rewriter.setInsertionPoint(&op);
|
||||||
Value memref = getOpMemRef(&op);
|
Value memref;
|
||||||
|
if (getOpMemRef(&op, memref).failed())
|
||||||
|
return failure();
|
||||||
Operation *newOp = nullptr;
|
Operation *newOp = nullptr;
|
||||||
|
|
||||||
llvm::TypeSwitch<Operation *>(&op)
|
llvm::TypeSwitch<Operation *>(&op)
|
||||||
|
@ -784,10 +792,10 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.Default([&](auto) {
|
.Default([&](auto) {
|
||||||
op.emitError("Load/store operation cannot be handled.");
|
op.emitOpError("Load/store operation cannot be handled.");
|
||||||
});
|
});
|
||||||
|
|
||||||
MemRefOps[memref].push_back(newOp);
|
memRefOps[memref].push_back(newOp);
|
||||||
opsToErase.push_back(&op);
|
opsToErase.push_back(&op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -801,7 +809,7 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
return MemRefOps;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Block *> getOperationBlocks(ArrayRef<Operation *> ops) {
|
std::vector<Block *> getOperationBlocks(ArrayRef<Operation *> ops) {
|
||||||
|
@ -949,8 +957,9 @@ void setLoadDataInputs(ArrayRef<Operation *> memOps, Operation *memOp) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp,
|
LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
|
||||||
int offset, ArrayRef<int> cntrlInd) {
|
Operation *memOp, int offset,
|
||||||
|
ArrayRef<int> cntrlInd) {
|
||||||
// Connect all memory ops to the join of that block (ensures that all mem
|
// Connect all memory ops to the join of that block (ensures that all mem
|
||||||
// ops terminate before a new block starts)
|
// ops terminate before a new block starts)
|
||||||
for (int i = 0, e = memOps.size(); i < e; ++i) {
|
for (int i = 0, e = memOps.size(); i < e; ++i) {
|
||||||
|
@ -958,10 +967,11 @@ void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp,
|
||||||
Value val = getBlockControlValue(op->getBlock());
|
Value val = getBlockControlValue(op->getBlock());
|
||||||
auto srcOp = val.getDefiningOp();
|
auto srcOp = val.getDefiningOp();
|
||||||
if (!isa<JoinOp, StartOp>(srcOp)) {
|
if (!isa<JoinOp, StartOp>(srcOp)) {
|
||||||
srcOp->emitError("Op expected to be a JoinOp or StartOp");
|
return srcOp->emitOpError("Op expected to be a JoinOp or StartOp");
|
||||||
}
|
}
|
||||||
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
|
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
|
void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
|
||||||
|
@ -1098,9 +1108,12 @@ LogicalResult connectToMemory(handshake::FuncOp f,
|
||||||
// user-determined)
|
// user-determined)
|
||||||
bool control = true;
|
bool control = true;
|
||||||
|
|
||||||
if (control)
|
if (control) {
|
||||||
setJoinControlInputs(memory.second, newOp, ld_count, newInd);
|
if (setJoinControlInputs(memory.second, newOp, ld_count, newInd)
|
||||||
else {
|
.failed())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
} else {
|
||||||
for (int i = 0, e = cntrl_count; i < e; ++i) {
|
for (int i = 0, e = cntrl_count; i < e; ++i) {
|
||||||
rewriter.setInsertionPointAfter(newOp);
|
rewriter.setInsertionPointAfter(newOp);
|
||||||
rewriter.create<SinkOp>(newOp->getLoc(),
|
rewriter.create<SinkOp>(newOp->getLoc(),
|
||||||
|
@ -1440,11 +1453,11 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
|
||||||
return newFuncOp.emitOpError("failed to rewrite Affine loops");
|
return newFuncOp.emitOpError("failed to rewrite Affine loops");
|
||||||
|
|
||||||
// Perform dataflow conversion
|
// Perform dataflow conversion
|
||||||
MemRefToMemoryAccessOp MemOps;
|
MemRefToMemoryAccessOp memOps;
|
||||||
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
|
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
|
||||||
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
|
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
|
||||||
MemOps = replaceMemoryOps(nfo, rewriter);
|
// Map from original memref to new load/store operations.
|
||||||
return success();
|
return replaceMemoryOps(nfo, rewriter, memOps);
|
||||||
},
|
},
|
||||||
ctx, newFuncOp));
|
ctx, newFuncOp));
|
||||||
|
|
||||||
|
@ -1462,12 +1475,12 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
|
||||||
connectConstantsToControl, ctx, newFuncOp));
|
connectConstantsToControl, ctx, newFuncOp));
|
||||||
returnOnError(
|
returnOnError(
|
||||||
partiallyLowerFuncOp<handshake::FuncOp>(addForkOps, ctx, newFuncOp));
|
partiallyLowerFuncOp<handshake::FuncOp>(addForkOps, ctx, newFuncOp));
|
||||||
checkDataflowConversion(newFuncOp);
|
returnOnError(checkDataflowConversion(newFuncOp));
|
||||||
|
|
||||||
bool lsq = false;
|
bool lsq = false;
|
||||||
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
|
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
|
||||||
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
|
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
|
||||||
return connectToMemory(nfo, MemOps, lsq, rewriter);
|
return connectToMemory(nfo, memOps, lsq, rewriter);
|
||||||
},
|
},
|
||||||
ctx, newFuncOp));
|
ctx, newFuncOp));
|
||||||
|
|
||||||
|
@ -1609,8 +1622,7 @@ struct HandshakeInsertBufferPass
|
||||||
else if (strategy == "all")
|
else if (strategy == "all")
|
||||||
bufferAllStrategy();
|
bufferAllStrategy();
|
||||||
else {
|
else {
|
||||||
emitError(getOperation().getLoc())
|
getOperation().emitOpError() << "Unknown buffer strategy: " << strategy;
|
||||||
<< "Unknown buffer strategy: " << strategy;
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1629,8 +1641,10 @@ struct HandshakeDataflowPass
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
|
|
||||||
for (auto funcOp : llvm::make_early_inc_range(m.getOps<mlir::FuncOp>())) {
|
for (auto funcOp : llvm::make_early_inc_range(m.getOps<mlir::FuncOp>())) {
|
||||||
if (failed(lowerFuncOp(funcOp, &getContext())))
|
if (failed(lowerFuncOp(funcOp, &getContext()))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legalize the resulting regions, which can have no basic blocks.
|
// Legalize the resulting regions, which can have no basic blocks.
|
||||||
|
|
|
@ -1116,6 +1116,18 @@ std::string handshake::StoreOp::getOperandName(unsigned int idx) {
|
||||||
return opName;
|
return opName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TMemoryOp>
|
||||||
|
static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
|
||||||
|
if (op.addresses().size() == 0)
|
||||||
|
return op.emitOpError() << "No addresses were specified";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoadOp(handshake::LoadOp op) {
|
||||||
|
return verifyMemoryAccessOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
std::string handshake::StoreOp::getResultName(unsigned int idx) {
|
std::string handshake::StoreOp::getResultName(unsigned int idx) {
|
||||||
std::string resName;
|
std::string resName;
|
||||||
if (idx == 0)
|
if (idx == 0)
|
||||||
|
@ -1141,6 +1153,10 @@ void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
|
||||||
result.types.append(indices.size(), builder.getIndexType());
|
result.types.append(indices.size(), builder.getIndexType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyStoreOp(handshake::StoreOp op) {
|
||||||
|
return verifyMemoryAccessOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||||
return parseMemoryAccessOp(parser, result);
|
return parseMemoryAccessOp(parser, result);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue