[HLSKernel] add maxpool, relu, and gemm operations

This commit is contained in:
Hanchen Ye 2020-11-04 22:40:29 -06:00
parent e74dde111a
commit 6adc65da7a
6 changed files with 94 additions and 5 deletions

View File

@ -12,7 +12,8 @@ def ConvOp : HLSKernelOp<"conv", [HLSKernelOpInterface]> {
row, and col), 4-dims W (input channel, output channel, kernel row, and
kernel col), 1-dim B (output channel), and 4-dims Y (batch, channel, row,
and col) are supported. Meanwhile, in the current lowering, padding is not
allowed. Verifiers will ensure the legalness of the operation}];
allowed. Verifiers will ensure the legalness of the operation.
}];
let arguments = (ins
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
@ -26,4 +27,59 @@ def ConvOp : HLSKernelOp<"conv", [HLSKernelOpInterface]> {
);
}
def MaxPoolOp : HLSKernelOp<"maxpool", [HLSKernelOpInterface]> {
let summary = "max pooling operation";
let description = [{
Max pooling operation. For now, only static shaped 4-dims X (batch, channel,
row, and col) and 4-dims Y (batch, channel, row, and col) are supported.
Verifiers will ensure the legalness of the operation.
}];
let arguments = (ins
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides
);
let results = (outs
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
);
}
def ReluOp : HLSKernelOp<"relu", [HLSKernelOpInterface]> {
let summary = "relu operation";
let description = [{
ReLU operation. For now, only static shaped 4-dims X (batch, channel, row,
and col) and 4-dims Y (batch, channel, row, and col) are supported.
Verifiers will ensure the legalness of the operation.
}];
let arguments = (ins
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X
);
let results = (outs
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
);
}
def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
let summary = "gemm operation";
let description = [{
GEMM operation. For now, only static shaped 2-dims X (batch and channel),
2-dims W (input and output channel), 1-dim B (output channel), and 2-dims Y
(batch and channel) are supported. Verifiers will ensure the legalness of
the operation.
}];
let arguments = (ins
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$W,
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$B
);
let results = (outs
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
);
}
#endif // SCALEHLS_DIALECT_HLSKERNEL_OPS_TD

View File

@ -21,7 +21,7 @@ public:
return TypeSwitch<Operation *, ResultType>(op)
.template Case<
// HLSKernel operations.
ConvOp>([&](auto opNode) -> ResultType {
ConvOp, MaxPoolOp, ReluOp, GemmOp>([&](auto opNode) -> ResultType {
return thisCast->visitOp(opNode, args...);
})
.Default([&](auto opNode) -> ResultType {
@ -48,6 +48,9 @@ public:
// HLSKernel operations.
HANDLE(ConvOp);
HANDLE(MaxPoolOp);
HANDLE(ReluOp);
HANDLE(GemmOp);
#undef HANDLE
};

View File

@ -27,6 +27,9 @@ public:
using HLSKernelVisitorBase::visitOp;
bool visitOp(ConvOp op);
bool visitOp(MaxPoolOp op);
bool visitOp(ReluOp op);
bool visitOp(GemmOp op);
};
} // namespace
@ -112,6 +115,12 @@ bool HLSKernelVisitor::visitOp(ConvOp op) {
return true;
}
bool HLSKernelVisitor::visitOp(MaxPoolOp op) { return true; }
bool HLSKernelVisitor::visitOp(ReluOp op) { return true; }
bool HLSKernelVisitor::visitOp(GemmOp op) { return true; }
//===----------------------------------------------------------------------===//
// HLSkernel to Affine Lowering Pass
//===----------------------------------------------------------------------===//
@ -130,9 +139,9 @@ void HLSKernelToAffinePass::runOnOperation() {
for (auto &op : getOperation()) {
if (auto func = dyn_cast<FuncOp>(op)) {
func.walk([&](HLSKernelOpInterface kernelOp) {
if (visitor.dispatchVisitor(kernelOp))
kernelOp.erase();
else
if (visitor.dispatchVisitor(kernelOp)) {
// kernelOp.erase();
} else
kernelOp.emitError("can't be correctly lowered.");
});
} else if (!isa<ModuleTerminatorOp>(op))

View File

@ -0,0 +1,7 @@
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
// CHECK: module {
func @test_gemm(%x: memref<10x128xf32>, %w: memref<16x128xf32>, %b: memref<16xf32>) -> (memref<10x16xf32>){
%y = "hlskernel.gemm" (%x, %w, %b) {} : (memref<10x128xf32>, memref<16x128xf32>, memref<16xf32>) -> (memref<10x16xf32>)
return %y : memref<10x16xf32>
}

View File

@ -0,0 +1,7 @@
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
// CHECK: module {
func @test_maxpool(%x: memref<10x16x28x28xf32>) -> (memref<10x16x14x14xf32>){
%y = "hlskernel.maxpool" (%x) {kernel_shape=[2, 2]} : (memref<10x16x28x28xf32>) -> (memref<10x16x14x14xf32>)
return %y : memref<10x16x14x14xf32>
}

View File

@ -0,0 +1,7 @@
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
// CHECK: module {
func @test_relu(%x: memref<10x16x28x28xf32>) -> (memref<10x16x28x28xf32>){
%y = "hlskernel.relu" (%x) {} : (memref<10x16x28x28xf32>) -> (memref<10x16x28x28xf32>)
return %y : memref<10x16x28x28xf32>
}