xref: /llvm-project/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir (revision 88f0e4c75c1ac498f2223fc640c4ff6c572c5ed1)
1// RUN: mlir-opt %s -split-input-file                                          \
2// RUN:    -async-parallel-for=async-dispatch=true                             \
3// RUN: | FileCheck %s
4
5// RUN: mlir-opt %s -split-input-file                                          \
6// RUN:    -async-parallel-for=async-dispatch=false                            \
7// RUN:    -canonicalize -inline -symbol-dce                                   \
8// RUN: | FileCheck %s
9
10// Check that constants defined outside of the `scf.parallel` body will be
11// sunk into the parallel compute function to avoid blowing up the number
12// of parallel compute function arguments.
13
14// CHECK-LABEL: func @clone_constant(
15func.func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
16  %one = arith.constant 1.0 : f32
17
18  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
19    memref.store %one, %arg0[%i] : memref<?xf32>
20  }
21
22  return
23}
24
25// CHECK-LABEL: func private @parallel_compute_fn(
26// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
27// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
28// CHECK-SAME:   %[[TRIP_COUNT:arg[0-9]+]]: index,
29// CHECK-SAME:   %[[LB:arg[0-9]+]]: index,
30// CHECK-SAME:   %[[UB:arg[0-9]+]]: index,
31// CHECK-SAME:   %[[STEP:arg[0-9]+]]: index,
32// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?xf32>
33// CHECK-SAME: ) {
34// CHECK:        %[[CST:.*]] = arith.constant 1.0{{.*}} : f32
35// CHECK:        scf.for
36// CHECK:          memref.store %[[CST]], %[[MEMREF]]
37
38// -----
39
40// Check that constant loop bound sunk into the parallel compute function.
41
42// CHECK-LABEL: func @sink_constant_step(
43func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
44  %one = arith.constant 1.0 : f32
45  %st = arith.constant 123 : index
46
47  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
48    memref.store %one, %arg0[%i] : memref<?xf32>
49  }
50
51  return
52}
53
54// CHECK-LABEL: func private @parallel_compute_fn(
55// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
56// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
57// CHECK-SAME:   %[[TRIP_COUNT:arg[0-9]+]]: index,
58// CHECK-SAME:   %[[LB:arg[0-9]+]]: index,
59// CHECK-SAME:   %[[UB:arg[0-9]+]]: index,
60// CHECK-SAME:   %[[STEP:arg[0-9]+]]: index,
61// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?xf32>
62// CHECK-SAME: ) {
63// CHECK:        %[[CSTEP:.*]] = arith.constant 123 : index
64// CHECK-NOT:    %[[STEP]]
65// CHECK:        scf.for %[[I:arg[0-9]+]]
66// CHECK:          %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]]
67// CHECK:          %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]]
68// CHECK:          memref.store
69
70// -----
71
72// Smoke test that parallel for doesn't crash when func dialect is not used.
73
74// CHECK-LABEL: llvm.func @without_func_dialect()
75llvm.func @without_func_dialect() {
76  %cst = arith.constant 0.0 : f32
77
78  %c0 = arith.constant 0 : index
79  %c22 = arith.constant 22 : index
80  %c1 = arith.constant 1 : index
81  %54 = memref.alloc() : memref<22xf32>
82  %alloc_4 = memref.alloc() : memref<22xf32>
83  scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
84    memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
85  }
86  llvm.return
87}
88
89// -----
90
91// Check that for statically known inner loop bound block size is aligned and
92// inner loop uses statically known loop trip counts.
93
94// CHECK-LABEL: func @sink_constant_step(
95func.func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
96  %one = arith.constant 1.0 : f32
97
98  %c0 = arith.constant 0 : index
99  %c1 = arith.constant 1 : index
100  %c10 = arith.constant 10 : index
101
102  scf.parallel (%i, %j) = (%lb, %c0) to (%ub, %c10) step (%c1, %c1) {
103    memref.store %one, %arg0[%i, %j] : memref<?x10xf32>
104  }
105
106  return
107}
108
109// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
110// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
111// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
112// CHECK-SAME:   %[[TRIP_COUNT0:arg[0-9]+]]: index,
113// CHECK-SAME:   %[[TRIP_COUNT1:arg[0-9]+]]: index,
114// CHECK-SAME:   %[[LB0:arg[0-9]+]]: index,
115// CHECK-SAME:   %[[LB1:arg[0-9]+]]: index,
116// CHECK-SAME:   %[[UB0:arg[0-9]+]]: index,
117// CHECK-SAME:   %[[UB1:arg[0-9]+]]: index,
118// CHECK-SAME:   %[[STEP0:arg[0-9]+]]: index,
119// CHECK-SAME:   %[[STEP1:arg[0-9]+]]: index,
120// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
121// CHECK-SAME: ) {
122// CHECK:        %[[C0:.*]] = arith.constant 0 : index
123// CHECK:        %[[C1:.*]] = arith.constant 1 : index
124// CHECK:        %[[C10:.*]] = arith.constant 10 : index
125// CHECK:        scf.for %[[I:arg[0-9]+]]
126// CHECK-NOT:      arith.select
127// CHECK:          scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
128
129// CHECK-LABEL: func private @parallel_compute_fn(
130// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
131// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
132// CHECK-SAME:   %[[TRIP_COUNT0:arg[0-9]+]]: index,
133// CHECK-SAME:   %[[TRIP_COUNT1:arg[0-9]+]]: index,
134// CHECK-SAME:   %[[LB0:arg[0-9]+]]: index,
135// CHECK-SAME:   %[[LB1:arg[0-9]+]]: index,
136// CHECK-SAME:   %[[UB0:arg[0-9]+]]: index,
137// CHECK-SAME:   %[[UB1:arg[0-9]+]]: index,
138// CHECK-SAME:   %[[STEP0:arg[0-9]+]]: index,
139// CHECK-SAME:   %[[STEP1:arg[0-9]+]]: index,
140// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
141// CHECK-SAME: ) {
142// CHECK:        scf.for %[[I:arg[0-9]+]]
143// CHECK:          arith.select
144// CHECK:          scf.for %[[J:arg[0-9]+]]
145// CHECK:          memref.store
146