1fb582b6aSBoian Petkantchin //===- MeshShardingInterfaceImpl.cpp --------------------------------------===// 2fb582b6aSBoian Petkantchin // 3fb582b6aSBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fb582b6aSBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information. 5fb582b6aSBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fb582b6aSBoian Petkantchin // 7fb582b6aSBoian Petkantchin //===----------------------------------------------------------------------===// 8fb582b6aSBoian Petkantchin 9fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" 10fb582b6aSBoian Petkantchin 11fb582b6aSBoian Petkantchin #include "mlir/Analysis/SliceAnalysis.h" 12fb582b6aSBoian Petkantchin #include "mlir/Dialect/Affine/IR/AffineOps.h" 13fb582b6aSBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h" 14fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/Linalg.h" 15fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 16fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h" 17fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 18fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 19fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 20fb582b6aSBoian Petkantchin #include "mlir/Dialect/SCF/IR/SCF.h" 21fb582b6aSBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h" 22fb582b6aSBoian Petkantchin #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23fb582b6aSBoian Petkantchin #include "mlir/IR/AffineExpr.h" 24fb582b6aSBoian Petkantchin #include "mlir/IR/DialectRegistry.h" 25fb582b6aSBoian Petkantchin #include "mlir/IR/IRMapping.h" 26fb582b6aSBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h" 27fb582b6aSBoian Petkantchin #include "mlir/IR/MLIRContext.h" 28fb582b6aSBoian Petkantchin #include "mlir/IR/OpDefinition.h" 29fb582b6aSBoian Petkantchin #include "mlir/IR/Operation.h" 30fb582b6aSBoian Petkantchin #include "mlir/IR/SymbolTable.h" 31fb582b6aSBoian Petkantchin #include "mlir/IR/Value.h" 32fb582b6aSBoian Petkantchin #include "mlir/Interfaces/TilingInterface.h" 33fb582b6aSBoian Petkantchin #include "llvm/ADT/ArrayRef.h" 34fb582b6aSBoian Petkantchin #include "llvm/ADT/STLExtras.h" 35fb582b6aSBoian Petkantchin #include "llvm/ADT/SmallVector.h" 36fb582b6aSBoian Petkantchin #include "llvm/ADT/TypeSwitch.h" 37fb582b6aSBoian Petkantchin #include <iterator> 38d635b860SBoian Petkantchin #include <numeric> 39fb582b6aSBoian Petkantchin #include <optional> 40fb582b6aSBoian Petkantchin #include <utility> 41fb582b6aSBoian Petkantchin 42fb582b6aSBoian Petkantchin namespace mlir::linalg { 43fb582b6aSBoian Petkantchin 44fb582b6aSBoian Petkantchin using MeshAxis = mesh::MeshAxis; 45fb582b6aSBoian Petkantchin using ReductionKind = mesh::ReductionKind; 46baabcb28SFrank Schlimbach using MeshSharding = mesh::MeshSharding; 47fb582b6aSBoian Petkantchin using ShardingArray = mesh::ShardingArray; 48fb582b6aSBoian Petkantchin using MeshOp = mesh::MeshOp; 49fb582b6aSBoian Petkantchin 50fb582b6aSBoian Petkantchin // Returns the corresponding mesh reduction kind for the given arith op. 51fb582b6aSBoian Petkantchin static ReductionKind getReductionKind(Operation *op) { 52fb582b6aSBoian Petkantchin return llvm::TypeSwitch<Operation *, ReductionKind>(op) 53fb582b6aSBoian Petkantchin // Floating-point operations. 54fb582b6aSBoian Petkantchin .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) 55fb582b6aSBoian Petkantchin .Case([](arith::MulFOp op) { return ReductionKind::Product; }) 56fb582b6aSBoian Petkantchin // TODO: handle maxnumf and minnumf. 57fb582b6aSBoian Petkantchin .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) 58fb582b6aSBoian Petkantchin .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) 59fb582b6aSBoian Petkantchin // Integer operations. 60fb582b6aSBoian Petkantchin .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) 61fb582b6aSBoian Petkantchin .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) 62fb582b6aSBoian Petkantchin .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) 63fb582b6aSBoian Petkantchin .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) 64fb582b6aSBoian Petkantchin // TODO: handle signless, signed and unsigned types properly. 65fb582b6aSBoian Petkantchin // It is assumed that the element type of the collective operands and 66fb582b6aSBoian Petkantchin // result drive the meaning of the reduction kind, whether it is signed 67fb582b6aSBoian Petkantchin // or unsigned. 68fb582b6aSBoian Petkantchin // The reduction op inside the linalg op may have different result type 69fb582b6aSBoian Petkantchin // from the element type of the linalg op's result. 70fb582b6aSBoian Petkantchin // Also signed and unsigned Arith dialect ops may accept signed, unsigned 71fb582b6aSBoian Petkantchin // or signless operands. 72fb582b6aSBoian Petkantchin // Maybe expand the reduction kinds. 73fb582b6aSBoian Petkantchin .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) 74fb582b6aSBoian Petkantchin .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) 75fb582b6aSBoian Petkantchin .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) 76fb582b6aSBoian Petkantchin .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) 77fb582b6aSBoian Petkantchin .Case([](arith::MulIOp op) { return ReductionKind::Product; }) 78fb582b6aSBoian Petkantchin .Default([](Operation *op) { return ReductionKind::Generic; }); 79fb582b6aSBoian Petkantchin } 80fb582b6aSBoian Petkantchin 81fb582b6aSBoian Petkantchin static std::optional<Operation *> getCombinerOp(LinalgOp op) { 82fb582b6aSBoian Petkantchin SmallVector<Operation *> combinerOps; 83fb582b6aSBoian Petkantchin Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); 84fb582b6aSBoian Petkantchin if (!reducedValue || combinerOps.size() != 1) { 85fb582b6aSBoian Petkantchin return std::nullopt; 86fb582b6aSBoian Petkantchin } 87fb582b6aSBoian Petkantchin 88fb582b6aSBoian Petkantchin return combinerOps[0]; 89fb582b6aSBoian Petkantchin } 90fb582b6aSBoian Petkantchin 91fb582b6aSBoian Petkantchin static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { 92fb582b6aSBoian Petkantchin std::optional<Operation *> reductionOp = getCombinerOp(op); 93fb582b6aSBoian Petkantchin if (!reductionOp) { 94fb582b6aSBoian Petkantchin return ReductionKind::Generic; 95fb582b6aSBoian Petkantchin } 96474a73d9SJie Fu [[maybe_unused]] Type resultElementType = 97fb582b6aSBoian Petkantchin llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType(); 98fb582b6aSBoian Petkantchin // TODO: handle case when result type of the reduction op does not match the 99fb582b6aSBoian Petkantchin // element type of the result tensor. 100fb582b6aSBoian Petkantchin // Would it makes sense at all? 101fb582b6aSBoian Petkantchin assert(resultElementType == reductionOp.value()->getResult(0).getType()); 102fb582b6aSBoian Petkantchin return getReductionKind(reductionOp.value()); 103fb582b6aSBoian Petkantchin } 104fb582b6aSBoian Petkantchin 105baabcb28SFrank Schlimbach static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings, 106baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 107fb582b6aSBoian Petkantchin SymbolTableCollection &symbolTable) { 108b7981a78SAdrian Kuegel for (const MeshSharding &sharding : operandShardings) { 109fb582b6aSBoian Petkantchin if (sharding) { 110baabcb28SFrank Schlimbach return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); 111fb582b6aSBoian Petkantchin } 112fb582b6aSBoian Petkantchin } 113fb582b6aSBoian Petkantchin 114b7981a78SAdrian Kuegel for (const MeshSharding &sharding : resultShardings) { 115fb582b6aSBoian Petkantchin if (sharding) { 116baabcb28SFrank Schlimbach return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); 117fb582b6aSBoian Petkantchin } 118fb582b6aSBoian Petkantchin } 119fb582b6aSBoian Petkantchin 120fb582b6aSBoian Petkantchin assert(false); 121474a73d9SJie Fu return nullptr; 122fb582b6aSBoian Petkantchin } 123fb582b6aSBoian Petkantchin 124fb582b6aSBoian Petkantchin // Choose the operand based on the current process index along the reduction 125fb582b6aSBoian Petkantchin // mesh axes. 126fb582b6aSBoian Petkantchin // We need to use the initial value only once to avoid including it in the 127fb582b6aSBoian Petkantchin // reduction multiple times. 128fb582b6aSBoian Petkantchin // In each process group only the leading process with linear index 0 would use 129fb582b6aSBoian Petkantchin // the original operand. 130fb582b6aSBoian Petkantchin // The other processes would use the reduction operation neutral tensor. 131fb582b6aSBoian Petkantchin static Value createDestinationPassingStyleInitOperand( 132*91bbebc7SKunwar Grover LinalgOp op, int operandNumber, Value spmdizedOperand, 133*91bbebc7SKunwar Grover ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp, 134*91bbebc7SKunwar Grover ImplicitLocOpBuilder &builder) { 135fb582b6aSBoian Petkantchin Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( 136fb582b6aSBoian Petkantchin meshOp.getSymName(), reductionMeshAxes, builder); 137fb582b6aSBoian Petkantchin Value zero = builder.create<arith::ConstantIndexOp>(0); 138fb582b6aSBoian Petkantchin Value isLeadProcess = builder.create<arith::CmpIOp>( 139fb582b6aSBoian Petkantchin builder.getI1Type(), arith::CmpIPredicate::eq, 140fb582b6aSBoian Petkantchin processLinearIndexInReductionGroup, zero); 141fb582b6aSBoian Petkantchin scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), 142fb582b6aSBoian Petkantchin isLeadProcess, true, true); 143fb582b6aSBoian Petkantchin // Then block. 144fb582b6aSBoian Petkantchin { 145fb582b6aSBoian Petkantchin OpBuilder::InsertionGuard insertionGuard(builder); 146fb582b6aSBoian Petkantchin builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); 147fb582b6aSBoian Petkantchin builder.create<scf::YieldOp>(spmdizedOperand); 148fb582b6aSBoian Petkantchin } 149fb582b6aSBoian Petkantchin 150fb582b6aSBoian Petkantchin // Else block. 151fb582b6aSBoian Petkantchin { 152fb582b6aSBoian Petkantchin OpBuilder::InsertionGuard insertionGuard(builder); 153fb582b6aSBoian Petkantchin builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); 154fb582b6aSBoian Petkantchin SmallVector<OpFoldResult> shape = 155fb582b6aSBoian Petkantchin tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); 156*91bbebc7SKunwar Grover 157*91bbebc7SKunwar Grover SmallVector<Operation *> combinerOps; 158*91bbebc7SKunwar Grover matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); 159*91bbebc7SKunwar Grover assert(combinerOps.size() == 1); 160*91bbebc7SKunwar Grover std::optional<TypedAttr> neutralEl = 161*91bbebc7SKunwar Grover arith::getNeutralElement(combinerOps[0]); 162*91bbebc7SKunwar Grover 163*91bbebc7SKunwar Grover Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape, 164*91bbebc7SKunwar Grover neutralEl.value().getType()); 165*91bbebc7SKunwar Grover Value constant = 166*91bbebc7SKunwar Grover builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value()); 167*91bbebc7SKunwar Grover Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init) 168*91bbebc7SKunwar Grover .getResult(0); 169*91bbebc7SKunwar Grover 170*91bbebc7SKunwar Grover builder.create<scf::YieldOp>(fill); 171fb582b6aSBoian Petkantchin } 172fb582b6aSBoian Petkantchin return ifOp.getResult(0); 173fb582b6aSBoian Petkantchin } 174fb582b6aSBoian Petkantchin 175fb582b6aSBoian Petkantchin // Create the DPS init operands for the spmdized Linalg op. 176fb582b6aSBoian Petkantchin // Return all the new spmdized operands. 177fb582b6aSBoian Petkantchin static SmallVector<Value> createDestinationPassingStyleInitOperands( 178fb582b6aSBoian Petkantchin LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, 179fb582b6aSBoian Petkantchin ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, 180fb582b6aSBoian Petkantchin ImplicitLocOpBuilder &builder) { 181fb582b6aSBoian Petkantchin // TODO: add support for multiple destination passing style initial value 182fb582b6aSBoian Petkantchin // operands. 1839329b20dSKunwar Grover assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported."); 184fb582b6aSBoian Petkantchin SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); 185fb582b6aSBoian Petkantchin auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); 186fb582b6aSBoian Petkantchin Value spmdizedInitOperand = 187fb582b6aSBoian Petkantchin spmdizationMap.lookup(op->getOperands()[operandIdx]); 188fb582b6aSBoian Petkantchin newOperands[operandIdx] = createDestinationPassingStyleInitOperand( 189*91bbebc7SKunwar Grover op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); 190fb582b6aSBoian Petkantchin return newOperands; 191fb582b6aSBoian Petkantchin } 192fb582b6aSBoian Petkantchin 193fb582b6aSBoian Petkantchin static void createAllReduceForResultWithoutPartialSharding( 194fb582b6aSBoian Petkantchin Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes, 195baabcb28SFrank Schlimbach MeshSharding resultSharding, ReductionKind reductionKind, 196fb582b6aSBoian Petkantchin IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { 197fb582b6aSBoian Petkantchin SmallVector<MeshAxis> allReduceMeshAxes; 198fb582b6aSBoian Petkantchin llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes), 199fb582b6aSBoian Petkantchin [&resultSharding](MeshAxis axis) { 200fb582b6aSBoian Petkantchin return !llvm::is_contained(resultSharding.getPartialAxes(), 201fb582b6aSBoian Petkantchin axis); 202fb582b6aSBoian Petkantchin }); 203fb582b6aSBoian Petkantchin if (allReduceMeshAxes.empty()) { 204fb582b6aSBoian Petkantchin return; 205fb582b6aSBoian Petkantchin } 206fb582b6aSBoian Petkantchin 207fb582b6aSBoian Petkantchin Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); 208fb582b6aSBoian Petkantchin Value reducedValue = builder.create<mesh::AllReduceOp>( 209baabcb28SFrank Schlimbach spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes, 210baabcb28SFrank Schlimbach reductionKind); 211fb582b6aSBoian Petkantchin spmdizationMap.map(unshardedLinalgOpResult, reducedValue); 212fb582b6aSBoian Petkantchin } 213fb582b6aSBoian Petkantchin 214fb582b6aSBoian Petkantchin static void createAllReduceForResultsWithoutPartialShardings( 215fb582b6aSBoian Petkantchin LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, 216baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, 217fb582b6aSBoian Petkantchin ImplicitLocOpBuilder &builder) { 218fb582b6aSBoian Petkantchin ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); 219fb582b6aSBoian Petkantchin for (auto [unshardedLinalgOpResult, resultSharding] : 220fb582b6aSBoian Petkantchin llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { 221fb582b6aSBoian Petkantchin createAllReduceForResultWithoutPartialSharding( 222fb582b6aSBoian Petkantchin unshardedLinalgOpResult, opReductionMeshAxes, resultSharding, 223fb582b6aSBoian Petkantchin reductionKind, spmdizationMap, builder); 224fb582b6aSBoian Petkantchin } 225fb582b6aSBoian Petkantchin } 226fb582b6aSBoian Petkantchin 227fb582b6aSBoian Petkantchin static void spmdizeLinalgOpWithShardedReduction( 228fb582b6aSBoian Petkantchin LinalgOp op, ArrayRef<Value> spmdizedOperands, 229baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 230baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 231fb582b6aSBoian Petkantchin ArrayRef<utils::IteratorType> loopIteratorTypes, 232fb582b6aSBoian Petkantchin ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, 233fb582b6aSBoian Petkantchin IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, 234fb582b6aSBoian Petkantchin ImplicitLocOpBuilder &builder) { 235fb582b6aSBoian Petkantchin MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); 236fb582b6aSBoian Petkantchin SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( 237fb582b6aSBoian Petkantchin loopIteratorTypes, meshAxisAssignmentForLoopIterators); 238fb582b6aSBoian Petkantchin SmallVector<Value> spmdizedLinalgOpOperands = 239fb582b6aSBoian Petkantchin createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, 240fb582b6aSBoian Petkantchin reductionMeshAxes, 241fb582b6aSBoian Petkantchin spmdizationMap, builder); 242fb582b6aSBoian Petkantchin // We must not change the operand mappings of the original spmdizationMap as 243fb582b6aSBoian Petkantchin // they are the mappings for the whole spmdization blob and may be used by 244fb582b6aSBoian Petkantchin // others. 245fb582b6aSBoian Petkantchin IRMapping internalSpmdizationMap; 246fb582b6aSBoian Petkantchin for (auto [unshardedOperand, spmdizedOperand] : 247fb582b6aSBoian Petkantchin llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { 248fb582b6aSBoian Petkantchin internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); 249fb582b6aSBoian Petkantchin } 250fb582b6aSBoian Petkantchin spmdizeTriviallyShardableOperation( 251fb582b6aSBoian Petkantchin *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, 252fb582b6aSBoian Petkantchin internalSpmdizationMap, symbolTable, builder); 253fb582b6aSBoian Petkantchin for (Value result : op->getResults()) { 254fb582b6aSBoian Petkantchin spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); 255fb582b6aSBoian Petkantchin } 256fb582b6aSBoian Petkantchin 257fb582b6aSBoian Petkantchin // Handle partial shardings. 258fb582b6aSBoian Petkantchin createAllReduceForResultsWithoutPartialShardings( 259fb582b6aSBoian Petkantchin op, reductionMeshAxes, resultShardings, spmdizationMap, builder); 260fb582b6aSBoian Petkantchin } 261fb582b6aSBoian Petkantchin 262fb582b6aSBoian Petkantchin namespace { 263fb582b6aSBoian Petkantchin 264fb582b6aSBoian Petkantchin // ShardingInterface for ops that implement LinalgStructuredInterface. 265fb582b6aSBoian Petkantchin // The supported ops are only those where the indexing maps are projected 266fb582b6aSBoian Petkantchin // permutations. 267fb582b6aSBoian Petkantchin template <typename Op> 268fb582b6aSBoian Petkantchin struct StructuredOpShardingInterface 269fb582b6aSBoian Petkantchin : public mesh::ShardingInterface::ExternalModel< 270fb582b6aSBoian Petkantchin StructuredOpShardingInterface<Op>, Op> { 271fb582b6aSBoian Petkantchin SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 272fb582b6aSBoian Petkantchin return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); 273fb582b6aSBoian Petkantchin } 274fb582b6aSBoian Petkantchin 275fb582b6aSBoian Petkantchin SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 276fb582b6aSBoian Petkantchin LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 277fb582b6aSBoian Petkantchin SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray(); 278fb582b6aSBoian Petkantchin 279fb582b6aSBoian Petkantchin // Results must have the same indexing as destination passing style initial 280fb582b6aSBoian Petkantchin // operands. 281fb582b6aSBoian Petkantchin for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { 282fb582b6aSBoian Petkantchin res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); 283fb582b6aSBoian Petkantchin } 284fb582b6aSBoian Petkantchin 285fb582b6aSBoian Petkantchin return res; 286fb582b6aSBoian Petkantchin } 287fb582b6aSBoian Petkantchin 288d635b860SBoian Petkantchin SmallVector<ReductionKind> 289d635b860SBoian Petkantchin getReductionLoopIteratorKinds(Operation *op) const { 290d635b860SBoian Petkantchin LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 291d635b860SBoian Petkantchin SmallVector<utils::IteratorType> iteratorTypes = 292d635b860SBoian Petkantchin linalgOp.getIteratorTypesArray(); 293d635b860SBoian Petkantchin unsigned reductionItersCount = std::accumulate( 294d635b860SBoian Petkantchin iteratorTypes.begin(), iteratorTypes.end(), 0, 295d635b860SBoian Petkantchin [](unsigned count, utils::IteratorType iter) { 296d635b860SBoian Petkantchin return count + (iter == utils::IteratorType::reduction); 297d635b860SBoian Petkantchin }); 298d635b860SBoian Petkantchin mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); 299d635b860SBoian Petkantchin return SmallVector<ReductionKind>(reductionItersCount, reductionKind); 300d635b860SBoian Petkantchin } 301d635b860SBoian Petkantchin 302fb582b6aSBoian Petkantchin LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 303baabcb28SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 304baabcb28SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 305fb582b6aSBoian Petkantchin IRMapping &spmdizationMap, 306fb582b6aSBoian Petkantchin SymbolTableCollection &symbolTable, 307fb582b6aSBoian Petkantchin OpBuilder &builder) const { 308fb582b6aSBoian Petkantchin LinalgOp linalgOp = llvm::cast<LinalgOp>(op); 309fb582b6aSBoian Petkantchin 310fb582b6aSBoian Petkantchin SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 311fb582b6aSBoian Petkantchin bool allIndexingMapsAreProjectedPermutation = 312fb582b6aSBoian Petkantchin llvm::all_of(indexingMaps, [](AffineMap map) { 313fb582b6aSBoian Petkantchin return map.isProjectedPermutation(); 314fb582b6aSBoian Petkantchin }); 315fb582b6aSBoian Petkantchin if (!allIndexingMapsAreProjectedPermutation) { 316fb582b6aSBoian Petkantchin // TODO: handle non-projected permutations. 317fb582b6aSBoian Petkantchin return op->emitOpError() 318fb582b6aSBoian Petkantchin << "supports indexing maps that are only projected permutation."; 319fb582b6aSBoian Petkantchin } 320fb582b6aSBoian Petkantchin 321fb582b6aSBoian Petkantchin SmallVector<utils::IteratorType> loopIteratorTypes = 322fb582b6aSBoian Petkantchin linalgOp.getIteratorTypesArray(); 323fb582b6aSBoian Petkantchin ShardingArray meshAxisAssignmentForLoopIterators = 324fb582b6aSBoian Petkantchin getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, 325fb582b6aSBoian Petkantchin loopIteratorTypes, indexingMaps); 326fb582b6aSBoian Petkantchin if (mesh::isAtLeastOneReductionIteratorSharded( 327fb582b6aSBoian Petkantchin loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { 328fb582b6aSBoian Petkantchin ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); 329fb582b6aSBoian Petkantchin spmdizeLinalgOpWithShardedReduction( 330fb582b6aSBoian Petkantchin linalgOp, spmdizedOperands, operandShardings, resultShardings, 331fb582b6aSBoian Petkantchin loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, 332fb582b6aSBoian Petkantchin symbolTable, implicitLocBuilder); 333fb582b6aSBoian Petkantchin } else { 334fb582b6aSBoian Petkantchin spmdizeTriviallyShardableOperation(*op, spmdizedOperands, 335fb582b6aSBoian Petkantchin operandShardings, resultShardings, 336fb582b6aSBoian Petkantchin spmdizationMap, symbolTable, builder); 337fb582b6aSBoian Petkantchin } 338fb582b6aSBoian Petkantchin 339fb582b6aSBoian Petkantchin return success(); 340fb582b6aSBoian Petkantchin } 341fb582b6aSBoian Petkantchin }; 342fb582b6aSBoian Petkantchin 343fb582b6aSBoian Petkantchin } // namespace 344fb582b6aSBoian Petkantchin 345fb582b6aSBoian Petkantchin template <typename OpType> 346fb582b6aSBoian Petkantchin static void registerOne(MLIRContext *ctx) { 347fb582b6aSBoian Petkantchin OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx); 348fb582b6aSBoian Petkantchin } 349fb582b6aSBoian Petkantchin 350fb582b6aSBoian Petkantchin /// Variadic helper function. 351fb582b6aSBoian Petkantchin template <typename... OpTypes> 352fb582b6aSBoian Petkantchin static void registerAll(MLIRContext *ctx) { 353fb582b6aSBoian Petkantchin (registerOne<OpTypes>(ctx), ...); 354fb582b6aSBoian Petkantchin } 355fb582b6aSBoian Petkantchin 356fb582b6aSBoian Petkantchin void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { 357fb582b6aSBoian Petkantchin registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { 358fb582b6aSBoian Petkantchin DialectRegistry registry; 359fb582b6aSBoian Petkantchin registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, 360fb582b6aSBoian Petkantchin tensor::TensorDialect>(); 361fb582b6aSBoian Petkantchin ctx->appendDialectRegistry(registry); 362fb582b6aSBoian Petkantchin for (StringRef name : registry.getDialectNames()) 363fb582b6aSBoian Petkantchin ctx->getOrLoadDialect(name); 364fb582b6aSBoian Petkantchin 365fb582b6aSBoian Petkantchin registerOne<linalg::GenericOp>(ctx); 366fb582b6aSBoian Petkantchin registerAll< 367fb582b6aSBoian Petkantchin #define GET_OP_LIST 368fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 369fb582b6aSBoian Petkantchin >(ctx); 370fb582b6aSBoian Petkantchin }); 371fb582b6aSBoian Petkantchin } 372fb582b6aSBoian Petkantchin 373fb582b6aSBoian Petkantchin } // namespace mlir::linalg 374