diff --git a/lib/Dialect/LLHD/IR/LLHDOps.cpp b/lib/Dialect/LLHD/IR/LLHDOps.cpp index 3a4e5f88b9..5f99fa4266 100644 --- a/lib/Dialect/LLHD/IR/LLHDOps.cpp +++ b/lib/Dialect/LLHD/IR/LLHDOps.cpp @@ -841,10 +841,6 @@ static LogicalResult verify(llhd::InstOp op) { auto proc = op->getParentOfType().lookupSymbol( calleeAttr.getValue()); - auto entity = op->getParentOfType().lookupSymbol( - calleeAttr.getValue()); - - // Verify that the input and output types match the callee. if (proc) { auto type = proc.getType(); @@ -856,12 +852,16 @@ static LogicalResult verify(llhd::InstOp op) { return op.emitOpError( "incorrect number of outputs for proc instantiation"); - for (size_t i = 0, e = type.getNumInputs(); i != e; ++i) + for (size_t i = 0, e = type.getNumInputs(); i != e; ++i) { if (op.getOperand(i).getType() != type.getInput(i)) return op.emitOpError("operand type mismatch"); + } return success(); } + + auto entity = op->getParentOfType().lookupSymbol( + calleeAttr.getValue()); if (entity) { auto type = entity.getType(); @@ -873,14 +873,51 @@ static LogicalResult verify(llhd::InstOp op) { return op.emitOpError( "incorrect number of outputs for entity instantiation"); - for (size_t i = 0, e = type.getNumInputs(); i != e; ++i) + for (size_t i = 0, e = type.getNumInputs(); i != e; ++i) { if (op.getOperand(i).getType() != type.getInput(i)) return op.emitOpError("operand type mismatch"); + } return success(); } - return op.emitOpError() << "'" << calleeAttr.getValue() - << "' does not reference a valid proc or entity"; + + auto module = op->getParentOfType().lookupSymbol( + calleeAttr.getValue()); + if (module) { + auto type = module.getType(); + + if (type.getNumInputs() != op.inputs().size()) + return op.emitOpError( + "incorrect number of inputs for hw.module instantiation"); + + if (type.getNumResults() + type.getNumInputs() != op.getNumOperands()) + return op.emitOpError( + "incorrect number of outputs for hw.module instantiation"); + + // Check input types + for (size_t i = 0, e = type.getNumInputs(); i != e; ++i) { + if (op.getOperand(i) + .getType() + .cast() + .getUnderlyingType() != type.getInput(i)) + return op.emitOpError("input type mismatch"); + } + + // Check output types + for (size_t i = 0, e = type.getNumResults(); i != e; ++i) { + if (op.getOperand(type.getNumInputs() + i) + .getType() + .cast() + .getUnderlyingType() != type.getResult(i)) + return op.emitOpError("output type mismatch"); + } + + return success(); + } + + return op.emitOpError() + << "'" << calleeAttr.getValue() + << "' does not reference a valid proc, entity, or hw.module"; } FunctionType llhd::InstOp::getCalleeType() { diff --git a/test/Dialect/LLHD/IR/inst-errors.mlir b/test/Dialect/LLHD/IR/inst-errors.mlir index 0ffe1de2a2..0edeac0faa 100644 --- a/test/Dialect/LLHD/IR/inst-errors.mlir +++ b/test/Dialect/LLHD/IR/inst-errors.mlir @@ -28,7 +28,7 @@ llhd.entity @caller(%arg : !llhd.sig) -> () { // ----- llhd.entity @caller() -> () { - // expected-error @+1 {{does not reference a valid proc or entity}} + // expected-error @+1 {{does not reference a valid proc, entity, or hw.module}} llhd.inst "does_not_exist" @does_not_exist() -> () : () -> () } @@ -41,3 +41,24 @@ llhd.entity @test_uniqueness() -> () { // expected-error @+1 {{Redefinition of instance named 'inst'!}} llhd.inst "inst" @empty() -> () : () -> () } + +// ----- + +hw.module @module(%arg0: i2) -> () {} + +llhd.entity @moduleTypeMismatch(%arg0: !llhd.sig) -> () { + // expected-error @+1 {{input type mismatch}} + llhd.inst "inst" @module(%arg0) -> () : (!llhd.sig) -> () +} + +// ----- + +hw.module @module() -> (arg0: i2) { + %0 = hw.constant 0 : i2 + hw.output %0 : i2 +} + +llhd.entity @moduleTypeMismatch() -> (%arg0: !llhd.sig) { + // expected-error @+1 {{output type mismatch}} + llhd.inst "inst" @module() -> (%arg0) : () -> !llhd.sig +} diff --git a/test/Dialect/LLHD/IR/inst.mlir b/test/Dialect/LLHD/IR/inst.mlir index 2367c01eab..efbf8d426a 100644 --- a/test/Dialect/LLHD/IR/inst.mlir +++ b/test/Dialect/LLHD/IR/inst.mlir @@ -39,6 +39,12 @@ llhd.proc @proc(%arg0 : !llhd.sig, %arg1 : !llhd.sig) -> (%out0 : !llh llhd.halt } +// CHECK-LABEL: @hwModule +hw.module @hwModule(%arg0 : i32, %arg1 : i16) -> (arg2: i8) { + %0 = hw.constant 2 : i8 + hw.output %0 : i8 +} + // CHECK: llhd.entity @caller (%[[ARG0:.*]] : !llhd.sig, %[[ARG1:.*]] : !llhd.sig) -> (%[[OUT0:.*]] : !llhd.sig, %[[OUT1:.*]] : !llhd.sig) { llhd.entity @caller(%arg0 : !llhd.sig, %arg1 : !llhd.sig) -> (%out0 : !llhd.sig, %out1 : !llhd.sig) { // CHECK-NEXT: llhd.inst "empty_entity" @empty_entity() -> () : () -> () @@ -57,5 +63,7 @@ llhd.entity @caller(%arg0 : !llhd.sig, %arg1 : !llhd.sig) -> (%out0 : "llhd.inst"(%arg0, %arg1, %out0, %out1) {callee=@entity, operand_segment_sizes=dense<[2,2]> : vector<2xi32>, name="entity"} : (!llhd.sig, !llhd.sig, !llhd.sig, !llhd.sig) -> () // CHECK-NEXT: llhd.inst "proc" @proc(%[[ARG0]], %[[ARG1]]) -> (%[[OUT0]], %[[OUT1]]) : (!llhd.sig, !llhd.sig) -> (!llhd.sig, !llhd.sig) "llhd.inst"(%arg0, %arg1, %out0, %out1) {callee=@proc, operand_segment_sizes=dense<[2,2]> : vector<2xi32>, name="proc"} : (!llhd.sig, !llhd.sig, !llhd.sig, !llhd.sig) -> () + // CHECK-NEXT: llhd.inst "module" @hwModule(%[[ARG0]], %[[ARG1]]) -> (%[[OUT0]]) : (!llhd.sig, !llhd.sig) -> !llhd.sig + llhd.inst "module" @hwModule(%arg0, %arg1) -> (%out0) : (!llhd.sig, !llhd.sig) -> !llhd.sig // CHECK-NEXT: } }