xref: /llvm-project/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (revision 79eb406a67fe08458548289da72cda18248a9313)
1 //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypeInterfaces.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
25 #include "mlir/Interfaces/ViewLikeInterface.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <functional>
36 #include <iterator>
37 #include <numeric>
38 #include <optional>
39 #include <utility>
40 
41 #define DEBUG_TYPE "mesh-ops"
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 using namespace mlir;
45 using namespace mlir::mesh;
46 
47 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48 
49 namespace {
50 
51 struct DimensionSize {
52   static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
53   DimensionSize(int64_t val) : val(val) {}
54   int64_t value() const { return val; }
55   operator int64_t() const { return val; }
56   bool isDynamic() const { return ShapedType::isDynamic(val); }
57 
58 private:
59   int64_t val;
60 };
61 
62 } // namespace
63 
64 static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
65   if (lhs.isDynamic() || rhs.isDynamic()) {
66     return DimensionSize::dynamic();
67   }
68   return lhs.value() / rhs.value();
69 }
70 
71 static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
72   if (lhs.isDynamic() || rhs.isDynamic()) {
73     return DimensionSize::dynamic();
74   }
75   return lhs.value() * rhs.value();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // Inliner
80 //===----------------------------------------------------------------------===//
81 
82 namespace {
83 struct MeshInlinerInterface : public DialectInlinerInterface {
84   using DialectInlinerInterface::DialectInlinerInterface;
85   // Currently no restrictions are encoded for inlining.
86   bool isLegalToInline(Operation *, Operation *, bool) const final {
87     return true;
88   }
89   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
90     return true;
91   }
92   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
93     return true;
94   }
95 };
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Mesh dialect
100 //===----------------------------------------------------------------------===//
101 
102 void MeshDialect::initialize() {
103   addOperations<
104 #define GET_OP_LIST
105 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
106       >();
107   addAttributes<
108 #define GET_ATTRDEF_LIST
109 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
110       >();
111   addTypes<
112 #define GET_TYPEDEF_LIST
113 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
114       >();
115   addInterface<MeshInlinerInterface>();
116 }
117 
118 Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
119                                             Type type, Location loc) {
120   return arith::ConstantOp::materialize(builder, value, type, loc);
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Mesh utilities
125 //===----------------------------------------------------------------------===//
126 
127 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
128                                           FlatSymbolRefAttr meshSymbol,
129                                           SymbolTableCollection &symbolTable) {
130   mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
131   if (!mesh) {
132     return op->emitError() << "Undefined required mesh symbol \""
133                            << meshSymbol.getValue() << "\".";
134   }
135 
136   return mesh;
137 }
138 
139 template <typename It>
140 bool isUnique(It begin, It end) {
141   if (begin == end) {
142     return true;
143   }
144   It next = std::next(begin);
145   if (next == end) {
146     return true;
147   }
148   for (; next != end; ++next, ++begin) {
149     if (*begin == *next) {
150       return false;
151     }
152   }
153   return true;
154 }
155 
156 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
157                                     MeshOp mesh) {
158   SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
159   llvm::sort(sorted);
160   if (!isUnique(sorted.begin(), sorted.end())) {
161     return emitError(loc) << "Mesh axes contains duplicate elements.";
162   }
163 
164   MeshAxis rank = mesh.getRank();
165   for (auto axis : axes) {
166     if (axis >= rank || axis < 0) {
167       return emitError(loc)
168              << "0-based mesh axis index " << axis
169              << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
170              << "\" is of rank " << rank << ".";
171     }
172   }
173 
174   return success();
175 }
176 
177 template <typename Op>
178 static FailureOr<MeshOp>
179 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
180   auto mesh =
181       ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
182   if (failed(mesh)) {
183     return failure();
184   }
185   if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
186     return failure();
187   }
188   return mesh;
189 }
190 
191 template <typename InShape, typename MeshShape, typename SplitAxes,
192           typename OutShape>
193 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
194                        const SplitAxes &splitAxes, OutShape &outShape,
195                        ArrayRef<int64_t> shardedDimsOffsets = {},
196                        ArrayRef<int64_t> haloSizes = {}) {
197   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
198             llvm::adl_begin(outShape));
199 
200   if (!shardedDimsOffsets.empty()) {
201     auto isDynShape = ShapedType::isDynamicShape(meshShape);
202     uint64_t pos = 1;
203     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
204       if (!innerSplitAxes.empty()) {
205         auto sz = shardedDimsOffsets[pos];
206         bool same = !isDynShape;
207         if (same) {
208           // Find sharded dims in shardedDimsOffsets with same static size on
209           // all devices. Use kDynamic for dimensions with dynamic or
210           // non-uniform offs in shardedDimsOffsets.
211           uint64_t numShards = 0;
212           for (auto i : innerSplitAxes.asArrayRef()) {
213             numShards += meshShape[i];
214           }
215           for (size_t i = 1; i < numShards; ++i) {
216             if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
217                 sz) {
218               same = false;
219               break;
220             }
221           }
222           pos += numShards + 1;
223         }
224         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
225       }
226     }
227   } else {
228     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
229       outShape[tensorAxis] = shardDimension(
230           inShape[tensorAxis],
231           collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
232     }
233 
234     if (!haloSizes.empty()) {
235       // add halo sizes if requested
236       int haloAxis = 0;
237       for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
238         if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
239             !innerSplitAxes.empty()) {
240           if (haloSizes[haloAxis * 2] >= 0 &&
241               haloSizes[haloAxis * 2 + 1] >= 0) {
242             outShape[tensorAxis] +=
243                 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
244             ++haloAxis;
245           } else {
246             outShape[tensorAxis] = ShapedType::kDynamic;
247           }
248         }
249       }
250     }
251   }
252 }
253 
254 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
255                                  MeshSharding sharding) {
256   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
257   SmallVector<Dim> resShapeArr(shape.getShape().size());
258   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
259              resShapeArr, sharding.getStaticShardedDimsOffsets(),
260              sharding.getStaticHaloSizes());
261   return shape.clone(resShapeArr);
262 }
263 
264 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
265   RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
266   if (rankedTensorType) {
267     return shardShapedType(rankedTensorType, mesh, sharding);
268   }
269   return type;
270 }
271 
272 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
273                                                      OpOperand &operand,
274                                                      OpBuilder &builder) {
275   OpBuilder::InsertionGuard insertionGuard(builder);
276   Value operandValue = operand.get();
277   Operation *operandOp = operand.getOwner();
278   builder.setInsertionPointAfterValue(operandValue);
279   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
280   if (shardOp && sharding == shardOp.getSharding() &&
281       !shardOp.getAnnotateForUsers()) {
282     // No need for anything the correct sharding is already set.
283     return;
284   }
285 
286   auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
287   auto newShardOp =
288       builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
289                               /*annotate_for_users*/ false);
290   IRRewriter rewriter(builder);
291   rewriter.replaceUsesWithIf(
292       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
293         return use.getOwner() == operandOp && use.get() == operandValue;
294       });
295 
296   if (!shardOp || shardOp.getAnnotateForUsers()) {
297     return;
298   }
299 
300   auto newShardOp2 =
301       builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
302                               /*annotate_for_users*/ true);
303   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
304 }
305 
306 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
307                                                      OpResult result,
308                                                      OpBuilder &builder) {
309   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
310     maybeInsertTargetShardingAnnotation(sharding, use, builder);
311   }
312 }
313 
314 void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
315                                                      OpOperand &operand,
316                                                      OpBuilder &builder) {
317   OpBuilder::InsertionGuard insertionGuard(builder);
318   Value operandValue = operand.get();
319   Operation *operandOp = operand.getOwner();
320   Operation *operandSrcOp = operandValue.getDefiningOp();
321   bool isBlockArg = !operandSrcOp;
322   ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
323 
324   if (shardOp && sharding == shardOp.getSharding() &&
325       shardOp.getAnnotateForUsers()) {
326     // No need for anything the correct sharding is already set.
327     return;
328   }
329 
330   builder.setInsertionPoint(operandOp);
331   auto shardingOp =
332       builder.create<ShardingOp>(operand.get().getLoc(), sharding);
333   auto newShardOp =
334       builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
335                               /*annotate_for_users*/ true);
336   IRRewriter rewriter(builder);
337   rewriter.replaceUsesWithIf(
338       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
339         return use.getOwner() == operandOp && use.get() == operandValue;
340       });
341 
342   if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
343     // No need for resharding.
344     return;
345   }
346 
347   builder.setInsertionPoint(newShardOp);
348   auto newPreceedingShardOp =
349       builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
350                               /*annotate_for_users*/ false);
351   rewriter.replaceUsesWithIf(
352       newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
353         return use.getOwner() == newShardOp.getOperation();
354       });
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // mesh.mesh op
359 //===----------------------------------------------------------------------===//
360 
361 LogicalResult MeshOp::verify() {
362   int64_t rank = getRank();
363 
364   if (rank <= 0)
365     return emitOpError("rank of mesh is expected to be a positive integer");
366 
367   for (int64_t dimSize : getShape()) {
368     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
369       return emitOpError("dimension size of a mesh is expected to be "
370                          "non-negative or dynamic");
371   }
372 
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // mesh.mesh_shape op
378 //===----------------------------------------------------------------------===//
379 
380 LogicalResult
381 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
382   auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
383   if (failed(mesh)) {
384     return failure();
385   }
386   if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
387     return failure();
388   }
389 
390   size_t expectedResultsCount =
391       getAxes().empty() ? mesh->getRank() : getAxes().size();
392   if (getResult().size() != expectedResultsCount) {
393     return emitError() << "Unexpected number of results " << getResult().size()
394                        << ". Expected " << expectedResultsCount << ".";
395   }
396 
397   return success();
398 }
399 
400 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
401                         MeshOp mesh) {
402   build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
403 }
404 
405 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
406                         MeshOp mesh, ArrayRef<MeshAxis> axes) {
407   build(odsBuilder, odsState,
408         SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
409                           odsBuilder.getIndexType()),
410         mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
411 }
412 
413 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
414                         StringRef mesh, ArrayRef<MeshAxis> axes) {
415   assert(!axes.empty());
416   build(odsBuilder, odsState,
417         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
418         MeshAxesAttr::get(odsBuilder.getContext(), axes));
419 }
420 
421 void MeshShapeOp::getAsmResultNames(
422     function_ref<void(Value, StringRef)> setNameFn) {
423   setNameFn(getResults()[0], "mesh_shape");
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // mesh.sharding
428 //===----------------------------------------------------------------------===//
429 
430 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
431                        FlatSymbolRefAttr mesh,
432                        ArrayRef<MeshAxesAttr> split_axes,
433                        ArrayRef<MeshAxis> partial_axes,
434                        mesh::ReductionKind partial_type,
435                        ArrayRef<int64_t> static_halo_sizes,
436                        ArrayRef<int64_t> static_sharded_dims_offsets) {
437   return build(
438       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
439       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
440       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
441       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
442       ::mlir::DenseI64ArrayAttr::get(b.getContext(),
443                                      static_sharded_dims_offsets),
444       {});
445 }
446 
447 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
448                        FlatSymbolRefAttr mesh,
449                        ArrayRef<MeshAxesAttr> split_axes) {
450   return build(
451       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
452       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
453       {}, {}, {}, {});
454 }
455 
456 void ShardingOp::build(
457     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
458     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
459     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
460     ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
461   mlir::SmallVector<int64_t> staticHalos, staticDims;
462   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
463   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
464   dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
465   return build(
466       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
467       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
468       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
469       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
470 }
471 
472 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
473                        mlir::mesh::MeshSharding from) {
474 
475   build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
476         MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
477         from.getPartialAxes().empty()
478             ? DenseI16ArrayAttr()
479             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
480         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
481                                              from.getPartialType()),
482         from.getStaticShardedDimsOffsets().empty()
483             ? DenseI64ArrayAttr()
484             : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
485         from.getDynamicShardedDimsOffsets(),
486         from.getStaticHaloSizes().empty()
487             ? DenseI64ArrayAttr()
488             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
489         from.getDynamicHaloSizes());
490 }
491 
492 LogicalResult ShardingOp::verify() {
493   llvm::SmallSet<MeshAxis, 4> visitedAxes;
494 
495   auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
496     for (MeshAxis axis : axesArray) {
497       if (axis < 0)
498         return emitError() << "mesh axis is expected to be non-negative";
499       if (!visitedAxes.insert(axis).second)
500         return emitError() << "mesh axis duplicated";
501     }
502     return success();
503   };
504 
505   for (auto subAxes : getSplitAxes().getAxes()) {
506     ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
507     if (failed(checkMeshAxis(subAxesArray)))
508       return failure();
509   }
510   if (getPartialAxes().has_value() &&
511       failed(checkMeshAxis(getPartialAxes().value())))
512     return failure();
513 
514   if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
515     return emitOpError("halo sizes and shard offsets are mutually exclusive");
516   }
517 
518   if (!getStaticHaloSizes().empty()) {
519     auto numSplitAxes = getSplitAxes().getAxes().size();
520     for (auto splitAxis : getSplitAxes().getAxes()) {
521       if (splitAxis.empty()) {
522         --numSplitAxes;
523       }
524     }
525     if (getStaticHaloSizes().size() != numSplitAxes * 2) {
526       return emitError() << "halo sizes must be specified for all split axes.";
527     }
528   }
529 
530   return success();
531 }
532 
533 void ShardingOp::getAsmResultNames(
534     function_ref<void(Value, StringRef)> setNameFn) {
535   setNameFn(getResult(), "sharding");
536 }
537 
538 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
539   auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
540   if (failed(mesh)) {
541     return failure();
542   }
543   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
544       getStaticShardedDimsOffsets().size() > 0) {
545     return emitError() << "sharded dims offsets are not allowed for "
546                           "devices meshes with dynamic shape.";
547   }
548 
549   auto shardedDimsOffsets = getStaticShardedDimsOffsets();
550   if (!shardedDimsOffsets.empty()) {
551     auto meshShape = mesh.value().getShape();
552     assert(!ShapedType::isDynamicShape(meshShape));
553     uint64_t pos = 0;
554     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
555       if (!innerSplitAxes.empty()) {
556         int64_t numShards = 0, off = 0;
557         for (auto i : innerSplitAxes.asArrayRef()) {
558           numShards += meshShape[i];
559         }
560         for (int64_t i = 0; i <= numShards; ++i) {
561           if (shardedDimsOffsets.size() <= pos + i) {
562             return emitError() << "sharded dims offsets has wrong size.";
563           }
564           if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
565             if (shardedDimsOffsets[pos + i] < off) {
566               return emitError()
567                      << "sharded dims offsets must be non-decreasing.";
568             }
569             off = shardedDimsOffsets[pos + i];
570           }
571         }
572         pos += numShards + 1;
573       }
574     }
575   }
576   return success();
577 }
578 
579 namespace {
580 // Sharding annotations "halo sizes" and "sharded dims offsets"
581 // are a mix of attributes and dynamic values. This canonicalization moves
582 // constant values to the respective attribute lists and so minimizes the number
583 // of values.
584 class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
585 public:
586   using OpRewritePattern<ShardingOp>::OpRewritePattern;
587 
588   LogicalResult matchAndRewrite(ShardingOp op,
589                                 PatternRewriter &b) const override {
590     auto mixedHalos =
591         getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
592     auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
593                                     op.getDynamicShardedDimsOffsets(), b);
594 
595     // No constant operands were folded, just return;
596     if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
597         failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
598       return failure();
599     }
600 
601     auto halos = decomposeMixedValues(mixedHalos);
602     auto offs = decomposeMixedValues(mixedOffs);
603 
604     op.setStaticHaloSizes(halos.first);
605     op.getDynamicHaloSizesMutable().assign(halos.second);
606     op.setStaticShardedDimsOffsets(offs.first);
607     op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
608 
609     return success();
610   }
611 };
612 } // namespace
613 
614 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
615                                              mlir::MLIRContext *context) {
616   results.add<FoldDynamicLists>(context);
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // MeshSharding
621 //===----------------------------------------------------------------------===//
622 
623 bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
624   if (getMesh() != rhs.getMesh()) {
625     return false;
626   }
627 
628   if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
629       (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
630       !llvm::equal(
631           llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
632           llvm::make_range(rhs.getPartialAxes().begin(),
633                            rhs.getPartialAxes().end()))) {
634     return false;
635   }
636 
637   auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
638   if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
639                                     getSplitAxes().begin() + minSize),
640                    llvm::make_range(rhs.getSplitAxes().begin(),
641                                     rhs.getSplitAxes().begin() + minSize))) {
642     return false;
643   }
644 
645   return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
646                                        getSplitAxes().end()),
647                       std::mem_fn(&MeshAxesAttr::empty)) &&
648          llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
649                                        rhs.getSplitAxes().end()),
650                       std::mem_fn(&MeshAxesAttr::empty));
651 }
652 
653 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
654   return equalShardSizes(rhs) && equalHaloSizes(rhs);
655 }
656 
657 bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
658   if (rhs.getStaticShardedDimsOffsets().size() !=
659           getStaticShardedDimsOffsets().size() ||
660       !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
661                                     getStaticShardedDimsOffsets().end()),
662                    llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
663                                     rhs.getStaticShardedDimsOffsets().end()))) {
664     return false;
665   }
666   if (rhs.getDynamicShardedDimsOffsets().size() !=
667           getDynamicShardedDimsOffsets().size() ||
668       !llvm::equal(
669           llvm::make_range(getDynamicShardedDimsOffsets().begin(),
670                            getDynamicShardedDimsOffsets().end()),
671           llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
672                            rhs.getDynamicShardedDimsOffsets().end()))) {
673     return false;
674   }
675   return true;
676 }
677 
678 bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
679   if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
680       !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
681                                     getStaticHaloSizes().end()),
682                    llvm::make_range(rhs.getStaticHaloSizes().begin(),
683                                     rhs.getStaticHaloSizes().end()))) {
684     return false;
685   }
686   if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
687       !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
688                                     getDynamicHaloSizes().end()),
689                    llvm::make_range(rhs.getDynamicHaloSizes().begin(),
690                                     rhs.getDynamicHaloSizes().end()))) {
691     return false;
692   }
693   return true;
694 }
695 
696 bool MeshSharding::operator==(Value rhs) const {
697   return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
698 }
699 
700 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
701 
702 bool MeshSharding::operator==(const MeshSharding &rhs) const {
703   return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
704 }
705 
706 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
707   return !(*this == rhs);
708 }
709 
710 MeshSharding::MeshSharding(Value rhs) {
711   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
712   assert(shardingOp && "expected sharding op");
713   *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
714               shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
715               shardingOp.getPartialType().value_or(ReductionKind::Sum),
716               shardingOp.getStaticHaloSizes(),
717               shardingOp.getStaticShardedDimsOffsets(),
718               SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
719               SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
720 }
721 
722 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
723                                ArrayRef<MeshAxesAttr> split_axes_,
724                                ArrayRef<MeshAxis> partial_axes_,
725                                ReductionKind partial_type_,
726                                ArrayRef<int64_t> static_halo_sizes_,
727                                ArrayRef<int64_t> static_sharded_dims_offsets_,
728                                ArrayRef<Value> dynamic_halo_sizes_,
729                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
730   MeshSharding res;
731   res.mesh = mesh_;
732   res.split_axes.resize(split_axes_.size());
733   for (auto [i, axis] : llvm::enumerate(split_axes_)) {
734     res.split_axes[i] =
735         MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
736   }
737 
738   auto clone = [](const auto src, auto &dst) {
739     dst.resize(src.size());
740     llvm::copy(src, dst.begin());
741   };
742 
743   clone(partial_axes_, res.partial_axes);
744   res.partial_type = partial_type_;
745   clone(static_halo_sizes_, res.static_halo_sizes);
746   clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
747   clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
748   clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
749 
750   return res;
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // mesh.shard_shape
755 //===----------------------------------------------------------------------===//
756 
757 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
758                          ::mlir::OperationState &odsState,
759                          ::llvm::ArrayRef<int64_t> shape,
760                          ::mlir::Value sharding, ::mlir::Value device) {
761   SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
762   build(odsBuilder, odsState, resType, shape, sharding, device);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // mesh.shard op
767 //===----------------------------------------------------------------------===//
768 
769 void ShardOp::getAsmResultNames(
770     function_ref<void(Value, StringRef)> setNameFn) {
771   setNameFn(getResult(), "sharding_annotated");
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // mesh.process_multi_index op
776 //===----------------------------------------------------------------------===//
777 
778 LogicalResult
779 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
780   auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
781   if (failed(mesh)) {
782     return failure();
783   }
784   if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
785     return failure();
786   }
787 
788   size_t expectedResultsCount =
789       getAxes().empty() ? mesh->getRank() : getAxes().size();
790   if (getResult().size() != expectedResultsCount) {
791     return emitError() << "Unexpected number of results " << getResult().size()
792                        << ". Expected " << expectedResultsCount << ".";
793   }
794 
795   return success();
796 }
797 
798 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
799                                 MeshOp mesh) {
800   build(odsBuilder, odsState,
801         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
802         mesh.getSymName(), ArrayRef<MeshAxis>());
803 }
804 
805 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
806                                 StringRef mesh, ArrayRef<MeshAxis> axes) {
807   build(odsBuilder, odsState,
808         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
809         MeshAxesAttr::get(odsBuilder.getContext(), axes));
810 }
811 
812 void ProcessMultiIndexOp::getAsmResultNames(
813     function_ref<void(Value, StringRef)> setNameFn) {
814   setNameFn(getResults()[0], "proc_linear_idx");
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // mesh.process_linear_index op
819 //===----------------------------------------------------------------------===//
820 
821 LogicalResult
822 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
823   auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
824   if (failed(mesh)) {
825     return failure();
826   }
827   return success();
828 }
829 
830 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
831                                  OperationState &odsState, MeshOp mesh) {
832   build(odsBuilder, odsState, mesh.getSymName());
833 }
834 
835 void ProcessLinearIndexOp::getAsmResultNames(
836     function_ref<void(Value, StringRef)> setNameFn) {
837   setNameFn(getResult(), "proc_linear_idx");
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // mesh.neighbors_linear_indices op
842 //===----------------------------------------------------------------------===//
843 
844 LogicalResult
845 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
846   auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
847   if (failed(mesh)) {
848     return failure();
849   }
850   return success();
851 }
852 
853 void NeighborsLinearIndicesOp::getAsmResultNames(
854     function_ref<void(Value, StringRef)> setNameFn) {
855   setNameFn(getNeighborDown(), "down_linear_idx");
856   setNameFn(getNeighborUp(), "up_linear_idx");
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // collective communication ops
861 //===----------------------------------------------------------------------===//
862 
863 namespace {
864 
865 template <typename Op>
866 struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
867   using OpRewritePattern<Op>::OpRewritePattern;
868   LogicalResult matchAndRewrite(Op op,
869                                 PatternRewriter &rewriter) const override {
870     auto meshAxes = op.getMeshAxes();
871     if (!meshAxes.empty()) {
872       return failure();
873     }
874     if (op.getInput().getType() != op.getResult().getType()) {
875       return failure();
876     }
877 
878     rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
879     rewriter.eraseOp(op.getOperation());
880     return success();
881   }
882 };
883 
884 } // namespace
885 
886 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
887                                          ArrayRef<int64_t> device,
888                                          Operation::operand_range deviceDynamic,
889                                          ArrayRef<MeshAxis> meshAxes,
890                                          ArrayRef<int64_t> meshShape) {
891   if (device.size() != meshAxes.size()) {
892     return emitError(loc) << "In-group device \"" << deviceName
893                           << "\" has unexpected multi-index size "
894                           << device.size() << ". Expected " << meshAxes.size()
895                           << ".";
896   }
897 
898   for (size_t i = 0; i < device.size(); ++i) {
899     if (!ShapedType::isDynamic(device[i]) &&
900         !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
901         meshShape[meshAxes[i]] <= device[i]) {
902       return emitError(loc)
903              << "Out of bounds coordinate " << i << " for in-group device \""
904              << deviceName << "\"."
905              << " Got " << device[i] << ", but expected value in the range [0, "
906              << (meshShape[meshAxes[i]] - 1) << "].";
907     }
908   }
909   return success();
910 }
911 
912 template <typename It>
913 static auto product(It begin, It end) {
914   using ElementType = std::decay_t<decltype(*begin)>;
915   return std::accumulate(begin, end, static_cast<ElementType>(1),
916                          std::multiplies<ElementType>());
917 }
918 
919 template <typename R>
920 static auto product(R &&range) {
921   return product(adl_begin(range), adl_end(range));
922 }
923 
924 static LogicalResult verifyDimensionCompatibility(Location loc,
925                                                   int64_t expectedDimSize,
926                                                   int64_t resultDimSize,
927                                                   int64_t resultAxis) {
928   if (!ShapedType::isDynamic(resultDimSize) &&
929       expectedDimSize != resultDimSize) {
930     return emitError(loc) << "Dimension size mismatch for result axis "
931                           << resultAxis << ". Expected "
932                           << (ShapedType::isDynamic(expectedDimSize)
933                                   ? Twine("dynamic")
934                                   : Twine(expectedDimSize))
935                           << ", but got " << resultDimSize << ".";
936   }
937 
938   return success();
939 }
940 
941 static LogicalResult verifyGatherOperandAndResultShape(
942     Value operand, Value result, int64_t gatherAxis,
943     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
944   auto resultRank = cast<ShapedType>(result.getType()).getRank();
945   if (gatherAxis < 0 || gatherAxis >= resultRank) {
946     return emitError(result.getLoc())
947            << "Gather axis " << gatherAxis << " is out of bounds [0, "
948            << resultRank << ").";
949   }
950 
951   ShapedType operandType = cast<ShapedType>(operand.getType());
952   ShapedType resultType = cast<ShapedType>(result.getType());
953   auto deviceGroupSize =
954       DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
955   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
956     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
957     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
958     auto expectedResultDimSize =
959         axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
960     if (failed(verifyDimensionCompatibility(
961             result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
962       return failure();
963     }
964   }
965   return success();
966 }
967 
968 static LogicalResult verifyAllToAllOperandAndResultShape(
969     Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
970     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
971   ShapedType operandType = cast<ShapedType>(operand.getType());
972   ShapedType resultType = cast<ShapedType>(result.getType());
973   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
974     if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
975       if (failed(verifyDimensionCompatibility(
976               result.getLoc(), operandType.getDimSize(axis),
977               resultType.getDimSize(axis), axis))) {
978         return failure();
979       }
980     }
981   }
982 
983   if (splitAxis == concatAxis) {
984     return success();
985   }
986 
987   auto deviceGroupSize =
988       DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
989   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
990   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
991   DimensionSize expectedResultConcatDimSize =
992       operandConcatDimSize * deviceGroupSize;
993   DimensionSize expectedResultSplitDimSize =
994       operandSplitDimSize / deviceGroupSize;
995   if (!expectedResultSplitDimSize.isDynamic() &&
996       int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
997     expectedResultSplitDimSize = DimensionSize::dynamic();
998   }
999   if (failed(verifyDimensionCompatibility(
1000           result.getLoc(), expectedResultConcatDimSize.value(),
1001           resultType.getDimSize(concatAxis), concatAxis))) {
1002     return failure();
1003   }
1004   if (failed(verifyDimensionCompatibility(
1005           result.getLoc(), expectedResultSplitDimSize.value(),
1006           resultType.getDimSize(splitAxis), splitAxis))) {
1007     return failure();
1008   }
1009 
1010   return success();
1011 }
1012 
1013 static LogicalResult verifyScatterOrSliceOperandAndResultShape(
1014     Value operand, Value result, int64_t tensorAxis,
1015     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1016   ShapedType operandType = cast<ShapedType>(operand.getType());
1017   ShapedType resultType = cast<ShapedType>(result.getType());
1018   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1019     if (axis != tensorAxis) {
1020       if (failed(verifyDimensionCompatibility(
1021               result.getLoc(), operandType.getDimSize(axis),
1022               resultType.getDimSize(axis), axis))) {
1023         return failure();
1024       }
1025     }
1026   }
1027 
1028   auto deviceGroupSize =
1029       DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1030   auto operandScatterDimSize =
1031       DimensionSize(operandType.getDimSize(tensorAxis));
1032   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1033       int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1034     return emitError(result.getLoc())
1035            << "Operand dimension size " << int64_t(operandScatterDimSize)
1036            << " is not divisible by collective device group size "
1037            << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1038            << ".";
1039   }
1040   DimensionSize expectedResultTensorDimSize =
1041       operandScatterDimSize / deviceGroupSize;
1042   if (failed(verifyDimensionCompatibility(
1043           result.getLoc(), expectedResultTensorDimSize.value(),
1044           resultType.getDimSize(tensorAxis), tensorAxis))) {
1045     return failure();
1046   }
1047 
1048   return success();
1049 }
1050 
1051 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1052                                         ArrayRef<MeshAxis> meshAxes,
1053                                         int64_t sliceAxis) {
1054   RankedTensorType operandRankedTensorType =
1055       cast<RankedTensorType>(operandType);
1056   DimensionSize operandSliceAxisSize =
1057       operandRankedTensorType.getShape()[sliceAxis];
1058   SmallVector<int64_t> resultShape =
1059       llvm::to_vector(operandRankedTensorType.getShape());
1060 
1061   resultShape[sliceAxis] =
1062       operandSliceAxisSize /
1063       DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1064   return operandRankedTensorType.clone(resultShape);
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // mesh.all_gather op
1069 //===----------------------------------------------------------------------===//
1070 
1071 LogicalResult
1072 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1073   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1074   if (failed(mesh)) {
1075     return failure();
1076   }
1077   auto gatherAxis = getGatherAxis().getSExtValue();
1078   return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1079                                            gatherAxis, getMeshAxes(),
1080                                            mesh.value().getShape());
1081 }
1082 
1083 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1084                                               MLIRContext *context) {
1085   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1086 }
1087 
1088 void AllGatherOp::getAsmResultNames(
1089     function_ref<void(Value, StringRef)> setNameFn) {
1090   setNameFn(getResult(), "all_gather");
1091 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // mesh.all_reduce op
1095 //===----------------------------------------------------------------------===//
1096 
1097 LogicalResult
1098 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1099   return getMeshAndVerifyAxes(*this, symbolTable);
1100 }
1101 
1102 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1103                                               MLIRContext *context) {
1104   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1105 }
1106 
1107 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1108                         Value input, StringRef mesh,
1109                         ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1110   build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1111         reduction);
1112 }
1113 
1114 void AllReduceOp::getAsmResultNames(
1115     function_ref<void(Value, StringRef)> setNameFn) {
1116   setNameFn(getResult(), "all_reduce");
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // mesh.all_slice op
1121 //===----------------------------------------------------------------------===//
1122 
1123 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1124   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1125   if (failed(mesh)) {
1126     return failure();
1127   }
1128   return verifyScatterOrSliceOperandAndResultShape(
1129       getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1130       mesh.value().getShape());
1131 }
1132 
1133 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1134                                              MLIRContext *context) {
1135   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1136 }
1137 
1138 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1139                        Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1140                        int64_t sliceAxis) {
1141   Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1142   build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1143         sliceAxis);
1144 }
1145 
1146 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1147                        Type resultType, Value input, StringRef mesh,
1148                        ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1149   build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1150         APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1151 }
1152 
1153 void AllSliceOp::getAsmResultNames(
1154     function_ref<void(Value, StringRef)> setNameFn) {
1155   setNameFn(getResult(), "all_slice");
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // mesh.all_to_all op
1160 //===----------------------------------------------------------------------===//
1161 
1162 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1163   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1164   if (failed(mesh)) {
1165     return failure();
1166   }
1167 
1168   return verifyAllToAllOperandAndResultShape(
1169       getOperand(), getResult(), getSplitAxis().getSExtValue(),
1170       getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1171 }
1172 
1173 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1174                                              MLIRContext *context) {
1175   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1176 }
1177 
1178 void AllToAllOp::getAsmResultNames(
1179     function_ref<void(Value, StringRef)> setNameFn) {
1180   setNameFn(getResult(), "all_to_all");
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // mesh.broadcast op
1185 //===----------------------------------------------------------------------===//
1186 
1187 LogicalResult
1188 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1189   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1190   if (failed(mesh)) {
1191     return failure();
1192   }
1193   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1194                                  getRootDynamic(), getMeshAxes(),
1195                                  mesh.value().getShape()))) {
1196     return failure();
1197   }
1198 
1199   return success();
1200 }
1201 
1202 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1203                                               MLIRContext *context) {
1204   patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1205 }
1206 
1207 void BroadcastOp::getAsmResultNames(
1208     function_ref<void(Value, StringRef)> setNameFn) {
1209   setNameFn(getResult(), "broadcast");
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // mesh.gather op
1214 //===----------------------------------------------------------------------===//
1215 
1216 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1217   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1218   if (failed(mesh)) {
1219     return failure();
1220   }
1221   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1222                                  getRootDynamic(), getMeshAxes(),
1223                                  mesh.value().getShape()))) {
1224     return failure();
1225   }
1226 
1227   auto gatherAxis = getGatherAxis().getSExtValue();
1228   return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1229                                            getMeshAxes(),
1230                                            mesh.value().getShape());
1231 }
1232 
1233 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1234                                            MLIRContext *context) {
1235   patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1236 }
1237 
1238 void GatherOp::getAsmResultNames(
1239     function_ref<void(Value, StringRef)> setNameFn) {
1240   setNameFn(getResult(), "gather");
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // mesh.recv op
1245 //===----------------------------------------------------------------------===//
1246 
1247 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1248   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1249   if (failed(mesh)) {
1250     return failure();
1251   }
1252   if (getSource() &&
1253       failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1254                                  getSource().value(), getSourceDynamic(),
1255                                  getMeshAxes(), mesh.value().getShape()))) {
1256     return failure();
1257   }
1258   return success();
1259 }
1260 
1261 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1262                                          MLIRContext *context) {
1263   patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1264 }
1265 
1266 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1267   setNameFn(getResult(), "recv");
1268 }
1269 
1270 //===----------------------------------------------------------------------===//
1271 // mesh.reduce op
1272 //===----------------------------------------------------------------------===//
1273 
1274 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1275   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1276   if (failed(mesh)) {
1277     return failure();
1278   }
1279   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1280                                  getRootDynamic(), getMeshAxes(),
1281                                  mesh.value().getShape()))) {
1282     return failure();
1283   }
1284 
1285   return success();
1286 }
1287 
1288 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1289                                            MLIRContext *context) {
1290   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1291 }
1292 
1293 void ReduceOp::getAsmResultNames(
1294     function_ref<void(Value, StringRef)> setNameFn) {
1295   setNameFn(getResult(), "reduce");
1296 }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // mesh.reduce_scatter op
1300 //===----------------------------------------------------------------------===//
1301 
1302 LogicalResult
1303 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1304   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1305   if (failed(mesh)) {
1306     return failure();
1307   }
1308 
1309   return verifyScatterOrSliceOperandAndResultShape(
1310       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1311       mesh.value().getShape());
1312 }
1313 
1314 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1315                                                   MLIRContext *context) {
1316   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1317 }
1318 
1319 void ReduceScatterOp::getAsmResultNames(
1320     function_ref<void(Value, StringRef)> setNameFn) {
1321   setNameFn(getResult(), "reduce_scatter");
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // mesh.scatter op
1326 //===----------------------------------------------------------------------===//
1327 
1328 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1329   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1330   if (failed(mesh)) {
1331     return failure();
1332   }
1333   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1334                                  getRootDynamic(), getMeshAxes(),
1335                                  mesh.value().getShape()))) {
1336     return failure();
1337   }
1338 
1339   auto scatterAxis = getScatterAxis().getSExtValue();
1340   return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1341                                                    scatterAxis, getMeshAxes(),
1342                                                    mesh.value().getShape());
1343 }
1344 
1345 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1346                                             MLIRContext *context) {
1347   patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1348 }
1349 
1350 void ScatterOp::getAsmResultNames(
1351     function_ref<void(Value, StringRef)> setNameFn) {
1352   setNameFn(getResult(), "scatter");
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // mesh.send op
1357 //===----------------------------------------------------------------------===//
1358 
1359 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1360   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1361   if (failed(mesh)) {
1362     return failure();
1363   }
1364   if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1365                                  getDestination(), getDestinationDynamic(),
1366                                  getMeshAxes(), mesh.value().getShape()))) {
1367     return failure();
1368   }
1369   return success();
1370 }
1371 
1372 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373                                          MLIRContext *context) {
1374   patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1375 }
1376 
1377 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1378   setNameFn(getResult(), "send");
1379 }
1380 
1381 //===----------------------------------------------------------------------===//
1382 // mesh.shift op
1383 //===----------------------------------------------------------------------===//
1384 
1385 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1386   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1387   if (failed(mesh)) {
1388     return failure();
1389   }
1390 
1391   auto meshAxes = getMeshAxes();
1392   auto shiftAxis = getShiftAxis().getZExtValue();
1393   if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1394     return emitError() << "Invalid shift axis " << shiftAxis
1395                        << ". It must be one of the grouping mesh axes.";
1396   }
1397 
1398   return success();
1399 }
1400 
1401 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1402                                           MLIRContext *context) {
1403   // TODO: remove op when offset is 0 or if it is a rotate with and
1404   // offset % shift_axis_mesh_dim_size == 0.
1405 }
1406 
1407 void ShiftOp::getAsmResultNames(
1408     function_ref<void(Value, StringRef)> setNameFn) {
1409   setNameFn(getResult(), "shift");
1410 }
1411 
1412 //===----------------------------------------------------------------------===//
1413 // mesh.update_halo op
1414 //===----------------------------------------------------------------------===//
1415 
1416 LogicalResult
1417 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1418   auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1419   if (failed(mesh)) {
1420     return failure();
1421   }
1422 
1423   return success();
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // TableGen'd op method definitions
1428 //===----------------------------------------------------------------------===//
1429 
1430 #define GET_OP_CLASSES
1431 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1432 
1433 #define GET_ATTRDEF_CLASSES
1434 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1435 
1436 #define GET_TYPEDEF_CLASSES
1437 #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1438 
1439 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
1440