1 //===- ShardingInterface.cpp -------------------------------------*- C++-*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 10 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 11 12 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/IRMapping.h" 15 #include "mlir/Support/LLVM.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/SmallSet.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <utility> 22 23 #define DEBUG_TYPE "sharding-interface" 24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 25 26 using namespace mlir; 27 using namespace mlir::mesh; 28 29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" 30 31 //===----------------------------------------------------------------------===// 32 // common util functions 33 //===----------------------------------------------------------------------===// 34 35 static LogicalResult 36 checkOperandAffineExprRecursively(AffineExpr expr, 37 SmallVectorImpl<bool> &seenIds) { 38 switch (expr.getKind()) { 39 case AffineExprKind::Add: { 40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 41 AffineExpr lhs = binOpExpr.getLHS(); 42 AffineExpr rhs = binOpExpr.getRHS(); 43 if (failed(checkOperandAffineExprRecursively(lhs, seenIds))) 44 return failure(); 45 if (failed(checkOperandAffineExprRecursively(rhs, seenIds))) 46 return failure(); 47 return success(); 48 } 49 case AffineExprKind::Mul: { 50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 51 AffineExpr lhs = binOpExpr.getLHS(); 52 AffineExpr rhs = binOpExpr.getRHS(); 53 AffineExpr dimExpr; 54 if (lhs.getKind() == AffineExprKind::DimId && 55 rhs.getKind() == AffineExprKind::Constant) { 56 dimExpr = lhs; 57 } else if (rhs.getKind() == AffineExprKind::DimId && 58 lhs.getKind() == AffineExprKind::Constant) { 59 dimExpr = rhs; 60 } else 61 return failure(); 62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition(); 63 if ((size_t)position >= seenIds.size() || seenIds[position]) 64 return failure(); 65 seenIds[position] = true; 66 return success(); 67 } 68 case AffineExprKind::DimId: { 69 unsigned position = cast<AffineDimExpr>(expr).getPosition(); 70 if ((size_t)position >= seenIds.size() || seenIds[position]) 71 return failure(); 72 seenIds[position] = true; 73 return success(); 74 } 75 default: 76 return failure(); 77 } 78 } 79 80 static FailureOr<llvm::SmallSet<unsigned, 2>> 81 checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { 82 SmallVector<bool> seenIds(numDims, false); 83 if (failed(checkOperandAffineExprRecursively(expr, seenIds))) 84 return failure(); 85 86 llvm::SmallSet<unsigned, 2> positions; 87 for (auto it : llvm::enumerate(seenIds)) { 88 if (it.value()) 89 positions.insert((unsigned)it.index()); 90 } 91 return positions; 92 } 93 94 template <typename T> 95 SmallVector<MeshAxesAttr> 96 fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { 97 SmallVector<MeshAxesAttr> res; 98 for (const auto &v : vec) { 99 res.emplace_back(MeshAxesAttr::get(ctxt, v)); 100 } 101 return res; 102 } 103 104 //===----------------------------------------------------------------------===// 105 // mesh::getMeshSharding 106 //===----------------------------------------------------------------------===// 107 108 FailureOr<std::pair<bool, MeshSharding>> 109 mesh::getMeshSharding(OpResult result) { 110 Value val = cast<Value>(result); 111 bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { 112 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); 113 if (!shardOp) 114 return false; 115 return !shardOp.getAnnotateForUsers(); 116 }); 117 118 if (anyShardedForDef) { 119 // expected to have exact one use if it has a use of `mesh.shard` without 120 // unit attr annotate_for_users 121 if (!val.hasOneUse()) 122 return failure(); 123 auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); 124 return std::make_pair(false, MeshSharding(shardOp.getSharding())); 125 } 126 127 bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) { 128 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); 129 if (!shardOp) 130 return false; 131 return shardOp.getAnnotateForUsers(); 132 }); 133 if (anyShardedForUsers) { 134 SmallVector<ShardOp> shardOps; 135 for (Operation *user : val.getUsers()) { 136 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); 137 if (shardOp) 138 shardOps.push_back(shardOp); 139 } 140 MeshSharding shardForDef = shardOps[0].getSharding(); 141 for (size_t i = 1; i < shardOps.size(); ++i) { 142 // TODO: Deduce a reasonable mesh sharding attr for def when they are 143 // different 144 assert(shardForDef == shardOps[i].getSharding() && 145 "only support all shard ops have the same mesh sharding attr"); 146 } 147 return std::make_pair(true, shardForDef); 148 } 149 return failure(); 150 } 151 152 FailureOr<std::pair<bool, MeshSharding>> 153 mesh::getMeshSharding(OpOperand &opOperand) { 154 Value val = opOperand.get(); 155 if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) 156 return std::make_pair(shardOp.getAnnotateForUsers(), 157 MeshSharding(shardOp.getSharding())); 158 159 return failure(); 160 } 161 162 //===----------------------------------------------------------------------===// 163 // ShardingInterface::verifyShardingInterfaceImpl 164 //===----------------------------------------------------------------------===// 165 166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { 167 Operation *op = getOperation(); 168 169 // check operands and results type 170 for (Type type : op->getOperandTypes()) 171 if (!llvm::isa<RankedTensorType>(type)) 172 return failure(); 173 for (Type type : op->getResultTypes()) 174 if (!llvm::isa<RankedTensorType>(type)) 175 return failure(); 176 177 // check loop types 178 SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes(); 179 if (loopTypes.empty()) 180 return failure(); 181 182 // check maps 183 SmallVector<AffineMap> maps = getIndexingMaps(); 184 if (maps.empty()) 185 return failure(); 186 unsigned numOperands = op->getNumOperands(); 187 unsigned numResults = op->getNumResults(); 188 if (numOperands + numResults != maps.size()) 189 return failure(); 190 191 for (OpResult result : op->getResults()) { 192 auto resultType = dyn_cast<RankedTensorType>(result.getType()); 193 if (!resultType) 194 return failure(); 195 AffineMap map = maps[numOperands + result.getResultNumber()]; 196 if (!map.isProjectedPermutation()) { 197 return failure(); 198 } 199 } 200 201 return success(); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // ShardingInterface::printLoopTypesAndIndexingMaps 206 //===----------------------------------------------------------------------===// 207 208 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { 209 os << "print loop types and indexing maps for: \n"; 210 getOperation()->print(os); 211 os << "\n"; 212 os << "loop types: ["; 213 for (utils::IteratorType type : getLoopIteratorTypes()) { 214 os << stringifyEnum(type) << " "; 215 } 216 os << "]\n"; 217 os << "indexing maps: \n"; 218 for (AffineMap map : getIndexingMaps()) 219 os << map << "\n"; 220 os << "\n"; 221 } 222 223 //===----------------------------------------------------------------------===// 224 // detail::defaultGetShardingOption 225 //===----------------------------------------------------------------------===// 226 227 namespace { 228 229 // Update the given `shardingOption` according to `meshAxes` and `loopIdx` 230 static LogicalResult fillShardingOption(Operation *op, 231 ShardingOption &shardingOption, 232 FlatSymbolRefAttr mesh, 233 ArrayRef<MeshAxis> meshAxes, 234 unsigned loopIdx) { 235 if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || 236 (!shardingOption.shardingArray[loopIdx].empty() && 237 shardingOption.shardingArray[loopIdx] != meshAxes)) { 238 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " 239 << loopIdx << "\n"); 240 return failure(); 241 } 242 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { 243 if (i == loopIdx) 244 continue; 245 246 for (MeshAxis axis : meshAxes) { 247 if (llvm::is_contained(shardingOption.shardingArray[i], axis)) { 248 LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " 249 << axis << " duplicate"); 250 return failure(); 251 } 252 } 253 } 254 if (mesh) 255 shardingOption.mesh = mesh; 256 if (shardingOption.shardingArray[loopIdx].empty()) 257 shardingOption.shardingArray[loopIdx].append(meshAxes.begin(), 258 meshAxes.end()); 259 return success(); 260 } 261 262 } // namespace 263 264 FailureOr<ShardingOption> 265 mesh::detail::defaultGetShardingOption(Operation *op, 266 ArrayRef<MeshSharding> operandShardings, 267 ArrayRef<MeshSharding> resultShardings) { 268 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 269 ShardingOption shardingOption; 270 271 if (failed(shardingOp.verifyShardingInterfaceImpl())) 272 return op->emitOpError() << "invalid sharding interface implementation"; 273 SmallVector<utils::IteratorType> loopTypes = 274 shardingOp.getLoopIteratorTypes(); 275 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 276 unsigned numOperands = op->getNumOperands(); 277 shardingOption.shardingArray.resize(loopTypes.size()); 278 llvm::SmallVector<MeshAxis> partialMeshAxes; 279 llvm::SmallSet<unsigned, 4> visitedLoopIndices; 280 bool anyShardingInResultsOrOperands = false; 281 282 // 1. Fill sharding option based on op results 283 for (auto shardingIt : llvm::enumerate(resultShardings)) { 284 MeshSharding shardAttr = shardingIt.value(); 285 if (!shardAttr) 286 continue; 287 AffineMap map = maps[numOperands + shardingIt.index()]; 288 anyShardingInResultsOrOperands = true; 289 // Handle the split axes: calculate the corresponding loop index for each 290 // split axes sub-array, and then store the sub-array to 291 // shardingOption[index] 292 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { 293 AffineExpr expr = std::get<0>(it); 294 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); 295 auto dim = cast<AffineDimExpr>(expr); 296 unsigned index = dim.getPosition(); 297 visitedLoopIndices.insert(index); 298 if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(), 299 axes, index))) 300 return failure(); 301 } 302 303 // Handle the partial axes: at this stage, the exact loop index/indices 304 // cannot be decided because there could be multiple reduction loops. 305 ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes(); 306 if (!partialAxes.empty()) { 307 if (!partialMeshAxes.empty()) 308 return op->emitOpError() << "at most one result with partial axes is " 309 "supported at present"; 310 partialMeshAxes.append(partialAxes.begin(), partialAxes.end()); 311 // Add all the reduction loop indices to `visitedLoopIndices` if 312 // `partialAxes` is not empty 313 for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) { 314 if (isReductionLoop(loopTypes[loopIdx])) 315 visitedLoopIndices.insert(loopIdx); 316 } 317 } 318 } 319 320 // 2. Fill sharding option based on operands 321 for (auto shardingIt : llvm::enumerate(operandShardings)) { 322 MeshSharding shardAttr = shardingIt.value(); 323 if (!shardAttr) 324 continue; 325 326 anyShardingInResultsOrOperands = true; 327 AffineMap map = maps[shardingIt.index()]; 328 unsigned numDims = map.getNumDims(); 329 330 // Handle the split axes. Partial axes don't need to be handled because they 331 // only affect the defining op of the operand. 332 // 333 // TODO: Change to process the operands with single loop index first and 334 // then the operands with multiple loop indices. 335 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { 336 AffineExpr expr = std::get<0>(it); 337 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); 338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = 339 checkOperandAffineExpr(expr, numDims); 340 if (failed(loopIndices)) 341 return op->emitOpError() 342 << "operand's affine expression is restricted to const_i * " 343 "dim_i + const_j + dim_j + ..."; 344 if (loopIndices->empty()) 345 continue; 346 if (loopIndices->size() == 1) { 347 unsigned loopIdx = *loopIndices->begin(); 348 visitedLoopIndices.insert(loopIdx); 349 if (failed(fillShardingOption(op, shardingOption, 350 shardAttr.getMeshAttr(), axes, loopIdx))) 351 return failure(); 352 } 353 // If multiple loop indices correspond to a dimension of an operand, it is 354 // difficult to infer which loop indices are responsible for sharding. 355 // Therefore, the exact loop index must be specified by others. 356 if (loopIndices->size() > 1) { 357 bool seenLoopIndices = false; 358 for (unsigned loopIdx : *loopIndices) { 359 if (visitedLoopIndices.contains(loopIdx)) { 360 seenLoopIndices = true; 361 break; 362 } 363 } 364 if (!seenLoopIndices) 365 return op->emitOpError() 366 << "the operand " << shardingIt.index() 367 << " has multiple loop indices in a dimension, but none of " 368 "them could be found in the exactly specified annotation " 369 "of op results or operands."; 370 } 371 } 372 } 373 374 // 3. Finalize sharding option 375 if (!partialMeshAxes.empty()) { 376 bool anyNonEmptyReductionLoop = llvm::any_of( 377 llvm::enumerate(shardingOption.shardingArray), [&](auto it) { 378 SmallVector<MeshAxis> &subArray = it.value(); 379 int64_t idx = it.index(); 380 return isReductionLoop(loopTypes[idx]) && !subArray.empty(); 381 }); 382 if (!anyNonEmptyReductionLoop) { 383 bool filled = false; 384 for (size_t idx = 0; idx < loopTypes.size(); ++idx) { 385 if (isReductionLoop(loopTypes[idx])) { 386 std::ignore = fillShardingOption(op, shardingOption, nullptr, 387 partialMeshAxes, idx); 388 filled = true; 389 break; 390 } 391 } 392 if (!filled) 393 return op->emitOpError() << "no matched reduction loop found for the " 394 "result's partial type"; 395 } 396 } 397 removeTrailingEmptySubArray(shardingOption.shardingArray); 398 if (!anyShardingInResultsOrOperands) 399 shardingOption.empty = true; 400 return shardingOption; 401 } 402 403 // Get the sharding attributed for the given result and sharding option. 404 MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, 405 AffineMap map, ArrayRef<utils::IteratorType> loopTypes, 406 ArrayRef<ReductionKind> reductionLoopKinds) { 407 auto resultType = cast<RankedTensorType>(result.getType()); 408 SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); 409 SmallVector<MeshAxis> partialAxes; 410 411 // process the split axes 412 for (auto it : llvm::enumerate(map.getResults())) { 413 SmallVector<MeshAxis> tmp_axes; 414 AffineExpr expr = it.value(); 415 // `expr` must be an `AffineDimExpr` because `map` is verified by 416 // isProjectedPermutation 417 auto dim = cast<AffineDimExpr>(expr); 418 unsigned loopIdx = dim.getPosition(); 419 if (loopIdx < shardingOption.shardingArray.size()) 420 splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]); 421 } 422 423 // process the partial axes 424 // partialType will be ignored if partialAxes is empty 425 ReductionKind partialType = ReductionKind::Sum; 426 size_t reductionLoopKindsIdx = 0; 427 for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) { 428 utils::IteratorType iType = std::get<0>(it); 429 if (isReductionLoop(iType)) { 430 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx]; 431 ++reductionLoopKindsIdx; 432 if (!partialAxes.empty()) 433 assert(partialType == curPartialType && 434 "Only one reduction type is supported"); 435 partialType = curPartialType; 436 const SmallVector<MeshAxis> &axis = std::get<1>(it); 437 partialAxes.append(axis); 438 } 439 } 440 441 removeTrailingEmptySubArray(splitAxes); 442 return MeshSharding::get(shardingOption.mesh, 443 fromArrayOfVector(result.getContext(), splitAxes), 444 partialAxes, partialType); 445 } 446 447 static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, 448 const ShardingOption &shardingOption, 449 AffineMap map) { 450 Value operandValue = opOperand.get(); 451 auto operandType = cast<RankedTensorType>(operandValue.getType()); 452 SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); 453 unsigned numDims = map.getNumDims(); 454 for (auto it : llvm::enumerate(map.getResults())) { 455 int64_t idx = it.index(); 456 AffineExpr expr = it.value(); 457 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = 458 checkOperandAffineExpr(expr, numDims); 459 if (failed(loopIndices)) 460 return failure(); 461 SmallVector<unsigned> shardedLoopIndices; 462 for (unsigned loopIdx : *loopIndices) { 463 if ((size_t)loopIdx < shardingOption.shardingArray.size() && 464 !shardingOption.shardingArray[loopIdx].empty()) 465 shardedLoopIndices.push_back(loopIdx); 466 } 467 // mostly one sharded loop index is accepted 468 if (shardedLoopIndices.size() > 1) 469 return failure(); 470 if (shardedLoopIndices.size() == 1) { 471 splitAxes[idx].append( 472 shardingOption.shardingArray[shardedLoopIndices[0]]); 473 } 474 } 475 476 removeTrailingEmptySubArray(splitAxes); 477 return MeshSharding::get( 478 shardingOption.mesh, 479 fromArrayOfVector(opOperand.get().getContext(), splitAxes)); 480 } 481 482 FailureOr<std::vector<MeshSharding>> 483 mesh::detail::defaultGetShardingAnnotations( 484 Operation *op, const ShardingOption &shardingOption) { 485 std::vector<MeshSharding> res; 486 487 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 488 SmallVector<utils::IteratorType> loopTypes = 489 shardingOp.getLoopIteratorTypes(); 490 SmallVector<ReductionKind> reductionKinds = 491 shardingOp.getReductionLoopIteratorKinds(); 492 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 493 unsigned numOperands = op->getNumOperands(); 494 495 for (OpOperand &opOperand : op->getOpOperands()) { 496 FailureOr<MeshSharding> shardingAttr = getSharding( 497 opOperand, shardingOption, maps[opOperand.getOperandNumber()]); 498 if (failed(shardingAttr)) 499 return failure(); 500 res.push_back(*shardingAttr); 501 } 502 503 for (OpResult result : op->getResults()) { 504 res.push_back(getSharding(result, shardingOption, 505 maps[numOperands + result.getResultNumber()], 506 loopTypes, reductionKinds)); 507 } 508 509 return res; 510 } 511 512 //===----------------------------------------------------------------------===// 513 // detail::defaultAddShardingAnnotations 514 //===----------------------------------------------------------------------===// 515 516 // To add a `mesh.shard` op for the given result, based on the details provided 517 // in `shardingOption`, `map`, and `loopTypes`. 518 static LogicalResult addShardOp(OpBuilder &b, OpResult result, 519 const ShardingOption &shardingOption, 520 AffineMap map, 521 ArrayRef<utils::IteratorType> loopTypes, 522 ArrayRef<ReductionKind> reductionLoopKinds) { 523 MeshSharding sharding = 524 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds); 525 maybeInsertTargetShardingAnnotation(sharding, result, b); 526 527 return success(); 528 } 529 530 // To add a `mesh.shard` op for the given operand, based on the details provided 531 // in `shardingOption`, `map`, and `loopTypes`. 532 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, 533 const ShardingOption &shardingOption, 534 AffineMap map) { 535 536 FailureOr<MeshSharding> sharding = 537 getSharding(opOperand, shardingOption, map); 538 if (failed(sharding)) { 539 return failure(); 540 } 541 OpBuilder::InsertionGuard guard(b); 542 maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b); 543 544 return success(); 545 } 546 547 LogicalResult mesh::detail::defaultAddShardingAnnotations( 548 Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { 549 assert(!shardingOption.empty && shardingOption.mesh); 550 551 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 552 SmallVector<utils::IteratorType> loopTypes = 553 shardingOp.getLoopIteratorTypes(); 554 SmallVector<ReductionKind> reductionKinds = 555 shardingOp.getReductionLoopIteratorKinds(); 556 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 557 unsigned numOperands = op->getNumOperands(); 558 559 // 1. add mesh.shard ops for all op results 560 for (OpResult result : op->getResults()) { 561 if (failed(addShardOp(b, result, shardingOption, 562 maps[numOperands + result.getResultNumber()], 563 loopTypes, reductionKinds))) 564 return failure(); 565 } 566 567 // 2. add mesh.shard ops for all operands 568 for (OpOperand &opOperand : op->getOpOperands()) { 569 if (failed(addShardOp(b, opOperand, shardingOption, 570 maps[opOperand.getOperandNumber()]))) 571 return failure(); 572 } 573 574 return success(); 575 } 576 577 #ifndef NDEBUG 578 static bool 579 isValueCompatibleWithFullReplicationSharding(Value value, 580 MeshSharding sharding) { 581 if (isa<RankedTensorType>(value.getType())) { 582 return sharding && isFullReplication(sharding); 583 } 584 585 return !sharding; 586 } 587 588 template <typename ValueRange, typename MeshShardingRage> 589 static bool 590 areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, 591 MeshShardingRage &&shardings) { 592 if (std::size(values) != std::size(shardings)) { 593 return false; 594 } 595 return llvm::all_of( 596 llvm::zip_equal(std::forward<ValueRange>(values), 597 std::forward<MeshShardingRage>(shardings)), 598 [](auto valueAndSharding) { 599 return isValueCompatibleWithFullReplicationSharding( 600 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); 601 }); 602 } 603 #endif // NDEBUG 604 605 void mesh::spmdizeFullyReplicatedOperation( 606 Operation &op, ArrayRef<Value> spmdizedOperands, 607 ArrayRef<MeshSharding> operandShardings, 608 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 609 SymbolTableCollection &symbolTable, OpBuilder &builder) { 610 assert(spmdizedOperands.size() == operandShardings.size()); 611 assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), 612 operandShardings)); 613 assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), 614 resultShardings)); 615 // `clone` will populate the mapping of old to new results. 616 builder.clone(op, spmdizationMap); 617 } 618 619 static void updateMeshAxisAssignmentForLoopIterators( 620 ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, 621 SmallVector<std::optional<SmallVector<MeshAxis>>> 622 &meshAxesAssignmentForLoopIterators) { 623 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); 624 unsigned loopIteratorIdx = affineDimExpr.getPosition(); 625 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { 626 assert(llvm::equal(meshAxesAssignmentForTensorAxis, 627 *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); 628 } else { 629 meshAxesAssignmentForLoopIterators[loopIteratorIdx] = 630 llvm::to_vector(meshAxesAssignmentForTensorAxis); 631 } 632 } 633 634 ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( 635 ArrayRef<MeshSharding> operandShardings, 636 ArrayRef<MeshSharding> resultShardings, 637 ArrayRef<utils::IteratorType> loopIteratorTypes, 638 ArrayRef<AffineMap> indexingMaps) { 639 SmallVector<std::optional<SmallVector<MeshAxis>>> 640 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); 641 std::vector<MeshSharding> operatorAndResultShardings; 642 operatorAndResultShardings.reserve(operandShardings.size() + 643 resultShardings.size()); 644 llvm::append_range(operatorAndResultShardings, operandShardings); 645 for (auto [sharding, affineMap] : 646 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { 647 if (!sharding) { 648 continue; 649 } 650 for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : 651 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { 652 updateMeshAxisAssignmentForLoopIterators( 653 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, 654 meshAxisAssignmentForLoopIterators); 655 } 656 // Missing trailing split axes means replication on those tensor dimensions. 657 for (unsigned i = sharding.getSplitAxes().size(); 658 i < affineMap.getNumResults(); ++i) { 659 updateMeshAxisAssignmentForLoopIterators( 660 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); 661 } 662 } 663 664 ShardingArray res; 665 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), 666 [](std::optional<SmallVector<MeshAxis>> &axes) { 667 if (!axes) { 668 return SmallVector<MeshAxis>(); 669 }; 670 return std::move(*axes); 671 }); 672 return res; 673 } 674 675 bool mesh::isAtLeastOneReductionIteratorSharded( 676 ArrayRef<utils::IteratorType> loopIteratorTypes, 677 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { 678 for (auto [loopIteratorType, meshAxisAssignment] : 679 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 680 if (loopIteratorType == utils::IteratorType::reduction && 681 !meshAxisAssignment.empty()) { 682 return true; 683 } 684 } 685 return false; 686 } 687 688 SmallVector<MeshAxis> mesh::getReductionMeshAxes( 689 ArrayRef<utils::IteratorType> loopIteratorTypes, 690 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { 691 SmallVector<MeshAxis> meshAxes; 692 for (auto [loopIteratorType, meshAxisAssignment] : 693 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 694 if (loopIteratorType == utils::IteratorType::reduction) { 695 llvm::append_range(meshAxes, meshAxisAssignment); 696 } 697 } 698 return meshAxes; 699 } 700 701 void mesh::spmdizeTriviallyShardableOperation( 702 Operation &op, ArrayRef<Value> spmdizedOperands, 703 ArrayRef<MeshSharding> operandShardings, 704 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 705 SymbolTableCollection &symbolTable, OpBuilder &builder) { 706 // `clone` will populate the mapping of old to new results. 707 Operation *newOp = builder.clone(op, spmdizationMap); 708 // Set the result types to the sharded counterparts. 709 for (auto [oldResult, newResult, sharding] : 710 llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { 711 newResult.setType( 712 shardType(newResult.getType(), 713 getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding)); 714 } 715 } 716