1 //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H 10 #define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H 11 12 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 13 #include "mlir/IR/DialectRegistry.h" 14 15 namespace mlir { 16 namespace mesh { 17 18 // Insert resharding spmdization of the value `sourceShardValue` 19 // from sharding `source` to sharding `target`. 20 // `sourceShardValue` is the already sharded value according to `source`. 21 // 22 // Example 23 // 24 // ```mlir 25 // mesh.mesh @mesh_1d(shape = 2) 26 // ... 27 // %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8> 28 // %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> 29 // ``` 30 // 31 // Will result in 32 // 33 // ```mlir 34 // %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : 35 // tensor<1xi8> -> tensor<2xi8> 36 // ``` 37 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, 38 ShardOp target, 39 TypedValue<ShapedType> sourceShardValue); 40 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, 41 ShardOp target, 42 TypedValue<ShapedType> sourceShardValue, 43 SymbolTableCollection &symbolTableCollection); 44 45 void reshardingRegisterDependentDialects(DialectRegistry ®istry); 46 47 } // namespace mesh 48 } // namespace mlir 49 50 #endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H 51