[HLSKernel] update CNNOps verifiers; [BenchmarkGen] fix bugs in CNNOps generation
This commit is contained in:
parent
cbc7573c16
commit
1f740b15f7
|
@ -26,7 +26,7 @@ void HLSKernelDialect::initialize() {
|
||||||
/// Verify that all memref operands of the operation have static shape.
|
/// Verify that all memref operands of the operation have static shape.
|
||||||
static bool verifyStaticShape(Operation *op) {
|
static bool verifyStaticShape(Operation *op) {
|
||||||
for (auto operand : op->getOperands()) {
|
for (auto operand : op->getOperands()) {
|
||||||
if (auto operandType = operand.getType().dyn_cast<MemRefType>()) {
|
if (auto operandType = operand.getType().dyn_cast<ShapedType>()) {
|
||||||
if (!operandType.hasStaticShape())
|
if (!operandType.hasStaticShape())
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -42,10 +42,14 @@ static LogicalResult verify(DenseOp op) {
|
||||||
if (!verifyStaticShape(op))
|
if (!verifyStaticShape(op))
|
||||||
return op.emitError("not all operands have static shape");
|
return op.emitError("not all operands have static shape");
|
||||||
|
|
||||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||||
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||||
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||||
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
|
ArrayRef<int64_t> OShape;
|
||||||
|
if (op.getNumResults())
|
||||||
|
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||||
|
else
|
||||||
|
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
if ((IShape.size() != 2 && IShape.size() != 4) ||
|
if ((IShape.size() != 2 && IShape.size() != 4) ||
|
||||||
(KShape.size() != 2 && KShape.size() != 4) || BShape.size() != 1 ||
|
(KShape.size() != 2 && KShape.size() != 4) || BShape.size() != 1 ||
|
||||||
|
@ -78,10 +82,14 @@ static LogicalResult verify(ConvOp op) {
|
||||||
if (!verifyStaticShape(op))
|
if (!verifyStaticShape(op))
|
||||||
return op.emitError("not all operands have static shape");
|
return op.emitError("not all operands have static shape");
|
||||||
|
|
||||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||||
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||||
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||||
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
|
ArrayRef<int64_t> OShape;
|
||||||
|
if (op.getNumResults())
|
||||||
|
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||||
|
else
|
||||||
|
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> padding;
|
SmallVector<int64_t, 2> padding;
|
||||||
for (auto shape : op.getAttrOfType<ArrayAttr>("padding"))
|
for (auto shape : op.getAttrOfType<ArrayAttr>("padding"))
|
||||||
|
@ -125,8 +133,12 @@ static LogicalResult verify(MaxPoolOp op) {
|
||||||
if (!verifyStaticShape(op))
|
if (!verifyStaticShape(op))
|
||||||
return op.emitError("not all operands have static shape");
|
return op.emitError("not all operands have static shape");
|
||||||
|
|
||||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||||
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
ArrayRef<int64_t> OShape;
|
||||||
|
if (op.getNumResults())
|
||||||
|
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||||
|
else
|
||||||
|
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> kernelShape;
|
SmallVector<int64_t, 2> kernelShape;
|
||||||
for (auto shape : op.getAttrOfType<ArrayAttr>("kernel_shape"))
|
for (auto shape : op.getAttrOfType<ArrayAttr>("kernel_shape"))
|
||||||
|
@ -176,8 +188,12 @@ static LogicalResult verify(ReluOp op) {
|
||||||
if (!verifyStaticShape(op))
|
if (!verifyStaticShape(op))
|
||||||
return op.emitError("not all operands have static shape");
|
return op.emitError("not all operands have static shape");
|
||||||
|
|
||||||
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||||
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
ArrayRef<int64_t> OShape;
|
||||||
|
if (op.getNumResults())
|
||||||
|
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||||
|
else
|
||||||
|
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
if (IShape != OShape)
|
if (IShape != OShape)
|
||||||
return op.emitError("incorrect operand shape, please refer to the op "
|
return op.emitError("incorrect operand shape, please refer to the op "
|
||||||
|
@ -194,9 +210,13 @@ static LogicalResult verify(MergeOp op) {
|
||||||
if (!verifyStaticShape(op))
|
if (!verifyStaticShape(op))
|
||||||
return op.emitError("not all operands have static shape");
|
return op.emitError("not all operands have static shape");
|
||||||
|
|
||||||
auto I0Shape = op.getOperand(0).getType().cast<MemRefType>().getShape();
|
auto I0Shape = op.getOperand(0).getType().cast<ShapedType>().getShape();
|
||||||
auto I1Shape = op.getOperand(1).getType().cast<MemRefType>().getShape();
|
auto I1Shape = op.getOperand(1).getType().cast<ShapedType>().getShape();
|
||||||
auto OShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
|
ArrayRef<int64_t> OShape;
|
||||||
|
if (op.getNumResults())
|
||||||
|
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
|
||||||
|
else
|
||||||
|
OShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
if (I0Shape != OShape || I1Shape != OShape)
|
if (I0Shape != OShape || I1Shape != OShape)
|
||||||
return op.emitError("incorrect operand shape, please refer to the op "
|
return op.emitError("incorrect operand shape, please refer to the op "
|
||||||
|
|
|
@ -128,8 +128,8 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
||||||
SmallVector<mlir::Type, 2> inputTypes;
|
SmallVector<mlir::Type, 2> inputTypes;
|
||||||
inputTypes.push_back(
|
inputTypes.push_back(
|
||||||
getTensorType({batchSize, inputChannel, inputHeight, inputWidth}));
|
getTensorType({batchSize, inputChannel, inputHeight, inputWidth}));
|
||||||
inputTypes.push_back(getTensorType({batchSize, outputChannel}));
|
|
||||||
SmallVector<mlir::Type, 2> outputTypes;
|
SmallVector<mlir::Type, 2> outputTypes;
|
||||||
|
outputTypes.push_back(getTensorType({batchSize, outputChannel}));
|
||||||
|
|
||||||
auto func = builder.create<FuncOp>(
|
auto func = builder.create<FuncOp>(
|
||||||
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
|
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
|
||||||
|
@ -205,17 +205,17 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the last dense layer.
|
// Create the last dense layer.
|
||||||
fmaps.push_back(func.getArgument(1));
|
|
||||||
kernels.push_back(builder.create<mlir::AllocOp>(
|
kernels.push_back(builder.create<mlir::AllocOp>(
|
||||||
loc, getMemType({outputChannel, topChannel, topHeight, topWidth})));
|
loc, getMemType({outputChannel, topChannel, topHeight, topWidth})));
|
||||||
biases.push_back(
|
biases.push_back(
|
||||||
builder.create<mlir::AllocOp>(loc, getMemType({outputChannel})));
|
builder.create<mlir::AllocOp>(loc, getMemType({outputChannel})));
|
||||||
|
|
||||||
builder.create<DenseOp>(loc, ArrayRef<mlir::Type>(),
|
auto denseLayer = builder.create<DenseOp>(
|
||||||
*std::prev(fmaps.end(), 2), kernels.back(),
|
loc, getTensorType({batchSize, outputChannel}), fmaps.back(),
|
||||||
biases.back(), fmaps.back());
|
kernels.back(), biases.back(), nullptr);
|
||||||
|
fmaps.push_back(denseLayer.getResult(0));
|
||||||
|
|
||||||
builder.create<mlir::ReturnOp>(loc);
|
builder.create<mlir::ReturnOp>(loc, denseLayer.getResult(0));
|
||||||
|
|
||||||
// Add bypass paths to the current model.
|
// Add bypass paths to the current model.
|
||||||
// Ensure the specified bypass number is available. Since the last dense layer
|
// Ensure the specified bypass number is available. Since the last dense layer
|
||||||
|
|
Loading…
Reference in New Issue