1// RUN: mlir-opt \ 2// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \ 3// RUN: --split-input-file \ 4// RUN: %s | FileCheck %s 5 6// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)> 7#map_identity_1d = affine_map<(d0) -> (d0)> 8 9mesh.mesh @mesh_1d(shape = 2) 10 11// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor 12func.func @elementwise_static_1d_mesh_static_1d_tensor( 13 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>, 14 %in1: tensor<2xi8>, 15 // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>, 16 %in2: tensor<2xi8>, 17 // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8> 18 %dps_out: tensor<2xi8> 19// CHECK-SAME: -> tensor<1xi8> { 20) -> tensor<2xi8> { 21 %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 22 %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8> 23 %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8> 24 %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8> 25 %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8> 26 %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8> 27 %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8> 28 // CHECK: %[[RES:.*]] = linalg.generic { 29 // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]], 30 // CHECK-SAME: iterator_types = ["parallel"]} 31 // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>) 32 // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) { 33 %res = linalg.generic { 34 indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d], 35 iterator_types = ["parallel"] 36 } ins(%in1_sharded2, %in2_sharded2 : tensor<2xi8>, tensor<2xi8>) 37 outs(%dps_out_shared2 : tensor<2xi8>) { 38 ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8): 39 %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8 40 linalg.yield %res_scalar : i8 41 } -> tensor<2xi8> 42 %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8> 43 %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8> 44 // CHECK: return %[[RES]] : tensor<1xi8> 45 return %res_shared2 : tensor<2xi8> 46} 47 48// ----- 49 50mesh.mesh @mesh_1d(shape = 4) 51 52// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding 53func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( 54 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>, 55 %in1: tensor<4x3xi8>, 56// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>, 57 %in2: tensor<3x8xi8>, 58// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8> 59 %dps_out: tensor<4x8xi8> 60// CHECK-SAME: -> tensor<1x8xi8> { 61) -> tensor<4x8xi8> { 62 %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 63 %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8> 64 %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8> 65 %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 66 %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8> 67 %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8> 68 %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8> 69 %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8> 70 // CHECK: %[[RES:.*]] = linalg.matmul 71 // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>) 72 // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>) 73 // CHECK-SAME: -> tensor<1x8xi8> 74 %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>) 75 outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> 76 %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8> 77 %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8> 78 // CHECK: return %[[RES]] : tensor<1x8xi8> 79 return %res_shared2 : tensor<4x8xi8> 80} 81 82// ----- 83 84mesh.mesh @mesh_1d(shape = 3) 85 86// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding 87func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( 88 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, 89 %in1: tensor<4x6xi8>, 90// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, 91 %in2: tensor<6x8xi8>, 92// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> 93 %dps_out: tensor<4x8xi8> 94// CHECK-SAME: -> tensor<4x8xi8> { 95) -> tensor<4x8xi8> { 96 %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 97 %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8> 98 %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> 99 %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 100 %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8> 101 %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> 102 %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 103 %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8> 104 %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> 105 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 106 // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 107 // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index 108 // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index 109 // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index 110 // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { 111 // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> 112 // CHECK: } else { 113 // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> 114 // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) 115 // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> 116 // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> 117 // CHECK: } 118 // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) 119 // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> 120 // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> 121 %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) 122 outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> 123 %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8> 124 %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> 125 // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8> 126 return %res_shared2 : tensor<4x8xi8> 127} 128 129// ----- 130 131mesh.mesh @mesh_1d(shape = 3) 132 133// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result 134func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result( 135 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, 136 %in1: tensor<4x6xi8>, 137// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, 138 %in2: tensor<6x8xi8>, 139// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> 140 %dps_out: tensor<4x8xi8> 141// CHECK-SAME: -> tensor<4x8xi8> { 142) -> tensor<4x8xi8> { 143 %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 144 %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8> 145 %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> 146 %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 147 %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8> 148 %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> 149 %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 150 %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8> 151 %sdps_out_shared2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 152 %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> 153 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 154 // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 155 // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index 156 // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index 157 // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index 158 // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { 159 // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> 160 // CHECK: } else { 161 // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> 162 // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) 163 // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> 164 // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> 165 // CHECK: } 166 // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) 167 // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> 168 %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) 169 outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> 170 %sharding4 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding 171 %res_shared1 = mesh.shard %res to %sharding4 : tensor<4x8xi8> 172 %res_shared2 = mesh.shard %res_shared1 to %sharding4 annotate_for_users : tensor<4x8xi8> 173 // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8> 174 return %res_shared2 : tensor<4x8xi8> 175} 176 177// ----- 178 179mesh.mesh @mesh_1d(shape = 4) 180 181// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis 182func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( 183 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>, 184 %in1: tensor<4x6xi8>, 185 // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>, 186 %in2: tensor<6x8xi8>, 187 // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> 188 %dps_out: tensor<4x8xi8> 189 // CHECK-SAME: -> tensor<4x8xi8> { 190) -> tensor<4x8xi8> { 191 %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding 192 %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8> 193 %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8> 194 // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1 195 %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8> 196 %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 197 %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8> 198 // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1 199 %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8> 200 %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8> 201 // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul 202 // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>) 203 // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>) 204 // CHECK-SAME: -> tensor<4x2xi8> 205 %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>) 206 outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8> 207 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> 208 %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8> 209 %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8> 210 // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8> 211 return %res_replicated : tensor<4x8xi8> 212} 213