1b0d5b4d2SChengji Yao //===- ShardingInterface.cpp -------------------------------------*- C++-*-===// 2b0d5b4d2SChengji Yao // 3b0d5b4d2SChengji Yao // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b0d5b4d2SChengji Yao // See https://llvm.org/LICENSE.txt for license information. 5b0d5b4d2SChengji Yao // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b0d5b4d2SChengji Yao // 7b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 8b0d5b4d2SChengji Yao 9b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 10adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 11adbf21f1SBoian Petkantchin 12b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/IR/MeshOps.h" 13b0d5b4d2SChengji Yao #include "mlir/IR/AffineMap.h" 14adbf21f1SBoian Petkantchin #include "mlir/IR/IRMapping.h" 15b0d5b4d2SChengji Yao #include "mlir/Support/LLVM.h" 16ff2720d1SBoian Petkantchin #include "llvm/ADT/ArrayRef.h" 17adbf21f1SBoian Petkantchin #include "llvm/ADT/STLExtras.h" 18b0d5b4d2SChengji Yao #include "llvm/ADT/SmallSet.h" 19b0d5b4d2SChengji Yao #include "llvm/Support/Debug.h" 20b0d5b4d2SChengji Yao 21b0d5b4d2SChengji Yao #include <utility> 22b0d5b4d2SChengji Yao 23b0d5b4d2SChengji Yao #define DEBUG_TYPE "sharding-interface" 24b0d5b4d2SChengji Yao #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 25b0d5b4d2SChengji Yao 26b0d5b4d2SChengji Yao using namespace mlir; 27b0d5b4d2SChengji Yao using namespace mlir::mesh; 28b0d5b4d2SChengji Yao 29b0d5b4d2SChengji Yao #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" 30b0d5b4d2SChengji Yao 31b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 32b0d5b4d2SChengji Yao // common util functions 33b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 34b0d5b4d2SChengji Yao 35b0d5b4d2SChengji Yao static LogicalResult 36b0d5b4d2SChengji Yao checkOperandAffineExprRecursively(AffineExpr expr, 37b0d5b4d2SChengji Yao SmallVectorImpl<bool> &seenIds) { 38b0d5b4d2SChengji Yao switch (expr.getKind()) { 39b0d5b4d2SChengji Yao case AffineExprKind::Add: { 401609f1c2Slong.chen auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 41b0d5b4d2SChengji Yao AffineExpr lhs = binOpExpr.getLHS(); 42b0d5b4d2SChengji Yao AffineExpr rhs = binOpExpr.getRHS(); 43b0d5b4d2SChengji Yao if (failed(checkOperandAffineExprRecursively(lhs, seenIds))) 44b0d5b4d2SChengji Yao return failure(); 45b0d5b4d2SChengji Yao if (failed(checkOperandAffineExprRecursively(rhs, seenIds))) 46b0d5b4d2SChengji Yao return failure(); 47b0d5b4d2SChengji Yao return success(); 48b0d5b4d2SChengji Yao } 49b0d5b4d2SChengji Yao case AffineExprKind::Mul: { 501609f1c2Slong.chen auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 51b0d5b4d2SChengji Yao AffineExpr lhs = binOpExpr.getLHS(); 52b0d5b4d2SChengji Yao AffineExpr rhs = binOpExpr.getRHS(); 53b0d5b4d2SChengji Yao AffineExpr dimExpr; 54b0d5b4d2SChengji Yao if (lhs.getKind() == AffineExprKind::DimId && 55b0d5b4d2SChengji Yao rhs.getKind() == AffineExprKind::Constant) { 56b0d5b4d2SChengji Yao dimExpr = lhs; 57b0d5b4d2SChengji Yao } else if (rhs.getKind() == AffineExprKind::DimId && 58b0d5b4d2SChengji Yao lhs.getKind() == AffineExprKind::Constant) { 59b0d5b4d2SChengji Yao dimExpr = rhs; 60b0d5b4d2SChengji Yao } else 61b0d5b4d2SChengji Yao return failure(); 621609f1c2Slong.chen unsigned position = cast<AffineDimExpr>(dimExpr).getPosition(); 63b0d5b4d2SChengji Yao if ((size_t)position >= seenIds.size() || seenIds[position]) 64b0d5b4d2SChengji Yao return failure(); 65b0d5b4d2SChengji Yao seenIds[position] = true; 66b0d5b4d2SChengji Yao return success(); 67b0d5b4d2SChengji Yao } 68b0d5b4d2SChengji Yao case AffineExprKind::DimId: { 691609f1c2Slong.chen unsigned position = cast<AffineDimExpr>(expr).getPosition(); 70b0d5b4d2SChengji Yao if ((size_t)position >= seenIds.size() || seenIds[position]) 71b0d5b4d2SChengji Yao return failure(); 72b0d5b4d2SChengji Yao seenIds[position] = true; 73b0d5b4d2SChengji Yao return success(); 74b0d5b4d2SChengji Yao } 75b0d5b4d2SChengji Yao default: 76b0d5b4d2SChengji Yao return failure(); 77b0d5b4d2SChengji Yao } 78b0d5b4d2SChengji Yao } 79b0d5b4d2SChengji Yao 80b0d5b4d2SChengji Yao static FailureOr<llvm::SmallSet<unsigned, 2>> 81b0d5b4d2SChengji Yao checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { 82b0d5b4d2SChengji Yao SmallVector<bool> seenIds(numDims, false); 83b0d5b4d2SChengji Yao if (failed(checkOperandAffineExprRecursively(expr, seenIds))) 84b0d5b4d2SChengji Yao return failure(); 85b0d5b4d2SChengji Yao 86b0d5b4d2SChengji Yao llvm::SmallSet<unsigned, 2> positions; 87b0d5b4d2SChengji Yao for (auto it : llvm::enumerate(seenIds)) { 88b0d5b4d2SChengji Yao if (it.value()) 89b0d5b4d2SChengji Yao positions.insert((unsigned)it.index()); 90b0d5b4d2SChengji Yao } 91b0d5b4d2SChengji Yao return positions; 92b0d5b4d2SChengji Yao } 93b0d5b4d2SChengji Yao 94*baabcb28SFrank Schlimbach template <typename T> 95*baabcb28SFrank Schlimbach SmallVector<MeshAxesAttr> 96*baabcb28SFrank Schlimbach fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { 97*baabcb28SFrank Schlimbach SmallVector<MeshAxesAttr> res; 98*baabcb28SFrank Schlimbach for (const auto &v : vec) { 99*baabcb28SFrank Schlimbach res.emplace_back(MeshAxesAttr::get(ctxt, v)); 100*baabcb28SFrank Schlimbach } 101*baabcb28SFrank Schlimbach return res; 102*baabcb28SFrank Schlimbach } 103*baabcb28SFrank Schlimbach 104b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 105*baabcb28SFrank Schlimbach // mesh::getMeshSharding 106b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 107b0d5b4d2SChengji Yao 108*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>> 109*baabcb28SFrank Schlimbach mesh::getMeshSharding(OpResult result) { 110a5757c5bSChristian Sigg Value val = cast<Value>(result); 111b0d5b4d2SChengji Yao bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { 112b0d5b4d2SChengji Yao auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); 113b0d5b4d2SChengji Yao if (!shardOp) 114b0d5b4d2SChengji Yao return false; 115b0d5b4d2SChengji Yao return !shardOp.getAnnotateForUsers(); 116b0d5b4d2SChengji Yao }); 117b0d5b4d2SChengji Yao 118b0d5b4d2SChengji Yao if (anyShardedForDef) { 119b0d5b4d2SChengji Yao // expected to have exact one use if it has a use of `mesh.shard` without 120b0d5b4d2SChengji Yao // unit attr annotate_for_users 121b0d5b4d2SChengji Yao if (!val.hasOneUse()) 122b0d5b4d2SChengji Yao return failure(); 123b0d5b4d2SChengji Yao auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); 124*baabcb28SFrank Schlimbach return std::make_pair(false, MeshSharding(shardOp.getSharding())); 125b0d5b4d2SChengji Yao } 126b0d5b4d2SChengji Yao 127b0d5b4d2SChengji Yao bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) { 128b0d5b4d2SChengji Yao auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); 129b0d5b4d2SChengji Yao if (!shardOp) 130b0d5b4d2SChengji Yao return false; 131b0d5b4d2SChengji Yao return shardOp.getAnnotateForUsers(); 132b0d5b4d2SChengji Yao }); 133b0d5b4d2SChengji Yao if (anyShardedForUsers) { 134b0d5b4d2SChengji Yao SmallVector<ShardOp> shardOps; 135b0d5b4d2SChengji Yao for (Operation *user : val.getUsers()) { 136b0d5b4d2SChengji Yao ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); 137b0d5b4d2SChengji Yao if (shardOp) 138b0d5b4d2SChengji Yao shardOps.push_back(shardOp); 139b0d5b4d2SChengji Yao } 140*baabcb28SFrank Schlimbach MeshSharding shardForDef = shardOps[0].getSharding(); 141b0d5b4d2SChengji Yao for (size_t i = 1; i < shardOps.size(); ++i) { 142b0d5b4d2SChengji Yao // TODO: Deduce a reasonable mesh sharding attr for def when they are 143b0d5b4d2SChengji Yao // different 144*baabcb28SFrank Schlimbach assert(shardForDef == shardOps[i].getSharding() && 145b0d5b4d2SChengji Yao "only support all shard ops have the same mesh sharding attr"); 146b0d5b4d2SChengji Yao } 147b0d5b4d2SChengji Yao return std::make_pair(true, shardForDef); 148b0d5b4d2SChengji Yao } 149b0d5b4d2SChengji Yao return failure(); 150b0d5b4d2SChengji Yao } 151b0d5b4d2SChengji Yao 152*baabcb28SFrank Schlimbach FailureOr<std::pair<bool, MeshSharding>> 153*baabcb28SFrank Schlimbach mesh::getMeshSharding(OpOperand &opOperand) { 154b0d5b4d2SChengji Yao Value val = opOperand.get(); 155b0d5b4d2SChengji Yao if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) 156*baabcb28SFrank Schlimbach return std::make_pair(shardOp.getAnnotateForUsers(), 157*baabcb28SFrank Schlimbach MeshSharding(shardOp.getSharding())); 158b0d5b4d2SChengji Yao 159b0d5b4d2SChengji Yao return failure(); 160b0d5b4d2SChengji Yao } 161b0d5b4d2SChengji Yao 162b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 163b0d5b4d2SChengji Yao // ShardingInterface::verifyShardingInterfaceImpl 164b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 165b0d5b4d2SChengji Yao 166b0d5b4d2SChengji Yao LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { 167b0d5b4d2SChengji Yao Operation *op = getOperation(); 168b0d5b4d2SChengji Yao 169b0d5b4d2SChengji Yao // check operands and results type 170b0d5b4d2SChengji Yao for (Type type : op->getOperandTypes()) 171b0d5b4d2SChengji Yao if (!llvm::isa<RankedTensorType>(type)) 172b0d5b4d2SChengji Yao return failure(); 173b0d5b4d2SChengji Yao for (Type type : op->getResultTypes()) 174b0d5b4d2SChengji Yao if (!llvm::isa<RankedTensorType>(type)) 175b0d5b4d2SChengji Yao return failure(); 176b0d5b4d2SChengji Yao 177b0d5b4d2SChengji Yao // check loop types 178ff2720d1SBoian Petkantchin SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes(); 1793b6462c5SAdrian Kuegel if (loopTypes.empty()) 180b0d5b4d2SChengji Yao return failure(); 181b0d5b4d2SChengji Yao 182b0d5b4d2SChengji Yao // check maps 183b0d5b4d2SChengji Yao SmallVector<AffineMap> maps = getIndexingMaps(); 1843b6462c5SAdrian Kuegel if (maps.empty()) 185b0d5b4d2SChengji Yao return failure(); 186b0d5b4d2SChengji Yao unsigned numOperands = op->getNumOperands(); 187b0d5b4d2SChengji Yao unsigned numResults = op->getNumResults(); 188b0d5b4d2SChengji Yao if (numOperands + numResults != maps.size()) 189b0d5b4d2SChengji Yao return failure(); 190b0d5b4d2SChengji Yao 191b0d5b4d2SChengji Yao for (OpResult result : op->getResults()) { 192a5757c5bSChristian Sigg auto resultType = dyn_cast<RankedTensorType>(result.getType()); 193b0d5b4d2SChengji Yao if (!resultType) 194b0d5b4d2SChengji Yao return failure(); 195b0d5b4d2SChengji Yao AffineMap map = maps[numOperands + result.getResultNumber()]; 196b0d5b4d2SChengji Yao if (!map.isProjectedPermutation()) { 197b0d5b4d2SChengji Yao return failure(); 198b0d5b4d2SChengji Yao } 199b0d5b4d2SChengji Yao } 200b0d5b4d2SChengji Yao 201b0d5b4d2SChengji Yao return success(); 202b0d5b4d2SChengji Yao } 203b0d5b4d2SChengji Yao 204b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 205b0d5b4d2SChengji Yao // ShardingInterface::printLoopTypesAndIndexingMaps 206b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 207b0d5b4d2SChengji Yao 208b0d5b4d2SChengji Yao void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { 209b0d5b4d2SChengji Yao os << "print loop types and indexing maps for: \n"; 210b0d5b4d2SChengji Yao getOperation()->print(os); 211b0d5b4d2SChengji Yao os << "\n"; 212b0d5b4d2SChengji Yao os << "loop types: ["; 213ff2720d1SBoian Petkantchin for (utils::IteratorType type : getLoopIteratorTypes()) { 214b0d5b4d2SChengji Yao os << stringifyEnum(type) << " "; 215b0d5b4d2SChengji Yao } 216b0d5b4d2SChengji Yao os << "]\n"; 217b0d5b4d2SChengji Yao os << "indexing maps: \n"; 218b0d5b4d2SChengji Yao for (AffineMap map : getIndexingMaps()) 219b0d5b4d2SChengji Yao os << map << "\n"; 220b0d5b4d2SChengji Yao os << "\n"; 221b0d5b4d2SChengji Yao } 222b0d5b4d2SChengji Yao 223b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 224b0d5b4d2SChengji Yao // detail::defaultGetShardingOption 225b0d5b4d2SChengji Yao //===----------------------------------------------------------------------===// 226b0d5b4d2SChengji Yao 227b0d5b4d2SChengji Yao namespace { 228b0d5b4d2SChengji Yao 229b0d5b4d2SChengji Yao // Update the given `shardingOption` according to `meshAxes` and `loopIdx` 230b0d5b4d2SChengji Yao static LogicalResult fillShardingOption(Operation *op, 231b0d5b4d2SChengji Yao ShardingOption &shardingOption, 2329a8437f5SBoian Petkantchin FlatSymbolRefAttr mesh, 2337a4c4975SBoian Petkantchin ArrayRef<MeshAxis> meshAxes, 234b0d5b4d2SChengji Yao unsigned loopIdx) { 2359a8437f5SBoian Petkantchin if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || 236b0d5b4d2SChengji Yao (!shardingOption.shardingArray[loopIdx].empty() && 237b0d5b4d2SChengji Yao shardingOption.shardingArray[loopIdx] != meshAxes)) { 238b0d5b4d2SChengji Yao LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " 239b0d5b4d2SChengji Yao << loopIdx << "\n"); 240b0d5b4d2SChengji Yao return failure(); 241b0d5b4d2SChengji Yao } 242b0d5b4d2SChengji Yao for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { 243b0d5b4d2SChengji Yao if (i == loopIdx) 244b0d5b4d2SChengji Yao continue; 245b0d5b4d2SChengji Yao 2467a4c4975SBoian Petkantchin for (MeshAxis axis : meshAxes) { 24766555810SKazu Hirata if (llvm::is_contained(shardingOption.shardingArray[i], axis)) { 248b0d5b4d2SChengji Yao LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " 249b0d5b4d2SChengji Yao << axis << " duplicate"); 250b0d5b4d2SChengji Yao return failure(); 251b0d5b4d2SChengji Yao } 252b0d5b4d2SChengji Yao } 253b0d5b4d2SChengji Yao } 2549a8437f5SBoian Petkantchin if (mesh) 2559a8437f5SBoian Petkantchin shardingOption.mesh = mesh; 256b0d5b4d2SChengji Yao if (shardingOption.shardingArray[loopIdx].empty()) 257b0d5b4d2SChengji Yao shardingOption.shardingArray[loopIdx].append(meshAxes.begin(), 258b0d5b4d2SChengji Yao meshAxes.end()); 259b0d5b4d2SChengji Yao return success(); 260b0d5b4d2SChengji Yao } 261b0d5b4d2SChengji Yao 262b0d5b4d2SChengji Yao } // namespace 263b0d5b4d2SChengji Yao 264*baabcb28SFrank Schlimbach FailureOr<ShardingOption> 265*baabcb28SFrank Schlimbach mesh::detail::defaultGetShardingOption(Operation *op, 266*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 267*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings) { 268b0d5b4d2SChengji Yao ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 269b0d5b4d2SChengji Yao ShardingOption shardingOption; 270b0d5b4d2SChengji Yao 271b0d5b4d2SChengji Yao if (failed(shardingOp.verifyShardingInterfaceImpl())) 272b0d5b4d2SChengji Yao return op->emitOpError() << "invalid sharding interface implementation"; 273ff2720d1SBoian Petkantchin SmallVector<utils::IteratorType> loopTypes = 274ff2720d1SBoian Petkantchin shardingOp.getLoopIteratorTypes(); 275b0d5b4d2SChengji Yao SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 276b0d5b4d2SChengji Yao unsigned numOperands = op->getNumOperands(); 277b0d5b4d2SChengji Yao shardingOption.shardingArray.resize(loopTypes.size()); 2787a4c4975SBoian Petkantchin llvm::SmallVector<MeshAxis> partialMeshAxes; 279b0d5b4d2SChengji Yao llvm::SmallSet<unsigned, 4> visitedLoopIndices; 280b0d5b4d2SChengji Yao bool anyShardingInResultsOrOperands = false; 281b0d5b4d2SChengji Yao 282b0d5b4d2SChengji Yao // 1. Fill sharding option based on op results 283b0d5b4d2SChengji Yao for (auto shardingIt : llvm::enumerate(resultShardings)) { 284*baabcb28SFrank Schlimbach MeshSharding shardAttr = shardingIt.value(); 285b0d5b4d2SChengji Yao if (!shardAttr) 286b0d5b4d2SChengji Yao continue; 287b0d5b4d2SChengji Yao AffineMap map = maps[numOperands + shardingIt.index()]; 288b0d5b4d2SChengji Yao anyShardingInResultsOrOperands = true; 289b0d5b4d2SChengji Yao // Handle the split axes: calculate the corresponding loop index for each 290b0d5b4d2SChengji Yao // split axes sub-array, and then store the sub-array to 291b0d5b4d2SChengji Yao // shardingOption[index] 292b0d5b4d2SChengji Yao for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { 293b0d5b4d2SChengji Yao AffineExpr expr = std::get<0>(it); 2947a4c4975SBoian Petkantchin ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); 2951609f1c2Slong.chen auto dim = cast<AffineDimExpr>(expr); 296b0d5b4d2SChengji Yao unsigned index = dim.getPosition(); 297b0d5b4d2SChengji Yao visitedLoopIndices.insert(index); 298*baabcb28SFrank Schlimbach if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(), 299b0d5b4d2SChengji Yao axes, index))) 300b0d5b4d2SChengji Yao return failure(); 301b0d5b4d2SChengji Yao } 302b0d5b4d2SChengji Yao 303b0d5b4d2SChengji Yao // Handle the partial axes: at this stage, the exact loop index/indices 304b0d5b4d2SChengji Yao // cannot be decided because there could be multiple reduction loops. 3057a4c4975SBoian Petkantchin ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes(); 306b0d5b4d2SChengji Yao if (!partialAxes.empty()) { 307b0d5b4d2SChengji Yao if (!partialMeshAxes.empty()) 308b0d5b4d2SChengji Yao return op->emitOpError() << "at most one result with partial axes is " 309b0d5b4d2SChengji Yao "supported at present"; 310b0d5b4d2SChengji Yao partialMeshAxes.append(partialAxes.begin(), partialAxes.end()); 311b0d5b4d2SChengji Yao // Add all the reduction loop indices to `visitedLoopIndices` if 312b0d5b4d2SChengji Yao // `partialAxes` is not empty 313b0d5b4d2SChengji Yao for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) { 314b0d5b4d2SChengji Yao if (isReductionLoop(loopTypes[loopIdx])) 315b0d5b4d2SChengji Yao visitedLoopIndices.insert(loopIdx); 316b0d5b4d2SChengji Yao } 317b0d5b4d2SChengji Yao } 318b0d5b4d2SChengji Yao } 319b0d5b4d2SChengji Yao 320b0d5b4d2SChengji Yao // 2. Fill sharding option based on operands 321b0d5b4d2SChengji Yao for (auto shardingIt : llvm::enumerate(operandShardings)) { 322*baabcb28SFrank Schlimbach MeshSharding shardAttr = shardingIt.value(); 323b0d5b4d2SChengji Yao if (!shardAttr) 324b0d5b4d2SChengji Yao continue; 325b0d5b4d2SChengji Yao 326b0d5b4d2SChengji Yao anyShardingInResultsOrOperands = true; 327b0d5b4d2SChengji Yao AffineMap map = maps[shardingIt.index()]; 328b0d5b4d2SChengji Yao unsigned numDims = map.getNumDims(); 329b0d5b4d2SChengji Yao 330b0d5b4d2SChengji Yao // Handle the split axes. Partial axes don't need to be handled because they 331b0d5b4d2SChengji Yao // only affect the defining op of the operand. 332b0d5b4d2SChengji Yao // 333b0d5b4d2SChengji Yao // TODO: Change to process the operands with single loop index first and 334b0d5b4d2SChengji Yao // then the operands with multiple loop indices. 335b0d5b4d2SChengji Yao for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { 336b0d5b4d2SChengji Yao AffineExpr expr = std::get<0>(it); 3377a4c4975SBoian Petkantchin ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); 338b0d5b4d2SChengji Yao FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = 339b0d5b4d2SChengji Yao checkOperandAffineExpr(expr, numDims); 340b0d5b4d2SChengji Yao if (failed(loopIndices)) 341b0d5b4d2SChengji Yao return op->emitOpError() 342b0d5b4d2SChengji Yao << "operand's affine expression is restricted to const_i * " 343b0d5b4d2SChengji Yao "dim_i + const_j + dim_j + ..."; 344b0d5b4d2SChengji Yao if (loopIndices->empty()) 345b0d5b4d2SChengji Yao continue; 346b0d5b4d2SChengji Yao if (loopIndices->size() == 1) { 347b0d5b4d2SChengji Yao unsigned loopIdx = *loopIndices->begin(); 348b0d5b4d2SChengji Yao visitedLoopIndices.insert(loopIdx); 349*baabcb28SFrank Schlimbach if (failed(fillShardingOption(op, shardingOption, 350*baabcb28SFrank Schlimbach shardAttr.getMeshAttr(), axes, loopIdx))) 351b0d5b4d2SChengji Yao return failure(); 352b0d5b4d2SChengji Yao } 353b0d5b4d2SChengji Yao // If multiple loop indices correspond to a dimension of an operand, it is 354b0d5b4d2SChengji Yao // difficult to infer which loop indices are responsible for sharding. 355b0d5b4d2SChengji Yao // Therefore, the exact loop index must be specified by others. 356b0d5b4d2SChengji Yao if (loopIndices->size() > 1) { 357b0d5b4d2SChengji Yao bool seenLoopIndices = false; 358b0d5b4d2SChengji Yao for (unsigned loopIdx : *loopIndices) { 359b0d5b4d2SChengji Yao if (visitedLoopIndices.contains(loopIdx)) { 360b0d5b4d2SChengji Yao seenLoopIndices = true; 361b0d5b4d2SChengji Yao break; 362b0d5b4d2SChengji Yao } 363b0d5b4d2SChengji Yao } 364b0d5b4d2SChengji Yao if (!seenLoopIndices) 365b0d5b4d2SChengji Yao return op->emitOpError() 366b0d5b4d2SChengji Yao << "the operand " << shardingIt.index() 367b0d5b4d2SChengji Yao << " has multiple loop indices in a dimension, but none of " 368b0d5b4d2SChengji Yao "them could be found in the exactly specified annotation " 369b0d5b4d2SChengji Yao "of op results or operands."; 370b0d5b4d2SChengji Yao } 371b0d5b4d2SChengji Yao } 372b0d5b4d2SChengji Yao } 373b0d5b4d2SChengji Yao 374b0d5b4d2SChengji Yao // 3. Finalize sharding option 375b0d5b4d2SChengji Yao if (!partialMeshAxes.empty()) { 376b0d5b4d2SChengji Yao bool anyNonEmptyReductionLoop = llvm::any_of( 377b0d5b4d2SChengji Yao llvm::enumerate(shardingOption.shardingArray), [&](auto it) { 3787a4c4975SBoian Petkantchin SmallVector<MeshAxis> &subArray = it.value(); 379b0d5b4d2SChengji Yao int64_t idx = it.index(); 380b0d5b4d2SChengji Yao return isReductionLoop(loopTypes[idx]) && !subArray.empty(); 381b0d5b4d2SChengji Yao }); 382b0d5b4d2SChengji Yao if (!anyNonEmptyReductionLoop) { 383b0d5b4d2SChengji Yao bool filled = false; 384b0d5b4d2SChengji Yao for (size_t idx = 0; idx < loopTypes.size(); ++idx) { 385ff2720d1SBoian Petkantchin if (isReductionLoop(loopTypes[idx])) { 386b0d5b4d2SChengji Yao std::ignore = fillShardingOption(op, shardingOption, nullptr, 387b0d5b4d2SChengji Yao partialMeshAxes, idx); 388b0d5b4d2SChengji Yao filled = true; 389b0d5b4d2SChengji Yao break; 390b0d5b4d2SChengji Yao } 391b0d5b4d2SChengji Yao } 392b0d5b4d2SChengji Yao if (!filled) 393b0d5b4d2SChengji Yao return op->emitOpError() << "no matched reduction loop found for the " 394b0d5b4d2SChengji Yao "result's partial type"; 395b0d5b4d2SChengji Yao } 396b0d5b4d2SChengji Yao } 397b0d5b4d2SChengji Yao removeTrailingEmptySubArray(shardingOption.shardingArray); 398b0d5b4d2SChengji Yao if (!anyShardingInResultsOrOperands) 399b0d5b4d2SChengji Yao shardingOption.empty = true; 400b0d5b4d2SChengji Yao return shardingOption; 401b0d5b4d2SChengji Yao } 402b0d5b4d2SChengji Yao 403d635b860SBoian Petkantchin // Get the sharding attributed for the given result and sharding option. 404*baabcb28SFrank Schlimbach MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, 405d635b860SBoian Petkantchin AffineMap map, ArrayRef<utils::IteratorType> loopTypes, 406ff2720d1SBoian Petkantchin ArrayRef<ReductionKind> reductionLoopKinds) { 407a5757c5bSChristian Sigg auto resultType = cast<RankedTensorType>(result.getType()); 4087a4c4975SBoian Petkantchin SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); 4097a4c4975SBoian Petkantchin SmallVector<MeshAxis> partialAxes; 410b0d5b4d2SChengji Yao 411b0d5b4d2SChengji Yao // process the split axes 412b0d5b4d2SChengji Yao for (auto it : llvm::enumerate(map.getResults())) { 413*baabcb28SFrank Schlimbach SmallVector<MeshAxis> tmp_axes; 414b0d5b4d2SChengji Yao AffineExpr expr = it.value(); 415b0d5b4d2SChengji Yao // `expr` must be an `AffineDimExpr` because `map` is verified by 416b0d5b4d2SChengji Yao // isProjectedPermutation 4171609f1c2Slong.chen auto dim = cast<AffineDimExpr>(expr); 418b0d5b4d2SChengji Yao unsigned loopIdx = dim.getPosition(); 419b0d5b4d2SChengji Yao if (loopIdx < shardingOption.shardingArray.size()) 420b0d5b4d2SChengji Yao splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]); 421b0d5b4d2SChengji Yao } 422b0d5b4d2SChengji Yao 423b0d5b4d2SChengji Yao // process the partial axes 4247d367bc9SChengji Yao // partialType will be ignored if partialAxes is empty 425ff2720d1SBoian Petkantchin ReductionKind partialType = ReductionKind::Sum; 426ff2720d1SBoian Petkantchin size_t reductionLoopKindsIdx = 0; 427b0d5b4d2SChengji Yao for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) { 428ff2720d1SBoian Petkantchin utils::IteratorType iType = std::get<0>(it); 429b0d5b4d2SChengji Yao if (isReductionLoop(iType)) { 430ff2720d1SBoian Petkantchin ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx]; 431ff2720d1SBoian Petkantchin ++reductionLoopKindsIdx; 432b0d5b4d2SChengji Yao if (!partialAxes.empty()) 433b0d5b4d2SChengji Yao assert(partialType == curPartialType && 434b0d5b4d2SChengji Yao "Only one reduction type is supported"); 435b0d5b4d2SChengji Yao partialType = curPartialType; 4367a4c4975SBoian Petkantchin const SmallVector<MeshAxis> &axis = std::get<1>(it); 437b0d5b4d2SChengji Yao partialAxes.append(axis); 438b0d5b4d2SChengji Yao } 439b0d5b4d2SChengji Yao } 440b0d5b4d2SChengji Yao 441b0d5b4d2SChengji Yao removeTrailingEmptySubArray(splitAxes); 442*baabcb28SFrank Schlimbach return MeshSharding::get(shardingOption.mesh, 443*baabcb28SFrank Schlimbach fromArrayOfVector(result.getContext(), splitAxes), 444*baabcb28SFrank Schlimbach partialAxes, partialType); 445b0d5b4d2SChengji Yao } 446b0d5b4d2SChengji Yao 447*baabcb28SFrank Schlimbach static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, 448*baabcb28SFrank Schlimbach const ShardingOption &shardingOption, 449ff2720d1SBoian Petkantchin AffineMap map) { 450d635b860SBoian Petkantchin Value operandValue = opOperand.get(); 451d635b860SBoian Petkantchin auto operandType = cast<RankedTensorType>(operandValue.getType()); 4527a4c4975SBoian Petkantchin SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); 453b0d5b4d2SChengji Yao unsigned numDims = map.getNumDims(); 454b0d5b4d2SChengji Yao for (auto it : llvm::enumerate(map.getResults())) { 455b0d5b4d2SChengji Yao int64_t idx = it.index(); 456b0d5b4d2SChengji Yao AffineExpr expr = it.value(); 457b0d5b4d2SChengji Yao FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = 458b0d5b4d2SChengji Yao checkOperandAffineExpr(expr, numDims); 459b0d5b4d2SChengji Yao if (failed(loopIndices)) 460b0d5b4d2SChengji Yao return failure(); 461b0d5b4d2SChengji Yao SmallVector<unsigned> shardedLoopIndices; 462b0d5b4d2SChengji Yao for (unsigned loopIdx : *loopIndices) { 463b0d5b4d2SChengji Yao if ((size_t)loopIdx < shardingOption.shardingArray.size() && 464b0d5b4d2SChengji Yao !shardingOption.shardingArray[loopIdx].empty()) 465b0d5b4d2SChengji Yao shardedLoopIndices.push_back(loopIdx); 466b0d5b4d2SChengji Yao } 467b0d5b4d2SChengji Yao // mostly one sharded loop index is accepted 468b0d5b4d2SChengji Yao if (shardedLoopIndices.size() > 1) 469b0d5b4d2SChengji Yao return failure(); 470b0d5b4d2SChengji Yao if (shardedLoopIndices.size() == 1) { 471b0d5b4d2SChengji Yao splitAxes[idx].append( 472b0d5b4d2SChengji Yao shardingOption.shardingArray[shardedLoopIndices[0]]); 473b0d5b4d2SChengji Yao } 474b0d5b4d2SChengji Yao } 475b0d5b4d2SChengji Yao 476b0d5b4d2SChengji Yao removeTrailingEmptySubArray(splitAxes); 477*baabcb28SFrank Schlimbach return MeshSharding::get( 478*baabcb28SFrank Schlimbach shardingOption.mesh, 479*baabcb28SFrank Schlimbach fromArrayOfVector(opOperand.get().getContext(), splitAxes)); 480d635b860SBoian Petkantchin } 481d635b860SBoian Petkantchin 482*baabcb28SFrank Schlimbach FailureOr<std::vector<MeshSharding>> 483d635b860SBoian Petkantchin mesh::detail::defaultGetShardingAnnotations( 484d635b860SBoian Petkantchin Operation *op, const ShardingOption &shardingOption) { 485*baabcb28SFrank Schlimbach std::vector<MeshSharding> res; 486d635b860SBoian Petkantchin 487d635b860SBoian Petkantchin ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 488d635b860SBoian Petkantchin SmallVector<utils::IteratorType> loopTypes = 489d635b860SBoian Petkantchin shardingOp.getLoopIteratorTypes(); 490d635b860SBoian Petkantchin SmallVector<ReductionKind> reductionKinds = 491d635b860SBoian Petkantchin shardingOp.getReductionLoopIteratorKinds(); 492d635b860SBoian Petkantchin SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 493d635b860SBoian Petkantchin unsigned numOperands = op->getNumOperands(); 494d635b860SBoian Petkantchin 495d635b860SBoian Petkantchin for (OpOperand &opOperand : op->getOpOperands()) { 496*baabcb28SFrank Schlimbach FailureOr<MeshSharding> shardingAttr = getSharding( 497d635b860SBoian Petkantchin opOperand, shardingOption, maps[opOperand.getOperandNumber()]); 498d635b860SBoian Petkantchin if (failed(shardingAttr)) 499d635b860SBoian Petkantchin return failure(); 500d635b860SBoian Petkantchin res.push_back(*shardingAttr); 501d635b860SBoian Petkantchin } 502d635b860SBoian Petkantchin 503d635b860SBoian Petkantchin for (OpResult result : op->getResults()) { 504*baabcb28SFrank Schlimbach res.push_back(getSharding(result, shardingOption, 505*baabcb28SFrank Schlimbach maps[numOperands + result.getResultNumber()], 506d635b860SBoian Petkantchin loopTypes, reductionKinds)); 507d635b860SBoian Petkantchin } 508d635b860SBoian Petkantchin 509d635b860SBoian Petkantchin return res; 510d635b860SBoian Petkantchin } 511d635b860SBoian Petkantchin 512d635b860SBoian Petkantchin //===----------------------------------------------------------------------===// 513d635b860SBoian Petkantchin // detail::defaultAddShardingAnnotations 514d635b860SBoian Petkantchin //===----------------------------------------------------------------------===// 515d635b860SBoian Petkantchin 516d635b860SBoian Petkantchin // To add a `mesh.shard` op for the given result, based on the details provided 517d635b860SBoian Petkantchin // in `shardingOption`, `map`, and `loopTypes`. 518d635b860SBoian Petkantchin static LogicalResult addShardOp(OpBuilder &b, OpResult result, 519d635b860SBoian Petkantchin const ShardingOption &shardingOption, 520d635b860SBoian Petkantchin AffineMap map, 521d635b860SBoian Petkantchin ArrayRef<utils::IteratorType> loopTypes, 522d635b860SBoian Petkantchin ArrayRef<ReductionKind> reductionLoopKinds) { 523*baabcb28SFrank Schlimbach MeshSharding sharding = 524*baabcb28SFrank Schlimbach getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds); 525*baabcb28SFrank Schlimbach maybeInsertTargetShardingAnnotation(sharding, result, b); 526d635b860SBoian Petkantchin 527d635b860SBoian Petkantchin return success(); 528d635b860SBoian Petkantchin } 529d635b860SBoian Petkantchin 530d635b860SBoian Petkantchin // To add a `mesh.shard` op for the given operand, based on the details provided 531d635b860SBoian Petkantchin // in `shardingOption`, `map`, and `loopTypes`. 532d635b860SBoian Petkantchin static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, 533d635b860SBoian Petkantchin const ShardingOption &shardingOption, 534d635b860SBoian Petkantchin AffineMap map) { 535d635b860SBoian Petkantchin 536*baabcb28SFrank Schlimbach FailureOr<MeshSharding> sharding = 537*baabcb28SFrank Schlimbach getSharding(opOperand, shardingOption, map); 538*baabcb28SFrank Schlimbach if (failed(sharding)) { 539d635b860SBoian Petkantchin return failure(); 540d635b860SBoian Petkantchin } 541b0d5b4d2SChengji Yao OpBuilder::InsertionGuard guard(b); 542*baabcb28SFrank Schlimbach maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b); 543b0d5b4d2SChengji Yao 544b0d5b4d2SChengji Yao return success(); 545b0d5b4d2SChengji Yao } 546b0d5b4d2SChengji Yao 547b0d5b4d2SChengji Yao LogicalResult mesh::detail::defaultAddShardingAnnotations( 548b0d5b4d2SChengji Yao Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { 549d635b860SBoian Petkantchin assert(!shardingOption.empty && shardingOption.mesh); 550d635b860SBoian Petkantchin 551b0d5b4d2SChengji Yao ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); 552ff2720d1SBoian Petkantchin SmallVector<utils::IteratorType> loopTypes = 553ff2720d1SBoian Petkantchin shardingOp.getLoopIteratorTypes(); 554ff2720d1SBoian Petkantchin SmallVector<ReductionKind> reductionKinds = 555ff2720d1SBoian Petkantchin shardingOp.getReductionLoopIteratorKinds(); 556b0d5b4d2SChengji Yao SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); 557b0d5b4d2SChengji Yao unsigned numOperands = op->getNumOperands(); 558b0d5b4d2SChengji Yao 559b0d5b4d2SChengji Yao // 1. add mesh.shard ops for all op results 560b0d5b4d2SChengji Yao for (OpResult result : op->getResults()) { 561b0d5b4d2SChengji Yao if (failed(addShardOp(b, result, shardingOption, 562b0d5b4d2SChengji Yao maps[numOperands + result.getResultNumber()], 563ff2720d1SBoian Petkantchin loopTypes, reductionKinds))) 564b0d5b4d2SChengji Yao return failure(); 565b0d5b4d2SChengji Yao } 566b0d5b4d2SChengji Yao 567b0d5b4d2SChengji Yao // 2. add mesh.shard ops for all operands 568b0d5b4d2SChengji Yao for (OpOperand &opOperand : op->getOpOperands()) { 569b0d5b4d2SChengji Yao if (failed(addShardOp(b, opOperand, shardingOption, 570ff2720d1SBoian Petkantchin maps[opOperand.getOperandNumber()]))) 571b0d5b4d2SChengji Yao return failure(); 572b0d5b4d2SChengji Yao } 573b0d5b4d2SChengji Yao 574b0d5b4d2SChengji Yao return success(); 575b0d5b4d2SChengji Yao } 576adbf21f1SBoian Petkantchin 57786fa21e9SJie Fu #ifndef NDEBUG 578adbf21f1SBoian Petkantchin static bool 579adbf21f1SBoian Petkantchin isValueCompatibleWithFullReplicationSharding(Value value, 580*baabcb28SFrank Schlimbach MeshSharding sharding) { 581a5757c5bSChristian Sigg if (isa<RankedTensorType>(value.getType())) { 582adbf21f1SBoian Petkantchin return sharding && isFullReplication(sharding); 583adbf21f1SBoian Petkantchin } 584adbf21f1SBoian Petkantchin 585adbf21f1SBoian Petkantchin return !sharding; 586adbf21f1SBoian Petkantchin } 587adbf21f1SBoian Petkantchin 588*baabcb28SFrank Schlimbach template <typename ValueRange, typename MeshShardingRage> 589*baabcb28SFrank Schlimbach static bool 590*baabcb28SFrank Schlimbach areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, 591*baabcb28SFrank Schlimbach MeshShardingRage &&shardings) { 592adbf21f1SBoian Petkantchin if (std::size(values) != std::size(shardings)) { 593adbf21f1SBoian Petkantchin return false; 594adbf21f1SBoian Petkantchin } 595*baabcb28SFrank Schlimbach return llvm::all_of( 596*baabcb28SFrank Schlimbach llvm::zip_equal(std::forward<ValueRange>(values), 597*baabcb28SFrank Schlimbach std::forward<MeshShardingRage>(shardings)), 598adbf21f1SBoian Petkantchin [](auto valueAndSharding) { 599adbf21f1SBoian Petkantchin return isValueCompatibleWithFullReplicationSharding( 600*baabcb28SFrank Schlimbach std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); 601adbf21f1SBoian Petkantchin }); 602adbf21f1SBoian Petkantchin } 60386fa21e9SJie Fu #endif // NDEBUG 604adbf21f1SBoian Petkantchin 605adbf21f1SBoian Petkantchin void mesh::spmdizeFullyReplicatedOperation( 606adbf21f1SBoian Petkantchin Operation &op, ArrayRef<Value> spmdizedOperands, 607*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 608*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 609adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTable, OpBuilder &builder) { 610adbf21f1SBoian Petkantchin assert(spmdizedOperands.size() == operandShardings.size()); 611adbf21f1SBoian Petkantchin assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), 612adbf21f1SBoian Petkantchin operandShardings)); 613adbf21f1SBoian Petkantchin assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), 614adbf21f1SBoian Petkantchin resultShardings)); 615adbf21f1SBoian Petkantchin // `clone` will populate the mapping of old to new results. 616adbf21f1SBoian Petkantchin builder.clone(op, spmdizationMap); 617adbf21f1SBoian Petkantchin } 618adbf21f1SBoian Petkantchin 619fb582b6aSBoian Petkantchin static void updateMeshAxisAssignmentForLoopIterators( 620fb582b6aSBoian Petkantchin ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, 621fb582b6aSBoian Petkantchin SmallVector<std::optional<SmallVector<MeshAxis>>> 622fb582b6aSBoian Petkantchin &meshAxesAssignmentForLoopIterators) { 623fb582b6aSBoian Petkantchin AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); 624fb582b6aSBoian Petkantchin unsigned loopIteratorIdx = affineDimExpr.getPosition(); 625fb582b6aSBoian Petkantchin if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { 626fb582b6aSBoian Petkantchin assert(llvm::equal(meshAxesAssignmentForTensorAxis, 627fb582b6aSBoian Petkantchin *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); 628fb582b6aSBoian Petkantchin } else { 629fb582b6aSBoian Petkantchin meshAxesAssignmentForLoopIterators[loopIteratorIdx] = 630fb582b6aSBoian Petkantchin llvm::to_vector(meshAxesAssignmentForTensorAxis); 631fb582b6aSBoian Petkantchin } 632fb582b6aSBoian Petkantchin } 633fb582b6aSBoian Petkantchin 634fb582b6aSBoian Petkantchin ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( 635*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 636*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 637fb582b6aSBoian Petkantchin ArrayRef<utils::IteratorType> loopIteratorTypes, 638fb582b6aSBoian Petkantchin ArrayRef<AffineMap> indexingMaps) { 639fb582b6aSBoian Petkantchin SmallVector<std::optional<SmallVector<MeshAxis>>> 640fb582b6aSBoian Petkantchin meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); 641*baabcb28SFrank Schlimbach std::vector<MeshSharding> operatorAndResultShardings; 642fb582b6aSBoian Petkantchin operatorAndResultShardings.reserve(operandShardings.size() + 643fb582b6aSBoian Petkantchin resultShardings.size()); 644fb582b6aSBoian Petkantchin llvm::append_range(operatorAndResultShardings, operandShardings); 645fb582b6aSBoian Petkantchin for (auto [sharding, affineMap] : 646fb582b6aSBoian Petkantchin llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { 647fb582b6aSBoian Petkantchin if (!sharding) { 648fb582b6aSBoian Petkantchin continue; 649fb582b6aSBoian Petkantchin } 650fb582b6aSBoian Petkantchin for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : 651fb582b6aSBoian Petkantchin llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { 652fb582b6aSBoian Petkantchin updateMeshAxisAssignmentForLoopIterators( 653fb582b6aSBoian Petkantchin meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, 654fb582b6aSBoian Petkantchin meshAxisAssignmentForLoopIterators); 655fb582b6aSBoian Petkantchin } 656fb582b6aSBoian Petkantchin // Missing trailing split axes means replication on those tensor dimensions. 657fb582b6aSBoian Petkantchin for (unsigned i = sharding.getSplitAxes().size(); 658fb582b6aSBoian Petkantchin i < affineMap.getNumResults(); ++i) { 659fb582b6aSBoian Petkantchin updateMeshAxisAssignmentForLoopIterators( 660fb582b6aSBoian Petkantchin {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); 661fb582b6aSBoian Petkantchin } 662fb582b6aSBoian Petkantchin } 663fb582b6aSBoian Petkantchin 664fb582b6aSBoian Petkantchin ShardingArray res; 665fb582b6aSBoian Petkantchin llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), 666fb582b6aSBoian Petkantchin [](std::optional<SmallVector<MeshAxis>> &axes) { 667fb582b6aSBoian Petkantchin if (!axes) { 668fb582b6aSBoian Petkantchin return SmallVector<MeshAxis>(); 669fb582b6aSBoian Petkantchin }; 670fb582b6aSBoian Petkantchin return std::move(*axes); 671fb582b6aSBoian Petkantchin }); 672fb582b6aSBoian Petkantchin return res; 673fb582b6aSBoian Petkantchin } 674fb582b6aSBoian Petkantchin 675fb582b6aSBoian Petkantchin bool mesh::isAtLeastOneReductionIteratorSharded( 676fb582b6aSBoian Petkantchin ArrayRef<utils::IteratorType> loopIteratorTypes, 677fb582b6aSBoian Petkantchin ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { 678fb582b6aSBoian Petkantchin for (auto [loopIteratorType, meshAxisAssignment] : 679fb582b6aSBoian Petkantchin llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 680fb582b6aSBoian Petkantchin if (loopIteratorType == utils::IteratorType::reduction && 681fb582b6aSBoian Petkantchin !meshAxisAssignment.empty()) { 682fb582b6aSBoian Petkantchin return true; 683fb582b6aSBoian Petkantchin } 684fb582b6aSBoian Petkantchin } 685fb582b6aSBoian Petkantchin return false; 686fb582b6aSBoian Petkantchin } 687fb582b6aSBoian Petkantchin 688fb582b6aSBoian Petkantchin SmallVector<MeshAxis> mesh::getReductionMeshAxes( 689fb582b6aSBoian Petkantchin ArrayRef<utils::IteratorType> loopIteratorTypes, 690fb582b6aSBoian Petkantchin ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { 691fb582b6aSBoian Petkantchin SmallVector<MeshAxis> meshAxes; 692fb582b6aSBoian Petkantchin for (auto [loopIteratorType, meshAxisAssignment] : 693fb582b6aSBoian Petkantchin llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 694fb582b6aSBoian Petkantchin if (loopIteratorType == utils::IteratorType::reduction) { 695fb582b6aSBoian Petkantchin llvm::append_range(meshAxes, meshAxisAssignment); 696fb582b6aSBoian Petkantchin } 697fb582b6aSBoian Petkantchin } 698fb582b6aSBoian Petkantchin return meshAxes; 699fb582b6aSBoian Petkantchin } 700fb582b6aSBoian Petkantchin 701adbf21f1SBoian Petkantchin void mesh::spmdizeTriviallyShardableOperation( 702adbf21f1SBoian Petkantchin Operation &op, ArrayRef<Value> spmdizedOperands, 703*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 704*baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 705adbf21f1SBoian Petkantchin SymbolTableCollection &symbolTable, OpBuilder &builder) { 706adbf21f1SBoian Petkantchin // `clone` will populate the mapping of old to new results. 707adbf21f1SBoian Petkantchin Operation *newOp = builder.clone(op, spmdizationMap); 708adbf21f1SBoian Petkantchin // Set the result types to the sharded counterparts. 709adbf21f1SBoian Petkantchin for (auto [oldResult, newResult, sharding] : 710fb582b6aSBoian Petkantchin llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { 711*baabcb28SFrank Schlimbach newResult.setType( 712*baabcb28SFrank Schlimbach shardType(newResult.getType(), 713*baabcb28SFrank Schlimbach getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding)); 714adbf21f1SBoian Petkantchin } 715adbf21f1SBoian Petkantchin } 716