[mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.
Reorder the methods and patterns to move related patterns/methods closer (textually). Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D118870
This commit is contained in:
parent
4f3f4d6722
commit
32288d3722
|
@ -27,6 +27,10 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse elementwise `linalg.generic` operations.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
|
||||
/// the `producer` to use in the fused operation given the indexing map of the
|
||||
/// result of the producer in the consumer.
|
||||
|
@ -345,6 +349,58 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
|
|||
return SmallVector<Value>(fusedOp->getResults());
|
||||
}
|
||||
|
||||
static Optional<SmallVector<Value>>
|
||||
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
||||
GenericOp producer,
|
||||
const ControlElementwiseOpsFusionFn &controlFn) {
|
||||
if (producer->getNumResults() != 1)
|
||||
return llvm::None;
|
||||
|
||||
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Patterns to fuse a generic op, with the producer of its operands.
|
||||
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
|
||||
public:
|
||||
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find the first operand that is defined by another generic op on tensors.
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
auto producer =
|
||||
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
|
||||
if (!producer || !producer.hasTensorSemantics())
|
||||
continue;
|
||||
Optional<SmallVector<Value>> fusedOpResults =
|
||||
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
|
||||
if (fusedOpResults) {
|
||||
rewriter.replaceOp(genericOp, *fusedOpResults);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFn;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse reshape ops with elementwise operations by
|
||||
// linearization of indexing maps.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
|
||||
// these produce in the general case are detrimental to transformations.
|
||||
// They are useful now only in the limited case of unit-dimension folding.
|
||||
// Remove these in favor of more general folding by dimension contraction.
|
||||
|
||||
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
|
||||
/// provided, given the shape of the source tensor that corresponds to the
|
||||
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
|
||||
|
@ -445,6 +501,157 @@ static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
|
||||
/// of the reshape op as the operand in the consumer (instead of the result of
|
||||
/// the tensor_collapse_shape). The corresponding index map in the consumer
|
||||
/// needs to be modified to linearize the folded dimension.
|
||||
///
|
||||
/// For example,
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
|
||||
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
|
||||
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
|
||||
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
///
|
||||
/// can be folded into
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
||||
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
|
||||
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldProducerReshapeOpByLinearization
|
||||
: public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (const auto &en : llvm::enumerate(inputOperands)) {
|
||||
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
|
||||
if (!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
||||
/*asProducer =*/true) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
continue;
|
||||
|
||||
// Compute the fused operands list,
|
||||
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
||||
fusedOperands[en.index()] = reshapeOp.src();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
llvm::append_range(fusedOperands, outputOperands);
|
||||
|
||||
// Compute indexing_maps for the fused operation. The indexing_maps for
|
||||
// the operands of the consumers that arent fused are the same.
|
||||
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the result of the producer.
|
||||
AffineMap modifiedMap =
|
||||
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
|
||||
// The modified map cannot have symbols.
|
||||
if (modifiedMap.getNumSymbols())
|
||||
return failure();
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine())
|
||||
return failure();
|
||||
}
|
||||
fusedIndexMaps[en.index()] = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
rewriter.startRootUpdate(genericOp);
|
||||
genericOp->setOperands(fusedOperands);
|
||||
genericOp.indexing_mapsAttr(
|
||||
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
|
||||
rewriter.finalizeRootUpdate(genericOp);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
|
||||
/// producer. The corresponding index map in the consumer needs to be modified
|
||||
/// to linearize the folded dimension.
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldConsumerReshapeOpByLinearization
|
||||
: public OpRewritePattern<TensorReshapeOp> {
|
||||
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
|
||||
if (!producer || !producer.hasTensorSemantics() ||
|
||||
producer.getNumOutputs() != 1 ||
|
||||
!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp,
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
||||
/*asProducer =*/false) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
return failure();
|
||||
// The indexing_maps for the operands of the fused operation are same as
|
||||
// those for the operands of the producer.
|
||||
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap = linearizeCollapsedDims(
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op indexing map is not affine");
|
||||
}
|
||||
}
|
||||
fusedIndexMaps.back() = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
Location loc = producer.getLoc();
|
||||
SmallVector<Value> inputOperands = producer.getInputOperands();
|
||||
Value output = rewriter.create<TensorReshapeOp>(
|
||||
loc, producer.getOutputOperand(0)->get(),
|
||||
reshapeOp.getReassociationExprs());
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
loc, reshapeOp.getResultType(),
|
||||
/*inputs=*/inputOperands,
|
||||
// TODO: handle outputs.
|
||||
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
producer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
auto &fusedRegion = fusedOp->getRegion(0);
|
||||
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
|
||||
fusedRegion.begin());
|
||||
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse reshape ops with elementwise operations by
|
||||
// expanding the dimensionality of the elementwise operations.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Conditions for folding a generic operation with a reshape op by expanding
|
||||
/// the iteration space dimensionality for tensor operations. These are
|
||||
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
|
||||
|
@ -612,9 +819,9 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|||
/// Note that this could be extended to handle dynamic case, but the
|
||||
/// implementation below uses `affine.apply` which seems to have issues when the
|
||||
/// shapes are not static.
|
||||
LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
||||
const ExpansionInfo &expansionInfo,
|
||||
PatternRewriter &rewriter) {
|
||||
static LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
||||
const ExpansionInfo &expansionInfo,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!genericOp.hasIndexSemantics())
|
||||
return success();
|
||||
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
|
||||
|
@ -863,88 +1070,85 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
|||
|
||||
namespace {
|
||||
|
||||
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
|
||||
/// of the reshape op as the operand in the consumer (instead of the result of
|
||||
/// the tensor_collapse_shape). The corresponding index map in the consumer
|
||||
/// needs to be modified to linearize the folded dimension.
|
||||
///
|
||||
/// For example,
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
|
||||
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
|
||||
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
|
||||
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
///
|
||||
/// can be folded into
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
||||
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
|
||||
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldProducerReshapeOpByLinearization
|
||||
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
|
||||
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
|
||||
/// in the consumer is expanded.
|
||||
class FoldWithProducerReshapeOpByExpansion
|
||||
: public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
public:
|
||||
FoldWithProducerReshapeOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<GenericOp>(context, benefit),
|
||||
controlFoldingReshapes(std::move(foldReshapes)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (const auto &en : llvm::enumerate(inputOperands)) {
|
||||
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
|
||||
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
||||
tensor::CollapseShapeOp reshapeOp =
|
||||
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
|
||||
if (!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
||||
/*asProducer =*/true) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
// Fold only if
|
||||
// - The tensor reshape op is folding.
|
||||
// - All constraints of fusing with reshape by expansion are met.
|
||||
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
|
||||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
|
||||
continue;
|
||||
|
||||
// Compute the fused operands list,
|
||||
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
||||
fusedOperands[en.index()] = reshapeOp.src();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
llvm::append_range(fusedOperands, outputOperands);
|
||||
|
||||
// Compute indexing_maps for the fused operation. The indexing_maps for
|
||||
// the operands of the consumers that arent fused are the same.
|
||||
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the result of the producer.
|
||||
AffineMap modifiedMap =
|
||||
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
|
||||
// The modified map cannot have symbols.
|
||||
if (modifiedMap.getNumSymbols())
|
||||
Optional<SmallVector<Value>> replacementValues =
|
||||
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine())
|
||||
return failure();
|
||||
}
|
||||
fusedIndexMaps[en.index()] = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
rewriter.startRootUpdate(genericOp);
|
||||
genericOp->setOperands(fusedOperands);
|
||||
genericOp.indexing_mapsAttr(
|
||||
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
|
||||
rewriter.finalizeRootUpdate(genericOp);
|
||||
rewriter.replaceOp(genericOp, replacementValues.getValue());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
|
||||
/// Pattern to fold a tensor_expand_shape op with its producer generic op
|
||||
/// by expanding the dimensionality of the loop in the producer op.
|
||||
struct FoldReshapeWithGenericOpByExpansion
|
||||
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
||||
|
||||
FoldReshapeWithGenericOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
|
||||
controlFoldingReshapes(std::move(foldReshapes)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Fold only if all constraints of fusing with reshape by expansion are met.
|
||||
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
|
||||
if (!producer || producer.getNumOutputs() != 1 ||
|
||||
!isFusableWithReshapeByDimExpansion(producer,
|
||||
producer.getOutputOperand(0)) ||
|
||||
!controlFoldingReshapes(producer->getResult(0),
|
||||
reshapeOp->getOpOperand(0)))
|
||||
return failure();
|
||||
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
||||
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns to convert tensor.expand_shape -> linalg.generic
|
||||
// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
static SmallVector<ReassociationIndices>
|
||||
getReassociationIndices(ArrayRef<AffineMap> maps) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
|
@ -959,6 +1163,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
|
|||
return reassociation;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to move rank reducing reshape after an elementwise linalg generic
|
||||
/// op. This is useful to expose more fusion opportunities between named ops and
|
||||
/// generic ops. This can only be done if there is no broadcast or permuation
|
||||
|
@ -1100,142 +1305,13 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
|
||||
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
|
||||
/// in the consumer is expanded.
|
||||
class FoldWithProducerReshapeOpByExpansion
|
||||
: public OpRewritePattern<GenericOp> {
|
||||
public:
|
||||
FoldWithProducerReshapeOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<GenericOp>(context, benefit),
|
||||
controlFoldingReshapes(std::move(foldReshapes)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
||||
tensor::CollapseShapeOp reshapeOp =
|
||||
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
// Fold only if
|
||||
// - The tensor reshape op is folding.
|
||||
// - All constraints of fusing with reshape by expansion are met.
|
||||
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
|
||||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
|
||||
continue;
|
||||
|
||||
Optional<SmallVector<Value>> replacementValues =
|
||||
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
rewriter.replaceOp(genericOp, replacementValues.getValue());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
|
||||
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
|
||||
/// producer. The corresponding index map in the consumer needs to be modified
|
||||
/// to linearize the folded dimension.
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldConsumerReshapeOpByLinearization
|
||||
: public OpRewritePattern<TensorReshapeOp> {
|
||||
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
|
||||
if (!producer || !producer.hasTensorSemantics() ||
|
||||
producer.getNumOutputs() != 1 ||
|
||||
!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp,
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
||||
/*asProducer =*/false) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
return failure();
|
||||
// The indexing_maps for the operands of the fused operation are same as
|
||||
// those for the operands of the producer.
|
||||
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap = linearizeCollapsedDims(
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op indexing map is not affine");
|
||||
}
|
||||
}
|
||||
fusedIndexMaps.back() = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
Location loc = producer.getLoc();
|
||||
SmallVector<Value> inputOperands = producer.getInputOperands();
|
||||
Value output = rewriter.create<TensorReshapeOp>(
|
||||
loc, producer.getOutputOperand(0)->get(),
|
||||
reshapeOp.getReassociationExprs());
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
loc, reshapeOp.getResultType(),
|
||||
/*inputs=*/inputOperands,
|
||||
// TODO: handle outputs.
|
||||
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
producer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
auto &fusedRegion = fusedOp->getRegion(0);
|
||||
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
|
||||
fusedRegion.begin());
|
||||
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to fold a tensor_expand_shape op with its producer generic op
|
||||
/// by expanding the dimensionality of the loop in the producer op.
|
||||
struct FoldReshapeWithGenericOpByExpansion
|
||||
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
||||
|
||||
FoldReshapeWithGenericOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
|
||||
controlFoldingReshapes(std::move(foldReshapes)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Fold only if all constraints of fusing with reshape by expansion are met.
|
||||
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
|
||||
if (!producer || producer.getNumOutputs() != 1 ||
|
||||
!isFusableWithReshapeByDimExpansion(producer,
|
||||
producer.getOutputOperand(0)) ||
|
||||
!controlFoldingReshapes(producer->getResult(0),
|
||||
reshapeOp->getOpOperand(0)))
|
||||
return failure();
|
||||
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
||||
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse constants with linalg.generic operations.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
|
||||
/// handle cases where the constant is not single-valued.
|
||||
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
|
||||
|
@ -1624,98 +1700,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
|
|||
|
||||
} // namespace
|
||||
|
||||
static Optional<SmallVector<Value>>
|
||||
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
||||
GenericOp producer,
|
||||
const ControlElementwiseOpsFusionFn &controlFn) {
|
||||
if (producer->getNumResults() != 1)
|
||||
return llvm::None;
|
||||
|
||||
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
||||
OpOperand &consumer) {
|
||||
if (auto producerCollapseOp =
|
||||
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(producerCollapseOp);
|
||||
}
|
||||
if (auto consumerExpandOp =
|
||||
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(consumerExpandOp);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Miscellaneous patterns that help fusion.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Patterns to fuse a generic op, with the producer of its operands.
|
||||
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
|
||||
public:
|
||||
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find the first operand that is defined by another generic op on tensors.
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
auto producer =
|
||||
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
|
||||
if (!producer || !producer.hasTensorSemantics())
|
||||
continue;
|
||||
Optional<SmallVector<Value>> fusedOpResults =
|
||||
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
|
||||
if (fusedOpResults) {
|
||||
rewriter.replaceOp(genericOp, *fusedOpResults);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFn;
|
||||
};
|
||||
|
||||
/// Pass that fuses generic ops on tensors. Used only for testing.
|
||||
struct LinalgElementwiseOpFusionPass
|
||||
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
ControlElementwiseOpsFusionFn allowFoldingFn =
|
||||
[](const OpResult &producer, const OpOperand &consumer) {
|
||||
return true;
|
||||
};
|
||||
populateElementwiseOpsFusionPatterns(
|
||||
patterns,
|
||||
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
|
||||
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
|
||||
|
||||
// Use TopDownTraversal for compile time reasons
|
||||
GreedyRewriteConfig grc;
|
||||
grc.useTopDownTraversal = true;
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
|
||||
grc);
|
||||
}
|
||||
};
|
||||
|
||||
/// Pass to test folding of reshape ops with generic ops by linearization.
|
||||
struct FoldReshapeOpsByLinearizationPass
|
||||
: public LinalgFoldReshapeOpsByLinearizationBase<
|
||||
FoldReshapeOpsByLinearizationPass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateFoldReshapeOpsByLinearizationPatterns(patterns);
|
||||
if (allowFoldingUnitDimReshapes) {
|
||||
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
|
||||
/// the value of the `outs` operand is not used within the op. This is only
|
||||
/// implemented for `linalg.generic` operations for now, but should hold for all
|
||||
|
@ -1761,9 +1750,12 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods that add patterns descrined in this file to a pattern list.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
|
@ -1815,6 +1807,65 @@ void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
|||
patterns.add<PushExpandingReshape>(context);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
||||
OpOperand &consumer) {
|
||||
if (auto producerCollapseOp =
|
||||
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(producerCollapseOp);
|
||||
}
|
||||
if (auto consumerExpandOp =
|
||||
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(consumerExpandOp);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pass that fuses generic ops on tensors. Used only for testing.
|
||||
struct LinalgElementwiseOpFusionPass
|
||||
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
ControlElementwiseOpsFusionFn allowFoldingFn =
|
||||
[](const OpResult &producer, const OpOperand &consumer) {
|
||||
return true;
|
||||
};
|
||||
populateElementwiseOpsFusionPatterns(
|
||||
patterns,
|
||||
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
|
||||
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
|
||||
|
||||
// Use TopDownTraversal for compile time reasons
|
||||
GreedyRewriteConfig grc;
|
||||
grc.useTopDownTraversal = true;
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
|
||||
grc);
|
||||
}
|
||||
};
|
||||
|
||||
/// Pass to test folding of reshape ops with generic ops by linearization.
|
||||
struct FoldReshapeOpsByLinearizationPass
|
||||
: public LinalgFoldReshapeOpsByLinearizationBase<
|
||||
FoldReshapeOpsByLinearizationPass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateFoldReshapeOpsByLinearizationPatterns(patterns);
|
||||
if (allowFoldingUnitDimReshapes) {
|
||||
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
|
||||
return std::make_unique<LinalgElementwiseOpFusionPass>();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue