[mlir][VectorOps] Lower vector.fma to llvm.fmuladd instead of llvm.fma
Summary: These are semantically equivalent, but fmuladd allows decaying the op into fmul+fadd if there is no fma instruction available. llvm.fma lowers to scalar calls to libm fmaf, which is a lot slower. Reviewers: nicolasvasilache, aartbik, ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, Kayjukh, jurahul, msifontes Tags: #mlir Differential Revision: https://reviews.llvm.org/D83666
This commit is contained in:
parent
ce23e54162
commit
3bffe6022c
|
@ -481,7 +481,7 @@ public:
|
||||||
/// ```
|
/// ```
|
||||||
/// is converted to:
|
/// is converted to:
|
||||||
/// ```
|
/// ```
|
||||||
/// llvm.intr.fma %va, %va, %va:
|
/// llvm.intr.fmuladd %va, %va, %va:
|
||||||
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
|
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
|
||||||
/// -> !llvm<"<8 x float>">
|
/// -> !llvm<"<8 x float>">
|
||||||
/// ```
|
/// ```
|
||||||
|
@ -500,8 +500,8 @@ public:
|
||||||
VectorType vType = fmaOp.getVectorType();
|
VectorType vType = fmaOp.getVectorType();
|
||||||
if (vType.getRank() != 1)
|
if (vType.getRank() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
|
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
|
||||||
adaptor.acc());
|
adaptor.rhs(), adaptor.acc());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -236,7 +236,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
|
||||||
// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>">
|
// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>">
|
||||||
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
|
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
|
||||||
// CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
|
// CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
|
||||||
// CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
|
// CHECK: %[[T8:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
|
||||||
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
|
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
|
||||||
// CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
|
// CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
|
||||||
// CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
|
// CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
|
||||||
|
@ -245,7 +245,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
|
||||||
// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
|
// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
|
||||||
// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
|
// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
|
||||||
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
|
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
|
||||||
// CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
|
// CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
|
||||||
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]">
|
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]">
|
||||||
// CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]">
|
// CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]">
|
||||||
|
|
||||||
|
@ -688,20 +688,20 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
|
||||||
// CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
|
// CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
|
||||||
// CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
|
// CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
|
||||||
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
|
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
|
||||||
// CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
|
// CHECK: "llvm.intr.fmuladd"(%[[A]], %[[A]], %[[A]]) :
|
||||||
// CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
|
// CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
|
||||||
%0 = vector.fma %a, %a, %a : vector<8xf32>
|
%0 = vector.fma %a, %a, %a : vector<8xf32>
|
||||||
|
|
||||||
// CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
|
// CHECK: %[[B0:.*]] = "llvm.intr.fmuladd"(%[[b00]], %[[b01]], %[[b02]]) :
|
||||||
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
|
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
|
||||||
// CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
|
// CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
// CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
|
||||||
// CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
|
// CHECK: %[[B1:.*]] = "llvm.intr.fmuladd"(%[[b10]], %[[b11]], %[[b12]]) :
|
||||||
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
|
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
|
||||||
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
|
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
|
||||||
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
|
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
|
||||||
|
|
Loading…
Reference in New Issue