From bbddd19ec723a15ea1558cce5e47cb2460fa8e24 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 8 Feb 2022 15:00:39 -0800 Subject: [PATCH] [mlir][math] Expand coverage of atan2 expansion Reuse the higher precision F32 approximation for the F16 one (by expanding and truncating). This is partly RFC as I'm not sure what the expectations are here (e.g., these are only for F32 and should not be expanded, that reusing higher-precision ones for lower precision is undesirable due to increased compute cost and only approximations per exact type is preferred, or this is appropriate [at least as fallback] but we need to see how to make it more generic across all the patterns here). Differential Revision: https://reviews.llvm.org/D118968 --- .../Transforms/PolynomialApproximation.cpp | 64 +++++++++++++++++++ .../Math/polynomial-approximation.mlir | 23 ++++--- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index ed9170cdf55a..9c8e413b0e55 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -23,11 +23,15 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::math; @@ -279,6 +283,65 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, } } // namespace +//----------------------------------------------------------------------------// +// Helper function/pattern to insert casts for reusing F32 bit expansion. +//----------------------------------------------------------------------------// + +template +LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { + // Conservatively only allow where the operand and result types are exactly 1. + Type origType = op->getResultTypes().front(); + for (Type t : llvm::drop_begin(op->getResultTypes())) + if (origType != t) + return rewriter.notifyMatchFailure(op, "required all types to match"); + for (Type t : op->getOperandTypes()) + if (origType != t) + return rewriter.notifyMatchFailure(op, "required all types to match"); + + // Skip if already F32 or larger than 32 bits. + if (getElementTypeOrSelf(origType).isF32() || + getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32) + return failure(); + + // Create F32 equivalent type. + Type newType; + if (auto shaped = origType.dyn_cast()) { + newType = shaped.clone(rewriter.getF32Type()); + } else if (origType.isa()) { + newType = rewriter.getF32Type(); + } else { + return rewriter.notifyMatchFailure(op, + "unable to find F32 equivalent type"); + } + + Location loc = op->getLoc(); + SmallVector operands; + for (auto operand : op->getOperands()) + operands.push_back(rewriter.create(loc, newType, operand)); + auto result = rewriter.create(loc, newType, operands); + rewriter.replaceOpWithNewOp(op, origType, result); + return success(); +} + +namespace { +// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result +// op. +// TODO: Consider revising to avoid adding multiple casts for a subgraph that is +// all in lower precision. Currently this is only fallback support and performs +// simplistic casting. +template +struct ReuseF32Expansion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { + static_assert( + T::template hasTrait(), + "requires same operands and result types"); + return insertCasts(op, rewriter); + } +}; +} // namespace + //----------------------------------------------------------------------------// // AtanOp approximation. //----------------------------------------------------------------------------// @@ -1209,6 +1272,7 @@ void mlir::populateMathPolynomialApproximationPatterns( patterns.add, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir index 8655c1cf61b8..e8c09ba2ca0a 100644 --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -542,7 +542,9 @@ func @atan_scalar(%arg0: f32) -> f32 { // CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 // CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 // CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 -// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1 +// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32 +// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32 +// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]] // CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]] // CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] // CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] @@ -562,30 +564,31 @@ func @atan_scalar(%arg0: f32) -> f32 { // CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] // CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] // CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] -// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]] +// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]] // CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] // Handle PI / 2 edge case: -// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]] -// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]] +// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] // CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] // Handle -PI / 2 edge case: // CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 -// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] // CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] // Handle Nan edgecase: -// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] // CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 // CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] -// CHECK: return %[[EDGE3]] +// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]] +// CHECK: return %[[RET]] -func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 { - %0 = math.atan2 %arg0, %arg1 : f32 - return %0 : f32 +func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 { + %0 = math.atan2 %arg0, %arg1 : f16 + return %0 : f16 }