xref: /llvm-project/mlir/test/Dialect/Shape/ops.mlir (revision af29db64b2c7091070dd623c81872559657e7b3d)
1// Verify the printed output can be parsed.
2// RUN: mlir-opt %s | mlir-opt | FileCheck %s
3// Verify the generic form can be parsed.
4// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
5
6// CHECK-LABEL: shape_num_elements
7func.func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
8  %init = shape.const_size 1
9  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
10    ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
11      %acc_next = shape.mul %acc, %extent
12          : !shape.size, !shape.size -> !shape.size
13      shape.yield %acc_next : !shape.size
14  }
15  return %num_elements : !shape.size
16}
17
18// CHECK-LABEL: extent_tensor_num_elements
19func.func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
20  %init = arith.constant 1 : index
21  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
22    ^bb0(%index : index, %extent : index, %acc : index):
23      %acc_next = shape.mul %acc, %extent : index, index -> index
24      shape.yield %acc_next : index
25  }
26  return %num_elements : index
27}
28
29func.func @test_shape_num_elements_unknown() {
30  %0 = "shape.unknown_shape"() : () -> !shape.shape
31  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
32  %2 = "shape.print"(%1) : (!shape.size) -> !shape.size
33  return
34}
35
36func.func @const_shape() {
37  %0 = shape.const_shape [1, 2, 3] : !shape.shape
38  %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
39  return
40}
41
42func.func @test_shape_num_elements_fixed() {
43  %0 = shape.const_shape [1, 57, 92] : !shape.shape
44  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
45  %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
46  return
47}
48
49func.func @test_broadcast_fixed() {
50  %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
51  %1 = shape.const_shape [4, 57, 92] : !shape.shape
52  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
53  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
54  return
55}
56
57func.func @test_broadcast_extents() -> tensor<4xindex> {
58  %0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
59  %1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
60  %2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
61  return %2 : tensor<4xindex>
62}
63
64func.func @test_shape_any_fixed() {
65  %0 = shape.const_shape [4, 57, 92] : !shape.shape
66  %1 = shape.const_shape [4, 57, 92] : !shape.shape
67  %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
68  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
69  return
70}
71
72func.func @test_shape_any_unknown() {
73  %0 = shape.const_shape [4, -1, 92] : !shape.shape
74  %1 = shape.const_shape [-1, 57, 92] : !shape.shape
75  %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
76  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
77  return
78}
79
80func.func @test_shape_any_fixed_mismatch() {
81  %0 = shape.const_shape [4, 57, 92] : !shape.shape
82  %1 = shape.const_shape [2, 57, 92] : !shape.shape
83  %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
84  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
85  return
86}
87
88func.func @test_parse_const_shape() {
89  %0 = shape.const_shape [] : !shape.shape
90  %1 = shape.const_shape [1, 2, 3] : !shape.shape
91  %2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
92  return
93}
94
95func.func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
96  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
97  return %0 : tensor<?xindex>
98}
99
100func.func @test_value_of(%arg0: !shape.value_shape) -> tensor<?xf32> {
101  %0 = shape.value_of %arg0 : tensor<?xf32>
102  return %0 : tensor<?xf32>
103}
104
105func.func @test_constraints() {
106  %0 = shape.const_shape [] : !shape.shape
107  %1 = shape.const_shape [1, 2, 3] : !shape.shape
108  %true = arith.constant true
109  %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
110  %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape
111  %w2 = shape.const_witness true
112  %w3 = shape.const_witness false
113  %w4 = shape.cstr_require %true, "msg"
114  %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4
115  shape.assuming %w_all -> !shape.shape {
116    %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
117    shape.assuming_yield %2 : !shape.shape
118  }
119  return
120}
121
122func.func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
123                           %rhs : tensor<?xindex>) {
124  %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
125  return
126}
127
128func.func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
129                                      %rhs : tensor<?xindex>) {
130  %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
131  return
132}
133
134func.func @mul(%size_arg : !shape.size, %index_arg : index) {
135  %size_prod = shape.mul %size_arg, %size_arg
136      : !shape.size, !shape.size -> !shape.size
137  %index_prod = shape.mul %index_arg, %index_arg : index, index -> index
138  %mixed_prod = shape.mul %size_arg, %index_arg
139      : !shape.size, index -> !shape.size
140  return
141}
142
143func.func @div(%size_arg : !shape.size, %index_arg : index) {
144  %size_div = shape.div %size_arg, %size_arg
145      : !shape.size, !shape.size -> !shape.size
146  %index_div = shape.div %index_arg, %index_arg : index, index -> index
147  %mixed_div = shape.div %size_arg, %index_arg
148      : !shape.size, index -> !shape.size
149  return
150}
151
152func.func @add(%size_arg : !shape.size, %index_arg : index) {
153  %size_sum = shape.add %size_arg, %size_arg
154      : !shape.size, !shape.size -> !shape.size
155  %index_sum = shape.add %index_arg, %index_arg : index, index -> index
156  %mixed_sum = shape.add %size_arg, %index_arg
157      : !shape.size, index -> !shape.size
158  return
159}
160
161func.func @const_size() {
162  // CHECK: %c1 = shape.const_size 1
163  // CHECK: %c2 = shape.const_size 2
164  // CHECK: %c2_0 = shape.const_size 2
165  %0 = shape.const_size 1
166  %1 = shape.const_size 2
167  %2 = shape.const_size 2
168  return
169}
170
171func.func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
172  %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
173  return %0 : tensor<3xindex>
174}
175
176func.func @test_identity_to_extent_tensor(%arg: tensor<3xindex>) -> tensor<3xindex> {
177  %0 = shape.to_extent_tensor %arg : tensor<3xindex> -> tensor<3xindex>
178  return %0 : tensor<3xindex>
179}
180
181func.func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
182  %0 = shape.from_extent_tensor %arg : tensor<?xindex>
183  return %0 : !shape.shape
184}
185
186func.func @rank(%shape : !shape.shape) -> !shape.size {
187  %rank = shape.rank %shape : !shape.shape -> !shape.size
188  return %rank : !shape.size
189}
190
191func.func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
192  %rank = shape.rank %shape : tensor<?xindex> -> index
193  return %rank : index
194}
195
196func.func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
197  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
198  return %result : i1
199}
200
201func.func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
202  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
203  return %result : i1
204}
205
206func.func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
207  %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
208  return %result : i1
209}
210
211func.func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
212  %c0 = shape.const_size 0
213  %result = shape.get_extent %arg, %c0 :
214      !shape.shape, !shape.size -> !shape.size
215  return %result : !shape.size
216}
217
218func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
219  %c0 = arith.constant 0 : index
220  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
221  return %result : index
222}
223
224func.func @get_dim(%arg : memref<?x?xindex>) -> index {
225  %c0 = arith.constant 0 : index
226  %result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index
227  return %result : index
228}
229
230func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
231  %c0 = shape.const_size 0
232  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
233  return %result : !shape.size
234}
235
236func.func @any() {
237  %0 = shape.const_shape [1, 2, 3] : !shape.shape
238  %1 = shape.const_shape [4, 5, 6] : !shape.shape
239  %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
240  %3 = shape.const_shape [1, 2, 3] : tensor<3xindex>
241  %4 = shape.const_shape [4, 5, 6] : tensor<3xindex>
242  %5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex>
243  return
244}
245
246func.func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index {
247  %result = shape.num_elements %arg : tensor<?xindex> -> index
248  return %result : index
249}
250
251func.func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
252  %result = shape.num_elements %arg : !shape.shape -> !shape.size
253  return %result : !shape.size
254}
255
256// Testing invoking shape function from another. shape_equal_shapes is merely
257// a trivial helper function to invoke elsewhere.
258func.func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
259  %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
260  %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
261  %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
262  return %2 : !shape.shape
263}
264func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
265  %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
266  %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
267  %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
268  return %2 : !shape.shape
269}
270
271func.func @shape_with_shape_extent_tensor_type(%a : tensor<?x?x?xf32>, %b : !shape.value_shape) -> !shape.value_shape {
272  %0 = shape.shape_of %a : tensor<?x?x?xf32> -> tensor<3xindex>
273  %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex>
274  return %1 : !shape.value_shape
275}
276
277func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
278    -> !shape.shape {
279  %result = shape.any %a, %b, %c
280      : !shape.shape, !shape.shape, !shape.shape -> !shape.shape
281  return %result : !shape.shape
282}
283
284func.func @any_on_mixed(%a : tensor<?xindex>,
285                   %b : tensor<?xindex>,
286                   %c : !shape.shape) -> !shape.shape {
287  %result = shape.any %a, %b, %c
288      : tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape
289  return %result : !shape.shape
290}
291
292func.func @any_on_extent_tensors(%a : tensor<?xindex>,
293                            %b : tensor<?xindex>,
294                            %c : tensor<?xindex>) -> tensor<?xindex> {
295  %result = shape.any %a, %b, %c
296      : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
297  return %result : tensor<?xindex>
298}
299
300func.func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
301                                         %b : tensor<?xindex>) -> i1 {
302  %result = shape.is_broadcastable %a, %b
303      : tensor<?xindex>, tensor<?xindex>
304  return %result : i1
305}
306
307func.func @is_broadcastable_on_shapes(%a : !shape.shape,
308                                 %b : !shape.shape) -> i1 {
309  %result = shape.is_broadcastable %a, %b
310      : !shape.shape, !shape.shape
311  return %result : i1
312}
313
314func.func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
315  %0 = shape.const_shape [4, 57, 92] : !shape.shape
316  %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
317  %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
318    !shape.shape, !shape.shape -> !shape.shape
319  return %2 : !shape.shape
320}
321
322func.func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
323  %0 = shape.const_shape [4, 57, 92] : !shape.shape
324  %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
325  %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
326    !shape.shape, !shape.shape -> !shape.shape
327  return %2 : !shape.shape
328}
329
330func.func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
331  %0 = shape.const_size 5
332  %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
333  %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
334    !shape.size, !shape.size -> !shape.size
335  return %2 : !shape.size
336}
337
338func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
339  %0 = shape.const_size 9
340  %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
341  %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
342    !shape.size, !shape.size -> !shape.size
343  return %2 : !shape.size
344}
345
346func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
347  %result = shape.meet %arg0, %arg1 : index, index -> index
348  return %result : index
349}
350
351