xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (revision ffc7feadece139c88f0e6930f16bfa9293747adc)
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
12 #include "mlir/Bytecode/BytecodeOpInterface.h"
13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14 #include "mlir/IR/BuiltinTypeInterfaces.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
19 #include "mlir/Interfaces/InferTypeOpInterface.h"
20 #include "mlir/Interfaces/SideEffectInterfaces.h"
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace mesh {
25 
26 using MeshAxis = int16_t;
27 using MeshAxesAttr = DenseI16ArrayAttr;
28 using ShardShapeAttr = DenseI64ArrayAttr;
29 using HaloSizePairAttr = DenseI64ArrayAttr;
30 
31 } // namespace mesh
32 } // namespace mlir
33 
34 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
35 
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
38 
39 namespace mlir {
40 namespace mesh {
41 
42 class MeshSharding {
43 private:
44   ::mlir::FlatSymbolRefAttr mesh;
45   SmallVector<MeshAxesAttr> split_axes;
46   SmallVector<MeshAxis> partial_axes;
47   ReductionKind partial_type;
48   SmallVector<int64_t> static_halo_sizes;
49   SmallVector<int64_t> static_sharded_dims_offsets;
50   SmallVector<Value> dynamic_halo_sizes;
51   SmallVector<Value> dynamic_sharded_dims_offsets;
52 
53 public:
54   MeshSharding() = default;
55   MeshSharding(Value rhs);
56   static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
57                           ArrayRef<MeshAxesAttr> split_axes_,
58                           ArrayRef<MeshAxis> partial_axes_ = {},
59                           ReductionKind partial_type_ = ReductionKind::Sum,
60                           ArrayRef<int64_t> static_halo_sizes_ = {},
61                           ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
62                           ArrayRef<Value> dynamic_halo_sizes_ = {},
63                           ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
64   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
65   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
66   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
67   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
68   ReductionKind getPartialType() const { return partial_type; }
69   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
70   ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
71     return static_sharded_dims_offsets;
72   }
73   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
74   ArrayRef<Value> getDynamicShardedDimsOffsets() const {
75     return dynamic_sharded_dims_offsets;
76   }
77   operator bool() const { return (!mesh) == false; }
78   bool operator==(Value rhs) const;
79   bool operator!=(Value rhs) const;
80   bool operator==(const MeshSharding &rhs) const;
81   bool operator!=(const MeshSharding &rhs) const;
82   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
83   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
84   bool equalHaloSizes(const MeshSharding &rhs) const;
85   bool equalShardSizes(const MeshSharding &rhs) const;
86 };
87 
88 } // namespace mesh
89 } // namespace mlir
90 
91 #define GET_TYPEDEF_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
93 
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
96 
97 namespace mlir {
98 namespace mesh {
99 
100 inline bool isReductionLoop(utils::IteratorType iType) {
101   return iType == utils::IteratorType::reduction;
102 }
103 
104 // Remove empty subarrays of `array` until a minimum lengh of one is reached.
105 template <typename T>
106 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
107   while (array.size() > 1 && array.back().empty())
108     array.pop_back();
109 }
110 
111 // Is the same tensor replicated on all processes.
112 inline bool isFullReplication(MeshSharding sharding) {
113   return sharding.getPartialAxes().empty() &&
114          llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
115            return axes.asArrayRef().empty();
116          });
117 }
118 
119 inline mesh::MeshOp
120 getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
121               SymbolTableCollection &symbolTableCollection) {
122   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
123       op, meshSymbol);
124 }
125 
126 inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
127                             SymbolTableCollection &symbolTableCollection) {
128   mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
129   assert(meshOp);
130   return meshOp;
131 }
132 
133 // Get the corresponding mesh op using the standard attribute nomenclature.
134 template <typename Op>
135 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
136   return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
137 }
138 
139 template <>
140 inline mesh::MeshOp
141 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
142   return getMesh(
143       op.getOperation(),
144       cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
145       symbolTableCollection);
146 }
147 
148 // Get the number of processes that participate in each group
149 // induced by `meshAxes`.
150 template <typename MeshAxesRange, typename MeshShapeRange>
151 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
152                                    MeshShapeRange &&meshShape) {
153   int64_t res = 1;
154 
155   for (MeshAxis axis : meshAxes) {
156     auto axisSize = *(std::begin(meshShape) + axis);
157     if (ShapedType::isDynamic(axisSize)) {
158       return ShapedType::kDynamic;
159     }
160     res *= axisSize;
161   }
162 
163   return res;
164 }
165 
166 template <typename MeshAxesRange>
167 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
168   return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
169                                     mesh.getShape());
170 }
171 
172 // Get the size of a sharded dimension.
173 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
174   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
175     return ShapedType::kDynamic;
176 
177   assert(dimSize % shardCount == 0);
178   return dimSize / shardCount;
179 }
180 
181 // Get the size of an unsharded dimension.
182 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
183   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
184     return ShapedType::kDynamic;
185 
186   return dimSize * shardCount;
187 }
188 
189 // Return the sharded shape `shape` according ot sharding `sharding`.
190 // The shape for the tensor on each device in the mesh.
191 // Example:
192 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
193 // result in a shape for each shard of ?x2x?.
194 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
195                            MeshSharding sharding);
196 
197 // If ranked tensor type return its sharded counterpart.
198 //
199 // If not ranked tensor type return `type`.
200 // `sharding` in that case must be null.
201 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
202 
203 // Insert shard op if there is not one that already has the same sharding.
204 // May insert resharding if required.
205 void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
206                                          OpOperand &operand,
207                                          OpBuilder &builder);
208 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
209                                          OpBuilder &builder);
210 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
211                                          OpOperand &operand,
212                                          OpBuilder &builder);
213 
214 } // namespace mesh
215 } // namespace mlir
216 
217 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
218