[HLSKernel] add maxpool, relu, and gemm operations
This commit is contained in:
parent
e74dde111a
commit
6adc65da7a
|
@ -12,7 +12,8 @@ def ConvOp : HLSKernelOp<"conv", [HLSKernelOpInterface]> {
|
|||
row, and col), 4-dims W (input channel, output channel, kernel row, and
|
||||
kernel col), 1-dim B (output channel), and 4-dims Y (batch, channel, row,
|
||||
and col) are supported. Meanwhile, in the current lowering, padding is not
|
||||
allowed. Verifiers will ensure the legalness of the operation}];
|
||||
allowed. Verifiers will ensure the legalness of the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
|
||||
|
@ -26,4 +27,59 @@ def ConvOp : HLSKernelOp<"conv", [HLSKernelOpInterface]> {
|
|||
);
|
||||
}
|
||||
|
||||
def MaxPoolOp : HLSKernelOp<"maxpool", [HLSKernelOpInterface]> {
|
||||
let summary = "max pooling operation";
|
||||
let description = [{
|
||||
Max pooling operation. For now, only static shaped 4-dims X (batch, channel,
|
||||
row, and col) and 4-dims Y (batch, channel, row, and col) are supported.
|
||||
Verifiers will ensure the legalness of the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
|
||||
I64ArrayAttr:$kernel_shape,
|
||||
OptionalAttr<I64ArrayAttr>:$pads,
|
||||
OptionalAttr<I64ArrayAttr>:$strides
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
|
||||
);
|
||||
}
|
||||
|
||||
def ReluOp : HLSKernelOp<"relu", [HLSKernelOpInterface]> {
|
||||
let summary = "relu operation";
|
||||
let description = [{
|
||||
ReLU operation. For now, only static shaped 4-dims X (batch, channel, row,
|
||||
and col) and 4-dims Y (batch, channel, row, and col) are supported.
|
||||
Verifiers will ensure the legalness of the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
|
||||
);
|
||||
}
|
||||
|
||||
def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
|
||||
let summary = "gemm operation";
|
||||
let description = [{
|
||||
GEMM operation. For now, only static shaped 2-dims X (batch and channel),
|
||||
2-dims W (input and output channel), 1-dim B (output channel), and 2-dims Y
|
||||
(batch and channel) are supported. Verifiers will ensure the legalness of
|
||||
the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$W,
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$B
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$Y
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
#endif // SCALEHLS_DIALECT_HLSKERNEL_OPS_TD
|
||||
|
|
|
@ -21,7 +21,7 @@ public:
|
|||
return TypeSwitch<Operation *, ResultType>(op)
|
||||
.template Case<
|
||||
// HLSKernel operations.
|
||||
ConvOp>([&](auto opNode) -> ResultType {
|
||||
ConvOp, MaxPoolOp, ReluOp, GemmOp>([&](auto opNode) -> ResultType {
|
||||
return thisCast->visitOp(opNode, args...);
|
||||
})
|
||||
.Default([&](auto opNode) -> ResultType {
|
||||
|
@ -48,6 +48,9 @@ public:
|
|||
|
||||
// HLSKernel operations.
|
||||
HANDLE(ConvOp);
|
||||
HANDLE(MaxPoolOp);
|
||||
HANDLE(ReluOp);
|
||||
HANDLE(GemmOp);
|
||||
|
||||
#undef HANDLE
|
||||
};
|
||||
|
|
|
@ -27,6 +27,9 @@ public:
|
|||
|
||||
using HLSKernelVisitorBase::visitOp;
|
||||
bool visitOp(ConvOp op);
|
||||
bool visitOp(MaxPoolOp op);
|
||||
bool visitOp(ReluOp op);
|
||||
bool visitOp(GemmOp op);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -112,6 +115,12 @@ bool HLSKernelVisitor::visitOp(ConvOp op) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HLSKernelVisitor::visitOp(MaxPoolOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(ReluOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(GemmOp op) { return true; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HLSkernel to Affine Lowering Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -130,9 +139,9 @@ void HLSKernelToAffinePass::runOnOperation() {
|
|||
for (auto &op : getOperation()) {
|
||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
||||
func.walk([&](HLSKernelOpInterface kernelOp) {
|
||||
if (visitor.dispatchVisitor(kernelOp))
|
||||
kernelOp.erase();
|
||||
else
|
||||
if (visitor.dispatchVisitor(kernelOp)) {
|
||||
// kernelOp.erase();
|
||||
} else
|
||||
kernelOp.emitError("can't be correctly lowered.");
|
||||
});
|
||||
} else if (!isa<ModuleTerminatorOp>(op))
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
func @test_gemm(%x: memref<10x128xf32>, %w: memref<16x128xf32>, %b: memref<16xf32>) -> (memref<10x16xf32>){
|
||||
%y = "hlskernel.gemm" (%x, %w, %b) {} : (memref<10x128xf32>, memref<16x128xf32>, memref<16xf32>) -> (memref<10x16xf32>)
|
||||
return %y : memref<10x16xf32>
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
func @test_maxpool(%x: memref<10x16x28x28xf32>) -> (memref<10x16x14x14xf32>){
|
||||
%y = "hlskernel.maxpool" (%x) {kernel_shape=[2, 2]} : (memref<10x16x28x28xf32>) -> (memref<10x16x14x14xf32>)
|
||||
return %y : memref<10x16x14x14xf32>
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
func @test_relu(%x: memref<10x16x28x28xf32>) -> (memref<10x16x28x28xf32>){
|
||||
%y = "hlskernel.relu" (%x) {} : (memref<10x16x28x28xf32>) -> (memref<10x16x28x28xf32>)
|
||||
return %y : memref<10x16x28x28xf32>
|
||||
}
|
Loading…
Reference in New Issue