xref: /llvm-project/mlir/test/mlir-opt/async.mlir (revision 52556c8e3561e7f3fa620e9d0c8f60cd4736b10f)
1// Check if mlir marks the corresponding function with required coroutine attribute.
2//
3// RUN:   mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
4// RUN: | FileCheck %s
5
6// CHECK: llvm.func @async_execute_fn{{.*}}attributes{{.*}}presplitcoroutine
7// CHECK: llvm.func @async_execute_fn_0{{.*}}attributes{{.*}}presplitcoroutine
8// CHECK: llvm.func @async_execute_fn_1{{.*}}attributes{{.*}}presplitcoroutine
9
10func.func @main() {
11  %i0 = arith.constant 0 : index
12  %i1 = arith.constant 1 : index
13  %i2 = arith.constant 2 : index
14  %i3 = arith.constant 3 : index
15
16  %c0 = arith.constant 0.0 : f32
17  %c1 = arith.constant 1.0 : f32
18  %c2 = arith.constant 2.0 : f32
19  %c3 = arith.constant 3.0 : f32
20  %c4 = arith.constant 4.0 : f32
21
22  %A = memref.alloc() : memref<4xf32>
23  linalg.fill ins(%c0 : f32) outs(%A : memref<4xf32>)
24
25  %U = memref.cast %A :  memref<4xf32> to memref<*xf32>
26  call @printMemrefF32(%U): (memref<*xf32>) -> ()
27
28  memref.store %c1, %A[%i0]: memref<4xf32>
29  call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
30  call @printMemrefF32(%U): (memref<*xf32>) -> ()
31
32  %outer = async.execute {
33    memref.store %c2, %A[%i1]: memref<4xf32>
34    func.call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
35    func.call @printMemrefF32(%U): (memref<*xf32>) -> ()
36
37    // No op async region to create a token for testing async dependency.
38    %noop = async.execute {
39      func.call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
40      async.yield
41    }
42
43    %inner = async.execute [%noop] {
44      memref.store %c3, %A[%i2]: memref<4xf32>
45      func.call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
46      func.call @printMemrefF32(%U): (memref<*xf32>) -> ()
47
48      async.yield
49    }
50    async.await %inner : !async.token
51
52    memref.store %c4, %A[%i3]: memref<4xf32>
53    func.call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
54    func.call @printMemrefF32(%U): (memref<*xf32>) -> ()
55
56    async.yield
57  }
58  async.await %outer : !async.token
59
60  call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
61  call @printMemrefF32(%U): (memref<*xf32>) -> ()
62
63  memref.dealloc %A : memref<4xf32>
64
65  return
66}
67
68func.func private @mlirAsyncRuntimePrintCurrentThreadId() -> ()
69
70func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
71