[HLSKernel] update BLASOps def; [HLSKernelToAffine] start of BLASOps lowering
This commit is contained in:
parent
10050bc6a3
commit
29aff5e6fb
|
@ -11,7 +11,7 @@ def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
|
|||
TRANSA (false / true): op(A) = A / op(A) = A^T,
|
||||
TRANSB (false / true): op(B) = B / op(B) = B^T,
|
||||
|
||||
Cout = alpha * op(A) * op(B) + beta * C,
|
||||
C = alpha * op(A) * op(B) + beta * C,
|
||||
|
||||
A: M x K,
|
||||
B: K x N,
|
||||
|
@ -27,20 +27,17 @@ def GemmOp : HLSKernelOp<"gemm", [HLSKernelOpInterface]> {
|
|||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
|
||||
);
|
||||
}
|
||||
|
||||
def SymmOp : HLSKernelOp<"symm", [HLSKernelOpInterface]> {
|
||||
let summary = "symm operation";
|
||||
let description = [{
|
||||
SIDE (false):
|
||||
Cout = alpha * A * B + beta * C,
|
||||
C = alpha * A * B + beta * C,
|
||||
A: M x M, symmetric,
|
||||
|
||||
SIDE (true):
|
||||
Cout = alpha * B * A + beta * C,
|
||||
C = alpha * B * A + beta * C,
|
||||
A: N x N, symmetric,
|
||||
|
||||
UPLO (false / true): A is upper / lower triangular,
|
||||
|
@ -58,20 +55,17 @@ def SymmOp : HLSKernelOp<"symm", [HLSKernelOpInterface]> {
|
|||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
|
||||
);
|
||||
}
|
||||
|
||||
def SyrkOp : HLSKernelOp<"syrk", [HLSKernelOpInterface]> {
|
||||
let summary = "syrk operation";
|
||||
let description = [{
|
||||
TRANS (false):
|
||||
Cout = alpha * A * A^T + beta * C,
|
||||
C = alpha * A * A^T + beta * C,
|
||||
A: N x K,
|
||||
|
||||
TRANS (true):
|
||||
Cout = alpha * A^T * A + beta * C,
|
||||
C = alpha * A^T * A + beta * C,
|
||||
A: K x N,
|
||||
|
||||
UPLO (false / true): C is upper / lower triangular,
|
||||
|
@ -87,21 +81,18 @@ def SyrkOp : HLSKernelOp<"syrk", [HLSKernelOpInterface]> {
|
|||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$A,
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
|
||||
);
|
||||
}
|
||||
|
||||
def Syr2kOp : HLSKernelOp<"syr2k", [HLSKernelOpInterface]> {
|
||||
let summary = "syr2k operation";
|
||||
let description = [{
|
||||
TRANS (false):
|
||||
Cout = alpha * A * B^T + alpha * B * A^T + beta * C,
|
||||
C = alpha * A * B^T + alpha * B * A^T + beta * C,
|
||||
A: N x K,
|
||||
B: N x K,
|
||||
|
||||
TRANS (true):
|
||||
Cout = alpha * A^T * B + alpha * B^T * A + beta * C,
|
||||
C = alpha * A^T * B + alpha * B^T * A + beta * C,
|
||||
A: K x N,
|
||||
B: K x N,
|
||||
|
||||
|
@ -119,20 +110,17 @@ def Syr2kOp : HLSKernelOp<"syr2k", [HLSKernelOpInterface]> {
|
|||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B,
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$C
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Cout
|
||||
);
|
||||
}
|
||||
|
||||
def TrmmOp : HLSKernelOp<"trmm", [HLSKernelOpInterface]> {
|
||||
let summary = "trmm operation";
|
||||
let description = [{
|
||||
SIDE (false):
|
||||
Bout = alpha * op(A) * B,
|
||||
B = alpha * op(A) * B,
|
||||
A: M x M, triangular,
|
||||
|
||||
SIDE (true):
|
||||
Bout = alpha * B * op(A),
|
||||
B = alpha * B * op(A),
|
||||
A: N x N, triangular,
|
||||
|
||||
UPLO (false / true): A is upper / lower triangular,
|
||||
|
@ -152,9 +140,6 @@ def TrmmOp : HLSKernelOp<"trmm", [HLSKernelOpInterface]> {
|
|||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$A,
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$B
|
||||
);
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnyTensor, AnyMemRef]>:$Bout
|
||||
);
|
||||
}
|
||||
|
||||
#endif // SCALEHLS_DIALECT_HLSKERNEL_BLASOPS_TD
|
||||
|
|
|
@ -34,6 +34,12 @@ public:
|
|||
bool visitOp(ReluOp op);
|
||||
bool visitOp(MergeOp op);
|
||||
|
||||
bool visitOp(GemmOp op);
|
||||
bool visitOp(SymmOp op);
|
||||
bool visitOp(SyrkOp op);
|
||||
bool visitOp(Syr2kOp op);
|
||||
bool visitOp(TrmmOp op);
|
||||
|
||||
private:
|
||||
OpBuilder &builder;
|
||||
Location loc;
|
||||
|
@ -73,6 +79,10 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CNNOps Handler
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool HLSKernelVisitor::visitOp(DenseOp op) {
|
||||
auto I = op.getOperand(0);
|
||||
auto K = op.getOperand(1);
|
||||
|
@ -304,6 +314,21 @@ bool HLSKernelVisitor::visitOp(MergeOp op) {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BLASOps Handler
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Only default attributes configuration are supported.
|
||||
bool HLSKernelVisitor::visitOp(GemmOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(SymmOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(SyrkOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(Syr2kOp op) { return true; }
|
||||
|
||||
bool HLSKernelVisitor::visitOp(TrmmOp op) { return true; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HLSkernel to Affine Lowering Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue