[HLSKernel] update BLASOps def; [HLSKernelToAffine] start of BLASOps lowering

This commit is contained in:
Hanchen Ye 2020-11-30 21:02:11 -06:00
parent 10050bc6a3
commit 29aff5e6fb
2 changed files with 34 additions and 24 deletions

View File

@ -11,7 +11,7 @@ def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
TRANSA (false / true): op(A) = A / op(A) = A^T,
TRANSB (false / true): op(B) = B / op(B) = B^T,
Cout = alpha * op(A) * op(B) + beta * C,
C = alpha * op(A) * op(B) + beta * C,
A: M x K,
B: K x N,
@ -27,20 +27,17 @@ def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
);
let results = (outs
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
);
}
def SymmOp : HLSKernelOp<"symm", [HLSKernelOpInterface]> {
let summary = "symm operation";
let description = [{
SIDE (false):
Cout = alpha * A * B + beta * C,
C = alpha * A * B + beta * C,
A: M x M, symmetric,
SIDE (true):
Cout = alpha * B * A + beta * C,
C = alpha * B * A + beta * C,
A: N x N, symmetric,
UPLO (false / true): A is upper / lower triangular,
@ -58,20 +55,17 @@ def SymmOp : HLSKernelOp<"symm", [HLSKernelOpInterface]> {
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
);
let results = (outs
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
);
}
def SyrkOp : HLSKernelOp<"syrk", [HLSKernelOpInterface]> {
let summary = "syrk operation";
let description = [{
TRANS (false):
Cout = alpha * A * A^T + beta * C,
C = alpha * A * A^T + beta * C,
A: N x K,
TRANS (true):
Cout = alpha * A^T * A + beta * C,
C = alpha * A^T * A + beta * C,
A: K x N,
UPLO (false / true): C is upper / lower triangular,
@ -87,21 +81,18 @@ def SyrkOp : HLSKernelOp<"syrk", [HLSKernelOpInterface]> {
AnyTypeOf<[AnyTensor, AnyMemRef]>:$A,
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
);
let results = (outs
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
);
}
def Syr2kOp : HLSKernelOp<"syr2k", [HLSKernelOpInterface]> {
let summary = "syr2k operation";
let description = [{
TRANS (false):
Cout = alpha * A * B^T + alpha * B * A^T + beta * C,
C = alpha * A * B^T + alpha * B * A^T + beta * C,
A: N x K,
B: N x K,
TRANS (true):
Cout = alpha * A^T * B + alpha * B^T * A + beta * C,
C = alpha * A^T * B + alpha * B^T * A + beta * C,
A: K x N,
B: K x N,
@ -119,20 +110,17 @@ def Syr2kOp : HLSKernelOp<"syr2k", [HLSKernelOpInterface]> {
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
);
let results = (outs
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
);
}
def TrmmOp : HLSKernelOp<"trmm", [HLSKernelOpInterface]> {
let summary = "trmm operation";
let description = [{
SIDE (false):
Bout = alpha * op(A) * B,
B = alpha * op(A) * B,
A: M x M, triangular,
SIDE (true):
Bout = alpha * B * op(A),
B = alpha * B * op(A),
A: N x N, triangular,
UPLO (false / true): A is upper / lower triangular,
@ -152,9 +140,6 @@ def TrmmOp : HLSKernelOp<"trmm", [HLSKernelOpInterface]> {
AnyTypeOf<[AnyTensor, AnyMemRef]>:$A,
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B
);
let results = (outs
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Bout
);
}
#endif // SCALEHLS_DIALECT_HLSKERNEL_BLASOPS_TD

View File

@ -34,6 +34,12 @@ public:
bool visitOp(ReluOp op);
bool visitOp(MergeOp op);
bool visitOp(GemmOp op);
bool visitOp(SymmOp op);
bool visitOp(SyrkOp op);
bool visitOp(Syr2kOp op);
bool visitOp(TrmmOp op);
private:
OpBuilder &builder;
Location loc;
@ -73,6 +79,10 @@ private:
};
} // namespace
//===----------------------------------------------------------------------===//
// CNNOps Handler
//===----------------------------------------------------------------------===//
bool HLSKernelVisitor::visitOp(DenseOp op) {
auto I = op.getOperand(0);
auto K = op.getOperand(1);
@ -304,6 +314,21 @@ bool HLSKernelVisitor::visitOp(MergeOp op) {
return true;
}
//===----------------------------------------------------------------------===//
// BLASOps Handler
//===----------------------------------------------------------------------===//
// Only default attributes configuration are supported.
bool HLSKernelVisitor::visitOp(GemmOp op) { return true; }
bool HLSKernelVisitor::visitOp(SymmOp op) { return true; }
bool HLSKernelVisitor::visitOp(SyrkOp op) { return true; }
bool HLSKernelVisitor::visitOp(Syr2kOp op) { return true; }
bool HLSKernelVisitor::visitOp(TrmmOp op) { return true; }
//===----------------------------------------------------------------------===//
// HLSkernel to Affine Lowering Pass
//===----------------------------------------------------------------------===//