1b0d5b4d2SChengji Yao //===- ShardingPropagation.cpp ------------------------------------- C++ --===// 2b0d5b4d2SChengji Yao // 3b0d5b4d2SChengji Yao // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b0d5b4d2SChengji Yao // See https://llvm.org/LICENSE.txt for license information. 5b0d5b4d2SChengji Yao // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b0d5b4d2SChengji Yao // 7b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 8b0d5b4d2SChengji Yao 9b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Transforms/Passes.h" 10b0d5b4d2SChengji Yao 11b0d5b4d2SChengji Yao #include "mlir/Dialect/Func/IR/FuncOps.h" 1231fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 13b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/IR/MeshOps.h" 14b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 15d635b860SBoian Petkantchin #include "mlir/IR/Verifier.h" 16abfac563SBoian Petkantchin #include "mlir/Interfaces/FunctionInterfaces.h" 17b0d5b4d2SChengji Yao #include "mlir/Pass/Pass.h" 18d635b860SBoian Petkantchin #include "llvm/ADT/STLExtras.h" 19d635b860SBoian Petkantchin #include "llvm/ADT/SmallVector.h" 20d635b860SBoian Petkantchin #include "llvm/ADT/iterator_range.h" 21b0d5b4d2SChengji Yao #include "llvm/Support/Debug.h" 22d635b860SBoian Petkantchin #include "llvm/Support/raw_ostream.h" 23d635b860SBoian Petkantchin #include <algorithm> 24b0d5b4d2SChengji Yao #include <vector> 25b0d5b4d2SChengji Yao 26b0d5b4d2SChengji Yao namespace mlir { 27b0d5b4d2SChengji Yao namespace mesh { 28b0d5b4d2SChengji Yao #define GEN_PASS_DEF_SHARDINGPROPAGATION 29b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" 30b0d5b4d2SChengji Yao } // namespace mesh 31b0d5b4d2SChengji Yao } // namespace mlir 32b0d5b4d2SChengji Yao 33b0d5b4d2SChengji Yao #define DEBUG_TYPE "sharding-propagation" 34b0d5b4d2SChengji Yao #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 35b0d5b4d2SChengji Yao 36b0d5b4d2SChengji Yao using namespace mlir; 37b0d5b4d2SChengji Yao using namespace mlir::mesh; 38b0d5b4d2SChengji Yao 39d635b860SBoian Petkantchin enum class ReshardingRquirementKind { 40d635b860SBoian Petkantchin NO_RESHARDING = 0, 41d635b860SBoian Petkantchin NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS, 42d635b860SBoian Petkantchin RESHARDING_FOR_EXPLICIT_ANNOTATIONS 43d635b860SBoian Petkantchin }; 44d635b860SBoian Petkantchin 45d635b860SBoian Petkantchin #ifdef LLVM_DEBUG 46d635b860SBoian Petkantchin 47d635b860SBoian Petkantchin template <typename T> 48d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 49d635b860SBoian Petkantchin const SmallVector<T> &vec); 50d635b860SBoian Petkantchin template <typename... Ts> 51d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 52d635b860SBoian Petkantchin const std::tuple<Ts...> &t); 53d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 54d635b860SBoian Petkantchin ReshardingRquirementKind v); 55d635b860SBoian Petkantchin 56d635b860SBoian Petkantchin template <typename Stream, typename Range> 57d635b860SBoian Petkantchin static Stream &printRange(Stream &stream, Range &&range) { 58d635b860SBoian Petkantchin stream << "["; 59d635b860SBoian Petkantchin llvm::for_each(range, [&stream](auto &v) { 60d635b860SBoian Petkantchin stream << v; 61d635b860SBoian Petkantchin stream << ", "; 62d635b860SBoian Petkantchin }); 63d635b860SBoian Petkantchin return stream << "]"; 64d635b860SBoian Petkantchin } 65d635b860SBoian Petkantchin 66d635b860SBoian Petkantchin template <typename T> 67d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 68d635b860SBoian Petkantchin const SmallVector<T> &vec) { 69d635b860SBoian Petkantchin return printRange(stream, vec); 70d635b860SBoian Petkantchin } 71d635b860SBoian Petkantchin 72bc0cdeffSKazu Hirata [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 73d635b860SBoian Petkantchin const ShardingOption &v) { 74d635b860SBoian Petkantchin return stream << "{empty = " << v.empty << ", mesh" << v.mesh 75d635b860SBoian Petkantchin << ", shardingArray = " << v.shardingArray << "}"; 76d635b860SBoian Petkantchin } 77d635b860SBoian Petkantchin 78d635b860SBoian Petkantchin template <typename Stream, typename... Ts, size_t... Is> 79d635b860SBoian Petkantchin static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple, 80d635b860SBoian Petkantchin std::index_sequence<Is...>) { 81d635b860SBoian Petkantchin static_assert(sizeof...(Is) == sizeof...(Ts), 82d635b860SBoian Petkantchin "Indices must have same number of elements as tuple types!"); 83d635b860SBoian Petkantchin static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream."); 84d635b860SBoian Petkantchin 85d635b860SBoian Petkantchin stream << "{"; 86d635b860SBoian Petkantchin ((stream << std::get<Is>(tuple) << ", "), ...); 87d635b860SBoian Petkantchin return stream << "}"; 88d635b860SBoian Petkantchin } 89d635b860SBoian Petkantchin 90d635b860SBoian Petkantchin template <typename... Ts> 91d635b860SBoian Petkantchin static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 92d635b860SBoian Petkantchin const std::tuple<Ts...> &t) { 93d635b860SBoian Petkantchin return printTuple(stream, t, std::index_sequence_for<Ts...>{}); 94d635b860SBoian Petkantchin } 95d635b860SBoian Petkantchin 96bc0cdeffSKazu Hirata [[maybe_unused]] static llvm::raw_ostream & 97bc0cdeffSKazu Hirata operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) { 98d635b860SBoian Petkantchin return stream << static_cast<int>(v); 99d635b860SBoian Petkantchin } 100d635b860SBoian Petkantchin 101d635b860SBoian Petkantchin #endif // LLVM_DEBUG 102d635b860SBoian Petkantchin 103b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 104b0d5b4d2SChengji Yao // Utilities 105b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 106b0d5b4d2SChengji Yao 107b0d5b4d2SChengji Yao // This method retrieves all potential sharding attributes, prioritizing 108b0d5b4d2SChengji Yao // specific shardings. For example, mustShardings = [shard0, None] and 109b0d5b4d2SChengji Yao // optionalShardings = [None, shard1], the result will be [[shard0, shard1], 110b0d5b4d2SChengji Yao // [shard0, None]] 111*baabcb28SFrank Schlimbach static SmallVector<std::vector<MeshSharding>> 112*baabcb28SFrank Schlimbach getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings, 113*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> optionalShardings) { 114*baabcb28SFrank Schlimbach SmallVector<std::vector<MeshSharding>> allShardingAttrs; 115*baabcb28SFrank Schlimbach std::vector<MeshSharding> curShardingAttrs; 116b0d5b4d2SChengji Yao 117b0d5b4d2SChengji Yao std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) { 118b0d5b4d2SChengji Yao if (i == mustShardings.size()) { 119*baabcb28SFrank Schlimbach allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs)); 120b0d5b4d2SChengji Yao return; 121b0d5b4d2SChengji Yao } 122b0d5b4d2SChengji Yao 123b0d5b4d2SChengji Yao if (mustShardings[i]) { 124b0d5b4d2SChengji Yao curShardingAttrs.push_back(mustShardings[i]); 125b0d5b4d2SChengji Yao dfsCreateShardingAttrs(i + 1); 126b0d5b4d2SChengji Yao curShardingAttrs.pop_back(); 127b0d5b4d2SChengji Yao return; 128b0d5b4d2SChengji Yao } 129b0d5b4d2SChengji Yao 130b0d5b4d2SChengji Yao if (optionalShardings[i]) { 131b0d5b4d2SChengji Yao curShardingAttrs.push_back(optionalShardings[i]); 132b0d5b4d2SChengji Yao dfsCreateShardingAttrs(i + 1); 133b0d5b4d2SChengji Yao curShardingAttrs.pop_back(); 134*baabcb28SFrank Schlimbach curShardingAttrs.push_back({}); 135b0d5b4d2SChengji Yao dfsCreateShardingAttrs(i + 1); 136b0d5b4d2SChengji Yao curShardingAttrs.pop_back(); 137b0d5b4d2SChengji Yao return; 138b0d5b4d2SChengji Yao } 139b0d5b4d2SChengji Yao 140*baabcb28SFrank Schlimbach curShardingAttrs.push_back({}); 141b0d5b4d2SChengji Yao dfsCreateShardingAttrs(i + 1); 142b0d5b4d2SChengji Yao curShardingAttrs.pop_back(); 143b0d5b4d2SChengji Yao }; 144b0d5b4d2SChengji Yao 145b0d5b4d2SChengji Yao dfsCreateShardingAttrs(0); 146b0d5b4d2SChengji Yao return allShardingAttrs; 147b0d5b4d2SChengji Yao } 148b0d5b4d2SChengji Yao 149d635b860SBoian Petkantchin // The order of preference is form highest to lowest: 150d635b860SBoian Petkantchin // 1. No resharding is required (all existing annotations are compatible). 151d635b860SBoian Petkantchin // 2. No resharding for operands/results that have annotation specifically 152d635b860SBoian Petkantchin // targeting this operation. This means 153d635b860SBoian Petkantchin // * operands that are the result of `mesh.shard` ops marked with 154d635b860SBoian Petkantchin // `annotate_for_users`. 155d635b860SBoian Petkantchin // * results that are annotated with `mesh.shard` ops without 156d635b860SBoian Petkantchin // `annotate_for_users`. 157d635b860SBoian Petkantchin // 3. All other cases. Resharding is required for operands/results with 158d635b860SBoian Petkantchin // annotation targeting explicitly this operation. 159d635b860SBoian Petkantchin ReshardingRquirementKind getReshardingRquirementKind( 160*baabcb28SFrank Schlimbach Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) { 161d635b860SBoian Petkantchin ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING; 162d635b860SBoian Petkantchin 163d635b860SBoian Petkantchin size_t operandsCount = op->getOperands().size(); 164d635b860SBoian Petkantchin auto operandShardings = 165d635b860SBoian Petkantchin llvm::make_range(operandAndResultShardings.begin(), 166d635b860SBoian Petkantchin operandAndResultShardings.begin() + operandsCount); 167d635b860SBoian Petkantchin auto resultShardings = 168d635b860SBoian Petkantchin llvm::make_range(operandAndResultShardings.begin() + operandsCount, 169d635b860SBoian Petkantchin operandAndResultShardings.end()); 170d635b860SBoian Petkantchin 171d635b860SBoian Petkantchin for (auto [operand, sharding] : 172d635b860SBoian Petkantchin llvm::zip_equal(op->getOperands(), operandShardings)) { 173d635b860SBoian Petkantchin ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp()); 174d635b860SBoian Petkantchin if (!shardOp) { 175d635b860SBoian Petkantchin continue; 176d635b860SBoian Petkantchin } 177*baabcb28SFrank Schlimbach bool needsResharding = sharding != shardOp.getSharding(); 178d635b860SBoian Petkantchin bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers(); 179d635b860SBoian Petkantchin if (needsResharding) { 180d635b860SBoian Petkantchin if (isExplicitAnnotationForThisOp) { 181d635b860SBoian Petkantchin // This is the worst case. No need to continue. 182d635b860SBoian Petkantchin return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 183d635b860SBoian Petkantchin } 184d635b860SBoian Petkantchin res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 185d635b860SBoian Petkantchin } 186d635b860SBoian Petkantchin } 187d635b860SBoian Petkantchin 188d635b860SBoian Petkantchin for (auto [result, sharding] : 189d635b860SBoian Petkantchin llvm::zip_equal(op->getResults(), resultShardings)) { 190d635b860SBoian Petkantchin for (auto user : result.getUsers()) { 191d635b860SBoian Petkantchin ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); 192d635b860SBoian Petkantchin if (!shardOp) { 193d635b860SBoian Petkantchin continue; 194d635b860SBoian Petkantchin } 195*baabcb28SFrank Schlimbach bool needsResharding = sharding != shardOp.getSharding(); 196d635b860SBoian Petkantchin bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers(); 197d635b860SBoian Petkantchin if (needsResharding) { 198d635b860SBoian Petkantchin if (isExplicitAnnotationForThisOp) { 199d635b860SBoian Petkantchin // This is the worst case. No need to continue. 200d635b860SBoian Petkantchin return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 201d635b860SBoian Petkantchin } 202d635b860SBoian Petkantchin res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 203d635b860SBoian Petkantchin } 204d635b860SBoian Petkantchin } 205d635b860SBoian Petkantchin } 206d635b860SBoian Petkantchin 207d635b860SBoian Petkantchin return res; 208d635b860SBoian Petkantchin } 209d635b860SBoian Petkantchin 210d635b860SBoian Petkantchin // From all the operand and result sharding combinations, 211d635b860SBoian Petkantchin // return the one that is most desirable. 212d635b860SBoian Petkantchin // The order of preference is: 213d635b860SBoian Petkantchin // 1. No resharding with respect to existing sharding annotations. 214d635b860SBoian Petkantchin // 2. Resharding for values that have already annotations that do not target 215d635b860SBoian Petkantchin // this op. 216d635b860SBoian Petkantchin // 3. Resharding of existing explicit sharding annotations for this op. 217d635b860SBoian Petkantchin static FailureOr<ShardingOption> selectShardingOption( 218d635b860SBoian Petkantchin ShardingInterface shardingOp, 219*baabcb28SFrank Schlimbach ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs, 220*baabcb28SFrank Schlimbach ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) { 221d635b860SBoian Petkantchin SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>> 222d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements; 223d635b860SBoian Petkantchin 224*baabcb28SFrank Schlimbach for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) { 225*baabcb28SFrank Schlimbach for (ArrayRef<MeshSharding> operandShardings : 226d635b860SBoian Petkantchin possibleOperandShardingAttrs) { 227d635b860SBoian Petkantchin FailureOr<ShardingOption> shardingOption = 228d635b860SBoian Petkantchin shardingOp.getShardingOption(operandShardings, resultShardings); 229d635b860SBoian Petkantchin if (failed(shardingOption) || shardingOption->empty) { 230d635b860SBoian Petkantchin continue; 231d635b860SBoian Petkantchin } 232d635b860SBoian Petkantchin // These shardings may not be the same as those in operandShardings and 233d635b860SBoian Petkantchin // resultShardings. 234d635b860SBoian Petkantchin // They may be missing some annotations. 235d635b860SBoian Petkantchin // Whatever is returned by getShardingAnnotations is exactly what the op 236d635b860SBoian Petkantchin // needs. 237*baabcb28SFrank Schlimbach FailureOr<std::vector<MeshSharding>> operandAndResultShardings = 238d635b860SBoian Petkantchin shardingOp.getShardingAnnotations(*shardingOption); 239d635b860SBoian Petkantchin if (failed(operandAndResultShardings)) { 240d635b860SBoian Petkantchin return failure(); 241d635b860SBoian Petkantchin } 242d635b860SBoian Petkantchin 243*baabcb28SFrank Schlimbach // LLVM_DEBUG(DBGS() << "operandAndResultShardings = " 244*baabcb28SFrank Schlimbach // << *operandAndResultShardings << "\n";); 245d635b860SBoian Petkantchin 246d635b860SBoian Petkantchin ReshardingRquirementKind reshardingRquirement = 247d635b860SBoian Petkantchin getReshardingRquirementKind(shardingOp, *operandAndResultShardings); 248d635b860SBoian Petkantchin if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) { 249d635b860SBoian Petkantchin // This is the best case. No need to go on. 250d635b860SBoian Petkantchin return *shardingOption; 251d635b860SBoian Petkantchin } 252d635b860SBoian Petkantchin 253d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements.emplace_back( 254d635b860SBoian Petkantchin std::move(*shardingOption), reshardingRquirement); 255d635b860SBoian Petkantchin } 256d635b860SBoian Petkantchin } 257d635b860SBoian Petkantchin 258d635b860SBoian Petkantchin if (shardingOptionsAndReshardingRequirements.empty()) { 259d635b860SBoian Petkantchin return ShardingOption::makeEmpty(); 260d635b860SBoian Petkantchin } 261d635b860SBoian Petkantchin 262d635b860SBoian Petkantchin std::partial_sort( 263d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements.begin(), 264d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements.begin() + 1, 265d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements.end(), 266d635b860SBoian Petkantchin [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a, 267d635b860SBoian Petkantchin const std::tuple<ShardingOption, ReshardingRquirementKind> &b) { 268d635b860SBoian Petkantchin return std::get<ReshardingRquirementKind>(a) < 269d635b860SBoian Petkantchin std::get<ReshardingRquirementKind>(b); 270d635b860SBoian Petkantchin }); 271d635b860SBoian Petkantchin 272d635b860SBoian Petkantchin LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = " 273d635b860SBoian Petkantchin << shardingOptionsAndReshardingRequirements << "\n";); 274d635b860SBoian Petkantchin 275d635b860SBoian Petkantchin return std::get<ShardingOption>( 276d635b860SBoian Petkantchin shardingOptionsAndReshardingRequirements.front()); 277d635b860SBoian Petkantchin } 278d635b860SBoian Petkantchin 279b0d5b4d2SChengji Yao // For each operation that implements the ShardingInterface, infer the sharding 280b0d5b4d2SChengji Yao // option of the operation from its operands and/or results using the 281b0d5b4d2SChengji Yao // `getShardingOption` method. If the inferred sharding option is not empty, add 282b0d5b4d2SChengji Yao // a `mesh.shard` operation for all remaining operands and results that do not 283b0d5b4d2SChengji Yao // have sharding annotations. 284adbf21f1SBoian Petkantchin static LogicalResult visitOp(Operation *op, OpBuilder &builder) { 285*baabcb28SFrank Schlimbach if (op->hasTrait<OpTrait::IsTerminator>() || 286*baabcb28SFrank Schlimbach llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op)) 287b0d5b4d2SChengji Yao return success(); 288b0d5b4d2SChengji Yao 289b0d5b4d2SChengji Yao ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op); 290b0d5b4d2SChengji Yao if (!shardingOp) { 291b0d5b4d2SChengji Yao op->emitOpError() << "sharding interface is not implemented."; 292b0d5b4d2SChengji Yao return failure(); 293b0d5b4d2SChengji Yao } 294b0d5b4d2SChengji Yao 295*baabcb28SFrank Schlimbach // collect MeshSharding from results 296*baabcb28SFrank Schlimbach std::vector<MeshSharding> allowConflictsResultShardings; 297b0d5b4d2SChengji Yao allowConflictsResultShardings.resize(op->getNumResults()); 298*baabcb28SFrank Schlimbach std::vector<MeshSharding> resultMustShardings; 299b0d5b4d2SChengji Yao resultMustShardings.resize(op->getNumResults()); 300b0d5b4d2SChengji Yao for (OpResult result : op->getResults()) { 301*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = 302*baabcb28SFrank Schlimbach getMeshSharding(result); 303b0d5b4d2SChengji Yao if (failed(maybeShardAttr)) 304b0d5b4d2SChengji Yao continue; 305b0d5b4d2SChengji Yao if (!maybeShardAttr->first) 306b0d5b4d2SChengji Yao resultMustShardings[result.getResultNumber()] = maybeShardAttr->second; 307b0d5b4d2SChengji Yao else 308b0d5b4d2SChengji Yao allowConflictsResultShardings[result.getResultNumber()] = 309b0d5b4d2SChengji Yao maybeShardAttr->second; 310b0d5b4d2SChengji Yao } 311b0d5b4d2SChengji Yao 312*baabcb28SFrank Schlimbach // collect MeshSharding from operands 313*baabcb28SFrank Schlimbach std::vector<MeshSharding> allowConflictsOperandShardings; 314b0d5b4d2SChengji Yao allowConflictsOperandShardings.resize(op->getNumOperands()); 315*baabcb28SFrank Schlimbach std::vector<MeshSharding> operandMustShardings; 316b0d5b4d2SChengji Yao operandMustShardings.resize(op->getNumOperands()); 317b0d5b4d2SChengji Yao for (OpOperand &opOperand : op->getOpOperands()) { 318*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = 319*baabcb28SFrank Schlimbach getMeshSharding(opOperand); 320b0d5b4d2SChengji Yao if (failed(maybeShardAttr)) 321b0d5b4d2SChengji Yao continue; 322b0d5b4d2SChengji Yao 323b0d5b4d2SChengji Yao if (maybeShardAttr->first) 324b0d5b4d2SChengji Yao operandMustShardings[opOperand.getOperandNumber()] = 325b0d5b4d2SChengji Yao maybeShardAttr->second; 326b0d5b4d2SChengji Yao else 327b0d5b4d2SChengji Yao allowConflictsOperandShardings[opOperand.getOperandNumber()] = 328b0d5b4d2SChengji Yao maybeShardAttr->second; 329b0d5b4d2SChengji Yao } 330b0d5b4d2SChengji Yao 331b0d5b4d2SChengji Yao // try to get the sharding option 332*baabcb28SFrank Schlimbach SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs = 333b0d5b4d2SChengji Yao getOrderedPossibleShardingAttrs(operandMustShardings, 334b0d5b4d2SChengji Yao allowConflictsOperandShardings); 335*baabcb28SFrank Schlimbach SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs = 336b0d5b4d2SChengji Yao getOrderedPossibleShardingAttrs(resultMustShardings, 337b0d5b4d2SChengji Yao allowConflictsResultShardings); 338d635b860SBoian Petkantchin FailureOr<ShardingOption> shardingOption = selectShardingOption( 339d635b860SBoian Petkantchin shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs); 340b0d5b4d2SChengji Yao 341d635b860SBoian Petkantchin if (failed(shardingOption)) { 342b0d5b4d2SChengji Yao op->emitOpError() << "fail to get sharding option."; 343b0d5b4d2SChengji Yao return failure(); 344b0d5b4d2SChengji Yao } 345d635b860SBoian Petkantchin 346d635b860SBoian Petkantchin LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n"); 347d635b860SBoian Petkantchin 348b0d5b4d2SChengji Yao // sharding info is empty, return immediately 349d635b860SBoian Petkantchin if (shardingOption->empty) 350b0d5b4d2SChengji Yao return success(); 351b0d5b4d2SChengji Yao 352d635b860SBoian Petkantchin if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) { 353b0d5b4d2SChengji Yao op->emitOpError() << "fail to set sharding annotations."; 354b0d5b4d2SChengji Yao return failure(); 355b0d5b4d2SChengji Yao } 356b0d5b4d2SChengji Yao return success(); 357b0d5b4d2SChengji Yao } 358b0d5b4d2SChengji Yao 359b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 360b0d5b4d2SChengji Yao // ShardingPropagation 361b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 362b0d5b4d2SChengji Yao struct ShardingPropagation 363b0d5b4d2SChengji Yao : public mesh::impl::ShardingPropagationBase<ShardingPropagation> { 364b0d5b4d2SChengji Yao void runOnOperation() override { 365abfac563SBoian Petkantchin FunctionOpInterface funcOp = getOperation(); 366b0d5b4d2SChengji Yao MLIRContext *ctx = funcOp.getContext(); 367abfac563SBoian Petkantchin Region ®ion = funcOp.getFunctionBody(); 368b0d5b4d2SChengji Yao OpBuilder builder(ctx); 369b0d5b4d2SChengji Yao if (!region.hasOneBlock()) { 370b0d5b4d2SChengji Yao funcOp.emitOpError() << "only one block is supported!"; 371b0d5b4d2SChengji Yao signalPassFailure(); 372b0d5b4d2SChengji Yao } 373b0d5b4d2SChengji Yao Block &block = region.front(); 374b0d5b4d2SChengji Yao 375b0d5b4d2SChengji Yao LLVM_DEBUG( 376b0d5b4d2SChengji Yao DBGS() << "print all the ops' iterator types and indexing maps in the " 377b0d5b4d2SChengji Yao "block.\n"; 378b0d5b4d2SChengji Yao for (Operation &op 379b0d5b4d2SChengji Yao : block.getOperations()) { 380b0d5b4d2SChengji Yao if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) 381b0d5b4d2SChengji Yao shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); 382b0d5b4d2SChengji Yao }); 383b0d5b4d2SChengji Yao 384b0d5b4d2SChengji Yao // 1. propagate in reversed order 385b0d5b4d2SChengji Yao for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) 386b0d5b4d2SChengji Yao if (failed(visitOp(&op, builder))) 387b0d5b4d2SChengji Yao return signalPassFailure(); 388b0d5b4d2SChengji Yao 389b0d5b4d2SChengji Yao LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" 390b0d5b4d2SChengji Yao << funcOp << "\n"); 391d635b860SBoian Petkantchin LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); 392b0d5b4d2SChengji Yao 393b0d5b4d2SChengji Yao // 2. propagate in original order 394b0d5b4d2SChengji Yao for (Operation &op : llvm::make_early_inc_range(block)) 395b0d5b4d2SChengji Yao if (failed(visitOp(&op, builder))) 396b0d5b4d2SChengji Yao return signalPassFailure(); 397b0d5b4d2SChengji Yao } 398b0d5b4d2SChengji Yao }; 399