Revert "[mlir][Linalg] NFC: Combine elementwise fusion test passes."

This reverts commit d730336411.
This commit is contained in:
Mahesh Ravishankar 2022-02-07 22:50:36 +00:00
parent 157bbe6aea
commit 7568f7101f
5 changed files with 75 additions and 61 deletions

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#binary2Dpointwise = {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index

View File

@ -58,29 +58,11 @@ struct TestLinalgElementwiseFusion
return "Test Linalg element wise operation fusion patterns";
}
Option<bool>
fuseGenericOps(*this, "fuse-generic-ops",
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false));
Option<bool> controlFuseByExpansion(
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false));
Option<bool>
pushExpandingReshape(*this, "push-expanding-reshape",
llvm::cl::desc("Test linalg expand_shape -> generic "
"to generic -> expand_shape pattern"),
llvm::cl::init(false));
void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
linalg::LinalgElementwiseFusionOptions()
@ -88,10 +70,27 @@ struct TestLinalgElementwiseFusion
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
};
struct TestLinalgControlFuseByExpansion
: public PassWrapper<TestLinalgControlFuseByExpansion,
OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-control-fusion-by-expansion";
}
StringRef getDescription() const final {
return "Test controlling of fusion of elementwise ops with reshape by "
"expansion";
}
if (controlFuseByExpansion) {
void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
RewritePatternSet fusionPatterns(context);
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
@ -118,17 +117,28 @@ struct TestLinalgElementwiseFusion
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
};
struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
StringRef getDescription() const final {
return "Test Linalg reshape push patterns";
}
if (pushExpandingReshape) {
void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
}
};
} // namespace
namespace test {

View File

@ -81,8 +81,10 @@ void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgControlFuseByExpansion();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgTiledLoopFusionTransforms();
@ -170,8 +172,10 @@ void registerTestPasses() {
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgControlFuseByExpansion();
mlir::test::registerTestLinalgDistribution();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestPushExpandingReshape();
mlir::test::registerTestLinalgFusionTransforms();
mlir::test::registerTestLinalgTensorFusionTransforms();
mlir::test::registerTestLinalgTiledLoopFusionTransforms();