11a8fb887SBoian Petkantchin //===- Spmdization.cpp --------------------------------------------- C++ --===// 21a8fb887SBoian Petkantchin // 31a8fb887SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41a8fb887SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information. 51a8fb887SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61a8fb887SBoian Petkantchin // 71a8fb887SBoian Petkantchin //===----------------------------------------------------------------------===// 81a8fb887SBoian Petkantchin 91a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" 10adbf21f1SBoian Petkantchin 11adbf21f1SBoian Petkantchin #include "mlir/Dialect/Func/IR/FuncOps.h" 1231fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 131a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h" 14adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 151a8fb887SBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h" 161a8fb887SBoian Petkantchin #include "mlir/IR/Builders.h" 171a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinAttributes.h" 181a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h" 191a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypes.h" 20adbf21f1SBoian Petkantchin #include "mlir/IR/Diagnostics.h" 21adbf21f1SBoian Petkantchin #include "mlir/IR/IRMapping.h" 221a8fb887SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h" 231a8fb887SBoian Petkantchin #include "mlir/IR/Location.h" 241a8fb887SBoian Petkantchin #include "mlir/IR/MLIRContext.h" 25adbf21f1SBoian Petkantchin #include "mlir/IR/SymbolTable.h" 261a8fb887SBoian Petkantchin #include "mlir/IR/Value.h" 27abfac563SBoian Petkantchin #include "mlir/Interfaces/ControlFlowInterfaces.h" 28abfac563SBoian Petkantchin #include "mlir/Interfaces/FunctionInterfaces.h" 29adbf21f1SBoian Petkantchin #include "mlir/Pass/Pass.h" 301a8fb887SBoian Petkantchin #include "mlir/Support/LLVM.h" 311a8fb887SBoian Petkantchin #include "llvm/ADT/APInt.h" 321a8fb887SBoian Petkantchin #include "llvm/ADT/DenseSet.h" 331a8fb887SBoian Petkantchin #include "llvm/ADT/STLExtras.h" 341a8fb887SBoian Petkantchin #include "llvm/ADT/SmallVector.h" 35adbf21f1SBoian Petkantchin #include "llvm/Support/Casting.h" 361a8fb887SBoian Petkantchin #include <iterator> 371a8fb887SBoian Petkantchin #include <optional> 381a8fb887SBoian Petkantchin #include <tuple> 391a8fb887SBoian Petkantchin #include <type_traits> 401a8fb887SBoian Petkantchin 41adbf21f1SBoian Petkantchin namespace mlir::mesh { 421a8fb887SBoian Petkantchin 431a8fb887SBoian Petkantchin template <typename SourceAxes, typename TargetAxes> 441a8fb887SBoian Petkantchin static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, 451a8fb887SBoian Petkantchin const TargetAxes &targetAxes) { 461a8fb887SBoian Petkantchin return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { 471a8fb887SBoian Petkantchin return sourceAxes.contains(targetAxis); 481a8fb887SBoian Petkantchin }); 491a8fb887SBoian Petkantchin } 501a8fb887SBoian Petkantchin 511a8fb887SBoian Petkantchin // Return the reduced value and its corresponding sharding. 521a8fb887SBoian Petkantchin // Example: 531a8fb887SBoian Petkantchin // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> 541a8fb887SBoian Petkantchin // targetSharding = <@mesh_1d, [[]]> 551a8fb887SBoian Petkantchin // Then will apply all-reduce on the source value 561a8fb887SBoian Petkantchin // and return it with the sharding <@mesh_1d, [[0]]>. 57baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding> 581a8fb887SBoian Petkantchin handlePartialAxesDuringResharding(OpBuilder &builder, 59baabcb28SFrank Schlimbach MeshSharding sourceSharding, 60baabcb28SFrank Schlimbach MeshSharding targetSharding, 611a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 621a8fb887SBoian Petkantchin if (sourceSharding.getPartialAxes().empty() && 631a8fb887SBoian Petkantchin targetSharding.getPartialAxes().empty()) { 641a8fb887SBoian Petkantchin return {sourceShard, sourceSharding}; 651a8fb887SBoian Petkantchin } 661a8fb887SBoian Petkantchin assert(targetSharding.getPartialAxes().empty() || 671a8fb887SBoian Petkantchin (!sourceSharding.getPartialAxes().empty() && 681a8fb887SBoian Petkantchin sourceSharding.getPartialType() == targetSharding.getPartialType())); 691a8fb887SBoian Petkantchin using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>; 701a8fb887SBoian Petkantchin using AxisSet = llvm::SmallDenseSet<Axis>; 711a8fb887SBoian Petkantchin AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), 721a8fb887SBoian Petkantchin sourceSharding.getPartialAxes().end()); 731a8fb887SBoian Petkantchin AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), 741a8fb887SBoian Petkantchin targetSharding.getPartialAxes().end()); 751a8fb887SBoian Petkantchin assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, 761a8fb887SBoian Petkantchin targetShardingPartialAxesSet)); 771a8fb887SBoian Petkantchin llvm::SmallVector<MeshAxis> allReduceMeshAxes; 781a8fb887SBoian Petkantchin llvm::copy_if(sourceShardingPartialAxesSet, 791a8fb887SBoian Petkantchin std::back_inserter(allReduceMeshAxes), 801a8fb887SBoian Petkantchin [&targetShardingPartialAxesSet](Axis a) { 811a8fb887SBoian Petkantchin return !targetShardingPartialAxesSet.contains(a); 821a8fb887SBoian Petkantchin }); 831a8fb887SBoian Petkantchin if (allReduceMeshAxes.empty()) { 841a8fb887SBoian Petkantchin return {sourceShard, sourceSharding}; 851a8fb887SBoian Petkantchin } 861a8fb887SBoian Petkantchin 871a8fb887SBoian Petkantchin builder.setInsertionPointAfterValue(sourceShard); 88a5757c5bSChristian Sigg TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>( 891a8fb887SBoian Petkantchin builder 901a8fb887SBoian Petkantchin .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), 91baabcb28SFrank Schlimbach sourceSharding.getMeshAttr().getLeafReference(), 921a8fb887SBoian Petkantchin allReduceMeshAxes, sourceShard, 931a8fb887SBoian Petkantchin sourceSharding.getPartialType()) 94a5757c5bSChristian Sigg .getResult()); 951a8fb887SBoian Petkantchin 967a4c4975SBoian Petkantchin llvm::SmallVector<MeshAxis> remainingPartialAxes; 971a8fb887SBoian Petkantchin llvm::copy_if(sourceShardingPartialAxesSet, 981a8fb887SBoian Petkantchin std::back_inserter(allReduceMeshAxes), 991a8fb887SBoian Petkantchin [&targetShardingPartialAxesSet](Axis a) { 1001a8fb887SBoian Petkantchin return targetShardingPartialAxesSet.contains(a); 1011a8fb887SBoian Petkantchin }); 102baabcb28SFrank Schlimbach MeshSharding resultSharding = MeshSharding::get( 103baabcb28SFrank Schlimbach sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(), 104baabcb28SFrank Schlimbach remainingPartialAxes, sourceSharding.getPartialType()); 1051a8fb887SBoian Petkantchin return {resultValue, resultSharding}; 1061a8fb887SBoian Petkantchin } 1071a8fb887SBoian Petkantchin 108baabcb28SFrank Schlimbach static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, 109baabcb28SFrank Schlimbach MeshSharding sourceSharding, 110baabcb28SFrank Schlimbach int64_t splitTensorAxis, 111baabcb28SFrank Schlimbach MeshAxis splitMeshAxis) { 1127a4c4975SBoian Petkantchin SmallVector<MeshAxesAttr> targetShardingSplitAxes = 1131a8fb887SBoian Petkantchin llvm::to_vector(sourceSharding.getSplitAxes()); 1141a8fb887SBoian Petkantchin while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= 1151a8fb887SBoian Petkantchin splitTensorAxis) { 1167a4c4975SBoian Petkantchin targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); 1171a8fb887SBoian Petkantchin } 1181a8fb887SBoian Petkantchin auto targetSplitAxes = 1191a8fb887SBoian Petkantchin llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); 1201a8fb887SBoian Petkantchin targetSplitAxes.push_back(splitMeshAxis); 1211a8fb887SBoian Petkantchin targetShardingSplitAxes[splitTensorAxis] = 1227a4c4975SBoian Petkantchin MeshAxesAttr::get(ctx, targetSplitAxes); 123baabcb28SFrank Schlimbach return MeshSharding::get( 124baabcb28SFrank Schlimbach sourceSharding.getMeshAttr(), targetShardingSplitAxes, 1251a8fb887SBoian Petkantchin sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 1261a8fb887SBoian Petkantchin } 1271a8fb887SBoian Petkantchin 1281a8fb887SBoian Petkantchin // Split a replicated tensor along a mesh axis. 129ffc7feadSFrank Schlimbach // E.g. [[0, 1]] -> [[0, 1, 2]]. 1301a8fb887SBoian Petkantchin // Returns the spmdized target value with its sharding. 131baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding> 1321a8fb887SBoian Petkantchin splitLastAxisInResharding(ImplicitLocOpBuilder &builder, 133baabcb28SFrank Schlimbach MeshSharding sourceSharding, 1349a8437f5SBoian Petkantchin TypedValue<ShapedType> sourceShard, MeshOp mesh, 1351a8fb887SBoian Petkantchin int64_t splitTensorAxis, MeshAxis splitMeshAxis) { 136a5757c5bSChristian Sigg TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 1371a8fb887SBoian Petkantchin builder 138dc3258c6SBoian Petkantchin .create<AllSliceOp>(sourceShard, mesh, 139dc3258c6SBoian Petkantchin ArrayRef<MeshAxis>(splitMeshAxis), 140dc3258c6SBoian Petkantchin splitTensorAxis) 141a5757c5bSChristian Sigg .getResult()); 142baabcb28SFrank Schlimbach MeshSharding targetSharding = targetShardingInSplitLastAxis( 143dc3258c6SBoian Petkantchin builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); 144dc3258c6SBoian Petkantchin return {targetShard, targetSharding}; 1451a8fb887SBoian Petkantchin } 1461a8fb887SBoian Petkantchin 1471a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g. 1481a8fb887SBoian Petkantchin // [[0, 1]] -> [[0, 1, 2]]. 1491a8fb887SBoian Petkantchin // If detected, returns the corresponding tensor axis mesh axis pair. 1501a8fb887SBoian Petkantchin // Does not detect insertions like 1511a8fb887SBoian Petkantchin // [[0, 1]] -> [[0, 2, 1]]. 1521a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, MeshAxis>> 153baabcb28SFrank Schlimbach detectSplitLastAxisInResharding(MeshSharding sourceSharding, 154baabcb28SFrank Schlimbach MeshSharding targetSharding) { 1551a8fb887SBoian Petkantchin for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); 1561a8fb887SBoian Petkantchin ++tensorAxis) { 1571a8fb887SBoian Petkantchin if (sourceSharding.getSplitAxes().size() > tensorAxis) { 1581a8fb887SBoian Petkantchin if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != 1591a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis].size()) { 1601a8fb887SBoian Petkantchin continue; 1611a8fb887SBoian Petkantchin } 1621a8fb887SBoian Petkantchin if (!llvm::equal( 1631a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), 1641a8fb887SBoian Petkantchin llvm::make_range( 1651a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis] 1661a8fb887SBoian Petkantchin .asArrayRef() 1671a8fb887SBoian Petkantchin .begin(), 1681a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 1691a8fb887SBoian Petkantchin 1))) { 1701a8fb887SBoian Petkantchin continue; 1711a8fb887SBoian Petkantchin } 1721a8fb887SBoian Petkantchin } else { 1731a8fb887SBoian Petkantchin if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { 1741a8fb887SBoian Petkantchin continue; 1751a8fb887SBoian Petkantchin } 1761a8fb887SBoian Petkantchin } 1771a8fb887SBoian Petkantchin return std::make_tuple( 1781a8fb887SBoian Petkantchin tensorAxis, 1791a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); 1801a8fb887SBoian Petkantchin } 1811a8fb887SBoian Petkantchin return std::nullopt; 1821a8fb887SBoian Petkantchin } 1831a8fb887SBoian Petkantchin 184baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 1859a8437f5SBoian Petkantchin trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 186baabcb28SFrank Schlimbach MeshSharding sourceSharding, 187baabcb28SFrank Schlimbach MeshSharding targetSharding, 1881a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 1891a8fb887SBoian Petkantchin if (auto detectRes = 1901a8fb887SBoian Petkantchin detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { 1911a8fb887SBoian Petkantchin auto [tensorAxis, meshAxis] = detectRes.value(); 1921a8fb887SBoian Petkantchin return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, 1931a8fb887SBoian Petkantchin tensorAxis, meshAxis); 1941a8fb887SBoian Petkantchin } 1951a8fb887SBoian Petkantchin 1961a8fb887SBoian Petkantchin return std::nullopt; 1971a8fb887SBoian Petkantchin } 1981a8fb887SBoian Petkantchin 1991a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g. 2001a8fb887SBoian Petkantchin // [[0, 1, 2]] -> [[0, 1]]. 2011a8fb887SBoian Petkantchin // If detected, returns the corresponding tensor axis mesh axis pair. 2021a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, MeshAxis>> 203baabcb28SFrank Schlimbach detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, 204baabcb28SFrank Schlimbach MeshSharding targetSharding) { 2051a8fb887SBoian Petkantchin for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); 2061a8fb887SBoian Petkantchin ++tensorAxis) { 2071a8fb887SBoian Petkantchin if (targetSharding.getSplitAxes().size() > tensorAxis) { 2081a8fb887SBoian Petkantchin if (sourceSharding.getSplitAxes()[tensorAxis].size() != 2091a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis].size() + 1) 2101a8fb887SBoian Petkantchin continue; 2111a8fb887SBoian Petkantchin if (!llvm::equal( 2121a8fb887SBoian Petkantchin llvm::make_range( 2131a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[tensorAxis] 2141a8fb887SBoian Petkantchin .asArrayRef() 2151a8fb887SBoian Petkantchin .begin(), 2161a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 2171a8fb887SBoian Petkantchin 1), 2181a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) 2191a8fb887SBoian Petkantchin continue; 2201a8fb887SBoian Petkantchin } else { 2211a8fb887SBoian Petkantchin if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) 2221a8fb887SBoian Petkantchin continue; 2231a8fb887SBoian Petkantchin } 2241a8fb887SBoian Petkantchin return std::make_tuple( 2251a8fb887SBoian Petkantchin tensorAxis, 2261a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); 2271a8fb887SBoian Petkantchin } 2281a8fb887SBoian Petkantchin return std::nullopt; 2291a8fb887SBoian Petkantchin } 2301a8fb887SBoian Petkantchin 231baabcb28SFrank Schlimbach static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, 232baabcb28SFrank Schlimbach MeshSharding sourceSharding, 2331a8fb887SBoian Petkantchin int64_t splitTensorAxis) { 2347a4c4975SBoian Petkantchin SmallVector<MeshAxesAttr> targetShardingSplitAxes = 2351a8fb887SBoian Petkantchin llvm::to_vector(sourceSharding.getSplitAxes()); 2361a8fb887SBoian Petkantchin assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > 2371a8fb887SBoian Petkantchin splitTensorAxis); 2381a8fb887SBoian Petkantchin auto targetSplitAxes = 2391a8fb887SBoian Petkantchin llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); 2401a8fb887SBoian Petkantchin 2411a8fb887SBoian Petkantchin targetSplitAxes.pop_back(); 2421a8fb887SBoian Petkantchin targetShardingSplitAxes[splitTensorAxis] = 2437a4c4975SBoian Petkantchin MeshAxesAttr::get(ctx, targetSplitAxes); 244baabcb28SFrank Schlimbach return MeshSharding::get( 245baabcb28SFrank Schlimbach sourceSharding.getMeshAttr(), targetShardingSplitAxes, 2461a8fb887SBoian Petkantchin sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 2471a8fb887SBoian Petkantchin } 2481a8fb887SBoian Petkantchin 2491a8fb887SBoian Petkantchin static ShapedType allGatherResultShapeInUnsplitLastAxis( 2501a8fb887SBoian Petkantchin ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { 2511a8fb887SBoian Petkantchin SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); 2521a8fb887SBoian Petkantchin targetShape[splitTensorAxis] = 253adbf21f1SBoian Petkantchin gatherDimension(targetShape[splitTensorAxis], splitCount); 2541a8fb887SBoian Petkantchin return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); 2551a8fb887SBoian Petkantchin } 2561a8fb887SBoian Petkantchin 257baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding> 2581a8fb887SBoian Petkantchin unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, 259baabcb28SFrank Schlimbach MeshSharding sourceSharding, 2601a8fb887SBoian Petkantchin ShapedType sourceUnshardedShape, 2619a8437f5SBoian Petkantchin TypedValue<ShapedType> sourceShard, MeshOp mesh, 2621a8fb887SBoian Petkantchin int64_t splitTensorAxis, MeshAxis splitMeshAxis) { 2631a8fb887SBoian Petkantchin MLIRContext *ctx = builder.getContext(); 2641a8fb887SBoian Petkantchin builder.setInsertionPointAfterValue(sourceShard); 2651a8fb887SBoian Petkantchin 266baabcb28SFrank Schlimbach MeshSharding targetSharding = 26701a429c4SArda Unal targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); 2681a8fb887SBoian Petkantchin ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( 2695df2c00aSBoian Petkantchin sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); 2701a8fb887SBoian Petkantchin Value allGatherResult = builder.create<AllGatherOp>( 2711a8fb887SBoian Petkantchin RankedTensorType::get(allGatherResultShape.getShape(), 2721a8fb887SBoian Petkantchin allGatherResultShape.getElementType()), 2731a8fb887SBoian Petkantchin mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, 2741a8fb887SBoian Petkantchin APInt(64, splitTensorAxis)); 2751a8fb887SBoian Petkantchin ShapedType targetShape = 2761a8fb887SBoian Petkantchin shardShapedType(sourceUnshardedShape, mesh, targetSharding); 277a5757c5bSChristian Sigg TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 278a5757c5bSChristian Sigg builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult()); 2791a8fb887SBoian Petkantchin return {targetShard, targetSharding}; 2801a8fb887SBoian Petkantchin } 2811a8fb887SBoian Petkantchin 282baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 2839a8437f5SBoian Petkantchin tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 284baabcb28SFrank Schlimbach MeshSharding sourceSharding, 285baabcb28SFrank Schlimbach MeshSharding targetSharding, 2861a8fb887SBoian Petkantchin ShapedType sourceUnshardedShape, 2871a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 2881a8fb887SBoian Petkantchin if (auto detectRes = 2891a8fb887SBoian Petkantchin detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { 2901a8fb887SBoian Petkantchin auto [tensorAxis, meshAxis] = detectRes.value(); 2911a8fb887SBoian Petkantchin return unsplitLastAxisInResharding(builder, sourceSharding, 2921a8fb887SBoian Petkantchin sourceUnshardedShape, sourceShard, mesh, 2931a8fb887SBoian Petkantchin tensorAxis, meshAxis); 2941a8fb887SBoian Petkantchin } 2951a8fb887SBoian Petkantchin 2961a8fb887SBoian Petkantchin return std::nullopt; 2971a8fb887SBoian Petkantchin } 2981a8fb887SBoian Petkantchin 2991a8fb887SBoian Petkantchin // Detect if the resharding is of type e.g. 3001a8fb887SBoian Petkantchin // [[0, 1], [2]] -> [[0], [1, 2]]. 3011a8fb887SBoian Petkantchin // Only moving the last axis counts. 3021a8fb887SBoian Petkantchin // If detected, returns the corresponding (source_tensor_axis, 3031a8fb887SBoian Petkantchin // target_tensor_axis, mesh_axis) tuple. 3041a8fb887SBoian Petkantchin static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> 305baabcb28SFrank Schlimbach detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, 306baabcb28SFrank Schlimbach MeshSharding targetSharding) { 3071a8fb887SBoian Petkantchin for (size_t sourceTensorAxis = 0; 3081a8fb887SBoian Petkantchin sourceTensorAxis < sourceSharding.getSplitAxes().size(); 3091a8fb887SBoian Petkantchin ++sourceTensorAxis) { 3101a8fb887SBoian Petkantchin for (size_t targetTensorAxis = 0; 3111a8fb887SBoian Petkantchin targetTensorAxis < targetSharding.getSplitAxes().size(); 3121a8fb887SBoian Petkantchin ++targetTensorAxis) { 3131a8fb887SBoian Petkantchin if (sourceTensorAxis == targetTensorAxis) 3141a8fb887SBoian Petkantchin continue; 3151a8fb887SBoian Petkantchin if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || 3161a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[targetTensorAxis].empty() || 3171a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != 3181a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[targetTensorAxis] 3191a8fb887SBoian Petkantchin .asArrayRef() 3201a8fb887SBoian Petkantchin .back()) 3211a8fb887SBoian Petkantchin continue; 3221a8fb887SBoian Petkantchin if (!llvm::equal( 3231a8fb887SBoian Petkantchin llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] 3241a8fb887SBoian Petkantchin .asArrayRef() 3251a8fb887SBoian Petkantchin .begin(), 3261a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[sourceTensorAxis] 3271a8fb887SBoian Petkantchin .asArrayRef() 3281a8fb887SBoian Petkantchin .end() - 3291a8fb887SBoian Petkantchin 1), 3301a8fb887SBoian Petkantchin llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] 3311a8fb887SBoian Petkantchin .asArrayRef() 3321a8fb887SBoian Petkantchin .begin(), 3331a8fb887SBoian Petkantchin targetSharding.getSplitAxes()[targetTensorAxis] 3341a8fb887SBoian Petkantchin .asArrayRef() 3351a8fb887SBoian Petkantchin .end() - 3361a8fb887SBoian Petkantchin 1))) 3371a8fb887SBoian Petkantchin continue; 3381a8fb887SBoian Petkantchin return std::make_tuple( 3391a8fb887SBoian Petkantchin sourceTensorAxis, targetTensorAxis, 3401a8fb887SBoian Petkantchin sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); 3411a8fb887SBoian Petkantchin } 3421a8fb887SBoian Petkantchin } 3431a8fb887SBoian Petkantchin return std::nullopt; 3441a8fb887SBoian Petkantchin } 3451a8fb887SBoian Petkantchin 346baabcb28SFrank Schlimbach static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, 347baabcb28SFrank Schlimbach MeshSharding sourceSharding, 3481a8fb887SBoian Petkantchin int64_t sourceTensorAxis, 3491a8fb887SBoian Petkantchin int64_t targetTensorAxis) { 3507a4c4975SBoian Petkantchin SmallVector<MeshAxesAttr> targetShardingSplitAxes = 3511a8fb887SBoian Petkantchin llvm::to_vector(sourceSharding.getSplitAxes()); 3521a8fb887SBoian Petkantchin while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= 3531a8fb887SBoian Petkantchin targetTensorAxis) { 3547a4c4975SBoian Petkantchin targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); 3551a8fb887SBoian Petkantchin } 3561a8fb887SBoian Petkantchin 3571a8fb887SBoian Petkantchin auto sourceSplitAxes = 3581a8fb887SBoian Petkantchin llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); 3591a8fb887SBoian Petkantchin assert(!sourceSplitAxes.empty()); 3601a8fb887SBoian Petkantchin auto meshAxis = sourceSplitAxes.back(); 3611a8fb887SBoian Petkantchin sourceSplitAxes.pop_back(); 3621a8fb887SBoian Petkantchin targetShardingSplitAxes[sourceTensorAxis] = 3637a4c4975SBoian Petkantchin MeshAxesAttr::get(ctx, sourceSplitAxes); 3641a8fb887SBoian Petkantchin 3651a8fb887SBoian Petkantchin auto targetSplitAxes = 3661a8fb887SBoian Petkantchin llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); 3671a8fb887SBoian Petkantchin targetSplitAxes.push_back(meshAxis); 3681a8fb887SBoian Petkantchin targetShardingSplitAxes[targetTensorAxis] = 3697a4c4975SBoian Petkantchin MeshAxesAttr::get(ctx, targetSplitAxes); 3701a8fb887SBoian Petkantchin 371baabcb28SFrank Schlimbach return MeshSharding::get( 372baabcb28SFrank Schlimbach sourceSharding.getMeshAttr(), targetShardingSplitAxes, 3731a8fb887SBoian Petkantchin sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 3741a8fb887SBoian Petkantchin } 3751a8fb887SBoian Petkantchin 3761a8fb887SBoian Petkantchin static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, 3771a8fb887SBoian Petkantchin int64_t splitCount, 3781a8fb887SBoian Petkantchin int64_t sourceTensorAxis, 3791a8fb887SBoian Petkantchin int64_t targetTensorAxis) { 3801a8fb887SBoian Petkantchin SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); 3811a8fb887SBoian Petkantchin targetShape[sourceTensorAxis] = 382adbf21f1SBoian Petkantchin gatherDimension(targetShape[sourceTensorAxis], splitCount); 3831a8fb887SBoian Petkantchin targetShape[targetTensorAxis] = 3841a8fb887SBoian Petkantchin shardDimension(targetShape[targetTensorAxis], splitCount); 3851a8fb887SBoian Petkantchin return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); 3861a8fb887SBoian Petkantchin } 3871a8fb887SBoian Petkantchin 388baabcb28SFrank Schlimbach static std::tuple<TypedValue<ShapedType>, MeshSharding> 3899a8437f5SBoian Petkantchin moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 390baabcb28SFrank Schlimbach MeshSharding sourceSharding, 3911a8fb887SBoian Petkantchin ShapedType sourceUnshardedShape, 3921a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard, 3931a8fb887SBoian Petkantchin int64_t sourceTensorAxis, 3941a8fb887SBoian Petkantchin int64_t targetTensorAxis, MeshAxis meshAxis) { 3951a8fb887SBoian Petkantchin MLIRContext *ctx = builder.getContext(); 3961a8fb887SBoian Petkantchin builder.setInsertionPointAfterValue(sourceShard); 3971a8fb887SBoian Petkantchin 398baabcb28SFrank Schlimbach MeshSharding targetSharding = targetShardingInMoveLastAxis( 3991a8fb887SBoian Petkantchin ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); 4001a8fb887SBoian Petkantchin ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( 4015df2c00aSBoian Petkantchin sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, 4025df2c00aSBoian Petkantchin targetTensorAxis); 4031a8fb887SBoian Petkantchin Value allToAllResult = builder.create<AllToAllOp>( 4041a8fb887SBoian Petkantchin RankedTensorType::get(allToAllResultShape.getShape(), 4051a8fb887SBoian Petkantchin allToAllResultShape.getElementType()), 4061a8fb887SBoian Petkantchin mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, 4071a8fb887SBoian Petkantchin APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); 4081a8fb887SBoian Petkantchin ShapedType targetShape = 4091a8fb887SBoian Petkantchin shardShapedType(sourceUnshardedShape, mesh, targetSharding); 410a5757c5bSChristian Sigg TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 411a5757c5bSChristian Sigg builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult()); 4121a8fb887SBoian Petkantchin return {targetShard, targetSharding}; 4131a8fb887SBoian Petkantchin } 4141a8fb887SBoian Petkantchin 415baabcb28SFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 4169a8437f5SBoian Petkantchin tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 417baabcb28SFrank Schlimbach MeshSharding sourceSharding, 418baabcb28SFrank Schlimbach MeshSharding targetSharding, 4191a8fb887SBoian Petkantchin ShapedType sourceUnshardedShape, 4201a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 4211a8fb887SBoian Petkantchin if (auto detectRes = 4221a8fb887SBoian Petkantchin detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { 4231a8fb887SBoian Petkantchin auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); 4241a8fb887SBoian Petkantchin return moveLastSplitAxisInResharding( 4251a8fb887SBoian Petkantchin builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, 4261a8fb887SBoian Petkantchin sourceTensorAxis, targetTensorAxis, meshAxis); 4271a8fb887SBoian Petkantchin } 4281a8fb887SBoian Petkantchin 4291a8fb887SBoian Petkantchin return std::nullopt; 4301a8fb887SBoian Petkantchin } 4311a8fb887SBoian Petkantchin 432ffc7feadSFrank Schlimbach // Detect a change in the halo size (only) and create necessary operations if 433ffc7feadSFrank Schlimbach // needed. A changed halo sizes requires copying the "core" of the source tensor 434ffc7feadSFrank Schlimbach // into the "core" of the destination tensor followed by an update halo 435ffc7feadSFrank Schlimbach // operation. 436ffc7feadSFrank Schlimbach static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 437ffc7feadSFrank Schlimbach tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 438ffc7feadSFrank Schlimbach MeshSharding sourceSharding, 439ffc7feadSFrank Schlimbach MeshSharding targetSharding, 440ffc7feadSFrank Schlimbach ShapedType sourceUnshardedShape, 441ffc7feadSFrank Schlimbach TypedValue<ShapedType> sourceShard) { 442ffc7feadSFrank Schlimbach // Currently handles only cases where halo sizes differ but everything else 443ffc7feadSFrank Schlimbach // stays the same (from source to destination sharding). 444ffc7feadSFrank Schlimbach if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) || 445ffc7feadSFrank Schlimbach !sourceSharding.getPartialAxes().empty() || 446ffc7feadSFrank Schlimbach !targetSharding.getPartialAxes().empty() || 447ffc7feadSFrank Schlimbach !sourceSharding.getStaticShardedDimsOffsets().empty() || 448ffc7feadSFrank Schlimbach !targetSharding.getStaticShardedDimsOffsets().empty() || 449ffc7feadSFrank Schlimbach sourceSharding.equalHaloSizes(targetSharding)) { 450ffc7feadSFrank Schlimbach return std::nullopt; 451ffc7feadSFrank Schlimbach } 452ffc7feadSFrank Schlimbach 453ffc7feadSFrank Schlimbach auto srcHaloSizes = sourceSharding.getStaticHaloSizes(); 454ffc7feadSFrank Schlimbach auto tgtHaloSizes = targetSharding.getStaticHaloSizes(); 455ffc7feadSFrank Schlimbach assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size()); 456ffc7feadSFrank Schlimbach assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) && 457ffc7feadSFrank Schlimbach !ShapedType::isDynamicShape(tgtHaloSizes) && 458ffc7feadSFrank Schlimbach sourceShard.getType().hasStaticShape()) && 459ffc7feadSFrank Schlimbach "dynamic shapes/halos are not supported yet for mesh-spmdization"); 460ffc7feadSFrank Schlimbach auto rank = sourceShard.getType().getRank(); 461ffc7feadSFrank Schlimbach auto splitAxes = sourceSharding.getSplitAxes(); 462ffc7feadSFrank Schlimbach SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), 463ffc7feadSFrank Schlimbach strides(rank, 1), outShape(sourceShard.getType().getShape()), 464ffc7feadSFrank Schlimbach coreShape(sourceShard.getType().getShape()); 465ffc7feadSFrank Schlimbach 466ffc7feadSFrank Schlimbach // Determine "core" of source and destination. 467ffc7feadSFrank Schlimbach // The core is the local part of the shard excluding halo regions. 468ffc7feadSFrank Schlimbach for (auto i = 0u; i < rank; ++i) { 469ffc7feadSFrank Schlimbach if (i < splitAxes.size() && !splitAxes[i].empty()) { 470ffc7feadSFrank Schlimbach if (!srcHaloSizes.empty()) { 471ffc7feadSFrank Schlimbach coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1]; 472ffc7feadSFrank Schlimbach srcCoreOffs[i] = srcHaloSizes[i * 2]; 473ffc7feadSFrank Schlimbach } 474ffc7feadSFrank Schlimbach tgtCoreOffs[i] = tgtHaloSizes[i * 2]; 475ffc7feadSFrank Schlimbach outShape[i] = 476ffc7feadSFrank Schlimbach coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1]; 477ffc7feadSFrank Schlimbach } 478ffc7feadSFrank Schlimbach } 479ffc7feadSFrank Schlimbach 480ffc7feadSFrank Schlimbach // Extract core from source and copy into destination core. 481ffc7feadSFrank Schlimbach auto noVals = ValueRange{}; 482ffc7feadSFrank Schlimbach auto initVal = builder.create<tensor::EmptyOp>( 483ffc7feadSFrank Schlimbach sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); 484ffc7feadSFrank Schlimbach auto core = builder.create<tensor::ExtractSliceOp>( 485ffc7feadSFrank Schlimbach sourceShard.getLoc(), 486ffc7feadSFrank Schlimbach RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), 487ffc7feadSFrank Schlimbach sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); 488ffc7feadSFrank Schlimbach auto initOprnd = builder.create<tensor::InsertSliceOp>( 489ffc7feadSFrank Schlimbach sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, 490ffc7feadSFrank Schlimbach coreShape, strides); 491ffc7feadSFrank Schlimbach 492ffc7feadSFrank Schlimbach // Finally update the halo. 493ffc7feadSFrank Schlimbach auto updateHaloResult = 494ffc7feadSFrank Schlimbach builder 495ffc7feadSFrank Schlimbach .create<UpdateHaloOp>( 496ffc7feadSFrank Schlimbach sourceShard.getLoc(), 497ffc7feadSFrank Schlimbach RankedTensorType::get(outShape, 498ffc7feadSFrank Schlimbach sourceShard.getType().getElementType()), 499*79eb406aSFrank Schlimbach initOprnd, mesh.getSymName(), 500ffc7feadSFrank Schlimbach MeshAxesArrayAttr::get(builder.getContext(), 501ffc7feadSFrank Schlimbach sourceSharding.getSplitAxes()), 502ffc7feadSFrank Schlimbach targetSharding.getDynamicHaloSizes(), 503ffc7feadSFrank Schlimbach targetSharding.getStaticHaloSizes()) 504ffc7feadSFrank Schlimbach .getResult(); 505ffc7feadSFrank Schlimbach return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult), 506ffc7feadSFrank Schlimbach targetSharding); 507ffc7feadSFrank Schlimbach } 508ffc7feadSFrank Schlimbach 5091a8fb887SBoian Petkantchin // Handles only resharding on a 1D mesh. 5101a8fb887SBoian Petkantchin // Currently the sharded tensor axes must be exactly divisible by the single 5111a8fb887SBoian Petkantchin // mesh axis size. 5121a8fb887SBoian Petkantchin static TypedValue<ShapedType> 5139a8437f5SBoian Petkantchin reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, 514baabcb28SFrank Schlimbach MeshSharding sourceSharding, MeshSharding targetSharding, 5151a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceUnshardedValue, 5161a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 5171a8fb887SBoian Petkantchin assert(sourceShard.getType() == 5181a8fb887SBoian Petkantchin shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); 519ab43cf26SJie Fu [[maybe_unused]] ShapedType targetShardType = 5201a8fb887SBoian Petkantchin shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); 5211a8fb887SBoian Petkantchin assert(sourceShard.getType().getRank() == targetShardType.getRank()); 5221a8fb887SBoian Petkantchin assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); 5231a8fb887SBoian Petkantchin 5241a8fb887SBoian Petkantchin auto [reducedSourceShard, reducedSourceSharding] = 5251a8fb887SBoian Petkantchin handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, 5261a8fb887SBoian Petkantchin sourceShard); 5271a8fb887SBoian Petkantchin 5281a8fb887SBoian Petkantchin if (reducedSourceSharding == targetSharding) { 5291a8fb887SBoian Petkantchin return reducedSourceShard; 5301a8fb887SBoian Petkantchin } 5311a8fb887SBoian Petkantchin 5321a8fb887SBoian Petkantchin TypedValue<ShapedType> targetShard; 533baabcb28SFrank Schlimbach MeshSharding actualTargetSharding; 534ffc7feadSFrank Schlimbach if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() && 535ffc7feadSFrank Schlimbach targetSharding.getStaticShardedDimsOffsets().empty() && 536ffc7feadSFrank Schlimbach reducedSourceSharding.getStaticHaloSizes().empty() && 537ffc7feadSFrank Schlimbach targetSharding.getStaticHaloSizes().empty()) { 5381a8fb887SBoian Petkantchin if (auto tryRes = tryMoveLastSplitAxisInResharding( 5391a8fb887SBoian Petkantchin builder, mesh, reducedSourceSharding, targetSharding, 5401a8fb887SBoian Petkantchin sourceUnshardedValue.getType(), reducedSourceShard)) { 5411a8fb887SBoian Petkantchin std::tie(targetShard, actualTargetSharding) = tryRes.value(); 5421a8fb887SBoian Petkantchin } else if (auto tryRes = trySplitLastAxisInResharding( 5431a8fb887SBoian Petkantchin builder, mesh, reducedSourceSharding, targetSharding, 5441a8fb887SBoian Petkantchin reducedSourceShard)) { 5451a8fb887SBoian Petkantchin std::tie(targetShard, actualTargetSharding) = tryRes.value(); 5461a8fb887SBoian Petkantchin } else if (auto tryRes = tryUnsplitLastAxisInResharding( 5471a8fb887SBoian Petkantchin builder, mesh, reducedSourceSharding, targetSharding, 5481a8fb887SBoian Petkantchin sourceUnshardedValue.getType(), reducedSourceShard)) { 5491a8fb887SBoian Petkantchin std::tie(targetShard, actualTargetSharding) = tryRes.value(); 5501a8fb887SBoian Petkantchin } 551baabcb28SFrank Schlimbach } 552baabcb28SFrank Schlimbach assert(targetShard && "Did not find any pattern to apply."); 5531a8fb887SBoian Petkantchin assert(actualTargetSharding == targetSharding); 5541a8fb887SBoian Petkantchin assert(targetShard.getType() == targetShardType); 5551a8fb887SBoian Petkantchin return targetShard; 5561a8fb887SBoian Petkantchin } 5571a8fb887SBoian Petkantchin 5589a8437f5SBoian Petkantchin TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, 559baabcb28SFrank Schlimbach MeshSharding sourceSharding, 560baabcb28SFrank Schlimbach MeshSharding targetSharding, 5611a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceUnshardedValue, 5621a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShard) { 563ffc7feadSFrank Schlimbach // If source and destination sharding are the same, no need to do anything. 564ffc7feadSFrank Schlimbach if (sourceSharding == targetSharding) { 565ffc7feadSFrank Schlimbach return sourceShard; 566ffc7feadSFrank Schlimbach } 567ffc7feadSFrank Schlimbach 568ffc7feadSFrank Schlimbach // Tries to handle the case where the resharding is needed because the halo 569ffc7feadSFrank Schlimbach // sizes are different. Supports arbitrary mesh dimensionality. 570ffc7feadSFrank Schlimbach if (auto tryRes = tryUpdateHaloInResharding( 571ffc7feadSFrank Schlimbach builder, mesh, sourceSharding, targetSharding, 572ffc7feadSFrank Schlimbach sourceUnshardedValue.getType(), sourceShard)) { 573ffc7feadSFrank Schlimbach return std::get<0>(tryRes.value()); // targetShard 574ffc7feadSFrank Schlimbach } 575ffc7feadSFrank Schlimbach 5761a8fb887SBoian Petkantchin // Resort to handling only 1D meshes since the general case is complicated if 5771a8fb887SBoian Petkantchin // it needs to be communication efficient in terms of minimizing the data 5781a8fb887SBoian Petkantchin // transfered between devices. 5791a8fb887SBoian Petkantchin return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, 5801a8fb887SBoian Petkantchin sourceUnshardedValue, sourceShard); 5811a8fb887SBoian Petkantchin } 5821a8fb887SBoian Petkantchin 5839a8437f5SBoian Petkantchin TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, 5849a8437f5SBoian Petkantchin ShardOp target, 5851a8fb887SBoian Petkantchin TypedValue<ShapedType> sourceShardValue) { 586baabcb28SFrank Schlimbach assert(source.getResult() == target.getSrc()); 587baabcb28SFrank Schlimbach auto sourceSharding = source.getSharding(); 588baabcb28SFrank Schlimbach auto targetSharding = target.getSharding(); 5891a8fb887SBoian Petkantchin ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); 590baabcb28SFrank Schlimbach return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, 591baabcb28SFrank Schlimbach cast<TypedValue<ShapedType>>(source.getSrc()), 592baabcb28SFrank Schlimbach sourceShardValue); 5931a8fb887SBoian Petkantchin } 5941a8fb887SBoian Petkantchin 595adbf21f1SBoian Petkantchin TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, 596adbf21f1SBoian Petkantchin ShardOp target, 597adbf21f1SBoian Petkantchin TypedValue<ShapedType> sourceShardValue, 598adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection) { 599adbf21f1SBoian Petkantchin MeshOp srcMesh = getMesh(source, symbolTableCollection); 600adbf21f1SBoian Petkantchin assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); 601adbf21f1SBoian Petkantchin return reshard(builder, srcMesh, source, target, sourceShardValue); 602adbf21f1SBoian Petkantchin } 603adbf21f1SBoian Petkantchin 6041a8fb887SBoian Petkantchin void reshardingRegisterDependentDialects(DialectRegistry ®istry) { 605dc3258c6SBoian Petkantchin registry.insert<mesh::MeshDialect, tensor::TensorDialect>(); 6061a8fb887SBoian Petkantchin } 6071a8fb887SBoian Petkantchin 608adbf21f1SBoian Petkantchin #define GEN_PASS_DEF_SPMDIZATION 609adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" 610adbf21f1SBoian Petkantchin 611adbf21f1SBoian Petkantchin using UnshardedToShardedValueMap = DenseMap<Value, Value>; 612adbf21f1SBoian Petkantchin 613adbf21f1SBoian Petkantchin // Get the types of block arguments for an spmdized block. 614adbf21f1SBoian Petkantchin // Reads the sharding annotations of the arguments to deduce the sharded types. 615adbf21f1SBoian Petkantchin // Types that are not ranked tensors are left unchanged. 616adbf21f1SBoian Petkantchin SmallVector<Type> 617adbf21f1SBoian Petkantchin shardedBlockArgumentTypes(Block &block, 618adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection) { 619adbf21f1SBoian Petkantchin SmallVector<Type> res; 620a5757c5bSChristian Sigg llvm::transform( 621a5757c5bSChristian Sigg block.getArguments(), std::back_inserter(res), 622adbf21f1SBoian Petkantchin [&symbolTableCollection](BlockArgument arg) { 623a5757c5bSChristian Sigg auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg); 624adbf21f1SBoian Petkantchin if (!rankedTensorArg) { 625adbf21f1SBoian Petkantchin return arg.getType(); 626adbf21f1SBoian Petkantchin } 627adbf21f1SBoian Petkantchin 628adbf21f1SBoian Petkantchin assert(rankedTensorArg.hasOneUse()); 629adbf21f1SBoian Petkantchin Operation *useOp = *rankedTensorArg.getUsers().begin(); 630adbf21f1SBoian Petkantchin ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp); 631adbf21f1SBoian Petkantchin assert(shardOp); 632adbf21f1SBoian Petkantchin MeshOp mesh = getMesh(shardOp, symbolTableCollection); 633a5757c5bSChristian Sigg return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh, 634baabcb28SFrank Schlimbach shardOp.getSharding())); 635adbf21f1SBoian Petkantchin }); 636adbf21f1SBoian Petkantchin return res; 637adbf21f1SBoian Petkantchin } 638adbf21f1SBoian Petkantchin 639baabcb28SFrank Schlimbach void spmdizeTriviallyShardableOperation(Operation &op, 640baabcb28SFrank Schlimbach ArrayRef<Value> spmdizedOperands, 641baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 642baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 643baabcb28SFrank Schlimbach IRMapping &spmdizationMap, 644baabcb28SFrank Schlimbach SymbolTableCollection &symbolTable, 645baabcb28SFrank Schlimbach OpBuilder &builder); 646baabcb28SFrank Schlimbach 647adbf21f1SBoian Petkantchin static LogicalResult spmdizeOperation( 648adbf21f1SBoian Petkantchin Operation &op, ArrayRef<Value> spmdizedOperands, 649baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 650baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 651adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { 652adbf21f1SBoian Petkantchin ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op); 653adbf21f1SBoian Petkantchin if (!shardingInterface) { 654adbf21f1SBoian Petkantchin // If there is no sharding interface we are conservative and assume that 655adbf21f1SBoian Petkantchin // the op should be fully replicated no all devices. 656adbf21f1SBoian Petkantchin spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, 657adbf21f1SBoian Petkantchin resultShardings, spmdizationMap, 658adbf21f1SBoian Petkantchin symbolTableCollection, builder); 659adbf21f1SBoian Petkantchin } else { 660adbf21f1SBoian Petkantchin if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, 661adbf21f1SBoian Petkantchin resultShardings, spmdizationMap, 662adbf21f1SBoian Petkantchin symbolTableCollection, builder))) { 663adbf21f1SBoian Petkantchin return failure(); 664adbf21f1SBoian Petkantchin } 665adbf21f1SBoian Petkantchin } 666adbf21f1SBoian Petkantchin 667adbf21f1SBoian Petkantchin assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { 668adbf21f1SBoian Petkantchin return spmdizationMap.contains(result); 669adbf21f1SBoian Petkantchin })); 670adbf21f1SBoian Petkantchin 671adbf21f1SBoian Petkantchin return success(); 672adbf21f1SBoian Petkantchin } 673adbf21f1SBoian Petkantchin 674adbf21f1SBoian Petkantchin // Retrieve the sharding annotations for the operands of the given operation. 675adbf21f1SBoian Petkantchin // If the type is not a ranked tensor it is not require to have an annotation. 676baabcb28SFrank Schlimbach static std::vector<MeshSharding> getOperandShardings(Operation &op) { 677baabcb28SFrank Schlimbach std::vector<MeshSharding> res; 678adbf21f1SBoian Petkantchin res.reserve(op.getNumOperands()); 679adbf21f1SBoian Petkantchin llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { 680adbf21f1SBoian Petkantchin TypedValue<RankedTensorType> rankedTensor = 681a5757c5bSChristian Sigg dyn_cast<TypedValue<RankedTensorType>>(operand); 682adbf21f1SBoian Petkantchin if (!rankedTensor) { 683baabcb28SFrank Schlimbach return MeshSharding(); 684adbf21f1SBoian Petkantchin } 685adbf21f1SBoian Petkantchin 686adbf21f1SBoian Petkantchin Operation *definingOp = operand.getDefiningOp(); 687adbf21f1SBoian Petkantchin assert(definingOp); 688adbf21f1SBoian Petkantchin ShardOp shardOp = llvm::cast<ShardOp>(definingOp); 689baabcb28SFrank Schlimbach return MeshSharding(shardOp.getSharding()); 690adbf21f1SBoian Petkantchin }); 691adbf21f1SBoian Petkantchin return res; 692adbf21f1SBoian Petkantchin } 693adbf21f1SBoian Petkantchin 694adbf21f1SBoian Petkantchin // Retrieve the sharding annotations for the results of the given operation. 695adbf21f1SBoian Petkantchin // If the type is not a ranked tensor it is not require to have an annotation. 696baabcb28SFrank Schlimbach static std::vector<MeshSharding> getResultShardings(Operation &op) { 697baabcb28SFrank Schlimbach std::vector<MeshSharding> res; 698adbf21f1SBoian Petkantchin res.reserve(op.getNumResults()); 699adbf21f1SBoian Petkantchin llvm::transform(op.getResults(), std::back_inserter(res), 700adbf21f1SBoian Petkantchin [](OpResult result) { 701adbf21f1SBoian Petkantchin TypedValue<RankedTensorType> rankedTensor = 702a5757c5bSChristian Sigg dyn_cast<TypedValue<RankedTensorType>>(result); 703adbf21f1SBoian Petkantchin if (!rankedTensor) { 704baabcb28SFrank Schlimbach return MeshSharding(); 705adbf21f1SBoian Petkantchin } 706adbf21f1SBoian Petkantchin 707adbf21f1SBoian Petkantchin assert(result.hasOneUse()); 708adbf21f1SBoian Petkantchin Operation *userOp = *result.getUsers().begin(); 709adbf21f1SBoian Petkantchin ShardOp shardOp = llvm::cast<ShardOp>(userOp); 710baabcb28SFrank Schlimbach return MeshSharding(shardOp.getSharding()); 711adbf21f1SBoian Petkantchin }); 712adbf21f1SBoian Petkantchin return res; 713adbf21f1SBoian Petkantchin } 714adbf21f1SBoian Petkantchin 715adbf21f1SBoian Petkantchin static LogicalResult 7164f7ab789SBoian Petkantchin spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, 7174f7ab789SBoian Petkantchin SymbolTableCollection &symbolTableCollection, 7184f7ab789SBoian Petkantchin OpBuilder &builder) { 7194f7ab789SBoian Petkantchin Value targetSpmdValue; 7204f7ab789SBoian Petkantchin 7214f7ab789SBoian Petkantchin // Check if 2 shard ops are chained. If not there is no need for resharding 7224f7ab789SBoian Petkantchin // as the source and target shared the same sharding. 7234f7ab789SBoian Petkantchin ShardOp srcShardOp = 724baabcb28SFrank Schlimbach dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp()); 7254f7ab789SBoian Petkantchin if (!srcShardOp) { 726baabcb28SFrank Schlimbach targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); 7274f7ab789SBoian Petkantchin } else { 7284f7ab789SBoian Petkantchin // Insert resharding. 729ffc7feadSFrank Schlimbach TypedValue<ShapedType> srcSpmdValue = 730ffc7feadSFrank Schlimbach cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp)); 7314f7ab789SBoian Petkantchin targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, 7324f7ab789SBoian Petkantchin symbolTableCollection); 7334f7ab789SBoian Petkantchin } 7344f7ab789SBoian Petkantchin 7354f7ab789SBoian Petkantchin assert(!spmdizationMap.contains(shardOp.getResult())); 7364f7ab789SBoian Petkantchin spmdizationMap.map(shardOp.getResult(), targetSpmdValue); 7374f7ab789SBoian Petkantchin return success(); 7384f7ab789SBoian Petkantchin } 7394f7ab789SBoian Petkantchin 7404f7ab789SBoian Petkantchin static LogicalResult 741adbf21f1SBoian Petkantchin spmdizeOperation(Operation &op, IRMapping &spmdizationMap, 742adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection, 743adbf21f1SBoian Petkantchin OpBuilder &builder) { 744baabcb28SFrank Schlimbach if (isa<ShardingOp>(op)) { 745baabcb28SFrank Schlimbach return success(); 746baabcb28SFrank Schlimbach } 747baabcb28SFrank Schlimbach 748adbf21f1SBoian Petkantchin ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); 749adbf21f1SBoian Petkantchin if (shardOp) { 7504f7ab789SBoian Petkantchin return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, 7514f7ab789SBoian Petkantchin builder); 752adbf21f1SBoian Petkantchin } 753adbf21f1SBoian Petkantchin 754adbf21f1SBoian Petkantchin SmallVector<Value> spmdizedOperands; 755adbf21f1SBoian Petkantchin llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), 756adbf21f1SBoian Petkantchin [&spmdizationMap](Value operand) { 757adbf21f1SBoian Petkantchin assert(spmdizationMap.contains(operand)); 758adbf21f1SBoian Petkantchin return spmdizationMap.lookup(operand); 759adbf21f1SBoian Petkantchin }); 760adbf21f1SBoian Petkantchin return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), 761adbf21f1SBoian Petkantchin getResultShardings(op), spmdizationMap, 762adbf21f1SBoian Petkantchin symbolTableCollection, builder); 763adbf21f1SBoian Petkantchin } 764adbf21f1SBoian Petkantchin 765adbf21f1SBoian Petkantchin static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, 766adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection, 767adbf21f1SBoian Petkantchin OpBuilder &builder) { 768adbf21f1SBoian Petkantchin SmallVector<Location> argLocations; 769adbf21f1SBoian Petkantchin llvm::transform(block.getArguments(), std::back_inserter(argLocations), 770adbf21f1SBoian Petkantchin [](BlockArgument arg) { return arg.getLoc(); }); 771adbf21f1SBoian Petkantchin Block *newBlock = builder.createBlock( 772adbf21f1SBoian Petkantchin block.getParent(), {}, 773adbf21f1SBoian Petkantchin shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); 774adbf21f1SBoian Petkantchin for (auto [unshardedBlockArg, spmdizedBlockArg] : 775adbf21f1SBoian Petkantchin llvm::zip(block.getArguments(), newBlock->getArguments())) { 776adbf21f1SBoian Petkantchin spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); 777adbf21f1SBoian Petkantchin } 778adbf21f1SBoian Petkantchin 779adbf21f1SBoian Petkantchin OpBuilder::InsertionGuard insertionGuard(builder); 780adbf21f1SBoian Petkantchin builder.setInsertionPointToEnd(newBlock); 781adbf21f1SBoian Petkantchin for (Operation &op : block.getOperations()) { 782adbf21f1SBoian Petkantchin if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, 783adbf21f1SBoian Petkantchin builder))) { 784adbf21f1SBoian Petkantchin return failure(); 785adbf21f1SBoian Petkantchin } 786adbf21f1SBoian Petkantchin } 787adbf21f1SBoian Petkantchin 788adbf21f1SBoian Petkantchin return success(); 789adbf21f1SBoian Petkantchin } 790adbf21f1SBoian Petkantchin 791adbf21f1SBoian Petkantchin static LogicalResult 792abfac563SBoian Petkantchin spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, 793adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTableCollection) { 794adbf21f1SBoian Petkantchin OpBuilder builder(op.getFunctionBody()); 795adbf21f1SBoian Petkantchin 796adbf21f1SBoian Petkantchin // Snapshot the original blocks to not mess up the iteration when adding new 797adbf21f1SBoian Petkantchin // blocks. 798adbf21f1SBoian Petkantchin SmallVector<Block *> originalBlocks; 799adbf21f1SBoian Petkantchin llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), 800adbf21f1SBoian Petkantchin [](Block &b) { return &b; }); 801adbf21f1SBoian Petkantchin 802adbf21f1SBoian Petkantchin for (Block *block : originalBlocks) { 803adbf21f1SBoian Petkantchin if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, 804adbf21f1SBoian Petkantchin builder))) { 805adbf21f1SBoian Petkantchin return failure(); 806adbf21f1SBoian Petkantchin } 807adbf21f1SBoian Petkantchin } 808adbf21f1SBoian Petkantchin 809adbf21f1SBoian Petkantchin for (Block *block : originalBlocks) { 810adbf21f1SBoian Petkantchin block->erase(); 811adbf21f1SBoian Petkantchin } 812adbf21f1SBoian Petkantchin 813adbf21f1SBoian Petkantchin // Find a return op and change the function results signature to its operands 814adbf21f1SBoian Petkantchin // signature. 815abfac563SBoian Petkantchin Operation *returnOp = nullptr; 816abfac563SBoian Petkantchin for (Block &block : op.getFunctionBody()) { 817adbf21f1SBoian Petkantchin if (block.empty()) { 818adbf21f1SBoian Petkantchin continue; 819adbf21f1SBoian Petkantchin } 820adbf21f1SBoian Petkantchin 821abfac563SBoian Petkantchin if (block.back().hasTrait<OpTrait::ReturnLike>()) { 822abfac563SBoian Petkantchin returnOp = &block.back(); 823adbf21f1SBoian Petkantchin break; 824adbf21f1SBoian Petkantchin } 825adbf21f1SBoian Petkantchin } 826adbf21f1SBoian Petkantchin assert(returnOp); 827abfac563SBoian Petkantchin op.setType(FunctionType::get(op->getContext(), 828abfac563SBoian Petkantchin op.getFunctionBody().front().getArgumentTypes(), 829adbf21f1SBoian Petkantchin returnOp->getOperandTypes())); 830adbf21f1SBoian Petkantchin 831adbf21f1SBoian Petkantchin return success(); 832adbf21f1SBoian Petkantchin } 833adbf21f1SBoian Petkantchin 834adbf21f1SBoian Petkantchin namespace { 835adbf21f1SBoian Petkantchin 836adbf21f1SBoian Petkantchin struct Spmdization : public impl::SpmdizationBase<Spmdization> { 837adbf21f1SBoian Petkantchin void runOnOperation() override { 838adbf21f1SBoian Petkantchin IRMapping spmdizationMap; 839adbf21f1SBoian Petkantchin SymbolTableCollection symbolTableCollection; 840adbf21f1SBoian Petkantchin if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, 841adbf21f1SBoian Petkantchin symbolTableCollection))) { 842adbf21f1SBoian Petkantchin return signalPassFailure(); 843adbf21f1SBoian Petkantchin } 844adbf21f1SBoian Petkantchin } 845adbf21f1SBoian Petkantchin 846adbf21f1SBoian Petkantchin void getDependentDialects(DialectRegistry ®istry) const override { 847adbf21f1SBoian Petkantchin reshardingRegisterDependentDialects(registry); 848adbf21f1SBoian Petkantchin registry.insert<mesh::MeshDialect>(); 849adbf21f1SBoian Petkantchin } 850adbf21f1SBoian Petkantchin }; 851adbf21f1SBoian Petkantchin 852adbf21f1SBoian Petkantchin } // namespace 853adbf21f1SBoian Petkantchin 854adbf21f1SBoian Petkantchin } // namespace mlir::mesh 855