[HLSKernelToLinalg] add SymmOp lowering

This commit is contained in:
Hanchen Ye 2020-11-23 15:42:36 -06:00
parent 78eecf18da
commit c9d694c65c
2 changed files with 129 additions and 0 deletions

View File

@ -28,6 +28,7 @@ public:
using HLSKernelVisitorBase::visitOp;
bool visitOp(GemmOp op);
bool visitOp(SymmOp op);
};
} // namespace
@ -104,6 +105,125 @@ bool HLSKernelVisitor::visitOp(GemmOp op) {
return true;
}
bool HLSKernelVisitor::visitOp(SymmOp op) {
OpBuilder builder(op);
auto alpha = op.getOperand(0);
auto beta = op.getOperand(1);
auto A = op.getOperand(2);
auto B = op.getOperand(3);
auto C = op.getOperand(4);
auto CType = C.getType().cast<MemRefType>();
auto ElementType = CType.getElementType();
auto Cout = builder.create<mlir::AllocOp>(op.getLoc(), CType);
op.getResult().replaceAllUsesWith(Cout);
// Calculate beta * C and store to Cout.
SmallVector<AffineExpr, 2> idxExprs;
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(1));
auto betaCMap = AffineMap::get(2, 0, idxExprs, op.getContext());
auto betaCOp = builder.create<linalg::GenericOp>(
op.getLoc(), ArrayRef<Value>(C), ArrayRef<Value>(Cout),
ArrayRef<AffineMap>({betaCMap, betaCMap}),
ArrayRef<StringRef>({"parallel", "parallel"}));
auto betaCBlock = builder.createBlock(
&betaCOp.getRegion(), {}, ArrayRef<Type>({ElementType, ElementType}));
builder.setInsertionPointToStart(betaCOp.getBody());
auto betaCResult = builder.create<mlir::MulFOp>(
op.getLoc(), ElementType, beta, betaCBlock->getArgument(0));
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(betaCResult));
// Calculate alpha * A[i][i] * B[i][j] and accumulate to Cout[i][j], which
// will only calculate diagonal elements of matrix A.
builder.setInsertionPoint(op);
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(0));
auto diagAlphaABMapA = AffineMap::get(2, 0, idxExprs, op.getContext());
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(1));
auto diagAlphaABMapB = AffineMap::get(2, 0, idxExprs, op.getContext());
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(1));
auto diagAlphaABMapC = AffineMap::get(2, 0, idxExprs, op.getContext());
auto diagAlphaABOp = builder.create<linalg::GenericOp>(
op.getLoc(), ArrayRef<Value>({A, B}), ArrayRef<Value>(Cout),
ArrayRef<AffineMap>({diagAlphaABMapA, diagAlphaABMapB, diagAlphaABMapC}),
ArrayRef<StringRef>({"parallel", "parallel"}));
auto diagAlphaABBlock = builder.createBlock(
&diagAlphaABOp.getRegion(), {},
ArrayRef<Type>({ElementType, ElementType, ElementType}));
builder.setInsertionPointToStart(diagAlphaABOp.getBody());
auto diagAlphaAResult = builder.create<mlir::MulFOp>(
op.getLoc(), ElementType, alpha, diagAlphaABBlock->getArgument(0));
auto diagAlphaABResult =
builder.create<mlir::MulFOp>(op.getLoc(), ElementType, diagAlphaAResult,
diagAlphaABBlock->getArgument(1));
auto diagResult =
builder.create<mlir::AddFOp>(op.getLoc(), ElementType, diagAlphaABResult,
diagAlphaABBlock->getArgument(2));
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(diagResult));
// Calculate alpha * A[i][k] * B[k][j] and accumulate to Cout[i][j] (k < i),
// which will only calculate lower elements of matrix A.
// Calculate alpha * A[i][k] * B[k][j] and accumulate to Cout[k][j] (k < i),
// which will only calculate lower elements of matrix A. Note that we are
// actually accessing A[k][i] because A[i][k] is equal to A[k][i].
// Calculate alpha * A * B and accumulate to Cout.
builder.setInsertionPoint(op);
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(2));
auto alphaABMapA = AffineMap::get(3, 0, idxExprs, op.getContext());
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(2));
idxExprs.push_back(builder.getAffineDimExpr(1));
auto alphaABMapB = AffineMap::get(3, 0, idxExprs, op.getContext());
idxExprs.clear();
idxExprs.push_back(builder.getAffineDimExpr(0));
idxExprs.push_back(builder.getAffineDimExpr(1));
auto alphaABMapC = AffineMap::get(3, 0, idxExprs, op.getContext());
auto alphaABOp = builder.create<linalg::GenericOp>(
op.getLoc(), ArrayRef<Value>({A, B}), ArrayRef<Value>(Cout),
ArrayRef<AffineMap>({alphaABMapA, alphaABMapB, alphaABMapC}),
ArrayRef<StringRef>({"parallel", "parallel", "reduction"}));
auto alphaABBlock = builder.createBlock(
&alphaABOp.getRegion(), {},
ArrayRef<Type>({ElementType, ElementType, ElementType}));
builder.setInsertionPointToStart(alphaABOp.getBody());
auto alphaAResult = builder.create<mlir::MulFOp>(
op.getLoc(), ElementType, alpha, alphaABBlock->getArgument(0));
auto alphaABResult = builder.create<mlir::MulFOp>(
op.getLoc(), ElementType, alphaAResult, alphaABBlock->getArgument(1));
auto result = builder.create<mlir::AddFOp>(
op.getLoc(), ElementType, alphaABResult, alphaABBlock->getArgument(2));
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(result));
return true;
}
//===----------------------------------------------------------------------===//
// HLSkernel to Linalg Lowering Pass
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,9 @@
// RUN: scalehls-opt -hlskernel-to-linalg %s | FileCheck %s
// CHECK: module {
func @test_symm(%A: memref<16x16xf32>, %B: memref<16x32xf32>, %C: memref<16x32xf32>) -> (memref<16x32xf32>){
%alpha = constant 1.1 : f32
%beta = constant 4.2 : f32
%Cout = "hlskernel.symm" (%alpha, %beta, %A, %B, %C) {} : (f32, f32, memref<16x16xf32>, memref<16x32xf32>, memref<16x32xf32>) -> (memref<16x32xf32>)
return %Cout : memref<16x32xf32>
}