diff --git a/lib/Dialect/HLSKernel/HLSKernel.cpp b/lib/Dialect/HLSKernel/HLSKernel.cpp index 89ce503..50d101d 100644 --- a/lib/Dialect/HLSKernel/HLSKernel.cpp +++ b/lib/Dialect/HLSKernel/HLSKernel.cpp @@ -26,7 +26,7 @@ void HLSKernelDialect::initialize() { /// Verify that all memref operands of the operation have static shape. static bool verifyStaticShape(Operation *op) { for (auto operand : op->getOperands()) { - if (auto operandType = operand.getType().dyn_cast()) { + if (auto operandType = operand.getType().dyn_cast()) { if (!operandType.hasStaticShape()) return false; } @@ -42,10 +42,14 @@ static LogicalResult verify(DenseOp op) { if (!verifyStaticShape(op)) return op.emitError("not all operands have static shape"); - auto IShape = op.getOperand(0).getType().cast().getShape(); - auto KShape = op.getOperand(1).getType().cast().getShape(); - auto BShape = op.getOperand(2).getType().cast().getShape(); - auto OShape = op.getOperand(3).getType().cast().getShape(); + auto IShape = op.getOperand(0).getType().cast().getShape(); + auto KShape = op.getOperand(1).getType().cast().getShape(); + auto BShape = op.getOperand(2).getType().cast().getShape(); + ArrayRef OShape; + if (op.getNumResults()) + OShape = op.getResult(0).getType().cast().getShape(); + else + OShape = op.getOperand(3).getType().cast().getShape(); if ((IShape.size() != 2 && IShape.size() != 4) || (KShape.size() != 2 && KShape.size() != 4) || BShape.size() != 1 || @@ -78,10 +82,14 @@ static LogicalResult verify(ConvOp op) { if (!verifyStaticShape(op)) return op.emitError("not all operands have static shape"); - auto IShape = op.getOperand(0).getType().cast().getShape(); - auto KShape = op.getOperand(1).getType().cast().getShape(); - auto BShape = op.getOperand(2).getType().cast().getShape(); - auto OShape = op.getOperand(3).getType().cast().getShape(); + auto IShape = op.getOperand(0).getType().cast().getShape(); + auto KShape = op.getOperand(1).getType().cast().getShape(); + auto BShape = op.getOperand(2).getType().cast().getShape(); + ArrayRef OShape; + if (op.getNumResults()) + OShape = op.getResult(0).getType().cast().getShape(); + else + OShape = op.getOperand(3).getType().cast().getShape(); SmallVector padding; for (auto shape : op.getAttrOfType("padding")) @@ -125,8 +133,12 @@ static LogicalResult verify(MaxPoolOp op) { if (!verifyStaticShape(op)) return op.emitError("not all operands have static shape"); - auto IShape = op.getOperand(0).getType().cast().getShape(); - auto OShape = op.getOperand(1).getType().cast().getShape(); + auto IShape = op.getOperand(0).getType().cast().getShape(); + ArrayRef OShape; + if (op.getNumResults()) + OShape = op.getResult(0).getType().cast().getShape(); + else + OShape = op.getOperand(1).getType().cast().getShape(); SmallVector kernelShape; for (auto shape : op.getAttrOfType("kernel_shape")) @@ -176,8 +188,12 @@ static LogicalResult verify(ReluOp op) { if (!verifyStaticShape(op)) return op.emitError("not all operands have static shape"); - auto IShape = op.getOperand(0).getType().cast().getShape(); - auto OShape = op.getOperand(1).getType().cast().getShape(); + auto IShape = op.getOperand(0).getType().cast().getShape(); + ArrayRef OShape; + if (op.getNumResults()) + OShape = op.getResult(0).getType().cast().getShape(); + else + OShape = op.getOperand(1).getType().cast().getShape(); if (IShape != OShape) return op.emitError("incorrect operand shape, please refer to the op " @@ -194,9 +210,13 @@ static LogicalResult verify(MergeOp op) { if (!verifyStaticShape(op)) return op.emitError("not all operands have static shape"); - auto I0Shape = op.getOperand(0).getType().cast().getShape(); - auto I1Shape = op.getOperand(1).getType().cast().getShape(); - auto OShape = op.getOperand(2).getType().cast().getShape(); + auto I0Shape = op.getOperand(0).getType().cast().getShape(); + auto I1Shape = op.getOperand(1).getType().cast().getShape(); + ArrayRef OShape; + if (op.getNumResults()) + OShape = op.getResult(0).getType().cast().getShape(); + else + OShape = op.getOperand(2).getType().cast().getShape(); if (I0Shape != OShape || I1Shape != OShape) return op.emitError("incorrect operand shape, please refer to the op " diff --git a/tools/benchmark-gen/benchmark-gen.cpp b/tools/benchmark-gen/benchmark-gen.cpp index 953eb0f..1c2f8f1 100644 --- a/tools/benchmark-gen/benchmark-gen.cpp +++ b/tools/benchmark-gen/benchmark-gen.cpp @@ -128,8 +128,8 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) { SmallVector inputTypes; inputTypes.push_back( getTensorType({batchSize, inputChannel, inputHeight, inputWidth})); - inputTypes.push_back(getTensorType({batchSize, outputChannel})); SmallVector outputTypes; + outputTypes.push_back(getTensorType({batchSize, outputChannel})); auto func = builder.create( loc, "tmp", builder.getFunctionType(inputTypes, outputTypes)); @@ -205,17 +205,17 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) { } // Create the last dense layer. - fmaps.push_back(func.getArgument(1)); kernels.push_back(builder.create( loc, getMemType({outputChannel, topChannel, topHeight, topWidth}))); biases.push_back( builder.create(loc, getMemType({outputChannel}))); - builder.create(loc, ArrayRef(), - *std::prev(fmaps.end(), 2), kernels.back(), - biases.back(), fmaps.back()); + auto denseLayer = builder.create( + loc, getTensorType({batchSize, outputChannel}), fmaps.back(), + kernels.back(), biases.back(), nullptr); + fmaps.push_back(denseLayer.getResult(0)); - builder.create(loc); + builder.create(loc, denseLayer.getResult(0)); // Add bypass paths to the current model. // Ensure the specified bypass number is available. Since the last dense layer