[HLSKernelToLinalg] add SymmOp lowering
This commit is contained in:
parent
78eecf18da
commit
c9d694c65c
|
@ -28,6 +28,7 @@ public:
|
|||
|
||||
using HLSKernelVisitorBase::visitOp;
|
||||
bool visitOp(GemmOp op);
|
||||
bool visitOp(SymmOp op);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -104,6 +105,125 @@ bool HLSKernelVisitor::visitOp(GemmOp op) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HLSKernelVisitor::visitOp(SymmOp op) {
|
||||
OpBuilder builder(op);
|
||||
|
||||
auto alpha = op.getOperand(0);
|
||||
auto beta = op.getOperand(1);
|
||||
auto A = op.getOperand(2);
|
||||
auto B = op.getOperand(3);
|
||||
auto C = op.getOperand(4);
|
||||
|
||||
auto CType = C.getType().cast<MemRefType>();
|
||||
auto ElementType = CType.getElementType();
|
||||
|
||||
auto Cout = builder.create<mlir::AllocOp>(op.getLoc(), CType);
|
||||
op.getResult().replaceAllUsesWith(Cout);
|
||||
|
||||
// Calculate beta * C and store to Cout.
|
||||
SmallVector<AffineExpr, 2> idxExprs;
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(1));
|
||||
auto betaCMap = AffineMap::get(2, 0, idxExprs, op.getContext());
|
||||
|
||||
auto betaCOp = builder.create<linalg::GenericOp>(
|
||||
op.getLoc(), ArrayRef<Value>(C), ArrayRef<Value>(Cout),
|
||||
ArrayRef<AffineMap>({betaCMap, betaCMap}),
|
||||
ArrayRef<StringRef>({"parallel", "parallel"}));
|
||||
|
||||
auto betaCBlock = builder.createBlock(
|
||||
&betaCOp.getRegion(), {}, ArrayRef<Type>({ElementType, ElementType}));
|
||||
|
||||
builder.setInsertionPointToStart(betaCOp.getBody());
|
||||
auto betaCResult = builder.create<mlir::MulFOp>(
|
||||
op.getLoc(), ElementType, beta, betaCBlock->getArgument(0));
|
||||
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(betaCResult));
|
||||
|
||||
// Calculate alpha * A[i][i] * B[i][j] and accumulate to Cout[i][j], which
|
||||
// will only calculate diagonal elements of matrix A.
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
auto diagAlphaABMapA = AffineMap::get(2, 0, idxExprs, op.getContext());
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(1));
|
||||
auto diagAlphaABMapB = AffineMap::get(2, 0, idxExprs, op.getContext());
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(1));
|
||||
auto diagAlphaABMapC = AffineMap::get(2, 0, idxExprs, op.getContext());
|
||||
|
||||
auto diagAlphaABOp = builder.create<linalg::GenericOp>(
|
||||
op.getLoc(), ArrayRef<Value>({A, B}), ArrayRef<Value>(Cout),
|
||||
ArrayRef<AffineMap>({diagAlphaABMapA, diagAlphaABMapB, diagAlphaABMapC}),
|
||||
ArrayRef<StringRef>({"parallel", "parallel"}));
|
||||
|
||||
auto diagAlphaABBlock = builder.createBlock(
|
||||
&diagAlphaABOp.getRegion(), {},
|
||||
ArrayRef<Type>({ElementType, ElementType, ElementType}));
|
||||
|
||||
builder.setInsertionPointToStart(diagAlphaABOp.getBody());
|
||||
auto diagAlphaAResult = builder.create<mlir::MulFOp>(
|
||||
op.getLoc(), ElementType, alpha, diagAlphaABBlock->getArgument(0));
|
||||
auto diagAlphaABResult =
|
||||
builder.create<mlir::MulFOp>(op.getLoc(), ElementType, diagAlphaAResult,
|
||||
diagAlphaABBlock->getArgument(1));
|
||||
auto diagResult =
|
||||
builder.create<mlir::AddFOp>(op.getLoc(), ElementType, diagAlphaABResult,
|
||||
diagAlphaABBlock->getArgument(2));
|
||||
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(diagResult));
|
||||
|
||||
// Calculate alpha * A[i][k] * B[k][j] and accumulate to Cout[i][j] (k < i),
|
||||
// which will only calculate lower elements of matrix A.
|
||||
|
||||
// Calculate alpha * A[i][k] * B[k][j] and accumulate to Cout[k][j] (k < i),
|
||||
// which will only calculate lower elements of matrix A. Note that we are
|
||||
// actually accessing A[k][i] because A[i][k] is equal to A[k][i].
|
||||
|
||||
// Calculate alpha * A * B and accumulate to Cout.
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(2));
|
||||
auto alphaABMapA = AffineMap::get(3, 0, idxExprs, op.getContext());
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(2));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(1));
|
||||
auto alphaABMapB = AffineMap::get(3, 0, idxExprs, op.getContext());
|
||||
|
||||
idxExprs.clear();
|
||||
idxExprs.push_back(builder.getAffineDimExpr(0));
|
||||
idxExprs.push_back(builder.getAffineDimExpr(1));
|
||||
auto alphaABMapC = AffineMap::get(3, 0, idxExprs, op.getContext());
|
||||
|
||||
auto alphaABOp = builder.create<linalg::GenericOp>(
|
||||
op.getLoc(), ArrayRef<Value>({A, B}), ArrayRef<Value>(Cout),
|
||||
ArrayRef<AffineMap>({alphaABMapA, alphaABMapB, alphaABMapC}),
|
||||
ArrayRef<StringRef>({"parallel", "parallel", "reduction"}));
|
||||
|
||||
auto alphaABBlock = builder.createBlock(
|
||||
&alphaABOp.getRegion(), {},
|
||||
ArrayRef<Type>({ElementType, ElementType, ElementType}));
|
||||
|
||||
builder.setInsertionPointToStart(alphaABOp.getBody());
|
||||
auto alphaAResult = builder.create<mlir::MulFOp>(
|
||||
op.getLoc(), ElementType, alpha, alphaABBlock->getArgument(0));
|
||||
auto alphaABResult = builder.create<mlir::MulFOp>(
|
||||
op.getLoc(), ElementType, alphaAResult, alphaABBlock->getArgument(1));
|
||||
auto result = builder.create<mlir::AddFOp>(
|
||||
op.getLoc(), ElementType, alphaABResult, alphaABBlock->getArgument(2));
|
||||
builder.create<linalg::YieldOp>(op.getLoc(), ArrayRef<Value>(result));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HLSkernel to Linalg Lowering Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: scalehls-opt -hlskernel-to-linalg %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
func @test_symm(%A: memref<16x16xf32>, %B: memref<16x32xf32>, %C: memref<16x32xf32>) -> (memref<16x32xf32>){
|
||||
%alpha = constant 1.1 : f32
|
||||
%beta = constant 4.2 : f32
|
||||
%Cout = "hlskernel.symm" (%alpha, %beta, %A, %B, %C) {} : (f32, f32, memref<16x16xf32>, memref<16x32xf32>, memref<16x32xf32>) -> (memref<16x32xf32>)
|
||||
return %Cout : memref<16x32xf32>
|
||||
}
|
Loading…
Reference in New Issue