xref: /llvm-project/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (revision b7749efb749541716b6785c48fc8d6c2a4453ffb)
1// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
2
3// -----
4// CHECK: mesh.mesh @mesh0
5mesh.mesh @mesh0(shape = 3x4x5)
6func.func @process_multi_index() -> (index, index, index) {
7  // CHECK: mpi.comm_rank : !mpi.retval, i32
8  // CHECK-DAG: %[[v4:.*]] = arith.remsi
9  // CHECK-DAG: %[[v0:.*]] = arith.remsi
10  // CHECK-DAG: %[[v1:.*]] = arith.remsi
11  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
12  // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
13  return %0#0, %0#1, %0#2 : index, index, index
14}
15
16// CHECK-LABEL: func @process_linear_index
17func.func @process_linear_index() -> index {
18  // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
19  // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
20  %0 = mesh.process_linear_index on @mesh0 : index
21  // CHECK: return %[[cast]] : index
22  return %0 : index
23}
24
25// CHECK-LABEL: func @neighbors_dim0
26func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
27  %c0 = arith.constant 0 : index
28  %c1 = arith.constant 1 : index
29  %c4 = arith.constant 4 : index
30  // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
31  // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
32  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
33  // CHECK: return [[down]], [[up]] : index, index
34  return %idx#0, %idx#1 : index, index
35}
36
37// CHECK-LABEL: func @neighbors_dim1
38func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
39  %c0 = arith.constant 0 : index
40  %c1 = arith.constant 1 : index
41  %c4 = arith.constant 4 : index
42  // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
43  // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
44  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
45  // CHECK: return [[down]], [[up]] : index, index
46  return %idx#0, %idx#1 : index, index
47}
48
49// CHECK-LABEL: func @neighbors_dim2
50func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
51  %c0 = arith.constant 0 : index
52  %c1 = arith.constant 1 : index
53  %c4 = arith.constant 4 : index
54  // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
55  // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
56  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
57  // CHECK: return [[down]], [[up]] : index, index
58  return %idx#0, %idx#1 : index, index
59}
60
61// -----
62// CHECK: mesh.mesh @mesh0
63mesh.mesh @mesh0(shape = 3x4x5)
64memref.global constant @static_mpi_rank : memref<index> = dense<24>
65func.func @process_multi_index() -> (index, index, index) {
66  // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
67  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
68  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
69  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
70  // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
71  return %0#0, %0#1, %0#2 : index, index, index
72}
73
74// CHECK-LABEL: func @process_linear_index
75func.func @process_linear_index() -> index {
76  // CHECK: %[[c24:.*]] = arith.constant 24 : index
77  %0 = mesh.process_linear_index on @mesh0 : index
78  // CHECK: return %[[c24]] : index
79  return %0 : index
80}
81
82// -----
83mesh.mesh @mesh0(shape = 3x4x5)
84// CHECK-LABEL: func @update_halo_1d_first
85func.func @update_halo_1d_first(
86  // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
87  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
88  // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
89  // CHECK: mpi.send(
90  // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
91  // CHECK: mpi.recv(
92  // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
93  // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
94  // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
95  // CHECK: mpi.send(
96  // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
97  // CHECK: mpi.recv(
98  // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
99  // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
100  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
101  // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
102  return %res : memref<120x120x120xi8>
103}
104
105// -----
106mesh.mesh @mesh0(shape = 3x4x5)
107memref.global constant @static_mpi_rank : memref<index> = dense<24>
108// CHECK-LABEL: func @update_halo_3d
109func.func @update_halo_3d(
110  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
111  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
112  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
113  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
114  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
115  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
116  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
117  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
118  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
119  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
120  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
121  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
122  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
123  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
124  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
125  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
126  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
127  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
128  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
129  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
130  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
131  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
132  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
133  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
134  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
135  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
136  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
137  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
138  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
139  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
140  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
141  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
142  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
143  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
144  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
145  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
146  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
147  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
148  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
149  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
150  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
151  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
152  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
153  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
154  // CHECK: return [[varg0]] : memref<120x120x120xi8>
155  return %res : memref<120x120x120xi8>
156}
157
158// CHECK-LABEL: func @update_halo_3d_tensor
159func.func @update_halo_3d_tensor(
160  // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
161  %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
162  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
163  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
164  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
165  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
166  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
167  // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
168  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
169  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
170  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
171  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
172  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
173  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
174  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
175  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
176  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
177  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
178  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
179  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
180  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
181  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
182  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
183  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
184  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
185  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
186  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
187  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
188  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
189  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
190  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
191  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
192  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
193  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
194  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
195  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
196  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
197  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
198  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
199  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
200  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
201  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
202  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
203  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
204  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
205  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
206  // CHECK: return [[v1]] : tensor<120x120x120xi8>
207  return %res : tensor<120x120x120xi8>
208}
209