1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/BuiltinAttributes.h" 16 #include "mlir/IR/BuiltinTypeInterfaces.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/IR/Value.h" 25 #include "mlir/Interfaces/ViewLikeInterface.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/InliningUtils.h" 28 #include "llvm/ADT/ArrayRef.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SmallSet.h" 31 #include "llvm/ADT/SmallVector.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Casting.h" 34 #include <algorithm> 35 #include <functional> 36 #include <iterator> 37 #include <numeric> 38 #include <optional> 39 #include <utility> 40 41 #define DEBUG_TYPE "mesh-ops" 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 using namespace mlir; 45 using namespace mlir::mesh; 46 47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc" 48 49 namespace { 50 51 struct DimensionSize { 52 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); } 53 DimensionSize(int64_t val) : val(val) {} 54 int64_t value() const { return val; } 55 operator int64_t() const { return val; } 56 bool isDynamic() const { return ShapedType::isDynamic(val); } 57 58 private: 59 int64_t val; 60 }; 61 62 } // namespace 63 64 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) { 65 if (lhs.isDynamic() || rhs.isDynamic()) { 66 return DimensionSize::dynamic(); 67 } 68 return lhs.value() / rhs.value(); 69 } 70 71 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { 72 if (lhs.isDynamic() || rhs.isDynamic()) { 73 return DimensionSize::dynamic(); 74 } 75 return lhs.value() * rhs.value(); 76 } 77 78 //===----------------------------------------------------------------------===// 79 // Inliner 80 //===----------------------------------------------------------------------===// 81 82 namespace { 83 struct MeshInlinerInterface : public DialectInlinerInterface { 84 using DialectInlinerInterface::DialectInlinerInterface; 85 // Currently no restrictions are encoded for inlining. 86 bool isLegalToInline(Operation *, Operation *, bool) const final { 87 return true; 88 } 89 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { 90 return true; 91 } 92 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 93 return true; 94 } 95 }; 96 } // namespace 97 98 //===----------------------------------------------------------------------===// 99 // Mesh dialect 100 //===----------------------------------------------------------------------===// 101 102 void MeshDialect::initialize() { 103 addOperations< 104 #define GET_OP_LIST 105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" 106 >(); 107 addAttributes< 108 #define GET_ATTRDEF_LIST 109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" 110 >(); 111 addTypes< 112 #define GET_TYPEDEF_LIST 113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" 114 >(); 115 addInterface<MeshInlinerInterface>(); 116 } 117 118 Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, 119 Type type, Location loc) { 120 return arith::ConstantOp::materialize(builder, value, type, loc); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // Mesh utilities 125 //===----------------------------------------------------------------------===// 126 127 static FailureOr<MeshOp> getMeshAndVerify(Operation *op, 128 FlatSymbolRefAttr meshSymbol, 129 SymbolTableCollection &symbolTable) { 130 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); 131 if (!mesh) { 132 return op->emitError() << "Undefined required mesh symbol \"" 133 << meshSymbol.getValue() << "\"."; 134 } 135 136 return mesh; 137 } 138 139 template <typename It> 140 bool isUnique(It begin, It end) { 141 if (begin == end) { 142 return true; 143 } 144 It next = std::next(begin); 145 if (next == end) { 146 return true; 147 } 148 for (; next != end; ++next, ++begin) { 149 if (*begin == *next) { 150 return false; 151 } 152 } 153 return true; 154 } 155 156 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, 157 MeshOp mesh) { 158 SmallVector<MeshAxis> sorted = llvm::to_vector(axes); 159 llvm::sort(sorted); 160 if (!isUnique(sorted.begin(), sorted.end())) { 161 return emitError(loc) << "Mesh axes contains duplicate elements."; 162 } 163 164 MeshAxis rank = mesh.getRank(); 165 for (auto axis : axes) { 166 if (axis >= rank || axis < 0) { 167 return emitError(loc) 168 << "0-based mesh axis index " << axis 169 << " is out of bounds. The referenced mesh \"" << mesh.getSymName() 170 << "\" is of rank " << rank << "."; 171 } 172 } 173 174 return success(); 175 } 176 177 template <typename Op> 178 static FailureOr<MeshOp> 179 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { 180 auto mesh = 181 ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable); 182 if (failed(mesh)) { 183 return failure(); 184 } 185 if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) { 186 return failure(); 187 } 188 return mesh; 189 } 190 191 template <typename InShape, typename MeshShape, typename SplitAxes, 192 typename OutShape> 193 static void shardShape(const InShape &inShape, const MeshShape &meshShape, 194 const SplitAxes &splitAxes, OutShape &outShape, 195 ArrayRef<int64_t> shardedDimsOffsets = {}, 196 ArrayRef<int64_t> haloSizes = {}) { 197 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), 198 llvm::adl_begin(outShape)); 199 200 if (!shardedDimsOffsets.empty()) { 201 auto isDynShape = ShapedType::isDynamicShape(meshShape); 202 uint64_t pos = 1; 203 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { 204 if (!innerSplitAxes.empty()) { 205 auto sz = shardedDimsOffsets[pos]; 206 bool same = !isDynShape; 207 if (same) { 208 // Find sharded dims in shardedDimsOffsets with same static size on 209 // all devices. Use kDynamic for dimensions with dynamic or 210 // non-uniform offs in shardedDimsOffsets. 211 uint64_t numShards = 0; 212 for (auto i : innerSplitAxes.asArrayRef()) { 213 numShards += meshShape[i]; 214 } 215 for (size_t i = 1; i < numShards; ++i) { 216 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] != 217 sz) { 218 same = false; 219 break; 220 } 221 } 222 pos += numShards + 1; 223 } 224 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic; 225 } 226 } 227 } else { 228 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { 229 outShape[tensorAxis] = shardDimension( 230 inShape[tensorAxis], 231 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape)); 232 } 233 234 if (!haloSizes.empty()) { 235 // add halo sizes if requested 236 int haloAxis = 0; 237 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { 238 if (!ShapedType::isDynamic(outShape[tensorAxis]) && 239 !innerSplitAxes.empty()) { 240 if (haloSizes[haloAxis * 2] >= 0 && 241 haloSizes[haloAxis * 2 + 1] >= 0) { 242 outShape[tensorAxis] += 243 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1]; 244 ++haloAxis; 245 } else { 246 outShape[tensorAxis] = ShapedType::kDynamic; 247 } 248 } 249 } 250 } 251 } 252 } 253 254 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, 255 MeshSharding sharding) { 256 using Dim = std::decay_t<decltype(shape.getDimSize(0))>; 257 SmallVector<Dim> resShapeArr(shape.getShape().size()); 258 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), 259 resShapeArr, sharding.getStaticShardedDimsOffsets(), 260 sharding.getStaticHaloSizes()); 261 return shape.clone(resShapeArr); 262 } 263 264 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { 265 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type); 266 if (rankedTensorType) { 267 return shardShapedType(rankedTensorType, mesh, sharding); 268 } 269 return type; 270 } 271 272 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, 273 OpOperand &operand, 274 OpBuilder &builder) { 275 OpBuilder::InsertionGuard insertionGuard(builder); 276 Value operandValue = operand.get(); 277 Operation *operandOp = operand.getOwner(); 278 builder.setInsertionPointAfterValue(operandValue); 279 ShardOp shardOp = dyn_cast<ShardOp>(operandOp); 280 if (shardOp && sharding == shardOp.getSharding() && 281 !shardOp.getAnnotateForUsers()) { 282 // No need for anything the correct sharding is already set. 283 return; 284 } 285 286 auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding); 287 auto newShardOp = 288 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, 289 /*annotate_for_users*/ false); 290 IRRewriter rewriter(builder); 291 rewriter.replaceUsesWithIf( 292 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { 293 return use.getOwner() == operandOp && use.get() == operandValue; 294 }); 295 296 if (!shardOp || shardOp.getAnnotateForUsers()) { 297 return; 298 } 299 300 auto newShardOp2 = 301 builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp, 302 /*annotate_for_users*/ true); 303 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); 304 } 305 306 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, 307 OpResult result, 308 OpBuilder &builder) { 309 for (auto &use : llvm::make_early_inc_range(result.getUses())) { 310 maybeInsertTargetShardingAnnotation(sharding, use, builder); 311 } 312 } 313 314 void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, 315 OpOperand &operand, 316 OpBuilder &builder) { 317 OpBuilder::InsertionGuard insertionGuard(builder); 318 Value operandValue = operand.get(); 319 Operation *operandOp = operand.getOwner(); 320 Operation *operandSrcOp = operandValue.getDefiningOp(); 321 bool isBlockArg = !operandSrcOp; 322 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp); 323 324 if (shardOp && sharding == shardOp.getSharding() && 325 shardOp.getAnnotateForUsers()) { 326 // No need for anything the correct sharding is already set. 327 return; 328 } 329 330 builder.setInsertionPoint(operandOp); 331 auto shardingOp = 332 builder.create<ShardingOp>(operand.get().getLoc(), sharding); 333 auto newShardOp = 334 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, 335 /*annotate_for_users*/ true); 336 IRRewriter rewriter(builder); 337 rewriter.replaceUsesWithIf( 338 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { 339 return use.getOwner() == operandOp && use.get() == operandValue; 340 }); 341 342 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) { 343 // No need for resharding. 344 return; 345 } 346 347 builder.setInsertionPoint(newShardOp); 348 auto newPreceedingShardOp = 349 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, 350 /*annotate_for_users*/ false); 351 rewriter.replaceUsesWithIf( 352 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) { 353 return use.getOwner() == newShardOp.getOperation(); 354 }); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // mesh.mesh op 359 //===----------------------------------------------------------------------===// 360 361 LogicalResult MeshOp::verify() { 362 int64_t rank = getRank(); 363 364 if (rank <= 0) 365 return emitOpError("rank of mesh is expected to be a positive integer"); 366 367 for (int64_t dimSize : getShape()) { 368 if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) 369 return emitOpError("dimension size of a mesh is expected to be " 370 "non-negative or dynamic"); 371 } 372 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // mesh.mesh_shape op 378 //===----------------------------------------------------------------------===// 379 380 LogicalResult 381 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 382 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 383 if (failed(mesh)) { 384 return failure(); 385 } 386 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { 387 return failure(); 388 } 389 390 size_t expectedResultsCount = 391 getAxes().empty() ? mesh->getRank() : getAxes().size(); 392 if (getResult().size() != expectedResultsCount) { 393 return emitError() << "Unexpected number of results " << getResult().size() 394 << ". Expected " << expectedResultsCount << "."; 395 } 396 397 return success(); 398 } 399 400 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, 401 MeshOp mesh) { 402 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>()); 403 } 404 405 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, 406 MeshOp mesh, ArrayRef<MeshAxis> axes) { 407 build(odsBuilder, odsState, 408 SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(), 409 odsBuilder.getIndexType()), 410 mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); 411 } 412 413 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, 414 StringRef mesh, ArrayRef<MeshAxis> axes) { 415 assert(!axes.empty()); 416 build(odsBuilder, odsState, 417 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, 418 MeshAxesAttr::get(odsBuilder.getContext(), axes)); 419 } 420 421 void MeshShapeOp::getAsmResultNames( 422 function_ref<void(Value, StringRef)> setNameFn) { 423 setNameFn(getResults()[0], "mesh_shape"); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // mesh.sharding 428 //===----------------------------------------------------------------------===// 429 430 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, 431 FlatSymbolRefAttr mesh, 432 ArrayRef<MeshAxesAttr> split_axes, 433 ArrayRef<MeshAxis> partial_axes, 434 mesh::ReductionKind partial_type, 435 ArrayRef<int64_t> static_halo_sizes, 436 ArrayRef<int64_t> static_sharded_dims_offsets) { 437 return build( 438 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), 439 ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), 440 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type), 441 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, 442 ::mlir::DenseI64ArrayAttr::get(b.getContext(), 443 static_sharded_dims_offsets), 444 {}); 445 } 446 447 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, 448 FlatSymbolRefAttr mesh, 449 ArrayRef<MeshAxesAttr> split_axes) { 450 return build( 451 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, 452 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), 453 {}, {}, {}, {}); 454 } 455 456 void ShardingOp::build( 457 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, 458 FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, 459 ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, 460 ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { 461 mlir::SmallVector<int64_t> staticHalos, staticDims; 462 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims; 463 dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); 464 dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); 465 return build( 466 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, 467 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), 468 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, 469 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); 470 } 471 472 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, 473 mlir::mesh::MeshSharding from) { 474 475 build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(), 476 MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), 477 from.getPartialAxes().empty() 478 ? DenseI16ArrayAttr() 479 : b.getDenseI16ArrayAttr(from.getPartialAxes()), 480 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), 481 from.getPartialType()), 482 from.getStaticShardedDimsOffsets().empty() 483 ? DenseI64ArrayAttr() 484 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()), 485 from.getDynamicShardedDimsOffsets(), 486 from.getStaticHaloSizes().empty() 487 ? DenseI64ArrayAttr() 488 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()), 489 from.getDynamicHaloSizes()); 490 } 491 492 LogicalResult ShardingOp::verify() { 493 llvm::SmallSet<MeshAxis, 4> visitedAxes; 494 495 auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult { 496 for (MeshAxis axis : axesArray) { 497 if (axis < 0) 498 return emitError() << "mesh axis is expected to be non-negative"; 499 if (!visitedAxes.insert(axis).second) 500 return emitError() << "mesh axis duplicated"; 501 } 502 return success(); 503 }; 504 505 for (auto subAxes : getSplitAxes().getAxes()) { 506 ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef(); 507 if (failed(checkMeshAxis(subAxesArray))) 508 return failure(); 509 } 510 if (getPartialAxes().has_value() && 511 failed(checkMeshAxis(getPartialAxes().value()))) 512 return failure(); 513 514 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) { 515 return emitOpError("halo sizes and shard offsets are mutually exclusive"); 516 } 517 518 if (!getStaticHaloSizes().empty()) { 519 auto numSplitAxes = getSplitAxes().getAxes().size(); 520 for (auto splitAxis : getSplitAxes().getAxes()) { 521 if (splitAxis.empty()) { 522 --numSplitAxes; 523 } 524 } 525 if (getStaticHaloSizes().size() != numSplitAxes * 2) { 526 return emitError() << "halo sizes must be specified for all split axes."; 527 } 528 } 529 530 return success(); 531 } 532 533 void ShardingOp::getAsmResultNames( 534 function_ref<void(Value, StringRef)> setNameFn) { 535 setNameFn(getResult(), "sharding"); 536 } 537 538 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 539 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 540 if (failed(mesh)) { 541 return failure(); 542 } 543 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) && 544 getStaticShardedDimsOffsets().size() > 0) { 545 return emitError() << "sharded dims offsets are not allowed for " 546 "devices meshes with dynamic shape."; 547 } 548 549 auto shardedDimsOffsets = getStaticShardedDimsOffsets(); 550 if (!shardedDimsOffsets.empty()) { 551 auto meshShape = mesh.value().getShape(); 552 assert(!ShapedType::isDynamicShape(meshShape)); 553 uint64_t pos = 0; 554 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) { 555 if (!innerSplitAxes.empty()) { 556 int64_t numShards = 0, off = 0; 557 for (auto i : innerSplitAxes.asArrayRef()) { 558 numShards += meshShape[i]; 559 } 560 for (int64_t i = 0; i <= numShards; ++i) { 561 if (shardedDimsOffsets.size() <= pos + i) { 562 return emitError() << "sharded dims offsets has wrong size."; 563 } 564 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) { 565 if (shardedDimsOffsets[pos + i] < off) { 566 return emitError() 567 << "sharded dims offsets must be non-decreasing."; 568 } 569 off = shardedDimsOffsets[pos + i]; 570 } 571 } 572 pos += numShards + 1; 573 } 574 } 575 } 576 return success(); 577 } 578 579 namespace { 580 // Sharding annotations "halo sizes" and "sharded dims offsets" 581 // are a mix of attributes and dynamic values. This canonicalization moves 582 // constant values to the respective attribute lists and so minimizes the number 583 // of values. 584 class FoldDynamicLists final : public OpRewritePattern<ShardingOp> { 585 public: 586 using OpRewritePattern<ShardingOp>::OpRewritePattern; 587 588 LogicalResult matchAndRewrite(ShardingOp op, 589 PatternRewriter &b) const override { 590 auto mixedHalos = 591 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b); 592 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(), 593 op.getDynamicShardedDimsOffsets(), b); 594 595 // No constant operands were folded, just return; 596 if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) && 597 failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) { 598 return failure(); 599 } 600 601 auto halos = decomposeMixedValues(mixedHalos); 602 auto offs = decomposeMixedValues(mixedOffs); 603 604 op.setStaticHaloSizes(halos.first); 605 op.getDynamicHaloSizesMutable().assign(halos.second); 606 op.setStaticShardedDimsOffsets(offs.first); 607 op.getDynamicShardedDimsOffsetsMutable().assign(offs.second); 608 609 return success(); 610 } 611 }; 612 } // namespace 613 614 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, 615 mlir::MLIRContext *context) { 616 results.add<FoldDynamicLists>(context); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // MeshSharding 621 //===----------------------------------------------------------------------===// 622 623 bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const { 624 if (getMesh() != rhs.getMesh()) { 625 return false; 626 } 627 628 if (getPartialAxes().size() != rhs.getPartialAxes().size() || 629 (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) || 630 !llvm::equal( 631 llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()), 632 llvm::make_range(rhs.getPartialAxes().begin(), 633 rhs.getPartialAxes().end()))) { 634 return false; 635 } 636 637 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size()); 638 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(), 639 getSplitAxes().begin() + minSize), 640 llvm::make_range(rhs.getSplitAxes().begin(), 641 rhs.getSplitAxes().begin() + minSize))) { 642 return false; 643 } 644 645 return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize, 646 getSplitAxes().end()), 647 std::mem_fn(&MeshAxesAttr::empty)) && 648 llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize, 649 rhs.getSplitAxes().end()), 650 std::mem_fn(&MeshAxesAttr::empty)); 651 } 652 653 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const { 654 return equalShardSizes(rhs) && equalHaloSizes(rhs); 655 } 656 657 bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { 658 if (rhs.getStaticShardedDimsOffsets().size() != 659 getStaticShardedDimsOffsets().size() || 660 !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(), 661 getStaticShardedDimsOffsets().end()), 662 llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(), 663 rhs.getStaticShardedDimsOffsets().end()))) { 664 return false; 665 } 666 if (rhs.getDynamicShardedDimsOffsets().size() != 667 getDynamicShardedDimsOffsets().size() || 668 !llvm::equal( 669 llvm::make_range(getDynamicShardedDimsOffsets().begin(), 670 getDynamicShardedDimsOffsets().end()), 671 llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(), 672 rhs.getDynamicShardedDimsOffsets().end()))) { 673 return false; 674 } 675 return true; 676 } 677 678 bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { 679 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() || 680 !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(), 681 getStaticHaloSizes().end()), 682 llvm::make_range(rhs.getStaticHaloSizes().begin(), 683 rhs.getStaticHaloSizes().end()))) { 684 return false; 685 } 686 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() || 687 !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(), 688 getDynamicHaloSizes().end()), 689 llvm::make_range(rhs.getDynamicHaloSizes().begin(), 690 rhs.getDynamicHaloSizes().end()))) { 691 return false; 692 } 693 return true; 694 } 695 696 bool MeshSharding::operator==(Value rhs) const { 697 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs); 698 } 699 700 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); } 701 702 bool MeshSharding::operator==(const MeshSharding &rhs) const { 703 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs); 704 } 705 706 bool MeshSharding::operator!=(const MeshSharding &rhs) const { 707 return !(*this == rhs); 708 } 709 710 MeshSharding::MeshSharding(Value rhs) { 711 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp()); 712 assert(shardingOp && "expected sharding op"); 713 *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(), 714 shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()), 715 shardingOp.getPartialType().value_or(ReductionKind::Sum), 716 shardingOp.getStaticHaloSizes(), 717 shardingOp.getStaticShardedDimsOffsets(), 718 SmallVector<Value>(shardingOp.getDynamicHaloSizes()), 719 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); 720 } 721 722 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, 723 ArrayRef<MeshAxesAttr> split_axes_, 724 ArrayRef<MeshAxis> partial_axes_, 725 ReductionKind partial_type_, 726 ArrayRef<int64_t> static_halo_sizes_, 727 ArrayRef<int64_t> static_sharded_dims_offsets_, 728 ArrayRef<Value> dynamic_halo_sizes_, 729 ArrayRef<Value> dynamic_sharded_dims_offsets_) { 730 MeshSharding res; 731 res.mesh = mesh_; 732 res.split_axes.resize(split_axes_.size()); 733 for (auto [i, axis] : llvm::enumerate(split_axes_)) { 734 res.split_axes[i] = 735 MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef()); 736 } 737 738 auto clone = [](const auto src, auto &dst) { 739 dst.resize(src.size()); 740 llvm::copy(src, dst.begin()); 741 }; 742 743 clone(partial_axes_, res.partial_axes); 744 res.partial_type = partial_type_; 745 clone(static_halo_sizes_, res.static_halo_sizes); 746 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets); 747 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes); 748 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets); 749 750 return res; 751 } 752 753 //===----------------------------------------------------------------------===// 754 // mesh.shard_shape 755 //===----------------------------------------------------------------------===// 756 757 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, 758 ::mlir::OperationState &odsState, 759 ::llvm::ArrayRef<int64_t> shape, 760 ::mlir::Value sharding, ::mlir::Value device) { 761 SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType()); 762 build(odsBuilder, odsState, resType, shape, sharding, device); 763 } 764 765 //===----------------------------------------------------------------------===// 766 // mesh.shard op 767 //===----------------------------------------------------------------------===// 768 769 void ShardOp::getAsmResultNames( 770 function_ref<void(Value, StringRef)> setNameFn) { 771 setNameFn(getResult(), "sharding_annotated"); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // mesh.process_multi_index op 776 //===----------------------------------------------------------------------===// 777 778 LogicalResult 779 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 780 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 781 if (failed(mesh)) { 782 return failure(); 783 } 784 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { 785 return failure(); 786 } 787 788 size_t expectedResultsCount = 789 getAxes().empty() ? mesh->getRank() : getAxes().size(); 790 if (getResult().size() != expectedResultsCount) { 791 return emitError() << "Unexpected number of results " << getResult().size() 792 << ". Expected " << expectedResultsCount << "."; 793 } 794 795 return success(); 796 } 797 798 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, 799 MeshOp mesh) { 800 build(odsBuilder, odsState, 801 SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), 802 mesh.getSymName(), ArrayRef<MeshAxis>()); 803 } 804 805 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, 806 StringRef mesh, ArrayRef<MeshAxis> axes) { 807 build(odsBuilder, odsState, 808 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, 809 MeshAxesAttr::get(odsBuilder.getContext(), axes)); 810 } 811 812 void ProcessMultiIndexOp::getAsmResultNames( 813 function_ref<void(Value, StringRef)> setNameFn) { 814 setNameFn(getResults()[0], "proc_linear_idx"); 815 } 816 817 //===----------------------------------------------------------------------===// 818 // mesh.process_linear_index op 819 //===----------------------------------------------------------------------===// 820 821 LogicalResult 822 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 823 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 824 if (failed(mesh)) { 825 return failure(); 826 } 827 return success(); 828 } 829 830 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, 831 OperationState &odsState, MeshOp mesh) { 832 build(odsBuilder, odsState, mesh.getSymName()); 833 } 834 835 void ProcessLinearIndexOp::getAsmResultNames( 836 function_ref<void(Value, StringRef)> setNameFn) { 837 setNameFn(getResult(), "proc_linear_idx"); 838 } 839 840 //===----------------------------------------------------------------------===// 841 // mesh.neighbors_linear_indices op 842 //===----------------------------------------------------------------------===// 843 844 LogicalResult 845 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 846 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 847 if (failed(mesh)) { 848 return failure(); 849 } 850 return success(); 851 } 852 853 void NeighborsLinearIndicesOp::getAsmResultNames( 854 function_ref<void(Value, StringRef)> setNameFn) { 855 setNameFn(getNeighborDown(), "down_linear_idx"); 856 setNameFn(getNeighborUp(), "up_linear_idx"); 857 } 858 859 //===----------------------------------------------------------------------===// 860 // collective communication ops 861 //===----------------------------------------------------------------------===// 862 863 namespace { 864 865 template <typename Op> 866 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { 867 using OpRewritePattern<Op>::OpRewritePattern; 868 LogicalResult matchAndRewrite(Op op, 869 PatternRewriter &rewriter) const override { 870 auto meshAxes = op.getMeshAxes(); 871 if (!meshAxes.empty()) { 872 return failure(); 873 } 874 if (op.getInput().getType() != op.getResult().getType()) { 875 return failure(); 876 } 877 878 rewriter.replaceAllUsesWith(op.getResult(), op.getInput()); 879 rewriter.eraseOp(op.getOperation()); 880 return success(); 881 } 882 }; 883 884 } // namespace 885 886 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, 887 ArrayRef<int64_t> device, 888 Operation::operand_range deviceDynamic, 889 ArrayRef<MeshAxis> meshAxes, 890 ArrayRef<int64_t> meshShape) { 891 if (device.size() != meshAxes.size()) { 892 return emitError(loc) << "In-group device \"" << deviceName 893 << "\" has unexpected multi-index size " 894 << device.size() << ". Expected " << meshAxes.size() 895 << "."; 896 } 897 898 for (size_t i = 0; i < device.size(); ++i) { 899 if (!ShapedType::isDynamic(device[i]) && 900 !ShapedType::isDynamic(meshShape[meshAxes[i]]) && 901 meshShape[meshAxes[i]] <= device[i]) { 902 return emitError(loc) 903 << "Out of bounds coordinate " << i << " for in-group device \"" 904 << deviceName << "\"." 905 << " Got " << device[i] << ", but expected value in the range [0, " 906 << (meshShape[meshAxes[i]] - 1) << "]."; 907 } 908 } 909 return success(); 910 } 911 912 template <typename It> 913 static auto product(It begin, It end) { 914 using ElementType = std::decay_t<decltype(*begin)>; 915 return std::accumulate(begin, end, static_cast<ElementType>(1), 916 std::multiplies<ElementType>()); 917 } 918 919 template <typename R> 920 static auto product(R &&range) { 921 return product(adl_begin(range), adl_end(range)); 922 } 923 924 static LogicalResult verifyDimensionCompatibility(Location loc, 925 int64_t expectedDimSize, 926 int64_t resultDimSize, 927 int64_t resultAxis) { 928 if (!ShapedType::isDynamic(resultDimSize) && 929 expectedDimSize != resultDimSize) { 930 return emitError(loc) << "Dimension size mismatch for result axis " 931 << resultAxis << ". Expected " 932 << (ShapedType::isDynamic(expectedDimSize) 933 ? Twine("dynamic") 934 : Twine(expectedDimSize)) 935 << ", but got " << resultDimSize << "."; 936 } 937 938 return success(); 939 } 940 941 static LogicalResult verifyGatherOperandAndResultShape( 942 Value operand, Value result, int64_t gatherAxis, 943 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { 944 auto resultRank = cast<ShapedType>(result.getType()).getRank(); 945 if (gatherAxis < 0 || gatherAxis >= resultRank) { 946 return emitError(result.getLoc()) 947 << "Gather axis " << gatherAxis << " is out of bounds [0, " 948 << resultRank << ")."; 949 } 950 951 ShapedType operandType = cast<ShapedType>(operand.getType()); 952 ShapedType resultType = cast<ShapedType>(result.getType()); 953 auto deviceGroupSize = 954 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); 955 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { 956 auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); 957 auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); 958 auto expectedResultDimSize = 959 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize; 960 if (failed(verifyDimensionCompatibility( 961 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) { 962 return failure(); 963 } 964 } 965 return success(); 966 } 967 968 static LogicalResult verifyAllToAllOperandAndResultShape( 969 Value operand, Value result, int64_t splitAxis, int64_t concatAxis, 970 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { 971 ShapedType operandType = cast<ShapedType>(operand.getType()); 972 ShapedType resultType = cast<ShapedType>(result.getType()); 973 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { 974 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) { 975 if (failed(verifyDimensionCompatibility( 976 result.getLoc(), operandType.getDimSize(axis), 977 resultType.getDimSize(axis), axis))) { 978 return failure(); 979 } 980 } 981 } 982 983 if (splitAxis == concatAxis) { 984 return success(); 985 } 986 987 auto deviceGroupSize = 988 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); 989 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis)); 990 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis)); 991 DimensionSize expectedResultConcatDimSize = 992 operandConcatDimSize * deviceGroupSize; 993 DimensionSize expectedResultSplitDimSize = 994 operandSplitDimSize / deviceGroupSize; 995 if (!expectedResultSplitDimSize.isDynamic() && 996 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) { 997 expectedResultSplitDimSize = DimensionSize::dynamic(); 998 } 999 if (failed(verifyDimensionCompatibility( 1000 result.getLoc(), expectedResultConcatDimSize.value(), 1001 resultType.getDimSize(concatAxis), concatAxis))) { 1002 return failure(); 1003 } 1004 if (failed(verifyDimensionCompatibility( 1005 result.getLoc(), expectedResultSplitDimSize.value(), 1006 resultType.getDimSize(splitAxis), splitAxis))) { 1007 return failure(); 1008 } 1009 1010 return success(); 1011 } 1012 1013 static LogicalResult verifyScatterOrSliceOperandAndResultShape( 1014 Value operand, Value result, int64_t tensorAxis, 1015 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { 1016 ShapedType operandType = cast<ShapedType>(operand.getType()); 1017 ShapedType resultType = cast<ShapedType>(result.getType()); 1018 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { 1019 if (axis != tensorAxis) { 1020 if (failed(verifyDimensionCompatibility( 1021 result.getLoc(), operandType.getDimSize(axis), 1022 resultType.getDimSize(axis), axis))) { 1023 return failure(); 1024 } 1025 } 1026 } 1027 1028 auto deviceGroupSize = 1029 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); 1030 auto operandScatterDimSize = 1031 DimensionSize(operandType.getDimSize(tensorAxis)); 1032 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && 1033 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) { 1034 return emitError(result.getLoc()) 1035 << "Operand dimension size " << int64_t(operandScatterDimSize) 1036 << " is not divisible by collective device group size " 1037 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis 1038 << "."; 1039 } 1040 DimensionSize expectedResultTensorDimSize = 1041 operandScatterDimSize / deviceGroupSize; 1042 if (failed(verifyDimensionCompatibility( 1043 result.getLoc(), expectedResultTensorDimSize.value(), 1044 resultType.getDimSize(tensorAxis), tensorAxis))) { 1045 return failure(); 1046 } 1047 1048 return success(); 1049 } 1050 1051 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, 1052 ArrayRef<MeshAxis> meshAxes, 1053 int64_t sliceAxis) { 1054 RankedTensorType operandRankedTensorType = 1055 cast<RankedTensorType>(operandType); 1056 DimensionSize operandSliceAxisSize = 1057 operandRankedTensorType.getShape()[sliceAxis]; 1058 SmallVector<int64_t> resultShape = 1059 llvm::to_vector(operandRankedTensorType.getShape()); 1060 1061 resultShape[sliceAxis] = 1062 operandSliceAxisSize / 1063 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); 1064 return operandRankedTensorType.clone(resultShape); 1065 } 1066 1067 //===----------------------------------------------------------------------===// 1068 // mesh.all_gather op 1069 //===----------------------------------------------------------------------===// 1070 1071 LogicalResult 1072 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1073 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1074 if (failed(mesh)) { 1075 return failure(); 1076 } 1077 auto gatherAxis = getGatherAxis().getSExtValue(); 1078 return verifyGatherOperandAndResultShape(getOperand(), getResult(), 1079 gatherAxis, getMeshAxes(), 1080 mesh.value().getShape()); 1081 } 1082 1083 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1084 MLIRContext *context) { 1085 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context); 1086 } 1087 1088 void AllGatherOp::getAsmResultNames( 1089 function_ref<void(Value, StringRef)> setNameFn) { 1090 setNameFn(getResult(), "all_gather"); 1091 } 1092 1093 //===----------------------------------------------------------------------===// 1094 // mesh.all_reduce op 1095 //===----------------------------------------------------------------------===// 1096 1097 LogicalResult 1098 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1099 return getMeshAndVerifyAxes(*this, symbolTable); 1100 } 1101 1102 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1103 MLIRContext *context) { 1104 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); 1105 } 1106 1107 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1108 Value input, StringRef mesh, 1109 ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { 1110 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, 1111 reduction); 1112 } 1113 1114 void AllReduceOp::getAsmResultNames( 1115 function_ref<void(Value, StringRef)> setNameFn) { 1116 setNameFn(getResult(), "all_reduce"); 1117 } 1118 1119 //===----------------------------------------------------------------------===// 1120 // mesh.all_slice op 1121 //===----------------------------------------------------------------------===// 1122 1123 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1124 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1125 if (failed(mesh)) { 1126 return failure(); 1127 } 1128 return verifyScatterOrSliceOperandAndResultShape( 1129 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), 1130 mesh.value().getShape()); 1131 } 1132 1133 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1134 MLIRContext *context) { 1135 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context); 1136 } 1137 1138 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1139 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes, 1140 int64_t sliceAxis) { 1141 Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); 1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, 1143 sliceAxis); 1144 } 1145 1146 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1147 Type resultType, Value input, StringRef mesh, 1148 ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) { 1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input, 1150 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); 1151 } 1152 1153 void AllSliceOp::getAsmResultNames( 1154 function_ref<void(Value, StringRef)> setNameFn) { 1155 setNameFn(getResult(), "all_slice"); 1156 } 1157 1158 //===----------------------------------------------------------------------===// 1159 // mesh.all_to_all op 1160 //===----------------------------------------------------------------------===// 1161 1162 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1163 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1164 if (failed(mesh)) { 1165 return failure(); 1166 } 1167 1168 return verifyAllToAllOperandAndResultShape( 1169 getOperand(), getResult(), getSplitAxis().getSExtValue(), 1170 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); 1171 } 1172 1173 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1174 MLIRContext *context) { 1175 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context); 1176 } 1177 1178 void AllToAllOp::getAsmResultNames( 1179 function_ref<void(Value, StringRef)> setNameFn) { 1180 setNameFn(getResult(), "all_to_all"); 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // mesh.broadcast op 1185 //===----------------------------------------------------------------------===// 1186 1187 LogicalResult 1188 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1189 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1190 if (failed(mesh)) { 1191 return failure(); 1192 } 1193 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), 1194 getRootDynamic(), getMeshAxes(), 1195 mesh.value().getShape()))) { 1196 return failure(); 1197 } 1198 1199 return success(); 1200 } 1201 1202 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1203 MLIRContext *context) { 1204 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context); 1205 } 1206 1207 void BroadcastOp::getAsmResultNames( 1208 function_ref<void(Value, StringRef)> setNameFn) { 1209 setNameFn(getResult(), "broadcast"); 1210 } 1211 1212 //===----------------------------------------------------------------------===// 1213 // mesh.gather op 1214 //===----------------------------------------------------------------------===// 1215 1216 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1217 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1218 if (failed(mesh)) { 1219 return failure(); 1220 } 1221 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), 1222 getRootDynamic(), getMeshAxes(), 1223 mesh.value().getShape()))) { 1224 return failure(); 1225 } 1226 1227 auto gatherAxis = getGatherAxis().getSExtValue(); 1228 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, 1229 getMeshAxes(), 1230 mesh.value().getShape()); 1231 } 1232 1233 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1234 MLIRContext *context) { 1235 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context); 1236 } 1237 1238 void GatherOp::getAsmResultNames( 1239 function_ref<void(Value, StringRef)> setNameFn) { 1240 setNameFn(getResult(), "gather"); 1241 } 1242 1243 //===----------------------------------------------------------------------===// 1244 // mesh.recv op 1245 //===----------------------------------------------------------------------===// 1246 1247 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1248 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1249 if (failed(mesh)) { 1250 return failure(); 1251 } 1252 if (getSource() && 1253 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(), 1254 getSource().value(), getSourceDynamic(), 1255 getMeshAxes(), mesh.value().getShape()))) { 1256 return failure(); 1257 } 1258 return success(); 1259 } 1260 1261 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1262 MLIRContext *context) { 1263 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context); 1264 } 1265 1266 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 1267 setNameFn(getResult(), "recv"); 1268 } 1269 1270 //===----------------------------------------------------------------------===// 1271 // mesh.reduce op 1272 //===----------------------------------------------------------------------===// 1273 1274 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1275 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1276 if (failed(mesh)) { 1277 return failure(); 1278 } 1279 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), 1280 getRootDynamic(), getMeshAxes(), 1281 mesh.value().getShape()))) { 1282 return failure(); 1283 } 1284 1285 return success(); 1286 } 1287 1288 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1289 MLIRContext *context) { 1290 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context); 1291 } 1292 1293 void ReduceOp::getAsmResultNames( 1294 function_ref<void(Value, StringRef)> setNameFn) { 1295 setNameFn(getResult(), "reduce"); 1296 } 1297 1298 //===----------------------------------------------------------------------===// 1299 // mesh.reduce_scatter op 1300 //===----------------------------------------------------------------------===// 1301 1302 LogicalResult 1303 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1304 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1305 if (failed(mesh)) { 1306 return failure(); 1307 } 1308 1309 return verifyScatterOrSliceOperandAndResultShape( 1310 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), 1311 mesh.value().getShape()); 1312 } 1313 1314 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1315 MLIRContext *context) { 1316 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context); 1317 } 1318 1319 void ReduceScatterOp::getAsmResultNames( 1320 function_ref<void(Value, StringRef)> setNameFn) { 1321 setNameFn(getResult(), "reduce_scatter"); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // mesh.scatter op 1326 //===----------------------------------------------------------------------===// 1327 1328 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1329 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1330 if (failed(mesh)) { 1331 return failure(); 1332 } 1333 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), 1334 getRootDynamic(), getMeshAxes(), 1335 mesh.value().getShape()))) { 1336 return failure(); 1337 } 1338 1339 auto scatterAxis = getScatterAxis().getSExtValue(); 1340 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), 1341 scatterAxis, getMeshAxes(), 1342 mesh.value().getShape()); 1343 } 1344 1345 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1346 MLIRContext *context) { 1347 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context); 1348 } 1349 1350 void ScatterOp::getAsmResultNames( 1351 function_ref<void(Value, StringRef)> setNameFn) { 1352 setNameFn(getResult(), "scatter"); 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // mesh.send op 1357 //===----------------------------------------------------------------------===// 1358 1359 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1360 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1361 if (failed(mesh)) { 1362 return failure(); 1363 } 1364 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), 1365 getDestination(), getDestinationDynamic(), 1366 getMeshAxes(), mesh.value().getShape()))) { 1367 return failure(); 1368 } 1369 return success(); 1370 } 1371 1372 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1373 MLIRContext *context) { 1374 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context); 1375 } 1376 1377 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 1378 setNameFn(getResult(), "send"); 1379 } 1380 1381 //===----------------------------------------------------------------------===// 1382 // mesh.shift op 1383 //===----------------------------------------------------------------------===// 1384 1385 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1386 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); 1387 if (failed(mesh)) { 1388 return failure(); 1389 } 1390 1391 auto meshAxes = getMeshAxes(); 1392 auto shiftAxis = getShiftAxis().getZExtValue(); 1393 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) { 1394 return emitError() << "Invalid shift axis " << shiftAxis 1395 << ". It must be one of the grouping mesh axes."; 1396 } 1397 1398 return success(); 1399 } 1400 1401 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1402 MLIRContext *context) { 1403 // TODO: remove op when offset is 0 or if it is a rotate with and 1404 // offset % shift_axis_mesh_dim_size == 0. 1405 } 1406 1407 void ShiftOp::getAsmResultNames( 1408 function_ref<void(Value, StringRef)> setNameFn) { 1409 setNameFn(getResult(), "shift"); 1410 } 1411 1412 //===----------------------------------------------------------------------===// 1413 // mesh.update_halo op 1414 //===----------------------------------------------------------------------===// 1415 1416 LogicalResult 1417 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1418 auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); 1419 if (failed(mesh)) { 1420 return failure(); 1421 } 1422 1423 return success(); 1424 } 1425 1426 //===----------------------------------------------------------------------===// 1427 // TableGen'd op method definitions 1428 //===----------------------------------------------------------------------===// 1429 1430 #define GET_OP_CLASSES 1431 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" 1432 1433 #define GET_ATTRDEF_CLASSES 1434 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" 1435 1436 #define GET_TYPEDEF_CLASSES 1437 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" 1438 1439 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc" 1440