From 2abd7f13bc742dc6d08f00c41de42cb0c26f17dc Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 7 Feb 2022 17:45:28 +0000 Subject: [PATCH] [mlir][Linalg] NFC: Combine elementwise fusion test passes. There are a few different test passes that check elementwise fusion in Linalg. Consolidate them to a single pass controlled by different pass options (in keeping with how `TestLinalgTransforms` exists). --- .../Linalg/fusion-elementwise-options.mlir | 2 +- .../Dialect/Linalg/fusion-push-reshape.mlir | 2 +- .../Linalg/reshape_control_fusion.mlir | 2 +- .../Linalg/TestLinalgElementwiseFusion.cpp | 137 ++++++++---------- mlir/tools/mlir-opt/mlir-opt.cpp | 4 - 5 files changed, 64 insertions(+), 83 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir index d81aab66491a..103a04d79ba4 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir index 0c02ff8c54d1..a1d428865120 100644 --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir index d9e440c96efd..c4e7d5552678 100644 --- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s +// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s func @control_producer_reshape_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 30bef4af8bcc..3efa97f94140 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -47,6 +47,9 @@ static bool setFusedOpOperandLimit(const OpResult &producer, namespace { struct TestLinalgElementwiseFusion : public PassWrapper> { + TestLinalgElementwiseFusion() = default; + TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) + : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -58,101 +61,83 @@ struct TestLinalgElementwiseFusion return "Test Linalg element wise operation fusion patterns"; } - void runOnOperation() override { - MLIRContext *context = &this->getContext(); - FuncOp funcOp = this->getOperation(); - RewritePatternSet fusionPatterns(context); + Option fuseGenericOps{ + *this, "fuse-generic-ops", + llvm::cl::desc("Test fusion of generic operations."), + llvm::cl::init(false)}; - linalg::populateElementwiseOpsFusionPatterns( - fusionPatterns, - linalg::LinalgElementwiseFusionOptions() - .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); + Option controlFuseByExpansion{ + *this, "control-fusion-by-expansion", + llvm::cl::desc( + "Test controlling fusion of reshape with generic op by expansion"), + llvm::cl::init(false)}; - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)); - } -}; - -struct TestLinalgControlFuseByExpansion - : public PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - StringRef getArgument() const final { - return "test-linalg-control-fusion-by-expansion"; - } - StringRef getDescription() const final { - return "Test controlling of fusion of elementwise ops with reshape by " - "expansion"; - } + Option pushExpandingReshape{ + *this, "push-expanding-reshape", + llvm::cl::desc("Test linalg expand_shape -> generic " + "to generic -> expand_shape pattern"), + llvm::cl::init(false)}; void runOnOperation() override { MLIRContext *context = &this->getContext(); FuncOp funcOp = this->getOperation(); - RewritePatternSet fusionPatterns(context); - linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = - [](const OpResult &producer, OpOperand &consumer) { - if (auto collapseOp = - producer.getDefiningOp()) { - if (!collapseOp.src().getDefiningOp()) { - return false; + if (fuseGenericOps) { + RewritePatternSet fusionPatterns(context); + linalg::populateElementwiseOpsFusionPatterns( + fusionPatterns, + linalg::LinalgElementwiseFusionOptions() + .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); + + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + return; + } + + if (controlFuseByExpansion) { + RewritePatternSet fusionPatterns(context); + + linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = + [](const OpResult &producer, OpOperand &consumer) { + if (auto collapseOp = + producer.getDefiningOp()) { + if (!collapseOp.src().getDefiningOp()) { + return false; + } } - } - if (auto expandOp = - dyn_cast(consumer.getOwner())) { - if (expandOp->hasOneUse()) { - OpOperand &use = *expandOp->getUses().begin(); - auto linalgOp = dyn_cast(use.getOwner()); - if (linalgOp && linalgOp.isOutputTensor(&use)) - return true; + if (auto expandOp = + dyn_cast(consumer.getOwner())) { + if (expandOp->hasOneUse()) { + OpOperand &use = *expandOp->getUses().begin(); + auto linalgOp = dyn_cast(use.getOwner()); + if (linalgOp && linalgOp.isOutputTensor(&use)) + return true; + } } - } - return linalg::skipUnitDimReshape(producer, consumer); - }; + return linalg::skipUnitDimReshape(producer, consumer); + }; - linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, - controlReshapeFusionFn); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)); + linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, + controlReshapeFusionFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + return; + } + + if (pushExpandingReshape) { + RewritePatternSet patterns(context); + linalg::populatePushReshapeOpsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } } }; -struct TestPushExpandingReshape - : public PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - StringRef getArgument() const final { return "test-linalg-push-reshape"; } - StringRef getDescription() const final { - return "Test Linalg reshape push patterns"; - } - - void runOnOperation() override { - MLIRContext *context = &this->getContext(); - FuncOp funcOp = this->getOperation(); - RewritePatternSet patterns(context); - linalg::populatePushReshapeOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); - } -}; } // namespace namespace test { void registerTestLinalgElementwiseFusion() { PassRegistration(); } - -void registerTestLinalgControlFuseByExpansion() { - PassRegistration(); -} - -void registerTestPushExpandingReshape() { - PassRegistration(); -} } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 5b09cb8671eb..73d1b54bbf4f 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -81,10 +81,8 @@ void registerTestGenericIRVisitorsPass(); void registerTestGenericIRVisitorsInterruptPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); -void registerTestLinalgControlFuseByExpansion(); void registerTestLinalgDistribution(); void registerTestLinalgElementwiseFusion(); -void registerTestPushExpandingReshape(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); void registerTestLinalgTiledLoopFusionTransforms(); @@ -172,10 +170,8 @@ void registerTestPasses() { mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); mlir::test::registerTestLinalgCodegenStrategy(); - mlir::test::registerTestLinalgControlFuseByExpansion(); mlir::test::registerTestLinalgDistribution(); mlir::test::registerTestLinalgElementwiseFusion(); - mlir::test::registerTestPushExpandingReshape(); mlir::test::registerTestLinalgFusionTransforms(); mlir::test::registerTestLinalgTensorFusionTransforms(); mlir::test::registerTestLinalgTiledLoopFusionTransforms();