1// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s 2 3mesh.mesh @mesh_1d(shape = 2) 4mesh.mesh @mesh_1d_dynamic(shape = ?) 5 6// CHECK-LABEL: func @same_source_and_target_sharding 7func.func @same_source_and_target_sharding( 8 // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> 9 %arg0: tensor<2xf32> 10) -> tensor<2xf32> { 11 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 12 %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> 13 %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 14 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32> 15 // CHECK: return %[[ARG]] 16 return %1 : tensor<2xf32> 17} 18 19// CHECK-LABEL: func @identical_source_and_target_sharding 20func.func @identical_source_and_target_sharding( 21 // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> 22 %arg0: tensor<2xf32> 23) -> tensor<2xf32> { 24 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 25 %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> 26 %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32> 27 // CHECK: return %[[ARG]] 28 return %1 : tensor<2xf32> 29} 30 31// CHECK-LABEL: func @split_replicated_tensor_axis 32func.func @split_replicated_tensor_axis( 33 // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> 34 %arg0: tensor<3x14xf32> 35) -> tensor<3x14xf32> { 36 // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1 37 // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> 38 // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> 39 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 40 %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32> 41 %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 42 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32> 43 // CHECK: return %[[RESULT]] : tensor<3x14xf32> 44 return %1 : tensor<3x14xf32> 45} 46 47// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic 48func.func @split_replicated_tensor_axis_dynamic( 49 // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32> 50 %arg0: tensor<?x3x?xf32> 51) -> tensor<?x3x?xf32> { 52 // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0 53 // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32> 54 %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding 55 %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32> 56 %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding 57 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32> 58 // CHECK: return %[[RESULT]] : tensor<?x3x?xf32> 59 return %1 : tensor<?x3x?xf32> 60} 61 62// CHECK-LABEL: func @move_split_axis 63func.func @move_split_axis( 64 // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 65 %arg0: tensor<10x14xf32> 66) -> tensor<10x14xf32> { 67 // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> 68 // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> 69 // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> 70 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 71 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 72 %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 73 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 74 // CHECK: return %[[RES]] : tensor<10x14xf32> 75 return %1 : tensor<10x14xf32> 76} 77 78// CHECK-LABEL: func @move_split_axis_dynamic_mesh 79func.func @move_split_axis_dynamic_mesh( 80 // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 81 %arg0: tensor<10x14xf32> 82) -> tensor<10x14xf32> { 83 // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> 84 // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32> 85 // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32> 86 // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> 87 %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding 88 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 89 %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding 90 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 91 // CHECK: return %[[RES]] : tensor<10x14xf32> 92 return %1 : tensor<10x14xf32> 93} 94 95// CHECK-LABEL: func @move_split_dynamic_axis 96func.func @move_split_dynamic_axis( 97 // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> 98 %arg0: tensor<?x14xf32> 99) -> tensor<?x14xf32> { 100 // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32> 101 // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32> 102 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 103 %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32> 104 %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 105 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> 106 // CHECK: return %[[RES]] : tensor<?x14xf32> 107 return %1 : tensor<?x14xf32> 108} 109 110// CHECK-LABEL: func @unshard_static_axis 111func.func @unshard_static_axis( 112 // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 113 %arg0: tensor<10x14xf32> 114) -> tensor<10x14xf32> { 115 // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> 116 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> 117 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 118 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 119 %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 120 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 121 // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> 122 return %1 : tensor<10x14xf32> 123} 124 125// CHECK-LABEL: func @unshard_static_last_axis 126func.func @unshard_static_last_axis( 127 // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 128 %arg0: tensor<10x14xf32> 129) -> tensor<10x14xf32> { 130 // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> 131 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> 132 %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 133 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 134 %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding 135 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 136 // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> 137 return %1 : tensor<10x14xf32> 138} 139 140// CHECK-LABEL: func @unshard_dynamic_axis 141func.func @unshard_dynamic_axis( 142 // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> 143 %arg0: tensor<?x14xf32> 144) -> tensor<?x14xf32> { 145 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> 146 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 147 %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32> 148 %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 149 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> 150 // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32> 151 return %1 : tensor<?x14xf32> 152} 153 154// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis 155func.func @unshard_static_axis_on_dynamic_mesh_axis( 156// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 157 %arg0: tensor<10x14xf32> 158) -> tensor<10x14xf32> { 159 // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> 160 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> 161 // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32> 162 %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding 163 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 164 %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding 165 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 166 // CHECK: return %[[RES]] : tensor<10x14xf32> 167 return %1 : tensor<10x14xf32> 168} 169 170// CHECK-LABEL: func @partial_axis_to_full_replication 171func.func @partial_axis_to_full_replication( 172// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> 173 %arg0: tensor<10x14xf32> 174) -> tensor<10x14xf32> { 175 // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32> 176 %s0 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding 177 %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> 178 %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 179 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> 180 // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32> 181 return %1 : tensor<10x14xf32> 182} 183