1// Verify the printed output can be parsed. 2// RUN: mlir-opt %s | mlir-opt | FileCheck %s 3// Verify the generic form can be parsed. 4// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s 5 6// CHECK-LABEL: shape_num_elements 7func.func @shape_num_elements(%shape : !shape.shape) -> !shape.size { 8 %init = shape.const_size 1 9 %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { 10 ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size): 11 %acc_next = shape.mul %acc, %extent 12 : !shape.size, !shape.size -> !shape.size 13 shape.yield %acc_next : !shape.size 14 } 15 return %num_elements : !shape.size 16} 17 18// CHECK-LABEL: extent_tensor_num_elements 19func.func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index { 20 %init = arith.constant 1 : index 21 %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index { 22 ^bb0(%index : index, %extent : index, %acc : index): 23 %acc_next = shape.mul %acc, %extent : index, index -> index 24 shape.yield %acc_next : index 25 } 26 return %num_elements : index 27} 28 29func.func @test_shape_num_elements_unknown() { 30 %0 = "shape.unknown_shape"() : () -> !shape.shape 31 %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) 32 %2 = "shape.print"(%1) : (!shape.size) -> !shape.size 33 return 34} 35 36func.func @const_shape() { 37 %0 = shape.const_shape [1, 2, 3] : !shape.shape 38 %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> 39 return 40} 41 42func.func @test_shape_num_elements_fixed() { 43 %0 = shape.const_shape [1, 57, 92] : !shape.shape 44 %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) 45 %3 = "shape.print"(%1) : (!shape.size) -> !shape.size 46 return 47} 48 49func.func @test_broadcast_fixed() { 50 %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape 51 %1 = shape.const_shape [4, 57, 92] : !shape.shape 52 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape 53 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 54 return 55} 56 57func.func @test_broadcast_extents() -> tensor<4xindex> { 58 %0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex> 59 %1 = shape.const_shape [4, 57, 92] : tensor<3xindex> 60 %2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex> 61 return %2 : tensor<4xindex> 62} 63 64func.func @test_shape_any_fixed() { 65 %0 = shape.const_shape [4, 57, 92] : !shape.shape 66 %1 = shape.const_shape [4, 57, 92] : !shape.shape 67 %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 68 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 69 return 70} 71 72func.func @test_shape_any_unknown() { 73 %0 = shape.const_shape [4, -1, 92] : !shape.shape 74 %1 = shape.const_shape [-1, 57, 92] : !shape.shape 75 %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 76 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 77 return 78} 79 80func.func @test_shape_any_fixed_mismatch() { 81 %0 = shape.const_shape [4, 57, 92] : !shape.shape 82 %1 = shape.const_shape [2, 57, 92] : !shape.shape 83 %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 84 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 85 return 86} 87 88func.func @test_parse_const_shape() { 89 %0 = shape.const_shape [] : !shape.shape 90 %1 = shape.const_shape [1, 2, 3] : !shape.shape 91 %2 = shape.const_shape [1, 2, 3] : tensor<3xindex> 92 return 93} 94 95func.func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> { 96 %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex> 97 return %0 : tensor<?xindex> 98} 99 100func.func @test_value_of(%arg0: !shape.value_shape) -> tensor<?xf32> { 101 %0 = shape.value_of %arg0 : tensor<?xf32> 102 return %0 : tensor<?xf32> 103} 104 105func.func @test_constraints() { 106 %0 = shape.const_shape [] : !shape.shape 107 %1 = shape.const_shape [1, 2, 3] : !shape.shape 108 %true = arith.constant true 109 %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape 110 %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape 111 %w2 = shape.const_witness true 112 %w3 = shape.const_witness false 113 %w4 = shape.cstr_require %true, "msg" 114 %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4 115 shape.assuming %w_all -> !shape.shape { 116 %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 117 shape.assuming_yield %2 : !shape.shape 118 } 119 return 120} 121 122func.func @eq_on_extent_tensors(%lhs : tensor<?xindex>, 123 %rhs : tensor<?xindex>) { 124 %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex> 125 return 126} 127 128func.func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>, 129 %rhs : tensor<?xindex>) { 130 %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex> 131 return 132} 133 134func.func @mul(%size_arg : !shape.size, %index_arg : index) { 135 %size_prod = shape.mul %size_arg, %size_arg 136 : !shape.size, !shape.size -> !shape.size 137 %index_prod = shape.mul %index_arg, %index_arg : index, index -> index 138 %mixed_prod = shape.mul %size_arg, %index_arg 139 : !shape.size, index -> !shape.size 140 return 141} 142 143func.func @div(%size_arg : !shape.size, %index_arg : index) { 144 %size_div = shape.div %size_arg, %size_arg 145 : !shape.size, !shape.size -> !shape.size 146 %index_div = shape.div %index_arg, %index_arg : index, index -> index 147 %mixed_div = shape.div %size_arg, %index_arg 148 : !shape.size, index -> !shape.size 149 return 150} 151 152func.func @add(%size_arg : !shape.size, %index_arg : index) { 153 %size_sum = shape.add %size_arg, %size_arg 154 : !shape.size, !shape.size -> !shape.size 155 %index_sum = shape.add %index_arg, %index_arg : index, index -> index 156 %mixed_sum = shape.add %size_arg, %index_arg 157 : !shape.size, index -> !shape.size 158 return 159} 160 161func.func @const_size() { 162 // CHECK: %c1 = shape.const_size 1 163 // CHECK: %c2 = shape.const_size 2 164 // CHECK: %c2_0 = shape.const_size 2 165 %0 = shape.const_size 1 166 %1 = shape.const_size 2 167 %2 = shape.const_size 2 168 return 169} 170 171func.func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { 172 %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex> 173 return %0 : tensor<3xindex> 174} 175 176func.func @test_identity_to_extent_tensor(%arg: tensor<3xindex>) -> tensor<3xindex> { 177 %0 = shape.to_extent_tensor %arg : tensor<3xindex> -> tensor<3xindex> 178 return %0 : tensor<3xindex> 179} 180 181func.func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape { 182 %0 = shape.from_extent_tensor %arg : tensor<?xindex> 183 return %0 : !shape.shape 184} 185 186func.func @rank(%shape : !shape.shape) -> !shape.size { 187 %rank = shape.rank %shape : !shape.shape -> !shape.size 188 return %rank : !shape.size 189} 190 191func.func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index { 192 %rank = shape.rank %shape : tensor<?xindex> -> index 193 return %rank : index 194} 195 196func.func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { 197 %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape 198 return %result : i1 199} 200 201func.func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 { 202 %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 203 return %result : i1 204} 205 206func.func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 { 207 %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape 208 return %result : i1 209} 210 211func.func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size { 212 %c0 = shape.const_size 0 213 %result = shape.get_extent %arg, %c0 : 214 !shape.shape, !shape.size -> !shape.size 215 return %result : !shape.size 216} 217 218func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index { 219 %c0 = arith.constant 0 : index 220 %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index 221 return %result : index 222} 223 224func.func @get_dim(%arg : memref<?x?xindex>) -> index { 225 %c0 = arith.constant 0 : index 226 %result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index 227 return %result : index 228} 229 230func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size { 231 %c0 = shape.const_size 0 232 %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size 233 return %result : !shape.size 234} 235 236func.func @any() { 237 %0 = shape.const_shape [1, 2, 3] : !shape.shape 238 %1 = shape.const_shape [4, 5, 6] : !shape.shape 239 %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 240 %3 = shape.const_shape [1, 2, 3] : tensor<3xindex> 241 %4 = shape.const_shape [4, 5, 6] : tensor<3xindex> 242 %5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex> 243 return 244} 245 246func.func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index { 247 %result = shape.num_elements %arg : tensor<?xindex> -> index 248 return %result : index 249} 250 251func.func @num_elements_shape(%arg : !shape.shape) -> !shape.size { 252 %result = shape.num_elements %arg : !shape.shape -> !shape.size 253 return %result : !shape.size 254} 255 256// Testing invoking shape function from another. shape_equal_shapes is merely 257// a trivial helper function to invoke elsewhere. 258func.func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { 259 %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape 260 %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape 261 %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 262 return %2 : !shape.shape 263} 264func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { 265 %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape 266 %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape 267 %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape 268 return %2 : !shape.shape 269} 270 271func.func @shape_with_shape_extent_tensor_type(%a : tensor<?x?x?xf32>, %b : !shape.value_shape) -> !shape.value_shape { 272 %0 = shape.shape_of %a : tensor<?x?x?xf32> -> tensor<3xindex> 273 %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex> 274 return %1 : !shape.value_shape 275} 276 277func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) 278 -> !shape.shape { 279 %result = shape.any %a, %b, %c 280 : !shape.shape, !shape.shape, !shape.shape -> !shape.shape 281 return %result : !shape.shape 282} 283 284func.func @any_on_mixed(%a : tensor<?xindex>, 285 %b : tensor<?xindex>, 286 %c : !shape.shape) -> !shape.shape { 287 %result = shape.any %a, %b, %c 288 : tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape 289 return %result : !shape.shape 290} 291 292func.func @any_on_extent_tensors(%a : tensor<?xindex>, 293 %b : tensor<?xindex>, 294 %c : tensor<?xindex>) -> tensor<?xindex> { 295 %result = shape.any %a, %b, %c 296 : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> 297 return %result : tensor<?xindex> 298} 299 300func.func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>, 301 %b : tensor<?xindex>) -> i1 { 302 %result = shape.is_broadcastable %a, %b 303 : tensor<?xindex>, tensor<?xindex> 304 return %result : i1 305} 306 307func.func @is_broadcastable_on_shapes(%a : !shape.shape, 308 %b : !shape.shape) -> i1 { 309 %result = shape.is_broadcastable %a, %b 310 : !shape.shape, !shape.shape 311 return %result : i1 312} 313 314func.func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { 315 %0 = shape.const_shape [4, 57, 92] : !shape.shape 316 %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape 317 %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : 318 !shape.shape, !shape.shape -> !shape.shape 319 return %2 : !shape.shape 320} 321 322func.func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { 323 %0 = shape.const_shape [4, 57, 92] : !shape.shape 324 %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape 325 %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : 326 !shape.shape, !shape.shape -> !shape.shape 327 return %2 : !shape.shape 328} 329 330func.func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { 331 %0 = shape.const_size 5 332 %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size 333 %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : 334 !shape.size, !shape.size -> !shape.size 335 return %2 : !shape.size 336} 337 338func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { 339 %0 = shape.const_size 9 340 %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size 341 %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : 342 !shape.size, !shape.size -> !shape.size 343 return %2 : !shape.size 344} 345 346func.func @meet_index(%arg0 : index, %arg1 : index) -> index { 347 %result = shape.meet %arg0, %arg1 : index, index -> index 348 return %result : index 349} 350 351