xref: /llvm-project/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir (revision cf9b77a636e0e92b1f4cafd99aaff394f4773f08)
1// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @elementwise_no_conflict
4func.func @elementwise_no_conflict(%a: tensor<5xf32>,
5                                   %b: tensor<5xf32>) -> tensor<5xf32> {
6  // CHECK: linalg.elemwise_binary
7  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
8  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
9      ins(%a, %b : tensor<5xf32>, tensor<5xf32>)
10      outs(%a : tensor<5xf32>) -> tensor<5xf32>
11  return %0 : tensor<5xf32>
12}
13
14// -----
15
16// CHECK-LABEL: @elementwise_no_conflict_2
17func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> {
18  // CHECK: linalg.elemwise_binary
19  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
20  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
21      ins(%a, %a : tensor<5xf32>, tensor<5xf32>)
22      outs(%a : tensor<5xf32>) -> tensor<5xf32>
23  return %0 : tensor<5xf32>
24}
25
26// -----
27
28// CHECK-LABEL: @elementwise_no_conflict_3
29func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> {
30  %c0f = arith.constant 1.0 : f32
31  // CHECK: linalg.elemwise_binary
32  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn<add>}
33  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
34      ins(%a, %c0f : tensor<5xf32>, f32)
35      outs(%a : tensor<5xf32>) -> tensor<5xf32>
36  return %0 : tensor<5xf32>
37}
38
39// -----
40
41func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> {
42  %cst = arith.constant 5.0 : f32
43  // CHECK: tensor.extract_slice
44  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
45  %b = tensor.extract_slice %a[0, 0] [1, 6] [1, 1]
46      : tensor<5x6xf32> to tensor<6xf32>
47  // CHECK: linalg.generic
48  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
49  %0 = linalg.generic
50    { iterator_types = ["parallel", "parallel"],
51      indexing_maps = [ affine_map<(d0, d1) -> (d1)>,
52                        affine_map<(d0, d1) -> (d0, d1)>] }
53    ins(%b: tensor<6xf32>) outs(%a: tensor<5x6xf32>) {
54    ^bb0(%arg0: f32, %arg1: f32):
55      %r = arith.addf %arg0, %arg1 : f32
56      linalg.yield %r : f32
57    } -> tensor<5x6xf32>
58  return %0 : tensor<5x6xf32>
59}
60
61// -----
62
63#map = affine_map<(d0, d1) -> (d0, d1)>
64#map1 = affine_map<(d0, d1) -> (d1)>
65
66// CHECK-LABEL: @elementwise_no_conflict_4
67func.func @elementwise_no_conflict_4(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32xf32>) -> tensor<8x32x32x32xf32> {
68  %cst = arith.constant dense<3.000000e-02> : tensor<32x32x32xf32>
69  %cst_0 = arith.constant dense<6.000000e-01> : tensor<32xf32>
70  %cst_1 = arith.constant 0.000000e+00 : f32
71  %r = scf.forall (%arg2, %arg3) in (8, 32) shared_outs(%arg4 = %arg0) -> (tensor<8x32x32x32xf32>) {
72    // CHECK: tensor.extract_slice
73    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "none"]}
74    %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32xf32>
75
76    // CHECK: linalg.fill
77    // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
78    %4 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
79
80    // CHECK: linalg.batch_reduce_matmul
81    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]}
82    %5 = linalg.batch_reduce_matmul ins(%arg1, %cst : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%4 : tensor<32x32xf32>) -> tensor<32x32xf32>
83
84    // CHECK: linalg.generic
85    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]}
86    // %cst_0 has a non-identity layout may, but %5 and %extracted_slice still
87    // bufferize to element-wise access.
88    %6 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_0 : tensor<32x32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
89    ^bb0(%in: f32, %in_4: f32, %out: f32):
90      %8 = arith.addf %in, %in_4 : f32
91      linalg.yield %8 : f32
92    } -> tensor<32x32xf32>
93
94    // CHECK: linalg.generic
95    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
96    // They are different SSA values, but %6 and %extract_slice are equivalent.
97    %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
98    ^bb0(%in: f32, %out: f32):
99      %8 = arith.maximumf %in, %cst_1 : f32
100      linalg.yield %8 : f32
101    } -> tensor<32x32xf32>
102    scf.forall.in_parallel {
103      // CHECK: tensor.parallel_insert_slice
104      // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]}
105      tensor.parallel_insert_slice %7 into %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x32x32x32xf32>
106    }
107  }
108  return %r : tensor<8x32x32x32xf32>
109}
110
111// -----
112
113// CHECK-LABEL: func @elementwise_access_regression(
114//       CHECK:   linalg.fill {__inplace_operands_attr__ = ["none", "false"]}
115//       CHECK:   linalg.map
116//  CHECK-SAME:   {__inplace_operands_attr__ = ["true", "true", "true"]}
117//       CHECK:   linalg.map
118//  CHECK-SAME:   {__inplace_operands_attr__ = ["true", "true", "true"]}
119func.func private @f(%arg: tensor<32x1xf32>) -> ()
120func.func @elementwise_access_regression(%arg0: i32, %arg2: tensor<32x1xf32>, %arg3: tensor<32x1xf32>) {
121      %cst_0 = arith.constant 0.000000e+00 : f32
122      %c0_i32 = arith.constant 0 : i32
123      %c1_i32 = arith.constant 1 : i32
124      %0 = tensor.empty() : tensor<32x1xf32>
125
126      // This op must bufferize out-of-place so that the filled tensor is not
127      // overwritten by the ops inside of the loop.
128      %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<32x1xf32>) -> tensor<32x1xf32>
129
130      scf.for %arg1 = %c0_i32 to %arg0 step %c1_i32 : i32 {
131        %2 = linalg.map { arith.subf } ins(%1, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%0 : tensor<32x1xf32>)
132        %3 = tensor.empty() : tensor<32x1xf32>
133        %4 = linalg.map { arith.subf } ins(%2, %arg3 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%3 : tensor<32x1xf32>)
134        func.call @f(%4) : (tensor<32x1xf32>) -> ()
135      }
136      return
137}
138