[MLIR][Shape] Fix lowering of `shape.get_extent`

The declarative conversion patterns caused crashes in the asan configuration.
The non-declarative implementation circumvents this.

Differential Revision: https://reviews.llvm.org/D82797
This commit is contained in:
Frederik Gossen 2020-06-30 08:33:49 +00:00
parent fe08ab542b
commit 8577a090f5
2 changed files with 24 additions and 17 deletions

View File

@ -90,6 +90,29 @@ public:
}
};
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
GetExtentOp::Adaptor transformed(operands);
// Derive shape extent directly from shape origin if possible.
// This circumvents the necessity to materialize the shape in memory.
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
transformed.dim());
return success();
}
rewriter.replaceOpWithNewOp<ExtractElementOp>(
op, rewriter.getIndexType(), transformed.shape(),
ValueRange{transformed.dim()});
return success();
}
};
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter,
GetExtentOpConverter,
RankOpConverter,
ShapeOfOpConversion>(ctx);
// clang-format on

View File

@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
(Shape_SizeToIndexOp $arg),
(replaceWithValue $arg)>;
// Derive shape extent directly from shape origin if possible.
// This circumvents the necessity to materialize the shape in memory.
def GetExtentShapeOfConversion : Pat<
(Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
(Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
[],
(addBenefit 10)>;
def GetExtentFromExtentTensorConversion : Pattern<
(Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
[
(Shape_SizeToIndexOp:$std_idx $idx),
(ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
(Shape_IndexToSizeOp $std_result)
],
[],
(addBenefit 10)>;