1 //===- ShardingInterface.h --------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ 10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ 11 12 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 #include "mlir/IR/Value.h" 15 #include "mlir/Support/LLVM.h" 16 17 namespace mlir { 18 19 class Operation; 20 class IRMapping; 21 class SymbolTableCollection; 22 23 namespace mesh { 24 25 using ShardingArray = SmallVector<SmallVector<MeshAxis>>; 26 using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>; 27 28 struct ShardingOption { 29 // An array of int array. The sub-array at the i-th position signifies the 30 // mesh axes the i-th loop will be sharded on. 31 ShardingArray shardingArray = {}; 32 FlatSymbolRefAttr mesh = nullptr; 33 // `empty` being true indicates that no sharding information can be inferred 34 // at present. Note that it is different from the case where an operation is 35 // not sharded. 36 bool empty = false; 37 ShardingOption() = default; 38 ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) 39 : shardingArray(std::move(shardingArray)), mesh(mesh) {} 40 static ShardingOption makeEmpty() { 41 auto res = ShardingOption(); 42 res.empty = true; 43 return res; 44 } 45 }; 46 47 // This method retrieves the 'MeshSharding' from a given operation 48 // result and includes the 'annotate_for_users' information. 49 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result); 50 51 // This method retrieves the 'MeshSharding' from a given operation 52 // operand and includes the 'annotate_for_users' information. 53 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand); 54 55 namespace detail { 56 57 FailureOr<ShardingOption> 58 defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings, 59 ArrayRef<MeshSharding> resultShardings); 60 61 FailureOr<std::vector<MeshSharding>> 62 defaultGetShardingAnnotations(Operation *op, 63 const ShardingOption &shardingOption); 64 65 LogicalResult 66 defaultAddShardingAnnotations(Operation *op, OpBuilder &b, 67 const ShardingOption &shardingOption); 68 69 } // namespace detail 70 71 // Assumes full replication on all ranked tensor arguments and results. 72 void spmdizeFullyReplicatedOperation(Operation &op, 73 ArrayRef<Value> spmdizedOperands, 74 ArrayRef<MeshSharding> operandShardings, 75 ArrayRef<MeshSharding> resultShardings, 76 IRMapping &spmdizationMap, 77 SymbolTableCollection &symbolTable, 78 OpBuilder &builder); 79 80 } // namespace mesh 81 } // namespace mlir 82 83 /// Include the ODS generated interface header files. 84 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc" 85 86 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ 87