1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 16 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 17 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 18 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 19 #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/DialectRegistry.h" 25 #include "mlir/IR/IRMapping.h" 26 #include "mlir/IR/ImplicitLocOpBuilder.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/OpDefinition.h" 29 #include "mlir/IR/Operation.h" 30 #include "mlir/IR/SymbolTable.h" 31 #include "mlir/IR/Value.h" 32 #include "mlir/Interfaces/TilingInterface.h" 33 #include "llvm/ADT/ArrayRef.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/ADT/SmallVector.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include <iterator> 38 #include <numeric> 39 #include <optional> 40 #include <utility> 41 42 namespace mlir::linalg { 43 44 using MeshAxis = mesh::MeshAxis; 45 using ReductionKind = mesh::ReductionKind; 46 using MeshSharding = mesh::MeshSharding; 47 using ShardingArray = mesh::ShardingArray; 48 using MeshOp = mesh::MeshOp; 49 50 // Returns the corresponding mesh reduction kind for the given arith op. 51 static ReductionKind getReductionKind(Operation *op) { 52 return llvm::TypeSwitch<Operation *, ReductionKind>(op) 53 // Floating-point operations. 54 .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) 55 .Case([](arith::MulFOp op) { return ReductionKind::Product; }) 56 // TODO: handle maxnumf and minnumf. 57 .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) 58 .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) 59 // Integer operations. 60 .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) 61 .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) 62 .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) 63 .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) 64 // TODO: handle signless, signed and unsigned types properly. 65 // It is assumed that the element type of the collective operands and 66 // result drive the meaning of the reduction kind, whether it is signed 67 // or unsigned. 68 // The reduction op inside the linalg op may have different result type 69 // from the element type of the linalg op's result. 70 // Also signed and unsigned Arith dialect ops may accept signed, unsigned 71 // or signless operands. 72 // Maybe expand the reduction kinds. 73 .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) 74 .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) 75 .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) 76 .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) 77 .Case([](arith::MulIOp op) { return ReductionKind::Product; }) 78 .Default([](Operation *op) { return ReductionKind::Generic; }); 79 } 80 81 static std::optional<Operation *> getCombinerOp(LinalgOp op) { 82 SmallVector<Operation *> combinerOps; 83 Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); 84 if (!reducedValue || combinerOps.size() != 1) { 85 return std::nullopt; 86 } 87 88 return combinerOps[0]; 89 } 90 91 static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { 92 std::optional<Operation *> reductionOp = getCombinerOp(op); 93 if (!reductionOp) { 94 return ReductionKind::Generic; 95 } 96 [[maybe_unused]] Type resultElementType = 97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType(); 98 // TODO: handle case when result type of the reduction op does not match the 99 // element type of the result tensor. 100 // Would it makes sense at all? 101 assert(resultElementType == reductionOp.value()->getResult(0).getType()); 102 return getReductionKind(reductionOp.value()); 103 } 104 105 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings, 106 ArrayRef<MeshSharding> resultShardings, 107 SymbolTableCollection &symbolTable) { 108 for (const MeshSharding &sharding : operandShardings) { 109 if (sharding) { 110 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); 111 } 112 } 113 114 for (const MeshSharding &sharding : resultShardings) { 115 if (sharding) { 116 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); 117 } 118 } 119 120 assert(false); 121 return nullptr; 122 } 123 124 // Choose the operand based on the current process index along the reduction 125 // mesh axes. 126 // We need to use the initial value only once to avoid including it in the 127 // reduction multiple times. 128 // In each process group only the leading process with linear index 0 would use 129 // the original operand. 130 // The other processes would use the reduction operation neutral tensor. 131 static Value createDestinationPassingStyleInitOperand( 132 LinalgOp op, int operandNumber, Value spmdizedOperand, 133 ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp, 134 ImplicitLocOpBuilder &builder) { 135 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( 136 meshOp.getSymName(), reductionMeshAxes, builder); 137 Value zero = builder.create<arith::ConstantIndexOp>(0); 138 Value isLeadProcess = builder.create<arith::CmpIOp>( 139 builder.getI1Type(), arith::CmpIPredicate::eq, 140 processLinearIndexInReductionGroup, zero); 141 scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), 142 isLeadProcess, true, true); 143 // Then block. 144 { 145 OpBuilder::InsertionGuard insertionGuard(builder); 146 builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); 147 builder.create<scf::YieldOp>(spmdizedOperand); 148 } 149 150 // Else block. 151 { 152 OpBuilder::InsertionGuard insertionGuard(builder); 153 builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); 154 SmallVector<OpFoldResult> shape = 155 tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); 156 157 SmallVector<Operation *> combinerOps; 158 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); 159 assert(combinerOps.size() == 1); 160 std::optional<TypedAttr> neutralEl = 161 arith::getNeutralElement(combinerOps[0]); 162 163 Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape, 164 neutralEl.value().getType()); 165 Value constant = 166 builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value()); 167 Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init) 168 .getResult(0); 169 170 builder.create<scf::YieldOp>(fill); 171 } 172 return ifOp.getResult(0); 173 } 174 175 // Create the DPS init operands for the spmdized Linalg op. 176 // Return all the new spmdized operands. 177 static SmallVector<Value> createDestinationPassingStyleInitOperands( 178 LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, 179 ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, 180 ImplicitLocOpBuilder &builder) { 181 // TODO: add support for multiple destination passing style initial value 182 // operands. 183 assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported."); 184 SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); 185 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); 186 Value spmdizedInitOperand = 187 spmdizationMap.lookup(op->getOperands()[operandIdx]); 188 newOperands[operandIdx] = createDestinationPassingStyleInitOperand( 189 op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); 190 return newOperands; 191 } 192 193 static void createAllReduceForResultWithoutPartialSharding( 194 Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes, 195 MeshSharding resultSharding, ReductionKind reductionKind, 196 IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { 197 SmallVector<MeshAxis> allReduceMeshAxes; 198 llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes), 199 [&resultSharding](MeshAxis axis) { 200 return !llvm::is_contained(resultSharding.getPartialAxes(), 201 axis); 202 }); 203 if (allReduceMeshAxes.empty()) { 204 return; 205 } 206 207 Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); 208 Value reducedValue = builder.create<mesh::AllReduceOp>( 209 spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes, 210 reductionKind); 211 spmdizationMap.map(unshardedLinalgOpResult, reducedValue); 212 } 213 214 static void createAllReduceForResultsWithoutPartialShardings( 215 LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, 216 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 217 ImplicitLocOpBuilder &builder) { 218 ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); 219 for (auto [unshardedLinalgOpResult, resultSharding] : 220 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { 221 createAllReduceForResultWithoutPartialSharding( 222 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding, 223 reductionKind, spmdizationMap, builder); 224 } 225 } 226 227 static void spmdizeLinalgOpWithShardedReduction( 228 LinalgOp op, ArrayRef<Value> spmdizedOperands, 229 ArrayRef<MeshSharding> operandShardings, 230 ArrayRef<MeshSharding> resultShardings, 231 ArrayRef<utils::IteratorType> loopIteratorTypes, 232 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, 233 IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, 234 ImplicitLocOpBuilder &builder) { 235 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); 236 SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( 237 loopIteratorTypes, meshAxisAssignmentForLoopIterators); 238 SmallVector<Value> spmdizedLinalgOpOperands = 239 createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, 240 reductionMeshAxes, 241 spmdizationMap, builder); 242 // We must not change the operand mappings of the original spmdizationMap as 243 // they are the mappings for the whole spmdization blob and may be used by 244 // others. 245 IRMapping internalSpmdizationMap; 246 for (auto [unshardedOperand, spmdizedOperand] : 247 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { 248 internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); 249 } 250 spmdizeTriviallyShardableOperation( 251 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, 252 internalSpmdizationMap, symbolTable, builder); 253 for (Value result : op->getResults()) { 254 spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); 255 } 256 257 // Handle partial shardings. 258 createAllReduceForResultsWithoutPartialShardings( 259 op, reductionMeshAxes, resultShardings, spmdizationMap, builder); 260 } 261 262 namespace { 263 264 // ShardingInterface for ops that implement LinalgStructuredInterface. 265 // The supported ops are only those where the indexing maps are projected 266 // permutations. 267 template <typename Op> 268 struct StructuredOpShardingInterface 269 : public mesh::ShardingInterface::ExternalModel< 270 StructuredOpShardingInterface<Op>, Op> { 271 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 272 return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); 273 } 274 275 SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 276 LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 277 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray(); 278 279 // Results must have the same indexing as destination passing style initial 280 // operands. 281 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { 282 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); 283 } 284 285 return res; 286 } 287 288 SmallVector<ReductionKind> 289 getReductionLoopIteratorKinds(Operation *op) const { 290 LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 291 SmallVector<utils::IteratorType> iteratorTypes = 292 linalgOp.getIteratorTypesArray(); 293 unsigned reductionItersCount = std::accumulate( 294 iteratorTypes.begin(), iteratorTypes.end(), 0, 295 [](unsigned count, utils::IteratorType iter) { 296 return count + (iter == utils::IteratorType::reduction); 297 }); 298 mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); 299 return SmallVector<ReductionKind>(reductionItersCount, reductionKind); 300 } 301 302 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 303 ArrayRef<MeshSharding> operandShardings, 304 ArrayRef<MeshSharding> resultShardings, 305 IRMapping &spmdizationMap, 306 SymbolTableCollection &symbolTable, 307 OpBuilder &builder) const { 308 LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 309 310 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 311 bool allIndexingMapsAreProjectedPermutation = 312 llvm::all_of(indexingMaps, [](AffineMap map) { 313 return map.isProjectedPermutation(); 314 }); 315 if (!allIndexingMapsAreProjectedPermutation) { 316 // TODO: handle non-projected permutations. 317 return op->emitOpError() 318 << "supports indexing maps that are only projected permutation."; 319 } 320 321 SmallVector<utils::IteratorType> loopIteratorTypes = 322 linalgOp.getIteratorTypesArray(); 323 ShardingArray meshAxisAssignmentForLoopIterators = 324 getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, 325 loopIteratorTypes, indexingMaps); 326 if (mesh::isAtLeastOneReductionIteratorSharded( 327 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 328 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); 329 spmdizeLinalgOpWithShardedReduction( 330 linalgOp, spmdizedOperands, operandShardings, resultShardings, 331 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, 332 symbolTable, implicitLocBuilder); 333 } else { 334 spmdizeTriviallyShardableOperation(*op, spmdizedOperands, 335 operandShardings, resultShardings, 336 spmdizationMap, symbolTable, builder); 337 } 338 339 return success(); 340 } 341 }; 342 343 } // namespace 344 345 template <typename OpType> 346 static void registerOne(MLIRContext *ctx) { 347 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx); 348 } 349 350 /// Variadic helper function. 351 template <typename... OpTypes> 352 static void registerAll(MLIRContext *ctx) { 353 (registerOne<OpTypes>(ctx), ...); 354 } 355 356 void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { 357 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { 358 DialectRegistry registry; 359 registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, 360 tensor::TensorDialect>(); 361 ctx->appendDialectRegistry(registry); 362 for (StringRef name : registry.getDialectNames()) 363 ctx->getOrLoadDialect(name); 364 365 registerOne<linalg::GenericOp>(ctx); 366 registerAll< 367 #define GET_OP_LIST 368 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 369 >(ctx); 370 }); 371 } 372 373 } // namespace mlir::linalg 374