xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
14 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15 #include "mlir/IR/Verifier.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Pass/Pass.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <algorithm>
24 #include <vector>
25 
26 namespace mlir {
27 namespace mesh {
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
30 } // namespace mesh
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35 
36 using namespace mlir;
37 using namespace mlir::mesh;
38 
39 enum class ReshardingRquirementKind {
40   NO_RESHARDING = 0,
41   NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
42   RESHARDING_FOR_EXPLICIT_ANNOTATIONS
43 };
44 
45 #ifdef LLVM_DEBUG
46 
47 template <typename T>
48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49                                      const SmallVector<T> &vec);
50 template <typename... Ts>
51 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52                                      const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
54                                      ReshardingRquirementKind v);
55 
56 template <typename Stream, typename Range>
57 static Stream &printRange(Stream &stream, Range &&range) {
58   stream << "[";
59   llvm::for_each(range, [&stream](auto &v) {
60     stream << v;
61     stream << ", ";
62   });
63   return stream << "]";
64 }
65 
66 template <typename T>
67 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
68                                      const SmallVector<T> &vec) {
69   return printRange(stream, vec);
70 }
71 
72 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
73                                                       const ShardingOption &v) {
74   return stream << "{empty = " << v.empty << ", mesh" << v.mesh
75                 << ", shardingArray = " << v.shardingArray << "}";
76 }
77 
78 template <typename Stream, typename... Ts, size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80                           std::index_sequence<Is...>) {
81   static_assert(sizeof...(Is) == sizeof...(Ts),
82                 "Indices must have same number of elements as tuple types!");
83   static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
84 
85   stream << "{";
86   ((stream << std::get<Is>(tuple) << ", "), ...);
87   return stream << "}";
88 }
89 
90 template <typename... Ts>
91 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
92                                      const std::tuple<Ts...> &t) {
93   return printTuple(stream, t, std::index_sequence_for<Ts...>{});
94 }
95 
96 [[maybe_unused]] static llvm::raw_ostream &
97 operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
98   return stream << static_cast<int>(v);
99 }
100 
101 #endif // LLVM_DEBUG
102 
103 //===----------------------------------------------------------------------===//
104 // Utilities
105 //===----------------------------------------------------------------------===//
106 
107 // This method retrieves all potential sharding attributes, prioritizing
108 // specific shardings. For example, mustShardings = [shard0, None] and
109 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
110 // [shard0, None]]
111 static SmallVector<std::vector<MeshSharding>>
112 getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
113                                 ArrayRef<MeshSharding> optionalShardings) {
114   SmallVector<std::vector<MeshSharding>> allShardingAttrs;
115   std::vector<MeshSharding> curShardingAttrs;
116 
117   std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
118     if (i == mustShardings.size()) {
119       allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
120       return;
121     }
122 
123     if (mustShardings[i]) {
124       curShardingAttrs.push_back(mustShardings[i]);
125       dfsCreateShardingAttrs(i + 1);
126       curShardingAttrs.pop_back();
127       return;
128     }
129 
130     if (optionalShardings[i]) {
131       curShardingAttrs.push_back(optionalShardings[i]);
132       dfsCreateShardingAttrs(i + 1);
133       curShardingAttrs.pop_back();
134       curShardingAttrs.push_back({});
135       dfsCreateShardingAttrs(i + 1);
136       curShardingAttrs.pop_back();
137       return;
138     }
139 
140     curShardingAttrs.push_back({});
141     dfsCreateShardingAttrs(i + 1);
142     curShardingAttrs.pop_back();
143   };
144 
145   dfsCreateShardingAttrs(0);
146   return allShardingAttrs;
147 }
148 
149 // The order of preference is form highest to lowest:
150 // 1. No resharding is required (all existing annotations are compatible).
151 // 2. No resharding for operands/results that have annotation specifically
152 //   targeting this operation. This means
153 //   * operands that are the result of `mesh.shard` ops marked with
154 //     `annotate_for_users`.
155 //   * results that are annotated with `mesh.shard` ops without
156 //     `annotate_for_users`.
157 // 3. All other cases. Resharding is required for operands/results with
158 //   annotation targeting explicitly this operation.
159 ReshardingRquirementKind getReshardingRquirementKind(
160     Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
161   ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
162 
163   size_t operandsCount = op->getOperands().size();
164   auto operandShardings =
165       llvm::make_range(operandAndResultShardings.begin(),
166                        operandAndResultShardings.begin() + operandsCount);
167   auto resultShardings =
168       llvm::make_range(operandAndResultShardings.begin() + operandsCount,
169                        operandAndResultShardings.end());
170 
171   for (auto [operand, sharding] :
172        llvm::zip_equal(op->getOperands(), operandShardings)) {
173     ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
174     if (!shardOp) {
175       continue;
176     }
177     bool needsResharding = sharding != shardOp.getSharding();
178     bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179     if (needsResharding) {
180       if (isExplicitAnnotationForThisOp) {
181         // This is the worst case. No need to continue.
182         return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
183       }
184       res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
185     }
186   }
187 
188   for (auto [result, sharding] :
189        llvm::zip_equal(op->getResults(), resultShardings)) {
190     for (auto user : result.getUsers()) {
191       ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
192       if (!shardOp) {
193         continue;
194       }
195       bool needsResharding = sharding != shardOp.getSharding();
196       bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197       if (needsResharding) {
198         if (isExplicitAnnotationForThisOp) {
199           // This is the worst case. No need to continue.
200           return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
201         }
202         res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
203       }
204     }
205   }
206 
207   return res;
208 }
209 
210 // From all the operand and result sharding combinations,
211 // return the one that is most desirable.
212 // The order of preference is:
213 // 1. No resharding with respect to existing sharding annotations.
214 // 2. Resharding for values that have already annotations that do not target
215 //    this op.
216 // 3. Resharding of existing explicit sharding annotations for this op.
217 static FailureOr<ShardingOption> selectShardingOption(
218     ShardingInterface shardingOp,
219     ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
220     ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
221   SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
222       shardingOptionsAndReshardingRequirements;
223 
224   for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
225     for (ArrayRef<MeshSharding> operandShardings :
226          possibleOperandShardingAttrs) {
227       FailureOr<ShardingOption> shardingOption =
228           shardingOp.getShardingOption(operandShardings, resultShardings);
229       if (failed(shardingOption) || shardingOption->empty) {
230         continue;
231       }
232       // These shardings may not be the same as those in operandShardings and
233       // resultShardings.
234       // They may be missing some annotations.
235       // Whatever is returned by getShardingAnnotations is exactly what the op
236       // needs.
237       FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
238           shardingOp.getShardingAnnotations(*shardingOption);
239       if (failed(operandAndResultShardings)) {
240         return failure();
241       }
242 
243       // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
244       //                   << *operandAndResultShardings << "\n";);
245 
246       ReshardingRquirementKind reshardingRquirement =
247           getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
248       if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
249         // This is the best case. No need to go on.
250         return *shardingOption;
251       }
252 
253       shardingOptionsAndReshardingRequirements.emplace_back(
254           std::move(*shardingOption), reshardingRquirement);
255     }
256   }
257 
258   if (shardingOptionsAndReshardingRequirements.empty()) {
259     return ShardingOption::makeEmpty();
260   }
261 
262   std::partial_sort(
263       shardingOptionsAndReshardingRequirements.begin(),
264       shardingOptionsAndReshardingRequirements.begin() + 1,
265       shardingOptionsAndReshardingRequirements.end(),
266       [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267          const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268         return std::get<ReshardingRquirementKind>(a) <
269                std::get<ReshardingRquirementKind>(b);
270       });
271 
272   LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
273                     << shardingOptionsAndReshardingRequirements << "\n";);
274 
275   return std::get<ShardingOption>(
276       shardingOptionsAndReshardingRequirements.front());
277 }
278 
279 // For each operation that implements the ShardingInterface, infer the sharding
280 // option of the operation from its operands and/or results using the
281 // `getShardingOption` method. If the inferred sharding option is not empty, add
282 // a `mesh.shard` operation for all remaining operands and results that do not
283 // have sharding annotations.
284 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
285   if (op->hasTrait<OpTrait::IsTerminator>() ||
286       llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
287     return success();
288 
289   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
290   if (!shardingOp) {
291     op->emitOpError() << "sharding interface is not implemented.";
292     return failure();
293   }
294 
295   // collect MeshSharding from results
296   std::vector<MeshSharding> allowConflictsResultShardings;
297   allowConflictsResultShardings.resize(op->getNumResults());
298   std::vector<MeshSharding> resultMustShardings;
299   resultMustShardings.resize(op->getNumResults());
300   for (OpResult result : op->getResults()) {
301     FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
302         getMeshSharding(result);
303     if (failed(maybeShardAttr))
304       continue;
305     if (!maybeShardAttr->first)
306       resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
307     else
308       allowConflictsResultShardings[result.getResultNumber()] =
309           maybeShardAttr->second;
310   }
311 
312   // collect MeshSharding from operands
313   std::vector<MeshSharding> allowConflictsOperandShardings;
314   allowConflictsOperandShardings.resize(op->getNumOperands());
315   std::vector<MeshSharding> operandMustShardings;
316   operandMustShardings.resize(op->getNumOperands());
317   for (OpOperand &opOperand : op->getOpOperands()) {
318     FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
319         getMeshSharding(opOperand);
320     if (failed(maybeShardAttr))
321       continue;
322 
323     if (maybeShardAttr->first)
324       operandMustShardings[opOperand.getOperandNumber()] =
325           maybeShardAttr->second;
326     else
327       allowConflictsOperandShardings[opOperand.getOperandNumber()] =
328           maybeShardAttr->second;
329   }
330 
331   // try to get the sharding option
332   SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
333       getOrderedPossibleShardingAttrs(operandMustShardings,
334                                       allowConflictsOperandShardings);
335   SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
336       getOrderedPossibleShardingAttrs(resultMustShardings,
337                                       allowConflictsResultShardings);
338   FailureOr<ShardingOption> shardingOption = selectShardingOption(
339       shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
340 
341   if (failed(shardingOption)) {
342     op->emitOpError() << "fail to get sharding option.";
343     return failure();
344   }
345 
346   LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
347 
348   // sharding info is empty, return immediately
349   if (shardingOption->empty)
350     return success();
351 
352   if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
353     op->emitOpError() << "fail to set sharding annotations.";
354     return failure();
355   }
356   return success();
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // ShardingPropagation
361 //===----------------------------------------------------------------------===//
362 struct ShardingPropagation
363     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
364   void runOnOperation() override {
365     FunctionOpInterface funcOp = getOperation();
366     MLIRContext *ctx = funcOp.getContext();
367     Region &region = funcOp.getFunctionBody();
368     OpBuilder builder(ctx);
369     if (!region.hasOneBlock()) {
370       funcOp.emitOpError() << "only one block is supported!";
371       signalPassFailure();
372     }
373     Block &block = region.front();
374 
375     LLVM_DEBUG(
376         DBGS() << "print all the ops' iterator types and indexing maps in the "
377                   "block.\n";
378         for (Operation &op
379              : block.getOperations()) {
380           if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
381             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
382         });
383 
384     // 1. propagate in reversed order
385     for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
386       if (failed(visitOp(&op, builder)))
387         return signalPassFailure();
388 
389     LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
390                       << funcOp << "\n");
391     LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
392 
393     // 2. propagate in original order
394     for (Operation &op : llvm::make_early_inc_range(block))
395       if (failed(visitOp(&op, builder)))
396         return signalPassFailure();
397   }
398 };
399