xref: /llvm-project/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir (revision 35d3b3430eff16403d004d9f0b0369f0814cf140)
1// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" -canonicalize -buffer-loop-hoisting -drop-equivalent-buffer-results -split-input-file | FileCheck %s
2
3// Run fuzzer with different seeds.
4// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-heuristic=fuzzer analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
5// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-heuristic=fuzzer analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
6// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-heuristic=fuzzer analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
7
8// Test bufferization using memref types that have no layout map.
9// RUN: mlir-opt %s -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -drop-equivalent-buffer-results -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT-MAP
10
11// TODO: Some test cases from this file should be moved to other dialects.
12
13// CHECK-LABEL: func @fill_inplace(
14//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
15// CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref<?xf32>) {
16func.func @fill_inplace(
17    %A : tensor<?xf32> {bufferization.writable = true})
18  -> tensor<?xf32>
19{
20  //     CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
21  %f0 = arith.constant 0.0 : f32
22
23  /// Inplaceable, no alloc
24  // CHECK-NOT: alloc
25  //     CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[A]] : memref<?xf32, strided<[?], offset: ?>>)
26  %r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
27
28  //     CHECK: return
29  // CHECK-NOT: tensor
30  return %r: tensor<?xf32>
31}
32
33// -----
34
35/// No bufferization.writable flag, must allocate.
36// CHECK-LABEL: func @not_inplace(
37//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>) -> memref<?xf32> {
38// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref<?xf32>) -> memref<?xf32>
39func.func @not_inplace(
40    %A : tensor<?xf32> {bufferization.writable = false})
41  -> tensor<?xf32>
42{
43  //     CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
44  %f0 = arith.constant 0.0 : f32
45
46  //     CHECK: %[[D0:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, strided<[?], offset: ?>>
47  //     CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) {alignment = 64 : i64} : memref<?xf32>
48  //     CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref<?xf32>)
49  %r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
50
51  // CHECK-NOT: dealloc
52  //     CHECK: return %[[ALLOC]] : memref<?xf32>
53  return %r: tensor<?xf32>
54}
55
56// -----
57
58
59// CHECK-LABEL: func @not_inplace
60//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>) {
61// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref<?x?xf32>) {
62func.func @not_inplace(
63    %A : tensor<?x?xf32> {bufferization.writable = true})
64  -> tensor<?x?xf32>
65{
66  %f0 = arith.constant 0.0 : f32
67
68  /// Cross-op multiple uses of %A, the first op which has interfering reads must alloc.
69  //       CHECK: %[[ALLOC:.*]] = memref.alloc
70  //       CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[ALLOC]]
71  %f = linalg.fill ins(%f0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
72
73  /// The second op has no interfering reads and can reuse.
74  //   CHECK-NOT: alloc
75  //       CHECK: linalg.matmul ins(%[[ALLOC]], %[[ALLOC]]{{.*}}) outs(%[[A]]
76  %r = linalg.matmul  ins(%f, %f: tensor<?x?xf32>, tensor<?x?xf32>)
77                     outs(%A: tensor<?x?xf32>)
78    -> tensor<?x?xf32>
79
80  //     CHECK: return
81  // CHECK-NOT: tensor
82  return %r: tensor<?x?xf32>
83}
84
85// -----
86
87// CHECK-LABEL: func @not_inplace
88func.func @not_inplace(
89    %A : tensor<?x?xf32> {bufferization.writable = true}) -> tensor<?x?xf32> {
90  /// Within op multiple uses of %A, must alloc.
91  // CHECK: alloc
92  %r = linalg.matmul  ins(%A, %A: tensor<?x?xf32>, tensor<?x?xf32>)
93                     outs(%A: tensor<?x?xf32>)
94    -> tensor<?x?xf32>
95  // CHECK-NOT: dealloc
96  return %r: tensor<?x?xf32>
97}
98// -----
99
100// CHECK-LABEL: func @vec_inplace
101func.func @vec_inplace(
102    %A : tensor<?xf32> {bufferization.writable = true}, %vec : vector<4xf32>)
103  -> tensor<?xf32>
104{
105  %c0 = arith.constant 0 : index
106
107  // CHECK-NOT: alloc
108  %r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
109
110  //     CHECK: return
111  // CHECK-NOT: tensor
112  return %r: tensor<?xf32>
113}
114
115// -----
116
117// CHECK-LABEL: func @vec_not_inplace
118//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
119func.func @vec_not_inplace(
120    %A : tensor<?xf32> {bufferization.writable = true}, %vec : vector<4xf32>)
121  -> (tensor<?xf32>, tensor<?xf32>)
122{
123  %c0 = arith.constant 0 : index
124  %c1 = arith.constant 1 : index
125
126  /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
127  //      CHECK: %[[ALLOC:.*]] = memref.alloc
128  //      CHECK: memref.copy {{.*}}, %[[ALLOC]]
129  // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
130  %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
131
132  /// The second vector.transfer has no interfering reads and can reuse the buffer.
133  //  CHECK-NOT: alloc
134  // CHECK-NEXT: vector.transfer_write {{.*}}, %[[A]]
135  %r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor<?xf32>
136
137  //     CHECK: return
138  // CHECK-NOT: tensor
139  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
140}
141
142// -----
143
144//      CHECK: func @matmul(
145// CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: memref<128x256xf32>
146// CHECK-SAME:   %[[B:[0-9a-zA-Z]*]]: memref<256x192xf32>
147// CHECK-SAME:   %[[C:[0-9a-zA-Z]*]]: memref<128x192xf32>
148func.func @matmul(
149    %A: tensor<128x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false},
150    %B: tensor<256x192xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false},
151    %C: tensor<128x192xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = true})
152  -> tensor<128x192xf32> {
153  %c0 = arith.constant 0 : index
154  %c256 = arith.constant 256 : index
155  %c32 = arith.constant 32 : index
156  %cst = arith.constant 0.000000e+00 : f32
157  %c128 = arith.constant 128 : index
158  %c192 = arith.constant 192 : index
159  %c8 = arith.constant 8 : index
160  %c16 = arith.constant 16 : index
161
162  // Hoisted alloc.
163  // CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x16xf32>
164
165  // CHECK: scf.for %[[I:.*]] =
166  %0 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %C) -> (tensor<128x192xf32>) {
167    %1 = tensor.extract_slice %A[%arg3, 0] [8, 256] [1, 1] :
168      tensor<128x256xf32> to tensor<8x256xf32>
169
170    // CHECK: scf.for %[[J:.*]] =
171    %2 = scf.for %arg5 = %c0 to %c192 step %c16 iter_args(%arg6 = %arg4) -> (tensor<128x192xf32>) {
172      %3 = tensor.extract_slice %B[0, %arg5] [256, 16] [1, 1] :
173        tensor<256x192xf32> to tensor<256x16xf32>
174
175      // Insert an artificial out-of-place buffer by extracting from %C instead
176      // of %arg6.
177      %4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] :
178        tensor<128x192xf32> to tensor<8x16xf32>
179
180      // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[ALLOC]]
181      %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8x16xf32>) -> tensor<8x16xf32>
182
183      // CHECK: scf.for %[[K:.*]] =
184      %6 = scf.for %arg7 = %c0 to %c256 step %c32 iter_args(%arg8 = %5) -> (tensor<8x16xf32>) {
185        %8 = tensor.extract_slice %1[0, %arg7] [8, 32] [1, 1] :
186          tensor<8x256xf32> to tensor<8x32xf32>
187        %9 = tensor.extract_slice %3[%arg7, 0] [32, 16] [1, 1] :
188          tensor<256x16xf32> to tensor<32x16xf32>
189
190        // linalg.matmul is inplace as well as the enclosing scf.for.
191        // CHECK: linalg.matmul ins({{.*}} outs(%[[ALLOC]]
192        %10 = linalg.matmul ins(%8, %9 : tensor<8x32xf32>, tensor<32x16xf32>)
193                           outs(%arg8 : tensor<8x16xf32>)
194          -> tensor<8x16xf32>
195        scf.yield %10 : tensor<8x16xf32>
196      }
197
198      // insert_slice is inplace but its source comes from an equivalent buffer
199      // that is not in place. So we must insert a copy of the small buffer into
200      // the bigger buffer.
201      // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1]
202      // CHECK: memref.copy %[[ALLOC]], %[[T]]
203      %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] :
204        tensor<8x16xf32> into tensor<128x192xf32>
205
206      scf.yield %7 : tensor<128x192xf32>
207    }
208    scf.yield %2 : tensor<128x192xf32>
209  }
210
211  return %0 : tensor<128x192xf32>
212}
213
214// -----
215
216/// This test just checks the produced IR is valid and does not have dominance
217/// errors in the def-use chains.
218
219// CHECK-LABEL: func @dominance_violation_bug_1
220func.func @dominance_violation_bug_1(
221    %A : tensor<?x?xf32> {bufferization.writable = false},
222    %idx : index)
223  -> tensor<?x?xf32>
224{
225  %f0 = arith.constant 0.0 : f32
226
227  %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
228  %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
229  %FA = linalg.fill ins(%f0 : f32) outs(%ssA : tensor<4x4xf32>) -> tensor<4x4xf32>
230  %rsA = tensor.insert_slice %FA into %sA[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
231  %rA = tensor.insert_slice %rsA into %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
232
233  return %rA : tensor<?x?xf32>
234}
235
236// -----
237
238func.func @gather_like(
239    %arg0 : tensor<?x?xf32> {bufferization.writable = false},
240    %arg1 : tensor<?xi32> {bufferization.writable = false},
241    %arg2 : tensor<?x?xf32> {bufferization.writable = true})
242  -> tensor<?x?xf32>
243{
244  %0 = linalg.generic {
245      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
246                       affine_map<(d0, d1) -> (d0, d1)>],
247      iterator_types = ["parallel", "parallel"]}
248      ins(%arg1 : tensor<?xi32>) outs(%arg2 : tensor<?x?xf32>) {
249      ^bb0(%arg3: i32, %arg4 : f32):
250        %iv1 = linalg.index 1 : index
251        %1 = arith.index_cast %arg3: i32 to index
252        %2 = tensor.extract %arg0[%1, %iv1] : tensor<?x?xf32>
253        linalg.yield %2 : f32
254      } -> tensor<?x?xf32>
255  return %0 : tensor<?x?xf32>
256}
257// CHECK-LABEL: func @gather_like(
258//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32,
259//  CHECK-SAME:     %[[ARG1:.+]]: memref<?xi32
260//  CHECK-SAME:     %[[ARG2:.+]]: memref<?x?xf32
261//  CHECK-SAME:   ) {
262//       CHECK:   linalg.generic
263//  CHECK-SAME:       ins(%[[ARG1]] :
264//  CHECK-SAME:       outs(%[[ARG2]] :
265//       CHECK:     %[[YIELD:.+]] = memref.load %[[ARG0]]
266//       CHECK:     linalg.yield %[[YIELD]]
267
268// -----
269
270// CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input
271//  CHECK-SAME:     %[[t1:.*]]: memref<?x?xf32, strided{{.*}}>, %[[t2:.*]]: memref<?xf32, strided{{.*}}>, %[[t3:.*]]: memref<?x?xf32, strided{{.*}}>
272func.func @linalg_op_bufferizes_inplace_with_input(
273    %t1: tensor<?x?xf32> {bufferization.writable = true},
274    %t2: tensor<?xf32> {bufferization.writable = true},
275    %t3: tensor<?x?xf32> {bufferization.writable = true},
276    %s1: index, %s2: index, %cst: f32)
277  -> tensor<?x?xf32>
278{
279  // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}})
280  %r = linalg.generic {
281    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
282                     affine_map<(d0, d1) -> (d1)>,
283                     affine_map<(d0, d1)-> (d0, d1)>],
284    iterator_types = ["parallel", "parallel"]}
285    ins(%t1, %t2 : tensor<?x?xf32>, tensor<?xf32>)
286    outs(%t3 : tensor<?x?xf32>) {
287      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
288        %add = arith.addf %arg0, %arg1 : f32
289        linalg.yield %add : f32
290    } -> tensor<?x?xf32>
291  return %r : tensor<?x?xf32>
292}
293
294// -----
295
296#accesses = [
297  affine_map<(i) -> (i)>
298]
299#trait = {
300  indexing_maps = #accesses,
301  iterator_types = ["parallel"]
302}
303
304// CHECK-LABEL: func @op_is_reading_but_following_ops_are_not
305//  CHECK-SAME:     %[[t0:.*]]: memref<?xf32
306func.func @op_is_reading_but_following_ops_are_not(
307    %t0 : tensor<?xf32> {bufferization.writable = false},
308    %cst : f32)
309  -> tensor<?xf32>
310{
311  // Make sure that a copy is inserted here.
312  // CHECK: %[[ALLOC:.*]] = memref.alloc
313  // CHECK: memref.copy %[[t0]], %[[ALLOC]]
314  // CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref
315  %r0 =linalg.generic #trait outs (%t0 : tensor<?xf32>) {
316      ^bb(%0: f32) :
317        %a = arith.addf %cst, %0 : f32
318        linalg.yield %a : f32
319    } -> (tensor<?xf32>)
320
321  // CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref
322  %r1 = linalg.generic #trait outs (%r0 : tensor<?xf32>) {
323      ^bb(%0: f32) :
324        linalg.yield %cst : f32
325    } -> (tensor<?xf32>)
326
327  // CHECK: return %[[ALLOC]]
328  return %r1 : tensor<?xf32>
329}
330
331// -----
332
333// CHECK-LABEL: func @map_binary
334// CHECK-SAME:  %[[LHS:[0-9a-zA-Z]*]]: memref<64xf32
335// CHECK-SAME:  %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
336func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
337                      %init: tensor<64xf32>) -> tensor<64xf32> {
338   // CHECK:      linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32
339   %add = linalg.map
340          ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
341          outs(%init:tensor<64xf32>)
342          (%lhs_elem: f32, %rhs_elem: f32) {
343            %0 = arith.addf %lhs_elem, %rhs_elem: f32
344            linalg.yield %0: f32
345          }
346  func.return %add : tensor<64xf32>
347}
348
349// -----
350
351// CHECK-LABEL: func @reduce
352// CHECK-SAME:  %[[INPUT:.*]]: memref<16x32x64xf32
353func.func @reduce(%input: tensor<16x32x64xf32>,
354                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
355  // CHECK:     linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32
356  %reduce = linalg.reduce
357      ins(%input:tensor<16x32x64xf32>)
358      outs(%init:tensor<16x64xf32>)
359      dimensions = [1]
360      (%in: f32, %out: f32) {
361        %0 = arith.addf %out, %in: f32
362        linalg.yield %0: f32
363      }
364  func.return %reduce : tensor<16x64xf32>
365}
366
367// -----
368
369// CHECK-LABEL: func @transpose
370// CHECK-SAME:  %[[ARG0:.*]]: memref<16x32x64xf32
371func.func @transpose(%input: tensor<16x32x64xf32>,
372                     %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
373  // CHECK:      linalg.transpose ins(%[[ARG0]] : memref<16x32x64xf32
374  %transpose = linalg.transpose
375      ins(%input:tensor<16x32x64xf32>)
376      outs(%init:tensor<32x64x16xf32>)
377      permutation = [1, 2, 0]
378  func.return %transpose : tensor<32x64x16xf32>
379}
380
381// -----
382
383// CHECK-LABEL: func @broadcast
384// CHECK-SAME:  %[[ARG0:.*]]: memref<8x32xf32
385func.func @broadcast(%input: tensor<8x32xf32>,
386                     %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
387  %bcast = linalg.broadcast
388      ins(%input:tensor<8x32xf32>)
389      outs(%init:tensor<8x16x32xf32>)
390      dimensions = [1]
391  func.return %bcast : tensor<8x16x32xf32>
392}
393
394// -----
395
396//===----------------------------------------------------------------------===//
397// AllocTensorOp elimination would produce SSA violations for the example below.
398//===----------------------------------------------------------------------===//
399
400func.func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
401    -> tensor<?x1x6x8xf32> {
402  %c0 = arith.constant 0 : index
403  %c32 = arith.constant 32 : index
404  %c8 = arith.constant 8 : index
405  %0 = bufferization.alloc_tensor() : tensor<4x1x6x8xf32>
406  %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
407  %2 = bufferization.alloc_tensor() : tensor<1x6x8xf32>
408  %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
409    %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
410    %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
411      tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
412    scf.yield %5 : tensor<?x1x6x8xf32>
413  }
414  return %3 : tensor<?x1x6x8xf32>
415}
416
417// -----
418
419// CHECK-LABEL: func @do_not_copy_alloc_tensors(
420func.func @do_not_copy_alloc_tensors(%f1: f32, %f2: f32, %idx: index)
421  -> (tensor<5xf32>, tensor<5xf32>)
422{
423  // CHECK: memref.alloc
424  // CHECK: memref.alloc
425  // CHECK-NOT: copy
426  // CHECK: memref.store
427  // CHECK: memref.store
428  %0 = bufferization.alloc_tensor() : tensor<5xf32>
429  %1 = tensor.insert %f1 into %0[%idx] : tensor<5xf32>
430  %2 = tensor.insert %f2 into %0[%idx] : tensor<5xf32>
431  return %1, %2 : tensor<5xf32>, tensor<5xf32>
432}
433