1 //===- ShardingPropagation.cpp ------------------------------------- C++ --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 13 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 14 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 15 #include "mlir/IR/Verifier.h" 16 #include "mlir/Interfaces/FunctionInterfaces.h" 17 #include "mlir/Pass/Pass.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/ADT/iterator_range.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/raw_ostream.h" 23 #include <algorithm> 24 #include <vector> 25 26 namespace mlir { 27 namespace mesh { 28 #define GEN_PASS_DEF_SHARDINGPROPAGATION 29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" 30 } // namespace mesh 31 } // namespace mlir 32 33 #define DEBUG_TYPE "sharding-propagation" 34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 35 36 using namespace mlir; 37 using namespace mlir::mesh; 38 39 enum class ReshardingRquirementKind { 40 NO_RESHARDING = 0, 41 NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS, 42 RESHARDING_FOR_EXPLICIT_ANNOTATIONS 43 }; 44 45 #ifdef LLVM_DEBUG 46 47 template <typename T> 48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 49 const SmallVector<T> &vec); 50 template <typename... Ts> 51 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 52 const std::tuple<Ts...> &t); 53 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 54 ReshardingRquirementKind v); 55 56 template <typename Stream, typename Range> 57 static Stream &printRange(Stream &stream, Range &&range) { 58 stream << "["; 59 llvm::for_each(range, [&stream](auto &v) { 60 stream << v; 61 stream << ", "; 62 }); 63 return stream << "]"; 64 } 65 66 template <typename T> 67 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 68 const SmallVector<T> &vec) { 69 return printRange(stream, vec); 70 } 71 72 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 73 const ShardingOption &v) { 74 return stream << "{empty = " << v.empty << ", mesh" << v.mesh 75 << ", shardingArray = " << v.shardingArray << "}"; 76 } 77 78 template <typename Stream, typename... Ts, size_t... Is> 79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple, 80 std::index_sequence<Is...>) { 81 static_assert(sizeof...(Is) == sizeof...(Ts), 82 "Indices must have same number of elements as tuple types!"); 83 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream."); 84 85 stream << "{"; 86 ((stream << std::get<Is>(tuple) << ", "), ...); 87 return stream << "}"; 88 } 89 90 template <typename... Ts> 91 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, 92 const std::tuple<Ts...> &t) { 93 return printTuple(stream, t, std::index_sequence_for<Ts...>{}); 94 } 95 96 [[maybe_unused]] static llvm::raw_ostream & 97 operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) { 98 return stream << static_cast<int>(v); 99 } 100 101 #endif // LLVM_DEBUG 102 103 //===----------------------------------------------------------------------===// 104 // Utilities 105 //===----------------------------------------------------------------------===// 106 107 // This method retrieves all potential sharding attributes, prioritizing 108 // specific shardings. For example, mustShardings = [shard0, None] and 109 // optionalShardings = [None, shard1], the result will be [[shard0, shard1], 110 // [shard0, None]] 111 static SmallVector<std::vector<MeshSharding>> 112 getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings, 113 ArrayRef<MeshSharding> optionalShardings) { 114 SmallVector<std::vector<MeshSharding>> allShardingAttrs; 115 std::vector<MeshSharding> curShardingAttrs; 116 117 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) { 118 if (i == mustShardings.size()) { 119 allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs)); 120 return; 121 } 122 123 if (mustShardings[i]) { 124 curShardingAttrs.push_back(mustShardings[i]); 125 dfsCreateShardingAttrs(i + 1); 126 curShardingAttrs.pop_back(); 127 return; 128 } 129 130 if (optionalShardings[i]) { 131 curShardingAttrs.push_back(optionalShardings[i]); 132 dfsCreateShardingAttrs(i + 1); 133 curShardingAttrs.pop_back(); 134 curShardingAttrs.push_back({}); 135 dfsCreateShardingAttrs(i + 1); 136 curShardingAttrs.pop_back(); 137 return; 138 } 139 140 curShardingAttrs.push_back({}); 141 dfsCreateShardingAttrs(i + 1); 142 curShardingAttrs.pop_back(); 143 }; 144 145 dfsCreateShardingAttrs(0); 146 return allShardingAttrs; 147 } 148 149 // The order of preference is form highest to lowest: 150 // 1. No resharding is required (all existing annotations are compatible). 151 // 2. No resharding for operands/results that have annotation specifically 152 // targeting this operation. This means 153 // * operands that are the result of `mesh.shard` ops marked with 154 // `annotate_for_users`. 155 // * results that are annotated with `mesh.shard` ops without 156 // `annotate_for_users`. 157 // 3. All other cases. Resharding is required for operands/results with 158 // annotation targeting explicitly this operation. 159 ReshardingRquirementKind getReshardingRquirementKind( 160 Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) { 161 ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING; 162 163 size_t operandsCount = op->getOperands().size(); 164 auto operandShardings = 165 llvm::make_range(operandAndResultShardings.begin(), 166 operandAndResultShardings.begin() + operandsCount); 167 auto resultShardings = 168 llvm::make_range(operandAndResultShardings.begin() + operandsCount, 169 operandAndResultShardings.end()); 170 171 for (auto [operand, sharding] : 172 llvm::zip_equal(op->getOperands(), operandShardings)) { 173 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp()); 174 if (!shardOp) { 175 continue; 176 } 177 bool needsResharding = sharding != shardOp.getSharding(); 178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers(); 179 if (needsResharding) { 180 if (isExplicitAnnotationForThisOp) { 181 // This is the worst case. No need to continue. 182 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 183 } 184 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 185 } 186 } 187 188 for (auto [result, sharding] : 189 llvm::zip_equal(op->getResults(), resultShardings)) { 190 for (auto user : result.getUsers()) { 191 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); 192 if (!shardOp) { 193 continue; 194 } 195 bool needsResharding = sharding != shardOp.getSharding(); 196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers(); 197 if (needsResharding) { 198 if (isExplicitAnnotationForThisOp) { 199 // This is the worst case. No need to continue. 200 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 201 } 202 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS; 203 } 204 } 205 } 206 207 return res; 208 } 209 210 // From all the operand and result sharding combinations, 211 // return the one that is most desirable. 212 // The order of preference is: 213 // 1. No resharding with respect to existing sharding annotations. 214 // 2. Resharding for values that have already annotations that do not target 215 // this op. 216 // 3. Resharding of existing explicit sharding annotations for this op. 217 static FailureOr<ShardingOption> selectShardingOption( 218 ShardingInterface shardingOp, 219 ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs, 220 ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) { 221 SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>> 222 shardingOptionsAndReshardingRequirements; 223 224 for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) { 225 for (ArrayRef<MeshSharding> operandShardings : 226 possibleOperandShardingAttrs) { 227 FailureOr<ShardingOption> shardingOption = 228 shardingOp.getShardingOption(operandShardings, resultShardings); 229 if (failed(shardingOption) || shardingOption->empty) { 230 continue; 231 } 232 // These shardings may not be the same as those in operandShardings and 233 // resultShardings. 234 // They may be missing some annotations. 235 // Whatever is returned by getShardingAnnotations is exactly what the op 236 // needs. 237 FailureOr<std::vector<MeshSharding>> operandAndResultShardings = 238 shardingOp.getShardingAnnotations(*shardingOption); 239 if (failed(operandAndResultShardings)) { 240 return failure(); 241 } 242 243 // LLVM_DEBUG(DBGS() << "operandAndResultShardings = " 244 // << *operandAndResultShardings << "\n";); 245 246 ReshardingRquirementKind reshardingRquirement = 247 getReshardingRquirementKind(shardingOp, *operandAndResultShardings); 248 if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) { 249 // This is the best case. No need to go on. 250 return *shardingOption; 251 } 252 253 shardingOptionsAndReshardingRequirements.emplace_back( 254 std::move(*shardingOption), reshardingRquirement); 255 } 256 } 257 258 if (shardingOptionsAndReshardingRequirements.empty()) { 259 return ShardingOption::makeEmpty(); 260 } 261 262 std::partial_sort( 263 shardingOptionsAndReshardingRequirements.begin(), 264 shardingOptionsAndReshardingRequirements.begin() + 1, 265 shardingOptionsAndReshardingRequirements.end(), 266 [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a, 267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) { 268 return std::get<ReshardingRquirementKind>(a) < 269 std::get<ReshardingRquirementKind>(b); 270 }); 271 272 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = " 273 << shardingOptionsAndReshardingRequirements << "\n";); 274 275 return std::get<ShardingOption>( 276 shardingOptionsAndReshardingRequirements.front()); 277 } 278 279 // For each operation that implements the ShardingInterface, infer the sharding 280 // option of the operation from its operands and/or results using the 281 // `getShardingOption` method. If the inferred sharding option is not empty, add 282 // a `mesh.shard` operation for all remaining operands and results that do not 283 // have sharding annotations. 284 static LogicalResult visitOp(Operation *op, OpBuilder &builder) { 285 if (op->hasTrait<OpTrait::IsTerminator>() || 286 llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op)) 287 return success(); 288 289 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op); 290 if (!shardingOp) { 291 op->emitOpError() << "sharding interface is not implemented."; 292 return failure(); 293 } 294 295 // collect MeshSharding from results 296 std::vector<MeshSharding> allowConflictsResultShardings; 297 allowConflictsResultShardings.resize(op->getNumResults()); 298 std::vector<MeshSharding> resultMustShardings; 299 resultMustShardings.resize(op->getNumResults()); 300 for (OpResult result : op->getResults()) { 301 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = 302 getMeshSharding(result); 303 if (failed(maybeShardAttr)) 304 continue; 305 if (!maybeShardAttr->first) 306 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second; 307 else 308 allowConflictsResultShardings[result.getResultNumber()] = 309 maybeShardAttr->second; 310 } 311 312 // collect MeshSharding from operands 313 std::vector<MeshSharding> allowConflictsOperandShardings; 314 allowConflictsOperandShardings.resize(op->getNumOperands()); 315 std::vector<MeshSharding> operandMustShardings; 316 operandMustShardings.resize(op->getNumOperands()); 317 for (OpOperand &opOperand : op->getOpOperands()) { 318 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = 319 getMeshSharding(opOperand); 320 if (failed(maybeShardAttr)) 321 continue; 322 323 if (maybeShardAttr->first) 324 operandMustShardings[opOperand.getOperandNumber()] = 325 maybeShardAttr->second; 326 else 327 allowConflictsOperandShardings[opOperand.getOperandNumber()] = 328 maybeShardAttr->second; 329 } 330 331 // try to get the sharding option 332 SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs = 333 getOrderedPossibleShardingAttrs(operandMustShardings, 334 allowConflictsOperandShardings); 335 SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs = 336 getOrderedPossibleShardingAttrs(resultMustShardings, 337 allowConflictsResultShardings); 338 FailureOr<ShardingOption> shardingOption = selectShardingOption( 339 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs); 340 341 if (failed(shardingOption)) { 342 op->emitOpError() << "fail to get sharding option."; 343 return failure(); 344 } 345 346 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n"); 347 348 // sharding info is empty, return immediately 349 if (shardingOption->empty) 350 return success(); 351 352 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) { 353 op->emitOpError() << "fail to set sharding annotations."; 354 return failure(); 355 } 356 return success(); 357 } 358 359 //===----------------------------------------------------------------------===// 360 // ShardingPropagation 361 //===----------------------------------------------------------------------===// 362 struct ShardingPropagation 363 : public mesh::impl::ShardingPropagationBase<ShardingPropagation> { 364 void runOnOperation() override { 365 FunctionOpInterface funcOp = getOperation(); 366 MLIRContext *ctx = funcOp.getContext(); 367 Region ®ion = funcOp.getFunctionBody(); 368 OpBuilder builder(ctx); 369 if (!region.hasOneBlock()) { 370 funcOp.emitOpError() << "only one block is supported!"; 371 signalPassFailure(); 372 } 373 Block &block = region.front(); 374 375 LLVM_DEBUG( 376 DBGS() << "print all the ops' iterator types and indexing maps in the " 377 "block.\n"; 378 for (Operation &op 379 : block.getOperations()) { 380 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) 381 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); 382 }); 383 384 // 1. propagate in reversed order 385 for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) 386 if (failed(visitOp(&op, builder))) 387 return signalPassFailure(); 388 389 LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" 390 << funcOp << "\n"); 391 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); 392 393 // 2. propagate in original order 394 for (Operation &op : llvm::make_early_inc_range(block)) 395 if (failed(visitOp(&op, builder))) 396 return signalPassFailure(); 397 } 398 }; 399