[mlir] Don't allow dynamic extent tensor types for ConstShapeOp.
ConstShapeOp has a constant shape, so its type can always be static. We still allow it to have ShapeType though. Differential Revision: https://reviews.llvm.org/D111139
This commit is contained in:
parent
14cb138b15
commit
2bb208ddfd
|
@ -827,14 +827,10 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
|
|||
Type lhs = l.front();
|
||||
Type rhs = r.front();
|
||||
|
||||
if (lhs == rhs)
|
||||
return true;
|
||||
|
||||
if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
|
||||
// Shape type is compatible with all other valid return types.
|
||||
return true;
|
||||
|
||||
return succeeded(verifyCompatibleShapes(lhs, rhs));
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12,6 +12,10 @@ def HasSingleElement : Constraint<CPred< [{
|
|||
$0.size() == 1
|
||||
}]>>;
|
||||
|
||||
def HasStaticShape : Constraint<CPred< [{
|
||||
$0.getType().dyn_cast<ShapedType>().hasStaticShape()
|
||||
}]>>;
|
||||
|
||||
// Canonicalization patterns.
|
||||
|
||||
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
|
||||
|
@ -37,4 +41,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
|
|||
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
|
||||
// const_shape to the destination type of the cast.
|
||||
def TensorCastConstShape : Pat <
|
||||
(Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
|
||||
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
|
||||
[(HasStaticShape $res)]>;
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
|
||||
// CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
|
||||
%0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
|
||||
return %0 : tensor<?xindex>
|
||||
func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
|
||||
// CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex>
|
||||
%0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex>
|
||||
return %0 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -62,13 +62,13 @@ func @f() -> !shape.shape {
|
|||
|
||||
// Basic case including extent tensors.
|
||||
// CHECK-LABEL: @broadcast
|
||||
func @broadcast() -> tensor<?xindex> {
|
||||
// CHECK: shape.const_shape [7, 2] : tensor<?xindex>
|
||||
%0 = shape.const_shape [1, 2] : tensor<?xindex>
|
||||
%1 = shape.const_shape [7, 1] : tensor<?xindex>
|
||||
func @broadcast() -> tensor<2xindex> {
|
||||
// CHECK: shape.const_shape [7, 2] : tensor<2xindex>
|
||||
%0 = shape.const_shape [1, 2] : tensor<2xindex>
|
||||
%1 = shape.const_shape [7, 1] : tensor<2xindex>
|
||||
%2 = shape.broadcast %0, %1
|
||||
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return %2 : tensor<?xindex>
|
||||
: tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
|
||||
return %2 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -77,9 +77,9 @@ func @broadcast() -> tensor<?xindex> {
|
|||
// CHECK-LABEL: @broadcast
|
||||
func @broadcast() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [7, 2] : !shape.shape
|
||||
%0 = shape.const_shape [1, 2] : tensor<?xindex>
|
||||
%1 = shape.const_shape [7, 1] : tensor<?xindex>
|
||||
%2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
|
||||
%0 = shape.const_shape [1, 2] : tensor<2xindex>
|
||||
%1 = shape.const_shape [7, 1] : tensor<2xindex>
|
||||
%2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<2xindex> -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
||||
|
@ -317,9 +317,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
|
|||
// CHECK-LABEL: func @basic
|
||||
func @basic() -> index {
|
||||
// CHECK: constant 2 : index
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
|
||||
%c2 = constant 2 : index
|
||||
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
|
||||
%1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
|
@ -330,9 +330,9 @@ func @basic() -> index {
|
|||
func @out_of_bounds() -> index {
|
||||
// CHECK: shape.const_shape
|
||||
// CHECK: shape.get_extent
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
|
||||
%c3 = constant 3 : index
|
||||
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
|
||||
%1 = shape.get_extent %0, %c3 : tensor<3xindex>, index -> index
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
|
@ -559,12 +559,12 @@ func @f(%arg : !shape.shape) -> !shape.shape {
|
|||
|
||||
// any can be replaced with a constant input if it has one.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
|
||||
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
|
||||
%1 = shape.any %0, %arg : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
func @f(%arg : tensor<?xindex>) -> tensor<3xindex> {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<3xindex>
|
||||
// CHECK-NEXT: return %[[CS]] : tensor<3xindex>
|
||||
%0 = shape.const_shape [2, 3, 4] : tensor<3xindex>
|
||||
%1 = shape.any %0, %arg : tensor<3xindex>, tensor<?xindex> -> tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -837,8 +837,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
|
|||
func @fold_rank() -> index {
|
||||
// CHECK: %[[RESULT:.*]] = constant 5 : index
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
|
||||
%rank = shape.rank %shape : tensor<?xindex> -> index
|
||||
%shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<5xindex>
|
||||
%rank = shape.rank %shape : tensor<5xindex> -> index
|
||||
return %rank : index
|
||||
}
|
||||
|
||||
|
@ -971,9 +971,9 @@ func @shape_eq_fold_1() -> i1 {
|
|||
// CHECK: %[[RESULT:.*]] = constant true
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%c = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
|
||||
%b = shape.const_shape [1, 2, 3] : tensor<3xindex>
|
||||
%c = shape.const_shape [1, 2, 3] : tensor<3xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<3xindex>, tensor<3xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -984,10 +984,10 @@ func @shape_eq_fold_1() -> i1 {
|
|||
func @shape_eq_fold_0() -> i1 {
|
||||
// CHECK: %[[RESULT:.*]] = constant false
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%c = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
%a = shape.const_shape [1, 2, 3] : tensor<3xindex>
|
||||
%b = shape.const_shape [4, 5, 6] : tensor<3xindex>
|
||||
%c = shape.const_shape [4, 5, 6] : tensor<3xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -1161,18 +1161,17 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
|
|||
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
|
||||
// CHECK: shape.const_shape [2] : tensor<1xindex>
|
||||
// CHECK-NOT: tensor.cast
|
||||
%0 = shape.const_shape [2] : tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
|
||||
%0 = shape.const_shape [2] : tensor<1xindex>
|
||||
%1 = tensor.cast %0 : tensor<1xindex> to tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that tensor.cast folding uses the correct type
|
||||
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
|
||||
func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
|
||||
// CHECK: shape.const_shape [2] : tensor<?xindex>
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK-LABEL: @dont_fold_tensor.cast_of_const_shape_returned_dynamic
|
||||
func @dont_fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
|
||||
// CHECK: %[[CONST_SHAPE:.*]] = shape.const_shape [2] : tensor<1xindex>
|
||||
// CHECK: tensor.cast %[[CONST_SHAPE]] : tensor<1xindex> to tensor<?xindex>
|
||||
%0 = shape.const_shape [2] : tensor<1xindex>
|
||||
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
|
|
|
@ -35,7 +35,6 @@ func @test_shape_num_elements_unknown() {
|
|||
|
||||
func @const_shape() {
|
||||
%0 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
|
||||
return
|
||||
}
|
||||
|
@ -55,11 +54,11 @@ func @test_broadcast_fixed() {
|
|||
return
|
||||
}
|
||||
|
||||
func @test_broadcast_extents() -> tensor<?xindex> {
|
||||
%0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
|
||||
%1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
|
||||
%2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return %2 : tensor<?xindex>
|
||||
func @test_broadcast_extents() -> tensor<4xindex> {
|
||||
%0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
|
||||
%1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
|
||||
%2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
|
||||
return %2 : tensor<4xindex>
|
||||
}
|
||||
|
||||
func @test_shape_any_fixed() {
|
||||
|
@ -89,7 +88,7 @@ func @test_shape_any_fixed_mismatch() {
|
|||
func @test_parse_const_shape() {
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -222,9 +221,9 @@ func @any() {
|
|||
%0 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%1 = shape.const_shape [4, 5, 6] : !shape.shape
|
||||
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
|
||||
%3 = shape.const_shape [1, 2, 3] : tensor<3xindex>
|
||||
%4 = shape.const_shape [4, 5, 6] : tensor<3xindex>
|
||||
%5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue