1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H 10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H 11 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 #include "mlir/IR/BuiltinTypeInterfaces.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 19 #include "mlir/Interfaces/InferTypeOpInterface.h" 20 #include "mlir/Interfaces/SideEffectInterfaces.h" 21 #include "llvm/Support/MathExtras.h" 22 23 namespace mlir { 24 namespace mesh { 25 26 using MeshAxis = int16_t; 27 using MeshAxesAttr = DenseI16ArrayAttr; 28 using ShardShapeAttr = DenseI64ArrayAttr; 29 using HaloSizePairAttr = DenseI64ArrayAttr; 30 31 } // namespace mesh 32 } // namespace mlir 33 34 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc" 35 36 #define GET_ATTRDEF_CLASSES 37 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc" 38 39 namespace mlir { 40 namespace mesh { 41 42 class MeshSharding { 43 private: 44 ::mlir::FlatSymbolRefAttr mesh; 45 SmallVector<MeshAxesAttr> split_axes; 46 SmallVector<MeshAxis> partial_axes; 47 ReductionKind partial_type; 48 SmallVector<int64_t> static_halo_sizes; 49 SmallVector<int64_t> static_sharded_dims_offsets; 50 SmallVector<Value> dynamic_halo_sizes; 51 SmallVector<Value> dynamic_sharded_dims_offsets; 52 53 public: 54 MeshSharding() = default; 55 MeshSharding(Value rhs); 56 static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, 57 ArrayRef<MeshAxesAttr> split_axes_, 58 ArrayRef<MeshAxis> partial_axes_ = {}, 59 ReductionKind partial_type_ = ReductionKind::Sum, 60 ArrayRef<int64_t> static_halo_sizes_ = {}, 61 ArrayRef<int64_t> static_sharded_dims_offsets_ = {}, 62 ArrayRef<Value> dynamic_halo_sizes_ = {}, 63 ArrayRef<Value> dynamic_sharded_dims_offsets_ = {}); 64 ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; } 65 ::llvm::StringRef getMesh() const { return mesh.getValue(); } 66 ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; } 67 ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; } 68 ReductionKind getPartialType() const { return partial_type; } 69 ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; } 70 ArrayRef<int64_t> getStaticShardedDimsOffsets() const { 71 return static_sharded_dims_offsets; 72 } 73 ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; } 74 ArrayRef<Value> getDynamicShardedDimsOffsets() const { 75 return dynamic_sharded_dims_offsets; 76 } 77 operator bool() const { return (!mesh) == false; } 78 bool operator==(Value rhs) const; 79 bool operator!=(Value rhs) const; 80 bool operator==(const MeshSharding &rhs) const; 81 bool operator!=(const MeshSharding &rhs) const; 82 bool equalSplitAndPartialAxes(const MeshSharding &rhs) const; 83 bool equalHaloAndShardSizes(const MeshSharding &rhs) const; 84 bool equalHaloSizes(const MeshSharding &rhs) const; 85 bool equalShardSizes(const MeshSharding &rhs) const; 86 }; 87 88 } // namespace mesh 89 } // namespace mlir 90 91 #define GET_TYPEDEF_CLASSES 92 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc" 93 94 #define GET_OP_CLASSES 95 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc" 96 97 namespace mlir { 98 namespace mesh { 99 100 inline bool isReductionLoop(utils::IteratorType iType) { 101 return iType == utils::IteratorType::reduction; 102 } 103 104 // Remove empty subarrays of `array` until a minimum lengh of one is reached. 105 template <typename T> 106 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) { 107 while (array.size() > 1 && array.back().empty()) 108 array.pop_back(); 109 } 110 111 // Is the same tensor replicated on all processes. 112 inline bool isFullReplication(MeshSharding sharding) { 113 return sharding.getPartialAxes().empty() && 114 llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) { 115 return axes.asArrayRef().empty(); 116 }); 117 } 118 119 inline mesh::MeshOp 120 getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, 121 SymbolTableCollection &symbolTableCollection) { 122 return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( 123 op, meshSymbol); 124 } 125 126 inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, 127 SymbolTableCollection &symbolTableCollection) { 128 mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection); 129 assert(meshOp); 130 return meshOp; 131 } 132 133 // Get the corresponding mesh op using the standard attribute nomenclature. 134 template <typename Op> 135 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) { 136 return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection); 137 } 138 139 template <> 140 inline mesh::MeshOp 141 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) { 142 return getMesh( 143 op.getOperation(), 144 cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(), 145 symbolTableCollection); 146 } 147 148 // Get the number of processes that participate in each group 149 // induced by `meshAxes`. 150 template <typename MeshAxesRange, typename MeshShapeRange> 151 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, 152 MeshShapeRange &&meshShape) { 153 int64_t res = 1; 154 155 for (MeshAxis axis : meshAxes) { 156 auto axisSize = *(std::begin(meshShape) + axis); 157 if (ShapedType::isDynamic(axisSize)) { 158 return ShapedType::kDynamic; 159 } 160 res *= axisSize; 161 } 162 163 return res; 164 } 165 166 template <typename MeshAxesRange> 167 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) { 168 return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes), 169 mesh.getShape()); 170 } 171 172 // Get the size of a sharded dimension. 173 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) { 174 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount)) 175 return ShapedType::kDynamic; 176 177 assert(dimSize % shardCount == 0); 178 return dimSize / shardCount; 179 } 180 181 // Get the size of an unsharded dimension. 182 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) { 183 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount)) 184 return ShapedType::kDynamic; 185 186 return dimSize * shardCount; 187 } 188 189 // Return the sharded shape `shape` according ot sharding `sharding`. 190 // The shape for the tensor on each device in the mesh. 191 // Example: 192 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would 193 // result in a shape for each shard of ?x2x?. 194 ShapedType shardShapedType(ShapedType shape, MeshOp mesh, 195 MeshSharding sharding); 196 197 // If ranked tensor type return its sharded counterpart. 198 // 199 // If not ranked tensor type return `type`. 200 // `sharding` in that case must be null. 201 Type shardType(Type type, MeshOp mesh, MeshSharding sharding); 202 203 // Insert shard op if there is not one that already has the same sharding. 204 // May insert resharding if required. 205 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, 206 OpOperand &operand, 207 OpBuilder &builder); 208 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, 209 OpBuilder &builder); 210 void maybeInsertSourceShardingAnnotation(MeshSharding sharding, 211 OpOperand &operand, 212 OpBuilder &builder); 213 214 } // namespace mesh 215 } // namespace mlir 216 217 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H 218