[mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.

Reorder the methods and patterns to move related patterns/methods
closer (textually).

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D118870
This commit is contained in:
Mahesh Ravishankar 2022-02-03 18:40:26 +00:00
parent 4f3f4d6722
commit 32288d3722
1 changed files with 346 additions and 295 deletions

View File

@ -27,6 +27,10 @@
using namespace mlir;
using namespace mlir::linalg;
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse elementwise `linalg.generic` operations.
//===---------------------------------------------------------------------===//
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
@ -345,6 +349,58 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
return SmallVector<Value>(fusedOp->getResults());
}
static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
GenericOp producer,
const ControlElementwiseOpsFusionFn &controlFn) {
if (producer->getNumResults() != 1)
return llvm::None;
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
rewriter);
}
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
auto producer =
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
if (!producer || !producer.hasTensorSemantics())
continue;
Optional<SmallVector<Value>> fusedOpResults =
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
if (fusedOpResults) {
rewriter.replaceOp(genericOp, *fusedOpResults);
return success();
}
}
return failure();
}
private:
ControlElementwiseOpsFusionFn controlFn;
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// linearization of indexing maps.
//===---------------------------------------------------------------------===//
// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
// these produce in the general case are detrimental to transformations.
// They are useful now only in the limited case of unit-dimension folding.
// Remove these in favor of more general folding by dimension contraction.
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
/// provided, given the shape of the source tensor that corresponds to the
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
@ -445,6 +501,157 @@ static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
return true;
}
namespace {
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
/// of the reshape op as the operand in the consumer (instead of the result of
/// the tensor_collapse_shape). The corresponding index map in the consumer
/// needs to be modified to linearize the folded dimension.
///
/// For example,
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
///
/// can be folded into
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldProducerReshapeOpByLinearization
: public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (const auto &en : llvm::enumerate(inputOperands)) {
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
if (!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
continue;
// Compute the fused operands list,
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
fusedOperands[en.index()] = reshapeOp.src();
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
llvm::append_range(fusedOperands, outputOperands);
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
// The modified map cannot have symbols.
if (modifiedMap.getNumSymbols())
return failure();
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
return failure();
}
fusedIndexMaps[en.index()] = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
genericOp, "fused op loop bound computation failed");
}
rewriter.startRootUpdate(genericOp);
genericOp->setOperands(fusedOperands);
genericOp.indexing_mapsAttr(
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
rewriter.finalizeRootUpdate(genericOp);
return success();
}
return failure();
}
};
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
/// producer. The corresponding index map in the consumer needs to be modified
/// to linearize the folded dimension.
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldConsumerReshapeOpByLinearization
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp,
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
// Compute the indexing map to use for the operand of the producer.
AffineMap modifiedMap = linearizeCollapsedDims(
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine()) {
return rewriter.notifyMatchFailure(
producer, "fused op indexing map is not affine");
}
}
fusedIndexMaps.back() = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
producer, "fused op loop bound computation failed");
}
Location loc = producer.getLoc();
SmallVector<Value> inputOperands = producer.getInputOperands();
Value output = rewriter.create<TensorReshapeOp>(
loc, producer.getOutputOperand(0)->get(),
reshapeOp.getReassociationExprs());
auto fusedOp = rewriter.create<GenericOp>(
loc, reshapeOp.getResultType(),
/*inputs=*/inputOperands,
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
auto &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
fusedRegion.begin());
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// expanding the dimensionality of the elementwise operations.
//===---------------------------------------------------------------------===//
/// Conditions for folding a generic operation with a reshape op by expanding
/// the iteration space dimensionality for tensor operations. These are
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
@ -612,9 +819,9 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
/// Note that this could be extended to handle dynamic case, but the
/// implementation below uses `affine.apply` which seems to have issues when the
/// shapes are not static.
LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
static LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@ -863,88 +1070,85 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
namespace {
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
/// of the reshape op as the operand in the consumer (instead of the result of
/// the tensor_collapse_shape). The corresponding index map in the consumer
/// needs to be modified to linearize the folded dimension.
///
/// For example,
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
///
/// can be folded into
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldProducerReshapeOpByLinearization
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
public:
FoldWithProducerReshapeOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (const auto &en : llvm::enumerate(inputOperands)) {
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
if (!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
// Compute the fused operands list,
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
fusedOperands[en.index()] = reshapeOp.src();
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
llvm::append_range(fusedOperands, outputOperands);
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
// The modified map cannot have symbols.
if (modifiedMap.getNumSymbols())
Optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
return failure();
}
fusedIndexMaps[en.index()] = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
genericOp, "fused op loop bound computation failed");
}
rewriter.startRootUpdate(genericOp);
genericOp->setOperands(fusedOperands);
genericOp.indexing_mapsAttr(
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
rewriter.finalizeRootUpdate(genericOp);
rewriter.replaceOp(genericOp, replacementValues.getValue());
return success();
}
return failure();
}
private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold a tensor_expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
FoldReshapeWithGenericOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getOutputOperand(0)) ||
!controlFoldingReshapes(producer->getResult(0),
reshapeOp->getOpOperand(0)))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
return success();
}
private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns to convert tensor.expand_shape -> linalg.generic
// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
//===---------------------------------------------------------------------===//
static SmallVector<ReassociationIndices>
getReassociationIndices(ArrayRef<AffineMap> maps) {
SmallVector<ReassociationIndices> reassociation;
@ -959,6 +1163,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
return reassociation;
}
namespace {
/// Pattern to move rank reducing reshape after an elementwise linalg generic
/// op. This is useful to expose more fusion opportunities between named ops and
/// generic ops. This can only be done if there is no broadcast or permuation
@ -1100,142 +1305,13 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
return success();
}
};
} // namespace
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOp> {
public:
FoldWithProducerReshapeOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
Optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(genericOp, replacementValues.getValue());
return success();
}
return failure();
}
private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
/// producer. The corresponding index map in the consumer needs to be modified
/// to linearize the folded dimension.
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldConsumerReshapeOpByLinearization
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp,
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
// Compute the indexing map to use for the operand of the producer.
AffineMap modifiedMap = linearizeCollapsedDims(
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine()) {
return rewriter.notifyMatchFailure(
producer, "fused op indexing map is not affine");
}
}
fusedIndexMaps.back() = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
producer, "fused op loop bound computation failed");
}
Location loc = producer.getLoc();
SmallVector<Value> inputOperands = producer.getInputOperands();
Value output = rewriter.create<TensorReshapeOp>(
loc, producer.getOutputOperand(0)->get(),
reshapeOp.getReassociationExprs());
auto fusedOp = rewriter.create<GenericOp>(
loc, reshapeOp.getResultType(),
/*inputs=*/inputOperands,
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
auto &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
fusedRegion.begin());
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
return success();
}
};
/// Pattern to fold a tensor_expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
FoldReshapeWithGenericOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getOutputOperand(0)) ||
!controlFoldingReshapes(producer->getResult(0),
reshapeOp->getOpOperand(0)))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
return success();
}
private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse constants with linalg.generic operations.
//===---------------------------------------------------------------------===//
namespace {
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
@ -1624,98 +1700,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
} // namespace
static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
GenericOp producer,
const ControlElementwiseOpsFusionFn &controlFn) {
if (producer->getNumResults() != 1)
return llvm::None;
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
rewriter);
}
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
OpOperand &consumer) {
if (auto producerCollapseOp =
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
return !isUnitDimExpansionOnly(producerCollapseOp);
}
if (auto consumerExpandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
return !isUnitDimExpansionOnly(consumerExpandOp);
}
return true;
}
//===---------------------------------------------------------------------===//
// Miscellaneous patterns that help fusion.
//===---------------------------------------------------------------------===//
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
auto producer =
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
if (!producer || !producer.hasTensorSemantics())
continue;
Optional<SmallVector<Value>> fusedOpResults =
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
if (fusedOpResults) {
rewriter.replaceOp(genericOp, *fusedOpResults);
return success();
}
}
return failure();
}
private:
ControlElementwiseOpsFusionFn controlFn;
};
/// Pass that fuses generic ops on tensors. Used only for testing.
struct LinalgElementwiseOpFusionPass
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
ControlElementwiseOpsFusionFn allowFoldingFn =
[](const OpResult &producer, const OpOperand &consumer) {
return true;
};
populateElementwiseOpsFusionPatterns(
patterns,
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
// Use TopDownTraversal for compile time reasons
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
grc);
}
};
/// Pass to test folding of reshape ops with generic ops by linearization.
struct FoldReshapeOpsByLinearizationPass
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateFoldReshapeOpsByLinearizationPatterns(patterns);
if (allowFoldingUnitDimReshapes) {
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
/// the value of the `outs` operand is not used within the op. This is only
/// implemented for `linalg.generic` operations for now, but should hold for all
@ -1761,9 +1750,12 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns descrined in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns
@ -1815,6 +1807,65 @@ void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
patterns.add<PushExpandingReshape>(context);
}
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
OpOperand &consumer) {
if (auto producerCollapseOp =
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
return !isUnitDimExpansionOnly(producerCollapseOp);
}
if (auto consumerExpandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
return !isUnitDimExpansionOnly(consumerExpandOp);
}
return true;
}
namespace {
/// Pass that fuses generic ops on tensors. Used only for testing.
struct LinalgElementwiseOpFusionPass
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
ControlElementwiseOpsFusionFn allowFoldingFn =
[](const OpResult &producer, const OpOperand &consumer) {
return true;
};
populateElementwiseOpsFusionPatterns(
patterns,
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
// Use TopDownTraversal for compile time reasons
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
grc);
}
};
/// Pass to test folding of reshape ops with generic ops by linearization.
struct FoldReshapeOpsByLinearizationPass
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateFoldReshapeOpsByLinearizationPatterns(patterns);
if (allowFoldingUnitDimReshapes) {
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
} // namespace
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
return std::make_unique<LinalgElementwiseOpFusionPass>();
}