diff --git a/include/Dialect/HLSKernel/BLASOps.td b/include/Dialect/HLSKernel/BLASOps.td index a9fc775..57a759b 100644 --- a/include/Dialect/HLSKernel/BLASOps.td +++ b/include/Dialect/HLSKernel/BLASOps.td @@ -40,7 +40,7 @@ def SymmOp : HLSKernelOp<"symm", [HLSKernelOpInterface]> { C = alpha * B * A + beta * C, A: N x N, symmetric, - UPLO (false / true): A is upper / lower triangular, + UPLO (false / true): A is lower / upper triangular, B: M x N, C: M x N @@ -68,7 +68,7 @@ def SyrkOp : HLSKernelOp<"syrk", [HLSKernelOpInterface]> { C = alpha * A^T * A + beta * C, A: K x N, - UPLO (false / true): C is upper / lower triangular, + UPLO (false / true): C is lower / upper triangular, C: N x N, symmetric }]; @@ -96,7 +96,7 @@ def Syr2kOp : HLSKernelOp<"syr2k", [HLSKernelOpInterface]> { A: K x N, B: K x N, - UPLO (false / true): C is upper / lower triangular, + UPLO (false / true): C is lower / upper triangular, C: N x N, symmetric }]; @@ -123,7 +123,7 @@ def TrmmOp : HLSKernelOp<"trmm", [HLSKernelOpInterface]> { B = alpha * B * op(A), A: N x N, triangular, - UPLO (false / true): A is upper / lower triangular, + UPLO (false / true): A is lower / upper triangular, TRANSA (false / true): op(A) = A / op(A) = A^T, DIAG (false / true): A is non-unit / unit triangular, diff --git a/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp b/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp index a9079bc..af87ad7 100644 --- a/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp +++ b/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp @@ -44,13 +44,42 @@ private: OpBuilder &builder; Location loc; - // Helpers for creating loops, loads, stores and binary operations. - Value createLoop(unsigned upper, unsigned step = 1, unsigned lower = 0) { + // Helpers for creating loops. + // Constant upper and lower bound. + Value createLoop(int64_t upper, int64_t lower = 0, int64_t step = 1) { auto loop = builder.create(loc, lower, upper, step); builder.setInsertionPointToStart(&loop.getLoopBody().front()); return loop.getInductionVar(); } + // General case. + Value createLoop(std::initializer_list upper, AffineMap upperMap, + std::initializer_list lower, AffineMap lowerMap, + int64_t step = 1) { + auto loop = builder.create(loc, lower, lowerMap, upper, + upperMap, step); + builder.setInsertionPointToStart(&loop.getLoopBody().front()); + return loop.getInductionVar(); + } + + Value createLoop(Value upper, Value lower, int64_t step = 1) { + auto indexMap = AffineMap::get(1, 0, getDim(0), builder.getContext()); + return createLoop({upper}, indexMap, {lower}, indexMap); + } + + Value createLoop(int64_t upper, Value lower, int64_t step = 1) { + auto lowerMap = AffineMap::get(1, 0, getDim(0), builder.getContext()); + auto upperMap = AffineMap::get(0, 0, getConst(upper), builder.getContext()); + return createLoop({}, upperMap, {lower}, lowerMap); + } + + Value createLoop(Value upper, int64_t lower, int64_t step = 1) { + auto lowerMap = AffineMap::get(0, 0, getConst(lower), builder.getContext()); + auto upperMap = AffineMap::get(1, 0, getDim(0), builder.getContext()); + return createLoop({upper}, upperMap, {}, lowerMap); + } + + // Helpers for creating loads, stores and binary operations. Value createLoad(Value array, std::initializer_list index) { return builder.create(loc, array, ArrayRef(index)); @@ -319,9 +348,100 @@ bool HLSKernelVisitor::visitOp(MergeOp op) { //===----------------------------------------------------------------------===// // Only default attributes configuration are supported. -bool HLSKernelVisitor::visitOp(GemmOp op) { return true; } +bool HLSKernelVisitor::visitOp(GemmOp op) { + auto alpha = op.getOperand(0); + auto beta = op.getOperand(1); -bool HLSKernelVisitor::visitOp(SymmOp op) { return true; } + auto A = op.getOperand(2); + auto B = op.getOperand(3); + auto C = op.getOperand(4); + + auto AShape = A.getType().cast().getShape(); + auto CShape = C.getType().cast().getShape(); + + // Set insertion point of builder. + builder.setInsertionPoint(op); + + // Create M dimension loop. + auto m = createLoop(CShape[0]); + + // Create N dimension loop. + auto n = createLoop(CShape[1]); + + // Update C with beta * C. + auto initC = createLoad(C, {m, n}); + auto betaC = createBinaryOp(beta, initC); + createStore(betaC, C, {m, n}); + + // Create K dimension loop. + auto k = createLoop(AShape[1]); + + // Accumulate C with alpha * A * B. + auto valA = createLoad(A, {m, k}); + auto valB = createLoad(B, {k, n}); + auto valC = createLoad(C, {m, n}); + + auto alphaA = createBinaryOp(alpha, valA); + auto alphaAB = createBinaryOp(alphaA, valB); + auto accumC = createBinaryOp(alphaAB, valC); + createStore(accumC, C, {m, n}); + + return true; +} + +bool HLSKernelVisitor::visitOp(SymmOp op) { + auto alpha = op.getOperand(0); + auto beta = op.getOperand(1); + + auto A = op.getOperand(2); + auto B = op.getOperand(3); + auto C = op.getOperand(4); + + auto CShape = C.getType().cast().getShape(); + + // Set insertion point of builder. + builder.setInsertionPoint(op); + + // Create M dimension loop. + auto m = createLoop(CShape[0]); + + // Create N dimension loop. + auto n = createLoop(CShape[1]); + + // Update C with beta * C. + auto initC = createLoad(C, {m, n}); + auto betaC = createBinaryOp(beta, initC); + createStore(betaC, C, {m, n}); + + // Create K dimension loop for lower triangle. + auto lk = createLoop(m, 0); + + // Accumulate C with alpha * A * B. + auto valA = createLoad(A, {m, lk}); + auto valB = createLoad(B, {lk, n}); + auto valC = createLoad(C, {m, n}); + + auto alphaA = createBinaryOp(alpha, valA); + auto alphaAB = createBinaryOp(alphaA, valB); + auto accumC = createBinaryOp(alphaAB, valC); + createStore(accumC, C, {m, n}); + + // Create K dimension loop for upper triangle. + builder.setInsertionPoint(n.getParentBlock()->getTerminator()); + auto hk = createLoop(CShape[0], m); + + // Accumulate C with alpha * A * B. + valA = createLoad(A, {hk, m}); + valB = createLoad(B, {hk, n}); + valC = createLoad(C, {m, n}); + + alphaA = createBinaryOp(alpha, valA); + alphaAB = createBinaryOp(alphaA, valB); + accumC = createBinaryOp(alphaAB, valC); + createStore(accumC, C, {m, n}); + + return true; +} bool HLSKernelVisitor::visitOp(SyrkOp op) { return true; } diff --git a/test/Conversion/HLSKernelToAffine/test_gemm.mlir b/test/Conversion/HLSKernelToAffine/test_gemm.mlir new file mode 100644 index 0000000..24b3eff --- /dev/null +++ b/test/Conversion/HLSKernelToAffine/test_gemm.mlir @@ -0,0 +1,9 @@ +// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s + +// CHECK: module { +func @test_gemm(%A: memref<32x16xf32>, %B: memref<16x8xf32>, %C: memref<32x8xf32>) -> () { + %alpha = constant 11.0 : f32 + %beta = constant 42.0 : f32 + "hlskernel.gemm" (%alpha, %beta, %A, %B, %C) {} : (f32, f32, memref<32x16xf32>, memref<16x8xf32>, memref<32x8xf32>) -> () + return +} diff --git a/test/Conversion/HLSKernelToAffine/test_symm.mlir b/test/Conversion/HLSKernelToAffine/test_symm.mlir new file mode 100644 index 0000000..af92b53 --- /dev/null +++ b/test/Conversion/HLSKernelToAffine/test_symm.mlir @@ -0,0 +1,9 @@ +// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s + +// CHECK: module { +func @test_symm(%A: memref<32x32xf32>, %B: memref<32x8xf32>, %C: memref<32x8xf32>) -> () { + %alpha = constant 11.0 : f32 + %beta = constant 42.0 : f32 + "hlskernel.symm" (%alpha, %beta, %A, %B, %C) {} : (f32, f32, memref<32x32xf32>, memref<32x8xf32>, memref<32x8xf32>) -> () + return +}