[mlir][math] Expand coverage of atan2 expansion
Reuse the higher precision F32 approximation for the F16 one (by expanding and truncating). This is partly RFC as I'm not sure what the expectations are here (e.g., these are only for F32 and should not be expanded, that reusing higher-precision ones for lower precision is undesirable due to increased compute cost and only approximations per exact type is preferred, or this is appropriate [at least as fallback] but we need to see how to make it more generic across all the patterns here). Differential Revision: https://reviews.llvm.org/D118968
This commit is contained in:
parent
0dcb370d43
commit
bbddd19ec7
|
@ -23,11 +23,15 @@
|
|||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::math;
|
||||
|
@ -279,6 +283,65 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Helper function/pattern to insert casts for reusing F32 bit expansion.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
template <typename T>
|
||||
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
|
||||
// Conservatively only allow where the operand and result types are exactly 1.
|
||||
Type origType = op->getResultTypes().front();
|
||||
for (Type t : llvm::drop_begin(op->getResultTypes()))
|
||||
if (origType != t)
|
||||
return rewriter.notifyMatchFailure(op, "required all types to match");
|
||||
for (Type t : op->getOperandTypes())
|
||||
if (origType != t)
|
||||
return rewriter.notifyMatchFailure(op, "required all types to match");
|
||||
|
||||
// Skip if already F32 or larger than 32 bits.
|
||||
if (getElementTypeOrSelf(origType).isF32() ||
|
||||
getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
|
||||
return failure();
|
||||
|
||||
// Create F32 equivalent type.
|
||||
Type newType;
|
||||
if (auto shaped = origType.dyn_cast<ShapedType>()) {
|
||||
newType = shaped.clone(rewriter.getF32Type());
|
||||
} else if (origType.isa<FloatType>()) {
|
||||
newType = rewriter.getF32Type();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unable to find F32 equivalent type");
|
||||
}
|
||||
|
||||
Location loc = op->getLoc();
|
||||
SmallVector<Value> operands;
|
||||
for (auto operand : op->getOperands())
|
||||
operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
|
||||
auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
|
||||
// op.
|
||||
// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
|
||||
// all in lower precision. Currently this is only fallback support and performs
|
||||
// simplistic casting.
|
||||
template <typename T>
|
||||
struct ReuseF32Expansion : public OpRewritePattern<T> {
|
||||
public:
|
||||
using OpRewritePattern<T>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
|
||||
static_assert(
|
||||
T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
|
||||
"requires same operands and result types");
|
||||
return insertCasts<T>(op, rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// AtanOp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
@ -1209,6 +1272,7 @@ void mlir::populateMathPolynomialApproximationPatterns(
|
|||
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
|
||||
LogApproximation, Log2Approximation, Log1pApproximation,
|
||||
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
|
||||
ReuseF32Expansion<math::Atan2Op>,
|
||||
SinAndCosApproximation<true, math::SinOp>,
|
||||
SinAndCosApproximation<false, math::CosOp>>(
|
||||
patterns.getContext());
|
||||
|
|
|
@ -542,7 +542,9 @@ func @atan_scalar(%arg0: f32) -> f32 {
|
|||
// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
|
||||
// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
|
||||
// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
|
||||
// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1
|
||||
// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32
|
||||
// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32
|
||||
// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]]
|
||||
// CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]]
|
||||
// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
|
||||
// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
|
||||
|
@ -562,30 +564,31 @@ func @atan_scalar(%arg0: f32) -> f32 {
|
|||
// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
|
||||
// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
|
||||
// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]]
|
||||
// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
|
||||
|
||||
// Handle PI / 2 edge case:
|
||||
// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]]
|
||||
// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]]
|
||||
// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
|
||||
// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
|
||||
|
||||
// Handle -PI / 2 edge case:
|
||||
// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
|
||||
// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]]
|
||||
// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
|
||||
// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
|
||||
|
||||
// Handle Nan edgecase:
|
||||
// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]]
|
||||
// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]]
|
||||
// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
|
||||
// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
|
||||
// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
|
||||
// CHECK: return %[[EDGE3]]
|
||||
// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]]
|
||||
// CHECK: return %[[RET]]
|
||||
|
||||
func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 {
|
||||
%0 = math.atan2 %arg0, %arg1 : f32
|
||||
return %0 : f32
|
||||
func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
|
||||
%0 = math.atan2 %arg0, %arg1 : f16
|
||||
return %0 : f16
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue