xref: /llvm-project/mlir/test/Dialect/Affine/parallelize.mlir (revision 227ed2f448e26cfc646e30ac17b03eac9578c297)
1// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize | FileCheck %s
2// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='max-nested=1' | FileCheck --check-prefix=MAX-NESTED %s
3// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='parallel-reductions=1' | FileCheck --check-prefix=REDUCE %s
4
5// CHECK-LABEL:    func @reduce_window_max() {
6func.func @reduce_window_max() {
7  %cst = arith.constant 0.000000e+00 : f32
8  %0 = memref.alloc() : memref<1x8x8x64xf32>
9  %1 = memref.alloc() : memref<1x18x18x64xf32>
10  affine.for %arg0 = 0 to 1 {
11    affine.for %arg1 = 0 to 8 {
12      affine.for %arg2 = 0 to 8 {
13        affine.for %arg3 = 0 to 64 {
14          affine.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
15        }
16      }
17    }
18  }
19  affine.for %arg0 = 0 to 1 {
20    affine.for %arg1 = 0 to 8 {
21      affine.for %arg2 = 0 to 8 {
22        affine.for %arg3 = 0 to 64 {
23          affine.for %arg4 = 0 to 1 {
24            affine.for %arg5 = 0 to 3 {
25              affine.for %arg6 = 0 to 3 {
26                affine.for %arg7 = 0 to 1 {
27                  %2 = affine.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
28                  %3 = affine.load %1[%arg0 + %arg4, %arg1 * 2 + %arg5, %arg2 * 2 + %arg6, %arg3 + %arg7] : memref<1x18x18x64xf32>
29                  %4 = arith.cmpf ogt, %2, %3 : f32
30                  %5 = arith.select %4, %2, %3 : f32
31                  affine.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
32                }
33              }
34            }
35          }
36        }
37      }
38    }
39  }
40  return
41}
42
43// CHECK:        %[[cst:.*]] = arith.constant 0.000000e+00 : f32
44// CHECK:        %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
45// CHECK:        %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
46// CHECK:        affine.parallel (%[[arg0:.*]]) = (0) to (1) {
47// CHECK:          affine.parallel (%[[arg1:.*]]) = (0) to (8) {
48// CHECK:            affine.parallel (%[[arg2:.*]]) = (0) to (8) {
49// CHECK:              affine.parallel (%[[arg3:.*]]) = (0) to (64) {
50// CHECK:                affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
51// CHECK:              }
52// CHECK:            }
53// CHECK:          }
54// CHECK:        }
55// CHECK:        affine.parallel (%[[a0:.*]]) = (0) to (1) {
56// CHECK:          affine.parallel (%[[a1:.*]]) = (0) to (8) {
57// CHECK:            affine.parallel (%[[a2:.*]]) = (0) to (8) {
58// CHECK:              affine.parallel (%[[a3:.*]]) = (0) to (64) {
59// CHECK:                affine.parallel (%[[a4:.*]]) = (0) to (1) {
60// CHECK:                  affine.for %[[a5:.*]] = 0 to 3 {
61// CHECK:                    affine.for %[[a6:.*]] = 0 to 3 {
62// CHECK:                      affine.parallel (%[[a7:.*]]) = (0) to (1) {
63// CHECK:                        %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
64// CHECK:                        %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
65// CHECK:                        %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
66// CHECK:                        %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
67// CHECK:                        affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
68// CHECK:                      }
69// CHECK:                    }
70// CHECK:                  }
71// CHECK:                }
72// CHECK:              }
73// CHECK:            }
74// CHECK:          }
75// CHECK:        }
76// CHECK:      }
77
78func.func @loop_nest_3d_outer_two_parallel(%N : index) {
79  %0 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>>
80  %1 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>>
81  %2 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>>
82  affine.for %i = 0 to %N {
83    affine.for %j = 0 to %N {
84      %7 = affine.load %2[%i, %j] : memref<1024x1024xvector<64xf32>>
85      affine.for %k = 0 to %N {
86        %5 = affine.load %0[%i, %k] : memref<1024x1024xvector<64xf32>>
87        %6 = affine.load %1[%k, %j] : memref<1024x1024xvector<64xf32>>
88        %8 = arith.mulf %5, %6 : vector<64xf32>
89        %9 = arith.addf %7, %8 : vector<64xf32>
90        affine.store %9, %2[%i, %j] : memref<1024x1024xvector<64xf32>>
91      }
92    }
93  }
94  return
95}
96
97// CHECK:      affine.parallel (%[[arg1:.*]]) = (0) to (symbol(%arg0)) {
98// CHECK-NEXT:        affine.parallel (%[[arg2:.*]]) = (0) to (symbol(%arg0)) {
99// CHECK:          affine.for %[[arg3:.*]] = 0 to %arg0 {
100
101// CHECK-LABEL: unknown_op_conservative
102func.func @unknown_op_conservative() {
103  affine.for %i = 0 to 10 {
104// CHECK:  affine.for %[[arg1:.*]] = 0 to 10 {
105    "unknown"() : () -> ()
106  }
107  return
108}
109
110// CHECK-LABEL: non_affine_load
111func.func @non_affine_load() {
112  %0 = memref.alloc() : memref<100 x f32>
113  affine.for %i = 0 to 100 {
114// CHECK:  affine.for %{{.*}} = 0 to 100 {
115    memref.load %0[%i] : memref<100 x f32>
116  }
117  return
118}
119
120// CHECK-LABEL: for_with_minmax
121func.func @for_with_minmax(%m: memref<?xf32>, %lb0: index, %lb1: index,
122                      %ub0: index, %ub1: index) {
123  // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %{{.*}})) to (min(%{{.*}}, %{{.*}}))
124  affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %lb1)
125          to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
126    affine.load %m[%i] : memref<?xf32>
127  }
128  return
129}
130
131// CHECK-LABEL: nested_for_with_minmax
132func.func @nested_for_with_minmax(%m: memref<?xf32>, %lb0: index,
133                             %ub0: index, %ub1: index) {
134  // CHECK: affine.parallel (%[[I:.*]]) =
135  affine.for %j = 0 to 10 {
136    // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %[[I]])) to (min(%{{.*}}, %{{.*}}))
137    affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j)
138            to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
139      affine.load %m[%i] : memref<?xf32>
140    }
141  }
142  return
143}
144
145// MAX-NESTED-LABEL: @max_nested
146func.func @max_nested(%m: memref<?x?xf32>, %lb0: index, %lb1: index,
147                 %ub0: index, %ub1: index) {
148  // MAX-NESTED: affine.parallel
149  affine.for %i = affine_map<(d0) -> (d0)>(%lb0) to affine_map<(d0) -> (d0)>(%ub0) {
150    // MAX-NESTED: affine.for
151    affine.for %j = affine_map<(d0) -> (d0)>(%lb1) to affine_map<(d0) -> (d0)>(%ub1) {
152      affine.load %m[%i, %j] : memref<?x?xf32>
153    }
154  }
155  return
156}
157
158// MAX-NESTED-LABEL: @max_nested_1
159func.func @max_nested_1(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) {
160  %0 = memref.alloc() : memref<4096x4096xf32>
161  // MAX-NESTED: affine.parallel
162  affine.for %arg3 = 0 to 4096 {
163    // MAX-NESTED-NEXT: affine.for
164    affine.for %arg4 = 0 to 4096 {
165      // MAX-NESTED-NEXT: affine.for
166      affine.for %arg5 = 0 to 4096 {
167        %1 = affine.load %arg0[%arg3, %arg5] : memref<4096x4096xf32>
168        %2 = affine.load %arg1[%arg5, %arg4] : memref<4096x4096xf32>
169        %3 = affine.load %0[%arg3, %arg4] : memref<4096x4096xf32>
170        %4 = arith.mulf %1, %2 : f32
171        %5 = arith.addf %3, %4 : f32
172        affine.store %5, %0[%arg3, %arg4] : memref<4096x4096xf32>
173      }
174    }
175  }
176  return
177}
178
179// CHECK-LABEL: @iter_args
180// REDUCE-LABEL: @iter_args
181func.func @iter_args(%in: memref<10xf32>) {
182  // REDUCE: %[[init:.*]] = arith.constant
183  %cst = arith.constant 0.000000e+00 : f32
184  // CHECK-NOT: affine.parallel
185  // REDUCE: %[[reduced:.*]] = affine.parallel (%{{.*}}) = (0) to (10) reduce ("addf")
186  %final_red = affine.for %i = 0 to 10 iter_args(%red_iter = %cst) -> (f32) {
187    // REDUCE: %[[red_value:.*]] = affine.load
188    %ld = affine.load %in[%i] : memref<10xf32>
189    // REDUCE-NOT: arith.addf
190    %add = arith.addf %red_iter, %ld : f32
191    // REDUCE: affine.yield %[[red_value]]
192    affine.yield %add : f32
193  }
194  // REDUCE: arith.addf %[[init]], %[[reduced]]
195  return
196}
197
198// CHECK-LABEL: @nested_iter_args
199// REDUCE-LABEL: @nested_iter_args
200func.func @nested_iter_args(%in: memref<20x10xf32>) {
201  %cst = arith.constant 0.000000e+00 : f32
202  // CHECK: affine.parallel
203  affine.for %i = 0 to 20 {
204    // CHECK-NOT: affine.parallel
205    // REDUCE: affine.parallel
206    // REDUCE: reduce ("addf")
207    %final_red = affine.for %j = 0 to 10 iter_args(%red_iter = %cst) -> (f32) {
208      %ld = affine.load %in[%i, %j] : memref<20x10xf32>
209      %add = arith.addf %red_iter, %ld : f32
210      affine.yield %add : f32
211    }
212  }
213  return
214}
215
216// REDUCE-LABEL: @strange_butterfly
217func.func @strange_butterfly() {
218  %cst1 = arith.constant 0.0 : f32
219  %cst2 = arith.constant 1.0 : f32
220  // REDUCE-NOT: affine.parallel
221  affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) {
222    %0 = arith.addf %it1, %it2 : f32
223    affine.yield %0, %0 : f32, f32
224  }
225  return
226}
227
228// An iter arg is used more than once. This is not a simple reduction and
229// should not be parallelized.
230// REDUCE-LABEL: @repeated_use
231func.func @repeated_use() {
232  %cst1 = arith.constant 0.0 : f32
233  // REDUCE-NOT: affine.parallel
234  affine.for %i = 0 to 10 iter_args(%it1 = %cst1) -> (f32) {
235    %0 = arith.addf %it1, %it1 : f32
236    affine.yield %0 : f32
237  }
238  return
239}
240
241// An iter arg is used in the chain of operations defining the value being
242// reduced, this is not a simple reduction and should not be parallelized.
243// REDUCE-LABEL: @use_in_backward_slice
244func.func @use_in_backward_slice() {
245  %cst1 = arith.constant 0.0 : f32
246  %cst2 = arith.constant 1.0 : f32
247  // REDUCE-NOT: affine.parallel
248  affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) {
249    %0 = "test.some_modification"(%it2) : (f32) -> f32
250    %1 = arith.addf %it1, %0 : f32
251    affine.yield %1, %1 : f32, f32
252  }
253  return
254}
255
256// REDUCE-LABEL: @nested_min_max
257// CHECK-LABEL: @nested_min_max
258// CHECK: (%{{.*}}, %[[LB0:.*]]: index, %[[UB0:.*]]: index, %[[UB1:.*]]: index)
259func.func @nested_min_max(%m: memref<?xf32>, %lb0: index,
260                     %ub0: index, %ub1: index) {
261  // CHECK: affine.parallel (%[[J:.*]]) =
262  affine.for %j = 0 to 10 {
263    // CHECK: affine.parallel (%{{.*}}) = (max(%[[LB0]], %[[J]]))
264    // CHECK:                          to (min(%[[UB0]], %[[UB1]]))
265    affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j)
266            to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
267      affine.load %m[%i] : memref<?xf32>
268    }
269  }
270  return
271}
272
273// Test in the presence of locally allocated memrefs.
274
275// CHECK: func @local_alloc
276func.func @local_alloc() {
277  %cst = arith.constant 0.0 : f32
278  affine.for %i = 0 to 100 {
279    %m = memref.alloc() : memref<1xf32>
280    %ma = memref.alloca() : memref<1xf32>
281    affine.store %cst, %m[0] : memref<1xf32>
282  }
283  // CHECK: affine.parallel
284  return
285}
286
287// CHECK: func @local_alloc_cast
288func.func @local_alloc_cast() {
289  %cst = arith.constant 0.0 : f32
290  affine.for %i = 0 to 100 {
291    %m = memref.alloc() : memref<128xf32>
292    affine.for %j = 0 to 128 {
293      affine.store %cst, %m[%j] : memref<128xf32>
294    }
295    affine.for %j = 0 to 128 {
296      affine.store %cst, %m[0] : memref<128xf32>
297    }
298    %r = memref.reinterpret_cast %m to offset: [0], sizes: [8, 16],
299           strides: [16, 1] : memref<128xf32> to memref<8x16xf32>
300    affine.for %j = 0 to 8 {
301      affine.store %cst, %r[%j, %j] : memref<8x16xf32>
302    }
303  }
304  // CHECK: affine.parallel
305  // CHECK:   affine.parallel
306  // CHECK:   }
307  // CHECK:   affine.for
308  // CHECK:   }
309  // CHECK:   affine.parallel
310  // CHECK:   }
311  // CHECK: }
312
313  return
314}
315
316// CHECK-LABEL: @iter_arg_memrefs
317func.func @iter_arg_memrefs(%in: memref<10xf32>) {
318  %mi = memref.alloc() : memref<f32>
319  // Loop-carried memrefs are treated as serializing the loop.
320  // CHECK: affine.for
321  %mo = affine.for %i = 0 to 10 iter_args(%m_arg = %mi) -> (memref<f32>) {
322    affine.yield %m_arg : memref<f32>
323  }
324  return
325}
326