[StandardToHandshake] Actually fail conversion when raising errors (#2267)

This commit is contained in:
Morten Borup Petersen 2021-12-02 10:48:31 +00:00 committed by GitHub
parent 7cf4f5e365
commit 9bc200c6ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 52 deletions

View File

@ -682,6 +682,7 @@ def LoadOp : Handshake_Op<"load", [
let builders = [OpBuilder<(ins "Value":$memref, "ArrayRef<Value>":$indices)>];
let printer = "return ::print$cppClass(p, *this);";
let parser = "return ::parse$cppClass(parser, result);";
let verifier = "return ::verify$cppClass(*this);";
}
def StoreOp : Handshake_Op<"store", [
@ -715,6 +716,7 @@ def StoreOp : Handshake_Op<"store", [
[OpBuilder<(ins "Value":$valueToStore, "ArrayRef<Value>":$indices)>];
let printer = "return ::print$cppClass(p, *this);";
let parser = "return ::parse$cppClass(parser, result);";
let verifier = "return ::verify$cppClass(*this);";
}
def JoinOp : Handshake_Op<"join", [

View File

@ -576,7 +576,7 @@ LogicalResult connectConstantsToControl(handshake::FuncOp f,
return success();
}
void checkUseCount(Operation *op, Value res) {
LogicalResult checkUseCount(Operation *op, Value res) {
// Checks if every result has single use
if (!res.hasOneUse()) {
int i = 0;
@ -584,24 +584,25 @@ void checkUseCount(Operation *op, Value res) {
user->emitWarning("user here");
i++;
}
op->emitError("every result must have exactly one user, but had ") << i;
return op->emitOpError("every result must have exactly one user, but had ")
<< i;
}
return;
return success();
}
// Checks if op successors are in appropriate blocks
void checkSuccessorBlocks(Operation *op, Value res) {
LogicalResult checkSuccessorBlocks(Operation *op, Value res) {
for (auto &u : res.getUses()) {
Operation *succOp = u.getOwner();
// Non-branch ops: succesors must be in same block
if (!(isa<handshake::ConditionalBranchOp>(op) ||
isa<handshake::BranchOp>(op))) {
if (op->getBlock() != succOp->getBlock())
op->emitError("cannot be block live-out");
return op->emitOpError("cannot be block live-out");
} else {
// Branch ops: must have successor per successor block
if (op->getBlock()->getNumSuccessors() != op->getNumResults())
op->emitError("incorrect successor count");
return op->emitOpError("incorrect successor count");
bool found = false;
for (int i = 0, e = op->getBlock()->getNumSuccessors(); i < e; ++i) {
Block *succ = op->getBlock()->getSuccessor(i);
@ -609,27 +610,28 @@ void checkSuccessorBlocks(Operation *op, Value res) {
found = true;
}
if (!found)
op->emitError("branch successor in incorrect block");
return op->emitOpError("branch successor in incorrect block");
}
}
return;
return success();
}
// Checks if merge predecessors are in appropriate block
void checkMergePredecessors(MergeLikeOpInterface mergeOp) {
LogicalResult checkMergePredecessors(MergeLikeOpInterface mergeOp) {
Block *block = mergeOp->getBlock();
unsigned operand_count = mergeOp.dataOperands().size();
// Merges in entry block have single predecessor (argument)
if (block->isEntryBlock()) {
if (operand_count != 1)
mergeOp->emitError("merge operations in entry block must have a ")
<< "single predecessor";
return mergeOp->emitOpError(
"merge operations in entry block must have a ")
<< "single predecessor";
} else {
if (operand_count > getBlockPredecessorCount(block))
mergeOp->emitError("merge operation has ")
<< operand_count << " data inputs, but only "
<< getBlockPredecessorCount(block) << " predecessor blocks";
return mergeOp->emitOpError("merge operation has ")
<< operand_count << " data inputs, but only "
<< getBlockPredecessorCount(block) << " predecessor blocks";
}
// There must be a predecessor from each predecessor block
@ -643,19 +645,19 @@ void checkMergePredecessors(MergeLikeOpInterface mergeOp) {
}
}
if (!found)
mergeOp->emitError("missing predecessor from predecessor block");
return mergeOp->emitOpError("missing predecessor from predecessor block");
}
// Select operand must come from same block
if (auto muxOp = dyn_cast<MuxOp>(mergeOp.getOperation())) {
auto *operand = muxOp.selectOperand().getDefiningOp();
if (operand->getBlock() != block)
mergeOp->emitError("mux select operand must be from same block");
return mergeOp->emitOpError("mux select operand must be from same block");
}
return;
return success();
}
void checkDataflowConversion(handshake::FuncOp f) {
LogicalResult checkDataflowConversion(handshake::FuncOp f) {
for (Operation &op : f.getOps()) {
if (isa<mlir::CondBranchOp, mlir::BranchOp, memref::LoadOp,
arith::ConstantOp, mlir::AffineReadOpInterface, mlir::AffineForOp>(
@ -664,13 +666,16 @@ void checkDataflowConversion(handshake::FuncOp f) {
if (op.getNumResults() > 0) {
for (auto result : op.getResults()) {
checkUseCount(&op, result);
checkSuccessorBlocks(&op, result);
if (checkUseCount(&op, result).failed() ||
checkSuccessorBlocks(&op, result).failed())
return failure();
}
}
if (auto mergeOp = dyn_cast<MergeLikeOpInterface>(op); mergeOp)
checkMergePredecessors(mergeOp);
if (checkMergePredecessors(mergeOp).failed())
return failure();
}
return success();
}
Value getBlockControlValue(Block *block) {
@ -688,17 +693,19 @@ Value getBlockControlValue(Block *block) {
return nullptr;
}
Value getOpMemRef(Operation *op) {
LogicalResult getOpMemRef(Operation *op, Value &out) {
out = Value();
if (auto memOp = dyn_cast<memref::LoadOp>(op))
return memOp.getMemRef();
if (auto memOp = dyn_cast<memref::StoreOp>(op))
return memOp.getMemRef();
if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) {
out = memOp.getMemRef();
else if (auto memOp = dyn_cast<memref::StoreOp>(op))
out = memOp.getMemRef();
else if (isa<mlir::AffineReadOpInterface, mlir::AffineWriteOpInterface>(op)) {
MemRefAccess access(op);
return access.memref;
out = access.memref;
}
op->emitError("Unknown Op type");
return Value();
if (out != Value())
return success();
return op->emitOpError("Unknown Op type");
}
bool isMemoryOp(Operation *op) {
@ -712,21 +719,22 @@ typedef llvm::MapVector<Value, std::vector<Operation *>> MemRefToMemoryAccessOp;
// ops which connect to memory/LSQ). Returns a map with an ordered
// list of new ops corresponding to each memref. Later, we instantiate
// a memory node for each memref and connect it to its load/store ops
MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) {
// Map from original memref to new load/store operations.
MemRefToMemoryAccessOp MemRefOps;
LogicalResult replaceMemoryOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter,
MemRefToMemoryAccessOp &memRefOps) {
std::vector<Operation *> opsToErase;
// Replace load and store ops with the corresponding handshake ops
// Need to traverse ops in blocks to store them in MemRefOps in program order
// Need to traverse ops in blocks to store them in memRefOps in program order
for (Operation &op : f.getOps()) {
if (!isMemoryOp(&op))
continue;
rewriter.setInsertionPoint(&op);
Value memref = getOpMemRef(&op);
Value memref;
if (getOpMemRef(&op, memref).failed())
return failure();
Operation *newOp = nullptr;
llvm::TypeSwitch<Operation *>(&op)
@ -784,10 +792,10 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
}
})
.Default([&](auto) {
op.emitError("Load/store operation cannot be handled.");
op.emitOpError("Load/store operation cannot be handled.");
});
MemRefOps[memref].push_back(newOp);
memRefOps[memref].push_back(newOp);
opsToErase.push_back(&op);
}
@ -801,7 +809,7 @@ MemRefToMemoryAccessOp replaceMemoryOps(handshake::FuncOp f,
rewriter.eraseOp(op);
}
return MemRefOps;
return success();
}
std::vector<Block *> getOperationBlocks(ArrayRef<Operation *> ops) {
@ -949,8 +957,9 @@ void setLoadDataInputs(ArrayRef<Operation *> memOps, Operation *memOp) {
}
}
void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
Operation *memOp, int offset,
ArrayRef<int> cntrlInd) {
// Connect all memory ops to the join of that block (ensures that all mem
// ops terminate before a new block starts)
for (int i = 0, e = memOps.size(); i < e; ++i) {
@ -958,10 +967,11 @@ void setJoinControlInputs(ArrayRef<Operation *> memOps, Operation *memOp,
Value val = getBlockControlValue(op->getBlock());
auto srcOp = val.getDefiningOp();
if (!isa<JoinOp, StartOp>(srcOp)) {
srcOp->emitError("Op expected to be a JoinOp or StartOp");
return srcOp->emitOpError("Op expected to be a JoinOp or StartOp");
}
addValueToOperands(srcOp, memOp->getResult(offset + cntrlInd[i]));
}
return success();
}
void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
@ -1098,9 +1108,12 @@ LogicalResult connectToMemory(handshake::FuncOp f,
// user-determined)
bool control = true;
if (control)
setJoinControlInputs(memory.second, newOp, ld_count, newInd);
else {
if (control) {
if (setJoinControlInputs(memory.second, newOp, ld_count, newInd)
.failed())
return failure();
} else {
for (int i = 0, e = cntrl_count; i < e; ++i) {
rewriter.setInsertionPointAfter(newOp);
rewriter.create<SinkOp>(newOp->getLoc(),
@ -1440,11 +1453,11 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
return newFuncOp.emitOpError("failed to rewrite Affine loops");
// Perform dataflow conversion
MemRefToMemoryAccessOp MemOps;
MemRefToMemoryAccessOp memOps;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
MemOps = replaceMemoryOps(nfo, rewriter);
return success();
// Map from original memref to new load/store operations.
return replaceMemoryOps(nfo, rewriter, memOps);
},
ctx, newFuncOp));
@ -1462,12 +1475,12 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx) {
connectConstantsToControl, ctx, newFuncOp));
returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(addForkOps, ctx, newFuncOp));
checkDataflowConversion(newFuncOp);
returnOnError(checkDataflowConversion(newFuncOp));
bool lsq = false;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return connectToMemory(nfo, MemOps, lsq, rewriter);
return connectToMemory(nfo, memOps, lsq, rewriter);
},
ctx, newFuncOp));
@ -1609,8 +1622,7 @@ struct HandshakeInsertBufferPass
else if (strategy == "all")
bufferAllStrategy();
else {
emitError(getOperation().getLoc())
<< "Unknown buffer strategy: " << strategy;
getOperation().emitOpError() << "Unknown buffer strategy: " << strategy;
signalPassFailure();
return;
}
@ -1629,8 +1641,10 @@ struct HandshakeDataflowPass
ModuleOp m = getOperation();
for (auto funcOp : llvm::make_early_inc_range(m.getOps<mlir::FuncOp>())) {
if (failed(lowerFuncOp(funcOp, &getContext())))
if (failed(lowerFuncOp(funcOp, &getContext()))) {
signalPassFailure();
return;
}
}
// Legalize the resulting regions, which can have no basic blocks.

View File

@ -1116,6 +1116,18 @@ std::string handshake::StoreOp::getOperandName(unsigned int idx) {
return opName;
}
template <typename TMemoryOp>
static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
if (op.addresses().size() == 0)
return op.emitOpError() << "No addresses were specified";
return success();
}
static LogicalResult verifyLoadOp(handshake::LoadOp op) {
return verifyMemoryAccessOp(op);
}
std::string handshake::StoreOp::getResultName(unsigned int idx) {
std::string resName;
if (idx == 0)
@ -1141,6 +1153,10 @@ void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
result.types.append(indices.size(), builder.getIndexType());
}
static LogicalResult verifyStoreOp(handshake::StoreOp op) {
return verifyMemoryAccessOp(op);
}
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
return parseMemoryAccessOp(parser, result);
}