xref: /llvm-project/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir (revision d4fd20258f63d30be638b04f10eaa469707759f0)
1// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s
2
3// Lower binary ops.
4// CHECK-LABEL: @binary_ops
5// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
6func.func @binary_ops(%lhs : index, %rhs : index) {
7  // CHECK: arith.addi %[[LHS]], %[[RHS]] : index
8  %sum = shape.add %lhs, %rhs : index, index -> index
9  // CHECK: arith.muli %[[LHS]], %[[RHS]] : index
10  %product = shape.mul %lhs, %rhs : index, index -> index
11  return
12}
13
14// -----
15
16// Don't lower binary ops when they operate on `shape.size`.
17// CHECK-LABEL: @binary_ops_on_size
18// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size)
19func.func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
20  // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
21  // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
22  %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size
23  %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size
24  return
25}
26
27// -----
28
29// Convert `rank` to `dim` of the first dimension.
30// CHECK-LABEL: @rank
31// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
32func.func @rank(%shape : tensor<?xindex>) -> index {
33  // CHECK: %[[C0:.*]] = arith.constant 0 : index
34  // CHECK: %[[RESULT:.*]] = tensor.dim %[[SHAPE]], %[[C0]]
35  // CHECK: return %[[RESULT]] : index
36  %rank = shape.rank %shape : tensor<?xindex> -> index
37  return %rank : index
38}
39
40// -----
41
42// Don't lower `get_extent` if it is of type `shape.size`.
43// CHECK-LABEL: @get_extent
44func.func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size {
45  // CHECK: shape.get_extent
46  %result = shape.get_extent %shape, %idx
47      : tensor<?xindex>, !shape.size -> !shape.size
48  return %result : !shape.size
49}
50
51// -----
52
53// Don't lower `rank` if type is not error-free.
54// CHECK-LABEL: @rank
55func.func @rank(%shape : !shape.shape) {
56  // CHECK: shape.rank
57  %rank = shape.rank %shape : !shape.shape -> !shape.size
58  return
59}
60
61// -----
62
63// Express `shape.dim` as `tensor.dim` when valid.
64// CHECK-LABEL: @dim
65// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
66func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index {
67  // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
68  // CHECK: return %[[RESULT]] : index
69  %result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index
70  return %result : index
71}
72
73// -----
74
75// Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a
76// `shape_of` operation.
77// CHECK-LABEL: @get_extent_shape_of
78// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
79func.func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
80  // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
81  // CHECK: return %[[RESULT]] : index
82  %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
83  %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
84  return %result : index
85}
86
87// -----
88
89// Express `get_extent` as `tensor.extract`.
90// CHECK-LABEL: @get_extent_from_extent_tensor
91// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
92func.func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
93    -> index {
94  // CHECK: %[[RESULT:.*]] = tensor.extract %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
95  // CHECK: return %[[RESULT]] : index
96  %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
97  return %result : index
98}
99
100// -----
101
102// Lower `const_shape` to `tensor.from_elements`.
103// CHECK-LABEL: @const_shape
104// CHECK-SAME: () -> tensor<3xindex>
105func.func @const_shape() -> tensor<3xindex> {
106  // CHECK: %[[C1:.*]] = arith.constant 1 : index
107  // CHECK: %[[C2:.*]] = arith.constant 2 : index
108  // CHECK: %[[C3:.*]] = arith.constant 3 : index
109  // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
110  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex>
111  // CHECK: return %[[RESULT]] : tensor<3xindex>
112  %shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
113  return %shape : tensor<3xindex>
114}
115
116// -----
117
118// Lower `const_shape` in the case of rank 0.
119// CHECK-LABEL: func @const_shape_zero_elements
120// CHECK-SAME: () -> tensor<0xindex>
121func.func @const_shape_zero_elements() -> tensor<0xindex> {
122  // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
123  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex>
124  // CHECK: return %[[RESULT]] : tensor<0xindex>
125  %shape = shape.const_shape [] : tensor<0xindex>
126  return %shape : tensor<0xindex>
127}
128
129// -----
130
131// Lower `any` to its first operand.
132// CHECK-LABEL: @any_of_three
133// CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
134func.func @any_of_three(%a : tensor<?xindex>,
135                   %b : tensor<?xindex>,
136                   %c : tensor<?xindex>) -> tensor<?xindex> {
137  // CHECK: return %[[A]] : tensor<?xindex>
138  %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
139  return %result : tensor<?xindex>
140}
141
142// -----
143
144// Lower `any` to its first operand.
145// CHECK-LABEL: @any_of_one
146// CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
147func.func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
148  // CHECK: return %[[A]] : tensor<?xindex>
149  %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
150  return %result : tensor<?xindex>
151}
152
153// -----
154
155// Lower 'const_size` to `arith.constant`
156// CHECK-LABEL: @const_size
157func.func @const_size() -> index {
158  // CHECK: %[[RES:.*]] = arith.constant 42 : index
159  %size = shape.const_size 42
160  %result = shape.size_to_index %size : !shape.size
161  // CHECK: return %[[RES]]
162  return %result : index
163}
164
165// -----
166
167// Lower `to_extent_tensor` to `tensor.cast`
168// Fold to_extent_tensor when already on tensor.
169// CHECK-LABEL: @to_extent_tensor
170// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
171func.func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
172  // CHECK-NOT: to_extent_tensor
173  // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
174  %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
175  // CHECK: return %[[RES]]
176  return %casted : tensor<3xindex>
177}
178
179// CHECK-LABEL: @shape_reduce
180// CHECK-SAME:  (%[[SHAPE:.*]]: tensor<?xindex>) -> index
181func.func @shape_reduce(%shape : tensor<?xindex>) -> index {
182  %init = arith.constant 1 : index
183  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
184    ^bb0(%index : index, %extent : index, %acc: index):
185      %new_acc = arith.muli %acc, %extent : index
186      shape.yield %new_acc : index
187  }
188  return %num_elements : index
189}
190// CHECK-NEXT: %[[INIT:.*]] = arith.constant 1 : index
191// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
192// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
193// CHECK-NEXT: %[[RANK:.*]] = tensor.dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
194// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
195// CHECK-NEXT:   %[[EXTENT:.*]] = tensor.extract %[[SHAPE]][%[[I]]]
196// CHECK-NEXT:   %[[NEW_ACC:.*]] = arith.muli %[[ACC]], %[[EXTENT]] : index
197// CHECK-NEXT:   scf.yield %[[NEW_ACC]] : index
198// CHECK-NEXT: }
199// CHECK-NEXT: return %[[RESULT]] : index
200
201// -----
202
203// Don't lower `shape_of` for result type of `shape.shape`.
204// CHECK-LABEL: @shape_of
205// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
206func.func @shape_of(%arg : tensor<*xf32>) {
207  // CHECK: shape.shape
208  %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
209  return
210}
211
212// -----
213
214// Lower `shape_of` for unranked tensors.
215// CHECK-LABEL: @shape_of_unranked
216// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
217func.func @shape_of_unranked(%arg : tensor<*xf32>) {
218  // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32>
219  // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] {
220  // CHECK: ^bb0(%[[I:.*]]: index):
221  // CHECK:   %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>
222  // CHECK:   yield %[[EXTENT]] : index
223  // CHECK: } : tensor<?xindex>
224  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
225  return
226}
227
228// -----
229
230// Don't lower `shape_of` with `shape.shape` type.
231// CHECK-LABEL: @shape_of
232// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
233func.func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
234  // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
235  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
236  return
237}
238
239// -----
240
241// Lower `shape_of` for statically shaped tensor.
242// CHECK-LABEL: @shape_of_stat
243// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
244func.func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
245  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
246  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
247  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
248  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex>
249  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
250  return
251}
252
253// -----
254
255// Lower `shape_of` for 0-D tensor.
256// CHECK-LABEL: @shape_of_zero_d
257// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
258func.func @shape_of_zero_d(%arg : tensor<f32>) {
259  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements : tensor<0xindex>
260  %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
261  return
262}
263
264// -----
265
266// Lower `shape_of` for dynamically shaped tensor.
267// CHECK-LABEL: @shape_of_dyn
268// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
269func.func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
270  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
271  // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
272  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
273  // CHECK-DAG: %[[DYN_DIM:.*]] = tensor.dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
274  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex>
275  %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
276  return
277}
278
279// -----
280
281// CHECK-LABEL:  @shape_eq
282// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
283func.func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
284  // CHECK: %[[C0:.*]] = arith.constant 0 : index
285  // CHECK: %[[RANK_A:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xindex>
286  // CHECK: %[[RANK_B:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?xindex>
287  // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_B]]
288  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
289  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
290  // CHECK:   %[[INIT:.*]] = arith.constant true
291  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
292  // CHECK:     %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
293  // CHECK:     %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
294  // CHECK:     %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
295  // CHECK:     %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]]
296  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
297  // CHECK:   }
298  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
299  // CHECK: } else {
300  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = arith.constant false
301  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
302  // CHECK: }
303  // CHECK: return %[[SHAPE_EQ]] : i1
304  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
305  return %result : i1
306}
307
308// -----
309
310// CHECK-LABEL:  @shape_eq
311// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
312func.func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
313  // CHECK: %[[C0:.*]] = arith.constant 0 : index
314  // CHECK: %[[RANK_A:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xindex>
315  // CHECK: %[[RANK_B:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?xindex>
316  // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_B]]
317  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
318  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
319  // CHECK:   %[[INIT:.*]] = arith.constant true
320  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
321  // CHECK:     %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
322  // CHECK:     %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
323  // CHECK:     %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
324  // CHECK:     %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]]
325  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
326  // CHECK:   }
327  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
328  // CHECK: } else {
329  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = arith.constant false
330  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
331  // CHECK: }
332  // CHECK: %[[RANK_C:.*]] = tensor.dim %[[C]], %[[C0]] : tensor<?xindex>
333  // CHECK: %[[RANK_EQ:.*]] = arith.cmpi eq, %[[RANK_A]], %[[RANK_C]]
334  // CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
335  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
336  // CHECK:   %[[INIT:.*]] = arith.constant true
337  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
338  // CHECK:     %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
339  // CHECK:     %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
340  // CHECK:     %[[EXTENT_EQ:.*]] = arith.cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
341  // CHECK:     %[[CONJ_NEXT:.*]] = arith.andi %[[CONJ]], %[[EXTENT_EQ]]
342  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
343  // CHECK:   }
344  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
345  // CHECK: } else {
346  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = arith.constant false
347  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
348  // CHECK: }
349  // CHECK: %[[RESULT:.*]] = arith.andi %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
350  // CHECK: return %[[RESULT]] : i1
351  %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
352  return %result : i1
353}
354
355// -----
356
357// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
358// CHECK-LABEL: @broadcast
359func.func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
360  // CHECK: shape.broadcast
361  %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
362  return %c : !shape.shape
363}
364
365// -----
366
367func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> i1 {
368  %0 = shape.is_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
369  return %0 : i1
370}
371// CHECK-LABEL: @try_is_broadcastable
372// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
373// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
374// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>)
375// CHECK:           %[[C0:.*]] = arith.constant 0 : index
376// CHECK:           %[[C1:.*]] = arith.constant 1 : index
377// CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
378// CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
379// CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
380// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
381// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
382// CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
383// CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
384// CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
385// CHECK:           %[[TRUE:.*]] = arith.constant true
386// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
387// CHECK:             %[[C1_0:.*]] = arith.constant 1 : index
388// CHECK:             %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
389// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
390// CHECK:               scf.yield %[[C1_0]] : index
391// CHECK:             } else {
392// CHECK:               %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index
393// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
394// CHECK:               %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
395// CHECK:               %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
396// CHECK:             }
397// CHECK:             %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
398// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
399// CHECK:               scf.yield %[[DIM0]] : index
400// CHECK:             } else {
401// CHECK:               %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index
402// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
403// CHECK:               %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
404// CHECK:               %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
405// CHECK:             }
406// CHECK:             %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
407// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
408// CHECK:               scf.yield %[[DIM1]] : index
409// CHECK:             } else {
410// CHECK:               %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index
411// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
412// CHECK:               %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
413// CHECK:               %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
414// CHECK:             }
415// CHECK:             %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
416// CHECK:             %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
417// CHECK:                scf.yield %[[ALL_SO_FAR]] : i1
418// CHECK:             } else {
419// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index
420// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
421// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index
422// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
423// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
424// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[ALL_SO_FAR]], %[[GOOD]] : i1
425// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
426// CHECK:             }
427// CHECK:             %[[OUT_BOUND_1:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
428// CHECK:             %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
429// CHECK:                scf.yield %[[REDUCTION_0]] : i1
430// CHECK:             } else {
431// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index
432// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
433// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index
434// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
435// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
436// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[REDUCTION_0]], %[[GOOD]] : i1
437// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
438// CHECK:             }
439// CHECK:             %[[OUT_BOUND_2:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
440// CHECK:             %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
441// CHECK:                scf.yield %[[SECOND_REDUCTION]] : i1
442// CHECK:             } else {
443// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index
444// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
445// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %c1 : index
446// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
447// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
448// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[SECOND_REDUCTION]], %[[GOOD]] : i1
449// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
450// CHECK:             }
451// CHECK:             scf.yield %[[FINAL_RESULT]] : i1
452
453// -----
454
455func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> !shape.witness {
456  %0 = shape.cstr_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
457  return %0 : !shape.witness
458}
459// CHECK-LABEL:   func @broadcast(
460// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
461// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
462// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>)
463// CHECK:           %[[C0:.*]] = arith.constant 0 : index
464// CHECK:           %[[C1:.*]] = arith.constant 1 : index
465// CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
466// CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
467// CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
468// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
469// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
470// CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
471// CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
472// CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
473// CHECK:           %[[TRUE:.*]] = arith.constant true
474// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
475// CHECK:             %[[C1_0:.*]] = arith.constant 1 : index
476// CHECK:             %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
477// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
478// CHECK:               scf.yield %[[C1_0]] : index
479// CHECK:             } else {
480// CHECK:               %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index
481// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
482// CHECK:               %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
483// CHECK:               %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
484// CHECK:             }
485// CHECK:             %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
486// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
487// CHECK:               scf.yield %[[DIM0]] : index
488// CHECK:             } else {
489// CHECK:               %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index
490// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
491// CHECK:               %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
492// CHECK:               %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
493// CHECK:             }
494// CHECK:             %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
495// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
496// CHECK:               scf.yield %[[DIM1]] : index
497// CHECK:             } else {
498// CHECK:               %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index
499// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
500// CHECK:               %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
501// CHECK:               %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
502// CHECK:             }
503// CHECK:             %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
504// CHECK:             %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
505// CHECK:                scf.yield %[[ALL_SO_FAR]] : i1
506// CHECK:             } else {
507// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index
508// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
509// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index
510// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
511// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
512// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[ALL_SO_FAR]], %[[GOOD]] : i1
513// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
514// CHECK:             }
515// CHECK:             %[[OUT_BOUND_1:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
516// CHECK:             %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
517// CHECK:                scf.yield %[[REDUCTION_0]] : i1
518// CHECK:             } else {
519// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index
520// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
521// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED]], %c1 : index
522// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
523// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
524// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[REDUCTION_0]], %[[GOOD]] : i1
525// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
526// CHECK:             }
527// CHECK:             %[[OUT_BOUND_2:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
528// CHECK:             %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
529// CHECK:                scf.yield %[[SECOND_REDUCTION]] : i1
530// CHECK:             } else {
531// CHECK:                %[[SHIFTED:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index
532// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
533// CHECK:                %[[EQUALS_1:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %c1 : index
534// CHECK:                %[[EQUALS_BROADCASTED:.*]] = arith.cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
535// CHECK:                %[[GOOD:.*]] = arith.ori %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
536// CHECK:                %[[AND_REDUCTION:.*]] = arith.andi %[[SECOND_REDUCTION]], %[[GOOD]] : i1
537// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
538// CHECK:             }
539// CHECK:             scf.yield %[[FINAL_RESULT]] : i1
540
541// CHECK:           %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
542// CHECK:           return %[[RESULT]] : !shape.witness
543// CHECK:         }
544
545// -----
546
547func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
548                                           %b : tensor<3xindex>,
549                                           %c : tensor<2xindex>) {
550// CHECK-LABEL:   func @broadcast_3_shapes_different_extents(
551// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
552// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
553// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>) {
554// CHECK:           %[[C0:.*]] = arith.constant 0 : index
555// CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
556// CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
557// CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
558// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
559// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
560// CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
561// CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
562// CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
563// CHECK:           %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]]  {
564// CHECK:           ^bb0(%[[IDX:.*]]: index):
565// CHECK:             %[[C1:.*]] = arith.constant 1 : index
566// CHECK:             %[[OUTBOUNDS0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
567// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
568// CHECK:               scf.yield %[[C1]] : index
569// CHECK:             } else {
570// CHECK:               %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index
571// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
572// CHECK:               %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index
573// CHECK:               %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index
574// CHECK:             }
575// CHECK:             %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
576// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
577// CHECK:               scf.yield %[[DIM0]] : index
578// CHECK:             } else {
579// CHECK:               %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index
580// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
581// CHECK:               %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index
582// CHECK:               %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
583// CHECK:             }
584// CHECK:             %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
585// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
586// CHECK:               scf.yield %[[DIM1]] : index
587// CHECK:             } else {
588// CHECK:               %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index
589// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
590// CHECK:               %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index
591// CHECK:               %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
592// CHECK:             }
593// CHECK:             tensor.yield %[[DIM2]] : index
594// CHECK:           } : tensor<?xindex>
595// CHECK:           return
596// CHECK:         }
597  %0 = shape.broadcast %a, %b, %c
598      : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
599  return
600}
601
602// -----
603
604// CHECK-LABEL: @broadcast_to_known_rank
605func.func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>)
606    -> tensor<3xindex> {
607  // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
608  // CHECK: return %[[RES]] : tensor<3xindex>
609  %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex>
610  return %0 : tensor<3xindex>
611}
612
613// -----
614
615// Lower `split_at`
616// CHECK-LABEL: @split_at
617// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index
618func.func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) {
619  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
620  // CHECK-NEXT: %[[RANK:.*]] = tensor.dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
621  // CHECK-NEXT: %[[POSINDEX:.*]] = arith.addi %[[INDEX]], %[[RANK]] : index
622  // CHECK-NEXT: %[[ISNEG:.*]] = arith.cmpi slt, %[[INDEX]], %[[C0]] : index
623  // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
624  // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
625  // CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
626  // CHECK-NEXT: %[[TAIL_SIZE:.*]] = arith.subi %[[RANK]], %[[SELECT]] : index
627  // CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
628  // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
629  %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
630  return %head, %tail : tensor<?xindex>, tensor<?xindex>
631}
632