//===- Spmdization.cpp --------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include #include #include #include namespace mlir::mesh { template static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes) { return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { return sourceAxes.contains(targetAxis); }); } // Return the reduced value and its corresponding sharding. // Example: // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> // targetSharding = <@mesh_1d, [[]]> // Then will apply all-reduce on the source value // and return it with the sharding <@mesh_1d, [[0]]>. static std::tuple, MeshSharding> handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue sourceShard) { if (sourceSharding.getPartialAxes().empty() && targetSharding.getPartialAxes().empty()) { return {sourceShard, sourceSharding}; } assert(targetSharding.getPartialAxes().empty() || (!sourceSharding.getPartialAxes().empty() && sourceSharding.getPartialType() == targetSharding.getPartialType())); using Axis = std::decay_t; using AxisSet = llvm::SmallDenseSet; AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), sourceSharding.getPartialAxes().end()); AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), targetSharding.getPartialAxes().end()); assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, targetShardingPartialAxesSet)); llvm::SmallVector allReduceMeshAxes; llvm::copy_if(sourceShardingPartialAxesSet, std::back_inserter(allReduceMeshAxes), [&targetShardingPartialAxesSet](Axis a) { return !targetShardingPartialAxesSet.contains(a); }); if (allReduceMeshAxes.empty()) { return {sourceShard, sourceSharding}; } builder.setInsertionPointAfterValue(sourceShard); TypedValue resultValue = cast>( builder .create(sourceShard.getLoc(), sourceShard.getType(), sourceSharding.getMeshAttr().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) .getResult()); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, std::back_inserter(allReduceMeshAxes), [&targetShardingPartialAxesSet](Axis a) { return targetShardingPartialAxesSet.contains(a); }); MeshSharding resultSharding = MeshSharding::get( sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(), remainingPartialAxes, sourceSharding.getPartialType()); return {resultValue, resultSharding}; } static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= splitTensorAxis) { targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); } auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); targetSplitAxes.push_back(splitMeshAxis); targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshSharding::get( sourceSharding.getMeshAttr(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } // Split a replicated tensor along a mesh axis. // E.g. [[0, 1]] -> [[0, 1, 2]]. // Returns the spmdized target value with its sharding. static std::tuple, MeshSharding> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { TypedValue targetShard = cast>( builder .create(sourceShard, mesh, ArrayRef(splitMeshAxis), splitTensorAxis) .getResult()); MeshSharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); return {targetShard, targetSharding}; } // Detect if the resharding is of type e.g. // [[0, 1]] -> [[0, 1, 2]]. // If detected, returns the corresponding tensor axis mesh axis pair. // Does not detect insertions like // [[0, 1]] -> [[0, 2, 1]]. static std::optional> detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); ++tensorAxis) { if (sourceSharding.getSplitAxes().size() > tensorAxis) { if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != targetSharding.getSplitAxes()[tensorAxis].size()) { continue; } if (!llvm::equal( sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), llvm::make_range( targetSharding.getSplitAxes()[tensorAxis] .asArrayRef() .begin(), targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 1))) { continue; } } else { if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { continue; } } return std::make_tuple( tensorAxis, targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); } return std::nullopt; } static std::optional, MeshSharding>> trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue sourceShard) { if (auto detectRes = detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { auto [tensorAxis, meshAxis] = detectRes.value(); return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, tensorAxis, meshAxis); } return std::nullopt; } // Detect if the resharding is of type e.g. // [[0, 1, 2]] -> [[0, 1]]. // If detected, returns the corresponding tensor axis mesh axis pair. static std::optional> detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); ++tensorAxis) { if (targetSharding.getSplitAxes().size() > tensorAxis) { if (sourceSharding.getSplitAxes()[tensorAxis].size() != targetSharding.getSplitAxes()[tensorAxis].size() + 1) continue; if (!llvm::equal( llvm::make_range( sourceSharding.getSplitAxes()[tensorAxis] .asArrayRef() .begin(), sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 1), targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) continue; } else { if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) continue; } return std::make_tuple( tensorAxis, sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); } return std::nullopt; } static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); assert(static_cast(targetShardingSplitAxes.size()) > splitTensorAxis); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); targetSplitAxes.pop_back(); targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshSharding::get( sourceSharding.getMeshAttr(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } static ShapedType allGatherResultShapeInUnsplitLastAxis( ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); targetShape[splitTensorAxis] = gatherDimension(targetShape[splitTensorAxis], splitCount); return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } static std::tuple, MeshSharding> unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); MeshSharding targetSharding = targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); Value allGatherResult = builder.create( RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), mesh.getSymName(), SmallVector({splitMeshAxis}), sourceShard, APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( builder.create(targetShape, allGatherResult).getResult()); return {targetShard, targetSharding}; } static std::optional, MeshSharding>> tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { auto [tensorAxis, meshAxis] = detectRes.value(); return unsplitLastAxisInResharding(builder, sourceSharding, sourceUnshardedShape, sourceShard, mesh, tensorAxis, meshAxis); } return std::nullopt; } // Detect if the resharding is of type e.g. // [[0, 1], [2]] -> [[0], [1, 2]]. // Only moving the last axis counts. // If detected, returns the corresponding (source_tensor_axis, // target_tensor_axis, mesh_axis) tuple. static std::optional> detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding) { for (size_t sourceTensorAxis = 0; sourceTensorAxis < sourceSharding.getSplitAxes().size(); ++sourceTensorAxis) { for (size_t targetTensorAxis = 0; targetTensorAxis < targetSharding.getSplitAxes().size(); ++targetTensorAxis) { if (sourceTensorAxis == targetTensorAxis) continue; if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || targetSharding.getSplitAxes()[targetTensorAxis].empty() || sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .back()) continue; if (!llvm::equal( llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] .asArrayRef() .begin(), sourceSharding.getSplitAxes()[sourceTensorAxis] .asArrayRef() .end() - 1), llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .begin(), targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .end() - 1))) continue; return std::make_tuple( sourceTensorAxis, targetTensorAxis, sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); } } return std::nullopt; } static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= targetTensorAxis) { targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); } auto sourceSplitAxes = llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); assert(!sourceSplitAxes.empty()); auto meshAxis = sourceSplitAxes.back(); sourceSplitAxes.pop_back(); targetShardingSplitAxes[sourceTensorAxis] = MeshAxesAttr::get(ctx, sourceSplitAxes); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); targetSplitAxes.push_back(meshAxis); targetShardingSplitAxes[targetTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshSharding::get( sourceSharding.getMeshAttr(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) { SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); targetShape[sourceTensorAxis] = gatherDimension(targetShape[sourceTensorAxis], splitCount); targetShape[targetTensorAxis] = shardDimension(targetShape[targetTensorAxis], splitCount); return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } static std::tuple, MeshSharding> moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); MeshSharding targetSharding = targetShardingInMoveLastAxis( ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, targetTensorAxis); Value allToAllResult = builder.create( RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), mesh.getSymName(), SmallVector({meshAxis}), sourceShard, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( builder.create(targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } static std::optional, MeshSharding>> tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); return moveLastSplitAxisInResharding( builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, sourceTensorAxis, targetTensorAxis, meshAxis); } return std::nullopt; } // Detect a change in the halo size (only) and create necessary operations if // needed. A changed halo sizes requires copying the "core" of the source tensor // into the "core" of the destination tensor followed by an update halo // operation. static std::optional, MeshSharding>> tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { // Currently handles only cases where halo sizes differ but everything else // stays the same (from source to destination sharding). if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) || !sourceSharding.getPartialAxes().empty() || !targetSharding.getPartialAxes().empty() || !sourceSharding.getStaticShardedDimsOffsets().empty() || !targetSharding.getStaticShardedDimsOffsets().empty() || sourceSharding.equalHaloSizes(targetSharding)) { return std::nullopt; } auto srcHaloSizes = sourceSharding.getStaticHaloSizes(); auto tgtHaloSizes = targetSharding.getStaticHaloSizes(); assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size()); assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) && !ShapedType::isDynamicShape(tgtHaloSizes) && sourceShard.getType().hasStaticShape()) && "dynamic shapes/halos are not supported yet for mesh-spmdization"); auto rank = sourceShard.getType().getRank(); auto splitAxes = sourceSharding.getSplitAxes(); SmallVector srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), strides(rank, 1), outShape(sourceShard.getType().getShape()), coreShape(sourceShard.getType().getShape()); // Determine "core" of source and destination. // The core is the local part of the shard excluding halo regions. for (auto i = 0u; i < rank; ++i) { if (i < splitAxes.size() && !splitAxes[i].empty()) { if (!srcHaloSizes.empty()) { coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1]; srcCoreOffs[i] = srcHaloSizes[i * 2]; } tgtCoreOffs[i] = tgtHaloSizes[i * 2]; outShape[i] = coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1]; } } // Extract core from source and copy into destination core. auto noVals = ValueRange{}; auto initVal = builder.create( sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); auto core = builder.create( sourceShard.getLoc(), RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); auto initOprnd = builder.create( sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, coreShape, strides); // Finally update the halo. auto updateHaloResult = builder .create( sourceShard.getLoc(), RankedTensorType::get(outShape, sourceShard.getType().getElementType()), initOprnd, mesh.getSymName(), MeshAxesArrayAttr::get(builder.getContext(), sourceSharding.getSplitAxes()), targetSharding.getDynamicHaloSizes(), targetSharding.getStaticHaloSizes()) .getResult(); return std::make_tuple(cast>(updateHaloResult), targetSharding); } // Handles only resharding on a 1D mesh. // Currently the sharded tensor axes must be exactly divisible by the single // mesh axis size. static TypedValue reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { assert(sourceShard.getType() == shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); [[maybe_unused]] ShapedType targetShardType = shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); assert(sourceShard.getType().getRank() == targetShardType.getRank()); assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); auto [reducedSourceShard, reducedSourceSharding] = handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, sourceShard); if (reducedSourceSharding == targetSharding) { return reducedSourceShard; } TypedValue targetShard; MeshSharding actualTargetSharding; if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() && targetSharding.getStaticShardedDimsOffsets().empty() && reducedSourceSharding.getStaticHaloSizes().empty() && targetSharding.getStaticHaloSizes().empty()) { if (auto tryRes = tryMoveLastSplitAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, sourceUnshardedValue.getType(), reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = trySplitLastAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = tryUnsplitLastAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, sourceUnshardedValue.getType(), reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } } assert(targetShard && "Did not find any pattern to apply."); assert(actualTargetSharding == targetSharding); assert(targetShard.getType() == targetShardType); return targetShard; } TypedValue reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { // If source and destination sharding are the same, no need to do anything. if (sourceSharding == targetSharding) { return sourceShard; } // Tries to handle the case where the resharding is needed because the halo // sizes are different. Supports arbitrary mesh dimensionality. if (auto tryRes = tryUpdateHaloInResharding( builder, mesh, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { return std::get<0>(tryRes.value()); // targetShard } // Resort to handling only 1D meshes since the general case is complicated if // it needs to be communication efficient in terms of minimizing the data // transfered between devices. return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, sourceUnshardedValue, sourceShard); } TypedValue reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue sourceShardValue) { assert(source.getResult() == target.getSrc()); auto sourceSharding = source.getSharding(); auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, cast>(source.getSrc()), sourceShardValue); } TypedValue reshard(OpBuilder &builder, ShardOp source, ShardOp target, TypedValue sourceShardValue, SymbolTableCollection &symbolTableCollection) { MeshOp srcMesh = getMesh(source, symbolTableCollection); assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); return reshard(builder, srcMesh, source, target, sourceShardValue); } void reshardingRegisterDependentDialects(DialectRegistry ®istry) { registry.insert(); } #define GEN_PASS_DEF_SPMDIZATION #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" using UnshardedToShardedValueMap = DenseMap; // Get the types of block arguments for an spmdized block. // Reads the sharding annotations of the arguments to deduce the sharded types. // Types that are not ranked tensors are left unchanged. SmallVector shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection) { SmallVector res; llvm::transform( block.getArguments(), std::back_inserter(res), [&symbolTableCollection](BlockArgument arg) { auto rankedTensorArg = dyn_cast>(arg); if (!rankedTensorArg) { return arg.getType(); } assert(rankedTensorArg.hasOneUse()); Operation *useOp = *rankedTensorArg.getUsers().begin(); ShardOp shardOp = llvm::dyn_cast(useOp); assert(shardOp); MeshOp mesh = getMesh(shardOp, symbolTableCollection); return cast(shardShapedType(rankedTensorArg.getType(), mesh, shardOp.getSharding())); }); return res; } void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef spmdizedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder); static LogicalResult spmdizeOperation( Operation &op, ArrayRef spmdizedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { ShardingInterface shardingInterface = llvm::dyn_cast(op); if (!shardingInterface) { // If there is no sharding interface we are conservative and assume that // the op should be fully replicated no all devices. spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, resultShardings, spmdizationMap, symbolTableCollection, builder); } else { if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, resultShardings, spmdizationMap, symbolTableCollection, builder))) { return failure(); } } assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { return spmdizationMap.contains(result); })); return success(); } // Retrieve the sharding annotations for the operands of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. static std::vector getOperandShardings(Operation &op) { std::vector res; res.reserve(op.getNumOperands()); llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { TypedValue rankedTensor = dyn_cast>(operand); if (!rankedTensor) { return MeshSharding(); } Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast(definingOp); return MeshSharding(shardOp.getSharding()); }); return res; } // Retrieve the sharding annotations for the results of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. static std::vector getResultShardings(Operation &op) { std::vector res; res.reserve(op.getNumResults()); llvm::transform(op.getResults(), std::back_inserter(res), [](OpResult result) { TypedValue rankedTensor = dyn_cast>(result); if (!rankedTensor) { return MeshSharding(); } assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast(userOp); return MeshSharding(shardOp.getSharding()); }); return res; } static LogicalResult spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { Value targetSpmdValue; // Check if 2 shard ops are chained. If not there is no need for resharding // as the source and target shared the same sharding. ShardOp srcShardOp = dyn_cast_or_null(shardOp.getSrc().getDefiningOp()); if (!srcShardOp) { targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); } else { // Insert resharding. TypedValue srcSpmdValue = cast>(spmdizationMap.lookup(srcShardOp)); targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, symbolTableCollection); } assert(!spmdizationMap.contains(shardOp.getResult())); spmdizationMap.map(shardOp.getResult(), targetSpmdValue); return success(); } static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { if (isa(op)) { return success(); } ShardOp shardOp = llvm::dyn_cast(op); if (shardOp) { return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, builder); } SmallVector spmdizedOperands; llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), [&spmdizationMap](Value operand) { assert(spmdizationMap.contains(operand)); return spmdizationMap.lookup(operand); }); return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), getResultShardings(op), spmdizationMap, symbolTableCollection, builder); } static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { SmallVector argLocations; llvm::transform(block.getArguments(), std::back_inserter(argLocations), [](BlockArgument arg) { return arg.getLoc(); }); Block *newBlock = builder.createBlock( block.getParent(), {}, shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); for (auto [unshardedBlockArg, spmdizedBlockArg] : llvm::zip(block.getArguments(), newBlock->getArguments())) { spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); } OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(newBlock); for (Operation &op : block.getOperations()) { if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, builder))) { return failure(); } } return success(); } static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection) { OpBuilder builder(op.getFunctionBody()); // Snapshot the original blocks to not mess up the iteration when adding new // blocks. SmallVector originalBlocks; llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), [](Block &b) { return &b; }); for (Block *block : originalBlocks) { if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, builder))) { return failure(); } } for (Block *block : originalBlocks) { block->erase(); } // Find a return op and change the function results signature to its operands // signature. Operation *returnOp = nullptr; for (Block &block : op.getFunctionBody()) { if (block.empty()) { continue; } if (block.back().hasTrait()) { returnOp = &block.back(); break; } } assert(returnOp); op.setType(FunctionType::get(op->getContext(), op.getFunctionBody().front().getArgumentTypes(), returnOp->getOperandTypes())); return success(); } namespace { struct Spmdization : public impl::SpmdizationBase { void runOnOperation() override { IRMapping spmdizationMap; SymbolTableCollection symbolTableCollection; if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, symbolTableCollection))) { return signalPassFailure(); } } void getDependentDialects(DialectRegistry ®istry) const override { reshardingRegisterDependentDialects(registry); registry.insert(); } }; } // namespace } // namespace mlir::mesh