[mlir] Add LogOp lowering from Complex dialect to Standard/Math dialect.
Differential Revision: https://reviews.llvm.org/D105342
This commit is contained in:
parent
21a1bcbd4d
commit
380fa71fb0
|
@ -315,6 +315,28 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
|
||||
using OpConversionPattern<complex::LogOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
complex::LogOp::Adaptor transformed(operands);
|
||||
auto type = transformed.complex().getType().cast<ComplexType>();
|
||||
auto elementType = type.getElementType().cast<FloatType>();
|
||||
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
|
||||
Value resultReal = b.create<math::LogOp>(elementType, abs);
|
||||
Value real = b.create<complex::ReOp>(elementType, transformed.complex());
|
||||
Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
|
||||
Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
|
||||
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
||||
resultImag);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
|
||||
using OpConversionPattern<complex::NegOp>::OpConversionPattern;
|
||||
|
||||
|
@ -374,6 +396,7 @@ void mlir::populateComplexToStandardConversionPatterns(
|
|||
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
|
||||
DivOpConversion,
|
||||
ExpOpConversion,
|
||||
LogOpConversion,
|
||||
NegOpConversion,
|
||||
SignOpConversion>(patterns.getContext());
|
||||
// clang-format on
|
||||
|
@ -396,8 +419,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
|
|||
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
|
||||
complex::ComplexDialect>();
|
||||
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
|
||||
complex::ExpOp, complex::NotEqualOp, complex::NegOp,
|
||||
complex::SignOp>();
|
||||
complex::ExpOp, complex::LogOp, complex::NotEqualOp,
|
||||
complex::NegOp, complex::SignOp>();
|
||||
if (failed(applyPartialConversion(function, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -154,6 +154,25 @@ func @complex_exp(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// CHECK-LABEL: func @complex_log
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func @complex_log(%arg: complex<f32>) -> complex<f32> {
|
||||
%log = complex.log %arg: complex<f32>
|
||||
return %log : complex<f32>
|
||||
}
|
||||
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32
|
||||
// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
|
||||
// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
|
||||
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
|
||||
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
|
||||
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
|
||||
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
|
||||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// CHECK-LABEL: func @complex_neg
|
||||
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
|
||||
func @complex_neg(%arg: complex<f32>) -> complex<f32> {
|
||||
|
|
Loading…
Reference in New Issue