xref: /llvm-project/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir (revision 5e7dea225be10d3ba0d01e87fb36e80c6764bd83)
1// RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=true  \
2// RUN: | FileCheck %s --dump-input=always
3
4// CHECK-LABEL: @loop_1d(
5// CHECK-SAME:    %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
6func.func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
7  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
8
9  // CHECK:      %[[RANGE:.*]] = arith.subi %[[UB]], %[[LB]]
10  // CHECK:      %[[TRIP_CNT:.*]] = arith.ceildivsi %[[RANGE]], %[[STEP]]
11  // CHECK:      %[[IS_NOOP:.*]] = arith.cmpi eq, %[[TRIP_CNT]], %[[C0]] : index
12
13  // CHECK:      scf.if %[[IS_NOOP]] {
14  // CHECK-NEXT: } else {
15  // CHECK:        scf.if {{.*}} {
16  // CHECK:          call @parallel_compute_fn(%[[C0]]
17  // CHECK:        } else {
18  // CHECK:          %[[GROUP:.*]] = async.create_group
19  // CHECK:          call @async_dispatch_fn
20  // CHECK:          async.await_all %[[GROUP]]
21  // CHECK:        }
22  // CHECK:      }
23  scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
24    %one = arith.constant 1.0 : f32
25    memref.store %one, %arg3[%i] : memref<?xf32>
26  }
27  return
28}
29
30// CHECK-LABEL: func private @parallel_compute_fn
31// CHECK:       scf.for
32// CHECK:         memref.store
33
34// CHECK-LABEL: func private @async_dispatch_fn
35// CHECK-SAME:  (
36// CHECK-SAME:    %[[GROUP:arg0]]: !async.group,
37// CHECK-SAME:    %[[BLOCK_START:arg1]]: index
38// CHECK-SAME:    %[[BLOCK_END:arg2]]: index
39// CHECK-SAME:  )
40// CHECK:         %[[C1:.*]] = arith.constant 1 : index
41// CHECK:         %[[C2:.*]] = arith.constant 2 : index
42// CHECK:         scf.while (%[[S0:.*]] = %[[BLOCK_START]],
43// CHECK-SAME:               %[[E0:.*]] = %[[BLOCK_END]])
44// While loop `before` block decides if we need to dispatch more tasks.
45// CHECK:         {
46// CHECK:           %[[DIFF0:.*]] = arith.subi %[[E0]], %[[S0]]
47// CHECK:           %[[COND:.*]] = arith.cmpi sgt, %[[DIFF0]], %[[C1]]
48// CHECK:           scf.condition(%[[COND]])
49// While loop `after` block splits the range in half and submits async task
50// to process the second half using the call to the same dispatch function.
51// CHECK:         } do {
52// CHECK:         ^bb0(%[[S1:.*]]: index, %[[E1:.*]]: index):
53// CHECK:           %[[DIFF1:.*]] = arith.subi %[[E1]], %[[S1]]
54// CHECK:           %[[HALF:.*]] = arith.divsi %[[DIFF1]], %[[C2]]
55// CHECK:           %[[MID:.*]] = arith.addi %[[S1]], %[[HALF]]
56// CHECK:           %[[TOKEN:.*]] = async.execute
57// CHECK:             call @async_dispatch_fn
58// CHECK:           async.add_to_group
59// CHECK:           scf.yield %[[S1]], %[[MID]]
60// CHECK:         }
61// After async dispatch the first block processed in the caller thread.
62// CHECK:         call @parallel_compute_fn(%[[BLOCK_START]]
63
64// -----
65
66// CHECK-LABEL: @loop_2d
67func.func @loop_2d(%arg0: index, %arg1: index, %arg2: index, // lb, ub, step
68              %arg3: index, %arg4: index, %arg5: index, // lb, ub, step
69              %arg6: memref<?x?xf32>) {
70  // CHECK: %[[GROUP:.*]] = async.create_group
71  // CHECK: call @async_dispatch_fn
72  // CHECK: async.await_all %[[GROUP]]
73  scf.parallel (%i0, %i1) = (%arg0, %arg3) to (%arg1, %arg4)
74                            step (%arg2, %arg5) {
75    %one = arith.constant 1.0 : f32
76    memref.store %one, %arg6[%i0, %i1] : memref<?x?xf32>
77  }
78  return
79}
80
81// CHECK-LABEL: func private @parallel_compute_fn
82// CHECK:       scf.for
83// CHECK:         scf.for
84// CHECK:           memref.store
85
86// CHECK-LABEL: func private @async_dispatch_fn
87