1// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s 2 3// Lower binary ops. 4// CHECK-LABEL: @binary_ops 5// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) 6func.func @binary_ops(%lhs : index, %rhs : index) { 7 // CHECK: arith.addi %[[LHS]], %[[RHS]] : index 8 %sum = shape.add %lhs, %rhs : index, index -> index 9 // CHECK: arith.muli %[[LHS]], %[[RHS]] : index 10 %product = shape.mul %lhs, %rhs : index, index -> index 11 return 12} 13 14// ----- 15 16// Don't lower binary ops when they operate on `shape.size`. 17// CHECK-LABEL: @binary_ops_on_size 18// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size) 19func.func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) { 20 // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size 21 // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size 22 %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size 23 %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size 24 return 25} 26 27// ----- 28 29// Convert `rank` to `dim` of the first dimension. 30// CHECK-LABEL: @rank 31// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index 32func.func @rank(%shape : tensor<?xindex>) -> index { 33 // CHECK: %[[C0:.*]] = arith.constant 0 : index 34 // CHECK: %[[RESULT:.*]] = tensor.dim %[[SHAPE]], %[[C0]] 35 // CHECK: return %[[RESULT]] : index 36 %rank = shape.rank %shape : tensor<?xindex> -> index 37 return %rank : index 38} 39 40// ----- 41 42// Don't lower `get_extent` if it is of type `shape.size`. 43// CHECK-LABEL: @get_extent 44func.func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size { 45 // CHECK: shape.get_extent 46 %result = shape.get_extent %shape, %idx 47 : tensor<?xindex>, !shape.size -> !shape.size 48 return %result : !shape.size 49} 50 51// ----- 52 53// Don't lower `rank` if type is not error-free. 54// CHECK-LABEL: @rank 55func.func @rank(%shape : !shape.shape) { 56 // CHECK: shape.rank 57 %rank = shape.rank %shape : !shape.shape -> !shape.size 58 return 59} 60 61// ----- 62 63// Express `shape.dim` as `tensor.dim` when valid. 64// CHECK-LABEL: @dim 65// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index 66func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index { 67 // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> 68 // CHECK: return %[[RESULT]] : index 69 %result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index 70 return %result : index 71} 72 73// ----- 74 75// Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a 76// `shape_of` operation. 77// CHECK-LABEL: @get_extent_shape_of 78// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index 79func.func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index { 80 // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> 81 // CHECK: return %[[RESULT]] : index 82 %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex> 83 %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index 84 return %result : index 85} 86 87// ----- 88 89// Express `get_extent` as `tensor.extract`. 90// CHECK-LABEL: @get_extent_from_extent_tensor 91// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index 92func.func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index) 93 -> index { 94 // CHECK: %[[RESULT:.*]] = tensor.extract %[[EXTENTS]][%[[IDX]]] : tensor<?xindex> 95 // CHECK: return %[[RESULT]] : index 96 %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index 97 return %result : index 98} 99 100// ----- 101 102// Lower `const_shape` to `tensor.from_elements`. 103// CHECK-LABEL: @const_shape 104// CHECK-SAME: () -> tensor<3xindex> 105func.func @const_shape() -> tensor<3xindex> { 106 // CHECK: %[[C1:.*]] = arith.constant 1 : index 107 // CHECK: %[[C2:.*]] = arith.constant 2 : index 108 // CHECK: %[[C3:.*]] = arith.constant 3 : index 109 // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] 110 // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex> 111 // CHECK: return %[[RESULT]] : tensor<3xindex> 112 %shape = shape.const_shape [1, 2, 3] : tensor<3xindex> 113 return %shape : tensor<3xindex> 114} 115 116// ----- 117 118// Lower `const_shape` in the case of rank 0. 119// CHECK-LABEL: func @const_shape_zero_elements 120// CHECK-SAME: () -> tensor<0xindex> 121func.func @const_shape_zero_elements() -> tensor<0xindex> { 122 // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex> 123 // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex> 124 // CHECK: return %[[RESULT]] : tensor<0xindex> 125 %shape = shape.const_shape [] : tensor<0xindex> 126 return %shape : tensor<0xindex> 127} 128 129// ----- 130 131// Lower `any` to its first operand. 132// CHECK-LABEL: @any_of_three 133// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex> 134func.func @any_of_three(%a : tensor<?xindex>, 135 %b : tensor<?xindex>, 136 %c : tensor<?xindex>) -> tensor<?xindex> { 137 // CHECK: return %[[A]] : tensor<?xindex> 138 %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex> 139 return %result : tensor<?xindex> 140} 141 142// ----- 143 144// Lower `any` to its first operand. 145// CHECK-LABEL: @any_of_one 146// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex> 147func.func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> { 148 // CHECK: return %[[A]] : tensor<?xindex> 149 %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex> 150 return %result : tensor<?xindex> 151} 152 153// ----- 154 155// Lower 'const_size` to `arith.constant` 156// CHECK-LABEL: @const_size 157func.func @const_size() -> index { 158 // CHECK: %[[RES:.*]] = arith.constant 42 : index 159 %size = shape.const_size 42 160 %result = shape.size_to_index %size : !shape.size 161 // CHECK: return %[[RES]] 162 return %result : index 163} 164 165// ----- 166 167// Lower `to_extent_tensor` to `tensor.cast` 168// Fold to_extent_tensor when already on tensor. 169// CHECK-LABEL: @to_extent_tensor 170// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex> 171func.func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> { 172 // CHECK-NOT: to_extent_tensor 173 // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex 174 %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex> 175 // CHECK: return %[[RES]] 176 return %casted : tensor<3xindex> 177} 178 179// CHECK-LABEL: @shape_reduce 180// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index 181func.func @shape_reduce(%shape : tensor<?xindex>) -> index { 182 %init = arith.constant 1 : index 183 %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index { 184 ^bb0(%index : index, %extent : index, %acc: index): 185 %new_acc = arith.muli %acc, %extent : index 186 shape.yield %new_acc : index 187 } 188 return %num_elements : index 189} 190// CHECK-NEXT: %[[INIT:.*]] = arith.constant 1 : index 191// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index 192// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index 193// CHECK-NEXT: %[[RANK:.*]] = tensor.dim %[[SHAPE]], %[[C0]] : tensor<?xindex> 194// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) 195// CHECK-NEXT: %[[EXTENT:.*]] = tensor.extract %[[SHAPE]][%[[I]]] 196// CHECK-NEXT: %[[NEW_ACC:.*]] = arith.muli %[[ACC]], %[[EXTENT]] : index 197// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index 198// CHECK-NEXT: } 199// CHECK-NEXT: return %[[RESULT]] : index 200 201// ----- 202 203// Don't lower `shape_of` for result type of `shape.shape`. 204// CHECK-LABEL: @shape_of 205// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) 206func.func @shape_of(%arg : tensor<*xf32>) { 207 // CHECK: shape.shape 208 %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape 209 return 210} 211 212// ----- 213 214// Lower `shape_of` for unranked tensors. 215// CHECK-LABEL: @shape_of_unranked 216// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) 217func.func @shape_of_unranked(%arg : tensor<*xf32>) { 218 // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32> 219 // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] { 220 // CHECK: ^bb0(%[[I:.*]]: index): 221 // CHECK: %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32> 222 // CHECK: yield %[[EXTENT]] : index 223 // CHECK: } : tensor<?xindex> 224 %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex> 225 return 226} 227 228// ----- 229 230// Don't lower `shape_of` with `shape.shape` type. 231// CHECK-LABEL: @shape_of 232// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) 233func.func @shape_of_stat(%arg : tensor<1x2x3xf32>) { 234 // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape 235 %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape 236 return 237} 238 239// ----- 240 241// Lower `shape_of` for statically shaped tensor. 242// CHECK-LABEL: @shape_of_stat 243// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) 244func.func @shape_of_stat(%arg : tensor<1x2x3xf32>) { 245 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 246 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 247 // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 248 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> 249 %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex> 250 return 251} 252 253// ----- 254 255// Lower `shape_of` for 0-D tensor. 256// CHECK-LABEL: @shape_of_zero_d 257// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>) 258func.func @shape_of_zero_d(%arg : tensor<f32>) { 259 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements : tensor<0xindex> 260 %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex> 261 return 262} 263 264// ----- 265 266// Lower `shape_of` for dynamically shaped tensor. 267// CHECK-LABEL: @shape_of_dyn 268// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) 269func.func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { 270 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 271 // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index 272 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 273 // CHECK-DAG: %[[DYN_DIM:.*]] = tensor.dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> 274 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> 275 %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex> 276 return 277} 278 279// ----- 280 281// CHECK-LABEL: @shape_eq 282// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1 283func.func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 { 284 // CHECK: %[[C0:.*]] = arith.constant 0 : index 285 // CHECK: %[[RANK_A:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xindex> 286 // CHECK: %[[RANK_B:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?xindex> 287 // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_B]] 288 // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { 289 // CHECK: %[[C1:.*]] = arith.constant 1 : index 290 // CHECK: %[[INIT:.*]] = arith.constant true 291 // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { 292 // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex> 293 // CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex> 294 // CHECK: %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]] 295 // CHECK: %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]] 296 // CHECK: scf.yield %[[CONJ_NEXT]] : i1 297 // CHECK: } 298 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 299 // CHECK: } else { 300 // CHECK: %[[SHAPE_EQ_INNER:.*]] = arith.constant false 301 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 302 // CHECK: } 303 // CHECK: return %[[SHAPE_EQ]] : i1 304 %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 305 return %result : i1 306} 307 308// ----- 309 310// CHECK-LABEL: @shape_eq 311// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1 312func.func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 { 313 // CHECK: %[[C0:.*]] = arith.constant 0 : index 314 // CHECK: %[[RANK_A:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xindex> 315 // CHECK: %[[RANK_B:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?xindex> 316 // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_B]] 317 // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { 318 // CHECK: %[[C1:.*]] = arith.constant 1 : index 319 // CHECK: %[[INIT:.*]] = arith.constant true 320 // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { 321 // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex> 322 // CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex> 323 // CHECK: %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]] 324 // CHECK: %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]] 325 // CHECK: scf.yield %[[CONJ_NEXT]] : i1 326 // CHECK: } 327 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 328 // CHECK: } else { 329 // CHECK: %[[SHAPE_EQ_INNER:.*]] = arith.constant false 330 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 331 // CHECK: } 332 // CHECK: %[[RANK_C:.*]] = tensor.dim %[[C]], %[[C0]] : tensor<?xindex> 333 // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_C]] 334 // CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) { 335 // CHECK: %[[C1:.*]] = arith.constant 1 : index 336 // CHECK: %[[INIT:.*]] = arith.constant true 337 // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { 338 // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex> 339 // CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex> 340 // CHECK: %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]] 341 // CHECK: %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]] 342 // CHECK: scf.yield %[[CONJ_NEXT]] : i1 343 // CHECK: } 344 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 345 // CHECK: } else { 346 // CHECK: %[[SHAPE_EQ_INNER:.*]] = arith.constant false 347 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 348 // CHECK: } 349 // CHECK: %[[RESULT:.*]] = arith.andi %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1 350 // CHECK: return %[[RESULT]] : i1 351 %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> 352 return %result : i1 353} 354 355// ----- 356 357// Don't lower `shape.broadcast` if a `shape.shape` type is involved. 358// CHECK-LABEL: @broadcast 359func.func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape { 360 // CHECK: shape.broadcast 361 %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape 362 return %c : !shape.shape 363} 364 365// ----- 366 367func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> i1 { 368 %0 = shape.is_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> 369 return %0 : i1 370} 371// CHECK-LABEL: @try_is_broadcastable 372// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, 373// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, 374// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) 375// CHECK: %[[C0:.*]] = arith.constant 0 : index 376// CHECK: %[[C1:.*]] = arith.constant 1 : index 377// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex> 378// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> 379// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> 380// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index 381// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index 382// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index 383// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index 384// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index 385// CHECK: %[[TRUE:.*]] = arith.constant true 386// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { 387// CHECK: %[[C1_0:.*]] = arith.constant 1 : index 388// CHECK: %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index 389// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { 390// CHECK: scf.yield %[[C1_0]] : index 391// CHECK: } else { 392// CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index 393// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> 394// CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index 395// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index 396// CHECK: } 397// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index 398// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { 399// CHECK: scf.yield %[[DIM0]] : index 400// CHECK: } else { 401// CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index 402// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> 403// CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index 404// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index 405// CHECK: } 406// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index 407// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { 408// CHECK: scf.yield %[[DIM1]] : index 409// CHECK: } else { 410// CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index 411// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> 412// CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index 413// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index 414// CHECK: } 415// CHECK: %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index 416// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { 417// CHECK: scf.yield %[[ALL_SO_FAR]] : i1 418// CHECK: } else { 419// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index 420// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex> 421// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index 422// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index 423// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 424// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[ALL_SO_FAR]], %[[GOOD]] : i1 425// CHECK: scf.yield %[[AND_REDUCTION]] : i1 426// CHECK: } 427// CHECK: %[[OUT_BOUND_1:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index 428// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) { 429// CHECK: scf.yield %[[REDUCTION_0]] : i1 430// CHECK: } else { 431// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index 432// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex> 433// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index 434// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index 435// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 436// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[REDUCTION_0]], %[[GOOD]] : i1 437// CHECK: scf.yield %[[AND_REDUCTION]] : i1 438// CHECK: } 439// CHECK: %[[OUT_BOUND_2:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index 440// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) { 441// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1 442// CHECK: } else { 443// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index 444// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex> 445// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %c1 : index 446// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index 447// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1 448// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[SECOND_REDUCTION]], %[[GOOD]] : i1 449// CHECK: scf.yield %[[AND_REDUCTION]] : i1 450// CHECK: } 451// CHECK: scf.yield %[[FINAL_RESULT]] : i1 452 453// ----- 454 455func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> !shape.witness { 456 %0 = shape.cstr_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> 457 return %0 : !shape.witness 458} 459// CHECK-LABEL: func @broadcast( 460// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, 461// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, 462// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) 463// CHECK: %[[C0:.*]] = arith.constant 0 : index 464// CHECK: %[[C1:.*]] = arith.constant 1 : index 465// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex> 466// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> 467// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> 468// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index 469// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index 470// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index 471// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index 472// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index 473// CHECK: %[[TRUE:.*]] = arith.constant true 474// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { 475// CHECK: %[[C1_0:.*]] = arith.constant 1 : index 476// CHECK: %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index 477// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { 478// CHECK: scf.yield %[[C1_0]] : index 479// CHECK: } else { 480// CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index 481// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> 482// CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index 483// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index 484// CHECK: } 485// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index 486// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { 487// CHECK: scf.yield %[[DIM0]] : index 488// CHECK: } else { 489// CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index 490// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> 491// CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index 492// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index 493// CHECK: } 494// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index 495// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { 496// CHECK: scf.yield %[[DIM1]] : index 497// CHECK: } else { 498// CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index 499// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> 500// CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index 501// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index 502// CHECK: } 503// CHECK: %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index 504// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { 505// CHECK: scf.yield %[[ALL_SO_FAR]] : i1 506// CHECK: } else { 507// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index 508// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex> 509// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index 510// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index 511// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 512// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[ALL_SO_FAR]], %[[GOOD]] : i1 513// CHECK: scf.yield %[[AND_REDUCTION]] : i1 514// CHECK: } 515// CHECK: %[[OUT_BOUND_1:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index 516// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) { 517// CHECK: scf.yield %[[REDUCTION_0]] : i1 518// CHECK: } else { 519// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index 520// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex> 521// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index 522// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index 523// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 524// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[REDUCTION_0]], %[[GOOD]] : i1 525// CHECK: scf.yield %[[AND_REDUCTION]] : i1 526// CHECK: } 527// CHECK: %[[OUT_BOUND_2:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index 528// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) { 529// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1 530// CHECK: } else { 531// CHECK: %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index 532// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex> 533// CHECK: %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %c1 : index 534// CHECK: %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index 535// CHECK: %[[GOOD:.*]] = arith.ori %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1 536// CHECK: %[[AND_REDUCTION:.*]] = arith.andi %[[SECOND_REDUCTION]], %[[GOOD]] : i1 537// CHECK: scf.yield %[[AND_REDUCTION]] : i1 538// CHECK: } 539// CHECK: scf.yield %[[FINAL_RESULT]] : i1 540 541// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" 542// CHECK: return %[[RESULT]] : !shape.witness 543// CHECK: } 544 545// ----- 546 547func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>, 548 %b : tensor<3xindex>, 549 %c : tensor<2xindex>) { 550// CHECK-LABEL: func @broadcast_3_shapes_different_extents( 551// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, 552// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, 553// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) { 554// CHECK: %[[C0:.*]] = arith.constant 0 : index 555// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex> 556// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> 557// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> 558// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index 559// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index 560// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index 561// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index 562// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index 563// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] { 564// CHECK: ^bb0(%[[IDX:.*]]: index): 565// CHECK: %[[C1:.*]] = arith.constant 1 : index 566// CHECK: %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index 567// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { 568// CHECK: scf.yield %[[C1]] : index 569// CHECK: } else { 570// CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index 571// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> 572// CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index 573// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index 574// CHECK: } 575// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index 576// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { 577// CHECK: scf.yield %[[DIM0]] : index 578// CHECK: } else { 579// CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index 580// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> 581// CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index 582// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index 583// CHECK: } 584// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index 585// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { 586// CHECK: scf.yield %[[DIM1]] : index 587// CHECK: } else { 588// CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index 589// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> 590// CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index 591// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index 592// CHECK: } 593// CHECK: tensor.yield %[[DIM2]] : index 594// CHECK: } : tensor<?xindex> 595// CHECK: return 596// CHECK: } 597 %0 = shape.broadcast %a, %b, %c 598 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex> 599 return 600} 601 602// ----- 603 604// CHECK-LABEL: @broadcast_to_known_rank 605func.func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>) 606 -> tensor<3xindex> { 607 // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex> 608 // CHECK: return %[[RES]] : tensor<3xindex> 609 %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex> 610 return %0 : tensor<3xindex> 611} 612 613// ----- 614 615// Lower `split_at` 616// CHECK-LABEL: @split_at 617// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index 618func.func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) { 619 // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index 620 // CHECK-NEXT: %[[RANK:.*]] = tensor.dim %[[SHAPE]], %[[C0]] : tensor<?xindex> 621 // CHECK-NEXT: %[[POSINDEX:.*]] = arith.addi %[[INDEX]], %[[RANK]] : index 622 // CHECK-NEXT: %[[ISNEG:.*]] = arith.cmpi slt, %[[INDEX]], %[[C0]] : index 623 // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index 624 // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index 625 // CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex> 626 // CHECK-NEXT: %[[TAIL_SIZE:.*]] = arith.subi %[[RANK]], %[[SELECT]] : index 627 // CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex> 628 // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex> 629 %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>) 630 return %head, %tail : tensor<?xindex>, tensor<?xindex> 631} 632