xref: /llvm-project/mlir/test/Dialect/Tensor/bufferize.mlir (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=tensor,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -cse -split-input-file | FileCheck %s
2
3// CHECK-LABEL:   func @dim(
4// CHECK-SAME:              %[[TENSOR:.*]]: tensor<*xf32>,
5// CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
6// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<*xf32> to memref<*xf32>
7// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<*xf32>
8// CHECK:           return %[[EXTENT]] : index
9func.func @dim(%arg0: tensor<*xf32>, %arg1: index) -> index {
10  %0 = tensor.dim %arg0, %arg1 : tensor<*xf32>
11  return %0 : index
12}
13
14// -----
15
16// CHECK-LABEL: func @rank(
17// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
18// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
19// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
20func.func @rank(%arg0: tensor<*xf32>) -> index {
21  %0 = tensor.rank %arg0 : tensor<*xf32>
22  return %0 : index
23}
24
25// -----
26
27// CHECK-LABEL:   func @tensor.cast(
28// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
29// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
30// CHECK:           %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
31// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED]]
32// CHECK:           return %[[RET]] : tensor<2xindex>
33func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
34  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
35  return %0 : tensor<2xindex>
36}
37
38// -----
39
40// CHECK-LABEL:   func @tensor.cast_from_unranked(
41// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
42// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<*xf32> to memref<*xf32>
43// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>>
44// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>>
45// CHECK:           return %[[RET]] : tensor<2xf32>
46func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
47  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
48  return %0 : tensor<2xf32>
49}
50
51// -----
52
53// CHECK-LABEL:   func @tensor.cast_to_unranked(
54// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
55// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<2xf32> to memref<2xf32>
56// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
57// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32>
58// CHECK:           return %[[RET]] : tensor<*xf32>
59func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
60  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
61  return %0 : tensor<*xf32>
62}
63
64// -----
65
66// CHECK-LABEL:   func @tensor.empty(
67// CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<5xf32>
68// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<5xf32>
69// CHECK:           return %[[RET]] : tensor<5xf32>
70func.func @tensor.empty() -> tensor<5xf32> {
71  %0 = tensor.empty() : tensor<5xf32>
72  return %0 : tensor<5xf32>
73}
74
75// -----
76
77// CHECK-LABEL:   func @tensor.extract(
78// CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
79// CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {
80// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<?xf32> to memref<?xf32>
81// CHECK:           %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
82// CHECK:           return %[[RET]] : f32
83// CHECK:         }
84func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
85  %0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
86  return %0 : f32
87}
88
89// -----
90
91// CHECK-LABEL:   func @tensor.from_elements_0d(
92// CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
93// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
94// CHECK:           store %[[ELEM0]], %[[MEMREF]]
95// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
96// CHECK:           return %[[RET]] : tensor<index>
97func.func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
98  %0 = tensor.from_elements %arg0 : tensor<index>
99  return %0 : tensor<index>
100}
101
102// -----
103
104// CHECK-LABEL:   func @tensor.from_elements_1d(
105// CHECK-SAME:                               %[[ELEM0:.*]]: index,
106// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
107// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
108// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
109// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
110// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
111// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
112// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
113// CHECK:           return %[[RET]] : tensor<2xindex>
114func.func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
115  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
116  return %0 : tensor<2xindex>
117}
118
119// -----
120
121// CHECK-LABEL: func @tensor.from_elements_2d(
122// CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
123// CHECK-SAME:      -> tensor<3x2xindex> {
124// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
125// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
126// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
127// CHECK-DAG:     %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
128// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
129// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
130// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
131// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
132// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
133// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
134// CHECK:         %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
135// CHECK:         return %[[RET]] : tensor<3x2xindex>
136func.func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
137  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
138         : tensor<3x2xindex>
139  return %0 : tensor<3x2xindex>
140}
141
142// -----
143
144// CHECK-LABEL: func @tensor.from_elements_3d(
145//  CHECK-SAME:     %[[F0:.*]]: f32
146
147// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
148// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
149// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
150// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
151// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
152// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
153// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
154// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
155// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
156// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
157// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
158
159// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
160// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
161// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
162
163// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
164
165// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
166// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
167// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
168// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
169// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
170// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
171// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
172// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
173// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
174// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
175// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
176// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
177
178// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
179// CHECK: return %[[RET]] : tensor<3x2x2xf32>
180func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
181  %f1 = arith.constant 1.0 : f32
182  %f2 = arith.constant 2.0 : f32
183  %f3 = arith.constant 3.0 : f32
184  %f4 = arith.constant 4.0 : f32
185  %f5 = arith.constant 5.0 : f32
186  %f6 = arith.constant 6.0 : f32
187  %f7 = arith.constant 7.0 : f32
188  %f8 = arith.constant 8.0 : f32
189  %f9 = arith.constant 9.0 : f32
190  %f10 = arith.constant 10.0 : f32
191  %f11 = arith.constant 11.0 : f32
192  %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
193         : tensor<3x2x2xf32>
194  return %0 : tensor<3x2x2xf32>
195}
196
197// -----
198
199// CHECK-LABEL:   func @tensor.generate(
200// CHECK-SAME:        %[[ARG:.*]]: tensor<*xf32>,
201// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
202// CHECK-DAG:       %[[ARG_M:.*]] = bufferization.to_memref %[[ARG]] : tensor<*xf32> to memref<*xf32>
203// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
204// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
205// CHECK:           %[[MAPPED:.*]] = linalg.map
206// CHECK:                 outs(%[[ALLOC_T]] : tensor<?xindex>)
207// CHECK:             %[[INDEX:.*]] = linalg.index 0 : index
208// CHECK:             %[[ELEM:.*]] = memref.dim %[[ARG_M]], %[[INDEX]] : memref<*xf32>
209// CHECK:             linalg.yield %[[ELEM]]
210// CHECK:           }
211// CHECK:           return %[[MAPPED]] : tensor<?xindex>
212// CHECK:         }
213func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
214  %result = tensor.generate %dynamic_extent {
215  ^bb0(%i : index):
216    %elem = tensor.dim %arg, %i : tensor<*xf32>
217    tensor.yield %elem : index
218  } : tensor<?xindex>
219  return %result : tensor<?xindex>
220}
221
222// -----
223
224// Additional test that checks the logic for intermixed static and dynamic
225// extents.
226//
227// CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
228// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
229// CHECK:           %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
230// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
231// CHECK:           %[[MAPPED:.*]] = linalg.map
232// CHECK:                 outs(%[[ALLOC_T]] : tensor<16x?xindex>)
233// CHECK:             %[[INDEX0:.*]] = linalg.index 0
234// CHECK:             %[[INDEX1:.*]] = linalg.index 1
235// CHECK:             %[[ADD:.*]] = arith.addi %[[INDEX0]], %[[INDEX1]]
236// CHECK:             linalg.yield %[[ADD]]
237// CHECK:           }
238// CHECK:           return %[[MAPPED]] : tensor<16x?xindex>
239// CHECK:         }
240func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
241  %result = tensor.generate %arg0 {
242  ^bb0(%i: index, %j: index):
243    %sum = arith.addi %i, %j : index
244    tensor.yield %sum : index
245  } : tensor<16x?xindex>
246  return %result : tensor<16x?xindex>
247}
248
249// -----
250
251// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
252func.func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
253  // CHECK-NOT: tensor.generate
254  %tensor = tensor.generate %arg0 {
255  ^bb0(%iv: index):
256    // CHECK: test.source
257    %0 = "test.source"() : () -> index
258    tensor.yield %0 : index
259  } : tensor<?xindex>
260  return %tensor : tensor<?xindex>
261}
262
263// -----
264
265// CHECK-LABEL: func @tensor.extract_slice(
266//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
267func.func @tensor.extract_slice(
268    %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
269  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : tensor<?x?xf32> to memref<?x?xf32>
270  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, strided<[?, 1], offset: ?>>
271  %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
272      : tensor<?x?xf32> to tensor<?x10xf32>
273  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
274  // CHECK: return %[[r_tensor]]
275  return %0 : tensor<?x10xf32>
276}
277
278// -----
279
280// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
281//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
282//  CHECK-SAME:     %[[idx2:.*]]: index
283func.func @tensor.extract_slice_rank_reducing(
284    %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
285  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10x?xf32> to memref<?x10x?xf32>
286  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, strided<[?, 1], offset: ?>>
287  %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
288      : tensor<?x10x?xf32> to tensor<?x15xf32>
289  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
290  // CHECK: return %[[r_tensor]]
291  return %0 : tensor<?x15xf32>
292}
293
294// -----
295
296// CHECK-LABEL: func @tensor.insert_slice(
297//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
298//  CHECK-SAME:     %[[idx1:.*]]: index, %[[idx2:.*]]: index
299func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
300                               %idx1: index, %idx2: index) -> tensor<?x?xf32> {
301  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
302  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
303  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x?xf32> to memref<?x?xf32>
304  // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : tensor<?x10xf32> to memref<?x10xf32>
305  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
306  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
307  //     CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
308  //     CHECK: memref.copy %[[m1]], %[[alloc]]
309  //     CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
310  //     CHECK: memref.copy %[[m2]], %[[subview]]
311  %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
312      : tensor<?x10xf32> into tensor<?x?xf32>
313
314  //     CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
315  //     CHECK: return %[[r]]
316  return %0 : tensor<?x?xf32>
317}
318
319// -----
320
321// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1(
322func.func @tensor.insert_slice_rank_reducing_1(
323    %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index)
324  -> tensor<?x?xf32>
325{
326  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32>
327  // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, strided<[], offset: ?>>
328  // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, strided<[], offset: ?>>
329  %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1]
330      : tensor<f32> into tensor<?x?xf32>
331  return %0 : tensor<?x?xf32>
332}
333
334// -----
335
336// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2(
337func.func @tensor.insert_slice_rank_reducing_2(
338    %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index)
339  -> tensor<?x?x?x?x?x?x?xf32>
340{
341  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32>
342  // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, strided<[?, ?, ?, ?, ?], offset: ?>>
343  // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, strided<[?, ?, ?, ?, ?], offset: ?>>
344  %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1]
345      : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32>
346  return %0 : tensor<?x?x?x?x?x?x?xf32>
347}
348
349// -----
350
351// CHECK-LABEL: func @tensor.insert(
352//  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
353//  CHECK-SAME:     %[[f:.*]]: f32
354func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
355  // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
356  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<5xf32> to memref<5xf32>
357  // CHECK: memref.copy %[[m1]], %[[alloc]]
358  // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
359  %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
360
361  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
362  // CHECK: return %[[r]]
363  return %0 : tensor<5xf32>
364}
365
366// -----
367
368// CHECK-LABEL: func @tensor.expand_shape(
369//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
370func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
371  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
372  // CHECK: %[[C0:.*]] = arith.constant 0 : index
373  // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
374  // CHECK: %[[C2:.*]] = arith.constant 2 : index
375  // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
376  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
377  %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
378      : tensor<?x10xf32> into tensor<2x?x10xf32>
379
380  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
381  // CHECK: return %[[r]]
382  return %0 : tensor<2x?x10xf32>
383}
384
385// -----
386
387// CHECK-LABEL: func @tensor.expand_shape_of_slice(
388//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
389func.func @tensor.expand_shape_of_slice(
390    %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
391  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
392  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
393  %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
394      tensor<?x20xf32> to tensor<?x10xf32>
395  // CHECK: %[[C7:.*]] = arith.constant 7 : index
396  // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
397  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
398  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
399      tensor<?x10xf32> into tensor<?x7x2x5xf32>
400  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
401  // CHECK: return %[[r]]
402  return %1 : tensor<?x7x2x5xf32>
403}
404
405// -----
406
407// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
408//  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32>
409func.func @tensor.expand_shape_of_scalar_slice(
410    %t1: tensor<?xf32>, %o1: index, %s1: index) -> tensor<1xf32> {
411  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?xf32> to memref<?xf32>
412  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] :  memref<?xf32> to memref<f32, strided<[], offset: ?>>
413  %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
414  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
415  %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
416  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
417  // CHECK: return %[[r]]
418  return %1 : tensor<1xf32>
419}
420
421// -----
422
423// CHECK-LABEL: func @tensor.collapse_shape(
424//  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
425func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
426  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<2x?x?xf32> to memref<2x?x?xf32>
427  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
428  // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
429  %0 = tensor.collapse_shape %t1 [[0, 1], [2]]
430      : tensor<2x?x?xf32> into tensor<?x?xf32>
431
432  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
433  // CHECK: return %[[r]]
434  return %0 : tensor<?x?xf32>
435}
436
437// -----
438
439// CHECK-LABEL: func @tensor.collapse_shape_to_scalar(
440//  CHECK-SAME:     %[[t1:.*]]: tensor<1x1x1xf32>
441func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> {
442  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<1x1x1xf32> to memref<1x1x1xf32>
443  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32>
444  %0 = tensor.collapse_shape %t1 []
445      : tensor<1x1x1xf32> into tensor<f32>
446
447  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
448  // CHECK: return %[[r]]
449  return %0 : tensor<f32>
450}
451
452// -----
453
454// CHECK-LABEL: func @tensor.collapse_shape_of_slice(
455func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
456  // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
457  %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32>
458  // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>>
459  %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
460  return %1 : tensor<i32>
461}
462
463// -----
464
465// CHECK-LABEL: func @tensor.collapse_shape_of_slice2(
466func.func @tensor.collapse_shape_of_slice2(
467    %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index)
468    -> tensor<87x63648xi64> {
469  // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, strided{{.*}}>
470  %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64>
471
472  // This memref may not be collapsible, so the buffer must be copied to get rid
473  // of the layout map.
474  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64>
475  // CHECK: memref.copy %[[subview]], %[[alloc]]
476  // CHECK: memref.collapse_shape %[[alloc]] [
477  // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64>
478  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
479  return %1 : tensor<87x63648xi64>
480}
481
482// -----
483
484// CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
485//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
486func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
487  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, strided<[2, 1]>>
488  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
489  // CHECK: memref.collapse_shape %{{.*}} [
490  // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, strided<[2]>>
491  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
492  return %1 : tensor<1xf32>
493}
494
495// -----
496
497// CHECK-LABEL:   func @tensor.collapse_shape_of_slice4(
498//  CHECK-SAME:     %[[t1:.*]]: tensor<?x2x4xf32>,
499// CHECK-SAME:      %[[OFFSET:.*]]: index) -> tensor<8xf32> {
500func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> {
501  // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>>
502  %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
503  // CHECK: memref.collapse_shape %{{.*}} [
504  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, strided<[4], offset: ?>>
505  %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
506  return %ret: tensor<8xf32>
507}
508
509// -----
510
511// CHECK-LABEL: func @tensor.collapse_shape_of_slice5(
512func.func @tensor.collapse_shape_of_slice5(%arg0: tensor<2x2x2xi64>) -> tensor<4xi64> {
513  // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<2x2x2xi64> to memref<2x1x2xi64, {{.*}}>
514  %0 = tensor.extract_slice %arg0[0, 0, 0] [2, 1, 2] [1, 1, 1] : tensor<2x2x2xi64> to tensor<2x1x2xi64>
515
516  // This memref is not collapsible, so the buffer must be copied to get rid of
517  // the layout map.
518  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<2x1x2xi64>
519  // CHECK: memref.copy %[[subview]], %[[alloc]]
520  // CHECK: memref.collapse_shape %[[alloc]] [
521  // CHECK-SAME: [0, 1, 2]] : memref<2x1x2xi64> into memref<4xi64>
522  %1 = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<2x1x2xi64> into tensor<4xi64>
523  return %1 : tensor<4xi64>
524}
525
526// -----
527
528// CHECK-LABEL: func @tensor.reshape(
529//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
530func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
531  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10xf32> to memref<?x10xf32>
532
533  // CHECK: %[[two:.*]] = arith.constant 2 : i64
534  %two = arith.constant 2 : i64
535  // CHECK: %[[five:.*]] = arith.constant 5 : i64
536  %five = arith.constant 5 : i64
537
538  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xi64>
539  // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
540  // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
541  // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
542  // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
543  // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
544  // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
545  %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
546
547  // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
548  %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
549
550  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
551  // CHECK: return %[[r]]
552  return %reshaped : tensor<2x2x5xf32>
553}
554
555// -----
556
557// CHECK:       #[[$sum_map_1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
558// CHECK:       #[[$sum_map_2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)>
559// CHECK-LABEL: func @tensor.pad(
560//  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
561func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
562                      %h2: index) -> tensor<?x?xindex> {
563  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor<?x10xindex> to memref<?x10xindex>
564  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
565  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
566  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
567  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
568  // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[h1]], %[[dim0]]]
569  // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map_2]]()[%[[l2]], %[[h2]]]
570  // CHECK:     %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
571  // CHECK:     %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]]
572  // CHECK:     %[[mapped:.*]] = linalg.map
573  // CHECK:           outs(%[[alloc_t]] : tensor<?x?xindex>)
574  // CHECK:       %[[index0:.*]] = linalg.index 0
575  // CHECK:       %[[index1:.*]] = linalg.index 1
576  // CHECK:       %[[mul:.*]] = arith.muli %[[index0]], %[[index1]]
577  // CHECK:       linalg.yield %[[mul]]
578  // CHECK:     }
579  // CHECK:     %[[mapped_m:.*]] = bufferization.to_memref %[[mapped]]
580  // CHECK:     %[[subview:.*]] = memref.subview %[[mapped_m]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
581  // CHECK:     memref.copy %[[m1]], %[[subview]]
582  %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] {
583  ^bb0(%arg0: index, %arg1: index):
584    %m = arith.muli %arg0, %arg1 : index
585    tensor.yield %m : index
586  } : tensor<?x10xindex> to tensor<?x?xindex>
587
588  // CHECK:     %[[r:.*]] = bufferization.to_tensor %[[mapped_m]]
589  // CHECK:     return %[[r]] : tensor<?x?xindex>
590  return %0 : tensor<?x?xindex>
591}
592
593// -----
594
595// CHECK-LABEL:   func @tensor.splat(
596// CHECK-SAME:        %[[F:.*]]: f32)
597// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32>
598// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
599// CHECK:           %[[MAPPED:.*]] = linalg.map
600// CHECK:                 outs(%[[ALLOC_T]] : tensor<10x2x4xf32>)
601// CHECK:             linalg.yield %[[F]]
602// CHECK:           }
603// CHECK:           return %[[MAPPED]] : tensor<10x2x4xf32>
604// CHECK:         }
605func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
606  %t = tensor.splat %f : tensor<10x2x4xf32>
607  return %t : tensor<10x2x4xf32>
608}
609
610// -----
611
612// CHECK-LABEL: func @tensor.splat_dynamic(
613// CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
614// CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index
615// CHECK-SAME:  %[[N:[a-zA-Z0-9_]+]]: index
616// CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
617// CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
618// CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
619// CHECK:         () {
620// CHECK:           linalg.yield %[[F]] : f32
621// CHECK:         }
622// CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
623// CHECK:       }
624func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> {
625  %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32>
626  return %0 : tensor<?x3x?xf32>
627}
628
629// -----
630
631// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write
632func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) {
633  %c1 = arith.constant 1 : index
634  %num_threads = arith.constant 4 : index
635
636  // CHECK: scf.forall {{.*}} {
637  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> {
638      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32>
639      scf.forall.in_parallel {
640        // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
641        // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
642        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
643          tensor<1xf32> into tensor<4xf32>
644      }
645  }
646  // CHECK: }
647  return
648}
649