xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (revision 79eb406a67fe08458548289da72cda18248a9313)
11a8fb887SBoian Petkantchin //===- Spmdization.cpp --------------------------------------------- C++ --===//
21a8fb887SBoian Petkantchin //
31a8fb887SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41a8fb887SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
51a8fb887SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61a8fb887SBoian Petkantchin //
71a8fb887SBoian Petkantchin //===----------------------------------------------------------------------===//
81a8fb887SBoian Petkantchin 
91a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
10adbf21f1SBoian Petkantchin 
11adbf21f1SBoian Petkantchin #include "mlir/Dialect/Func/IR/FuncOps.h"
1231fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
131a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
14adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
151a8fb887SBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h"
161a8fb887SBoian Petkantchin #include "mlir/IR/Builders.h"
171a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinAttributes.h"
181a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h"
191a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypes.h"
20adbf21f1SBoian Petkantchin #include "mlir/IR/Diagnostics.h"
21adbf21f1SBoian Petkantchin #include "mlir/IR/IRMapping.h"
221a8fb887SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
231a8fb887SBoian Petkantchin #include "mlir/IR/Location.h"
241a8fb887SBoian Petkantchin #include "mlir/IR/MLIRContext.h"
25adbf21f1SBoian Petkantchin #include "mlir/IR/SymbolTable.h"
261a8fb887SBoian Petkantchin #include "mlir/IR/Value.h"
27abfac563SBoian Petkantchin #include "mlir/Interfaces/ControlFlowInterfaces.h"
28abfac563SBoian Petkantchin #include "mlir/Interfaces/FunctionInterfaces.h"
29adbf21f1SBoian Petkantchin #include "mlir/Pass/Pass.h"
301a8fb887SBoian Petkantchin #include "mlir/Support/LLVM.h"
311a8fb887SBoian Petkantchin #include "llvm/ADT/APInt.h"
321a8fb887SBoian Petkantchin #include "llvm/ADT/DenseSet.h"
331a8fb887SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
341a8fb887SBoian Petkantchin #include "llvm/ADT/SmallVector.h"
35adbf21f1SBoian Petkantchin #include "llvm/Support/Casting.h"
361a8fb887SBoian Petkantchin #include <iterator>
371a8fb887SBoian Petkantchin #include <optional>
381a8fb887SBoian Petkantchin #include <tuple>
391a8fb887SBoian Petkantchin #include <type_traits>
401a8fb887SBoian Petkantchin 
41adbf21f1SBoian Petkantchin namespace mlir::mesh {
421a8fb887SBoian Petkantchin 
431a8fb887SBoian Petkantchin template <typename SourceAxes, typename TargetAxes>
441a8fb887SBoian Petkantchin static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
451a8fb887SBoian Petkantchin                                      const TargetAxes &targetAxes) {
461a8fb887SBoian Petkantchin   return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
471a8fb887SBoian Petkantchin     return sourceAxes.contains(targetAxis);
481a8fb887SBoian Petkantchin   });
491a8fb887SBoian Petkantchin }
501a8fb887SBoian Petkantchin 
511a8fb887SBoian Petkantchin // Return the reduced value and its corresponding sharding.
521a8fb887SBoian Petkantchin // Example:
531a8fb887SBoian Petkantchin // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
541a8fb887SBoian Petkantchin // targetSharding = <@mesh_1d, [[]]>
551a8fb887SBoian Petkantchin // Then will apply all-reduce on the source value
561a8fb887SBoian Petkantchin // and return it with the sharding <@mesh_1d, [[0]]>.
57baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding>
581a8fb887SBoian Petkantchin handlePartialAxesDuringResharding(OpBuilder &builder,
59baabcb28SFrank Schlimbach                                   MeshSharding sourceSharding,
60baabcb28SFrank Schlimbach                                   MeshSharding targetSharding,
611a8fb887SBoian Petkantchin                                   TypedValue<ShapedType> sourceShard) {
621a8fb887SBoian Petkantchin   if (sourceSharding.getPartialAxes().empty() &&
631a8fb887SBoian Petkantchin       targetSharding.getPartialAxes().empty()) {
641a8fb887SBoian Petkantchin     return {sourceShard, sourceSharding};
651a8fb887SBoian Petkantchin   }
661a8fb887SBoian Petkantchin   assert(targetSharding.getPartialAxes().empty() ||
671a8fb887SBoian Petkantchin          (!sourceSharding.getPartialAxes().empty() &&
681a8fb887SBoian Petkantchin           sourceSharding.getPartialType() == targetSharding.getPartialType()));
691a8fb887SBoian Petkantchin   using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
701a8fb887SBoian Petkantchin   using AxisSet = llvm::SmallDenseSet<Axis>;
711a8fb887SBoian Petkantchin   AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
721a8fb887SBoian Petkantchin                                        sourceSharding.getPartialAxes().end());
731a8fb887SBoian Petkantchin   AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
741a8fb887SBoian Petkantchin                                        targetSharding.getPartialAxes().end());
751a8fb887SBoian Petkantchin   assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
761a8fb887SBoian Petkantchin                                   targetShardingPartialAxesSet));
771a8fb887SBoian Petkantchin   llvm::SmallVector<MeshAxis> allReduceMeshAxes;
781a8fb887SBoian Petkantchin   llvm::copy_if(sourceShardingPartialAxesSet,
791a8fb887SBoian Petkantchin                 std::back_inserter(allReduceMeshAxes),
801a8fb887SBoian Petkantchin                 [&targetShardingPartialAxesSet](Axis a) {
811a8fb887SBoian Petkantchin                   return !targetShardingPartialAxesSet.contains(a);
821a8fb887SBoian Petkantchin                 });
831a8fb887SBoian Petkantchin   if (allReduceMeshAxes.empty()) {
841a8fb887SBoian Petkantchin     return {sourceShard, sourceSharding};
851a8fb887SBoian Petkantchin   }
861a8fb887SBoian Petkantchin 
871a8fb887SBoian Petkantchin   builder.setInsertionPointAfterValue(sourceShard);
88a5757c5bSChristian Sigg   TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
891a8fb887SBoian Petkantchin       builder
901a8fb887SBoian Petkantchin           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91baabcb28SFrank Schlimbach                                sourceSharding.getMeshAttr().getLeafReference(),
921a8fb887SBoian Petkantchin                                allReduceMeshAxes, sourceShard,
931a8fb887SBoian Petkantchin                                sourceSharding.getPartialType())
94a5757c5bSChristian Sigg           .getResult());
951a8fb887SBoian Petkantchin 
967a4c4975SBoian Petkantchin   llvm::SmallVector<MeshAxis> remainingPartialAxes;
971a8fb887SBoian Petkantchin   llvm::copy_if(sourceShardingPartialAxesSet,
981a8fb887SBoian Petkantchin                 std::back_inserter(allReduceMeshAxes),
991a8fb887SBoian Petkantchin                 [&targetShardingPartialAxesSet](Axis a) {
1001a8fb887SBoian Petkantchin                   return targetShardingPartialAxesSet.contains(a);
1011a8fb887SBoian Petkantchin                 });
102baabcb28SFrank Schlimbach   MeshSharding resultSharding = MeshSharding::get(
103baabcb28SFrank Schlimbach       sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104baabcb28SFrank Schlimbach       remainingPartialAxes, sourceSharding.getPartialType());
1051a8fb887SBoian Petkantchin   return {resultValue, resultSharding};
1061a8fb887SBoian Petkantchin }
1071a8fb887SBoian Petkantchin 
108baabcb28SFrank Schlimbach static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
109baabcb28SFrank Schlimbach                                                   MeshSharding sourceSharding,
110baabcb28SFrank Schlimbach                                                   int64_t splitTensorAxis,
111baabcb28SFrank Schlimbach                                                   MeshAxis splitMeshAxis) {
1127a4c4975SBoian Petkantchin   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
1131a8fb887SBoian Petkantchin       llvm::to_vector(sourceSharding.getSplitAxes());
1141a8fb887SBoian Petkantchin   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
1151a8fb887SBoian Petkantchin          splitTensorAxis) {
1167a4c4975SBoian Petkantchin     targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
1171a8fb887SBoian Petkantchin   }
1181a8fb887SBoian Petkantchin   auto targetSplitAxes =
1191a8fb887SBoian Petkantchin       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
1201a8fb887SBoian Petkantchin   targetSplitAxes.push_back(splitMeshAxis);
1211a8fb887SBoian Petkantchin   targetShardingSplitAxes[splitTensorAxis] =
1227a4c4975SBoian Petkantchin       MeshAxesAttr::get(ctx, targetSplitAxes);
123baabcb28SFrank Schlimbach   return MeshSharding::get(
124baabcb28SFrank Schlimbach       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
1251a8fb887SBoian Petkantchin       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
1261a8fb887SBoian Petkantchin }
1271a8fb887SBoian Petkantchin 
1281a8fb887SBoian Petkantchin // Split a replicated tensor along a mesh axis.
129ffc7feadSFrank Schlimbach // E.g. [[0, 1]] -> [[0, 1, 2]].
1301a8fb887SBoian Petkantchin // Returns the spmdized target value with its sharding.
131baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding>
1321a8fb887SBoian Petkantchin splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
133baabcb28SFrank Schlimbach                           MeshSharding sourceSharding,
1349a8437f5SBoian Petkantchin                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
1351a8fb887SBoian Petkantchin                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136a5757c5bSChristian Sigg   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
1371a8fb887SBoian Petkantchin       builder
138dc3258c6SBoian Petkantchin           .create<AllSliceOp>(sourceShard, mesh,
139dc3258c6SBoian Petkantchin                               ArrayRef<MeshAxis>(splitMeshAxis),
140dc3258c6SBoian Petkantchin                               splitTensorAxis)
141a5757c5bSChristian Sigg           .getResult());
142baabcb28SFrank Schlimbach   MeshSharding targetSharding = targetShardingInSplitLastAxis(
143dc3258c6SBoian Petkantchin       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144dc3258c6SBoian Petkantchin   return {targetShard, targetSharding};
1451a8fb887SBoian Petkantchin }
1461a8fb887SBoian Petkantchin 
1471a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g.
1481a8fb887SBoian Petkantchin // [[0, 1]] -> [[0, 1, 2]].
1491a8fb887SBoian Petkantchin // If detected, returns the corresponding tensor axis mesh axis pair.
1501a8fb887SBoian Petkantchin // Does not detect insertions like
1511a8fb887SBoian Petkantchin // [[0, 1]] -> [[0, 2, 1]].
1521a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, MeshAxis>>
153baabcb28SFrank Schlimbach detectSplitLastAxisInResharding(MeshSharding sourceSharding,
154baabcb28SFrank Schlimbach                                 MeshSharding targetSharding) {
1551a8fb887SBoian Petkantchin   for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
1561a8fb887SBoian Petkantchin        ++tensorAxis) {
1571a8fb887SBoian Petkantchin     if (sourceSharding.getSplitAxes().size() > tensorAxis) {
1581a8fb887SBoian Petkantchin       if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
1591a8fb887SBoian Petkantchin           targetSharding.getSplitAxes()[tensorAxis].size()) {
1601a8fb887SBoian Petkantchin         continue;
1611a8fb887SBoian Petkantchin       }
1621a8fb887SBoian Petkantchin       if (!llvm::equal(
1631a8fb887SBoian Petkantchin               sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
1641a8fb887SBoian Petkantchin               llvm::make_range(
1651a8fb887SBoian Petkantchin                   targetSharding.getSplitAxes()[tensorAxis]
1661a8fb887SBoian Petkantchin                       .asArrayRef()
1671a8fb887SBoian Petkantchin                       .begin(),
1681a8fb887SBoian Petkantchin                   targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
1691a8fb887SBoian Petkantchin                       1))) {
1701a8fb887SBoian Petkantchin         continue;
1711a8fb887SBoian Petkantchin       }
1721a8fb887SBoian Petkantchin     } else {
1731a8fb887SBoian Petkantchin       if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
1741a8fb887SBoian Petkantchin         continue;
1751a8fb887SBoian Petkantchin       }
1761a8fb887SBoian Petkantchin     }
1771a8fb887SBoian Petkantchin     return std::make_tuple(
1781a8fb887SBoian Petkantchin         tensorAxis,
1791a8fb887SBoian Petkantchin         targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
1801a8fb887SBoian Petkantchin   }
1811a8fb887SBoian Petkantchin   return std::nullopt;
1821a8fb887SBoian Petkantchin }
1831a8fb887SBoian Petkantchin 
184baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
1859a8437f5SBoian Petkantchin trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
186baabcb28SFrank Schlimbach                              MeshSharding sourceSharding,
187baabcb28SFrank Schlimbach                              MeshSharding targetSharding,
1881a8fb887SBoian Petkantchin                              TypedValue<ShapedType> sourceShard) {
1891a8fb887SBoian Petkantchin   if (auto detectRes =
1901a8fb887SBoian Petkantchin           detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
1911a8fb887SBoian Petkantchin     auto [tensorAxis, meshAxis] = detectRes.value();
1921a8fb887SBoian Petkantchin     return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
1931a8fb887SBoian Petkantchin                                      tensorAxis, meshAxis);
1941a8fb887SBoian Petkantchin   }
1951a8fb887SBoian Petkantchin 
1961a8fb887SBoian Petkantchin   return std::nullopt;
1971a8fb887SBoian Petkantchin }
1981a8fb887SBoian Petkantchin 
1991a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g.
2001a8fb887SBoian Petkantchin // [[0, 1, 2]] -> [[0, 1]].
2011a8fb887SBoian Petkantchin // If detected, returns the corresponding tensor axis mesh axis pair.
2021a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, MeshAxis>>
203baabcb28SFrank Schlimbach detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
204baabcb28SFrank Schlimbach                                   MeshSharding targetSharding) {
2051a8fb887SBoian Petkantchin   for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
2061a8fb887SBoian Petkantchin        ++tensorAxis) {
2071a8fb887SBoian Petkantchin     if (targetSharding.getSplitAxes().size() > tensorAxis) {
2081a8fb887SBoian Petkantchin       if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
2091a8fb887SBoian Petkantchin           targetSharding.getSplitAxes()[tensorAxis].size() + 1)
2101a8fb887SBoian Petkantchin         continue;
2111a8fb887SBoian Petkantchin       if (!llvm::equal(
2121a8fb887SBoian Petkantchin               llvm::make_range(
2131a8fb887SBoian Petkantchin                   sourceSharding.getSplitAxes()[tensorAxis]
2141a8fb887SBoian Petkantchin                       .asArrayRef()
2151a8fb887SBoian Petkantchin                       .begin(),
2161a8fb887SBoian Petkantchin                   sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
2171a8fb887SBoian Petkantchin                       1),
2181a8fb887SBoian Petkantchin               targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
2191a8fb887SBoian Petkantchin         continue;
2201a8fb887SBoian Petkantchin     } else {
2211a8fb887SBoian Petkantchin       if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
2221a8fb887SBoian Petkantchin         continue;
2231a8fb887SBoian Petkantchin     }
2241a8fb887SBoian Petkantchin     return std::make_tuple(
2251a8fb887SBoian Petkantchin         tensorAxis,
2261a8fb887SBoian Petkantchin         sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
2271a8fb887SBoian Petkantchin   }
2281a8fb887SBoian Petkantchin   return std::nullopt;
2291a8fb887SBoian Petkantchin }
2301a8fb887SBoian Petkantchin 
231baabcb28SFrank Schlimbach static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
232baabcb28SFrank Schlimbach                                                     MeshSharding sourceSharding,
2331a8fb887SBoian Petkantchin                                                     int64_t splitTensorAxis) {
2347a4c4975SBoian Petkantchin   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
2351a8fb887SBoian Petkantchin       llvm::to_vector(sourceSharding.getSplitAxes());
2361a8fb887SBoian Petkantchin   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
2371a8fb887SBoian Petkantchin          splitTensorAxis);
2381a8fb887SBoian Petkantchin   auto targetSplitAxes =
2391a8fb887SBoian Petkantchin       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
2401a8fb887SBoian Petkantchin 
2411a8fb887SBoian Petkantchin   targetSplitAxes.pop_back();
2421a8fb887SBoian Petkantchin   targetShardingSplitAxes[splitTensorAxis] =
2437a4c4975SBoian Petkantchin       MeshAxesAttr::get(ctx, targetSplitAxes);
244baabcb28SFrank Schlimbach   return MeshSharding::get(
245baabcb28SFrank Schlimbach       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
2461a8fb887SBoian Petkantchin       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
2471a8fb887SBoian Petkantchin }
2481a8fb887SBoian Petkantchin 
2491a8fb887SBoian Petkantchin static ShapedType allGatherResultShapeInUnsplitLastAxis(
2501a8fb887SBoian Petkantchin     ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
2511a8fb887SBoian Petkantchin   SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
2521a8fb887SBoian Petkantchin   targetShape[splitTensorAxis] =
253adbf21f1SBoian Petkantchin       gatherDimension(targetShape[splitTensorAxis], splitCount);
2541a8fb887SBoian Petkantchin   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
2551a8fb887SBoian Petkantchin }
2561a8fb887SBoian Petkantchin 
257baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding>
2581a8fb887SBoian Petkantchin unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
259baabcb28SFrank Schlimbach                             MeshSharding sourceSharding,
2601a8fb887SBoian Petkantchin                             ShapedType sourceUnshardedShape,
2619a8437f5SBoian Petkantchin                             TypedValue<ShapedType> sourceShard, MeshOp mesh,
2621a8fb887SBoian Petkantchin                             int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
2631a8fb887SBoian Petkantchin   MLIRContext *ctx = builder.getContext();
2641a8fb887SBoian Petkantchin   builder.setInsertionPointAfterValue(sourceShard);
2651a8fb887SBoian Petkantchin 
266baabcb28SFrank Schlimbach   MeshSharding targetSharding =
26701a429c4SArda Unal       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
2681a8fb887SBoian Petkantchin   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
2695df2c00aSBoian Petkantchin       sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
2701a8fb887SBoian Petkantchin   Value allGatherResult = builder.create<AllGatherOp>(
2711a8fb887SBoian Petkantchin       RankedTensorType::get(allGatherResultShape.getShape(),
2721a8fb887SBoian Petkantchin                             allGatherResultShape.getElementType()),
2731a8fb887SBoian Petkantchin       mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
2741a8fb887SBoian Petkantchin       APInt(64, splitTensorAxis));
2751a8fb887SBoian Petkantchin   ShapedType targetShape =
2761a8fb887SBoian Petkantchin       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277a5757c5bSChristian Sigg   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278a5757c5bSChristian Sigg       builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
2791a8fb887SBoian Petkantchin   return {targetShard, targetSharding};
2801a8fb887SBoian Petkantchin }
2811a8fb887SBoian Petkantchin 
282baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
2839a8437f5SBoian Petkantchin tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
284baabcb28SFrank Schlimbach                                MeshSharding sourceSharding,
285baabcb28SFrank Schlimbach                                MeshSharding targetSharding,
2861a8fb887SBoian Petkantchin                                ShapedType sourceUnshardedShape,
2871a8fb887SBoian Petkantchin                                TypedValue<ShapedType> sourceShard) {
2881a8fb887SBoian Petkantchin   if (auto detectRes =
2891a8fb887SBoian Petkantchin           detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
2901a8fb887SBoian Petkantchin     auto [tensorAxis, meshAxis] = detectRes.value();
2911a8fb887SBoian Petkantchin     return unsplitLastAxisInResharding(builder, sourceSharding,
2921a8fb887SBoian Petkantchin                                        sourceUnshardedShape, sourceShard, mesh,
2931a8fb887SBoian Petkantchin                                        tensorAxis, meshAxis);
2941a8fb887SBoian Petkantchin   }
2951a8fb887SBoian Petkantchin 
2961a8fb887SBoian Petkantchin   return std::nullopt;
2971a8fb887SBoian Petkantchin }
2981a8fb887SBoian Petkantchin 
2991a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g.
3001a8fb887SBoian Petkantchin // [[0, 1], [2]] -> [[0], [1, 2]].
3011a8fb887SBoian Petkantchin // Only moving the last axis counts.
3021a8fb887SBoian Petkantchin // If detected, returns the corresponding (source_tensor_axis,
3031a8fb887SBoian Petkantchin // target_tensor_axis, mesh_axis) tuple.
3041a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
305baabcb28SFrank Schlimbach detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
306baabcb28SFrank Schlimbach                                     MeshSharding targetSharding) {
3071a8fb887SBoian Petkantchin   for (size_t sourceTensorAxis = 0;
3081a8fb887SBoian Petkantchin        sourceTensorAxis < sourceSharding.getSplitAxes().size();
3091a8fb887SBoian Petkantchin        ++sourceTensorAxis) {
3101a8fb887SBoian Petkantchin     for (size_t targetTensorAxis = 0;
3111a8fb887SBoian Petkantchin          targetTensorAxis < targetSharding.getSplitAxes().size();
3121a8fb887SBoian Petkantchin          ++targetTensorAxis) {
3131a8fb887SBoian Petkantchin       if (sourceTensorAxis == targetTensorAxis)
3141a8fb887SBoian Petkantchin         continue;
3151a8fb887SBoian Petkantchin       if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
3161a8fb887SBoian Petkantchin           targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
3171a8fb887SBoian Petkantchin           sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
3181a8fb887SBoian Petkantchin               targetSharding.getSplitAxes()[targetTensorAxis]
3191a8fb887SBoian Petkantchin                   .asArrayRef()
3201a8fb887SBoian Petkantchin                   .back())
3211a8fb887SBoian Petkantchin         continue;
3221a8fb887SBoian Petkantchin       if (!llvm::equal(
3231a8fb887SBoian Petkantchin               llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
3241a8fb887SBoian Petkantchin                                    .asArrayRef()
3251a8fb887SBoian Petkantchin                                    .begin(),
3261a8fb887SBoian Petkantchin                                sourceSharding.getSplitAxes()[sourceTensorAxis]
3271a8fb887SBoian Petkantchin                                        .asArrayRef()
3281a8fb887SBoian Petkantchin                                        .end() -
3291a8fb887SBoian Petkantchin                                    1),
3301a8fb887SBoian Petkantchin               llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
3311a8fb887SBoian Petkantchin                                    .asArrayRef()
3321a8fb887SBoian Petkantchin                                    .begin(),
3331a8fb887SBoian Petkantchin                                targetSharding.getSplitAxes()[targetTensorAxis]
3341a8fb887SBoian Petkantchin                                        .asArrayRef()
3351a8fb887SBoian Petkantchin                                        .end() -
3361a8fb887SBoian Petkantchin                                    1)))
3371a8fb887SBoian Petkantchin         continue;
3381a8fb887SBoian Petkantchin       return std::make_tuple(
3391a8fb887SBoian Petkantchin           sourceTensorAxis, targetTensorAxis,
3401a8fb887SBoian Petkantchin           sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
3411a8fb887SBoian Petkantchin     }
3421a8fb887SBoian Petkantchin   }
3431a8fb887SBoian Petkantchin   return std::nullopt;
3441a8fb887SBoian Petkantchin }
3451a8fb887SBoian Petkantchin 
346baabcb28SFrank Schlimbach static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
347baabcb28SFrank Schlimbach                                                  MeshSharding sourceSharding,
3481a8fb887SBoian Petkantchin                                                  int64_t sourceTensorAxis,
3491a8fb887SBoian Petkantchin                                                  int64_t targetTensorAxis) {
3507a4c4975SBoian Petkantchin   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
3511a8fb887SBoian Petkantchin       llvm::to_vector(sourceSharding.getSplitAxes());
3521a8fb887SBoian Petkantchin   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
3531a8fb887SBoian Petkantchin          targetTensorAxis) {
3547a4c4975SBoian Petkantchin     targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
3551a8fb887SBoian Petkantchin   }
3561a8fb887SBoian Petkantchin 
3571a8fb887SBoian Petkantchin   auto sourceSplitAxes =
3581a8fb887SBoian Petkantchin       llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
3591a8fb887SBoian Petkantchin   assert(!sourceSplitAxes.empty());
3601a8fb887SBoian Petkantchin   auto meshAxis = sourceSplitAxes.back();
3611a8fb887SBoian Petkantchin   sourceSplitAxes.pop_back();
3621a8fb887SBoian Petkantchin   targetShardingSplitAxes[sourceTensorAxis] =
3637a4c4975SBoian Petkantchin       MeshAxesAttr::get(ctx, sourceSplitAxes);
3641a8fb887SBoian Petkantchin 
3651a8fb887SBoian Petkantchin   auto targetSplitAxes =
3661a8fb887SBoian Petkantchin       llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
3671a8fb887SBoian Petkantchin   targetSplitAxes.push_back(meshAxis);
3681a8fb887SBoian Petkantchin   targetShardingSplitAxes[targetTensorAxis] =
3697a4c4975SBoian Petkantchin       MeshAxesAttr::get(ctx, targetSplitAxes);
3701a8fb887SBoian Petkantchin 
371baabcb28SFrank Schlimbach   return MeshSharding::get(
372baabcb28SFrank Schlimbach       sourceSharding.getMeshAttr(), targetShardingSplitAxes,
3731a8fb887SBoian Petkantchin       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
3741a8fb887SBoian Petkantchin }
3751a8fb887SBoian Petkantchin 
3761a8fb887SBoian Petkantchin static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
3771a8fb887SBoian Petkantchin                                                     int64_t splitCount,
3781a8fb887SBoian Petkantchin                                                     int64_t sourceTensorAxis,
3791a8fb887SBoian Petkantchin                                                     int64_t targetTensorAxis) {
3801a8fb887SBoian Petkantchin   SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
3811a8fb887SBoian Petkantchin   targetShape[sourceTensorAxis] =
382adbf21f1SBoian Petkantchin       gatherDimension(targetShape[sourceTensorAxis], splitCount);
3831a8fb887SBoian Petkantchin   targetShape[targetTensorAxis] =
3841a8fb887SBoian Petkantchin       shardDimension(targetShape[targetTensorAxis], splitCount);
3851a8fb887SBoian Petkantchin   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
3861a8fb887SBoian Petkantchin }
3871a8fb887SBoian Petkantchin 
388baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding>
3899a8437f5SBoian Petkantchin moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
390baabcb28SFrank Schlimbach                               MeshSharding sourceSharding,
3911a8fb887SBoian Petkantchin                               ShapedType sourceUnshardedShape,
3921a8fb887SBoian Petkantchin                               TypedValue<ShapedType> sourceShard,
3931a8fb887SBoian Petkantchin                               int64_t sourceTensorAxis,
3941a8fb887SBoian Petkantchin                               int64_t targetTensorAxis, MeshAxis meshAxis) {
3951a8fb887SBoian Petkantchin   MLIRContext *ctx = builder.getContext();
3961a8fb887SBoian Petkantchin   builder.setInsertionPointAfterValue(sourceShard);
3971a8fb887SBoian Petkantchin 
398baabcb28SFrank Schlimbach   MeshSharding targetSharding = targetShardingInMoveLastAxis(
3991a8fb887SBoian Petkantchin       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
4001a8fb887SBoian Petkantchin   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
4015df2c00aSBoian Petkantchin       sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
4025df2c00aSBoian Petkantchin       targetTensorAxis);
4031a8fb887SBoian Petkantchin   Value allToAllResult = builder.create<AllToAllOp>(
4041a8fb887SBoian Petkantchin       RankedTensorType::get(allToAllResultShape.getShape(),
4051a8fb887SBoian Petkantchin                             allToAllResultShape.getElementType()),
4061a8fb887SBoian Petkantchin       mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
4071a8fb887SBoian Petkantchin       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
4081a8fb887SBoian Petkantchin   ShapedType targetShape =
4091a8fb887SBoian Petkantchin       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410a5757c5bSChristian Sigg   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411a5757c5bSChristian Sigg       builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
4121a8fb887SBoian Petkantchin   return {targetShard, targetSharding};
4131a8fb887SBoian Petkantchin }
4141a8fb887SBoian Petkantchin 
415baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
4169a8437f5SBoian Petkantchin tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
417baabcb28SFrank Schlimbach                                  MeshSharding sourceSharding,
418baabcb28SFrank Schlimbach                                  MeshSharding targetSharding,
4191a8fb887SBoian Petkantchin                                  ShapedType sourceUnshardedShape,
4201a8fb887SBoian Petkantchin                                  TypedValue<ShapedType> sourceShard) {
4211a8fb887SBoian Petkantchin   if (auto detectRes =
4221a8fb887SBoian Petkantchin           detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
4231a8fb887SBoian Petkantchin     auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
4241a8fb887SBoian Petkantchin     return moveLastSplitAxisInResharding(
4251a8fb887SBoian Petkantchin         builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
4261a8fb887SBoian Petkantchin         sourceTensorAxis, targetTensorAxis, meshAxis);
4271a8fb887SBoian Petkantchin   }
4281a8fb887SBoian Petkantchin 
4291a8fb887SBoian Petkantchin   return std::nullopt;
4301a8fb887SBoian Petkantchin }
4311a8fb887SBoian Petkantchin 
432ffc7feadSFrank Schlimbach // Detect a change in the halo size (only) and create necessary operations if
433ffc7feadSFrank Schlimbach // needed. A changed halo sizes requires copying the "core" of the source tensor
434ffc7feadSFrank Schlimbach // into the "core" of the destination tensor followed by an update halo
435ffc7feadSFrank Schlimbach // operation.
436ffc7feadSFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
437ffc7feadSFrank Schlimbach tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
438ffc7feadSFrank Schlimbach                           MeshSharding sourceSharding,
439ffc7feadSFrank Schlimbach                           MeshSharding targetSharding,
440ffc7feadSFrank Schlimbach                           ShapedType sourceUnshardedShape,
441ffc7feadSFrank Schlimbach                           TypedValue<ShapedType> sourceShard) {
442ffc7feadSFrank Schlimbach   // Currently handles only cases where halo sizes differ but everything else
443ffc7feadSFrank Schlimbach   // stays the same (from source to destination sharding).
444ffc7feadSFrank Schlimbach   if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445ffc7feadSFrank Schlimbach       !sourceSharding.getPartialAxes().empty() ||
446ffc7feadSFrank Schlimbach       !targetSharding.getPartialAxes().empty() ||
447ffc7feadSFrank Schlimbach       !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448ffc7feadSFrank Schlimbach       !targetSharding.getStaticShardedDimsOffsets().empty() ||
449ffc7feadSFrank Schlimbach       sourceSharding.equalHaloSizes(targetSharding)) {
450ffc7feadSFrank Schlimbach     return std::nullopt;
451ffc7feadSFrank Schlimbach   }
452ffc7feadSFrank Schlimbach 
453ffc7feadSFrank Schlimbach   auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454ffc7feadSFrank Schlimbach   auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455ffc7feadSFrank Schlimbach   assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456ffc7feadSFrank Schlimbach   assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457ffc7feadSFrank Schlimbach           !ShapedType::isDynamicShape(tgtHaloSizes) &&
458ffc7feadSFrank Schlimbach           sourceShard.getType().hasStaticShape()) &&
459ffc7feadSFrank Schlimbach          "dynamic shapes/halos are not supported yet for mesh-spmdization");
460ffc7feadSFrank Schlimbach   auto rank = sourceShard.getType().getRank();
461ffc7feadSFrank Schlimbach   auto splitAxes = sourceSharding.getSplitAxes();
462ffc7feadSFrank Schlimbach   SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463ffc7feadSFrank Schlimbach       strides(rank, 1), outShape(sourceShard.getType().getShape()),
464ffc7feadSFrank Schlimbach       coreShape(sourceShard.getType().getShape());
465ffc7feadSFrank Schlimbach 
466ffc7feadSFrank Schlimbach   // Determine "core" of source and destination.
467ffc7feadSFrank Schlimbach   // The core is the local part of the shard excluding halo regions.
468ffc7feadSFrank Schlimbach   for (auto i = 0u; i < rank; ++i) {
469ffc7feadSFrank Schlimbach     if (i < splitAxes.size() && !splitAxes[i].empty()) {
470ffc7feadSFrank Schlimbach       if (!srcHaloSizes.empty()) {
471ffc7feadSFrank Schlimbach         coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472ffc7feadSFrank Schlimbach         srcCoreOffs[i] = srcHaloSizes[i * 2];
473ffc7feadSFrank Schlimbach       }
474ffc7feadSFrank Schlimbach       tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475ffc7feadSFrank Schlimbach       outShape[i] =
476ffc7feadSFrank Schlimbach           coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477ffc7feadSFrank Schlimbach     }
478ffc7feadSFrank Schlimbach   }
479ffc7feadSFrank Schlimbach 
480ffc7feadSFrank Schlimbach   // Extract core from source and copy into destination core.
481ffc7feadSFrank Schlimbach   auto noVals = ValueRange{};
482ffc7feadSFrank Schlimbach   auto initVal = builder.create<tensor::EmptyOp>(
483ffc7feadSFrank Schlimbach       sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484ffc7feadSFrank Schlimbach   auto core = builder.create<tensor::ExtractSliceOp>(
485ffc7feadSFrank Schlimbach       sourceShard.getLoc(),
486ffc7feadSFrank Schlimbach       RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487ffc7feadSFrank Schlimbach       sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488ffc7feadSFrank Schlimbach   auto initOprnd = builder.create<tensor::InsertSliceOp>(
489ffc7feadSFrank Schlimbach       sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490ffc7feadSFrank Schlimbach       coreShape, strides);
491ffc7feadSFrank Schlimbach 
492ffc7feadSFrank Schlimbach   // Finally update the halo.
493ffc7feadSFrank Schlimbach   auto updateHaloResult =
494ffc7feadSFrank Schlimbach       builder
495ffc7feadSFrank Schlimbach           .create<UpdateHaloOp>(
496ffc7feadSFrank Schlimbach               sourceShard.getLoc(),
497ffc7feadSFrank Schlimbach               RankedTensorType::get(outShape,
498ffc7feadSFrank Schlimbach                                     sourceShard.getType().getElementType()),
499*79eb406aSFrank Schlimbach               initOprnd, mesh.getSymName(),
500ffc7feadSFrank Schlimbach               MeshAxesArrayAttr::get(builder.getContext(),
501ffc7feadSFrank Schlimbach                                      sourceSharding.getSplitAxes()),
502ffc7feadSFrank Schlimbach               targetSharding.getDynamicHaloSizes(),
503ffc7feadSFrank Schlimbach               targetSharding.getStaticHaloSizes())
504ffc7feadSFrank Schlimbach           .getResult();
505ffc7feadSFrank Schlimbach   return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
506ffc7feadSFrank Schlimbach                          targetSharding);
507ffc7feadSFrank Schlimbach }
508ffc7feadSFrank Schlimbach 
5091a8fb887SBoian Petkantchin // Handles only resharding on a 1D mesh.
5101a8fb887SBoian Petkantchin // Currently the sharded tensor axes must be exactly divisible by the single
5111a8fb887SBoian Petkantchin // mesh axis size.
5121a8fb887SBoian Petkantchin static TypedValue<ShapedType>
5139a8437f5SBoian Petkantchin reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
514baabcb28SFrank Schlimbach                 MeshSharding sourceSharding, MeshSharding targetSharding,
5151a8fb887SBoian Petkantchin                 TypedValue<ShapedType> sourceUnshardedValue,
5161a8fb887SBoian Petkantchin                 TypedValue<ShapedType> sourceShard) {
5171a8fb887SBoian Petkantchin   assert(sourceShard.getType() ==
5181a8fb887SBoian Petkantchin          shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519ab43cf26SJie Fu   [[maybe_unused]] ShapedType targetShardType =
5201a8fb887SBoian Petkantchin       shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
5211a8fb887SBoian Petkantchin   assert(sourceShard.getType().getRank() == targetShardType.getRank());
5221a8fb887SBoian Petkantchin   assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
5231a8fb887SBoian Petkantchin 
5241a8fb887SBoian Petkantchin   auto [reducedSourceShard, reducedSourceSharding] =
5251a8fb887SBoian Petkantchin       handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
5261a8fb887SBoian Petkantchin                                         sourceShard);
5271a8fb887SBoian Petkantchin 
5281a8fb887SBoian Petkantchin   if (reducedSourceSharding == targetSharding) {
5291a8fb887SBoian Petkantchin     return reducedSourceShard;
5301a8fb887SBoian Petkantchin   }
5311a8fb887SBoian Petkantchin 
5321a8fb887SBoian Petkantchin   TypedValue<ShapedType> targetShard;
533baabcb28SFrank Schlimbach   MeshSharding actualTargetSharding;
534ffc7feadSFrank Schlimbach   if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
535ffc7feadSFrank Schlimbach       targetSharding.getStaticShardedDimsOffsets().empty() &&
536ffc7feadSFrank Schlimbach       reducedSourceSharding.getStaticHaloSizes().empty() &&
537ffc7feadSFrank Schlimbach       targetSharding.getStaticHaloSizes().empty()) {
5381a8fb887SBoian Petkantchin     if (auto tryRes = tryMoveLastSplitAxisInResharding(
5391a8fb887SBoian Petkantchin             builder, mesh, reducedSourceSharding, targetSharding,
5401a8fb887SBoian Petkantchin             sourceUnshardedValue.getType(), reducedSourceShard)) {
5411a8fb887SBoian Petkantchin       std::tie(targetShard, actualTargetSharding) = tryRes.value();
5421a8fb887SBoian Petkantchin     } else if (auto tryRes = trySplitLastAxisInResharding(
5431a8fb887SBoian Petkantchin                    builder, mesh, reducedSourceSharding, targetSharding,
5441a8fb887SBoian Petkantchin                    reducedSourceShard)) {
5451a8fb887SBoian Petkantchin       std::tie(targetShard, actualTargetSharding) = tryRes.value();
5461a8fb887SBoian Petkantchin     } else if (auto tryRes = tryUnsplitLastAxisInResharding(
5471a8fb887SBoian Petkantchin                    builder, mesh, reducedSourceSharding, targetSharding,
5481a8fb887SBoian Petkantchin                    sourceUnshardedValue.getType(), reducedSourceShard)) {
5491a8fb887SBoian Petkantchin       std::tie(targetShard, actualTargetSharding) = tryRes.value();
5501a8fb887SBoian Petkantchin     }
551baabcb28SFrank Schlimbach   }
552baabcb28SFrank Schlimbach   assert(targetShard && "Did not find any pattern to apply.");
5531a8fb887SBoian Petkantchin   assert(actualTargetSharding == targetSharding);
5541a8fb887SBoian Petkantchin   assert(targetShard.getType() == targetShardType);
5551a8fb887SBoian Petkantchin   return targetShard;
5561a8fb887SBoian Petkantchin }
5571a8fb887SBoian Petkantchin 
5589a8437f5SBoian Petkantchin TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
559baabcb28SFrank Schlimbach                                MeshSharding sourceSharding,
560baabcb28SFrank Schlimbach                                MeshSharding targetSharding,
5611a8fb887SBoian Petkantchin                                TypedValue<ShapedType> sourceUnshardedValue,
5621a8fb887SBoian Petkantchin                                TypedValue<ShapedType> sourceShard) {
563ffc7feadSFrank Schlimbach   // If source and destination sharding are the same, no need to do anything.
564ffc7feadSFrank Schlimbach   if (sourceSharding == targetSharding) {
565ffc7feadSFrank Schlimbach     return sourceShard;
566ffc7feadSFrank Schlimbach   }
567ffc7feadSFrank Schlimbach 
568ffc7feadSFrank Schlimbach   // Tries to handle the case where the resharding is needed because the halo
569ffc7feadSFrank Schlimbach   // sizes are different. Supports arbitrary mesh dimensionality.
570ffc7feadSFrank Schlimbach   if (auto tryRes = tryUpdateHaloInResharding(
571ffc7feadSFrank Schlimbach           builder, mesh, sourceSharding, targetSharding,
572ffc7feadSFrank Schlimbach           sourceUnshardedValue.getType(), sourceShard)) {
573ffc7feadSFrank Schlimbach     return std::get<0>(tryRes.value()); // targetShard
574ffc7feadSFrank Schlimbach   }
575ffc7feadSFrank Schlimbach 
5761a8fb887SBoian Petkantchin   // Resort to handling only 1D meshes since the general case is complicated if
5771a8fb887SBoian Petkantchin   // it needs to be communication efficient in terms of minimizing the data
5781a8fb887SBoian Petkantchin   // transfered between devices.
5791a8fb887SBoian Petkantchin   return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
5801a8fb887SBoian Petkantchin                          sourceUnshardedValue, sourceShard);
5811a8fb887SBoian Petkantchin }
5821a8fb887SBoian Petkantchin 
5839a8437f5SBoian Petkantchin TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
5849a8437f5SBoian Petkantchin                                ShardOp target,
5851a8fb887SBoian Petkantchin                                TypedValue<ShapedType> sourceShardValue) {
586baabcb28SFrank Schlimbach   assert(source.getResult() == target.getSrc());
587baabcb28SFrank Schlimbach   auto sourceSharding = source.getSharding();
588baabcb28SFrank Schlimbach   auto targetSharding = target.getSharding();
5891a8fb887SBoian Petkantchin   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
590baabcb28SFrank Schlimbach   return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
591baabcb28SFrank Schlimbach                  cast<TypedValue<ShapedType>>(source.getSrc()),
592baabcb28SFrank Schlimbach                  sourceShardValue);
5931a8fb887SBoian Petkantchin }
5941a8fb887SBoian Petkantchin 
595adbf21f1SBoian Petkantchin TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
596adbf21f1SBoian Petkantchin                                ShardOp target,
597adbf21f1SBoian Petkantchin                                TypedValue<ShapedType> sourceShardValue,
598adbf21f1SBoian Petkantchin                                SymbolTableCollection &symbolTableCollection) {
599adbf21f1SBoian Petkantchin   MeshOp srcMesh = getMesh(source, symbolTableCollection);
600adbf21f1SBoian Petkantchin   assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
601adbf21f1SBoian Petkantchin   return reshard(builder, srcMesh, source, target, sourceShardValue);
602adbf21f1SBoian Petkantchin }
603adbf21f1SBoian Petkantchin 
6041a8fb887SBoian Petkantchin void reshardingRegisterDependentDialects(DialectRegistry &registry) {
605dc3258c6SBoian Petkantchin   registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
6061a8fb887SBoian Petkantchin }
6071a8fb887SBoian Petkantchin 
608adbf21f1SBoian Petkantchin #define GEN_PASS_DEF_SPMDIZATION
609adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
610adbf21f1SBoian Petkantchin 
611adbf21f1SBoian Petkantchin using UnshardedToShardedValueMap = DenseMap<Value, Value>;
612adbf21f1SBoian Petkantchin 
613adbf21f1SBoian Petkantchin // Get the types of block arguments for an spmdized block.
614adbf21f1SBoian Petkantchin // Reads the sharding annotations of the arguments to deduce the sharded types.
615adbf21f1SBoian Petkantchin // Types that are not ranked tensors are left unchanged.
616adbf21f1SBoian Petkantchin SmallVector<Type>
617adbf21f1SBoian Petkantchin shardedBlockArgumentTypes(Block &block,
618adbf21f1SBoian Petkantchin                           SymbolTableCollection &symbolTableCollection) {
619adbf21f1SBoian Petkantchin   SmallVector<Type> res;
620a5757c5bSChristian Sigg   llvm::transform(
621a5757c5bSChristian Sigg       block.getArguments(), std::back_inserter(res),
622adbf21f1SBoian Petkantchin       [&symbolTableCollection](BlockArgument arg) {
623a5757c5bSChristian Sigg         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
624adbf21f1SBoian Petkantchin         if (!rankedTensorArg) {
625adbf21f1SBoian Petkantchin           return arg.getType();
626adbf21f1SBoian Petkantchin         }
627adbf21f1SBoian Petkantchin 
628adbf21f1SBoian Petkantchin         assert(rankedTensorArg.hasOneUse());
629adbf21f1SBoian Petkantchin         Operation *useOp = *rankedTensorArg.getUsers().begin();
630adbf21f1SBoian Petkantchin         ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
631adbf21f1SBoian Petkantchin         assert(shardOp);
632adbf21f1SBoian Petkantchin         MeshOp mesh = getMesh(shardOp, symbolTableCollection);
633a5757c5bSChristian Sigg         return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
634baabcb28SFrank Schlimbach                                           shardOp.getSharding()));
635adbf21f1SBoian Petkantchin       });
636adbf21f1SBoian Petkantchin   return res;
637adbf21f1SBoian Petkantchin }
638adbf21f1SBoian Petkantchin 
639baabcb28SFrank Schlimbach void spmdizeTriviallyShardableOperation(Operation &op,
640baabcb28SFrank Schlimbach                                         ArrayRef<Value> spmdizedOperands,
641baabcb28SFrank Schlimbach                                         ArrayRef<MeshSharding> operandShardings,
642baabcb28SFrank Schlimbach                                         ArrayRef<MeshSharding> resultShardings,
643baabcb28SFrank Schlimbach                                         IRMapping &spmdizationMap,
644baabcb28SFrank Schlimbach                                         SymbolTableCollection &symbolTable,
645baabcb28SFrank Schlimbach                                         OpBuilder &builder);
646baabcb28SFrank Schlimbach 
647adbf21f1SBoian Petkantchin static LogicalResult spmdizeOperation(
648adbf21f1SBoian Petkantchin     Operation &op, ArrayRef<Value> spmdizedOperands,
649baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> operandShardings,
650baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
651adbf21f1SBoian Petkantchin     SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
652adbf21f1SBoian Petkantchin   ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
653adbf21f1SBoian Petkantchin   if (!shardingInterface) {
654adbf21f1SBoian Petkantchin     // If there is no sharding interface we are conservative and assume that
655adbf21f1SBoian Petkantchin     // the op should be fully replicated no all devices.
656adbf21f1SBoian Petkantchin     spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
657adbf21f1SBoian Petkantchin                                     resultShardings, spmdizationMap,
658adbf21f1SBoian Petkantchin                                     symbolTableCollection, builder);
659adbf21f1SBoian Petkantchin   } else {
660adbf21f1SBoian Petkantchin     if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
661adbf21f1SBoian Petkantchin                                          resultShardings, spmdizationMap,
662adbf21f1SBoian Petkantchin                                          symbolTableCollection, builder))) {
663adbf21f1SBoian Petkantchin       return failure();
664adbf21f1SBoian Petkantchin     }
665adbf21f1SBoian Petkantchin   }
666adbf21f1SBoian Petkantchin 
667adbf21f1SBoian Petkantchin   assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
668adbf21f1SBoian Petkantchin     return spmdizationMap.contains(result);
669adbf21f1SBoian Petkantchin   }));
670adbf21f1SBoian Petkantchin 
671adbf21f1SBoian Petkantchin   return success();
672adbf21f1SBoian Petkantchin }
673adbf21f1SBoian Petkantchin 
674adbf21f1SBoian Petkantchin // Retrieve the sharding annotations for the operands of the given operation.
675adbf21f1SBoian Petkantchin // If the type is not a ranked tensor it is not require to have an annotation.
676baabcb28SFrank Schlimbach static std::vector<MeshSharding> getOperandShardings(Operation &op) {
677baabcb28SFrank Schlimbach   std::vector<MeshSharding> res;
678adbf21f1SBoian Petkantchin   res.reserve(op.getNumOperands());
679adbf21f1SBoian Petkantchin   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
680adbf21f1SBoian Petkantchin     TypedValue<RankedTensorType> rankedTensor =
681a5757c5bSChristian Sigg         dyn_cast<TypedValue<RankedTensorType>>(operand);
682adbf21f1SBoian Petkantchin     if (!rankedTensor) {
683baabcb28SFrank Schlimbach       return MeshSharding();
684adbf21f1SBoian Petkantchin     }
685adbf21f1SBoian Petkantchin 
686adbf21f1SBoian Petkantchin     Operation *definingOp = operand.getDefiningOp();
687adbf21f1SBoian Petkantchin     assert(definingOp);
688adbf21f1SBoian Petkantchin     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
689baabcb28SFrank Schlimbach     return MeshSharding(shardOp.getSharding());
690adbf21f1SBoian Petkantchin   });
691adbf21f1SBoian Petkantchin   return res;
692adbf21f1SBoian Petkantchin }
693adbf21f1SBoian Petkantchin 
694adbf21f1SBoian Petkantchin // Retrieve the sharding annotations for the results of the given operation.
695adbf21f1SBoian Petkantchin // If the type is not a ranked tensor it is not require to have an annotation.
696baabcb28SFrank Schlimbach static std::vector<MeshSharding> getResultShardings(Operation &op) {
697baabcb28SFrank Schlimbach   std::vector<MeshSharding> res;
698adbf21f1SBoian Petkantchin   res.reserve(op.getNumResults());
699adbf21f1SBoian Petkantchin   llvm::transform(op.getResults(), std::back_inserter(res),
700adbf21f1SBoian Petkantchin                   [](OpResult result) {
701adbf21f1SBoian Petkantchin                     TypedValue<RankedTensorType> rankedTensor =
702a5757c5bSChristian Sigg                         dyn_cast<TypedValue<RankedTensorType>>(result);
703adbf21f1SBoian Petkantchin                     if (!rankedTensor) {
704baabcb28SFrank Schlimbach                       return MeshSharding();
705adbf21f1SBoian Petkantchin                     }
706adbf21f1SBoian Petkantchin 
707adbf21f1SBoian Petkantchin                     assert(result.hasOneUse());
708adbf21f1SBoian Petkantchin                     Operation *userOp = *result.getUsers().begin();
709adbf21f1SBoian Petkantchin                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
710baabcb28SFrank Schlimbach                     return MeshSharding(shardOp.getSharding());
711adbf21f1SBoian Petkantchin                   });
712adbf21f1SBoian Petkantchin   return res;
713adbf21f1SBoian Petkantchin }
714adbf21f1SBoian Petkantchin 
715adbf21f1SBoian Petkantchin static LogicalResult
7164f7ab789SBoian Petkantchin spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
7174f7ab789SBoian Petkantchin                  SymbolTableCollection &symbolTableCollection,
7184f7ab789SBoian Petkantchin                  OpBuilder &builder) {
7194f7ab789SBoian Petkantchin   Value targetSpmdValue;
7204f7ab789SBoian Petkantchin 
7214f7ab789SBoian Petkantchin   // Check if 2 shard ops are chained. If not there is no need for resharding
7224f7ab789SBoian Petkantchin   // as the source and target shared the same sharding.
7234f7ab789SBoian Petkantchin   ShardOp srcShardOp =
724baabcb28SFrank Schlimbach       dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
7254f7ab789SBoian Petkantchin   if (!srcShardOp) {
726baabcb28SFrank Schlimbach     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
7274f7ab789SBoian Petkantchin   } else {
7284f7ab789SBoian Petkantchin     // Insert resharding.
729ffc7feadSFrank Schlimbach     TypedValue<ShapedType> srcSpmdValue =
730ffc7feadSFrank Schlimbach         cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
7314f7ab789SBoian Petkantchin     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
7324f7ab789SBoian Petkantchin                               symbolTableCollection);
7334f7ab789SBoian Petkantchin   }
7344f7ab789SBoian Petkantchin 
7354f7ab789SBoian Petkantchin   assert(!spmdizationMap.contains(shardOp.getResult()));
7364f7ab789SBoian Petkantchin   spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
7374f7ab789SBoian Petkantchin   return success();
7384f7ab789SBoian Petkantchin }
7394f7ab789SBoian Petkantchin 
7404f7ab789SBoian Petkantchin static LogicalResult
741adbf21f1SBoian Petkantchin spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
742adbf21f1SBoian Petkantchin                  SymbolTableCollection &symbolTableCollection,
743adbf21f1SBoian Petkantchin                  OpBuilder &builder) {
744baabcb28SFrank Schlimbach   if (isa<ShardingOp>(op)) {
745baabcb28SFrank Schlimbach     return success();
746baabcb28SFrank Schlimbach   }
747baabcb28SFrank Schlimbach 
748adbf21f1SBoian Petkantchin   ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
749adbf21f1SBoian Petkantchin   if (shardOp) {
7504f7ab789SBoian Petkantchin     return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
7514f7ab789SBoian Petkantchin                             builder);
752adbf21f1SBoian Petkantchin   }
753adbf21f1SBoian Petkantchin 
754adbf21f1SBoian Petkantchin   SmallVector<Value> spmdizedOperands;
755adbf21f1SBoian Petkantchin   llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
756adbf21f1SBoian Petkantchin                   [&spmdizationMap](Value operand) {
757adbf21f1SBoian Petkantchin                     assert(spmdizationMap.contains(operand));
758adbf21f1SBoian Petkantchin                     return spmdizationMap.lookup(operand);
759adbf21f1SBoian Petkantchin                   });
760adbf21f1SBoian Petkantchin   return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
761adbf21f1SBoian Petkantchin                           getResultShardings(op), spmdizationMap,
762adbf21f1SBoian Petkantchin                           symbolTableCollection, builder);
763adbf21f1SBoian Petkantchin }
764adbf21f1SBoian Petkantchin 
765adbf21f1SBoian Petkantchin static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
766adbf21f1SBoian Petkantchin                                   SymbolTableCollection &symbolTableCollection,
767adbf21f1SBoian Petkantchin                                   OpBuilder &builder) {
768adbf21f1SBoian Petkantchin   SmallVector<Location> argLocations;
769adbf21f1SBoian Petkantchin   llvm::transform(block.getArguments(), std::back_inserter(argLocations),
770adbf21f1SBoian Petkantchin                   [](BlockArgument arg) { return arg.getLoc(); });
771adbf21f1SBoian Petkantchin   Block *newBlock = builder.createBlock(
772adbf21f1SBoian Petkantchin       block.getParent(), {},
773adbf21f1SBoian Petkantchin       shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
774adbf21f1SBoian Petkantchin   for (auto [unshardedBlockArg, spmdizedBlockArg] :
775adbf21f1SBoian Petkantchin        llvm::zip(block.getArguments(), newBlock->getArguments())) {
776adbf21f1SBoian Petkantchin     spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
777adbf21f1SBoian Petkantchin   }
778adbf21f1SBoian Petkantchin 
779adbf21f1SBoian Petkantchin   OpBuilder::InsertionGuard insertionGuard(builder);
780adbf21f1SBoian Petkantchin   builder.setInsertionPointToEnd(newBlock);
781adbf21f1SBoian Petkantchin   for (Operation &op : block.getOperations()) {
782adbf21f1SBoian Petkantchin     if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
783adbf21f1SBoian Petkantchin                                 builder))) {
784adbf21f1SBoian Petkantchin       return failure();
785adbf21f1SBoian Petkantchin     }
786adbf21f1SBoian Petkantchin   }
787adbf21f1SBoian Petkantchin 
788adbf21f1SBoian Petkantchin   return success();
789adbf21f1SBoian Petkantchin }
790adbf21f1SBoian Petkantchin 
791adbf21f1SBoian Petkantchin static LogicalResult
792abfac563SBoian Petkantchin spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
793adbf21f1SBoian Petkantchin               SymbolTableCollection &symbolTableCollection) {
794adbf21f1SBoian Petkantchin   OpBuilder builder(op.getFunctionBody());
795adbf21f1SBoian Petkantchin 
796adbf21f1SBoian Petkantchin   // Snapshot the original blocks to not mess up the iteration when adding new
797adbf21f1SBoian Petkantchin   // blocks.
798adbf21f1SBoian Petkantchin   SmallVector<Block *> originalBlocks;
799adbf21f1SBoian Petkantchin   llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
800adbf21f1SBoian Petkantchin                   [](Block &b) { return &b; });
801adbf21f1SBoian Petkantchin 
802adbf21f1SBoian Petkantchin   for (Block *block : originalBlocks) {
803adbf21f1SBoian Petkantchin     if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
804adbf21f1SBoian Petkantchin                             builder))) {
805adbf21f1SBoian Petkantchin       return failure();
806adbf21f1SBoian Petkantchin     }
807adbf21f1SBoian Petkantchin   }
808adbf21f1SBoian Petkantchin 
809adbf21f1SBoian Petkantchin   for (Block *block : originalBlocks) {
810adbf21f1SBoian Petkantchin     block->erase();
811adbf21f1SBoian Petkantchin   }
812adbf21f1SBoian Petkantchin 
813adbf21f1SBoian Petkantchin   // Find a return op and change the function results signature to its operands
814adbf21f1SBoian Petkantchin   // signature.
815abfac563SBoian Petkantchin   Operation *returnOp = nullptr;
816abfac563SBoian Petkantchin   for (Block &block : op.getFunctionBody()) {
817adbf21f1SBoian Petkantchin     if (block.empty()) {
818adbf21f1SBoian Petkantchin       continue;
819adbf21f1SBoian Petkantchin     }
820adbf21f1SBoian Petkantchin 
821abfac563SBoian Petkantchin     if (block.back().hasTrait<OpTrait::ReturnLike>()) {
822abfac563SBoian Petkantchin       returnOp = &block.back();
823adbf21f1SBoian Petkantchin       break;
824adbf21f1SBoian Petkantchin     }
825adbf21f1SBoian Petkantchin   }
826adbf21f1SBoian Petkantchin   assert(returnOp);
827abfac563SBoian Petkantchin   op.setType(FunctionType::get(op->getContext(),
828abfac563SBoian Petkantchin                                op.getFunctionBody().front().getArgumentTypes(),
829adbf21f1SBoian Petkantchin                                returnOp->getOperandTypes()));
830adbf21f1SBoian Petkantchin 
831adbf21f1SBoian Petkantchin   return success();
832adbf21f1SBoian Petkantchin }
833adbf21f1SBoian Petkantchin 
834adbf21f1SBoian Petkantchin namespace {
835adbf21f1SBoian Petkantchin 
836adbf21f1SBoian Petkantchin struct Spmdization : public impl::SpmdizationBase<Spmdization> {
837adbf21f1SBoian Petkantchin   void runOnOperation() override {
838adbf21f1SBoian Petkantchin     IRMapping spmdizationMap;
839adbf21f1SBoian Petkantchin     SymbolTableCollection symbolTableCollection;
840adbf21f1SBoian Petkantchin     if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
841adbf21f1SBoian Petkantchin                              symbolTableCollection))) {
842adbf21f1SBoian Petkantchin       return signalPassFailure();
843adbf21f1SBoian Petkantchin     }
844adbf21f1SBoian Petkantchin   }
845adbf21f1SBoian Petkantchin 
846adbf21f1SBoian Petkantchin   void getDependentDialects(DialectRegistry &registry) const override {
847adbf21f1SBoian Petkantchin     reshardingRegisterDependentDialects(registry);
848adbf21f1SBoian Petkantchin     registry.insert<mesh::MeshDialect>();
849adbf21f1SBoian Petkantchin   }
850adbf21f1SBoian Petkantchin };
851adbf21f1SBoian Petkantchin 
852adbf21f1SBoian Petkantchin } // namespace
853adbf21f1SBoian Petkantchin 
854adbf21f1SBoian Petkantchin } // namespace mlir::mesh
855