From 76d72c941df19779a7abfdfda66d3f452cef1db8 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Mon, 29 Jun 2020 08:17:00 +0000 Subject: [PATCH] [MLIR][Shape] Lower `shape.get_extent` to `std.dim` when possible When the shape is derived from a tensor argument the shape extent can be derived directly from that tensor with `std.dim`. This lowering pattern circumvents the necessity to materialize the shape in memory. Differential Revision: https://reviews.llvm.org/D82644 --- .../ShapeToStandard/ShapeToStandardPatterns.td | 7 +++++++ .../ShapeToStandard/shape-to-standard.mlir | 16 ++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td index a1335487f5ab..473da36a84ec 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -19,3 +19,10 @@ def SizeToIndexOpConversion : Pat< (Shape_SizeToIndexOp $arg), (replaceWithValue $arg)>; +// Derive shape extent directly from shape origin if possible. +// This circumvents the necessity to materialize the shape in memory. +def GetExtentShapeOfConversion : Pat< + (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), + (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))), + [], + (addBenefit 10)>; diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index a9b4bf701909..f9daadd03196 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -127,3 +127,19 @@ func @rank(%shape : !shape.shape) -> !shape.size { %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Express `get_extent` as `std.dim` when it relies directly on the outcome of a +// `shape_of` operation. +// CHECK-LABEL: @get_extent_shape_of +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index +func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size) + -> !shape.size { + // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> + // CHECK: return %[[RESULT]] : index + %shape = shape.shape_of %arg : tensor<2x3xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +} +