[Handshake] Expose module output control signal through `handshake.instance` (#2064)

This commit is contained in:
Morten Borup Petersen 2021-11-01 15:02:08 +00:00 committed by GitHub
parent 1ec2489571
commit 11431bf7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 18 deletions

View File

@ -107,15 +107,15 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> {
let summary = "module instantiate operation"; let summary = "module instantiate operation";
let description = [{ let description = [{
The `instance` operation represents the instantiation of a module. This The `instance` operation represents the instantiation of a module. This
is similar to a function call, except that different instances of the is similar to a function call, except that different instances of the
same module are guaranteed to have their own distinct state. same module are guaranteed to have their own distinct state.
The instantiated module is encoded as a The instantiated module is encoded as a symbol reference attribute named
symbol reference attribute named "module". "module". An instance operation takes a control input as its last argument
and returns a control output as its last result.
Example: Example:
```mlir ```mlir
%2 = handshake.instance @my_add(%0, %1) : (f32, f32) -> f32 %2:2 = handshake.instance @my_add(%0, %1, %ctrl) : (f32, f32, none) -> (f32, none)
``` ```
}]; }];
@ -127,12 +127,14 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> {
$_state.addOperands(operands); $_state.addOperands(operands);
$_state.addAttribute("module", SymbolRefAttr::get(module)); $_state.addAttribute("module", SymbolRefAttr::get(module));
$_state.addTypes(module.getType().getResults()); $_state.addTypes(module.getType().getResults());
$_state.addTypes({$_builder.getType<::mlir::NoneType>()});
}]>, OpBuilder< }]>, OpBuilder<
(ins "SymbolRefAttr":$module, "TypeRange":$results, (ins "SymbolRefAttr":$module, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{ CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands); $_state.addOperands(operands);
$_state.addAttribute("module", module); $_state.addAttribute("module", module);
$_state.addTypes(results); $_state.addTypes(results);
$_state.addTypes({$_builder.getType<::mlir::NoneType>()});
}]>, OpBuilder< }]>, OpBuilder<
(ins "StringRef":$module, "TypeRange":$results, (ins "StringRef":$module, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{ CArg<"ValueRange", "{}">:$operands), [{

View File

@ -1647,8 +1647,13 @@ LogicalResult replaceCallOps(handshake::FuncOp f,
llvm::copy(callOp.getOperands(), std::back_inserter(operands)); llvm::copy(callOp.getOperands(), std::back_inserter(operands));
operands.push_back(cntrlMg->getResult(0)); operands.push_back(cntrlMg->getResult(0));
rewriter.setInsertionPoint(callOp); rewriter.setInsertionPoint(callOp);
rewriter.replaceOpWithNewOp<handshake::InstanceOp>( auto instanceOp = rewriter.create<handshake::InstanceOp>(
callOp, callOp.getCallee(), callOp.getResultTypes(), operands); callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
operands);
// Replace all results of the source callOp.
for (auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
rewriter.eraseOp(callOp);
} }
} }
} }

View File

@ -16,9 +16,10 @@ func @foo(%0 : i32) -> i32 {
// CHECK-SAME: %[[VAL_0:.*]]: i32, // CHECK-SAME: %[[VAL_0:.*]]: i32,
// CHECK-SAME: %[[VAL_1:.*]]: none, ...) -> (i32, none) { // CHECK-SAME: %[[VAL_1:.*]]: none, ...) -> (i32, none) {
// CHECK: %[[VAL_2:.*]] = "handshake.merge"(%[[VAL_0]]) : (i32) -> i32 // CHECK: %[[VAL_2:.*]] = "handshake.merge"(%[[VAL_0]]) : (i32) -> i32
// CHECK: %[[VAL_4:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none) // CHECK: %[[VAL_3:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none)
// CHECK: %[[VAL_3:.*]] = handshake.instance @bar(%[[VAL_2]], %[[VAL_4]]#0) : (i32, none) -> i32 // CHECK: %[[VAL_4:.*]]:2 = handshake.instance @bar(%[[VAL_2]], %[[VAL_3]]#0) : (i32, none) -> (i32, none)
// CHECK: handshake.return %[[VAL_3]], %[[VAL_4]]#1 : i32, none // CHECK: "handshake.sink"(%[[VAL_4]]#1) : (none) -> ()
// CHECK: handshake.return %[[VAL_4]]#0, %[[VAL_3]]#1 : i32, none
// CHECK: } // CHECK: }
%a1 = call @bar(%0) : (i32) -> i32 %a1 = call @bar(%0) : (i32) -> i32
@ -54,17 +55,19 @@ func @sub(%arg0 : i32, %arg1: i32) -> i32 {
// CHECK: %[[VAL_16:.*]]:2 = "handshake.control_merge"(%[[VAL_12]]) {control = true} : (none) -> (none, index) // CHECK: %[[VAL_16:.*]]:2 = "handshake.control_merge"(%[[VAL_12]]) {control = true} : (none) -> (none, index)
// CHECK: %[[VAL_17:.*]]:2 = "handshake.fork"(%[[VAL_16]]#0) {control = true} : (none) -> (none, none) // CHECK: %[[VAL_17:.*]]:2 = "handshake.fork"(%[[VAL_16]]#0) {control = true} : (none) -> (none, none)
// CHECK: "handshake.sink"(%[[VAL_16]]#1) : (index) -> () // CHECK: "handshake.sink"(%[[VAL_16]]#1) : (index) -> ()
// CHECK: %[[VAL_18:.*]] = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> i32 // CHECK: %[[VAL_18:.*]]:2 = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> (i32, none)
// CHECK: "handshake.sink"(%[[VAL_18]]#1) : (none) -> ()
// CHECK: %[[VAL_19:.*]] = "handshake.branch"(%[[VAL_17]]#0) {control = true} : (none) -> none // CHECK: %[[VAL_19:.*]] = "handshake.branch"(%[[VAL_17]]#0) {control = true} : (none) -> none
// CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]) {control = false} : (i32) -> i32 // CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]#0) {control = false} : (i32) -> i32
// CHECK: %[[VAL_21:.*]] = "handshake.merge"(%[[VAL_9]]) : (i32) -> i32 // CHECK: %[[VAL_21:.*]] = "handshake.merge"(%[[VAL_9]]) : (i32) -> i32
// CHECK: %[[VAL_22:.*]] = "handshake.merge"(%[[VAL_11]]) : (i32) -> i32 // CHECK: %[[VAL_22:.*]] = "handshake.merge"(%[[VAL_11]]) : (i32) -> i32
// CHECK: %[[VAL_23:.*]]:2 = "handshake.control_merge"(%[[VAL_13]]) {control = true} : (none) -> (none, index) // CHECK: %[[VAL_23:.*]]:2 = "handshake.control_merge"(%[[VAL_13]]) {control = true} : (none) -> (none, index)
// CHECK: %[[VAL_24:.*]]:2 = "handshake.fork"(%[[VAL_23]]#0) {control = true} : (none) -> (none, none) // CHECK: %[[VAL_24:.*]]:2 = "handshake.fork"(%[[VAL_23]]#0) {control = true} : (none) -> (none, none)
// CHECK: "handshake.sink"(%[[VAL_23]]#1) : (index) -> () // CHECK: "handshake.sink"(%[[VAL_23]]#1) : (index) -> ()
// CHECK: %[[VAL_25:.*]] = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> i32 // CHECK: %[[VAL_25:.*]]:2 = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> (i32, none)
// CHECK: "handshake.sink"(%[[VAL_25]]#1) : (none) -> ()
// CHECK: %[[VAL_26:.*]] = "handshake.branch"(%[[VAL_24]]#0) {control = true} : (none) -> none // CHECK: %[[VAL_26:.*]] = "handshake.branch"(%[[VAL_24]]#0) {control = true} : (none) -> none
// CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]) {control = false} : (i32) -> i32 // CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]#0) {control = false} : (i32) -> i32
// CHECK: %[[VAL_28:.*]]:2 = "handshake.control_merge"(%[[VAL_26]], %[[VAL_19]]) {control = true} : (none, none) -> (none, index) // CHECK: %[[VAL_28:.*]]:2 = "handshake.control_merge"(%[[VAL_26]], %[[VAL_19]]) {control = true} : (none, none) -> (none, index)
// CHECK: %[[VAL_29:.*]] = "handshake.mux"(%[[VAL_28]]#1, %[[VAL_27]], %[[VAL_20]]) : (index, i32, i32) -> i32 // CHECK: %[[VAL_29:.*]] = "handshake.mux"(%[[VAL_28]]#1, %[[VAL_27]], %[[VAL_20]]) : (index, i32, i32) -> i32
// CHECK: handshake.return %[[VAL_29]], %[[VAL_28]]#0 : i32, none // CHECK: handshake.return %[[VAL_29]], %[[VAL_28]]#0 : i32, none

View File

@ -23,7 +23,7 @@ module {
%c108 = arith.constant 108 : index %c108 = arith.constant 108 : index
%c109 = arith.constant 109 : index %c109 = arith.constant 109 : index
%c2 = arith.constant 2 : index %c2 = arith.constant 2 : index
%3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index %3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index
%c3 = arith.constant 3 : index %c3 = arith.constant 3 : index
%4 = arith.muli %c105, %c3 : index %4 = arith.muli %c105, %c3 : index
%5 = arith.addi %3, %4 : index %5 = arith.addi %3, %4 : index

View File

@ -586,8 +586,8 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
entryBlock.getArguments(); entryBlock.getArguments();
// Create a new value map containing only the arguments of the // Create a new value map containing only the arguments of the
// InstanceOp. This will be the value and time map for the scope of the // InstanceOp. This will be the value and time map for the callee scope of
// function pointed to by the InstanceOp. // the function pointed to by the InstanceOp.
llvm::DenseMap<mlir::Value, Any> scopeValueMap; llvm::DenseMap<mlir::Value, Any> scopeValueMap;
llvm::DenseMap<mlir::Value, double> scopeTimeMap; llvm::DenseMap<mlir::Value, double> scopeTimeMap;
@ -605,9 +605,12 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
apnonearg; apnonearg;
std::vector<Any> nestedResults(nRealFuncOuts); std::vector<Any> nestedResults(nRealFuncOuts);
std::vector<double> nestedResTimes(nRealFuncOuts); std::vector<double> nestedResTimes(nRealFuncOuts);
// Go execute!
HandshakeExecuter(func, scopeValueMap, scopeTimeMap, nestedResults, HandshakeExecuter(func, scopeValueMap, scopeTimeMap, nestedResults,
nestedResTimes, store, storeTimes, *module); nestedResTimes, store, storeTimes, *module);
// Place the output arguments in the caller scope.
for (auto nestedRes : enumerate(nestedResults)) { for (auto nestedRes : enumerate(nestedResults)) {
out[nestedRes.index()] = nestedRes.value(); out[nestedRes.index()] = nestedRes.value();
valueMap[instanceOp.getResults()[nestedRes.index()]] = valueMap[instanceOp.getResults()[nestedRes.index()]] =
@ -615,6 +618,11 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
timeMap[instanceOp.getResults()[nestedRes.index()]] = timeMap[instanceOp.getResults()[nestedRes.index()]] =
nestedResTimes[nestedRes.index()]; nestedResTimes[nestedRes.index()];
} }
// ... and the implicit none argument
unsigned ctrlResultIdx = instanceOp.getNumResults() - 1;
valueMap[instanceOp->getResult(ctrlResultIdx)] = apnonearg;
out[ctrlResultIdx] = apnonearg;
return success(); return success();
} else { } else {
return instanceOp.emitError() return instanceOp.emitError()