1 //===- Spmdization.cpp --------------------------------------------- C++ --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 13 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 14 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/BuiltinTypeInterfaces.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/SymbolTable.h" 26 #include "mlir/IR/Value.h" 27 #include "mlir/Interfaces/ControlFlowInterfaces.h" 28 #include "mlir/Interfaces/FunctionInterfaces.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Support/LLVM.h" 31 #include "llvm/ADT/APInt.h" 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/ADT/SmallVector.h" 35 #include "llvm/Support/Casting.h" 36 #include <iterator> 37 #include <optional> 38 #include <tuple> 39 #include <type_traits> 40 41 namespace mlir::mesh { 42 43 template <typename SourceAxes, typename TargetAxes> 44 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, 45 const TargetAxes &targetAxes) { 46 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { 47 return sourceAxes.contains(targetAxis); 48 }); 49 } 50 51 // Return the reduced value and its corresponding sharding. 52 // Example: 53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> 54 // targetSharding = <@mesh_1d, [[]]> 55 // Then will apply all-reduce on the source value 56 // and return it with the sharding <@mesh_1d, [[0]]>. 57 static std::tuple<TypedValue<ShapedType>, MeshSharding> 58 handlePartialAxesDuringResharding(OpBuilder &builder, 59 MeshSharding sourceSharding, 60 MeshSharding targetSharding, 61 TypedValue<ShapedType> sourceShard) { 62 if (sourceSharding.getPartialAxes().empty() && 63 targetSharding.getPartialAxes().empty()) { 64 return {sourceShard, sourceSharding}; 65 } 66 assert(targetSharding.getPartialAxes().empty() || 67 (!sourceSharding.getPartialAxes().empty() && 68 sourceSharding.getPartialType() == targetSharding.getPartialType())); 69 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>; 70 using AxisSet = llvm::SmallDenseSet<Axis>; 71 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), 72 sourceSharding.getPartialAxes().end()); 73 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), 74 targetSharding.getPartialAxes().end()); 75 assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, 76 targetShardingPartialAxesSet)); 77 llvm::SmallVector<MeshAxis> allReduceMeshAxes; 78 llvm::copy_if(sourceShardingPartialAxesSet, 79 std::back_inserter(allReduceMeshAxes), 80 [&targetShardingPartialAxesSet](Axis a) { 81 return !targetShardingPartialAxesSet.contains(a); 82 }); 83 if (allReduceMeshAxes.empty()) { 84 return {sourceShard, sourceSharding}; 85 } 86 87 builder.setInsertionPointAfterValue(sourceShard); 88 TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>( 89 builder 90 .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), 91 sourceSharding.getMeshAttr().getLeafReference(), 92 allReduceMeshAxes, sourceShard, 93 sourceSharding.getPartialType()) 94 .getResult()); 95 96 llvm::SmallVector<MeshAxis> remainingPartialAxes; 97 llvm::copy_if(sourceShardingPartialAxesSet, 98 std::back_inserter(allReduceMeshAxes), 99 [&targetShardingPartialAxesSet](Axis a) { 100 return targetShardingPartialAxesSet.contains(a); 101 }); 102 MeshSharding resultSharding = MeshSharding::get( 103 sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(), 104 remainingPartialAxes, sourceSharding.getPartialType()); 105 return {resultValue, resultSharding}; 106 } 107 108 static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, 109 MeshSharding sourceSharding, 110 int64_t splitTensorAxis, 111 MeshAxis splitMeshAxis) { 112 SmallVector<MeshAxesAttr> targetShardingSplitAxes = 113 llvm::to_vector(sourceSharding.getSplitAxes()); 114 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= 115 splitTensorAxis) { 116 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); 117 } 118 auto targetSplitAxes = 119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); 120 targetSplitAxes.push_back(splitMeshAxis); 121 targetShardingSplitAxes[splitTensorAxis] = 122 MeshAxesAttr::get(ctx, targetSplitAxes); 123 return MeshSharding::get( 124 sourceSharding.getMeshAttr(), targetShardingSplitAxes, 125 sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 126 } 127 128 // Split a replicated tensor along a mesh axis. 129 // E.g. [[0, 1]] -> [[0, 1, 2]]. 130 // Returns the spmdized target value with its sharding. 131 static std::tuple<TypedValue<ShapedType>, MeshSharding> 132 splitLastAxisInResharding(ImplicitLocOpBuilder &builder, 133 MeshSharding sourceSharding, 134 TypedValue<ShapedType> sourceShard, MeshOp mesh, 135 int64_t splitTensorAxis, MeshAxis splitMeshAxis) { 136 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 137 builder 138 .create<AllSliceOp>(sourceShard, mesh, 139 ArrayRef<MeshAxis>(splitMeshAxis), 140 splitTensorAxis) 141 .getResult()); 142 MeshSharding targetSharding = targetShardingInSplitLastAxis( 143 builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); 144 return {targetShard, targetSharding}; 145 } 146 147 // Detect if the resharding is of type e.g. 148 // [[0, 1]] -> [[0, 1, 2]]. 149 // If detected, returns the corresponding tensor axis mesh axis pair. 150 // Does not detect insertions like 151 // [[0, 1]] -> [[0, 2, 1]]. 152 static std::optional<std::tuple<int64_t, MeshAxis>> 153 detectSplitLastAxisInResharding(MeshSharding sourceSharding, 154 MeshSharding targetSharding) { 155 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); 156 ++tensorAxis) { 157 if (sourceSharding.getSplitAxes().size() > tensorAxis) { 158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != 159 targetSharding.getSplitAxes()[tensorAxis].size()) { 160 continue; 161 } 162 if (!llvm::equal( 163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), 164 llvm::make_range( 165 targetSharding.getSplitAxes()[tensorAxis] 166 .asArrayRef() 167 .begin(), 168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 169 1))) { 170 continue; 171 } 172 } else { 173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { 174 continue; 175 } 176 } 177 return std::make_tuple( 178 tensorAxis, 179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); 180 } 181 return std::nullopt; 182 } 183 184 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 185 trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 186 MeshSharding sourceSharding, 187 MeshSharding targetSharding, 188 TypedValue<ShapedType> sourceShard) { 189 if (auto detectRes = 190 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { 191 auto [tensorAxis, meshAxis] = detectRes.value(); 192 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, 193 tensorAxis, meshAxis); 194 } 195 196 return std::nullopt; 197 } 198 199 // Detect if the resharding is of type e.g. 200 // [[0, 1, 2]] -> [[0, 1]]. 201 // If detected, returns the corresponding tensor axis mesh axis pair. 202 static std::optional<std::tuple<int64_t, MeshAxis>> 203 detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, 204 MeshSharding targetSharding) { 205 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); 206 ++tensorAxis) { 207 if (targetSharding.getSplitAxes().size() > tensorAxis) { 208 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 209 targetSharding.getSplitAxes()[tensorAxis].size() + 1) 210 continue; 211 if (!llvm::equal( 212 llvm::make_range( 213 sourceSharding.getSplitAxes()[tensorAxis] 214 .asArrayRef() 215 .begin(), 216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 217 1), 218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) 219 continue; 220 } else { 221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) 222 continue; 223 } 224 return std::make_tuple( 225 tensorAxis, 226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); 227 } 228 return std::nullopt; 229 } 230 231 static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, 232 MeshSharding sourceSharding, 233 int64_t splitTensorAxis) { 234 SmallVector<MeshAxesAttr> targetShardingSplitAxes = 235 llvm::to_vector(sourceSharding.getSplitAxes()); 236 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > 237 splitTensorAxis); 238 auto targetSplitAxes = 239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); 240 241 targetSplitAxes.pop_back(); 242 targetShardingSplitAxes[splitTensorAxis] = 243 MeshAxesAttr::get(ctx, targetSplitAxes); 244 return MeshSharding::get( 245 sourceSharding.getMeshAttr(), targetShardingSplitAxes, 246 sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 247 } 248 249 static ShapedType allGatherResultShapeInUnsplitLastAxis( 250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { 251 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); 252 targetShape[splitTensorAxis] = 253 gatherDimension(targetShape[splitTensorAxis], splitCount); 254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); 255 } 256 257 static std::tuple<TypedValue<ShapedType>, MeshSharding> 258 unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, 259 MeshSharding sourceSharding, 260 ShapedType sourceUnshardedShape, 261 TypedValue<ShapedType> sourceShard, MeshOp mesh, 262 int64_t splitTensorAxis, MeshAxis splitMeshAxis) { 263 MLIRContext *ctx = builder.getContext(); 264 builder.setInsertionPointAfterValue(sourceShard); 265 266 MeshSharding targetSharding = 267 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); 268 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( 269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); 270 Value allGatherResult = builder.create<AllGatherOp>( 271 RankedTensorType::get(allGatherResultShape.getShape(), 272 allGatherResultShape.getElementType()), 273 mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, 274 APInt(64, splitTensorAxis)); 275 ShapedType targetShape = 276 shardShapedType(sourceUnshardedShape, mesh, targetSharding); 277 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 278 builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult()); 279 return {targetShard, targetSharding}; 280 } 281 282 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 283 tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 284 MeshSharding sourceSharding, 285 MeshSharding targetSharding, 286 ShapedType sourceUnshardedShape, 287 TypedValue<ShapedType> sourceShard) { 288 if (auto detectRes = 289 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { 290 auto [tensorAxis, meshAxis] = detectRes.value(); 291 return unsplitLastAxisInResharding(builder, sourceSharding, 292 sourceUnshardedShape, sourceShard, mesh, 293 tensorAxis, meshAxis); 294 } 295 296 return std::nullopt; 297 } 298 299 // Detect if the resharding is of type e.g. 300 // [[0, 1], [2]] -> [[0], [1, 2]]. 301 // Only moving the last axis counts. 302 // If detected, returns the corresponding (source_tensor_axis, 303 // target_tensor_axis, mesh_axis) tuple. 304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> 305 detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, 306 MeshSharding targetSharding) { 307 for (size_t sourceTensorAxis = 0; 308 sourceTensorAxis < sourceSharding.getSplitAxes().size(); 309 ++sourceTensorAxis) { 310 for (size_t targetTensorAxis = 0; 311 targetTensorAxis < targetSharding.getSplitAxes().size(); 312 ++targetTensorAxis) { 313 if (sourceTensorAxis == targetTensorAxis) 314 continue; 315 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || 316 targetSharding.getSplitAxes()[targetTensorAxis].empty() || 317 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != 318 targetSharding.getSplitAxes()[targetTensorAxis] 319 .asArrayRef() 320 .back()) 321 continue; 322 if (!llvm::equal( 323 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] 324 .asArrayRef() 325 .begin(), 326 sourceSharding.getSplitAxes()[sourceTensorAxis] 327 .asArrayRef() 328 .end() - 329 1), 330 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] 331 .asArrayRef() 332 .begin(), 333 targetSharding.getSplitAxes()[targetTensorAxis] 334 .asArrayRef() 335 .end() - 336 1))) 337 continue; 338 return std::make_tuple( 339 sourceTensorAxis, targetTensorAxis, 340 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); 341 } 342 } 343 return std::nullopt; 344 } 345 346 static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, 347 MeshSharding sourceSharding, 348 int64_t sourceTensorAxis, 349 int64_t targetTensorAxis) { 350 SmallVector<MeshAxesAttr> targetShardingSplitAxes = 351 llvm::to_vector(sourceSharding.getSplitAxes()); 352 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= 353 targetTensorAxis) { 354 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); 355 } 356 357 auto sourceSplitAxes = 358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); 359 assert(!sourceSplitAxes.empty()); 360 auto meshAxis = sourceSplitAxes.back(); 361 sourceSplitAxes.pop_back(); 362 targetShardingSplitAxes[sourceTensorAxis] = 363 MeshAxesAttr::get(ctx, sourceSplitAxes); 364 365 auto targetSplitAxes = 366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); 367 targetSplitAxes.push_back(meshAxis); 368 targetShardingSplitAxes[targetTensorAxis] = 369 MeshAxesAttr::get(ctx, targetSplitAxes); 370 371 return MeshSharding::get( 372 sourceSharding.getMeshAttr(), targetShardingSplitAxes, 373 sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); 374 } 375 376 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, 377 int64_t splitCount, 378 int64_t sourceTensorAxis, 379 int64_t targetTensorAxis) { 380 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); 381 targetShape[sourceTensorAxis] = 382 gatherDimension(targetShape[sourceTensorAxis], splitCount); 383 targetShape[targetTensorAxis] = 384 shardDimension(targetShape[targetTensorAxis], splitCount); 385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); 386 } 387 388 static std::tuple<TypedValue<ShapedType>, MeshSharding> 389 moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 390 MeshSharding sourceSharding, 391 ShapedType sourceUnshardedShape, 392 TypedValue<ShapedType> sourceShard, 393 int64_t sourceTensorAxis, 394 int64_t targetTensorAxis, MeshAxis meshAxis) { 395 MLIRContext *ctx = builder.getContext(); 396 builder.setInsertionPointAfterValue(sourceShard); 397 398 MeshSharding targetSharding = targetShardingInMoveLastAxis( 399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); 400 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( 401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, 402 targetTensorAxis); 403 Value allToAllResult = builder.create<AllToAllOp>( 404 RankedTensorType::get(allToAllResultShape.getShape(), 405 allToAllResultShape.getElementType()), 406 mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, 407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); 408 ShapedType targetShape = 409 shardShapedType(sourceUnshardedShape, mesh, targetSharding); 410 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 411 builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult()); 412 return {targetShard, targetSharding}; 413 } 414 415 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 416 tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 417 MeshSharding sourceSharding, 418 MeshSharding targetSharding, 419 ShapedType sourceUnshardedShape, 420 TypedValue<ShapedType> sourceShard) { 421 if (auto detectRes = 422 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { 423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); 424 return moveLastSplitAxisInResharding( 425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, 426 sourceTensorAxis, targetTensorAxis, meshAxis); 427 } 428 429 return std::nullopt; 430 } 431 432 // Detect a change in the halo size (only) and create necessary operations if 433 // needed. A changed halo sizes requires copying the "core" of the source tensor 434 // into the "core" of the destination tensor followed by an update halo 435 // operation. 436 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> 437 tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, 438 MeshSharding sourceSharding, 439 MeshSharding targetSharding, 440 ShapedType sourceUnshardedShape, 441 TypedValue<ShapedType> sourceShard) { 442 // Currently handles only cases where halo sizes differ but everything else 443 // stays the same (from source to destination sharding). 444 if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) || 445 !sourceSharding.getPartialAxes().empty() || 446 !targetSharding.getPartialAxes().empty() || 447 !sourceSharding.getStaticShardedDimsOffsets().empty() || 448 !targetSharding.getStaticShardedDimsOffsets().empty() || 449 sourceSharding.equalHaloSizes(targetSharding)) { 450 return std::nullopt; 451 } 452 453 auto srcHaloSizes = sourceSharding.getStaticHaloSizes(); 454 auto tgtHaloSizes = targetSharding.getStaticHaloSizes(); 455 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size()); 456 assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) && 457 !ShapedType::isDynamicShape(tgtHaloSizes) && 458 sourceShard.getType().hasStaticShape()) && 459 "dynamic shapes/halos are not supported yet for mesh-spmdization"); 460 auto rank = sourceShard.getType().getRank(); 461 auto splitAxes = sourceSharding.getSplitAxes(); 462 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), 463 strides(rank, 1), outShape(sourceShard.getType().getShape()), 464 coreShape(sourceShard.getType().getShape()); 465 466 // Determine "core" of source and destination. 467 // The core is the local part of the shard excluding halo regions. 468 for (auto i = 0u; i < rank; ++i) { 469 if (i < splitAxes.size() && !splitAxes[i].empty()) { 470 if (!srcHaloSizes.empty()) { 471 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1]; 472 srcCoreOffs[i] = srcHaloSizes[i * 2]; 473 } 474 tgtCoreOffs[i] = tgtHaloSizes[i * 2]; 475 outShape[i] = 476 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1]; 477 } 478 } 479 480 // Extract core from source and copy into destination core. 481 auto noVals = ValueRange{}; 482 auto initVal = builder.create<tensor::EmptyOp>( 483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); 484 auto core = builder.create<tensor::ExtractSliceOp>( 485 sourceShard.getLoc(), 486 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), 487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); 488 auto initOprnd = builder.create<tensor::InsertSliceOp>( 489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, 490 coreShape, strides); 491 492 // Finally update the halo. 493 auto updateHaloResult = 494 builder 495 .create<UpdateHaloOp>( 496 sourceShard.getLoc(), 497 RankedTensorType::get(outShape, 498 sourceShard.getType().getElementType()), 499 initOprnd, mesh.getSymName(), 500 MeshAxesArrayAttr::get(builder.getContext(), 501 sourceSharding.getSplitAxes()), 502 targetSharding.getDynamicHaloSizes(), 503 targetSharding.getStaticHaloSizes()) 504 .getResult(); 505 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult), 506 targetSharding); 507 } 508 509 // Handles only resharding on a 1D mesh. 510 // Currently the sharded tensor axes must be exactly divisible by the single 511 // mesh axis size. 512 static TypedValue<ShapedType> 513 reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, 514 MeshSharding sourceSharding, MeshSharding targetSharding, 515 TypedValue<ShapedType> sourceUnshardedValue, 516 TypedValue<ShapedType> sourceShard) { 517 assert(sourceShard.getType() == 518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); 519 [[maybe_unused]] ShapedType targetShardType = 520 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); 521 assert(sourceShard.getType().getRank() == targetShardType.getRank()); 522 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); 523 524 auto [reducedSourceShard, reducedSourceSharding] = 525 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, 526 sourceShard); 527 528 if (reducedSourceSharding == targetSharding) { 529 return reducedSourceShard; 530 } 531 532 TypedValue<ShapedType> targetShard; 533 MeshSharding actualTargetSharding; 534 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() && 535 targetSharding.getStaticShardedDimsOffsets().empty() && 536 reducedSourceSharding.getStaticHaloSizes().empty() && 537 targetSharding.getStaticHaloSizes().empty()) { 538 if (auto tryRes = tryMoveLastSplitAxisInResharding( 539 builder, mesh, reducedSourceSharding, targetSharding, 540 sourceUnshardedValue.getType(), reducedSourceShard)) { 541 std::tie(targetShard, actualTargetSharding) = tryRes.value(); 542 } else if (auto tryRes = trySplitLastAxisInResharding( 543 builder, mesh, reducedSourceSharding, targetSharding, 544 reducedSourceShard)) { 545 std::tie(targetShard, actualTargetSharding) = tryRes.value(); 546 } else if (auto tryRes = tryUnsplitLastAxisInResharding( 547 builder, mesh, reducedSourceSharding, targetSharding, 548 sourceUnshardedValue.getType(), reducedSourceShard)) { 549 std::tie(targetShard, actualTargetSharding) = tryRes.value(); 550 } 551 } 552 assert(targetShard && "Did not find any pattern to apply."); 553 assert(actualTargetSharding == targetSharding); 554 assert(targetShard.getType() == targetShardType); 555 return targetShard; 556 } 557 558 TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, 559 MeshSharding sourceSharding, 560 MeshSharding targetSharding, 561 TypedValue<ShapedType> sourceUnshardedValue, 562 TypedValue<ShapedType> sourceShard) { 563 // If source and destination sharding are the same, no need to do anything. 564 if (sourceSharding == targetSharding) { 565 return sourceShard; 566 } 567 568 // Tries to handle the case where the resharding is needed because the halo 569 // sizes are different. Supports arbitrary mesh dimensionality. 570 if (auto tryRes = tryUpdateHaloInResharding( 571 builder, mesh, sourceSharding, targetSharding, 572 sourceUnshardedValue.getType(), sourceShard)) { 573 return std::get<0>(tryRes.value()); // targetShard 574 } 575 576 // Resort to handling only 1D meshes since the general case is complicated if 577 // it needs to be communication efficient in terms of minimizing the data 578 // transfered between devices. 579 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, 580 sourceUnshardedValue, sourceShard); 581 } 582 583 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, 584 ShardOp target, 585 TypedValue<ShapedType> sourceShardValue) { 586 assert(source.getResult() == target.getSrc()); 587 auto sourceSharding = source.getSharding(); 588 auto targetSharding = target.getSharding(); 589 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); 590 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, 591 cast<TypedValue<ShapedType>>(source.getSrc()), 592 sourceShardValue); 593 } 594 595 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, 596 ShardOp target, 597 TypedValue<ShapedType> sourceShardValue, 598 SymbolTableCollection &symbolTableCollection) { 599 MeshOp srcMesh = getMesh(source, symbolTableCollection); 600 assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); 601 return reshard(builder, srcMesh, source, target, sourceShardValue); 602 } 603 604 void reshardingRegisterDependentDialects(DialectRegistry ®istry) { 605 registry.insert<mesh::MeshDialect, tensor::TensorDialect>(); 606 } 607 608 #define GEN_PASS_DEF_SPMDIZATION 609 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" 610 611 using UnshardedToShardedValueMap = DenseMap<Value, Value>; 612 613 // Get the types of block arguments for an spmdized block. 614 // Reads the sharding annotations of the arguments to deduce the sharded types. 615 // Types that are not ranked tensors are left unchanged. 616 SmallVector<Type> 617 shardedBlockArgumentTypes(Block &block, 618 SymbolTableCollection &symbolTableCollection) { 619 SmallVector<Type> res; 620 llvm::transform( 621 block.getArguments(), std::back_inserter(res), 622 [&symbolTableCollection](BlockArgument arg) { 623 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg); 624 if (!rankedTensorArg) { 625 return arg.getType(); 626 } 627 628 assert(rankedTensorArg.hasOneUse()); 629 Operation *useOp = *rankedTensorArg.getUsers().begin(); 630 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp); 631 assert(shardOp); 632 MeshOp mesh = getMesh(shardOp, symbolTableCollection); 633 return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh, 634 shardOp.getSharding())); 635 }); 636 return res; 637 } 638 639 void spmdizeTriviallyShardableOperation(Operation &op, 640 ArrayRef<Value> spmdizedOperands, 641 ArrayRef<MeshSharding> operandShardings, 642 ArrayRef<MeshSharding> resultShardings, 643 IRMapping &spmdizationMap, 644 SymbolTableCollection &symbolTable, 645 OpBuilder &builder); 646 647 static LogicalResult spmdizeOperation( 648 Operation &op, ArrayRef<Value> spmdizedOperands, 649 ArrayRef<MeshSharding> operandShardings, 650 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 651 SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { 652 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op); 653 if (!shardingInterface) { 654 // If there is no sharding interface we are conservative and assume that 655 // the op should be fully replicated no all devices. 656 spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, 657 resultShardings, spmdizationMap, 658 symbolTableCollection, builder); 659 } else { 660 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, 661 resultShardings, spmdizationMap, 662 symbolTableCollection, builder))) { 663 return failure(); 664 } 665 } 666 667 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { 668 return spmdizationMap.contains(result); 669 })); 670 671 return success(); 672 } 673 674 // Retrieve the sharding annotations for the operands of the given operation. 675 // If the type is not a ranked tensor it is not require to have an annotation. 676 static std::vector<MeshSharding> getOperandShardings(Operation &op) { 677 std::vector<MeshSharding> res; 678 res.reserve(op.getNumOperands()); 679 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { 680 TypedValue<RankedTensorType> rankedTensor = 681 dyn_cast<TypedValue<RankedTensorType>>(operand); 682 if (!rankedTensor) { 683 return MeshSharding(); 684 } 685 686 Operation *definingOp = operand.getDefiningOp(); 687 assert(definingOp); 688 ShardOp shardOp = llvm::cast<ShardOp>(definingOp); 689 return MeshSharding(shardOp.getSharding()); 690 }); 691 return res; 692 } 693 694 // Retrieve the sharding annotations for the results of the given operation. 695 // If the type is not a ranked tensor it is not require to have an annotation. 696 static std::vector<MeshSharding> getResultShardings(Operation &op) { 697 std::vector<MeshSharding> res; 698 res.reserve(op.getNumResults()); 699 llvm::transform(op.getResults(), std::back_inserter(res), 700 [](OpResult result) { 701 TypedValue<RankedTensorType> rankedTensor = 702 dyn_cast<TypedValue<RankedTensorType>>(result); 703 if (!rankedTensor) { 704 return MeshSharding(); 705 } 706 707 assert(result.hasOneUse()); 708 Operation *userOp = *result.getUsers().begin(); 709 ShardOp shardOp = llvm::cast<ShardOp>(userOp); 710 return MeshSharding(shardOp.getSharding()); 711 }); 712 return res; 713 } 714 715 static LogicalResult 716 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, 717 SymbolTableCollection &symbolTableCollection, 718 OpBuilder &builder) { 719 Value targetSpmdValue; 720 721 // Check if 2 shard ops are chained. If not there is no need for resharding 722 // as the source and target shared the same sharding. 723 ShardOp srcShardOp = 724 dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp()); 725 if (!srcShardOp) { 726 targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); 727 } else { 728 // Insert resharding. 729 TypedValue<ShapedType> srcSpmdValue = 730 cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp)); 731 targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, 732 symbolTableCollection); 733 } 734 735 assert(!spmdizationMap.contains(shardOp.getResult())); 736 spmdizationMap.map(shardOp.getResult(), targetSpmdValue); 737 return success(); 738 } 739 740 static LogicalResult 741 spmdizeOperation(Operation &op, IRMapping &spmdizationMap, 742 SymbolTableCollection &symbolTableCollection, 743 OpBuilder &builder) { 744 if (isa<ShardingOp>(op)) { 745 return success(); 746 } 747 748 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); 749 if (shardOp) { 750 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, 751 builder); 752 } 753 754 SmallVector<Value> spmdizedOperands; 755 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), 756 [&spmdizationMap](Value operand) { 757 assert(spmdizationMap.contains(operand)); 758 return spmdizationMap.lookup(operand); 759 }); 760 return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), 761 getResultShardings(op), spmdizationMap, 762 symbolTableCollection, builder); 763 } 764 765 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, 766 SymbolTableCollection &symbolTableCollection, 767 OpBuilder &builder) { 768 SmallVector<Location> argLocations; 769 llvm::transform(block.getArguments(), std::back_inserter(argLocations), 770 [](BlockArgument arg) { return arg.getLoc(); }); 771 Block *newBlock = builder.createBlock( 772 block.getParent(), {}, 773 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); 774 for (auto [unshardedBlockArg, spmdizedBlockArg] : 775 llvm::zip(block.getArguments(), newBlock->getArguments())) { 776 spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); 777 } 778 779 OpBuilder::InsertionGuard insertionGuard(builder); 780 builder.setInsertionPointToEnd(newBlock); 781 for (Operation &op : block.getOperations()) { 782 if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, 783 builder))) { 784 return failure(); 785 } 786 } 787 788 return success(); 789 } 790 791 static LogicalResult 792 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, 793 SymbolTableCollection &symbolTableCollection) { 794 OpBuilder builder(op.getFunctionBody()); 795 796 // Snapshot the original blocks to not mess up the iteration when adding new 797 // blocks. 798 SmallVector<Block *> originalBlocks; 799 llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), 800 [](Block &b) { return &b; }); 801 802 for (Block *block : originalBlocks) { 803 if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, 804 builder))) { 805 return failure(); 806 } 807 } 808 809 for (Block *block : originalBlocks) { 810 block->erase(); 811 } 812 813 // Find a return op and change the function results signature to its operands 814 // signature. 815 Operation *returnOp = nullptr; 816 for (Block &block : op.getFunctionBody()) { 817 if (block.empty()) { 818 continue; 819 } 820 821 if (block.back().hasTrait<OpTrait::ReturnLike>()) { 822 returnOp = &block.back(); 823 break; 824 } 825 } 826 assert(returnOp); 827 op.setType(FunctionType::get(op->getContext(), 828 op.getFunctionBody().front().getArgumentTypes(), 829 returnOp->getOperandTypes())); 830 831 return success(); 832 } 833 834 namespace { 835 836 struct Spmdization : public impl::SpmdizationBase<Spmdization> { 837 void runOnOperation() override { 838 IRMapping spmdizationMap; 839 SymbolTableCollection symbolTableCollection; 840 if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, 841 symbolTableCollection))) { 842 return signalPassFailure(); 843 } 844 } 845 846 void getDependentDialects(DialectRegistry ®istry) const override { 847 reshardingRegisterDependentDialects(registry); 848 registry.insert<mesh::MeshDialect>(); 849 } 850 }; 851 852 } // namespace 853 854 } // namespace mlir::mesh 855