[mlir] Add LogOp lowering from Complex dialect to Standard/Math dialect.

Differential Revision: https://reviews.llvm.org/D105342
This commit is contained in:
Adrian Kuegel 2021-07-02 13:20:13 +02:00
parent 21a1bcbd4d
commit 380fa71fb0
2 changed files with 44 additions and 2 deletions

View File

@ -315,6 +315,28 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
}
};
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
using OpConversionPattern<complex::LogOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
complex::LogOp::Adaptor transformed(operands);
auto type = transformed.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
Value resultReal = b.create<math::LogOp>(elementType, abs);
Value real = b.create<complex::ReOp>(elementType, transformed.complex());
Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
}
};
struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
using OpConversionPattern<complex::NegOp>::OpConversionPattern;
@ -374,6 +396,7 @@ void mlir::populateComplexToStandardConversionPatterns(
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
DivOpConversion,
ExpOpConversion,
LogOpConversion,
NegOpConversion,
SignOpConversion>(patterns.getContext());
// clang-format on
@ -396,8 +419,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
complex::ExpOp, complex::NotEqualOp, complex::NegOp,
complex::SignOp>();
complex::ExpOp, complex::LogOp, complex::NotEqualOp,
complex::NegOp, complex::SignOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}

View File

@ -154,6 +154,25 @@ func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// CHECK-LABEL: func @complex_log
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32
// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func @complex_neg(%arg: complex<f32>) -> complex<f32> {