[HLSKernel] update CNNOps verifiers; [BenchmarkGen] fix bugs in CNNOps generation
This commit is contained in:
parent
cbc7573c16
commit
1f740b15f7
|
@ -26,7 +26,7 @@ void HLSKernelDialect::initialize() {
|
|||
/// Verify that all memref operands of the operation have static shape.
|
||||
static bool verifyStaticShape(Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operandType = operand.getType().dyn_cast<MemRefType>()) {
|
||||
if (auto operandType = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (!operandType.hasStaticShape())
|
||||
return false;
|
||||
}
|
||||
|
@ -42,10 +42,14 @@ static LogicalResult verify(DenseOp op) {
|
|||
if (!verifyStaticShape(op))
|
||||
return op.emitError("not all operands have static shape");
|
||||
|
||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
||||
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
||||
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
||||
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
|
||||
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> OShape;
|
||||
if (op.getNumResults())
|
||||
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||
else
|
||||
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
|
||||
|
||||
if ((IShape.size() != 2 && IShape.size() != 4) ||
|
||||
(KShape.size() != 2 && KShape.size() != 4) || BShape.size() != 1 ||
|
||||
|
@ -78,10 +82,14 @@ static LogicalResult verify(ConvOp op) {
|
|||
if (!verifyStaticShape(op))
|
||||
return op.emitError("not all operands have static shape");
|
||||
|
||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
||||
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
||||
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
||||
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
|
||||
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> OShape;
|
||||
if (op.getNumResults())
|
||||
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||
else
|
||||
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
|
||||
|
||||
SmallVector<int64_t, 2> padding;
|
||||
for (auto shape : op.getAttrOfType<ArrayAttr>("padding"))
|
||||
|
@ -125,8 +133,12 @@ static LogicalResult verify(MaxPoolOp op) {
|
|||
if (!verifyStaticShape(op))
|
||||
return op.emitError("not all operands have static shape");
|
||||
|
||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
||||
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
||||
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> OShape;
|
||||
if (op.getNumResults())
|
||||
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||
else
|
||||
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||
|
||||
SmallVector<int64_t, 2> kernelShape;
|
||||
for (auto shape : op.getAttrOfType<ArrayAttr>("kernel_shape"))
|
||||
|
@ -176,8 +188,12 @@ static LogicalResult verify(ReluOp op) {
|
|||
if (!verifyStaticShape(op))
|
||||
return op.emitError("not all operands have static shape");
|
||||
|
||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
||||
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
||||
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> OShape;
|
||||
if (op.getNumResults())
|
||||
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||
else
|
||||
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (IShape != OShape)
|
||||
return op.emitError("incorrect operand shape, please refer to the op "
|
||||
|
@ -194,9 +210,13 @@ static LogicalResult verify(MergeOp op) {
|
|||
if (!verifyStaticShape(op))
|
||||
return op.emitError("not all operands have static shape");
|
||||
|
||||
auto I0Shape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
||||
auto I1Shape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
||||
auto OShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
||||
auto I0Shape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||
auto I1Shape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> OShape;
|
||||
if (op.getNumResults())
|
||||
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||
else
|
||||
OShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (I0Shape != OShape || I1Shape != OShape)
|
||||
return op.emitError("incorrect operand shape, please refer to the op "
|
||||
|
|
|
@ -128,8 +128,8 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
|||
SmallVector<mlir::Type, 2> inputTypes;
|
||||
inputTypes.push_back(
|
||||
getTensorType({batchSize, inputChannel, inputHeight, inputWidth}));
|
||||
inputTypes.push_back(getTensorType({batchSize, outputChannel}));
|
||||
SmallVector<mlir::Type, 2> outputTypes;
|
||||
outputTypes.push_back(getTensorType({batchSize, outputChannel}));
|
||||
|
||||
auto func = builder.create<FuncOp>(
|
||||
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
|
||||
|
@ -205,17 +205,17 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
|||
}
|
||||
|
||||
// Create the last dense layer.
|
||||
fmaps.push_back(func.getArgument(1));
|
||||
kernels.push_back(builder.create<mlir::AllocOp>(
|
||||
loc, getMemType({outputChannel, topChannel, topHeight, topWidth})));
|
||||
biases.push_back(
|
||||
builder.create<mlir::AllocOp>(loc, getMemType({outputChannel})));
|
||||
|
||||
builder.create<DenseOp>(loc, ArrayRef<mlir::Type>(),
|
||||
*std::prev(fmaps.end(), 2), kernels.back(),
|
||||
biases.back(), fmaps.back());
|
||||
auto denseLayer = builder.create<DenseOp>(
|
||||
loc, getTensorType({batchSize, outputChannel}), fmaps.back(),
|
||||
kernels.back(), biases.back(), nullptr);
|
||||
fmaps.push_back(denseLayer.getResult(0));
|
||||
|
||||
builder.create<mlir::ReturnOp>(loc);
|
||||
builder.create<mlir::ReturnOp>(loc, denseLayer.getResult(0));
|
||||
|
||||
// Add bypass paths to the current model.
|
||||
// Ensure the specified bypass number is available. Since the last dense layer
|
||||
|
|
Loading…
Reference in New Issue