[mlir][VectorOps] Lower vector.fma to llvm.fmuladd instead of llvm.fma

Summary:
These are semantically equivalent, but fmuladd allows decaying the op
into fmul+fadd if there is no fma instruction available. llvm.fma lowers
to scalar calls to libm fmaf, which is a lot slower.

Reviewers: nicolasvasilache, aartbik, ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, Kayjukh, jurahul, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D83666
This commit is contained in:
Benjamin Kramer 2020-07-13 12:23:53 +02:00
parent ce23e54162
commit 3bffe6022c
2 changed files with 8 additions and 8 deletions

View File

@ -481,7 +481,7 @@ public:
/// ```
/// is converted to:
/// ```
/// llvm.intr.fma %va, %va, %va:
/// llvm.intr.fmuladd %va, %va, %va:
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
/// -> !llvm<"<8 x float>">
/// ```
@ -500,8 +500,8 @@ public:
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
adaptor.acc());
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
adaptor.rhs(), adaptor.acc());
return success();
}
};

View File

@ -236,7 +236,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
// CHECK: %[[T8:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
@ -245,7 +245,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
// CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]">
// CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]">
@ -688,20 +688,20 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
// CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
// CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
// CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
// CHECK: "llvm.intr.fmuladd"(%[[A]], %[[A]], %[[A]]) :
// CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
%0 = vector.fma %a, %a, %a : vector<8xf32>
// CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
// CHECK: %[[B0:.*]] = "llvm.intr.fmuladd"(%[[b00]], %[[b01]], %[[b02]]) :
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
// CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
// CHECK: %[[B1:.*]] = "llvm.intr.fmuladd"(%[[b10]], %[[b11]], %[[b12]]) :
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
%1 = vector.fma %b, %b, %b : vector<2x4xf32>