1// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s 2 3mesh.mesh @mesh0(shape = 4x2) 4mesh.mesh @mesh1(shape = 4) 5 6// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to 7// `all_reduce(x + y)`. 8// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism 9func.func @all_reduce_arith_addf_endomorphism( 10 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 11 %arg0: tensor<5xf32>, 12 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 13 %arg1: tensor<5xf32>) -> tensor<5xf32> { 14 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 15 : tensor<5xf32> -> tensor<5xf32> 16 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] 17 : tensor<5xf32> -> tensor<5xf32> 18 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] 19 %2 = arith.addf %0, %1 : tensor<5xf32> 20 // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] 21 // CHECK: return %[[ALL_REDUCE_RES]] 22 return %2 : tensor<5xf32> 23} 24 25// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result 26func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result( 27 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 28 %arg0: tensor<5xf32>, 29 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 30 %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { 31 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 32 : tensor<5xf32> -> tensor<5xf32> 33 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] 34 : tensor<5xf32> -> tensor<5xf32> 35 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] 36 %2 = arith.addf %0, %1 : tensor<5xf32> 37 // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] 38 // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]] 39 return %2, %2 : tensor<5xf32>, tensor<5xf32> 40} 41 42// Do not simplify if there is another use of one of the all-reduces. 43// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result 44func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result( 45 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 46 %arg0: tensor<5xf32>, 47 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 48 %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { 49 // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] 50 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 51 : tensor<5xf32> -> tensor<5xf32> 52 // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] 53 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] 54 : tensor<5xf32> -> tensor<5xf32> 55 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]] 56 %2 = arith.addf %0, %1 : tensor<5xf32> 57 // CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]] 58 return %0, %2 : tensor<5xf32>, tensor<5xf32> 59} 60 61// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh 62func.func @all_reduce_arith_addf_no_endomorphism_different_mesh( 63 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 64 %arg0: tensor<5xf32>, 65 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 66 %arg1: tensor<5xf32>) -> tensor<5xf32> { 67 // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 68 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 69 : tensor<5xf32> -> tensor<5xf32> 70 // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1 71 %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0] 72 : tensor<5xf32> -> tensor<5xf32> 73 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] 74 %2 = arith.addf %0, %1 : tensor<5xf32> 75 // CHECK: return %[[ADD_RES]] 76 return %2 : tensor<5xf32> 77} 78 79// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes 80func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes( 81 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 82 %arg0: tensor<5xf32>, 83 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 84 %arg1: tensor<5xf32>) -> tensor<5xf32> { 85 // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] 86 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 87 : tensor<5xf32> -> tensor<5xf32> 88 // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1] 89 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1] 90 : tensor<5xf32> -> tensor<5xf32> 91 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] 92 %2 = arith.addf %0, %1 : tensor<5xf32> 93 // CHECK: return %[[ADD_RES]] 94 return %2 : tensor<5xf32> 95} 96 97// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind 98func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind( 99 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 100 %arg0: tensor<5xf32>, 101 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 102 %arg1: tensor<5xf32>) -> tensor<5xf32> { 103 // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max 104 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max 105 : tensor<5xf32> -> tensor<5xf32> 106 // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] 107 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] 108 : tensor<5xf32> -> tensor<5xf32> 109 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] 110 %2 = arith.addf %0, %1 : tensor<5xf32> 111 // CHECK: return %[[ADD_RES]] 112 return %2 : tensor<5xf32> 113} 114 115// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types 116func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types( 117 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 118 %arg0: tensor<5xf32>, 119 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 120 %arg1: tensor<5xf32>) -> tensor<5xf64> { 121 // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] 122 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] 123 : tensor<5xf32> -> tensor<5xf64> 124 // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] 125 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] 126 : tensor<5xf32> -> tensor<5xf64> 127 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] 128 %2 = arith.addf %0, %1 : tensor<5xf64> 129 // CHECK: return %[[ADD_RES]] 130 return %2 : tensor<5xf64> 131} 132 133// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to 134// `all_reduce(min(x, y))`. 135// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism 136func.func @all_reduce_arith_minimumf_endomorphism( 137 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> 138 %arg0: tensor<5xf32>, 139 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> 140 %arg1: tensor<5xf32>) -> tensor<5xf32> { 141 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min 142 : tensor<5xf32> -> tensor<5xf32> 143 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min 144 : tensor<5xf32> -> tensor<5xf32> 145 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]] 146 %2 = arith.minimumf %0, %1 : tensor<5xf32> 147 // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min 148 // CHECK: return %[[ALL_REDUCE_RES]] 149 return %2 : tensor<5xf32> 150} 151 152// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism 153func.func @all_reduce_arith_minsi_endomorphism( 154 // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32> 155 %arg0: tensor<5xi32>, 156 // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32> 157 %arg1: tensor<5xi32>) -> tensor<5xi32> { 158 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min 159 : tensor<5xi32> -> tensor<5xi32> 160 %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min 161 : tensor<5xi32> -> tensor<5xi32> 162 // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]] 163 %2 = arith.minsi %0, %1 : tensor<5xi32> 164 // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min 165 // CHECK: return %[[ALL_REDUCE_RES]] 166 return %2 : tensor<5xi32> 167} 168