xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (revision 79eb406a67fe08458548289da72cda18248a9313)
1 //===- Spmdization.cpp --------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
14 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypeInterfaces.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/ControlFlowInterfaces.h"
28 #include "mlir/Interfaces/FunctionInterfaces.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
36 #include <iterator>
37 #include <optional>
38 #include <tuple>
39 #include <type_traits>
40 
41 namespace mlir::mesh {
42 
43 template <typename SourceAxes, typename TargetAxes>
44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45                                      const TargetAxes &targetAxes) {
46   return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47     return sourceAxes.contains(targetAxis);
48   });
49 }
50 
51 // Return the reduced value and its corresponding sharding.
52 // Example:
53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54 // targetSharding = <@mesh_1d, [[]]>
55 // Then will apply all-reduce on the source value
56 // and return it with the sharding <@mesh_1d, [[0]]>.
57 static std::tuple<TypedValue<ShapedType>, MeshSharding>
58 handlePartialAxesDuringResharding(OpBuilder &builder,
59                                   MeshSharding sourceSharding,
60                                   MeshSharding targetSharding,
61                                   TypedValue<ShapedType> sourceShard) {
62   if (sourceSharding.getPartialAxes().empty() &&
63       targetSharding.getPartialAxes().empty()) {
64     return {sourceShard, sourceSharding};
65   }
66   assert(targetSharding.getPartialAxes().empty() ||
67          (!sourceSharding.getPartialAxes().empty() &&
68           sourceSharding.getPartialType() == targetSharding.getPartialType()));
69   using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70   using AxisSet = llvm::SmallDenseSet<Axis>;
71   AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72                                        sourceSharding.getPartialAxes().end());
73   AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74                                        targetSharding.getPartialAxes().end());
75   assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76                                   targetShardingPartialAxesSet));
77   llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78   llvm::copy_if(sourceShardingPartialAxesSet,
79                 std::back_inserter(allReduceMeshAxes),
80                 [&targetShardingPartialAxesSet](Axis a) {
81                   return !targetShardingPartialAxesSet.contains(a);
82                 });
83   if (allReduceMeshAxes.empty()) {
84     return {sourceShard, sourceSharding};
85   }
86 
87   builder.setInsertionPointAfterValue(sourceShard);
88   TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89       builder
90           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91                                sourceSharding.getMeshAttr().getLeafReference(),
92                                allReduceMeshAxes, sourceShard,
93                                sourceSharding.getPartialType())
94           .getResult());
95 
96   llvm::SmallVector<MeshAxis> remainingPartialAxes;
97   llvm::copy_if(sourceShardingPartialAxesSet,
98                 std::back_inserter(allReduceMeshAxes),
99                 [&targetShardingPartialAxesSet](Axis a) {
100                   return targetShardingPartialAxesSet.contains(a);
101                 });
102   MeshSharding resultSharding = MeshSharding::get(
103       sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104       remainingPartialAxes, sourceSharding.getPartialType());
105   return {resultValue, resultSharding};
106 }
107 
108 static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
109                                                   MeshSharding sourceSharding,
110                                                   int64_t splitTensorAxis,
111                                                   MeshAxis splitMeshAxis) {
112   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113       llvm::to_vector(sourceSharding.getSplitAxes());
114   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115          splitTensorAxis) {
116     targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117   }
118   auto targetSplitAxes =
119       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120   targetSplitAxes.push_back(splitMeshAxis);
121   targetShardingSplitAxes[splitTensorAxis] =
122       MeshAxesAttr::get(ctx, targetSplitAxes);
123   return MeshSharding::get(
124       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126 }
127 
128 // Split a replicated tensor along a mesh axis.
129 // E.g. [[0, 1]] -> [[0, 1, 2]].
130 // Returns the spmdized target value with its sharding.
131 static std::tuple<TypedValue<ShapedType>, MeshSharding>
132 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
133                           MeshSharding sourceSharding,
134                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
135                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137       builder
138           .create<AllSliceOp>(sourceShard, mesh,
139                               ArrayRef<MeshAxis>(splitMeshAxis),
140                               splitTensorAxis)
141           .getResult());
142   MeshSharding targetSharding = targetShardingInSplitLastAxis(
143       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144   return {targetShard, targetSharding};
145 }
146 
147 // Detect if the resharding is of type e.g.
148 // [[0, 1]] -> [[0, 1, 2]].
149 // If detected, returns the corresponding tensor axis mesh axis pair.
150 // Does not detect insertions like
151 // [[0, 1]] -> [[0, 2, 1]].
152 static std::optional<std::tuple<int64_t, MeshAxis>>
153 detectSplitLastAxisInResharding(MeshSharding sourceSharding,
154                                 MeshSharding targetSharding) {
155   for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156        ++tensorAxis) {
157     if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158       if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159           targetSharding.getSplitAxes()[tensorAxis].size()) {
160         continue;
161       }
162       if (!llvm::equal(
163               sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164               llvm::make_range(
165                   targetSharding.getSplitAxes()[tensorAxis]
166                       .asArrayRef()
167                       .begin(),
168                   targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169                       1))) {
170         continue;
171       }
172     } else {
173       if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174         continue;
175       }
176     }
177     return std::make_tuple(
178         tensorAxis,
179         targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180   }
181   return std::nullopt;
182 }
183 
184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
185 trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
186                              MeshSharding sourceSharding,
187                              MeshSharding targetSharding,
188                              TypedValue<ShapedType> sourceShard) {
189   if (auto detectRes =
190           detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191     auto [tensorAxis, meshAxis] = detectRes.value();
192     return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193                                      tensorAxis, meshAxis);
194   }
195 
196   return std::nullopt;
197 }
198 
199 // Detect if the resharding is of type e.g.
200 // [[0, 1, 2]] -> [[0, 1]].
201 // If detected, returns the corresponding tensor axis mesh axis pair.
202 static std::optional<std::tuple<int64_t, MeshAxis>>
203 detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
204                                   MeshSharding targetSharding) {
205   for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206        ++tensorAxis) {
207     if (targetSharding.getSplitAxes().size() > tensorAxis) {
208       if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209           targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210         continue;
211       if (!llvm::equal(
212               llvm::make_range(
213                   sourceSharding.getSplitAxes()[tensorAxis]
214                       .asArrayRef()
215                       .begin(),
216                   sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217                       1),
218               targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219         continue;
220     } else {
221       if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222         continue;
223     }
224     return std::make_tuple(
225         tensorAxis,
226         sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227   }
228   return std::nullopt;
229 }
230 
231 static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
232                                                     MeshSharding sourceSharding,
233                                                     int64_t splitTensorAxis) {
234   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235       llvm::to_vector(sourceSharding.getSplitAxes());
236   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237          splitTensorAxis);
238   auto targetSplitAxes =
239       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240 
241   targetSplitAxes.pop_back();
242   targetShardingSplitAxes[splitTensorAxis] =
243       MeshAxesAttr::get(ctx, targetSplitAxes);
244   return MeshSharding::get(
245       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247 }
248 
249 static ShapedType allGatherResultShapeInUnsplitLastAxis(
250     ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251   SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252   targetShape[splitTensorAxis] =
253       gatherDimension(targetShape[splitTensorAxis], splitCount);
254   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255 }
256 
257 static std::tuple<TypedValue<ShapedType>, MeshSharding>
258 unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
259                             MeshSharding sourceSharding,
260                             ShapedType sourceUnshardedShape,
261                             TypedValue<ShapedType> sourceShard, MeshOp mesh,
262                             int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263   MLIRContext *ctx = builder.getContext();
264   builder.setInsertionPointAfterValue(sourceShard);
265 
266   MeshSharding targetSharding =
267       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269       sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270   Value allGatherResult = builder.create<AllGatherOp>(
271       RankedTensorType::get(allGatherResultShape.getShape(),
272                             allGatherResultShape.getElementType()),
273       mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274       APInt(64, splitTensorAxis));
275   ShapedType targetShape =
276       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278       builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279   return {targetShard, targetSharding};
280 }
281 
282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
283 tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
284                                MeshSharding sourceSharding,
285                                MeshSharding targetSharding,
286                                ShapedType sourceUnshardedShape,
287                                TypedValue<ShapedType> sourceShard) {
288   if (auto detectRes =
289           detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290     auto [tensorAxis, meshAxis] = detectRes.value();
291     return unsplitLastAxisInResharding(builder, sourceSharding,
292                                        sourceUnshardedShape, sourceShard, mesh,
293                                        tensorAxis, meshAxis);
294   }
295 
296   return std::nullopt;
297 }
298 
299 // Detect if the resharding is of type e.g.
300 // [[0, 1], [2]] -> [[0], [1, 2]].
301 // Only moving the last axis counts.
302 // If detected, returns the corresponding (source_tensor_axis,
303 // target_tensor_axis, mesh_axis) tuple.
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
305 detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
306                                     MeshSharding targetSharding) {
307   for (size_t sourceTensorAxis = 0;
308        sourceTensorAxis < sourceSharding.getSplitAxes().size();
309        ++sourceTensorAxis) {
310     for (size_t targetTensorAxis = 0;
311          targetTensorAxis < targetSharding.getSplitAxes().size();
312          ++targetTensorAxis) {
313       if (sourceTensorAxis == targetTensorAxis)
314         continue;
315       if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316           targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317           sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318               targetSharding.getSplitAxes()[targetTensorAxis]
319                   .asArrayRef()
320                   .back())
321         continue;
322       if (!llvm::equal(
323               llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324                                    .asArrayRef()
325                                    .begin(),
326                                sourceSharding.getSplitAxes()[sourceTensorAxis]
327                                        .asArrayRef()
328                                        .end() -
329                                    1),
330               llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331                                    .asArrayRef()
332                                    .begin(),
333                                targetSharding.getSplitAxes()[targetTensorAxis]
334                                        .asArrayRef()
335                                        .end() -
336                                    1)))
337         continue;
338       return std::make_tuple(
339           sourceTensorAxis, targetTensorAxis,
340           sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341     }
342   }
343   return std::nullopt;
344 }
345 
346 static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
347                                                  MeshSharding sourceSharding,
348                                                  int64_t sourceTensorAxis,
349                                                  int64_t targetTensorAxis) {
350   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351       llvm::to_vector(sourceSharding.getSplitAxes());
352   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353          targetTensorAxis) {
354     targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355   }
356 
357   auto sourceSplitAxes =
358       llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359   assert(!sourceSplitAxes.empty());
360   auto meshAxis = sourceSplitAxes.back();
361   sourceSplitAxes.pop_back();
362   targetShardingSplitAxes[sourceTensorAxis] =
363       MeshAxesAttr::get(ctx, sourceSplitAxes);
364 
365   auto targetSplitAxes =
366       llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367   targetSplitAxes.push_back(meshAxis);
368   targetShardingSplitAxes[targetTensorAxis] =
369       MeshAxesAttr::get(ctx, targetSplitAxes);
370 
371   return MeshSharding::get(
372       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374 }
375 
376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377                                                     int64_t splitCount,
378                                                     int64_t sourceTensorAxis,
379                                                     int64_t targetTensorAxis) {
380   SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381   targetShape[sourceTensorAxis] =
382       gatherDimension(targetShape[sourceTensorAxis], splitCount);
383   targetShape[targetTensorAxis] =
384       shardDimension(targetShape[targetTensorAxis], splitCount);
385   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386 }
387 
388 static std::tuple<TypedValue<ShapedType>, MeshSharding>
389 moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
390                               MeshSharding sourceSharding,
391                               ShapedType sourceUnshardedShape,
392                               TypedValue<ShapedType> sourceShard,
393                               int64_t sourceTensorAxis,
394                               int64_t targetTensorAxis, MeshAxis meshAxis) {
395   MLIRContext *ctx = builder.getContext();
396   builder.setInsertionPointAfterValue(sourceShard);
397 
398   MeshSharding targetSharding = targetShardingInMoveLastAxis(
399       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401       sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402       targetTensorAxis);
403   Value allToAllResult = builder.create<AllToAllOp>(
404       RankedTensorType::get(allToAllResultShape.getShape(),
405                             allToAllResultShape.getElementType()),
406       mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408   ShapedType targetShape =
409       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411       builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412   return {targetShard, targetSharding};
413 }
414 
415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
416 tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
417                                  MeshSharding sourceSharding,
418                                  MeshSharding targetSharding,
419                                  ShapedType sourceUnshardedShape,
420                                  TypedValue<ShapedType> sourceShard) {
421   if (auto detectRes =
422           detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423     auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
424     return moveLastSplitAxisInResharding(
425         builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426         sourceTensorAxis, targetTensorAxis, meshAxis);
427   }
428 
429   return std::nullopt;
430 }
431 
432 // Detect a change in the halo size (only) and create necessary operations if
433 // needed. A changed halo sizes requires copying the "core" of the source tensor
434 // into the "core" of the destination tensor followed by an update halo
435 // operation.
436 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
437 tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
438                           MeshSharding sourceSharding,
439                           MeshSharding targetSharding,
440                           ShapedType sourceUnshardedShape,
441                           TypedValue<ShapedType> sourceShard) {
442   // Currently handles only cases where halo sizes differ but everything else
443   // stays the same (from source to destination sharding).
444   if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445       !sourceSharding.getPartialAxes().empty() ||
446       !targetSharding.getPartialAxes().empty() ||
447       !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448       !targetSharding.getStaticShardedDimsOffsets().empty() ||
449       sourceSharding.equalHaloSizes(targetSharding)) {
450     return std::nullopt;
451   }
452 
453   auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454   auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455   assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456   assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457           !ShapedType::isDynamicShape(tgtHaloSizes) &&
458           sourceShard.getType().hasStaticShape()) &&
459          "dynamic shapes/halos are not supported yet for mesh-spmdization");
460   auto rank = sourceShard.getType().getRank();
461   auto splitAxes = sourceSharding.getSplitAxes();
462   SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463       strides(rank, 1), outShape(sourceShard.getType().getShape()),
464       coreShape(sourceShard.getType().getShape());
465 
466   // Determine "core" of source and destination.
467   // The core is the local part of the shard excluding halo regions.
468   for (auto i = 0u; i < rank; ++i) {
469     if (i < splitAxes.size() && !splitAxes[i].empty()) {
470       if (!srcHaloSizes.empty()) {
471         coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472         srcCoreOffs[i] = srcHaloSizes[i * 2];
473       }
474       tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475       outShape[i] =
476           coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477     }
478   }
479 
480   // Extract core from source and copy into destination core.
481   auto noVals = ValueRange{};
482   auto initVal = builder.create<tensor::EmptyOp>(
483       sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484   auto core = builder.create<tensor::ExtractSliceOp>(
485       sourceShard.getLoc(),
486       RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487       sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488   auto initOprnd = builder.create<tensor::InsertSliceOp>(
489       sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490       coreShape, strides);
491 
492   // Finally update the halo.
493   auto updateHaloResult =
494       builder
495           .create<UpdateHaloOp>(
496               sourceShard.getLoc(),
497               RankedTensorType::get(outShape,
498                                     sourceShard.getType().getElementType()),
499               initOprnd, mesh.getSymName(),
500               MeshAxesArrayAttr::get(builder.getContext(),
501                                      sourceSharding.getSplitAxes()),
502               targetSharding.getDynamicHaloSizes(),
503               targetSharding.getStaticHaloSizes())
504           .getResult();
505   return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
506                          targetSharding);
507 }
508 
509 // Handles only resharding on a 1D mesh.
510 // Currently the sharded tensor axes must be exactly divisible by the single
511 // mesh axis size.
512 static TypedValue<ShapedType>
513 reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
514                 MeshSharding sourceSharding, MeshSharding targetSharding,
515                 TypedValue<ShapedType> sourceUnshardedValue,
516                 TypedValue<ShapedType> sourceShard) {
517   assert(sourceShard.getType() ==
518          shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519   [[maybe_unused]] ShapedType targetShardType =
520       shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
521   assert(sourceShard.getType().getRank() == targetShardType.getRank());
522   assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
523 
524   auto [reducedSourceShard, reducedSourceSharding] =
525       handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
526                                         sourceShard);
527 
528   if (reducedSourceSharding == targetSharding) {
529     return reducedSourceShard;
530   }
531 
532   TypedValue<ShapedType> targetShard;
533   MeshSharding actualTargetSharding;
534   if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
535       targetSharding.getStaticShardedDimsOffsets().empty() &&
536       reducedSourceSharding.getStaticHaloSizes().empty() &&
537       targetSharding.getStaticHaloSizes().empty()) {
538     if (auto tryRes = tryMoveLastSplitAxisInResharding(
539             builder, mesh, reducedSourceSharding, targetSharding,
540             sourceUnshardedValue.getType(), reducedSourceShard)) {
541       std::tie(targetShard, actualTargetSharding) = tryRes.value();
542     } else if (auto tryRes = trySplitLastAxisInResharding(
543                    builder, mesh, reducedSourceSharding, targetSharding,
544                    reducedSourceShard)) {
545       std::tie(targetShard, actualTargetSharding) = tryRes.value();
546     } else if (auto tryRes = tryUnsplitLastAxisInResharding(
547                    builder, mesh, reducedSourceSharding, targetSharding,
548                    sourceUnshardedValue.getType(), reducedSourceShard)) {
549       std::tie(targetShard, actualTargetSharding) = tryRes.value();
550     }
551   }
552   assert(targetShard && "Did not find any pattern to apply.");
553   assert(actualTargetSharding == targetSharding);
554   assert(targetShard.getType() == targetShardType);
555   return targetShard;
556 }
557 
558 TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
559                                MeshSharding sourceSharding,
560                                MeshSharding targetSharding,
561                                TypedValue<ShapedType> sourceUnshardedValue,
562                                TypedValue<ShapedType> sourceShard) {
563   // If source and destination sharding are the same, no need to do anything.
564   if (sourceSharding == targetSharding) {
565     return sourceShard;
566   }
567 
568   // Tries to handle the case where the resharding is needed because the halo
569   // sizes are different. Supports arbitrary mesh dimensionality.
570   if (auto tryRes = tryUpdateHaloInResharding(
571           builder, mesh, sourceSharding, targetSharding,
572           sourceUnshardedValue.getType(), sourceShard)) {
573     return std::get<0>(tryRes.value()); // targetShard
574   }
575 
576   // Resort to handling only 1D meshes since the general case is complicated if
577   // it needs to be communication efficient in terms of minimizing the data
578   // transfered between devices.
579   return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
580                          sourceUnshardedValue, sourceShard);
581 }
582 
583 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
584                                ShardOp target,
585                                TypedValue<ShapedType> sourceShardValue) {
586   assert(source.getResult() == target.getSrc());
587   auto sourceSharding = source.getSharding();
588   auto targetSharding = target.getSharding();
589   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
590   return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
591                  cast<TypedValue<ShapedType>>(source.getSrc()),
592                  sourceShardValue);
593 }
594 
595 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
596                                ShardOp target,
597                                TypedValue<ShapedType> sourceShardValue,
598                                SymbolTableCollection &symbolTableCollection) {
599   MeshOp srcMesh = getMesh(source, symbolTableCollection);
600   assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
601   return reshard(builder, srcMesh, source, target, sourceShardValue);
602 }
603 
604 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
605   registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
606 }
607 
608 #define GEN_PASS_DEF_SPMDIZATION
609 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
610 
611 using UnshardedToShardedValueMap = DenseMap<Value, Value>;
612 
613 // Get the types of block arguments for an spmdized block.
614 // Reads the sharding annotations of the arguments to deduce the sharded types.
615 // Types that are not ranked tensors are left unchanged.
616 SmallVector<Type>
617 shardedBlockArgumentTypes(Block &block,
618                           SymbolTableCollection &symbolTableCollection) {
619   SmallVector<Type> res;
620   llvm::transform(
621       block.getArguments(), std::back_inserter(res),
622       [&symbolTableCollection](BlockArgument arg) {
623         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
624         if (!rankedTensorArg) {
625           return arg.getType();
626         }
627 
628         assert(rankedTensorArg.hasOneUse());
629         Operation *useOp = *rankedTensorArg.getUsers().begin();
630         ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
631         assert(shardOp);
632         MeshOp mesh = getMesh(shardOp, symbolTableCollection);
633         return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
634                                           shardOp.getSharding()));
635       });
636   return res;
637 }
638 
639 void spmdizeTriviallyShardableOperation(Operation &op,
640                                         ArrayRef<Value> spmdizedOperands,
641                                         ArrayRef<MeshSharding> operandShardings,
642                                         ArrayRef<MeshSharding> resultShardings,
643                                         IRMapping &spmdizationMap,
644                                         SymbolTableCollection &symbolTable,
645                                         OpBuilder &builder);
646 
647 static LogicalResult spmdizeOperation(
648     Operation &op, ArrayRef<Value> spmdizedOperands,
649     ArrayRef<MeshSharding> operandShardings,
650     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
651     SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
652   ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
653   if (!shardingInterface) {
654     // If there is no sharding interface we are conservative and assume that
655     // the op should be fully replicated no all devices.
656     spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
657                                     resultShardings, spmdizationMap,
658                                     symbolTableCollection, builder);
659   } else {
660     if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
661                                          resultShardings, spmdizationMap,
662                                          symbolTableCollection, builder))) {
663       return failure();
664     }
665   }
666 
667   assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
668     return spmdizationMap.contains(result);
669   }));
670 
671   return success();
672 }
673 
674 // Retrieve the sharding annotations for the operands of the given operation.
675 // If the type is not a ranked tensor it is not require to have an annotation.
676 static std::vector<MeshSharding> getOperandShardings(Operation &op) {
677   std::vector<MeshSharding> res;
678   res.reserve(op.getNumOperands());
679   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
680     TypedValue<RankedTensorType> rankedTensor =
681         dyn_cast<TypedValue<RankedTensorType>>(operand);
682     if (!rankedTensor) {
683       return MeshSharding();
684     }
685 
686     Operation *definingOp = operand.getDefiningOp();
687     assert(definingOp);
688     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
689     return MeshSharding(shardOp.getSharding());
690   });
691   return res;
692 }
693 
694 // Retrieve the sharding annotations for the results of the given operation.
695 // If the type is not a ranked tensor it is not require to have an annotation.
696 static std::vector<MeshSharding> getResultShardings(Operation &op) {
697   std::vector<MeshSharding> res;
698   res.reserve(op.getNumResults());
699   llvm::transform(op.getResults(), std::back_inserter(res),
700                   [](OpResult result) {
701                     TypedValue<RankedTensorType> rankedTensor =
702                         dyn_cast<TypedValue<RankedTensorType>>(result);
703                     if (!rankedTensor) {
704                       return MeshSharding();
705                     }
706 
707                     assert(result.hasOneUse());
708                     Operation *userOp = *result.getUsers().begin();
709                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
710                     return MeshSharding(shardOp.getSharding());
711                   });
712   return res;
713 }
714 
715 static LogicalResult
716 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
717                  SymbolTableCollection &symbolTableCollection,
718                  OpBuilder &builder) {
719   Value targetSpmdValue;
720 
721   // Check if 2 shard ops are chained. If not there is no need for resharding
722   // as the source and target shared the same sharding.
723   ShardOp srcShardOp =
724       dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
725   if (!srcShardOp) {
726     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
727   } else {
728     // Insert resharding.
729     TypedValue<ShapedType> srcSpmdValue =
730         cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
731     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
732                               symbolTableCollection);
733   }
734 
735   assert(!spmdizationMap.contains(shardOp.getResult()));
736   spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
737   return success();
738 }
739 
740 static LogicalResult
741 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
742                  SymbolTableCollection &symbolTableCollection,
743                  OpBuilder &builder) {
744   if (isa<ShardingOp>(op)) {
745     return success();
746   }
747 
748   ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
749   if (shardOp) {
750     return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
751                             builder);
752   }
753 
754   SmallVector<Value> spmdizedOperands;
755   llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
756                   [&spmdizationMap](Value operand) {
757                     assert(spmdizationMap.contains(operand));
758                     return spmdizationMap.lookup(operand);
759                   });
760   return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
761                           getResultShardings(op), spmdizationMap,
762                           symbolTableCollection, builder);
763 }
764 
765 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
766                                   SymbolTableCollection &symbolTableCollection,
767                                   OpBuilder &builder) {
768   SmallVector<Location> argLocations;
769   llvm::transform(block.getArguments(), std::back_inserter(argLocations),
770                   [](BlockArgument arg) { return arg.getLoc(); });
771   Block *newBlock = builder.createBlock(
772       block.getParent(), {},
773       shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
774   for (auto [unshardedBlockArg, spmdizedBlockArg] :
775        llvm::zip(block.getArguments(), newBlock->getArguments())) {
776     spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
777   }
778 
779   OpBuilder::InsertionGuard insertionGuard(builder);
780   builder.setInsertionPointToEnd(newBlock);
781   for (Operation &op : block.getOperations()) {
782     if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
783                                 builder))) {
784       return failure();
785     }
786   }
787 
788   return success();
789 }
790 
791 static LogicalResult
792 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
793               SymbolTableCollection &symbolTableCollection) {
794   OpBuilder builder(op.getFunctionBody());
795 
796   // Snapshot the original blocks to not mess up the iteration when adding new
797   // blocks.
798   SmallVector<Block *> originalBlocks;
799   llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
800                   [](Block &b) { return &b; });
801 
802   for (Block *block : originalBlocks) {
803     if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
804                             builder))) {
805       return failure();
806     }
807   }
808 
809   for (Block *block : originalBlocks) {
810     block->erase();
811   }
812 
813   // Find a return op and change the function results signature to its operands
814   // signature.
815   Operation *returnOp = nullptr;
816   for (Block &block : op.getFunctionBody()) {
817     if (block.empty()) {
818       continue;
819     }
820 
821     if (block.back().hasTrait<OpTrait::ReturnLike>()) {
822       returnOp = &block.back();
823       break;
824     }
825   }
826   assert(returnOp);
827   op.setType(FunctionType::get(op->getContext(),
828                                op.getFunctionBody().front().getArgumentTypes(),
829                                returnOp->getOperandTypes()));
830 
831   return success();
832 }
833 
834 namespace {
835 
836 struct Spmdization : public impl::SpmdizationBase<Spmdization> {
837   void runOnOperation() override {
838     IRMapping spmdizationMap;
839     SymbolTableCollection symbolTableCollection;
840     if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
841                              symbolTableCollection))) {
842       return signalPassFailure();
843     }
844   }
845 
846   void getDependentDialects(DialectRegistry &registry) const override {
847     reshardingRegisterDependentDialects(registry);
848     registry.insert<mesh::MeshDialect>();
849   }
850 };
851 
852 } // namespace
853 
854 } // namespace mlir::mesh
855