xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1 //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
17 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
18 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
19 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/DialectRegistry.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Interfaces/TilingInterface.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include <iterator>
38 #include <numeric>
39 #include <optional>
40 #include <utility>
41 
42 namespace mlir::linalg {
43 
44 using MeshAxis = mesh::MeshAxis;
45 using ReductionKind = mesh::ReductionKind;
46 using MeshSharding = mesh::MeshSharding;
47 using ShardingArray = mesh::ShardingArray;
48 using MeshOp = mesh::MeshOp;
49 
50 // Returns the corresponding mesh reduction kind for the given arith op.
51 static ReductionKind getReductionKind(Operation *op) {
52   return llvm::TypeSwitch<Operation *, ReductionKind>(op)
53       // Floating-point operations.
54       .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55       .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56       // TODO: handle maxnumf and minnumf.
57       .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58       .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59       // Integer operations.
60       .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61       .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62       .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63       .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64       // TODO: handle signless, signed and unsigned types properly.
65       // It is assumed that the element type of the collective operands and
66       // result drive the meaning of the reduction kind, whether it is signed
67       // or unsigned.
68       // The reduction op inside the linalg op may have different result type
69       // from the element type of the linalg op's result.
70       // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71       // or signless operands.
72       // Maybe expand the reduction kinds.
73       .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74       .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75       .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76       .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77       .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78       .Default([](Operation *op) { return ReductionKind::Generic; });
79 }
80 
81 static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82   SmallVector<Operation *> combinerOps;
83   Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84   if (!reducedValue || combinerOps.size() != 1) {
85     return std::nullopt;
86   }
87 
88   return combinerOps[0];
89 }
90 
91 static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
92   std::optional<Operation *> reductionOp = getCombinerOp(op);
93   if (!reductionOp) {
94     return ReductionKind::Generic;
95   }
96   [[maybe_unused]] Type resultElementType =
97       llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98   // TODO: handle case when result type of the reduction op does not match the
99   // element type of the result tensor.
100   // Would it makes sense at all?
101   assert(resultElementType == reductionOp.value()->getResult(0).getType());
102   return getReductionKind(reductionOp.value());
103 }
104 
105 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106                       ArrayRef<MeshSharding> resultShardings,
107                       SymbolTableCollection &symbolTable) {
108   for (const MeshSharding &sharding : operandShardings) {
109     if (sharding) {
110       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111     }
112   }
113 
114   for (const MeshSharding &sharding : resultShardings) {
115     if (sharding) {
116       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117     }
118   }
119 
120   assert(false);
121   return nullptr;
122 }
123 
124 // Choose the operand based on the current process index along the reduction
125 // mesh axes.
126 // We need to use the initial value only once to avoid including it in the
127 // reduction multiple times.
128 // In each process group only the leading process with linear index 0 would use
129 // the original operand.
130 // The other processes would use the reduction operation neutral tensor.
131 static Value createDestinationPassingStyleInitOperand(
132     LinalgOp op, int operandNumber, Value spmdizedOperand,
133     ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134     ImplicitLocOpBuilder &builder) {
135   Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136       meshOp.getSymName(), reductionMeshAxes, builder);
137   Value zero = builder.create<arith::ConstantIndexOp>(0);
138   Value isLeadProcess = builder.create<arith::CmpIOp>(
139       builder.getI1Type(), arith::CmpIPredicate::eq,
140       processLinearIndexInReductionGroup, zero);
141   scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142                                              isLeadProcess, true, true);
143   // Then block.
144   {
145     OpBuilder::InsertionGuard insertionGuard(builder);
146     builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147     builder.create<scf::YieldOp>(spmdizedOperand);
148   }
149 
150   // Else block.
151   {
152     OpBuilder::InsertionGuard insertionGuard(builder);
153     builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
154     SmallVector<OpFoldResult> shape =
155         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156 
157     SmallVector<Operation *> combinerOps;
158     matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159     assert(combinerOps.size() == 1);
160     std::optional<TypedAttr> neutralEl =
161         arith::getNeutralElement(combinerOps[0]);
162 
163     Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164                                                  neutralEl.value().getType());
165     Value constant =
166         builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167     Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168                      .getResult(0);
169 
170     builder.create<scf::YieldOp>(fill);
171   }
172   return ifOp.getResult(0);
173 }
174 
175 // Create the DPS init operands for the spmdized Linalg op.
176 // Return all the new spmdized operands.
177 static SmallVector<Value> createDestinationPassingStyleInitOperands(
178     LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
179     ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
180     ImplicitLocOpBuilder &builder) {
181   // TODO: add support for multiple destination passing style initial value
182   // operands.
183   assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
184   SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
185   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
186   Value spmdizedInitOperand =
187       spmdizationMap.lookup(op->getOperands()[operandIdx]);
188   newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
189       op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
190   return newOperands;
191 }
192 
193 static void createAllReduceForResultWithoutPartialSharding(
194     Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
195     MeshSharding resultSharding, ReductionKind reductionKind,
196     IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
197   SmallVector<MeshAxis> allReduceMeshAxes;
198   llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
199                 [&resultSharding](MeshAxis axis) {
200                   return !llvm::is_contained(resultSharding.getPartialAxes(),
201                                              axis);
202                 });
203   if (allReduceMeshAxes.empty()) {
204     return;
205   }
206 
207   Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
208   Value reducedValue = builder.create<mesh::AllReduceOp>(
209       spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
210       reductionKind);
211   spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
212 }
213 
214 static void createAllReduceForResultsWithoutPartialShardings(
215     LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
216     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
217     ImplicitLocOpBuilder &builder) {
218   ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
219   for (auto [unshardedLinalgOpResult, resultSharding] :
220        llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
221     createAllReduceForResultWithoutPartialSharding(
222         unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
223         reductionKind, spmdizationMap, builder);
224   }
225 }
226 
227 static void spmdizeLinalgOpWithShardedReduction(
228     LinalgOp op, ArrayRef<Value> spmdizedOperands,
229     ArrayRef<MeshSharding> operandShardings,
230     ArrayRef<MeshSharding> resultShardings,
231     ArrayRef<utils::IteratorType> loopIteratorTypes,
232     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
233     IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
234     ImplicitLocOpBuilder &builder) {
235   MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
236   SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
237       loopIteratorTypes, meshAxisAssignmentForLoopIterators);
238   SmallVector<Value> spmdizedLinalgOpOperands =
239       createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
240                                                 reductionMeshAxes,
241                                                 spmdizationMap, builder);
242   // We must not change the operand mappings of the original spmdizationMap as
243   // they are the mappings for the whole spmdization blob and may be used by
244   // others.
245   IRMapping internalSpmdizationMap;
246   for (auto [unshardedOperand, spmdizedOperand] :
247        llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
248     internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
249   }
250   spmdizeTriviallyShardableOperation(
251       *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
252       internalSpmdizationMap, symbolTable, builder);
253   for (Value result : op->getResults()) {
254     spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
255   }
256 
257   // Handle partial shardings.
258   createAllReduceForResultsWithoutPartialShardings(
259       op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
260 }
261 
262 namespace {
263 
264 // ShardingInterface for ops that implement LinalgStructuredInterface.
265 // The supported ops are only those where the indexing maps are projected
266 // permutations.
267 template <typename Op>
268 struct StructuredOpShardingInterface
269     : public mesh::ShardingInterface::ExternalModel<
270           StructuredOpShardingInterface<Op>, Op> {
271   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
272     return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
273   }
274 
275   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
276     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
277     SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
278 
279     // Results must have the same indexing as destination passing style initial
280     // operands.
281     for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
282       res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
283     }
284 
285     return res;
286   }
287 
288   SmallVector<ReductionKind>
289   getReductionLoopIteratorKinds(Operation *op) const {
290     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
291     SmallVector<utils::IteratorType> iteratorTypes =
292         linalgOp.getIteratorTypesArray();
293     unsigned reductionItersCount = std::accumulate(
294         iteratorTypes.begin(), iteratorTypes.end(), 0,
295         [](unsigned count, utils::IteratorType iter) {
296           return count + (iter == utils::IteratorType::reduction);
297         });
298     mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
299     return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
300   }
301 
302   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
303                         ArrayRef<MeshSharding> operandShardings,
304                         ArrayRef<MeshSharding> resultShardings,
305                         IRMapping &spmdizationMap,
306                         SymbolTableCollection &symbolTable,
307                         OpBuilder &builder) const {
308     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
309 
310     SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
311     bool allIndexingMapsAreProjectedPermutation =
312         llvm::all_of(indexingMaps, [](AffineMap map) {
313           return map.isProjectedPermutation();
314         });
315     if (!allIndexingMapsAreProjectedPermutation) {
316       // TODO: handle non-projected permutations.
317       return op->emitOpError()
318              << "supports indexing maps that are only projected permutation.";
319     }
320 
321     SmallVector<utils::IteratorType> loopIteratorTypes =
322         linalgOp.getIteratorTypesArray();
323     ShardingArray meshAxisAssignmentForLoopIterators =
324         getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
325                                               loopIteratorTypes, indexingMaps);
326     if (mesh::isAtLeastOneReductionIteratorSharded(
327             loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
328       ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
329       spmdizeLinalgOpWithShardedReduction(
330           linalgOp, spmdizedOperands, operandShardings, resultShardings,
331           loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
332           symbolTable, implicitLocBuilder);
333     } else {
334       spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
335                                          operandShardings, resultShardings,
336                                          spmdizationMap, symbolTable, builder);
337     }
338 
339     return success();
340   }
341 };
342 
343 } // namespace
344 
345 template <typename OpType>
346 static void registerOne(MLIRContext *ctx) {
347   OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
348 }
349 
350 /// Variadic helper function.
351 template <typename... OpTypes>
352 static void registerAll(MLIRContext *ctx) {
353   (registerOne<OpTypes>(ctx), ...);
354 }
355 
356 void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
357   registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
358     DialectRegistry registry;
359     registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
360                     tensor::TensorDialect>();
361     ctx->appendDialectRegistry(registry);
362     for (StringRef name : registry.getDialectNames())
363       ctx->getOrLoadDialect(name);
364 
365     registerOne<linalg::GenericOp>(ctx);
366     registerAll<
367 #define GET_OP_LIST
368 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
369         >(ctx);
370   });
371 }
372 
373 } // namespace mlir::linalg
374