1// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s 2 3// ----- 4// CHECK: mesh.mesh @mesh0 5mesh.mesh @mesh0(shape = 3x4x5) 6func.func @process_multi_index() -> (index, index, index) { 7 // CHECK: mpi.comm_rank : !mpi.retval, i32 8 // CHECK-DAG: %[[v4:.*]] = arith.remsi 9 // CHECK-DAG: %[[v0:.*]] = arith.remsi 10 // CHECK-DAG: %[[v1:.*]] = arith.remsi 11 %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index 12 // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index 13 return %0#0, %0#1, %0#2 : index, index, index 14} 15 16// CHECK-LABEL: func @process_linear_index 17func.func @process_linear_index() -> index { 18 // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32 19 // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index 20 %0 = mesh.process_linear_index on @mesh0 : index 21 // CHECK: return %[[cast]] : index 22 return %0 : index 23} 24 25// CHECK-LABEL: func @neighbors_dim0 26func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) { 27 %c0 = arith.constant 0 : index 28 %c1 = arith.constant 1 : index 29 %c4 = arith.constant 4 : index 30 // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index 31 // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index 32 %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index 33 // CHECK: return [[down]], [[up]] : index, index 34 return %idx#0, %idx#1 : index, index 35} 36 37// CHECK-LABEL: func @neighbors_dim1 38func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) { 39 %c0 = arith.constant 0 : index 40 %c1 = arith.constant 1 : index 41 %c4 = arith.constant 4 : index 42 // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index 43 // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index 44 %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index 45 // CHECK: return [[down]], [[up]] : index, index 46 return %idx#0, %idx#1 : index, index 47} 48 49// CHECK-LABEL: func @neighbors_dim2 50func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { 51 %c0 = arith.constant 0 : index 52 %c1 = arith.constant 1 : index 53 %c4 = arith.constant 4 : index 54 // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index 55 // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index 56 %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index 57 // CHECK: return [[down]], [[up]] : index, index 58 return %idx#0, %idx#1 : index, index 59} 60 61// ----- 62// CHECK: mesh.mesh @mesh0 63mesh.mesh @mesh0(shape = 3x4x5) 64memref.global constant @static_mpi_rank : memref<index> = dense<24> 65func.func @process_multi_index() -> (index, index, index) { 66 // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index 67 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 68 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 69 %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index 70 // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index 71 return %0#0, %0#1, %0#2 : index, index, index 72} 73 74// CHECK-LABEL: func @process_linear_index 75func.func @process_linear_index() -> index { 76 // CHECK: %[[c24:.*]] = arith.constant 24 : index 77 %0 = mesh.process_linear_index on @mesh0 : index 78 // CHECK: return %[[c24]] : index 79 return %0 : index 80} 81 82// ----- 83mesh.mesh @mesh0(shape = 3x4x5) 84// CHECK-LABEL: func @update_halo_1d_first 85func.func @update_halo_1d_first( 86 // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8> 87 %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { 88 // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 89 // CHECK: mpi.send( 90 // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 91 // CHECK: mpi.recv( 92 // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 93 // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 94 // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 95 // CHECK: mpi.send( 96 // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 97 // CHECK: mpi.recv( 98 // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 99 // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 100 %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> 101 // CHECK: return [[res:%.*]] : memref<120x120x120xi8> 102 return %res : memref<120x120x120xi8> 103} 104 105// ----- 106mesh.mesh @mesh0(shape = 3x4x5) 107memref.global constant @static_mpi_rank : memref<index> = dense<24> 108// CHECK-LABEL: func @update_halo_3d 109func.func @update_halo_3d( 110 // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> 111 %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { 112 // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 113 // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 114 // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 115 // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 116 // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 117 // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> 118 // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> 119 // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> 120 // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 121 // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 122 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> 123 // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> 124 // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> 125 // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> 126 // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> 127 // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> 128 // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 129 // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 130 // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> 131 // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> 132 // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> 133 // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> 134 // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 135 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> 136 // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> 137 // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> 138 // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> 139 // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> 140 // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> 141 // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 142 // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> 143 // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> 144 // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> 145 // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> 146 // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 147 // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> 148 // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> 149 // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 150 // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> 151 // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> 152 // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> 153 %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> 154 // CHECK: return [[varg0]] : memref<120x120x120xi8> 155 return %res : memref<120x120x120xi8> 156} 157 158// CHECK-LABEL: func @update_halo_3d_tensor 159func.func @update_halo_3d_tensor( 160 // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> 161 %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { 162 // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 163 // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 164 // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 165 // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 166 // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 167 // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> 168 // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> 169 // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> 170 // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> 171 // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 172 // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 173 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> 174 // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> 175 // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> 176 // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> 177 // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> 178 // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> 179 // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 180 // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 181 // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> 182 // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> 183 // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> 184 // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> 185 // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 186 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> 187 // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> 188 // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> 189 // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> 190 // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> 191 // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> 192 // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 193 // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> 194 // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> 195 // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> 196 // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> 197 // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 198 // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> 199 // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> 200 // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 201 // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> 202 // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> 203 // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> 204 // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> 205 %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> 206 // CHECK: return [[v1]] : tensor<120x120x120xi8> 207 return %res : tensor<120x120x120xi8> 208} 209