xref: /llvm-project/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1b0d5b4d2SChengji Yao //===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2b0d5b4d2SChengji Yao //
3b0d5b4d2SChengji Yao // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b0d5b4d2SChengji Yao // See https://llvm.org/LICENSE.txt for license information.
5b0d5b4d2SChengji Yao // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b0d5b4d2SChengji Yao //
7b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
8b0d5b4d2SChengji Yao 
9b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11adbf21f1SBoian Petkantchin 
12b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13b0d5b4d2SChengji Yao #include "mlir/IR/AffineMap.h"
14adbf21f1SBoian Petkantchin #include "mlir/IR/IRMapping.h"
15b0d5b4d2SChengji Yao #include "mlir/Support/LLVM.h"
16ff2720d1SBoian Petkantchin #include "llvm/ADT/ArrayRef.h"
17adbf21f1SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
18b0d5b4d2SChengji Yao #include "llvm/ADT/SmallSet.h"
19b0d5b4d2SChengji Yao #include "llvm/Support/Debug.h"
20b0d5b4d2SChengji Yao 
21b0d5b4d2SChengji Yao #include <utility>
22b0d5b4d2SChengji Yao 
23b0d5b4d2SChengji Yao #define DEBUG_TYPE "sharding-interface"
24b0d5b4d2SChengji Yao #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25b0d5b4d2SChengji Yao 
26b0d5b4d2SChengji Yao using namespace mlir;
27b0d5b4d2SChengji Yao using namespace mlir::mesh;
28b0d5b4d2SChengji Yao 
29b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30b0d5b4d2SChengji Yao 
31b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
32b0d5b4d2SChengji Yao // common util functions
33b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
34b0d5b4d2SChengji Yao 
35b0d5b4d2SChengji Yao static LogicalResult
36b0d5b4d2SChengji Yao checkOperandAffineExprRecursively(AffineExpr expr,
37b0d5b4d2SChengji Yao                                   SmallVectorImpl<bool> &seenIds) {
38b0d5b4d2SChengji Yao   switch (expr.getKind()) {
39b0d5b4d2SChengji Yao   case AffineExprKind::Add: {
401609f1c2Slong.chen     auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41b0d5b4d2SChengji Yao     AffineExpr lhs = binOpExpr.getLHS();
42b0d5b4d2SChengji Yao     AffineExpr rhs = binOpExpr.getRHS();
43b0d5b4d2SChengji Yao     if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
44b0d5b4d2SChengji Yao       return failure();
45b0d5b4d2SChengji Yao     if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
46b0d5b4d2SChengji Yao       return failure();
47b0d5b4d2SChengji Yao     return success();
48b0d5b4d2SChengji Yao   }
49b0d5b4d2SChengji Yao   case AffineExprKind::Mul: {
501609f1c2Slong.chen     auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51b0d5b4d2SChengji Yao     AffineExpr lhs = binOpExpr.getLHS();
52b0d5b4d2SChengji Yao     AffineExpr rhs = binOpExpr.getRHS();
53b0d5b4d2SChengji Yao     AffineExpr dimExpr;
54b0d5b4d2SChengji Yao     if (lhs.getKind() == AffineExprKind::DimId &&
55b0d5b4d2SChengji Yao         rhs.getKind() == AffineExprKind::Constant) {
56b0d5b4d2SChengji Yao       dimExpr = lhs;
57b0d5b4d2SChengji Yao     } else if (rhs.getKind() == AffineExprKind::DimId &&
58b0d5b4d2SChengji Yao                lhs.getKind() == AffineExprKind::Constant) {
59b0d5b4d2SChengji Yao       dimExpr = rhs;
60b0d5b4d2SChengji Yao     } else
61b0d5b4d2SChengji Yao       return failure();
621609f1c2Slong.chen     unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63b0d5b4d2SChengji Yao     if ((size_t)position >= seenIds.size() || seenIds[position])
64b0d5b4d2SChengji Yao       return failure();
65b0d5b4d2SChengji Yao     seenIds[position] = true;
66b0d5b4d2SChengji Yao     return success();
67b0d5b4d2SChengji Yao   }
68b0d5b4d2SChengji Yao   case AffineExprKind::DimId: {
691609f1c2Slong.chen     unsigned position = cast<AffineDimExpr>(expr).getPosition();
70b0d5b4d2SChengji Yao     if ((size_t)position >= seenIds.size() || seenIds[position])
71b0d5b4d2SChengji Yao       return failure();
72b0d5b4d2SChengji Yao     seenIds[position] = true;
73b0d5b4d2SChengji Yao     return success();
74b0d5b4d2SChengji Yao   }
75b0d5b4d2SChengji Yao   default:
76b0d5b4d2SChengji Yao     return failure();
77b0d5b4d2SChengji Yao   }
78b0d5b4d2SChengji Yao }
79b0d5b4d2SChengji Yao 
80b0d5b4d2SChengji Yao static FailureOr<llvm::SmallSet<unsigned, 2>>
81b0d5b4d2SChengji Yao checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82b0d5b4d2SChengji Yao   SmallVector<bool> seenIds(numDims, false);
83b0d5b4d2SChengji Yao   if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
84b0d5b4d2SChengji Yao     return failure();
85b0d5b4d2SChengji Yao 
86b0d5b4d2SChengji Yao   llvm::SmallSet<unsigned, 2> positions;
87b0d5b4d2SChengji Yao   for (auto it : llvm::enumerate(seenIds)) {
88b0d5b4d2SChengji Yao     if (it.value())
89b0d5b4d2SChengji Yao       positions.insert((unsigned)it.index());
90b0d5b4d2SChengji Yao   }
91b0d5b4d2SChengji Yao   return positions;
92b0d5b4d2SChengji Yao }
93b0d5b4d2SChengji Yao 
94*baabcb28SFrank Schlimbach template <typename T>
95*baabcb28SFrank Schlimbach SmallVector<MeshAxesAttr>
96*baabcb28SFrank Schlimbach fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
97*baabcb28SFrank Schlimbach   SmallVector<MeshAxesAttr> res;
98*baabcb28SFrank Schlimbach   for (const auto &v : vec) {
99*baabcb28SFrank Schlimbach     res.emplace_back(MeshAxesAttr::get(ctxt, v));
100*baabcb28SFrank Schlimbach   }
101*baabcb28SFrank Schlimbach   return res;
102*baabcb28SFrank Schlimbach }
103*baabcb28SFrank Schlimbach 
104b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
105*baabcb28SFrank Schlimbach // mesh::getMeshSharding
106b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
107b0d5b4d2SChengji Yao 
108*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>>
109*baabcb28SFrank Schlimbach mesh::getMeshSharding(OpResult result) {
110a5757c5bSChristian Sigg   Value val = cast<Value>(result);
111b0d5b4d2SChengji Yao   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
112b0d5b4d2SChengji Yao     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
113b0d5b4d2SChengji Yao     if (!shardOp)
114b0d5b4d2SChengji Yao       return false;
115b0d5b4d2SChengji Yao     return !shardOp.getAnnotateForUsers();
116b0d5b4d2SChengji Yao   });
117b0d5b4d2SChengji Yao 
118b0d5b4d2SChengji Yao   if (anyShardedForDef) {
119b0d5b4d2SChengji Yao     // expected to have exact one use if it has a use of `mesh.shard` without
120b0d5b4d2SChengji Yao     // unit attr annotate_for_users
121b0d5b4d2SChengji Yao     if (!val.hasOneUse())
122b0d5b4d2SChengji Yao       return failure();
123b0d5b4d2SChengji Yao     auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
124*baabcb28SFrank Schlimbach     return std::make_pair(false, MeshSharding(shardOp.getSharding()));
125b0d5b4d2SChengji Yao   }
126b0d5b4d2SChengji Yao 
127b0d5b4d2SChengji Yao   bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
128b0d5b4d2SChengji Yao     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
129b0d5b4d2SChengji Yao     if (!shardOp)
130b0d5b4d2SChengji Yao       return false;
131b0d5b4d2SChengji Yao     return shardOp.getAnnotateForUsers();
132b0d5b4d2SChengji Yao   });
133b0d5b4d2SChengji Yao   if (anyShardedForUsers) {
134b0d5b4d2SChengji Yao     SmallVector<ShardOp> shardOps;
135b0d5b4d2SChengji Yao     for (Operation *user : val.getUsers()) {
136b0d5b4d2SChengji Yao       ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
137b0d5b4d2SChengji Yao       if (shardOp)
138b0d5b4d2SChengji Yao         shardOps.push_back(shardOp);
139b0d5b4d2SChengji Yao     }
140*baabcb28SFrank Schlimbach     MeshSharding shardForDef = shardOps[0].getSharding();
141b0d5b4d2SChengji Yao     for (size_t i = 1; i < shardOps.size(); ++i) {
142b0d5b4d2SChengji Yao       // TODO: Deduce a reasonable mesh sharding attr for def when they are
143b0d5b4d2SChengji Yao       // different
144*baabcb28SFrank Schlimbach       assert(shardForDef == shardOps[i].getSharding() &&
145b0d5b4d2SChengji Yao              "only support all shard ops have the same mesh sharding attr");
146b0d5b4d2SChengji Yao     }
147b0d5b4d2SChengji Yao     return std::make_pair(true, shardForDef);
148b0d5b4d2SChengji Yao   }
149b0d5b4d2SChengji Yao   return failure();
150b0d5b4d2SChengji Yao }
151b0d5b4d2SChengji Yao 
152*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>>
153*baabcb28SFrank Schlimbach mesh::getMeshSharding(OpOperand &opOperand) {
154b0d5b4d2SChengji Yao   Value val = opOperand.get();
155b0d5b4d2SChengji Yao   if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
156*baabcb28SFrank Schlimbach     return std::make_pair(shardOp.getAnnotateForUsers(),
157*baabcb28SFrank Schlimbach                           MeshSharding(shardOp.getSharding()));
158b0d5b4d2SChengji Yao 
159b0d5b4d2SChengji Yao   return failure();
160b0d5b4d2SChengji Yao }
161b0d5b4d2SChengji Yao 
162b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
163b0d5b4d2SChengji Yao // ShardingInterface::verifyShardingInterfaceImpl
164b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
165b0d5b4d2SChengji Yao 
166b0d5b4d2SChengji Yao LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
167b0d5b4d2SChengji Yao   Operation *op = getOperation();
168b0d5b4d2SChengji Yao 
169b0d5b4d2SChengji Yao   // check operands and results type
170b0d5b4d2SChengji Yao   for (Type type : op->getOperandTypes())
171b0d5b4d2SChengji Yao     if (!llvm::isa<RankedTensorType>(type))
172b0d5b4d2SChengji Yao       return failure();
173b0d5b4d2SChengji Yao   for (Type type : op->getResultTypes())
174b0d5b4d2SChengji Yao     if (!llvm::isa<RankedTensorType>(type))
175b0d5b4d2SChengji Yao       return failure();
176b0d5b4d2SChengji Yao 
177b0d5b4d2SChengji Yao   // check loop types
178ff2720d1SBoian Petkantchin   SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
1793b6462c5SAdrian Kuegel   if (loopTypes.empty())
180b0d5b4d2SChengji Yao     return failure();
181b0d5b4d2SChengji Yao 
182b0d5b4d2SChengji Yao   // check maps
183b0d5b4d2SChengji Yao   SmallVector<AffineMap> maps = getIndexingMaps();
1843b6462c5SAdrian Kuegel   if (maps.empty())
185b0d5b4d2SChengji Yao     return failure();
186b0d5b4d2SChengji Yao   unsigned numOperands = op->getNumOperands();
187b0d5b4d2SChengji Yao   unsigned numResults = op->getNumResults();
188b0d5b4d2SChengji Yao   if (numOperands + numResults != maps.size())
189b0d5b4d2SChengji Yao     return failure();
190b0d5b4d2SChengji Yao 
191b0d5b4d2SChengji Yao   for (OpResult result : op->getResults()) {
192a5757c5bSChristian Sigg     auto resultType = dyn_cast<RankedTensorType>(result.getType());
193b0d5b4d2SChengji Yao     if (!resultType)
194b0d5b4d2SChengji Yao       return failure();
195b0d5b4d2SChengji Yao     AffineMap map = maps[numOperands + result.getResultNumber()];
196b0d5b4d2SChengji Yao     if (!map.isProjectedPermutation()) {
197b0d5b4d2SChengji Yao       return failure();
198b0d5b4d2SChengji Yao     }
199b0d5b4d2SChengji Yao   }
200b0d5b4d2SChengji Yao 
201b0d5b4d2SChengji Yao   return success();
202b0d5b4d2SChengji Yao }
203b0d5b4d2SChengji Yao 
204b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
205b0d5b4d2SChengji Yao // ShardingInterface::printLoopTypesAndIndexingMaps
206b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
207b0d5b4d2SChengji Yao 
208b0d5b4d2SChengji Yao void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
209b0d5b4d2SChengji Yao   os << "print loop types and indexing maps for: \n";
210b0d5b4d2SChengji Yao   getOperation()->print(os);
211b0d5b4d2SChengji Yao   os << "\n";
212b0d5b4d2SChengji Yao   os << "loop types: [";
213ff2720d1SBoian Petkantchin   for (utils::IteratorType type : getLoopIteratorTypes()) {
214b0d5b4d2SChengji Yao     os << stringifyEnum(type) << " ";
215b0d5b4d2SChengji Yao   }
216b0d5b4d2SChengji Yao   os << "]\n";
217b0d5b4d2SChengji Yao   os << "indexing maps: \n";
218b0d5b4d2SChengji Yao   for (AffineMap map : getIndexingMaps())
219b0d5b4d2SChengji Yao     os << map << "\n";
220b0d5b4d2SChengji Yao   os << "\n";
221b0d5b4d2SChengji Yao }
222b0d5b4d2SChengji Yao 
223b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
224b0d5b4d2SChengji Yao // detail::defaultGetShardingOption
225b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===//
226b0d5b4d2SChengji Yao 
227b0d5b4d2SChengji Yao namespace {
228b0d5b4d2SChengji Yao 
229b0d5b4d2SChengji Yao // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
230b0d5b4d2SChengji Yao static LogicalResult fillShardingOption(Operation *op,
231b0d5b4d2SChengji Yao                                         ShardingOption &shardingOption,
2329a8437f5SBoian Petkantchin                                         FlatSymbolRefAttr mesh,
2337a4c4975SBoian Petkantchin                                         ArrayRef<MeshAxis> meshAxes,
234b0d5b4d2SChengji Yao                                         unsigned loopIdx) {
2359a8437f5SBoian Petkantchin   if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
236b0d5b4d2SChengji Yao       (!shardingOption.shardingArray[loopIdx].empty() &&
237b0d5b4d2SChengji Yao        shardingOption.shardingArray[loopIdx] != meshAxes)) {
238b0d5b4d2SChengji Yao     LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
239b0d5b4d2SChengji Yao                       << loopIdx << "\n");
240b0d5b4d2SChengji Yao     return failure();
241b0d5b4d2SChengji Yao   }
242b0d5b4d2SChengji Yao   for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
243b0d5b4d2SChengji Yao     if (i == loopIdx)
244b0d5b4d2SChengji Yao       continue;
245b0d5b4d2SChengji Yao 
2467a4c4975SBoian Petkantchin     for (MeshAxis axis : meshAxes) {
24766555810SKazu Hirata       if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
248b0d5b4d2SChengji Yao         LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
249b0d5b4d2SChengji Yao                           << axis << " duplicate");
250b0d5b4d2SChengji Yao         return failure();
251b0d5b4d2SChengji Yao       }
252b0d5b4d2SChengji Yao     }
253b0d5b4d2SChengji Yao   }
2549a8437f5SBoian Petkantchin   if (mesh)
2559a8437f5SBoian Petkantchin     shardingOption.mesh = mesh;
256b0d5b4d2SChengji Yao   if (shardingOption.shardingArray[loopIdx].empty())
257b0d5b4d2SChengji Yao     shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
258b0d5b4d2SChengji Yao                                                  meshAxes.end());
259b0d5b4d2SChengji Yao   return success();
260b0d5b4d2SChengji Yao }
261b0d5b4d2SChengji Yao 
262b0d5b4d2SChengji Yao } // namespace
263b0d5b4d2SChengji Yao 
264*baabcb28SFrank Schlimbach FailureOr<ShardingOption>
265*baabcb28SFrank Schlimbach mesh::detail::defaultGetShardingOption(Operation *op,
266*baabcb28SFrank Schlimbach                                        ArrayRef<MeshSharding> operandShardings,
267*baabcb28SFrank Schlimbach                                        ArrayRef<MeshSharding> resultShardings) {
268b0d5b4d2SChengji Yao   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
269b0d5b4d2SChengji Yao   ShardingOption shardingOption;
270b0d5b4d2SChengji Yao 
271b0d5b4d2SChengji Yao   if (failed(shardingOp.verifyShardingInterfaceImpl()))
272b0d5b4d2SChengji Yao     return op->emitOpError() << "invalid sharding interface implementation";
273ff2720d1SBoian Petkantchin   SmallVector<utils::IteratorType> loopTypes =
274ff2720d1SBoian Petkantchin       shardingOp.getLoopIteratorTypes();
275b0d5b4d2SChengji Yao   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
276b0d5b4d2SChengji Yao   unsigned numOperands = op->getNumOperands();
277b0d5b4d2SChengji Yao   shardingOption.shardingArray.resize(loopTypes.size());
2787a4c4975SBoian Petkantchin   llvm::SmallVector<MeshAxis> partialMeshAxes;
279b0d5b4d2SChengji Yao   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
280b0d5b4d2SChengji Yao   bool anyShardingInResultsOrOperands = false;
281b0d5b4d2SChengji Yao 
282b0d5b4d2SChengji Yao   // 1. Fill sharding option based on op results
283b0d5b4d2SChengji Yao   for (auto shardingIt : llvm::enumerate(resultShardings)) {
284*baabcb28SFrank Schlimbach     MeshSharding shardAttr = shardingIt.value();
285b0d5b4d2SChengji Yao     if (!shardAttr)
286b0d5b4d2SChengji Yao       continue;
287b0d5b4d2SChengji Yao     AffineMap map = maps[numOperands + shardingIt.index()];
288b0d5b4d2SChengji Yao     anyShardingInResultsOrOperands = true;
289b0d5b4d2SChengji Yao     // Handle the split axes: calculate the corresponding loop index for each
290b0d5b4d2SChengji Yao     // split axes sub-array, and then store the sub-array to
291b0d5b4d2SChengji Yao     // shardingOption[index]
292b0d5b4d2SChengji Yao     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
293b0d5b4d2SChengji Yao       AffineExpr expr = std::get<0>(it);
2947a4c4975SBoian Petkantchin       ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
2951609f1c2Slong.chen       auto dim = cast<AffineDimExpr>(expr);
296b0d5b4d2SChengji Yao       unsigned index = dim.getPosition();
297b0d5b4d2SChengji Yao       visitedLoopIndices.insert(index);
298*baabcb28SFrank Schlimbach       if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
299b0d5b4d2SChengji Yao                                     axes, index)))
300b0d5b4d2SChengji Yao         return failure();
301b0d5b4d2SChengji Yao     }
302b0d5b4d2SChengji Yao 
303b0d5b4d2SChengji Yao     // Handle the partial axes: at this stage, the exact loop index/indices
304b0d5b4d2SChengji Yao     // cannot be decided because there could be multiple reduction loops.
3057a4c4975SBoian Petkantchin     ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
306b0d5b4d2SChengji Yao     if (!partialAxes.empty()) {
307b0d5b4d2SChengji Yao       if (!partialMeshAxes.empty())
308b0d5b4d2SChengji Yao         return op->emitOpError() << "at most one result with partial axes is "
309b0d5b4d2SChengji Yao                                     "supported at present";
310b0d5b4d2SChengji Yao       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
311b0d5b4d2SChengji Yao       // Add all the reduction loop indices to `visitedLoopIndices` if
312b0d5b4d2SChengji Yao       // `partialAxes` is not empty
313b0d5b4d2SChengji Yao       for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314b0d5b4d2SChengji Yao         if (isReductionLoop(loopTypes[loopIdx]))
315b0d5b4d2SChengji Yao           visitedLoopIndices.insert(loopIdx);
316b0d5b4d2SChengji Yao       }
317b0d5b4d2SChengji Yao     }
318b0d5b4d2SChengji Yao   }
319b0d5b4d2SChengji Yao 
320b0d5b4d2SChengji Yao   // 2. Fill sharding option based on operands
321b0d5b4d2SChengji Yao   for (auto shardingIt : llvm::enumerate(operandShardings)) {
322*baabcb28SFrank Schlimbach     MeshSharding shardAttr = shardingIt.value();
323b0d5b4d2SChengji Yao     if (!shardAttr)
324b0d5b4d2SChengji Yao       continue;
325b0d5b4d2SChengji Yao 
326b0d5b4d2SChengji Yao     anyShardingInResultsOrOperands = true;
327b0d5b4d2SChengji Yao     AffineMap map = maps[shardingIt.index()];
328b0d5b4d2SChengji Yao     unsigned numDims = map.getNumDims();
329b0d5b4d2SChengji Yao 
330b0d5b4d2SChengji Yao     // Handle the split axes. Partial axes don't need to be handled because they
331b0d5b4d2SChengji Yao     // only affect the defining op of the operand.
332b0d5b4d2SChengji Yao     //
333b0d5b4d2SChengji Yao     // TODO: Change to process the operands with single loop index first and
334b0d5b4d2SChengji Yao     // then the operands with multiple loop indices.
335b0d5b4d2SChengji Yao     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
336b0d5b4d2SChengji Yao       AffineExpr expr = std::get<0>(it);
3377a4c4975SBoian Petkantchin       ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
338b0d5b4d2SChengji Yao       FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339b0d5b4d2SChengji Yao           checkOperandAffineExpr(expr, numDims);
340b0d5b4d2SChengji Yao       if (failed(loopIndices))
341b0d5b4d2SChengji Yao         return op->emitOpError()
342b0d5b4d2SChengji Yao                << "operand's affine expression is restricted to const_i * "
343b0d5b4d2SChengji Yao                   "dim_i + const_j + dim_j + ...";
344b0d5b4d2SChengji Yao       if (loopIndices->empty())
345b0d5b4d2SChengji Yao         continue;
346b0d5b4d2SChengji Yao       if (loopIndices->size() == 1) {
347b0d5b4d2SChengji Yao         unsigned loopIdx = *loopIndices->begin();
348b0d5b4d2SChengji Yao         visitedLoopIndices.insert(loopIdx);
349*baabcb28SFrank Schlimbach         if (failed(fillShardingOption(op, shardingOption,
350*baabcb28SFrank Schlimbach                                       shardAttr.getMeshAttr(), axes, loopIdx)))
351b0d5b4d2SChengji Yao           return failure();
352b0d5b4d2SChengji Yao       }
353b0d5b4d2SChengji Yao       // If multiple loop indices correspond to a dimension of an operand, it is
354b0d5b4d2SChengji Yao       // difficult to infer which loop indices are responsible for sharding.
355b0d5b4d2SChengji Yao       // Therefore, the exact loop index must be specified by others.
356b0d5b4d2SChengji Yao       if (loopIndices->size() > 1) {
357b0d5b4d2SChengji Yao         bool seenLoopIndices = false;
358b0d5b4d2SChengji Yao         for (unsigned loopIdx : *loopIndices) {
359b0d5b4d2SChengji Yao           if (visitedLoopIndices.contains(loopIdx)) {
360b0d5b4d2SChengji Yao             seenLoopIndices = true;
361b0d5b4d2SChengji Yao             break;
362b0d5b4d2SChengji Yao           }
363b0d5b4d2SChengji Yao         }
364b0d5b4d2SChengji Yao         if (!seenLoopIndices)
365b0d5b4d2SChengji Yao           return op->emitOpError()
366b0d5b4d2SChengji Yao                  << "the operand " << shardingIt.index()
367b0d5b4d2SChengji Yao                  << " has multiple loop indices in a dimension, but none of "
368b0d5b4d2SChengji Yao                     "them could be found in the exactly specified annotation "
369b0d5b4d2SChengji Yao                     "of op results or operands.";
370b0d5b4d2SChengji Yao       }
371b0d5b4d2SChengji Yao     }
372b0d5b4d2SChengji Yao   }
373b0d5b4d2SChengji Yao 
374b0d5b4d2SChengji Yao   // 3. Finalize sharding option
375b0d5b4d2SChengji Yao   if (!partialMeshAxes.empty()) {
376b0d5b4d2SChengji Yao     bool anyNonEmptyReductionLoop = llvm::any_of(
377b0d5b4d2SChengji Yao         llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
3787a4c4975SBoian Petkantchin           SmallVector<MeshAxis> &subArray = it.value();
379b0d5b4d2SChengji Yao           int64_t idx = it.index();
380b0d5b4d2SChengji Yao           return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381b0d5b4d2SChengji Yao         });
382b0d5b4d2SChengji Yao     if (!anyNonEmptyReductionLoop) {
383b0d5b4d2SChengji Yao       bool filled = false;
384b0d5b4d2SChengji Yao       for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
385ff2720d1SBoian Petkantchin         if (isReductionLoop(loopTypes[idx])) {
386b0d5b4d2SChengji Yao           std::ignore = fillShardingOption(op, shardingOption, nullptr,
387b0d5b4d2SChengji Yao                                            partialMeshAxes, idx);
388b0d5b4d2SChengji Yao           filled = true;
389b0d5b4d2SChengji Yao           break;
390b0d5b4d2SChengji Yao         }
391b0d5b4d2SChengji Yao       }
392b0d5b4d2SChengji Yao       if (!filled)
393b0d5b4d2SChengji Yao         return op->emitOpError() << "no matched reduction loop found for the "
394b0d5b4d2SChengji Yao                                     "result's partial type";
395b0d5b4d2SChengji Yao     }
396b0d5b4d2SChengji Yao   }
397b0d5b4d2SChengji Yao   removeTrailingEmptySubArray(shardingOption.shardingArray);
398b0d5b4d2SChengji Yao   if (!anyShardingInResultsOrOperands)
399b0d5b4d2SChengji Yao     shardingOption.empty = true;
400b0d5b4d2SChengji Yao   return shardingOption;
401b0d5b4d2SChengji Yao }
402b0d5b4d2SChengji Yao 
403d635b860SBoian Petkantchin // Get the sharding attributed for the given result and sharding option.
404*baabcb28SFrank Schlimbach MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
405d635b860SBoian Petkantchin                          AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
406ff2720d1SBoian Petkantchin                          ArrayRef<ReductionKind> reductionLoopKinds) {
407a5757c5bSChristian Sigg   auto resultType = cast<RankedTensorType>(result.getType());
4087a4c4975SBoian Petkantchin   SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
4097a4c4975SBoian Petkantchin   SmallVector<MeshAxis> partialAxes;
410b0d5b4d2SChengji Yao 
411b0d5b4d2SChengji Yao   // process the split axes
412b0d5b4d2SChengji Yao   for (auto it : llvm::enumerate(map.getResults())) {
413*baabcb28SFrank Schlimbach     SmallVector<MeshAxis> tmp_axes;
414b0d5b4d2SChengji Yao     AffineExpr expr = it.value();
415b0d5b4d2SChengji Yao     // `expr` must be an `AffineDimExpr` because `map` is verified by
416b0d5b4d2SChengji Yao     // isProjectedPermutation
4171609f1c2Slong.chen     auto dim = cast<AffineDimExpr>(expr);
418b0d5b4d2SChengji Yao     unsigned loopIdx = dim.getPosition();
419b0d5b4d2SChengji Yao     if (loopIdx < shardingOption.shardingArray.size())
420b0d5b4d2SChengji Yao       splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
421b0d5b4d2SChengji Yao   }
422b0d5b4d2SChengji Yao 
423b0d5b4d2SChengji Yao   // process the partial axes
4247d367bc9SChengji Yao   // partialType will be ignored if partialAxes is empty
425ff2720d1SBoian Petkantchin   ReductionKind partialType = ReductionKind::Sum;
426ff2720d1SBoian Petkantchin   size_t reductionLoopKindsIdx = 0;
427b0d5b4d2SChengji Yao   for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
428ff2720d1SBoian Petkantchin     utils::IteratorType iType = std::get<0>(it);
429b0d5b4d2SChengji Yao     if (isReductionLoop(iType)) {
430ff2720d1SBoian Petkantchin       ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
431ff2720d1SBoian Petkantchin       ++reductionLoopKindsIdx;
432b0d5b4d2SChengji Yao       if (!partialAxes.empty())
433b0d5b4d2SChengji Yao         assert(partialType == curPartialType &&
434b0d5b4d2SChengji Yao                "Only one reduction type is supported");
435b0d5b4d2SChengji Yao       partialType = curPartialType;
4367a4c4975SBoian Petkantchin       const SmallVector<MeshAxis> &axis = std::get<1>(it);
437b0d5b4d2SChengji Yao       partialAxes.append(axis);
438b0d5b4d2SChengji Yao     }
439b0d5b4d2SChengji Yao   }
440b0d5b4d2SChengji Yao 
441b0d5b4d2SChengji Yao   removeTrailingEmptySubArray(splitAxes);
442*baabcb28SFrank Schlimbach   return MeshSharding::get(shardingOption.mesh,
443*baabcb28SFrank Schlimbach                            fromArrayOfVector(result.getContext(), splitAxes),
444*baabcb28SFrank Schlimbach                            partialAxes, partialType);
445b0d5b4d2SChengji Yao }
446b0d5b4d2SChengji Yao 
447*baabcb28SFrank Schlimbach static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
448*baabcb28SFrank Schlimbach                                            const ShardingOption &shardingOption,
449ff2720d1SBoian Petkantchin                                            AffineMap map) {
450d635b860SBoian Petkantchin   Value operandValue = opOperand.get();
451d635b860SBoian Petkantchin   auto operandType = cast<RankedTensorType>(operandValue.getType());
4527a4c4975SBoian Petkantchin   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
453b0d5b4d2SChengji Yao   unsigned numDims = map.getNumDims();
454b0d5b4d2SChengji Yao   for (auto it : llvm::enumerate(map.getResults())) {
455b0d5b4d2SChengji Yao     int64_t idx = it.index();
456b0d5b4d2SChengji Yao     AffineExpr expr = it.value();
457b0d5b4d2SChengji Yao     FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
458b0d5b4d2SChengji Yao         checkOperandAffineExpr(expr, numDims);
459b0d5b4d2SChengji Yao     if (failed(loopIndices))
460b0d5b4d2SChengji Yao       return failure();
461b0d5b4d2SChengji Yao     SmallVector<unsigned> shardedLoopIndices;
462b0d5b4d2SChengji Yao     for (unsigned loopIdx : *loopIndices) {
463b0d5b4d2SChengji Yao       if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
464b0d5b4d2SChengji Yao           !shardingOption.shardingArray[loopIdx].empty())
465b0d5b4d2SChengji Yao         shardedLoopIndices.push_back(loopIdx);
466b0d5b4d2SChengji Yao     }
467b0d5b4d2SChengji Yao     // mostly one sharded loop index is accepted
468b0d5b4d2SChengji Yao     if (shardedLoopIndices.size() > 1)
469b0d5b4d2SChengji Yao       return failure();
470b0d5b4d2SChengji Yao     if (shardedLoopIndices.size() == 1) {
471b0d5b4d2SChengji Yao       splitAxes[idx].append(
472b0d5b4d2SChengji Yao           shardingOption.shardingArray[shardedLoopIndices[0]]);
473b0d5b4d2SChengji Yao     }
474b0d5b4d2SChengji Yao   }
475b0d5b4d2SChengji Yao 
476b0d5b4d2SChengji Yao   removeTrailingEmptySubArray(splitAxes);
477*baabcb28SFrank Schlimbach   return MeshSharding::get(
478*baabcb28SFrank Schlimbach       shardingOption.mesh,
479*baabcb28SFrank Schlimbach       fromArrayOfVector(opOperand.get().getContext(), splitAxes));
480d635b860SBoian Petkantchin }
481d635b860SBoian Petkantchin 
482*baabcb28SFrank Schlimbach FailureOr<std::vector<MeshSharding>>
483d635b860SBoian Petkantchin mesh::detail::defaultGetShardingAnnotations(
484d635b860SBoian Petkantchin     Operation *op, const ShardingOption &shardingOption) {
485*baabcb28SFrank Schlimbach   std::vector<MeshSharding> res;
486d635b860SBoian Petkantchin 
487d635b860SBoian Petkantchin   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
488d635b860SBoian Petkantchin   SmallVector<utils::IteratorType> loopTypes =
489d635b860SBoian Petkantchin       shardingOp.getLoopIteratorTypes();
490d635b860SBoian Petkantchin   SmallVector<ReductionKind> reductionKinds =
491d635b860SBoian Petkantchin       shardingOp.getReductionLoopIteratorKinds();
492d635b860SBoian Petkantchin   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
493d635b860SBoian Petkantchin   unsigned numOperands = op->getNumOperands();
494d635b860SBoian Petkantchin 
495d635b860SBoian Petkantchin   for (OpOperand &opOperand : op->getOpOperands()) {
496*baabcb28SFrank Schlimbach     FailureOr<MeshSharding> shardingAttr = getSharding(
497d635b860SBoian Petkantchin         opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
498d635b860SBoian Petkantchin     if (failed(shardingAttr))
499d635b860SBoian Petkantchin       return failure();
500d635b860SBoian Petkantchin     res.push_back(*shardingAttr);
501d635b860SBoian Petkantchin   }
502d635b860SBoian Petkantchin 
503d635b860SBoian Petkantchin   for (OpResult result : op->getResults()) {
504*baabcb28SFrank Schlimbach     res.push_back(getSharding(result, shardingOption,
505*baabcb28SFrank Schlimbach                               maps[numOperands + result.getResultNumber()],
506d635b860SBoian Petkantchin                               loopTypes, reductionKinds));
507d635b860SBoian Petkantchin   }
508d635b860SBoian Petkantchin 
509d635b860SBoian Petkantchin   return res;
510d635b860SBoian Petkantchin }
511d635b860SBoian Petkantchin 
512d635b860SBoian Petkantchin //===----------------------------------------------------------------------===//
513d635b860SBoian Petkantchin // detail::defaultAddShardingAnnotations
514d635b860SBoian Petkantchin //===----------------------------------------------------------------------===//
515d635b860SBoian Petkantchin 
516d635b860SBoian Petkantchin // To add a `mesh.shard` op for the given result, based on the details provided
517d635b860SBoian Petkantchin // in `shardingOption`, `map`, and `loopTypes`.
518d635b860SBoian Petkantchin static LogicalResult addShardOp(OpBuilder &b, OpResult result,
519d635b860SBoian Petkantchin                                 const ShardingOption &shardingOption,
520d635b860SBoian Petkantchin                                 AffineMap map,
521d635b860SBoian Petkantchin                                 ArrayRef<utils::IteratorType> loopTypes,
522d635b860SBoian Petkantchin                                 ArrayRef<ReductionKind> reductionLoopKinds) {
523*baabcb28SFrank Schlimbach   MeshSharding sharding =
524*baabcb28SFrank Schlimbach       getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
525*baabcb28SFrank Schlimbach   maybeInsertTargetShardingAnnotation(sharding, result, b);
526d635b860SBoian Petkantchin 
527d635b860SBoian Petkantchin   return success();
528d635b860SBoian Petkantchin }
529d635b860SBoian Petkantchin 
530d635b860SBoian Petkantchin // To add a `mesh.shard` op for the given operand, based on the details provided
531d635b860SBoian Petkantchin // in `shardingOption`, `map`, and `loopTypes`.
532d635b860SBoian Petkantchin static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
533d635b860SBoian Petkantchin                                 const ShardingOption &shardingOption,
534d635b860SBoian Petkantchin                                 AffineMap map) {
535d635b860SBoian Petkantchin 
536*baabcb28SFrank Schlimbach   FailureOr<MeshSharding> sharding =
537*baabcb28SFrank Schlimbach       getSharding(opOperand, shardingOption, map);
538*baabcb28SFrank Schlimbach   if (failed(sharding)) {
539d635b860SBoian Petkantchin     return failure();
540d635b860SBoian Petkantchin   }
541b0d5b4d2SChengji Yao   OpBuilder::InsertionGuard guard(b);
542*baabcb28SFrank Schlimbach   maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
543b0d5b4d2SChengji Yao 
544b0d5b4d2SChengji Yao   return success();
545b0d5b4d2SChengji Yao }
546b0d5b4d2SChengji Yao 
547b0d5b4d2SChengji Yao LogicalResult mesh::detail::defaultAddShardingAnnotations(
548b0d5b4d2SChengji Yao     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
549d635b860SBoian Petkantchin   assert(!shardingOption.empty && shardingOption.mesh);
550d635b860SBoian Petkantchin 
551b0d5b4d2SChengji Yao   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
552ff2720d1SBoian Petkantchin   SmallVector<utils::IteratorType> loopTypes =
553ff2720d1SBoian Petkantchin       shardingOp.getLoopIteratorTypes();
554ff2720d1SBoian Petkantchin   SmallVector<ReductionKind> reductionKinds =
555ff2720d1SBoian Petkantchin       shardingOp.getReductionLoopIteratorKinds();
556b0d5b4d2SChengji Yao   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
557b0d5b4d2SChengji Yao   unsigned numOperands = op->getNumOperands();
558b0d5b4d2SChengji Yao 
559b0d5b4d2SChengji Yao   // 1. add mesh.shard ops for all op results
560b0d5b4d2SChengji Yao   for (OpResult result : op->getResults()) {
561b0d5b4d2SChengji Yao     if (failed(addShardOp(b, result, shardingOption,
562b0d5b4d2SChengji Yao                           maps[numOperands + result.getResultNumber()],
563ff2720d1SBoian Petkantchin                           loopTypes, reductionKinds)))
564b0d5b4d2SChengji Yao       return failure();
565b0d5b4d2SChengji Yao   }
566b0d5b4d2SChengji Yao 
567b0d5b4d2SChengji Yao   // 2. add mesh.shard ops for all operands
568b0d5b4d2SChengji Yao   for (OpOperand &opOperand : op->getOpOperands()) {
569b0d5b4d2SChengji Yao     if (failed(addShardOp(b, opOperand, shardingOption,
570ff2720d1SBoian Petkantchin                           maps[opOperand.getOperandNumber()])))
571b0d5b4d2SChengji Yao       return failure();
572b0d5b4d2SChengji Yao   }
573b0d5b4d2SChengji Yao 
574b0d5b4d2SChengji Yao   return success();
575b0d5b4d2SChengji Yao }
576adbf21f1SBoian Petkantchin 
57786fa21e9SJie Fu #ifndef NDEBUG
578adbf21f1SBoian Petkantchin static bool
579adbf21f1SBoian Petkantchin isValueCompatibleWithFullReplicationSharding(Value value,
580*baabcb28SFrank Schlimbach                                              MeshSharding sharding) {
581a5757c5bSChristian Sigg   if (isa<RankedTensorType>(value.getType())) {
582adbf21f1SBoian Petkantchin     return sharding && isFullReplication(sharding);
583adbf21f1SBoian Petkantchin   }
584adbf21f1SBoian Petkantchin 
585adbf21f1SBoian Petkantchin   return !sharding;
586adbf21f1SBoian Petkantchin }
587adbf21f1SBoian Petkantchin 
588*baabcb28SFrank Schlimbach template <typename ValueRange, typename MeshShardingRage>
589*baabcb28SFrank Schlimbach static bool
590*baabcb28SFrank Schlimbach areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
591*baabcb28SFrank Schlimbach                                                 MeshShardingRage &&shardings) {
592adbf21f1SBoian Petkantchin   if (std::size(values) != std::size(shardings)) {
593adbf21f1SBoian Petkantchin     return false;
594adbf21f1SBoian Petkantchin   }
595*baabcb28SFrank Schlimbach   return llvm::all_of(
596*baabcb28SFrank Schlimbach       llvm::zip_equal(std::forward<ValueRange>(values),
597*baabcb28SFrank Schlimbach                       std::forward<MeshShardingRage>(shardings)),
598adbf21f1SBoian Petkantchin       [](auto valueAndSharding) {
599adbf21f1SBoian Petkantchin         return isValueCompatibleWithFullReplicationSharding(
600*baabcb28SFrank Schlimbach             std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
601adbf21f1SBoian Petkantchin       });
602adbf21f1SBoian Petkantchin }
60386fa21e9SJie Fu #endif // NDEBUG
604adbf21f1SBoian Petkantchin 
605adbf21f1SBoian Petkantchin void mesh::spmdizeFullyReplicatedOperation(
606adbf21f1SBoian Petkantchin     Operation &op, ArrayRef<Value> spmdizedOperands,
607*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> operandShardings,
608*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
609adbf21f1SBoian Petkantchin     SymbolTableCollection &symbolTable, OpBuilder &builder) {
610adbf21f1SBoian Petkantchin   assert(spmdizedOperands.size() == operandShardings.size());
611adbf21f1SBoian Petkantchin   assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
612adbf21f1SBoian Petkantchin                                                          operandShardings));
613adbf21f1SBoian Petkantchin   assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
614adbf21f1SBoian Petkantchin                                                          resultShardings));
615adbf21f1SBoian Petkantchin   // `clone` will populate the mapping of old to new results.
616adbf21f1SBoian Petkantchin   builder.clone(op, spmdizationMap);
617adbf21f1SBoian Petkantchin }
618adbf21f1SBoian Petkantchin 
619fb582b6aSBoian Petkantchin static void updateMeshAxisAssignmentForLoopIterators(
620fb582b6aSBoian Petkantchin     ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
621fb582b6aSBoian Petkantchin     SmallVector<std::optional<SmallVector<MeshAxis>>>
622fb582b6aSBoian Petkantchin         &meshAxesAssignmentForLoopIterators) {
623fb582b6aSBoian Petkantchin   AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
624fb582b6aSBoian Petkantchin   unsigned loopIteratorIdx = affineDimExpr.getPosition();
625fb582b6aSBoian Petkantchin   if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
626fb582b6aSBoian Petkantchin     assert(llvm::equal(meshAxesAssignmentForTensorAxis,
627fb582b6aSBoian Petkantchin                        *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
628fb582b6aSBoian Petkantchin   } else {
629fb582b6aSBoian Petkantchin     meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
630fb582b6aSBoian Petkantchin         llvm::to_vector(meshAxesAssignmentForTensorAxis);
631fb582b6aSBoian Petkantchin   }
632fb582b6aSBoian Petkantchin }
633fb582b6aSBoian Petkantchin 
634fb582b6aSBoian Petkantchin ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
635*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> operandShardings,
636*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings,
637fb582b6aSBoian Petkantchin     ArrayRef<utils::IteratorType> loopIteratorTypes,
638fb582b6aSBoian Petkantchin     ArrayRef<AffineMap> indexingMaps) {
639fb582b6aSBoian Petkantchin   SmallVector<std::optional<SmallVector<MeshAxis>>>
640fb582b6aSBoian Petkantchin       meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
641*baabcb28SFrank Schlimbach   std::vector<MeshSharding> operatorAndResultShardings;
642fb582b6aSBoian Petkantchin   operatorAndResultShardings.reserve(operandShardings.size() +
643fb582b6aSBoian Petkantchin                                      resultShardings.size());
644fb582b6aSBoian Petkantchin   llvm::append_range(operatorAndResultShardings, operandShardings);
645fb582b6aSBoian Petkantchin   for (auto [sharding, affineMap] :
646fb582b6aSBoian Petkantchin        llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
647fb582b6aSBoian Petkantchin     if (!sharding) {
648fb582b6aSBoian Petkantchin       continue;
649fb582b6aSBoian Petkantchin     }
650fb582b6aSBoian Petkantchin     for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
651fb582b6aSBoian Petkantchin          llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
652fb582b6aSBoian Petkantchin       updateMeshAxisAssignmentForLoopIterators(
653fb582b6aSBoian Petkantchin           meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
654fb582b6aSBoian Petkantchin           meshAxisAssignmentForLoopIterators);
655fb582b6aSBoian Petkantchin     }
656fb582b6aSBoian Petkantchin     // Missing trailing split axes means replication on those tensor dimensions.
657fb582b6aSBoian Petkantchin     for (unsigned i = sharding.getSplitAxes().size();
658fb582b6aSBoian Petkantchin          i < affineMap.getNumResults(); ++i) {
659fb582b6aSBoian Petkantchin       updateMeshAxisAssignmentForLoopIterators(
660fb582b6aSBoian Petkantchin           {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
661fb582b6aSBoian Petkantchin     }
662fb582b6aSBoian Petkantchin   }
663fb582b6aSBoian Petkantchin 
664fb582b6aSBoian Petkantchin   ShardingArray res;
665fb582b6aSBoian Petkantchin   llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
666fb582b6aSBoian Petkantchin                   [](std::optional<SmallVector<MeshAxis>> &axes) {
667fb582b6aSBoian Petkantchin                     if (!axes) {
668fb582b6aSBoian Petkantchin                       return SmallVector<MeshAxis>();
669fb582b6aSBoian Petkantchin                     };
670fb582b6aSBoian Petkantchin                     return std::move(*axes);
671fb582b6aSBoian Petkantchin                   });
672fb582b6aSBoian Petkantchin   return res;
673fb582b6aSBoian Petkantchin }
674fb582b6aSBoian Petkantchin 
675fb582b6aSBoian Petkantchin bool mesh::isAtLeastOneReductionIteratorSharded(
676fb582b6aSBoian Petkantchin     ArrayRef<utils::IteratorType> loopIteratorTypes,
677fb582b6aSBoian Petkantchin     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
678fb582b6aSBoian Petkantchin   for (auto [loopIteratorType, meshAxisAssignment] :
679fb582b6aSBoian Petkantchin        llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680fb582b6aSBoian Petkantchin     if (loopIteratorType == utils::IteratorType::reduction &&
681fb582b6aSBoian Petkantchin         !meshAxisAssignment.empty()) {
682fb582b6aSBoian Petkantchin       return true;
683fb582b6aSBoian Petkantchin     }
684fb582b6aSBoian Petkantchin   }
685fb582b6aSBoian Petkantchin   return false;
686fb582b6aSBoian Petkantchin }
687fb582b6aSBoian Petkantchin 
688fb582b6aSBoian Petkantchin SmallVector<MeshAxis> mesh::getReductionMeshAxes(
689fb582b6aSBoian Petkantchin     ArrayRef<utils::IteratorType> loopIteratorTypes,
690fb582b6aSBoian Petkantchin     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
691fb582b6aSBoian Petkantchin   SmallVector<MeshAxis> meshAxes;
692fb582b6aSBoian Petkantchin   for (auto [loopIteratorType, meshAxisAssignment] :
693fb582b6aSBoian Petkantchin        llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
694fb582b6aSBoian Petkantchin     if (loopIteratorType == utils::IteratorType::reduction) {
695fb582b6aSBoian Petkantchin       llvm::append_range(meshAxes, meshAxisAssignment);
696fb582b6aSBoian Petkantchin     }
697fb582b6aSBoian Petkantchin   }
698fb582b6aSBoian Petkantchin   return meshAxes;
699fb582b6aSBoian Petkantchin }
700fb582b6aSBoian Petkantchin 
701adbf21f1SBoian Petkantchin void mesh::spmdizeTriviallyShardableOperation(
702adbf21f1SBoian Petkantchin     Operation &op, ArrayRef<Value> spmdizedOperands,
703*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> operandShardings,
704*baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
705adbf21f1SBoian Petkantchin     SymbolTableCollection &symbolTable, OpBuilder &builder) {
706adbf21f1SBoian Petkantchin   // `clone` will populate the mapping of old to new results.
707adbf21f1SBoian Petkantchin   Operation *newOp = builder.clone(op, spmdizationMap);
708adbf21f1SBoian Petkantchin   // Set the result types to the sharded counterparts.
709adbf21f1SBoian Petkantchin   for (auto [oldResult, newResult, sharding] :
710fb582b6aSBoian Petkantchin        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
711*baabcb28SFrank Schlimbach     newResult.setType(
712*baabcb28SFrank Schlimbach         shardType(newResult.getType(),
713*baabcb28SFrank Schlimbach                   getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
714adbf21f1SBoian Petkantchin   }
715adbf21f1SBoian Petkantchin }
716