xref: /llvm-project/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11 
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <utility>
22 
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 
26 using namespace mlir;
27 using namespace mlir::mesh;
28 
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // common util functions
33 //===----------------------------------------------------------------------===//
34 
35 static LogicalResult
36 checkOperandAffineExprRecursively(AffineExpr expr,
37                                   SmallVectorImpl<bool> &seenIds) {
38   switch (expr.getKind()) {
39   case AffineExprKind::Add: {
40     auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41     AffineExpr lhs = binOpExpr.getLHS();
42     AffineExpr rhs = binOpExpr.getRHS();
43     if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44       return failure();
45     if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46       return failure();
47     return success();
48   }
49   case AffineExprKind::Mul: {
50     auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51     AffineExpr lhs = binOpExpr.getLHS();
52     AffineExpr rhs = binOpExpr.getRHS();
53     AffineExpr dimExpr;
54     if (lhs.getKind() == AffineExprKind::DimId &&
55         rhs.getKind() == AffineExprKind::Constant) {
56       dimExpr = lhs;
57     } else if (rhs.getKind() == AffineExprKind::DimId &&
58                lhs.getKind() == AffineExprKind::Constant) {
59       dimExpr = rhs;
60     } else
61       return failure();
62     unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63     if ((size_t)position >= seenIds.size() || seenIds[position])
64       return failure();
65     seenIds[position] = true;
66     return success();
67   }
68   case AffineExprKind::DimId: {
69     unsigned position = cast<AffineDimExpr>(expr).getPosition();
70     if ((size_t)position >= seenIds.size() || seenIds[position])
71       return failure();
72     seenIds[position] = true;
73     return success();
74   }
75   default:
76     return failure();
77   }
78 }
79 
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82   SmallVector<bool> seenIds(numDims, false);
83   if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84     return failure();
85 
86   llvm::SmallSet<unsigned, 2> positions;
87   for (auto it : llvm::enumerate(seenIds)) {
88     if (it.value())
89       positions.insert((unsigned)it.index());
90   }
91   return positions;
92 }
93 
94 template <typename T>
95 SmallVector<MeshAxesAttr>
96 fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
97   SmallVector<MeshAxesAttr> res;
98   for (const auto &v : vec) {
99     res.emplace_back(MeshAxesAttr::get(ctxt, v));
100   }
101   return res;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // mesh::getMeshSharding
106 //===----------------------------------------------------------------------===//
107 
108 FailureOr<std::pair<bool, MeshSharding>>
109 mesh::getMeshSharding(OpResult result) {
110   Value val = cast<Value>(result);
111   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
113     if (!shardOp)
114       return false;
115     return !shardOp.getAnnotateForUsers();
116   });
117 
118   if (anyShardedForDef) {
119     // expected to have exact one use if it has a use of `mesh.shard` without
120     // unit attr annotate_for_users
121     if (!val.hasOneUse())
122       return failure();
123     auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
124     return std::make_pair(false, MeshSharding(shardOp.getSharding()));
125   }
126 
127   bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
129     if (!shardOp)
130       return false;
131     return shardOp.getAnnotateForUsers();
132   });
133   if (anyShardedForUsers) {
134     SmallVector<ShardOp> shardOps;
135     for (Operation *user : val.getUsers()) {
136       ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137       if (shardOp)
138         shardOps.push_back(shardOp);
139     }
140     MeshSharding shardForDef = shardOps[0].getSharding();
141     for (size_t i = 1; i < shardOps.size(); ++i) {
142       // TODO: Deduce a reasonable mesh sharding attr for def when they are
143       // different
144       assert(shardForDef == shardOps[i].getSharding() &&
145              "only support all shard ops have the same mesh sharding attr");
146     }
147     return std::make_pair(true, shardForDef);
148   }
149   return failure();
150 }
151 
152 FailureOr<std::pair<bool, MeshSharding>>
153 mesh::getMeshSharding(OpOperand &opOperand) {
154   Value val = opOperand.get();
155   if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
156     return std::make_pair(shardOp.getAnnotateForUsers(),
157                           MeshSharding(shardOp.getSharding()));
158 
159   return failure();
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // ShardingInterface::verifyShardingInterfaceImpl
164 //===----------------------------------------------------------------------===//
165 
166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
167   Operation *op = getOperation();
168 
169   // check operands and results type
170   for (Type type : op->getOperandTypes())
171     if (!llvm::isa<RankedTensorType>(type))
172       return failure();
173   for (Type type : op->getResultTypes())
174     if (!llvm::isa<RankedTensorType>(type))
175       return failure();
176 
177   // check loop types
178   SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
179   if (loopTypes.empty())
180     return failure();
181 
182   // check maps
183   SmallVector<AffineMap> maps = getIndexingMaps();
184   if (maps.empty())
185     return failure();
186   unsigned numOperands = op->getNumOperands();
187   unsigned numResults = op->getNumResults();
188   if (numOperands + numResults != maps.size())
189     return failure();
190 
191   for (OpResult result : op->getResults()) {
192     auto resultType = dyn_cast<RankedTensorType>(result.getType());
193     if (!resultType)
194       return failure();
195     AffineMap map = maps[numOperands + result.getResultNumber()];
196     if (!map.isProjectedPermutation()) {
197       return failure();
198     }
199   }
200 
201   return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // ShardingInterface::printLoopTypesAndIndexingMaps
206 //===----------------------------------------------------------------------===//
207 
208 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
209   os << "print loop types and indexing maps for: \n";
210   getOperation()->print(os);
211   os << "\n";
212   os << "loop types: [";
213   for (utils::IteratorType type : getLoopIteratorTypes()) {
214     os << stringifyEnum(type) << " ";
215   }
216   os << "]\n";
217   os << "indexing maps: \n";
218   for (AffineMap map : getIndexingMaps())
219     os << map << "\n";
220   os << "\n";
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // detail::defaultGetShardingOption
225 //===----------------------------------------------------------------------===//
226 
227 namespace {
228 
229 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
230 static LogicalResult fillShardingOption(Operation *op,
231                                         ShardingOption &shardingOption,
232                                         FlatSymbolRefAttr mesh,
233                                         ArrayRef<MeshAxis> meshAxes,
234                                         unsigned loopIdx) {
235   if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
236       (!shardingOption.shardingArray[loopIdx].empty() &&
237        shardingOption.shardingArray[loopIdx] != meshAxes)) {
238     LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
239                       << loopIdx << "\n");
240     return failure();
241   }
242   for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
243     if (i == loopIdx)
244       continue;
245 
246     for (MeshAxis axis : meshAxes) {
247       if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
248         LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
249                           << axis << " duplicate");
250         return failure();
251       }
252     }
253   }
254   if (mesh)
255     shardingOption.mesh = mesh;
256   if (shardingOption.shardingArray[loopIdx].empty())
257     shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
258                                                  meshAxes.end());
259   return success();
260 }
261 
262 } // namespace
263 
264 FailureOr<ShardingOption>
265 mesh::detail::defaultGetShardingOption(Operation *op,
266                                        ArrayRef<MeshSharding> operandShardings,
267                                        ArrayRef<MeshSharding> resultShardings) {
268   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
269   ShardingOption shardingOption;
270 
271   if (failed(shardingOp.verifyShardingInterfaceImpl()))
272     return op->emitOpError() << "invalid sharding interface implementation";
273   SmallVector<utils::IteratorType> loopTypes =
274       shardingOp.getLoopIteratorTypes();
275   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
276   unsigned numOperands = op->getNumOperands();
277   shardingOption.shardingArray.resize(loopTypes.size());
278   llvm::SmallVector<MeshAxis> partialMeshAxes;
279   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
280   bool anyShardingInResultsOrOperands = false;
281 
282   // 1. Fill sharding option based on op results
283   for (auto shardingIt : llvm::enumerate(resultShardings)) {
284     MeshSharding shardAttr = shardingIt.value();
285     if (!shardAttr)
286       continue;
287     AffineMap map = maps[numOperands + shardingIt.index()];
288     anyShardingInResultsOrOperands = true;
289     // Handle the split axes: calculate the corresponding loop index for each
290     // split axes sub-array, and then store the sub-array to
291     // shardingOption[index]
292     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
293       AffineExpr expr = std::get<0>(it);
294       ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
295       auto dim = cast<AffineDimExpr>(expr);
296       unsigned index = dim.getPosition();
297       visitedLoopIndices.insert(index);
298       if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
299                                     axes, index)))
300         return failure();
301     }
302 
303     // Handle the partial axes: at this stage, the exact loop index/indices
304     // cannot be decided because there could be multiple reduction loops.
305     ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
306     if (!partialAxes.empty()) {
307       if (!partialMeshAxes.empty())
308         return op->emitOpError() << "at most one result with partial axes is "
309                                     "supported at present";
310       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
311       // Add all the reduction loop indices to `visitedLoopIndices` if
312       // `partialAxes` is not empty
313       for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314         if (isReductionLoop(loopTypes[loopIdx]))
315           visitedLoopIndices.insert(loopIdx);
316       }
317     }
318   }
319 
320   // 2. Fill sharding option based on operands
321   for (auto shardingIt : llvm::enumerate(operandShardings)) {
322     MeshSharding shardAttr = shardingIt.value();
323     if (!shardAttr)
324       continue;
325 
326     anyShardingInResultsOrOperands = true;
327     AffineMap map = maps[shardingIt.index()];
328     unsigned numDims = map.getNumDims();
329 
330     // Handle the split axes. Partial axes don't need to be handled because they
331     // only affect the defining op of the operand.
332     //
333     // TODO: Change to process the operands with single loop index first and
334     // then the operands with multiple loop indices.
335     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
336       AffineExpr expr = std::get<0>(it);
337       ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
338       FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339           checkOperandAffineExpr(expr, numDims);
340       if (failed(loopIndices))
341         return op->emitOpError()
342                << "operand's affine expression is restricted to const_i * "
343                   "dim_i + const_j + dim_j + ...";
344       if (loopIndices->empty())
345         continue;
346       if (loopIndices->size() == 1) {
347         unsigned loopIdx = *loopIndices->begin();
348         visitedLoopIndices.insert(loopIdx);
349         if (failed(fillShardingOption(op, shardingOption,
350                                       shardAttr.getMeshAttr(), axes, loopIdx)))
351           return failure();
352       }
353       // If multiple loop indices correspond to a dimension of an operand, it is
354       // difficult to infer which loop indices are responsible for sharding.
355       // Therefore, the exact loop index must be specified by others.
356       if (loopIndices->size() > 1) {
357         bool seenLoopIndices = false;
358         for (unsigned loopIdx : *loopIndices) {
359           if (visitedLoopIndices.contains(loopIdx)) {
360             seenLoopIndices = true;
361             break;
362           }
363         }
364         if (!seenLoopIndices)
365           return op->emitOpError()
366                  << "the operand " << shardingIt.index()
367                  << " has multiple loop indices in a dimension, but none of "
368                     "them could be found in the exactly specified annotation "
369                     "of op results or operands.";
370       }
371     }
372   }
373 
374   // 3. Finalize sharding option
375   if (!partialMeshAxes.empty()) {
376     bool anyNonEmptyReductionLoop = llvm::any_of(
377         llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
378           SmallVector<MeshAxis> &subArray = it.value();
379           int64_t idx = it.index();
380           return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381         });
382     if (!anyNonEmptyReductionLoop) {
383       bool filled = false;
384       for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
385         if (isReductionLoop(loopTypes[idx])) {
386           std::ignore = fillShardingOption(op, shardingOption, nullptr,
387                                            partialMeshAxes, idx);
388           filled = true;
389           break;
390         }
391       }
392       if (!filled)
393         return op->emitOpError() << "no matched reduction loop found for the "
394                                     "result's partial type";
395     }
396   }
397   removeTrailingEmptySubArray(shardingOption.shardingArray);
398   if (!anyShardingInResultsOrOperands)
399     shardingOption.empty = true;
400   return shardingOption;
401 }
402 
403 // Get the sharding attributed for the given result and sharding option.
404 MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
405                          AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
406                          ArrayRef<ReductionKind> reductionLoopKinds) {
407   auto resultType = cast<RankedTensorType>(result.getType());
408   SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409   SmallVector<MeshAxis> partialAxes;
410 
411   // process the split axes
412   for (auto it : llvm::enumerate(map.getResults())) {
413     SmallVector<MeshAxis> tmp_axes;
414     AffineExpr expr = it.value();
415     // `expr` must be an `AffineDimExpr` because `map` is verified by
416     // isProjectedPermutation
417     auto dim = cast<AffineDimExpr>(expr);
418     unsigned loopIdx = dim.getPosition();
419     if (loopIdx < shardingOption.shardingArray.size())
420       splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
421   }
422 
423   // process the partial axes
424   // partialType will be ignored if partialAxes is empty
425   ReductionKind partialType = ReductionKind::Sum;
426   size_t reductionLoopKindsIdx = 0;
427   for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
428     utils::IteratorType iType = std::get<0>(it);
429     if (isReductionLoop(iType)) {
430       ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
431       ++reductionLoopKindsIdx;
432       if (!partialAxes.empty())
433         assert(partialType == curPartialType &&
434                "Only one reduction type is supported");
435       partialType = curPartialType;
436       const SmallVector<MeshAxis> &axis = std::get<1>(it);
437       partialAxes.append(axis);
438     }
439   }
440 
441   removeTrailingEmptySubArray(splitAxes);
442   return MeshSharding::get(shardingOption.mesh,
443                            fromArrayOfVector(result.getContext(), splitAxes),
444                            partialAxes, partialType);
445 }
446 
447 static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
448                                            const ShardingOption &shardingOption,
449                                            AffineMap map) {
450   Value operandValue = opOperand.get();
451   auto operandType = cast<RankedTensorType>(operandValue.getType());
452   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
453   unsigned numDims = map.getNumDims();
454   for (auto it : llvm::enumerate(map.getResults())) {
455     int64_t idx = it.index();
456     AffineExpr expr = it.value();
457     FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
458         checkOperandAffineExpr(expr, numDims);
459     if (failed(loopIndices))
460       return failure();
461     SmallVector<unsigned> shardedLoopIndices;
462     for (unsigned loopIdx : *loopIndices) {
463       if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
464           !shardingOption.shardingArray[loopIdx].empty())
465         shardedLoopIndices.push_back(loopIdx);
466     }
467     // mostly one sharded loop index is accepted
468     if (shardedLoopIndices.size() > 1)
469       return failure();
470     if (shardedLoopIndices.size() == 1) {
471       splitAxes[idx].append(
472           shardingOption.shardingArray[shardedLoopIndices[0]]);
473     }
474   }
475 
476   removeTrailingEmptySubArray(splitAxes);
477   return MeshSharding::get(
478       shardingOption.mesh,
479       fromArrayOfVector(opOperand.get().getContext(), splitAxes));
480 }
481 
482 FailureOr<std::vector<MeshSharding>>
483 mesh::detail::defaultGetShardingAnnotations(
484     Operation *op, const ShardingOption &shardingOption) {
485   std::vector<MeshSharding> res;
486 
487   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
488   SmallVector<utils::IteratorType> loopTypes =
489       shardingOp.getLoopIteratorTypes();
490   SmallVector<ReductionKind> reductionKinds =
491       shardingOp.getReductionLoopIteratorKinds();
492   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
493   unsigned numOperands = op->getNumOperands();
494 
495   for (OpOperand &opOperand : op->getOpOperands()) {
496     FailureOr<MeshSharding> shardingAttr = getSharding(
497         opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
498     if (failed(shardingAttr))
499       return failure();
500     res.push_back(*shardingAttr);
501   }
502 
503   for (OpResult result : op->getResults()) {
504     res.push_back(getSharding(result, shardingOption,
505                               maps[numOperands + result.getResultNumber()],
506                               loopTypes, reductionKinds));
507   }
508 
509   return res;
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // detail::defaultAddShardingAnnotations
514 //===----------------------------------------------------------------------===//
515 
516 // To add a `mesh.shard` op for the given result, based on the details provided
517 // in `shardingOption`, `map`, and `loopTypes`.
518 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
519                                 const ShardingOption &shardingOption,
520                                 AffineMap map,
521                                 ArrayRef<utils::IteratorType> loopTypes,
522                                 ArrayRef<ReductionKind> reductionLoopKinds) {
523   MeshSharding sharding =
524       getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
525   maybeInsertTargetShardingAnnotation(sharding, result, b);
526 
527   return success();
528 }
529 
530 // To add a `mesh.shard` op for the given operand, based on the details provided
531 // in `shardingOption`, `map`, and `loopTypes`.
532 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
533                                 const ShardingOption &shardingOption,
534                                 AffineMap map) {
535 
536   FailureOr<MeshSharding> sharding =
537       getSharding(opOperand, shardingOption, map);
538   if (failed(sharding)) {
539     return failure();
540   }
541   OpBuilder::InsertionGuard guard(b);
542   maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
543 
544   return success();
545 }
546 
547 LogicalResult mesh::detail::defaultAddShardingAnnotations(
548     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
549   assert(!shardingOption.empty && shardingOption.mesh);
550 
551   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
552   SmallVector<utils::IteratorType> loopTypes =
553       shardingOp.getLoopIteratorTypes();
554   SmallVector<ReductionKind> reductionKinds =
555       shardingOp.getReductionLoopIteratorKinds();
556   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
557   unsigned numOperands = op->getNumOperands();
558 
559   // 1. add mesh.shard ops for all op results
560   for (OpResult result : op->getResults()) {
561     if (failed(addShardOp(b, result, shardingOption,
562                           maps[numOperands + result.getResultNumber()],
563                           loopTypes, reductionKinds)))
564       return failure();
565   }
566 
567   // 2. add mesh.shard ops for all operands
568   for (OpOperand &opOperand : op->getOpOperands()) {
569     if (failed(addShardOp(b, opOperand, shardingOption,
570                           maps[opOperand.getOperandNumber()])))
571       return failure();
572   }
573 
574   return success();
575 }
576 
577 #ifndef NDEBUG
578 static bool
579 isValueCompatibleWithFullReplicationSharding(Value value,
580                                              MeshSharding sharding) {
581   if (isa<RankedTensorType>(value.getType())) {
582     return sharding && isFullReplication(sharding);
583   }
584 
585   return !sharding;
586 }
587 
588 template <typename ValueRange, typename MeshShardingRage>
589 static bool
590 areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
591                                                 MeshShardingRage &&shardings) {
592   if (std::size(values) != std::size(shardings)) {
593     return false;
594   }
595   return llvm::all_of(
596       llvm::zip_equal(std::forward<ValueRange>(values),
597                       std::forward<MeshShardingRage>(shardings)),
598       [](auto valueAndSharding) {
599         return isValueCompatibleWithFullReplicationSharding(
600             std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
601       });
602 }
603 #endif // NDEBUG
604 
605 void mesh::spmdizeFullyReplicatedOperation(
606     Operation &op, ArrayRef<Value> spmdizedOperands,
607     ArrayRef<MeshSharding> operandShardings,
608     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
609     SymbolTableCollection &symbolTable, OpBuilder &builder) {
610   assert(spmdizedOperands.size() == operandShardings.size());
611   assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
612                                                          operandShardings));
613   assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
614                                                          resultShardings));
615   // `clone` will populate the mapping of old to new results.
616   builder.clone(op, spmdizationMap);
617 }
618 
619 static void updateMeshAxisAssignmentForLoopIterators(
620     ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
621     SmallVector<std::optional<SmallVector<MeshAxis>>>
622         &meshAxesAssignmentForLoopIterators) {
623   AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
624   unsigned loopIteratorIdx = affineDimExpr.getPosition();
625   if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
626     assert(llvm::equal(meshAxesAssignmentForTensorAxis,
627                        *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
628   } else {
629     meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
630         llvm::to_vector(meshAxesAssignmentForTensorAxis);
631   }
632 }
633 
634 ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
635     ArrayRef<MeshSharding> operandShardings,
636     ArrayRef<MeshSharding> resultShardings,
637     ArrayRef<utils::IteratorType> loopIteratorTypes,
638     ArrayRef<AffineMap> indexingMaps) {
639   SmallVector<std::optional<SmallVector<MeshAxis>>>
640       meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
641   std::vector<MeshSharding> operatorAndResultShardings;
642   operatorAndResultShardings.reserve(operandShardings.size() +
643                                      resultShardings.size());
644   llvm::append_range(operatorAndResultShardings, operandShardings);
645   for (auto [sharding, affineMap] :
646        llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
647     if (!sharding) {
648       continue;
649     }
650     for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
651          llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
652       updateMeshAxisAssignmentForLoopIterators(
653           meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
654           meshAxisAssignmentForLoopIterators);
655     }
656     // Missing trailing split axes means replication on those tensor dimensions.
657     for (unsigned i = sharding.getSplitAxes().size();
658          i < affineMap.getNumResults(); ++i) {
659       updateMeshAxisAssignmentForLoopIterators(
660           {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
661     }
662   }
663 
664   ShardingArray res;
665   llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
666                   [](std::optional<SmallVector<MeshAxis>> &axes) {
667                     if (!axes) {
668                       return SmallVector<MeshAxis>();
669                     };
670                     return std::move(*axes);
671                   });
672   return res;
673 }
674 
675 bool mesh::isAtLeastOneReductionIteratorSharded(
676     ArrayRef<utils::IteratorType> loopIteratorTypes,
677     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
678   for (auto [loopIteratorType, meshAxisAssignment] :
679        llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680     if (loopIteratorType == utils::IteratorType::reduction &&
681         !meshAxisAssignment.empty()) {
682       return true;
683     }
684   }
685   return false;
686 }
687 
688 SmallVector<MeshAxis> mesh::getReductionMeshAxes(
689     ArrayRef<utils::IteratorType> loopIteratorTypes,
690     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
691   SmallVector<MeshAxis> meshAxes;
692   for (auto [loopIteratorType, meshAxisAssignment] :
693        llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
694     if (loopIteratorType == utils::IteratorType::reduction) {
695       llvm::append_range(meshAxes, meshAxisAssignment);
696     }
697   }
698   return meshAxes;
699 }
700 
701 void mesh::spmdizeTriviallyShardableOperation(
702     Operation &op, ArrayRef<Value> spmdizedOperands,
703     ArrayRef<MeshSharding> operandShardings,
704     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
705     SymbolTableCollection &symbolTable, OpBuilder &builder) {
706   // `clone` will populate the mapping of old to new results.
707   Operation *newOp = builder.clone(op, spmdizationMap);
708   // Set the result types to the sharded counterparts.
709   for (auto [oldResult, newResult, sharding] :
710        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
711     newResult.setType(
712         shardType(newResult.getType(),
713                   getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
714   }
715 }
716