[StandardToHandshake] Actually fail conversion when raising errors (#2267)

This commit is contained in:
Morten Borup Petersen 2021-12-02 10:48:31 +00:00 committed by GitHub
parent 7cf4f5e365
commit 9bc200c6ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 52 deletions

View File

@ -682,6 +682,7 @@ def LoadOp : Handshake_Op<"load", [
let builders = [OpBuilder<(ins "Value":$memref, "ArrayRef<Value>":$indices)>]; let builders = [OpBuilder<(ins "Value":$memref, "ArrayRef<Value>":$indices)>];
let printer = "return ::print$cppClass(p, *this);"; let printer = "return ::print$cppClass(p, *this);";
let parser = "return ::parse$cppClass(parser, result);"; let parser = "return ::parse$cppClass(parser, result);";
let verifier = "return ::verify$cppClass(*this);";
} }
def StoreOp : Handshake_Op<"store", [ def StoreOp : Handshake_Op<"store", [
@ -715,6 +716,7 @@ def StoreOp : Handshake_Op<"store", [
[OpBuilder<(ins "Value":$valueToStore, "ArrayRef<Value>":$indices)>]; [OpBuilder<(ins "Value":$valueToStore, "ArrayRef<Value>":$indices)>];
let printer = "return ::print$cppClass(p, *this);"; let printer = "return ::print$cppClass(p, *this);";
let parser = "return ::parse$cppClass(parser, result);"; let parser = "return ::parse$cppClass(parser, result);";
let verifier = "return ::verify$cppClass(*this);";
} }
def JoinOp : Handshake_Op<"join", [ def JoinOp : Handshake_Op<"join", [

View File

@ -576,7 +576,7 @@ LogicalResult connectConstantsToControl(handshake::FuncOp f,
return success(); return success();
} }
void checkUseCount(Operation *op, Value res) { LogicalResult checkUseCount(Operation *op, Value res) {
// Checks if every result has single use // Checks if every result has single use
if (!res.hasOneUse()) { if (!res.hasOneUse()) {
int i = 0; int i = 0;
@ -584,24 +584,25 @@ void checkUseCount(Operation *op, Value res) {
user->emitWarning("user here"); user->emitWarning("user here");
i++; i++;
} }
op->emitError("every result must have exactly one user, but had ") << i; return op->emitOpError("every result must have exactly one user, but had ")
<< i;
} }
return; return success();
} }
// Checks if op successors are in appropriate blocks // Checks if op successors are in appropriate blocks
void checkSuccessorBlocks(Operation *op, Value res) { LogicalResult checkSuccessorBlocks(Operation *op, Value res) {
for (auto &u : res.getUses()) { for (auto &u : res.getUses()) {
Operation *succOp = u.getOwner(); Operation *succOp = u.getOwner();
// Non-branch ops: succesors must be in same block // Non-branch ops: succesors must be in same block
if (!(isa<handshake::ConditionalBranchOp>(op) || if (!(isa<handshake::ConditionalBranchOp>(op) ||
isa<handshake::BranchOp>(op))) { isa<handshake::BranchOp>(op))) {
if (op->getBlock() != succOp->getBlock()) if (op->getBlock() != succOp->getBlock())
op->emitError("cannot be block live-out"); return op->emitOpError("cannot be block live-out");
} else { } else {
// Branch ops: must have successor per successor block // Branch ops: must have successor per successor block
if (op->getBlock()->getNumSuccessors() != op->getNumResults()) if (op->getBlock()->getNumSuccessors() != op->getNumResults())
op->emitError("incorrect successor count"); return op->emitOpError("incorrect successor count");
bool found = false; bool found = false;
for (int i = 0, e = op->getBlock()->getNumSuccessors(); i < e; ++i) { for (int i = 0, e = op->getBlock()->getNumSuccessors(); i < e; ++i) {
Block *succ = op->getBlock()->getSuccessor(i); Block *succ = op->getBlock()->getSuccessor(i);
@ -609,25 +610,26 @@ void checkSuccessorBlocks(Operation *op, Value res) {
found = true; found = true;
} }
if (!found) if (!found)
op->emitError("branch successor in incorrect block"); return op->emitOpError("branch successor in incorrect block");
} }
} }
return; return success();
} }
// Checks if merge predecessors are in appropriate block // Checks if merge predecessors are in appropriate block
void checkMergePredecessors(MergeLikeOpInterface mergeOp) { LogicalResult checkMergePredecessors(MergeLikeOpInterface mergeOp) {
Block *block = mergeOp->getBlock(); Block *block = mergeOp->getBlock();
unsigned operand_count = mergeOp.dataOperands().size(); unsigned operand_count = mergeOp.dataOperands().size();
// Merges in entry block have single predecessor (argument) // Merges in entry block have single predecessor (argument)
if (block->isEntryBlock()) { if (block->isEntryBlock()) {
if (operand_count != 1) if (operand_count != 1)
mergeOp->emitError("merge operations in entry block must have a ") return mergeOp->emitOpError(
"merge operations in entry block must have a ")
<< "single predecessor"; << "single predecessor";
} else { } else {
if (operand_count > getBlockPredecessorCount(block)) if (operand_count > getBlockPredecessorCount(block))
mergeOp->emitError("merge operation has ") return mergeOp->emitOpError("merge operation has ")
<< operand_count << " data inputs, but only " << operand_count << " data inputs, but only "
<< getBlockPredecessorCount(block) << " predecessor blocks"; << getBlockPredecessorCount(block) << " predecessor blocks";
} }
@ -643,19 +645,19 @@ void checkMergePredecessors(MergeLikeOpInterface mergeOp) {
} }
} }
if (!found) if (!found)
mergeOp->emitError("missing predecessor from predecessor block"); return mergeOp->emitOpError("missing predecessor from predecessor block");
} }
// Select operand must come from same block // Select operand must come from same block
if (auto muxOp = dyn_cast<MuxOp>(mergeOp.getOperation())) { if (auto muxOp = dyn_cast<MuxOp>(mergeOp.getOperation())) {
auto *operand = muxOp.selectOperand().getDefiningOp(); auto *operand = muxOp.selectOperand().getDefiningOp();
if (operand->getBlock() != block) if (operand->getBlock() != block)
mergeOp->emitError("mux select operand must be from same block"); return mergeOp->emitOpError("mux select operand must be from same block");
} }
return; return success();
} }
void checkDataflowConversion(handshake::FuncOp f) { LogicalResult checkDataflowConversion(handshake::FuncOp f) {
for (Operation &op : f.getOps()) { for (Operation &op : f.getOps()) {
if (isa<mlir::CondBranchOp, mlir::BranchOp, memref::LoadOp, if (isa<mlir::CondBranchOp, mlir::BranchOp, memref::LoadOp,
arith::ConstantOp, mlir::AffineReadOpInterface, mlir::AffineForOp>( arith::ConstantOp, mlir::AffineReadOpInterface, mlir::AffineForOp>(
@ -664,13 +666,16 @@ void checkDataflowConversion(handshake::FuncOp f) {
if (op.getNumResults() > 0) { if (op.getNumResults() > 0) {
for (auto result : op.getResults()) { for (auto result : op.getResults()) {
checkUseCount(&op, result); if (checkUseCount(&op, result).failed() ||
checkSuccessorBlocks(&op, result); checkSuccessorBlocks(&op, result).failed())
return failure();
} }
} }
if (auto mergeOp = dyn_cast<MergeLikeOpInterface>(op); mergeOp) if (auto mergeOp = dyn_cast<MergeLikeOpInterface>(op); mergeOp)
checkMergePredecessors(mergeOp); if (checkMergePredecessors(mergeOp).failed())
return failure();
} }
return success();
} }
Value getBlockControlValue(Block *block) { Value getBlockControlValue(Block *block) {
@ -688,17 +693,19 @@ Value getBlockControlValue(Block *block) {
return nullptr; return nullptr;
} }
Value getOpMemRef(Operation *op) { LogicalResult getOpMemRef(Operation *op, Value &out) {
out = Value();
if (auto memOp = dyn_cast<memref::LoadOp>(op)) if (auto memOp = dyn_cast<memref::LoadOp>(op))
return memOp.getMemRef(); out = memOp.getMemRef();
if (auto memOp = dyn_cast<memref::StoreOp>(op)) else if (auto memOp = dyn_cast<memref::StoreOp>(op))
return memOp.getMemRef(); out = memOp.getMemRef();
if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) { else if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) {
MemRefAccess access(op); MemRefAccess access(op);
return access.memref; out = access.memref;
} }
op->emitError("Unknown Op type"); if (out != Value())
return Value(); return success();
return op->emitOpError("Unknown Op type");
} }
bool isMemoryOp(Operation *op) { bool isMemoryOp(Operation *op) {
@ -712,21 +719,22 @@ typedef llvm::MapVector<Value, std::vector<Operation *>> MemRefToMemoryAccessOp;
// ops which connect to memory/LSQ). Returns a map with an ordered // ops which connect to memory/LSQ). Returns a map with an ordered
// list of new ops corresponding to each memref. Later, we instantiate // list of new ops corresponding to each memref. Later, we instantiate
// a memory node for each memref and connect it to its load/store ops // a memory node for each memref and connect it to its load/store ops
MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f, LogicalResult replaceMemoryOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter,
// Map from original memref to new load/store operations. MemRefToMemoryAccessOp &memRefOps) {
MemRefToMemoryAccessOp MemRefOps;
std::vector<Operation *> opsToErase; std::vector<Operation *> opsToErase;
// Replace load and store ops with the corresponding handshake ops // Replace load and store ops with the corresponding handshake ops
// Need to traverse ops in blocks to store them in MemRefOps in program order // Need to traverse ops in blocks to store them in memRefOps in program order
for (Operation &op : f.getOps()) { for (Operation &op : f.getOps()) {
if (!isMemoryOp(&op)) if (!isMemoryOp(&op))
continue; continue;
rewriter.setInsertionPoint(&op); rewriter.setInsertionPoint(&op);
Value memref = getOpMemRef(&op); Value memref;
if (getOpMemRef(&op, memref).failed())
return failure();
Operation *newOp = nullptr; Operation *newOp = nullptr;
llvm::TypeSwitch<Operation *>(&op) llvm::TypeSwitch<Operation *>(&op)
@ -784,10 +792,10 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
} }
}) })
.Default([&](auto) { .Default([&](auto) {
op.emitError("Load/store operation cannot be handled."); op.emitOpError("Load/store operation cannot be handled.");
}); });
MemRefOps[memref].push_back(newOp); memRefOps[memref].push_back(newOp);
opsToErase.push_back(&op); opsToErase.push_back(&op);
} }
@ -801,7 +809,7 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
rewriter.eraseOp(op); rewriter.eraseOp(op);
} }
return MemRefOps; return success();
} }
std::vector<Block *> getOperationBlocks(ArrayRef<Operation *> ops) { std::vector<Block *> getOperationBlocks(ArrayRef<Operation *> ops) {
@ -949,8 +957,9 @@ void setLoadDataInputs(ArrayRef<Operation *> memOps, Operation *memOp) {
} }
} }
void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp, LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
int offset, ArrayRef<int> cntrlInd) { Operation *memOp, int offset,
ArrayRef<int> cntrlInd) {
// Connect all memory ops to the join of that block (ensures that all mem // Connect all memory ops to the join of that block (ensures that all mem
// ops terminate before a new block starts) // ops terminate before a new block starts)
for (int i = 0, e = memOps.size(); i < e; ++i) { for (int i = 0, e = memOps.size(); i < e; ++i) {
@ -958,10 +967,11 @@ void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp,
Value val = getBlockControlValue(op->getBlock()); Value val = getBlockControlValue(op->getBlock());
auto srcOp = val.getDefiningOp(); auto srcOp = val.getDefiningOp();
if (!isa<JoinOp, StartOp>(srcOp)) { if (!isa<JoinOp, StartOp>(srcOp)) {
srcOp->emitError("Op expected to be a JoinOp or StartOp"); return srcOp->emitOpError("Op expected to be a JoinOp or StartOp");
} }
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i])); addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
} }
return success();
} }
void setMemOpControlInputs(ConversionPatternRewriter &rewriter, void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
@ -1098,9 +1108,12 @@ LogicalResult connectToMemory(handshake::FuncOp f,
// user-determined) // user-determined)
bool control = true; bool control = true;
if (control) if (control) {
setJoinControlInputs(memory.second, newOp, ld_count, newInd); if (setJoinControlInputs(memory.second, newOp, ld_count, newInd)
else { .failed())
return failure();
} else {
for (int i = 0, e = cntrl_count; i < e; ++i) { for (int i = 0, e = cntrl_count; i < e; ++i) {
rewriter.setInsertionPointAfter(newOp); rewriter.setInsertionPointAfter(newOp);
rewriter.create<SinkOp>(newOp->getLoc(), rewriter.create<SinkOp>(newOp->getLoc(),
@ -1440,11 +1453,11 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
return newFuncOp.emitOpError("failed to rewrite Affine loops"); return newFuncOp.emitOpError("failed to rewrite Affine loops");
// Perform dataflow conversion // Perform dataflow conversion
MemRefToMemoryAccessOp MemOps; MemRefToMemoryAccessOp memOps;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>( returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) { [&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
MemOps = replaceMemoryOps(nfo, rewriter); // Map from original memref to new load/store operations.
return success(); return replaceMemoryOps(nfo, rewriter, memOps);
}, },
ctx, newFuncOp)); ctx, newFuncOp));
@ -1462,12 +1475,12 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
connectConstantsToControl, ctx, newFuncOp)); connectConstantsToControl, ctx, newFuncOp));
returnOnError( returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(addForkOps, ctx, newFuncOp)); partiallyLowerFuncOp<handshake::FuncOp>(addForkOps, ctx, newFuncOp));
checkDataflowConversion(newFuncOp); returnOnError(checkDataflowConversion(newFuncOp));
bool lsq = false; bool lsq = false;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>( returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) { [&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return connectToMemory(nfo, MemOps, lsq, rewriter); return connectToMemory(nfo, memOps, lsq, rewriter);
}, },
ctx, newFuncOp)); ctx, newFuncOp));
@ -1609,8 +1622,7 @@ struct HandshakeInsertBufferPass
else if (strategy == "all") else if (strategy == "all")
bufferAllStrategy(); bufferAllStrategy();
else { else {
emitError(getOperation().getLoc()) getOperation().emitOpError() << "Unknown buffer strategy: " << strategy;
<< "Unknown buffer strategy: " << strategy;
signalPassFailure(); signalPassFailure();
return; return;
} }
@ -1629,8 +1641,10 @@ struct HandshakeDataflowPass
ModuleOp m = getOperation(); ModuleOp m = getOperation();
for (auto funcOp : llvm::make_early_inc_range(m.getOps<mlir::FuncOp>())) { for (auto funcOp : llvm::make_early_inc_range(m.getOps<mlir::FuncOp>())) {
if (failed(lowerFuncOp(funcOp, &getContext()))) if (failed(lowerFuncOp(funcOp, &getContext()))) {
signalPassFailure(); signalPassFailure();
return;
}
} }
// Legalize the resulting regions, which can have no basic blocks. // Legalize the resulting regions, which can have no basic blocks.

View File

@ -1116,6 +1116,18 @@ std::string handshake::StoreOp::getOperandName(unsigned int idx) {
return opName; return opName;
} }
template <typename TMemoryOp>
static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
if (op.addresses().size() == 0)
return op.emitOpError() << "No addresses were specified";
return success();
}
static LogicalResult verifyLoadOp(handshake::LoadOp op) {
return verifyMemoryAccessOp(op);
}
std::string handshake::StoreOp::getResultName(unsigned int idx) { std::string handshake::StoreOp::getResultName(unsigned int idx) {
std::string resName; std::string resName;
if (idx == 0) if (idx == 0)
@ -1141,6 +1153,10 @@ void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
result.types.append(indices.size(), builder.getIndexType()); result.types.append(indices.size(), builder.getIndexType());
} }
static LogicalResult verifyStoreOp(handshake::StoreOp op) {
return verifyMemoryAccessOp(op);
}
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
return parseMemoryAccessOp(parser, result); return parseMemoryAccessOp(parser, result);
} }