[Handshake] Expose module output control signal through `handshake.instance` (#2064)

This commit is contained in:
Morten Borup Petersen 2021-11-01 15:02:08 +00:00 committed by GitHub
parent 1ec2489571
commit 11431bf7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 18 deletions

View File

@ -107,15 +107,15 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> {
let summary = "module instantiate operation";
let description = [{
The `instance` operation represents the instantiation of a module. This
is similar to a function call, except that different instances of the
same module are guaranteed to have their own distinct state.
The instantiated module is encoded as a
symbol reference attribute named "module".
is similar to a function call, except that different instances of the
same module are guaranteed to have their own distinct state.
The instantiated module is encoded as a symbol reference attribute named
"module". An instance operation takes a control input as its last argument
and returns a control output as its last result.
Example:
```mlir
%2 = handshake.instance @my_add(%0, %1) : (f32, f32) -> f32
%2:2 = handshake.instance @my_add(%0, %1, %ctrl) : (f32, f32, none) -> (f32, none)
```
}];
@ -127,12 +127,14 @@ def InstanceOp : Handshake_Op<"instance", [CallOpInterface]> {
$_state.addOperands(operands);
$_state.addAttribute("module", SymbolRefAttr::get(module));
$_state.addTypes(module.getType().getResults());
$_state.addTypes({$_builder.getType<::mlir::NoneType>()});
}]>, OpBuilder<
(ins "SymbolRefAttr":$module, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("module", module);
$_state.addTypes(results);
$_state.addTypes({$_builder.getType<::mlir::NoneType>()});
}]>, OpBuilder<
(ins "StringRef":$module, "TypeRange":$results,
CArg<"ValueRange", "{}">:$operands), [{

View File

@ -1647,8 +1647,13 @@ LogicalResult replaceCallOps(handshake::FuncOp f,
llvm::copy(callOp.getOperands(), std::back_inserter(operands));
operands.push_back(cntrlMg->getResult(0));
rewriter.setInsertionPoint(callOp);
rewriter.replaceOpWithNewOp<handshake::InstanceOp>(
callOp, callOp.getCallee(), callOp.getResultTypes(), operands);
auto instanceOp = rewriter.create<handshake::InstanceOp>(
callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
operands);
// Replace all results of the source callOp.
for (auto it : llvm::zip(callOp.getResults(), instanceOp.getResults()))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
rewriter.eraseOp(callOp);
}
}
}

View File

@ -16,9 +16,10 @@ func @foo(%0 : i32) -> i32 {
// CHECK-SAME: %[[VAL_0:.*]]: i32,
// CHECK-SAME: %[[VAL_1:.*]]: none, ...) -> (i32, none) {
// CHECK: %[[VAL_2:.*]] = "handshake.merge"(%[[VAL_0]]) : (i32) -> i32
// CHECK: %[[VAL_4:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none)
// CHECK: %[[VAL_3:.*]] = handshake.instance @bar(%[[VAL_2]], %[[VAL_4]]#0) : (i32, none) -> i32
// CHECK: handshake.return %[[VAL_3]], %[[VAL_4]]#1 : i32, none
// CHECK: %[[VAL_3:.*]]:2 = "handshake.fork"(%[[VAL_1]]) {control = true} : (none) -> (none, none)
// CHECK: %[[VAL_4:.*]]:2 = handshake.instance @bar(%[[VAL_2]], %[[VAL_3]]#0) : (i32, none) -> (i32, none)
// CHECK: "handshake.sink"(%[[VAL_4]]#1) : (none) -> ()
// CHECK: handshake.return %[[VAL_4]]#0, %[[VAL_3]]#1 : i32, none
// CHECK: }
%a1 = call @bar(%0) : (i32) -> i32
@ -54,17 +55,19 @@ func @sub(%arg0 : i32, %arg1: i32) -> i32 {
// CHECK: %[[VAL_16:.*]]:2 = "handshake.control_merge"(%[[VAL_12]]) {control = true} : (none) -> (none, index)
// CHECK: %[[VAL_17:.*]]:2 = "handshake.fork"(%[[VAL_16]]#0) {control = true} : (none) -> (none, none)
// CHECK: "handshake.sink"(%[[VAL_16]]#1) : (index) -> ()
// CHECK: %[[VAL_18:.*]] = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> i32
// CHECK: %[[VAL_18:.*]]:2 = handshake.instance @add(%[[VAL_14]], %[[VAL_15]], %[[VAL_17]]#1) : (i32, i32, none) -> (i32, none)
// CHECK: "handshake.sink"(%[[VAL_18]]#1) : (none) -> ()
// CHECK: %[[VAL_19:.*]] = "handshake.branch"(%[[VAL_17]]#0) {control = true} : (none) -> none
// CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]) {control = false} : (i32) -> i32
// CHECK: %[[VAL_20:.*]] = "handshake.branch"(%[[VAL_18]]#0) {control = false} : (i32) -> i32
// CHECK: %[[VAL_21:.*]] = "handshake.merge"(%[[VAL_9]]) : (i32) -> i32
// CHECK: %[[VAL_22:.*]] = "handshake.merge"(%[[VAL_11]]) : (i32) -> i32
// CHECK: %[[VAL_23:.*]]:2 = "handshake.control_merge"(%[[VAL_13]]) {control = true} : (none) -> (none, index)
// CHECK: %[[VAL_24:.*]]:2 = "handshake.fork"(%[[VAL_23]]#0) {control = true} : (none) -> (none, none)
// CHECK: "handshake.sink"(%[[VAL_23]]#1) : (index) -> ()
// CHECK: %[[VAL_25:.*]] = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> i32
// CHECK: %[[VAL_25:.*]]:2 = handshake.instance @sub(%[[VAL_21]], %[[VAL_22]], %[[VAL_24]]#1) : (i32, i32, none) -> (i32, none)
// CHECK: "handshake.sink"(%[[VAL_25]]#1) : (none) -> ()
// CHECK: %[[VAL_26:.*]] = "handshake.branch"(%[[VAL_24]]#0) {control = true} : (none) -> none
// CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]) {control = false} : (i32) -> i32
// CHECK: %[[VAL_27:.*]] = "handshake.branch"(%[[VAL_25]]#0) {control = false} : (i32) -> i32
// CHECK: %[[VAL_28:.*]]:2 = "handshake.control_merge"(%[[VAL_26]], %[[VAL_19]]) {control = true} : (none, none) -> (none, index)
// CHECK: %[[VAL_29:.*]] = "handshake.mux"(%[[VAL_28]]#1, %[[VAL_27]], %[[VAL_20]]) : (index, i32, i32) -> i32
// CHECK: handshake.return %[[VAL_29]], %[[VAL_28]]#0 : i32, none

View File

@ -23,7 +23,7 @@ module {
%c108 = arith.constant 108 : index
%c109 = arith.constant 109 : index
%c2 = arith.constant 2 : index
%3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index
%3 = call @muladd(%c104, %c2, %c103) : (index, index, index) -> index
%c3 = arith.constant 3 : index
%4 = arith.muli %c105, %c3 : index
%5 = arith.addi %3, %4 : index

View File

@ -586,8 +586,8 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
entryBlock.getArguments();
// Create a new value map containing only the arguments of the
// InstanceOp. This will be the value and time map for the scope of the
// function pointed to by the InstanceOp.
// InstanceOp. This will be the value and time map for the callee scope of
// the function pointed to by the InstanceOp.
llvm::DenseMap<mlir::Value, Any> scopeValueMap;
llvm::DenseMap<mlir::Value, double> scopeTimeMap;
@ -605,9 +605,12 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
apnonearg;
std::vector<Any> nestedResults(nRealFuncOuts);
std::vector<double> nestedResTimes(nRealFuncOuts);
// Go execute!
HandshakeExecuter(func, scopeValueMap, scopeTimeMap, nestedResults,
nestedResTimes, store, storeTimes, *module);
// Place the output arguments in the caller scope.
for (auto nestedRes : enumerate(nestedResults)) {
out[nestedRes.index()] = nestedRes.value();
valueMap[instanceOp.getResults()[nestedRes.index()]] =
@ -615,6 +618,11 @@ LogicalResult HandshakeExecuter::execute(handshake::InstanceOp instanceOp,
timeMap[instanceOp.getResults()[nestedRes.index()]] =
nestedResTimes[nestedRes.index()];
}
// ... and the implicit none argument
unsigned ctrlResultIdx = instanceOp.getNumResults() - 1;
valueMap[instanceOp->getResult(ctrlResultIdx)] = apnonearg;
out[ctrlResultIdx] = apnonearg;
return success();
} else {
return instanceOp.emitError()