xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1b0d5b4d2SChengji Yao //===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2b0d5b4d2SChengji Yao //
3b0d5b4d2SChengji Yao // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b0d5b4d2SChengji Yao // See https://llvm.org/LICENSE.txt for license information.
5b0d5b4d2SChengji Yao // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b0d5b4d2SChengji Yao //
7b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
8b0d5b4d2SChengji Yao 
9b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Transforms/Passes.h"
10b0d5b4d2SChengji Yao 
11b0d5b4d2SChengji Yao #include "mlir/Dialect/Func/IR/FuncOps.h"
1231fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/IR/MeshOps.h"
14b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15d635b860SBoian Petkantchin #include "mlir/IR/Verifier.h"
16abfac563SBoian Petkantchin #include "mlir/Interfaces/FunctionInterfaces.h"
17b0d5b4d2SChengji Yao #include "mlir/Pass/Pass.h"
18d635b860SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
19d635b860SBoian Petkantchin #include "llvm/ADT/SmallVector.h"
20d635b860SBoian Petkantchin #include "llvm/ADT/iterator_range.h"
21b0d5b4d2SChengji Yao #include "llvm/Support/Debug.h"
22d635b860SBoian Petkantchin #include "llvm/Support/raw_ostream.h"
23d635b860SBoian Petkantchin #include <algorithm>
24b0d5b4d2SChengji Yao #include <vector>
25b0d5b4d2SChengji Yao 
26b0d5b4d2SChengji Yao namespace mlir {
27b0d5b4d2SChengji Yao namespace mesh {
28b0d5b4d2SChengji Yao #define GEN_PASS_DEF_SHARDINGPROPAGATION
29b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
30b0d5b4d2SChengji Yao } // namespace mesh
31b0d5b4d2SChengji Yao } // namespace mlir
32b0d5b4d2SChengji Yao 
33b0d5b4d2SChengji Yao #define DEBUG_TYPE "sharding-propagation"
34b0d5b4d2SChengji Yao #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35b0d5b4d2SChengji Yao 
36b0d5b4d2SChengji Yao using namespace mlir;
37b0d5b4d2SChengji Yao using namespace mlir::mesh;
38b0d5b4d2SChengji Yao 
39d635b860SBoian Petkantchin enum class ReshardingRquirementKind {
40d635b860SBoian Petkantchin   NO_RESHARDING = 0,
41d635b860SBoian Petkantchin   NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
42d635b860SBoian Petkantchin   RESHARDING_FOR_EXPLICIT_ANNOTATIONS
43d635b860SBoian Petkantchin };
44d635b860SBoian Petkantchin 
45d635b860SBoian Petkantchin #ifdef LLVM_DEBUG
46d635b860SBoian Petkantchin 
47d635b860SBoian Petkantchin template <typename T>
48d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49d635b860SBoian Petkantchin                                      const SmallVector<T> &vec);
50d635b860SBoian Petkantchin template <typename... Ts>
51d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52d635b860SBoian Petkantchin                                      const std::tuple<Ts...> &t);
53d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
54d635b860SBoian Petkantchin                                      ReshardingRquirementKind v);
55d635b860SBoian Petkantchin 
56d635b860SBoian Petkantchin template <typename Stream, typename Range>
57d635b860SBoian Petkantchin static Stream &printRange(Stream &stream, Range &&range) {
58d635b860SBoian Petkantchin   stream << "[";
59d635b860SBoian Petkantchin   llvm::for_each(range, [&stream](auto &v) {
60d635b860SBoian Petkantchin     stream << v;
61d635b860SBoian Petkantchin     stream << ", ";
62d635b860SBoian Petkantchin   });
63d635b860SBoian Petkantchin   return stream << "]";
64d635b860SBoian Petkantchin }
65d635b860SBoian Petkantchin 
66d635b860SBoian Petkantchin template <typename T>
67d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
68d635b860SBoian Petkantchin                                      const SmallVector<T> &vec) {
69d635b860SBoian Petkantchin   return printRange(stream, vec);
70d635b860SBoian Petkantchin }
71d635b860SBoian Petkantchin 
72bc0cdeffSKazu Hirata [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
73d635b860SBoian Petkantchin                                                       const ShardingOption &v) {
74d635b860SBoian Petkantchin   return stream << "{empty = " << v.empty << ", mesh" << v.mesh
75d635b860SBoian Petkantchin                 << ", shardingArray = " << v.shardingArray << "}";
76d635b860SBoian Petkantchin }
77d635b860SBoian Petkantchin 
78d635b860SBoian Petkantchin template <typename Stream, typename... Ts, size_t... Is>
79d635b860SBoian Petkantchin static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80d635b860SBoian Petkantchin                           std::index_sequence<Is...>) {
81d635b860SBoian Petkantchin   static_assert(sizeof...(Is) == sizeof...(Ts),
82d635b860SBoian Petkantchin                 "Indices must have same number of elements as tuple types!");
83d635b860SBoian Petkantchin   static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
84d635b860SBoian Petkantchin 
85d635b860SBoian Petkantchin   stream << "{";
86d635b860SBoian Petkantchin   ((stream << std::get<Is>(tuple) << ", "), ...);
87d635b860SBoian Petkantchin   return stream << "}";
88d635b860SBoian Petkantchin }
89d635b860SBoian Petkantchin 
90d635b860SBoian Petkantchin template <typename... Ts>
91d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
92d635b860SBoian Petkantchin                                      const std::tuple<Ts...> &t) {
93d635b860SBoian Petkantchin   return printTuple(stream, t, std::index_sequence_for<Ts...>{});
94d635b860SBoian Petkantchin }
95d635b860SBoian Petkantchin 
96bc0cdeffSKazu Hirata [[maybe_unused]] static llvm::raw_ostream &
97bc0cdeffSKazu Hirata operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
98d635b860SBoian Petkantchin   return stream << static_cast<int>(v);
99d635b860SBoian Petkantchin }
100d635b860SBoian Petkantchin 
101d635b860SBoian Petkantchin #endif // LLVM_DEBUG
102d635b860SBoian Petkantchin 
103b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
104b0d5b4d2SChengji Yao // Utilities
105b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
106b0d5b4d2SChengji Yao 
107b0d5b4d2SChengji Yao // This method retrieves all potential sharding attributes, prioritizing
108b0d5b4d2SChengji Yao // specific shardings. For example, mustShardings = [shard0, None] and
109b0d5b4d2SChengji Yao // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
110b0d5b4d2SChengji Yao // [shard0, None]]
111*baabcb28SFrank Schlimbach static SmallVector<std::vector<MeshSharding>>
112*baabcb28SFrank Schlimbach getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
113*baabcb28SFrank Schlimbach                                 ArrayRef<MeshSharding> optionalShardings) {
114*baabcb28SFrank Schlimbach   SmallVector<std::vector<MeshSharding>> allShardingAttrs;
115*baabcb28SFrank Schlimbach   std::vector<MeshSharding> curShardingAttrs;
116b0d5b4d2SChengji Yao 
117b0d5b4d2SChengji Yao   std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
118b0d5b4d2SChengji Yao     if (i == mustShardings.size()) {
119*baabcb28SFrank Schlimbach       allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
120b0d5b4d2SChengji Yao       return;
121b0d5b4d2SChengji Yao     }
122b0d5b4d2SChengji Yao 
123b0d5b4d2SChengji Yao     if (mustShardings[i]) {
124b0d5b4d2SChengji Yao       curShardingAttrs.push_back(mustShardings[i]);
125b0d5b4d2SChengji Yao       dfsCreateShardingAttrs(i + 1);
126b0d5b4d2SChengji Yao       curShardingAttrs.pop_back();
127b0d5b4d2SChengji Yao       return;
128b0d5b4d2SChengji Yao     }
129b0d5b4d2SChengji Yao 
130b0d5b4d2SChengji Yao     if (optionalShardings[i]) {
131b0d5b4d2SChengji Yao       curShardingAttrs.push_back(optionalShardings[i]);
132b0d5b4d2SChengji Yao       dfsCreateShardingAttrs(i + 1);
133b0d5b4d2SChengji Yao       curShardingAttrs.pop_back();
134*baabcb28SFrank Schlimbach       curShardingAttrs.push_back({});
135b0d5b4d2SChengji Yao       dfsCreateShardingAttrs(i + 1);
136b0d5b4d2SChengji Yao       curShardingAttrs.pop_back();
137b0d5b4d2SChengji Yao       return;
138b0d5b4d2SChengji Yao     }
139b0d5b4d2SChengji Yao 
140*baabcb28SFrank Schlimbach     curShardingAttrs.push_back({});
141b0d5b4d2SChengji Yao     dfsCreateShardingAttrs(i + 1);
142b0d5b4d2SChengji Yao     curShardingAttrs.pop_back();
143b0d5b4d2SChengji Yao   };
144b0d5b4d2SChengji Yao 
145b0d5b4d2SChengji Yao   dfsCreateShardingAttrs(0);
146b0d5b4d2SChengji Yao   return allShardingAttrs;
147b0d5b4d2SChengji Yao }
148b0d5b4d2SChengji Yao 
149d635b860SBoian Petkantchin // The order of preference is form highest to lowest:
150d635b860SBoian Petkantchin // 1. No resharding is required (all existing annotations are compatible).
151d635b860SBoian Petkantchin // 2. No resharding for operands/results that have annotation specifically
152d635b860SBoian Petkantchin //   targeting this operation. This means
153d635b860SBoian Petkantchin //   * operands that are the result of `mesh.shard` ops marked with
154d635b860SBoian Petkantchin //     `annotate_for_users`.
155d635b860SBoian Petkantchin //   * results that are annotated with `mesh.shard` ops without
156d635b860SBoian Petkantchin //     `annotate_for_users`.
157d635b860SBoian Petkantchin // 3. All other cases. Resharding is required for operands/results with
158d635b860SBoian Petkantchin //   annotation targeting explicitly this operation.
159d635b860SBoian Petkantchin ReshardingRquirementKind getReshardingRquirementKind(
160*baabcb28SFrank Schlimbach     Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
161d635b860SBoian Petkantchin   ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
162d635b860SBoian Petkantchin 
163d635b860SBoian Petkantchin   size_t operandsCount = op->getOperands().size();
164d635b860SBoian Petkantchin   auto operandShardings =
165d635b860SBoian Petkantchin       llvm::make_range(operandAndResultShardings.begin(),
166d635b860SBoian Petkantchin                        operandAndResultShardings.begin() + operandsCount);
167d635b860SBoian Petkantchin   auto resultShardings =
168d635b860SBoian Petkantchin       llvm::make_range(operandAndResultShardings.begin() + operandsCount,
169d635b860SBoian Petkantchin                        operandAndResultShardings.end());
170d635b860SBoian Petkantchin 
171d635b860SBoian Petkantchin   for (auto [operand, sharding] :
172d635b860SBoian Petkantchin        llvm::zip_equal(op->getOperands(), operandShardings)) {
173d635b860SBoian Petkantchin     ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
174d635b860SBoian Petkantchin     if (!shardOp) {
175d635b860SBoian Petkantchin       continue;
176d635b860SBoian Petkantchin     }
177*baabcb28SFrank Schlimbach     bool needsResharding = sharding != shardOp.getSharding();
178d635b860SBoian Petkantchin     bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179d635b860SBoian Petkantchin     if (needsResharding) {
180d635b860SBoian Petkantchin       if (isExplicitAnnotationForThisOp) {
181d635b860SBoian Petkantchin         // This is the worst case. No need to continue.
182d635b860SBoian Petkantchin         return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
183d635b860SBoian Petkantchin       }
184d635b860SBoian Petkantchin       res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
185d635b860SBoian Petkantchin     }
186d635b860SBoian Petkantchin   }
187d635b860SBoian Petkantchin 
188d635b860SBoian Petkantchin   for (auto [result, sharding] :
189d635b860SBoian Petkantchin        llvm::zip_equal(op->getResults(), resultShardings)) {
190d635b860SBoian Petkantchin     for (auto user : result.getUsers()) {
191d635b860SBoian Petkantchin       ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
192d635b860SBoian Petkantchin       if (!shardOp) {
193d635b860SBoian Petkantchin         continue;
194d635b860SBoian Petkantchin       }
195*baabcb28SFrank Schlimbach       bool needsResharding = sharding != shardOp.getSharding();
196d635b860SBoian Petkantchin       bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197d635b860SBoian Petkantchin       if (needsResharding) {
198d635b860SBoian Petkantchin         if (isExplicitAnnotationForThisOp) {
199d635b860SBoian Petkantchin           // This is the worst case. No need to continue.
200d635b860SBoian Petkantchin           return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
201d635b860SBoian Petkantchin         }
202d635b860SBoian Petkantchin         res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
203d635b860SBoian Petkantchin       }
204d635b860SBoian Petkantchin     }
205d635b860SBoian Petkantchin   }
206d635b860SBoian Petkantchin 
207d635b860SBoian Petkantchin   return res;
208d635b860SBoian Petkantchin }
209d635b860SBoian Petkantchin 
210d635b860SBoian Petkantchin // From all the operand and result sharding combinations,
211d635b860SBoian Petkantchin // return the one that is most desirable.
212d635b860SBoian Petkantchin // The order of preference is:
213d635b860SBoian Petkantchin // 1. No resharding with respect to existing sharding annotations.
214d635b860SBoian Petkantchin // 2. Resharding for values that have already annotations that do not target
215d635b860SBoian Petkantchin //    this op.
216d635b860SBoian Petkantchin // 3. Resharding of existing explicit sharding annotations for this op.
217d635b860SBoian Petkantchin static FailureOr<ShardingOption> selectShardingOption(
218d635b860SBoian Petkantchin     ShardingInterface shardingOp,
219*baabcb28SFrank Schlimbach     ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
220*baabcb28SFrank Schlimbach     ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
221d635b860SBoian Petkantchin   SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
222d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements;
223d635b860SBoian Petkantchin 
224*baabcb28SFrank Schlimbach   for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
225*baabcb28SFrank Schlimbach     for (ArrayRef<MeshSharding> operandShardings :
226d635b860SBoian Petkantchin          possibleOperandShardingAttrs) {
227d635b860SBoian Petkantchin       FailureOr<ShardingOption> shardingOption =
228d635b860SBoian Petkantchin           shardingOp.getShardingOption(operandShardings, resultShardings);
229d635b860SBoian Petkantchin       if (failed(shardingOption) || shardingOption->empty) {
230d635b860SBoian Petkantchin         continue;
231d635b860SBoian Petkantchin       }
232d635b860SBoian Petkantchin       // These shardings may not be the same as those in operandShardings and
233d635b860SBoian Petkantchin       // resultShardings.
234d635b860SBoian Petkantchin       // They may be missing some annotations.
235d635b860SBoian Petkantchin       // Whatever is returned by getShardingAnnotations is exactly what the op
236d635b860SBoian Petkantchin       // needs.
237*baabcb28SFrank Schlimbach       FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
238d635b860SBoian Petkantchin           shardingOp.getShardingAnnotations(*shardingOption);
239d635b860SBoian Petkantchin       if (failed(operandAndResultShardings)) {
240d635b860SBoian Petkantchin         return failure();
241d635b860SBoian Petkantchin       }
242d635b860SBoian Petkantchin 
243*baabcb28SFrank Schlimbach       // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
244*baabcb28SFrank Schlimbach       //                   << *operandAndResultShardings << "\n";);
245d635b860SBoian Petkantchin 
246d635b860SBoian Petkantchin       ReshardingRquirementKind reshardingRquirement =
247d635b860SBoian Petkantchin           getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
248d635b860SBoian Petkantchin       if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
249d635b860SBoian Petkantchin         // This is the best case. No need to go on.
250d635b860SBoian Petkantchin         return *shardingOption;
251d635b860SBoian Petkantchin       }
252d635b860SBoian Petkantchin 
253d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements.emplace_back(
254d635b860SBoian Petkantchin           std::move(*shardingOption), reshardingRquirement);
255d635b860SBoian Petkantchin     }
256d635b860SBoian Petkantchin   }
257d635b860SBoian Petkantchin 
258d635b860SBoian Petkantchin   if (shardingOptionsAndReshardingRequirements.empty()) {
259d635b860SBoian Petkantchin     return ShardingOption::makeEmpty();
260d635b860SBoian Petkantchin   }
261d635b860SBoian Petkantchin 
262d635b860SBoian Petkantchin   std::partial_sort(
263d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements.begin(),
264d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements.begin() + 1,
265d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements.end(),
266d635b860SBoian Petkantchin       [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267d635b860SBoian Petkantchin          const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268d635b860SBoian Petkantchin         return std::get<ReshardingRquirementKind>(a) <
269d635b860SBoian Petkantchin                std::get<ReshardingRquirementKind>(b);
270d635b860SBoian Petkantchin       });
271d635b860SBoian Petkantchin 
272d635b860SBoian Petkantchin   LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
273d635b860SBoian Petkantchin                     << shardingOptionsAndReshardingRequirements << "\n";);
274d635b860SBoian Petkantchin 
275d635b860SBoian Petkantchin   return std::get<ShardingOption>(
276d635b860SBoian Petkantchin       shardingOptionsAndReshardingRequirements.front());
277d635b860SBoian Petkantchin }
278d635b860SBoian Petkantchin 
279b0d5b4d2SChengji Yao // For each operation that implements the ShardingInterface, infer the sharding
280b0d5b4d2SChengji Yao // option of the operation from its operands and/or results using the
281b0d5b4d2SChengji Yao // `getShardingOption` method. If the inferred sharding option is not empty, add
282b0d5b4d2SChengji Yao // a `mesh.shard` operation for all remaining operands and results that do not
283b0d5b4d2SChengji Yao // have sharding annotations.
284adbf21f1SBoian Petkantchin static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
285*baabcb28SFrank Schlimbach   if (op->hasTrait<OpTrait::IsTerminator>() ||
286*baabcb28SFrank Schlimbach       llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
287b0d5b4d2SChengji Yao     return success();
288b0d5b4d2SChengji Yao 
289b0d5b4d2SChengji Yao   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
290b0d5b4d2SChengji Yao   if (!shardingOp) {
291b0d5b4d2SChengji Yao     op->emitOpError() << "sharding interface is not implemented.";
292b0d5b4d2SChengji Yao     return failure();
293b0d5b4d2SChengji Yao   }
294b0d5b4d2SChengji Yao 
295*baabcb28SFrank Schlimbach   // collect MeshSharding from results
296*baabcb28SFrank Schlimbach   std::vector<MeshSharding> allowConflictsResultShardings;
297b0d5b4d2SChengji Yao   allowConflictsResultShardings.resize(op->getNumResults());
298*baabcb28SFrank Schlimbach   std::vector<MeshSharding> resultMustShardings;
299b0d5b4d2SChengji Yao   resultMustShardings.resize(op->getNumResults());
300b0d5b4d2SChengji Yao   for (OpResult result : op->getResults()) {
301*baabcb28SFrank Schlimbach     FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
302*baabcb28SFrank Schlimbach         getMeshSharding(result);
303b0d5b4d2SChengji Yao     if (failed(maybeShardAttr))
304b0d5b4d2SChengji Yao       continue;
305b0d5b4d2SChengji Yao     if (!maybeShardAttr->first)
306b0d5b4d2SChengji Yao       resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
307b0d5b4d2SChengji Yao     else
308b0d5b4d2SChengji Yao       allowConflictsResultShardings[result.getResultNumber()] =
309b0d5b4d2SChengji Yao           maybeShardAttr->second;
310b0d5b4d2SChengji Yao   }
311b0d5b4d2SChengji Yao 
312*baabcb28SFrank Schlimbach   // collect MeshSharding from operands
313*baabcb28SFrank Schlimbach   std::vector<MeshSharding> allowConflictsOperandShardings;
314b0d5b4d2SChengji Yao   allowConflictsOperandShardings.resize(op->getNumOperands());
315*baabcb28SFrank Schlimbach   std::vector<MeshSharding> operandMustShardings;
316b0d5b4d2SChengji Yao   operandMustShardings.resize(op->getNumOperands());
317b0d5b4d2SChengji Yao   for (OpOperand &opOperand : op->getOpOperands()) {
318*baabcb28SFrank Schlimbach     FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
319*baabcb28SFrank Schlimbach         getMeshSharding(opOperand);
320b0d5b4d2SChengji Yao     if (failed(maybeShardAttr))
321b0d5b4d2SChengji Yao       continue;
322b0d5b4d2SChengji Yao 
323b0d5b4d2SChengji Yao     if (maybeShardAttr->first)
324b0d5b4d2SChengji Yao       operandMustShardings[opOperand.getOperandNumber()] =
325b0d5b4d2SChengji Yao           maybeShardAttr->second;
326b0d5b4d2SChengji Yao     else
327b0d5b4d2SChengji Yao       allowConflictsOperandShardings[opOperand.getOperandNumber()] =
328b0d5b4d2SChengji Yao           maybeShardAttr->second;
329b0d5b4d2SChengji Yao   }
330b0d5b4d2SChengji Yao 
331b0d5b4d2SChengji Yao   // try to get the sharding option
332*baabcb28SFrank Schlimbach   SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
333b0d5b4d2SChengji Yao       getOrderedPossibleShardingAttrs(operandMustShardings,
334b0d5b4d2SChengji Yao                                       allowConflictsOperandShardings);
335*baabcb28SFrank Schlimbach   SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
336b0d5b4d2SChengji Yao       getOrderedPossibleShardingAttrs(resultMustShardings,
337b0d5b4d2SChengji Yao                                       allowConflictsResultShardings);
338d635b860SBoian Petkantchin   FailureOr<ShardingOption> shardingOption = selectShardingOption(
339d635b860SBoian Petkantchin       shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
340b0d5b4d2SChengji Yao 
341d635b860SBoian Petkantchin   if (failed(shardingOption)) {
342b0d5b4d2SChengji Yao     op->emitOpError() << "fail to get sharding option.";
343b0d5b4d2SChengji Yao     return failure();
344b0d5b4d2SChengji Yao   }
345d635b860SBoian Petkantchin 
346d635b860SBoian Petkantchin   LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
347d635b860SBoian Petkantchin 
348b0d5b4d2SChengji Yao   // sharding info is empty, return immediately
349d635b860SBoian Petkantchin   if (shardingOption->empty)
350b0d5b4d2SChengji Yao     return success();
351b0d5b4d2SChengji Yao 
352d635b860SBoian Petkantchin   if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
353b0d5b4d2SChengji Yao     op->emitOpError() << "fail to set sharding annotations.";
354b0d5b4d2SChengji Yao     return failure();
355b0d5b4d2SChengji Yao   }
356b0d5b4d2SChengji Yao   return success();
357b0d5b4d2SChengji Yao }
358b0d5b4d2SChengji Yao 
359b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
360b0d5b4d2SChengji Yao // ShardingPropagation
361b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
362b0d5b4d2SChengji Yao struct ShardingPropagation
363b0d5b4d2SChengji Yao     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
364b0d5b4d2SChengji Yao   void runOnOperation() override {
365abfac563SBoian Petkantchin     FunctionOpInterface funcOp = getOperation();
366b0d5b4d2SChengji Yao     MLIRContext *ctx = funcOp.getContext();
367abfac563SBoian Petkantchin     Region &region = funcOp.getFunctionBody();
368b0d5b4d2SChengji Yao     OpBuilder builder(ctx);
369b0d5b4d2SChengji Yao     if (!region.hasOneBlock()) {
370b0d5b4d2SChengji Yao       funcOp.emitOpError() << "only one block is supported!";
371b0d5b4d2SChengji Yao       signalPassFailure();
372b0d5b4d2SChengji Yao     }
373b0d5b4d2SChengji Yao     Block &block = region.front();
374b0d5b4d2SChengji Yao 
375b0d5b4d2SChengji Yao     LLVM_DEBUG(
376b0d5b4d2SChengji Yao         DBGS() << "print all the ops' iterator types and indexing maps in the "
377b0d5b4d2SChengji Yao                   "block.\n";
378b0d5b4d2SChengji Yao         for (Operation &op
379b0d5b4d2SChengji Yao              : block.getOperations()) {
380b0d5b4d2SChengji Yao           if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
381b0d5b4d2SChengji Yao             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
382b0d5b4d2SChengji Yao         });
383b0d5b4d2SChengji Yao 
384b0d5b4d2SChengji Yao     // 1. propagate in reversed order
385b0d5b4d2SChengji Yao     for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
386b0d5b4d2SChengji Yao       if (failed(visitOp(&op, builder)))
387b0d5b4d2SChengji Yao         return signalPassFailure();
388b0d5b4d2SChengji Yao 
389b0d5b4d2SChengji Yao     LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
390b0d5b4d2SChengji Yao                       << funcOp << "\n");
391d635b860SBoian Petkantchin     LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
392b0d5b4d2SChengji Yao 
393b0d5b4d2SChengji Yao     // 2. propagate in original order
394b0d5b4d2SChengji Yao     for (Operation &op : llvm::make_early_inc_range(block))
395b0d5b4d2SChengji Yao       if (failed(visitOp(&op, builder)))
396b0d5b4d2SChengji Yao         return signalPassFailure();
397b0d5b4d2SChengji Yao   }
398b0d5b4d2SChengji Yao };
399