[MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`

Differential Revision: https://reviews.llvm.org/D82848
This commit is contained in:
Frederik Gossen 2020-07-28 15:39:49 +00:00
parent a4edc04693
commit dfcc09890a
2 changed files with 50 additions and 0 deletions

View File

@ -103,6 +103,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
return success();
}
namespace {
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
public:
using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult ConstShapeOpConverter::matchAndRewrite(
ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering supports only extent tensors, not `shape.shape`
// types.
if (op.getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
SmallVector<Value, 4> extentOperands;
for (auto extent : op.shape()) {
extentOperands.push_back(
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
}
Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
Type indexTy = rewriter.getIndexType();
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
return success();
}
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
ConstShapeOpConverter,
BinaryOpConversion<MulOp, MulIOp>,
GetExtentOpConverter,
RankOpConverter,

View File

@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
// -----
// Lower `const_shape` to `tensor_from_elements`.
// CHECK-LABEL: @const_shape
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape() -> tensor<?xindex> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
return %shape : tensor<?xindex>
}
// -----
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>