xref: /llvm-project/mlir/test/Dialect/Bufferization/canonicalize.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// RUN: mlir-opt %s \
2// RUN:   -canonicalize="test-convergence" \
3// RUN:   --split-input-file -allow-unregistered-dialect | \
4// RUN: FileCheck %s
5
6// Basic folding of to_tensor(to_memref(t)) -> t
7// CHECK-LABEL: func @tensor_load_of_buffer_cast(
8func.func @tensor_load_of_buffer_cast(%arg0: tensor<?xf32>) -> tensor<?xf32> {
9  %0 = bufferization.to_memref %arg0 : tensor<?xf32> to memref<?xf32>
10  %1 = bufferization.to_tensor %0 : memref<?xf32> to tensor<?xf32>
11  return %1 : tensor<?xf32>
12}
13// CHECK-SAME:   %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
14// CHECK: return %[[TENSOR]]
15
16// -----
17
18// Basic folding of to_memref(to_tensor(m)) -> m
19// CHECK-LABEL: func @buffer_cast_of_tensor_load(
20func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
21  %0 = bufferization.to_tensor %arg0 : memref<?xf32> to tensor<?xf32>
22  %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32>
23  return %1 : memref<?xf32>
24}
25// CHECK-SAME:   %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
26// CHECK: return %[[MEMREF]]
27
28// -----
29
30// If the memrefs are not the same type, don't fold them.
31// If the memrefs are not cast-compatible (e.g. different address space), don't
32// canonicalize them either.
33// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
34//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
35//  CHECK-SAME:     -> memref<?xf32, 7> {
36//       CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
37//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
38//       CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
39//  CHECK-SAME:   %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
40//       CHECK: return %[[MEMREF_ADDRSPACE7]]
41func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
42    -> memref<?xf32, 7> {
43  %0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
44  %1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>
45  return %1 : memref<?xf32, 7>
46}
47
48// -----
49
50// If the memrefs are definitely cast-compatible, canonicalize to
51//            cast.
52// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load(
53//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, strided<[1], offset: 3>>)
54//  CHECK-SAME:     -> memref<?xf32, strided<[1], offset: ?>> {
55//   CHECK-NOT: bufferization.to_tensor
56//   CHECK-NOT: bufferization.to_memref
57//       CHECK: %[[R:.*]] = memref.cast %[[M]]
58//  CHECK-SAME:   memref<?xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[1], offset: ?>>
59//       CHECK: return %[[R]]
60func.func @canonicalize_buffer_cast_of_tensor_load(
61  %arg0: memref<?xf32, strided<[1], offset: 3>>)
62  -> memref<?xf32, strided<[1], offset: ?>>
63{
64  %0 = bufferization.to_tensor %arg0 : memref<?xf32, strided<[1], offset: 3>> to tensor<?xf32>
65  %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32, strided<[1], offset: ?>>
66  return %1 : memref<?xf32, strided<[1], offset: ?>>
67}
68
69// -----
70
71// If the memrefs are potentially cast-compatible, canonicalize to
72//            copy.
73// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy(
74func.func @canonicalize_buffer_cast_of_tensor_load_to_copy(
75  %arg0: memref<?xf32, strided<[1], offset: ?>>)
76  -> memref<?xf32, strided<[1], offset: 3>> {
77  %0 = bufferization.to_tensor %arg0 : memref<?xf32, strided<[1], offset: ?>> to tensor<?xf32>
78  %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32, strided<[1], offset: 3>>
79  return %1 : memref<?xf32, strided<[1], offset: 3>>
80}
81// CHECK-SAME:   %[[M:.*]]: memref<?xf32, strided<[1], offset: ?>>)
82// CHECK-SAME:     -> memref<?xf32, strided<[1], offset: 3>> {
83//  CHECK-NOT: bufferization.to_tensor
84//  CHECK-NOT: bufferization.to_memref
85//      CHECK: %[[C0:.*]] = arith.constant 0 : index
86//      CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>>
87//      CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>>
88//      CHECK: memref.copy %[[M]], %[[ALLOC]]
89// CHECK-SAME:   memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>>
90//      CHECK: return %[[ALLOC]]
91
92// -----
93
94
95// Basic folding of tensor.dim(to_tensor(m)) -> memref.dim(m).
96// CHECK-LABEL: func @dim_of_tensor_load(
97//  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
98//       CHECK:   %[[C0:.*]] = arith.constant 0
99//       CHECK:   %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
100//       CHECK:   return %[[D]] : index
101func.func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
102  %c0 = arith.constant 0 : index
103  %0 = bufferization.to_tensor %arg0 : memref<?xf32> to tensor<?xf32>
104  %1 = tensor.dim %0, %c0 : tensor<?xf32>
105  return %1 : index
106}
107
108// -----
109
110// CHECK-LABEL: @clone_before_dealloc
111func.func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> {
112  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
113  memref.dealloc %arg0 : memref<?xf32>
114  return %0 : memref<?xf32>
115}
116// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
117// CHECK-NEXT: return %[[ARG]]
118
119// -----
120
121// CHECK-LABEL: @clone_before_dealloc
122func.func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> {
123  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
124  "use"(%0) : (memref<?xf32>) -> ()
125  memref.dealloc %0 : memref<?xf32>
126  return %arg0 : memref<?xf32>
127}
128// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
129// CHECK-NEXT: "use"(%arg0)
130// CHECK-NEXT: return %[[ARG]]
131
132// -----
133
134// CHECK-LABEL: @clone_after_cast
135func.func @clone_after_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
136  %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
137  %1 = bufferization.clone %0 : memref<32xf32> to memref<32xf32>
138  return %1 : memref<32xf32>
139}
140// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
141// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<?xf32> to memref<32xf32>
142// CHECK-NOT: memref.cast
143
144// -----
145
146// CHECK-LABEL: @clone_and_cast
147func.func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
148  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32>
149  memref.dealloc %arg0 : memref<?xf32>
150  return %0 : memref<32xf32>
151}
152// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
153// CHECK-NEXT: %[[RES:.*]] = memref.cast %[[ARG]]
154// CHECK-SAME:   memref<?xf32> to memref<32xf32>
155// CHECK-NEXT: return %[[RES]]
156
157// -----
158
159// CHECK-LABEL: @clone_incompatible
160func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> {
161  %0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32>
162  memref.dealloc %arg0 : memref<32xf32, strided<[2]>>
163  return %0 : memref<32xf32>
164}
165// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>>
166// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32>
167// CHECK-NOT: memref.cast
168
169// -----
170
171// CHECK-LABEL: @alias_is_freed
172func.func @alias_is_freed(%arg0 : memref<?xf32>) {
173  %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
174  %1 = bufferization.clone %0 : memref<32xf32> to memref<32xf32>
175  memref.dealloc %arg0 : memref<?xf32>
176  "use"(%1) : (memref<32xf32>) -> ()
177  memref.dealloc %1 : memref<32xf32>
178  return
179}
180// CHECK: bufferization.clone
181// CHECK: memref.dealloc
182// CHECK: memref.dealloc
183
184// -----
185
186// Verify SimplifyClones skips clones with multiple deallocations.
187// CHECK-LABEL: @clone_multiple_dealloc_of_source
188func.func @clone_multiple_dealloc_of_source(%arg0: memref<?xf32>) -> memref<?xf32> {
189  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
190  "if_else"() ({
191    memref.dealloc %arg0 : memref<?xf32>
192    }, {
193    memref.dealloc %arg0 : memref<?xf32>
194    }) : () -> ()
195  return %0 : memref<?xf32>
196}
197// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
198// CHECK-NEXT: %[[RES:.*]] = bufferization.clone %[[ARG]]
199// CHECK: memref.dealloc %[[ARG]]
200// CHECK: memref.dealloc %[[ARG]]
201// CHECK: return %[[RES]]
202
203// -----
204
205// CHECK-LABEL: @clone_multiple_dealloc_of_clone
206// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
207func.func @clone_multiple_dealloc_of_clone(%arg0: memref<?xf32>) -> memref<?xf32> {
208  // CHECK-NEXT: %[[CLONE:.*]] = bufferization.clone %[[ARG]]
209  // CHECK: memref.dealloc %[[CLONE]]
210  // CHECK: memref.dealloc %[[CLONE]]
211  // CHECK: return %[[ARG]]
212  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
213  "use"(%0) : (memref<?xf32>) -> ()
214  "if_else"() ({
215    memref.dealloc %0 : memref<?xf32>
216    }, {
217    memref.dealloc %0 : memref<?xf32>
218    }) : () -> ()
219  return %arg0 : memref<?xf32>
220}
221
222// -----
223
224// Verify SimplifyClones skips clones followed by realloc.
225// CHECK-LABEL: @clone_and_realloc
226func.func @clone_and_realloc(%arg0: memref<?xf32>) {
227  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32>
228  "use"(%0) : (memref<32xf32>) -> ()
229  %1 = memref.realloc %0 : memref<32xf32> to memref<64xf32>
230  memref.dealloc %1 : memref<64xf32>
231  return
232}
233// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
234// CHECK-NOT: %cast = memref.cast %[[ARG]]
235
236// -----
237
238// Verify SimplifyClones skips clones with preceding deallocation.
239// CHECK-LABEL: @clone_and_preceding_dealloc
240func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
241  memref.dealloc %arg0 : memref<?xf32>
242  %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32>
243  return %0 : memref<32xf32>
244}
245// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
246// CHECK-NOT: %cast = memref.cast %[[ARG]]
247
248// -----
249
250// CHECK-LABEL: func @tensor_cast_to_memref
251//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
252func.func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
253  memref<?x?x16x32xi8> {
254  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
255  %1 = bufferization.to_memref %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
256  return %1 : memref<?x?x16x32xi8>
257}
258// CHECK:   %[[M:.+]] = bufferization.to_memref %[[ARG0]] : tensor<4x6x16x32xi8>
259// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
260// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
261// CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
262
263// -----
264
265// Folding of memref.load(to_memref(%v, %idxs)) -> tensor.extract(%v, %idx)
266// CHECK-LABEL: func @load_from_buffer_cast(
267func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
268                            %arg2: tensor<?x?xf32>) -> f32 {
269  %0 = bufferization.to_memref %arg2 : tensor<?x?xf32> to memref<?x?xf32>
270  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
271  return %1 : f32
272}
273//  CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
274//  CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
275//       CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
276//   CHECK-NOT: memref.load
277//       CHECK: return %[[RES]] : f32
278
279
280// -----
281
282func.func @alloc_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
283  %c6 = arith.constant 6 : index
284  %0 = bufferization.alloc_tensor(%c6) : tensor<4x5x?xf32>
285  return %0 : tensor<4x5x?xf32>
286}
287// CHECK: func @alloc_tensor_canonicalize
288// CHECK:   %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32>
289// CHECK:   %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
290// CHECK:   return %[[T1]]
291
292// -----
293
294func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<*xf32> {
295  %c1 = arith.constant 1 : index
296  %0 = memref.alloc(%c1) : memref<?xf32>
297  %1 = memref.reshape %0(%arg0) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
298  %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32>
299  memref.dealloc %0 : memref<?xf32>
300  return %2 : memref<*xf32>
301}
302// CHECK-LABEL: @dealloc_canonicalize_clone_removal
303//   CHECK-NOT:   bufferization.clone
304//   CHECK-NOT:   memref.dealloc
305//       CHECK:   return {{.*}}
306
307// -----
308
309func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) {
310  %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
311  bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
312  return %0#0, %0#1, %0#2 : i1, i1, i1
313}
314
315// CHECK-LABEL: func @dealloc_canonicalize_duplicates
316//  CHECK-SAME:  ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>)
317//  CHECK-NEXT:   [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>)
318//  CHECK-NEXT:   [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
319//  CHECK-NEXT:   bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
320//  CHECK-NEXT:   return [[V0]]#0, [[V0]]#1, [[V0]]#0 :
321
322// -----
323
324func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 {
325  bufferization.dealloc
326  %0 = bufferization.dealloc retain (%arg0 : memref<2xi32>)
327  return %0 : i1
328}
329
330// CHECK-LABEL: func @dealloc_erase_empty
331//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
332//  CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
333//  CHECK-NEXT: return [[FALSE]] :
334
335// -----
336
337func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) {
338  %false = arith.constant false
339  bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
340  return
341}
342
343// CHECK-LABEL: func @dealloc_always_false_condition
344//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
345//  CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
346//  CHECK-NEXT: return
347
348// -----
349
350func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> {
351  %alloc = memref.alloc() : memref<2xi32>
352  %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index
353  %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index
354  bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2)
355  return %alloc : memref<2xi32>
356}
357
358// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
359//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>)
360//  CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc() : memref<2xi32>
361//  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] :
362//  CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]])
363//  CHECK-NEXT: return
364
365// -----
366
367func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) {
368  %true = arith.constant true
369  %alloc = memref.alloc() : memref<2xi32>
370  bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true)
371  return
372}
373
374// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
375//  CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>)
376//   CHECK-NOT: memref.alloc(
377//       CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true
378
379// -----
380
381// CHECK-LABEL: func @negative_input
382func.func @negative_input() -> tensor<?x?x?xf16> {
383  %idx27 = index.constant 27
384  %idx-3 = index.constant -3  // negative integer?
385  %c10 = arith.constant 10 : index
386// CHECK: bufferization.alloc_tensor
387// CHECK-SAME: tensor<10x?x27xf16>
388  %11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
389  return %11 : tensor<?x?x?xf16>
390}
391