[mlir] Don't allow dynamic extent tensor types for ConstShapeOp.

ConstShapeOp has a constant shape, so its type can always be static.
We still allow it to have ShapeType though.

Differential Revision: https://reviews.llvm.org/D111139
This commit is contained in:
Adrian Kuegel 2021-10-05 14:37:30 +02:00
parent 14cb138b15
commit 2bb208ddfd
4 changed files with 54 additions and 55 deletions

View File

@ -827,14 +827,10 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
Type lhs = l.front();
Type rhs = r.front();
if (lhs == rhs)
return true;
if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
// Shape type is compatible with all other valid return types.
return true;
return succeeded(verifyCompatibleShapes(lhs, rhs));
return lhs == rhs;
}
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,10 @@ def HasSingleElement : Constraint<CPred< [{
$0.size() == 1
}]>>;
def HasStaticShape : Constraint<CPred< [{
$0.getType().dyn_cast<ShapedType>().hasStaticShape()
}]>>;
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
@ -37,4 +41,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
(Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;

View File

@ -1,10 +1,10 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
// CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
%0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
return %0 : tensor<?xindex>
func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
// CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex>
%0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex>
return %0 : tensor<3xindex>
}
// -----
@ -62,13 +62,13 @@ func @f() -> !shape.shape {
// Basic case including extent tensors.
// CHECK-LABEL: @broadcast
func @broadcast() -> tensor<?xindex> {
// CHECK: shape.const_shape [7, 2] : tensor<?xindex>
%0 = shape.const_shape [1, 2] : tensor<?xindex>
%1 = shape.const_shape [7, 1] : tensor<?xindex>
func @broadcast() -> tensor<2xindex> {
// CHECK: shape.const_shape [7, 2] : tensor<2xindex>
%0 = shape.const_shape [1, 2] : tensor<2xindex>
%1 = shape.const_shape [7, 1] : tensor<2xindex>
%2 = shape.broadcast %0, %1
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return %2 : tensor<?xindex>
: tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
return %2 : tensor<2xindex>
}
// -----
@ -77,9 +77,9 @@ func @broadcast() -> tensor<?xindex> {
// CHECK-LABEL: @broadcast
func @broadcast() -> !shape.shape {
// CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : tensor<?xindex>
%1 = shape.const_shape [7, 1] : tensor<?xindex>
%2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
%0 = shape.const_shape [1, 2] : tensor<2xindex>
%1 = shape.const_shape [7, 1] : tensor<2xindex>
%2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<2xindex> -> !shape.shape
return %2 : !shape.shape
}
@ -317,9 +317,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-LABEL: func @basic
func @basic() -> index {
// CHECK: constant 2 : index
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
%c2 = constant 2 : index
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
%1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
return %1 : index
}
@ -330,9 +330,9 @@ func @basic() -> index {
func @out_of_bounds() -> index {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
%c3 = constant 3 : index
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
%1 = shape.get_extent %0, %c3 : tensor<3xindex>, index -> index
return %1 : index
}
@ -559,12 +559,12 @@ func @f(%arg : !shape.shape) -> !shape.shape {
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
%1 = shape.any %0, %arg : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return %1 : tensor<?xindex>
func @f(%arg : tensor<?xindex>) -> tensor<3xindex> {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<3xindex>
// CHECK-NEXT: return %[[CS]] : tensor<3xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<3xindex>
%1 = shape.any %0, %arg : tensor<3xindex>, tensor<?xindex> -> tensor<3xindex>
return %1 : tensor<3xindex>
}
// -----
@ -837,8 +837,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
func @fold_rank() -> index {
// CHECK: %[[RESULT:.*]] = constant 5 : index
// CHECK: return %[[RESULT]] : index
%shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
%rank = shape.rank %shape : tensor<?xindex> -> index
%shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<5xindex>
%rank = shape.rank %shape : tensor<5xindex> -> index
return %rank : index
}
@ -971,9 +971,9 @@ func @shape_eq_fold_1() -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : !shape.shape
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
%c = shape.const_shape [1, 2, 3] : tensor<?xindex>
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
%b = shape.const_shape [1, 2, 3] : tensor<3xindex>
%c = shape.const_shape [1, 2, 3] : tensor<3xindex>
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<3xindex>, tensor<3xindex>
return %result : i1
}
@ -984,10 +984,10 @@ func @shape_eq_fold_1() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
%c = shape.const_shape [4, 5, 6] : tensor<?xindex>
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
%a = shape.const_shape [1, 2, 3] : tensor<3xindex>
%b = shape.const_shape [4, 5, 6] : tensor<3xindex>
%c = shape.const_shape [4, 5, 6] : tensor<3xindex>
%result = shape.shape_eq %a, %b, %c : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>
return %result : i1
}
@ -1161,18 +1161,17 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
// CHECK: shape.const_shape [2] : tensor<1xindex>
// CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
%0 = shape.const_shape [2] : tensor<1xindex>
%1 = tensor.cast %0 : tensor<1xindex> to tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
// Verify that tensor.cast folding uses the correct type
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
// CHECK: shape.const_shape [2] : tensor<?xindex>
// CHECK-NOT: tensor.cast
// CHECK-LABEL: @dont_fold_tensor.cast_of_const_shape_returned_dynamic
func @dont_fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
// CHECK: %[[CONST_SHAPE:.*]] = shape.const_shape [2] : tensor<1xindex>
// CHECK: tensor.cast %[[CONST_SHAPE]] : tensor<1xindex> to tensor<?xindex>
%0 = shape.const_shape [2] : tensor<1xindex>
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>

View File

@ -35,7 +35,6 @@ func @test_shape_num_elements_unknown() {
func @const_shape() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
return
}
@ -55,11 +54,11 @@ func @test_broadcast_fixed() {
return
}
func @test_broadcast_extents() -> tensor<?xindex> {
%0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
%1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
%2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return %2 : tensor<?xindex>
func @test_broadcast_extents() -> tensor<4xindex> {
%0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
%1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
%2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
return %2 : tensor<4xindex>
}
func @test_shape_any_fixed() {
@ -89,7 +88,7 @@ func @test_shape_any_fixed_mismatch() {
func @test_parse_const_shape() {
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
return
}
@ -222,9 +221,9 @@ func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
%3 = shape.const_shape [1, 2, 3] : tensor<3xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<3xindex>
%5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex>
return
}