[HLSKernel] update CNNOps verifiers; [BenchmarkGen] fix bugs in CNNOps generation

This commit is contained in:
Hanchen Ye 2020-12-10 19:07:11 -06:00
parent cbc7573c16
commit 1f740b15f7
2 changed files with 42 additions and 22 deletions

View File

@ -26,7 +26,7 @@ void HLSKernelDialect::initialize() {
/// Verify that all memref operands of the operation have static shape.
static bool verifyStaticShape(Operation *op) {
for (auto operand : op->getOperands()) {
if (auto operandType = operand.getType().dyn_cast<MemRefType>()) {
if (auto operandType = operand.getType().dyn_cast<ShapedType>()) {
if (!operandType.hasStaticShape())
return false;
}
@ -42,10 +42,14 @@ static LogicalResult verify(DenseOp op) {
if (!verifyStaticShape(op))
return op.emitError("not all operands have static shape");
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> OShape;
if (op.getNumResults())
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
else
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
if ((IShape.size() != 2 && IShape.size() != 4) ||
(KShape.size() != 2 && KShape.size() != 4) || BShape.size() != 1 ||
@ -78,10 +82,14 @@ static LogicalResult verify(ConvOp op) {
if (!verifyStaticShape(op))
return op.emitError("not all operands have static shape");
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
auto KShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
auto BShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
auto OShape = op.getOperand(3).getType().cast<MemRefType>().getShape();
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
auto KShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
auto BShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> OShape;
if (op.getNumResults())
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
else
OShape = op.getOperand(3).getType().cast<ShapedType>().getShape();
SmallVector<int64_t, 2> padding;
for (auto shape : op.getAttrOfType<ArrayAttr>("padding"))
@ -125,8 +133,12 @@ static LogicalResult verify(MaxPoolOp op) {
if (!verifyStaticShape(op))
return op.emitError("not all operands have static shape");
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> OShape;
if (op.getNumResults())
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
else
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
SmallVector<int64_t, 2> kernelShape;
for (auto shape : op.getAttrOfType<ArrayAttr>("kernel_shape"))
@ -176,8 +188,12 @@ static LogicalResult verify(ReluOp op) {
if (!verifyStaticShape(op))
return op.emitError("not all operands have static shape");
auto IShape = op.getOperand(0).getType().cast<MemRefType>().getShape();
auto OShape = op.getOperand(1).getType().cast<MemRefType>().getShape();
auto IShape = op.getOperand(0).getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> OShape;
if (op.getNumResults())
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
else
OShape = op.getOperand(1).getType().cast<ShapedType>().getShape();
if (IShape != OShape)
return op.emitError("incorrect operand shape, please refer to the op "
@ -194,9 +210,13 @@ static LogicalResult verify(MergeOp op) {
if (!verifyStaticShape(op))
return op.emitError("not all operands have static shape");
auto I0Shape = op.getOperand(0).getType().cast<MemRefType>().getShape();
auto I1Shape = op.getOperand(1).getType().cast<MemRefType>().getShape();
auto OShape = op.getOperand(2).getType().cast<MemRefType>().getShape();
auto I0Shape = op.getOperand(0).getType().cast<ShapedType>().getShape();
auto I1Shape = op.getOperand(1).getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> OShape;
if (op.getNumResults())
OShape = op.getResult(0).getType().cast<ShapedType>().getShape();
else
OShape = op.getOperand(2).getType().cast<ShapedType>().getShape();
if (I0Shape != OShape || I1Shape != OShape)
return op.emitError("incorrect operand shape, please refer to the op "

View File

@ -128,8 +128,8 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
SmallVector<mlir::Type, 2> inputTypes;
inputTypes.push_back(
getTensorType({batchSize, inputChannel, inputHeight, inputWidth}));
inputTypes.push_back(getTensorType({batchSize, outputChannel}));
SmallVector<mlir::Type, 2> outputTypes;
outputTypes.push_back(getTensorType({batchSize, outputChannel}));
auto func = builder.create<FuncOp>(
loc, "tmp", builder.getFunctionType(inputTypes, outputTypes));
@ -205,17 +205,17 @@ LogicalResult BenchmarkGenerator::genCNN(INIReader config) {
}
// Create the last dense layer.
fmaps.push_back(func.getArgument(1));
kernels.push_back(builder.create<mlir::AllocOp>(
loc, getMemType({outputChannel, topChannel, topHeight, topWidth})));
biases.push_back(
builder.create<mlir::AllocOp>(loc, getMemType({outputChannel})));
builder.create<DenseOp>(loc, ArrayRef<mlir::Type>(),
*std::prev(fmaps.end(), 2), kernels.back(),
biases.back(), fmaps.back());
auto denseLayer = builder.create<DenseOp>(
loc, getTensorType({batchSize, outputChannel}), fmaps.back(),
kernels.back(), biases.back(), nullptr);
fmaps.push_back(denseLayer.getResult(0));
builder.create<mlir::ReturnOp>(loc);
builder.create<mlir::ReturnOp>(loc, denseLayer.getResult(0));
// Add bypass paths to the current model.
// Ensure the specified bypass number is available. Since the last dense layer