diff --git a/include/circt/Dialect/Handshake/HandshakeOps.td b/include/circt/Dialect/Handshake/HandshakeOps.td index 91553f6f6b..4d2db67f37 100644 --- a/include/circt/Dialect/Handshake/HandshakeOps.td +++ b/include/circt/Dialect/Handshake/HandshakeOps.td @@ -107,15 +107,15 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> { let summary = "module instantiate operation"; let description = [{ The `instance` operation represents the instantiation of a module. This - is similar to a function call, except that different instances of the - same module are guaranteed to have their own distinct state. - The instantiated module is encoded as a - symbol reference attribute named "module". + is similar to a function call, except that different instances of the + same module are guaranteed to have their own distinct state. + The instantiated module is encoded as a symbol reference attribute named + "module". An instance operation takes a control input as its last argument + and returns a control output as its last result. Example: - ```mlir - %2 = handshake.instance @my_add(%0, %1) : (f32, f32) -> f32 + %2:2 = handshake.instance @my_add(%0, %1, %ctrl) : (f32, f32, none) -> (f32, none) ``` }]; @@ -127,12 +127,14 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> { $_state.addOperands(operands); $_state.addAttribute("module", SymbolRefAttr::get(module)); $_state.addTypes(module.getType().getResults()); + $_state.addTypes({$_builder.getType<::mlir::NoneType>()}); }]>, OpBuilder< (ins "SymbolRefAttr":$module, "TypeRange":$results, CArg<"ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); $_state.addAttribute("module", module); $_state.addTypes(results); + $_state.addTypes({$_builder.getType<::mlir::NoneType>()}); }]>, OpBuilder< (ins "StringRef":$module, "TypeRange":$results, CArg<"ValueRange", "{}">:$operands), [{ diff --git a/lib/Conversion/StandardToHandshake/StandardToHandshake.cpp b/lib/Conversion/StandardToHandshake/StandardToHandshake.cpp index fc909b5866..001f4a6439 100644 --- a/lib/Conversion/StandardToHandshake/StandardToHandshake.cpp +++ b/lib/Conversion/StandardToHandshake/StandardToHandshake.cpp @@ -1647,8 +1647,13 @@ LogicalResult replaceCallOps(handshake::FuncOp f, llvm::copy(callOp.getOperands(), std::back_inserter(operands)); operands.push_back(cntrlMg->getResult(0)); rewriter.setInsertionPoint(callOp); - rewriter.replaceOpWithNewOp( - callOp, callOp.getCallee(), callOp.getResultTypes(), operands); + auto instanceOp = rewriter.create( + callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(), + operands); + // Replace all results of the source callOp. + for (auto it : llvm::zip(callOp.getResults(), instanceOp.getResults())) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + rewriter.eraseOp(callOp); } } } diff --git a/test/Conversion/StandardToHandshake/test_call.mlir b/test/Conversion/StandardToHandshake/test_call.mlir index e9a7da455d..b78251a52c 100644 --- a/test/Conversion/StandardToHandshake/test_call.mlir +++ b/test/Conversion/StandardToHandshake/test_call.mlir @@ -16,9 +16,10 @@ func @foo(%0 : i32) -> i32 { // CHECK-SAME: %[[VAL_0:.*]]: i32, // CHECK-SAME: %[[VAL_1:.*]]: none, ...) -> (i32, none) { // CHECK: %[[VAL_2:.*]] = "handshake.merge"(%[[VAL_0]]) : (i32) -> i32 -// CHECK: %[[VAL_4:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none) -// CHECK: %[[VAL_3:.*]] = handshake.instance @bar(%[[VAL_2]], %[[VAL_4]]#0) : (i32, none) -> i32 -// CHECK: handshake.return %[[VAL_3]], %[[VAL_4]]#1 : i32, none +// CHECK: %[[VAL_3:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none) +// CHECK: %[[VAL_4:.*]]:2 = handshake.instance @bar(%[[VAL_2]], %[[VAL_3]]#0) : (i32, none) -> (i32, none) +// CHECK: "handshake.sink"(%[[VAL_4]]#1) : (none) -> () +// CHECK: handshake.return %[[VAL_4]]#0, %[[VAL_3]]#1 : i32, none // CHECK: } %a1 = call @bar(%0) : (i32) -> i32 @@ -54,17 +55,19 @@ func @sub(%arg0 : i32, %arg1: i32) -> i32 { // CHECK: %[[VAL_16:.*]]:2 = "handshake.control_merge"(%[[VAL_12]]) {control = true} : (none) -> (none, index) // CHECK: %[[VAL_17:.*]]:2 = "handshake.fork"(%[[VAL_16]]#0) {control = true} : (none) -> (none, none) // CHECK: "handshake.sink"(%[[VAL_16]]#1) : (index) -> () -// CHECK: %[[VAL_18:.*]] = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> i32 +// CHECK: %[[VAL_18:.*]]:2 = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> (i32, none) +// CHECK: "handshake.sink"(%[[VAL_18]]#1) : (none) -> () // CHECK: %[[VAL_19:.*]] = "handshake.branch"(%[[VAL_17]]#0) {control = true} : (none) -> none -// CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]) {control = false} : (i32) -> i32 +// CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]#0) {control = false} : (i32) -> i32 // CHECK: %[[VAL_21:.*]] = "handshake.merge"(%[[VAL_9]]) : (i32) -> i32 // CHECK: %[[VAL_22:.*]] = "handshake.merge"(%[[VAL_11]]) : (i32) -> i32 // CHECK: %[[VAL_23:.*]]:2 = "handshake.control_merge"(%[[VAL_13]]) {control = true} : (none) -> (none, index) // CHECK: %[[VAL_24:.*]]:2 = "handshake.fork"(%[[VAL_23]]#0) {control = true} : (none) -> (none, none) // CHECK: "handshake.sink"(%[[VAL_23]]#1) : (index) -> () -// CHECK: %[[VAL_25:.*]] = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> i32 +// CHECK: %[[VAL_25:.*]]:2 = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> (i32, none) +// CHECK: "handshake.sink"(%[[VAL_25]]#1) : (none) -> () // CHECK: %[[VAL_26:.*]] = "handshake.branch"(%[[VAL_24]]#0) {control = true} : (none) -> none -// CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]) {control = false} : (i32) -> i32 +// CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]#0) {control = false} : (i32) -> i32 // CHECK: %[[VAL_28:.*]]:2 = "handshake.control_merge"(%[[VAL_26]], %[[VAL_19]]) {control = true} : (none, none) -> (none, index) // CHECK: %[[VAL_29:.*]] = "handshake.mux"(%[[VAL_28]]#1, %[[VAL_27]], %[[VAL_20]]) : (index, i32, i32) -> i32 // CHECK: handshake.return %[[VAL_29]], %[[VAL_28]]#0 : i32, none diff --git a/test/handshake-runner/call_bb.mlir b/test/handshake-runner/call_bb.mlir index cc2e381ce0..b20cfa1f96 100644 --- a/test/handshake-runner/call_bb.mlir +++ b/test/handshake-runner/call_bb.mlir @@ -23,7 +23,7 @@ module { %c108 = arith.constant 108 : index %c109 = arith.constant 109 : index %c2 = arith.constant 2 : index - %3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index + %3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index %c3 = arith.constant 3 : index %4 = arith.muli %c105, %c3 : index %5 = arith.addi %3, %4 : index diff --git a/tools/handshake-runner/Simulation.cpp b/tools/handshake-runner/Simulation.cpp index 3a0bfc7222..0b640cd444 100644 --- a/tools/handshake-runner/Simulation.cpp +++ b/tools/handshake-runner/Simulation.cpp @@ -586,8 +586,8 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp, entryBlock.getArguments(); // Create a new value map containing only the arguments of the - // InstanceOp. This will be the value and time map for the scope of the - // function pointed to by the InstanceOp. + // InstanceOp. This will be the value and time map for the callee scope of + // the function pointed to by the InstanceOp. llvm::DenseMap scopeValueMap; llvm::DenseMap scopeTimeMap; @@ -605,9 +605,12 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp, apnonearg; std::vector nestedResults(nRealFuncOuts); std::vector nestedResTimes(nRealFuncOuts); + + // Go execute! HandshakeExecuter(func, scopeValueMap, scopeTimeMap, nestedResults, nestedResTimes, store, storeTimes, *module); + // Place the output arguments in the caller scope. for (auto nestedRes : enumerate(nestedResults)) { out[nestedRes.index()] = nestedRes.value(); valueMap[instanceOp.getResults()[nestedRes.index()]] = @@ -615,6 +618,11 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp, timeMap[instanceOp.getResults()[nestedRes.index()]] = nestedResTimes[nestedRes.index()]; } + // ... and the implicit none argument + unsigned ctrlResultIdx = instanceOp.getNumResults() - 1; + valueMap[instanceOp->getResult(ctrlResultIdx)] = apnonearg; + out[ctrlResultIdx] = apnonearg; + return success(); } else { return instanceOp.emitError()