xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1fb582b6aSBoian Petkantchin //===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2fb582b6aSBoian Petkantchin //
3fb582b6aSBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fb582b6aSBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
5fb582b6aSBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fb582b6aSBoian Petkantchin //
7fb582b6aSBoian Petkantchin //===----------------------------------------------------------------------===//
8fb582b6aSBoian Petkantchin 
9fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
10fb582b6aSBoian Petkantchin 
11fb582b6aSBoian Petkantchin #include "mlir/Analysis/SliceAnalysis.h"
12fb582b6aSBoian Petkantchin #include "mlir/Dialect/Affine/IR/AffineOps.h"
13fb582b6aSBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h"
14fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/Linalg.h"
15fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
17fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
18fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
19fb582b6aSBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
20fb582b6aSBoian Petkantchin #include "mlir/Dialect/SCF/IR/SCF.h"
21fb582b6aSBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h"
22fb582b6aSBoian Petkantchin #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23fb582b6aSBoian Petkantchin #include "mlir/IR/AffineExpr.h"
24fb582b6aSBoian Petkantchin #include "mlir/IR/DialectRegistry.h"
25fb582b6aSBoian Petkantchin #include "mlir/IR/IRMapping.h"
26fb582b6aSBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
27fb582b6aSBoian Petkantchin #include "mlir/IR/MLIRContext.h"
28fb582b6aSBoian Petkantchin #include "mlir/IR/OpDefinition.h"
29fb582b6aSBoian Petkantchin #include "mlir/IR/Operation.h"
30fb582b6aSBoian Petkantchin #include "mlir/IR/SymbolTable.h"
31fb582b6aSBoian Petkantchin #include "mlir/IR/Value.h"
32fb582b6aSBoian Petkantchin #include "mlir/Interfaces/TilingInterface.h"
33fb582b6aSBoian Petkantchin #include "llvm/ADT/ArrayRef.h"
34fb582b6aSBoian Petkantchin #include "llvm/ADT/STLExtras.h"
35fb582b6aSBoian Petkantchin #include "llvm/ADT/SmallVector.h"
36fb582b6aSBoian Petkantchin #include "llvm/ADT/TypeSwitch.h"
37fb582b6aSBoian Petkantchin #include <iterator>
38d635b860SBoian Petkantchin #include <numeric>
39fb582b6aSBoian Petkantchin #include <optional>
40fb582b6aSBoian Petkantchin #include <utility>
41fb582b6aSBoian Petkantchin 
42fb582b6aSBoian Petkantchin namespace mlir::linalg {
43fb582b6aSBoian Petkantchin 
44fb582b6aSBoian Petkantchin using MeshAxis = mesh::MeshAxis;
45fb582b6aSBoian Petkantchin using ReductionKind = mesh::ReductionKind;
46baabcb28SFrank Schlimbach using MeshSharding = mesh::MeshSharding;
47fb582b6aSBoian Petkantchin using ShardingArray = mesh::ShardingArray;
48fb582b6aSBoian Petkantchin using MeshOp = mesh::MeshOp;
49fb582b6aSBoian Petkantchin 
50fb582b6aSBoian Petkantchin // Returns the corresponding mesh reduction kind for the given arith op.
51fb582b6aSBoian Petkantchin static ReductionKind getReductionKind(Operation *op) {
52fb582b6aSBoian Petkantchin   return llvm::TypeSwitch<Operation *, ReductionKind>(op)
53fb582b6aSBoian Petkantchin       // Floating-point operations.
54fb582b6aSBoian Petkantchin       .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55fb582b6aSBoian Petkantchin       .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56fb582b6aSBoian Petkantchin       // TODO: handle maxnumf and minnumf.
57fb582b6aSBoian Petkantchin       .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58fb582b6aSBoian Petkantchin       .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59fb582b6aSBoian Petkantchin       // Integer operations.
60fb582b6aSBoian Petkantchin       .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61fb582b6aSBoian Petkantchin       .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62fb582b6aSBoian Petkantchin       .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63fb582b6aSBoian Petkantchin       .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64fb582b6aSBoian Petkantchin       // TODO: handle signless, signed and unsigned types properly.
65fb582b6aSBoian Petkantchin       // It is assumed that the element type of the collective operands and
66fb582b6aSBoian Petkantchin       // result drive the meaning of the reduction kind, whether it is signed
67fb582b6aSBoian Petkantchin       // or unsigned.
68fb582b6aSBoian Petkantchin       // The reduction op inside the linalg op may have different result type
69fb582b6aSBoian Petkantchin       // from the element type of the linalg op's result.
70fb582b6aSBoian Petkantchin       // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71fb582b6aSBoian Petkantchin       // or signless operands.
72fb582b6aSBoian Petkantchin       // Maybe expand the reduction kinds.
73fb582b6aSBoian Petkantchin       .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74fb582b6aSBoian Petkantchin       .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75fb582b6aSBoian Petkantchin       .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76fb582b6aSBoian Petkantchin       .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77fb582b6aSBoian Petkantchin       .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78fb582b6aSBoian Petkantchin       .Default([](Operation *op) { return ReductionKind::Generic; });
79fb582b6aSBoian Petkantchin }
80fb582b6aSBoian Petkantchin 
81fb582b6aSBoian Petkantchin static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82fb582b6aSBoian Petkantchin   SmallVector<Operation *> combinerOps;
83fb582b6aSBoian Petkantchin   Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84fb582b6aSBoian Petkantchin   if (!reducedValue || combinerOps.size() != 1) {
85fb582b6aSBoian Petkantchin     return std::nullopt;
86fb582b6aSBoian Petkantchin   }
87fb582b6aSBoian Petkantchin 
88fb582b6aSBoian Petkantchin   return combinerOps[0];
89fb582b6aSBoian Petkantchin }
90fb582b6aSBoian Petkantchin 
91fb582b6aSBoian Petkantchin static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
92fb582b6aSBoian Petkantchin   std::optional<Operation *> reductionOp = getCombinerOp(op);
93fb582b6aSBoian Petkantchin   if (!reductionOp) {
94fb582b6aSBoian Petkantchin     return ReductionKind::Generic;
95fb582b6aSBoian Petkantchin   }
96474a73d9SJie Fu   [[maybe_unused]] Type resultElementType =
97fb582b6aSBoian Petkantchin       llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98fb582b6aSBoian Petkantchin   // TODO: handle case when result type of the reduction op does not match the
99fb582b6aSBoian Petkantchin   // element type of the result tensor.
100fb582b6aSBoian Petkantchin   // Would it makes sense at all?
101fb582b6aSBoian Petkantchin   assert(resultElementType == reductionOp.value()->getResult(0).getType());
102fb582b6aSBoian Petkantchin   return getReductionKind(reductionOp.value());
103fb582b6aSBoian Petkantchin }
104fb582b6aSBoian Petkantchin 
105baabcb28SFrank Schlimbach static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106baabcb28SFrank Schlimbach                       ArrayRef<MeshSharding> resultShardings,
107fb582b6aSBoian Petkantchin                       SymbolTableCollection &symbolTable) {
108b7981a78SAdrian Kuegel   for (const MeshSharding &sharding : operandShardings) {
109fb582b6aSBoian Petkantchin     if (sharding) {
110baabcb28SFrank Schlimbach       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111fb582b6aSBoian Petkantchin     }
112fb582b6aSBoian Petkantchin   }
113fb582b6aSBoian Petkantchin 
114b7981a78SAdrian Kuegel   for (const MeshSharding &sharding : resultShardings) {
115fb582b6aSBoian Petkantchin     if (sharding) {
116baabcb28SFrank Schlimbach       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117fb582b6aSBoian Petkantchin     }
118fb582b6aSBoian Petkantchin   }
119fb582b6aSBoian Petkantchin 
120fb582b6aSBoian Petkantchin   assert(false);
121474a73d9SJie Fu   return nullptr;
122fb582b6aSBoian Petkantchin }
123fb582b6aSBoian Petkantchin 
124fb582b6aSBoian Petkantchin // Choose the operand based on the current process index along the reduction
125fb582b6aSBoian Petkantchin // mesh axes.
126fb582b6aSBoian Petkantchin // We need to use the initial value only once to avoid including it in the
127fb582b6aSBoian Petkantchin // reduction multiple times.
128fb582b6aSBoian Petkantchin // In each process group only the leading process with linear index 0 would use
129fb582b6aSBoian Petkantchin // the original operand.
130fb582b6aSBoian Petkantchin // The other processes would use the reduction operation neutral tensor.
131fb582b6aSBoian Petkantchin static Value createDestinationPassingStyleInitOperand(
132*91bbebc7SKunwar Grover     LinalgOp op, int operandNumber, Value spmdizedOperand,
133*91bbebc7SKunwar Grover     ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134*91bbebc7SKunwar Grover     ImplicitLocOpBuilder &builder) {
135fb582b6aSBoian Petkantchin   Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136fb582b6aSBoian Petkantchin       meshOp.getSymName(), reductionMeshAxes, builder);
137fb582b6aSBoian Petkantchin   Value zero = builder.create<arith::ConstantIndexOp>(0);
138fb582b6aSBoian Petkantchin   Value isLeadProcess = builder.create<arith::CmpIOp>(
139fb582b6aSBoian Petkantchin       builder.getI1Type(), arith::CmpIPredicate::eq,
140fb582b6aSBoian Petkantchin       processLinearIndexInReductionGroup, zero);
141fb582b6aSBoian Petkantchin   scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142fb582b6aSBoian Petkantchin                                              isLeadProcess, true, true);
143fb582b6aSBoian Petkantchin   // Then block.
144fb582b6aSBoian Petkantchin   {
145fb582b6aSBoian Petkantchin     OpBuilder::InsertionGuard insertionGuard(builder);
146fb582b6aSBoian Petkantchin     builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147fb582b6aSBoian Petkantchin     builder.create<scf::YieldOp>(spmdizedOperand);
148fb582b6aSBoian Petkantchin   }
149fb582b6aSBoian Petkantchin 
150fb582b6aSBoian Petkantchin   // Else block.
151fb582b6aSBoian Petkantchin   {
152fb582b6aSBoian Petkantchin     OpBuilder::InsertionGuard insertionGuard(builder);
153fb582b6aSBoian Petkantchin     builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
154fb582b6aSBoian Petkantchin     SmallVector<OpFoldResult> shape =
155fb582b6aSBoian Petkantchin         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156*91bbebc7SKunwar Grover 
157*91bbebc7SKunwar Grover     SmallVector<Operation *> combinerOps;
158*91bbebc7SKunwar Grover     matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159*91bbebc7SKunwar Grover     assert(combinerOps.size() == 1);
160*91bbebc7SKunwar Grover     std::optional<TypedAttr> neutralEl =
161*91bbebc7SKunwar Grover         arith::getNeutralElement(combinerOps[0]);
162*91bbebc7SKunwar Grover 
163*91bbebc7SKunwar Grover     Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164*91bbebc7SKunwar Grover                                                  neutralEl.value().getType());
165*91bbebc7SKunwar Grover     Value constant =
166*91bbebc7SKunwar Grover         builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167*91bbebc7SKunwar Grover     Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168*91bbebc7SKunwar Grover                      .getResult(0);
169*91bbebc7SKunwar Grover 
170*91bbebc7SKunwar Grover     builder.create<scf::YieldOp>(fill);
171fb582b6aSBoian Petkantchin   }
172fb582b6aSBoian Petkantchin   return ifOp.getResult(0);
173fb582b6aSBoian Petkantchin }
174fb582b6aSBoian Petkantchin 
175fb582b6aSBoian Petkantchin // Create the DPS init operands for the spmdized Linalg op.
176fb582b6aSBoian Petkantchin // Return all the new spmdized operands.
177fb582b6aSBoian Petkantchin static SmallVector<Value> createDestinationPassingStyleInitOperands(
178fb582b6aSBoian Petkantchin     LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
179fb582b6aSBoian Petkantchin     ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
180fb582b6aSBoian Petkantchin     ImplicitLocOpBuilder &builder) {
181fb582b6aSBoian Petkantchin   // TODO: add support for multiple destination passing style initial value
182fb582b6aSBoian Petkantchin   // operands.
1839329b20dSKunwar Grover   assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
184fb582b6aSBoian Petkantchin   SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
185fb582b6aSBoian Petkantchin   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
186fb582b6aSBoian Petkantchin   Value spmdizedInitOperand =
187fb582b6aSBoian Petkantchin       spmdizationMap.lookup(op->getOperands()[operandIdx]);
188fb582b6aSBoian Petkantchin   newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
189*91bbebc7SKunwar Grover       op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
190fb582b6aSBoian Petkantchin   return newOperands;
191fb582b6aSBoian Petkantchin }
192fb582b6aSBoian Petkantchin 
193fb582b6aSBoian Petkantchin static void createAllReduceForResultWithoutPartialSharding(
194fb582b6aSBoian Petkantchin     Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
195baabcb28SFrank Schlimbach     MeshSharding resultSharding, ReductionKind reductionKind,
196fb582b6aSBoian Petkantchin     IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
197fb582b6aSBoian Petkantchin   SmallVector<MeshAxis> allReduceMeshAxes;
198fb582b6aSBoian Petkantchin   llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
199fb582b6aSBoian Petkantchin                 [&resultSharding](MeshAxis axis) {
200fb582b6aSBoian Petkantchin                   return !llvm::is_contained(resultSharding.getPartialAxes(),
201fb582b6aSBoian Petkantchin                                              axis);
202fb582b6aSBoian Petkantchin                 });
203fb582b6aSBoian Petkantchin   if (allReduceMeshAxes.empty()) {
204fb582b6aSBoian Petkantchin     return;
205fb582b6aSBoian Petkantchin   }
206fb582b6aSBoian Petkantchin 
207fb582b6aSBoian Petkantchin   Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
208fb582b6aSBoian Petkantchin   Value reducedValue = builder.create<mesh::AllReduceOp>(
209baabcb28SFrank Schlimbach       spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
210baabcb28SFrank Schlimbach       reductionKind);
211fb582b6aSBoian Petkantchin   spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
212fb582b6aSBoian Petkantchin }
213fb582b6aSBoian Petkantchin 
214fb582b6aSBoian Petkantchin static void createAllReduceForResultsWithoutPartialShardings(
215fb582b6aSBoian Petkantchin     LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
216baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
217fb582b6aSBoian Petkantchin     ImplicitLocOpBuilder &builder) {
218fb582b6aSBoian Petkantchin   ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
219fb582b6aSBoian Petkantchin   for (auto [unshardedLinalgOpResult, resultSharding] :
220fb582b6aSBoian Petkantchin        llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
221fb582b6aSBoian Petkantchin     createAllReduceForResultWithoutPartialSharding(
222fb582b6aSBoian Petkantchin         unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
223fb582b6aSBoian Petkantchin         reductionKind, spmdizationMap, builder);
224fb582b6aSBoian Petkantchin   }
225fb582b6aSBoian Petkantchin }
226fb582b6aSBoian Petkantchin 
227fb582b6aSBoian Petkantchin static void spmdizeLinalgOpWithShardedReduction(
228fb582b6aSBoian Petkantchin     LinalgOp op, ArrayRef<Value> spmdizedOperands,
229baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> operandShardings,
230baabcb28SFrank Schlimbach     ArrayRef<MeshSharding> resultShardings,
231fb582b6aSBoian Petkantchin     ArrayRef<utils::IteratorType> loopIteratorTypes,
232fb582b6aSBoian Petkantchin     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
233fb582b6aSBoian Petkantchin     IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
234fb582b6aSBoian Petkantchin     ImplicitLocOpBuilder &builder) {
235fb582b6aSBoian Petkantchin   MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
236fb582b6aSBoian Petkantchin   SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
237fb582b6aSBoian Petkantchin       loopIteratorTypes, meshAxisAssignmentForLoopIterators);
238fb582b6aSBoian Petkantchin   SmallVector<Value> spmdizedLinalgOpOperands =
239fb582b6aSBoian Petkantchin       createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
240fb582b6aSBoian Petkantchin                                                 reductionMeshAxes,
241fb582b6aSBoian Petkantchin                                                 spmdizationMap, builder);
242fb582b6aSBoian Petkantchin   // We must not change the operand mappings of the original spmdizationMap as
243fb582b6aSBoian Petkantchin   // they are the mappings for the whole spmdization blob and may be used by
244fb582b6aSBoian Petkantchin   // others.
245fb582b6aSBoian Petkantchin   IRMapping internalSpmdizationMap;
246fb582b6aSBoian Petkantchin   for (auto [unshardedOperand, spmdizedOperand] :
247fb582b6aSBoian Petkantchin        llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
248fb582b6aSBoian Petkantchin     internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
249fb582b6aSBoian Petkantchin   }
250fb582b6aSBoian Petkantchin   spmdizeTriviallyShardableOperation(
251fb582b6aSBoian Petkantchin       *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
252fb582b6aSBoian Petkantchin       internalSpmdizationMap, symbolTable, builder);
253fb582b6aSBoian Petkantchin   for (Value result : op->getResults()) {
254fb582b6aSBoian Petkantchin     spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
255fb582b6aSBoian Petkantchin   }
256fb582b6aSBoian Petkantchin 
257fb582b6aSBoian Petkantchin   // Handle partial shardings.
258fb582b6aSBoian Petkantchin   createAllReduceForResultsWithoutPartialShardings(
259fb582b6aSBoian Petkantchin       op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
260fb582b6aSBoian Petkantchin }
261fb582b6aSBoian Petkantchin 
262fb582b6aSBoian Petkantchin namespace {
263fb582b6aSBoian Petkantchin 
264fb582b6aSBoian Petkantchin // ShardingInterface for ops that implement LinalgStructuredInterface.
265fb582b6aSBoian Petkantchin // The supported ops are only those where the indexing maps are projected
266fb582b6aSBoian Petkantchin // permutations.
267fb582b6aSBoian Petkantchin template <typename Op>
268fb582b6aSBoian Petkantchin struct StructuredOpShardingInterface
269fb582b6aSBoian Petkantchin     : public mesh::ShardingInterface::ExternalModel<
270fb582b6aSBoian Petkantchin           StructuredOpShardingInterface<Op>, Op> {
271fb582b6aSBoian Petkantchin   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
272fb582b6aSBoian Petkantchin     return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
273fb582b6aSBoian Petkantchin   }
274fb582b6aSBoian Petkantchin 
275fb582b6aSBoian Petkantchin   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
276fb582b6aSBoian Petkantchin     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
277fb582b6aSBoian Petkantchin     SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
278fb582b6aSBoian Petkantchin 
279fb582b6aSBoian Petkantchin     // Results must have the same indexing as destination passing style initial
280fb582b6aSBoian Petkantchin     // operands.
281fb582b6aSBoian Petkantchin     for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
282fb582b6aSBoian Petkantchin       res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
283fb582b6aSBoian Petkantchin     }
284fb582b6aSBoian Petkantchin 
285fb582b6aSBoian Petkantchin     return res;
286fb582b6aSBoian Petkantchin   }
287fb582b6aSBoian Petkantchin 
288d635b860SBoian Petkantchin   SmallVector<ReductionKind>
289d635b860SBoian Petkantchin   getReductionLoopIteratorKinds(Operation *op) const {
290d635b860SBoian Petkantchin     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
291d635b860SBoian Petkantchin     SmallVector<utils::IteratorType> iteratorTypes =
292d635b860SBoian Petkantchin         linalgOp.getIteratorTypesArray();
293d635b860SBoian Petkantchin     unsigned reductionItersCount = std::accumulate(
294d635b860SBoian Petkantchin         iteratorTypes.begin(), iteratorTypes.end(), 0,
295d635b860SBoian Petkantchin         [](unsigned count, utils::IteratorType iter) {
296d635b860SBoian Petkantchin           return count + (iter == utils::IteratorType::reduction);
297d635b860SBoian Petkantchin         });
298d635b860SBoian Petkantchin     mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
299d635b860SBoian Petkantchin     return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
300d635b860SBoian Petkantchin   }
301d635b860SBoian Petkantchin 
302fb582b6aSBoian Petkantchin   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
303baabcb28SFrank Schlimbach                         ArrayRef<MeshSharding> operandShardings,
304baabcb28SFrank Schlimbach                         ArrayRef<MeshSharding> resultShardings,
305fb582b6aSBoian Petkantchin                         IRMapping &spmdizationMap,
306fb582b6aSBoian Petkantchin                         SymbolTableCollection &symbolTable,
307fb582b6aSBoian Petkantchin                         OpBuilder &builder) const {
308fb582b6aSBoian Petkantchin     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
309fb582b6aSBoian Petkantchin 
310fb582b6aSBoian Petkantchin     SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
311fb582b6aSBoian Petkantchin     bool allIndexingMapsAreProjectedPermutation =
312fb582b6aSBoian Petkantchin         llvm::all_of(indexingMaps, [](AffineMap map) {
313fb582b6aSBoian Petkantchin           return map.isProjectedPermutation();
314fb582b6aSBoian Petkantchin         });
315fb582b6aSBoian Petkantchin     if (!allIndexingMapsAreProjectedPermutation) {
316fb582b6aSBoian Petkantchin       // TODO: handle non-projected permutations.
317fb582b6aSBoian Petkantchin       return op->emitOpError()
318fb582b6aSBoian Petkantchin              << "supports indexing maps that are only projected permutation.";
319fb582b6aSBoian Petkantchin     }
320fb582b6aSBoian Petkantchin 
321fb582b6aSBoian Petkantchin     SmallVector<utils::IteratorType> loopIteratorTypes =
322fb582b6aSBoian Petkantchin         linalgOp.getIteratorTypesArray();
323fb582b6aSBoian Petkantchin     ShardingArray meshAxisAssignmentForLoopIterators =
324fb582b6aSBoian Petkantchin         getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
325fb582b6aSBoian Petkantchin                                               loopIteratorTypes, indexingMaps);
326fb582b6aSBoian Petkantchin     if (mesh::isAtLeastOneReductionIteratorSharded(
327fb582b6aSBoian Petkantchin             loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
328fb582b6aSBoian Petkantchin       ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
329fb582b6aSBoian Petkantchin       spmdizeLinalgOpWithShardedReduction(
330fb582b6aSBoian Petkantchin           linalgOp, spmdizedOperands, operandShardings, resultShardings,
331fb582b6aSBoian Petkantchin           loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
332fb582b6aSBoian Petkantchin           symbolTable, implicitLocBuilder);
333fb582b6aSBoian Petkantchin     } else {
334fb582b6aSBoian Petkantchin       spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
335fb582b6aSBoian Petkantchin                                          operandShardings, resultShardings,
336fb582b6aSBoian Petkantchin                                          spmdizationMap, symbolTable, builder);
337fb582b6aSBoian Petkantchin     }
338fb582b6aSBoian Petkantchin 
339fb582b6aSBoian Petkantchin     return success();
340fb582b6aSBoian Petkantchin   }
341fb582b6aSBoian Petkantchin };
342fb582b6aSBoian Petkantchin 
343fb582b6aSBoian Petkantchin } // namespace
344fb582b6aSBoian Petkantchin 
345fb582b6aSBoian Petkantchin template <typename OpType>
346fb582b6aSBoian Petkantchin static void registerOne(MLIRContext *ctx) {
347fb582b6aSBoian Petkantchin   OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
348fb582b6aSBoian Petkantchin }
349fb582b6aSBoian Petkantchin 
350fb582b6aSBoian Petkantchin /// Variadic helper function.
351fb582b6aSBoian Petkantchin template <typename... OpTypes>
352fb582b6aSBoian Petkantchin static void registerAll(MLIRContext *ctx) {
353fb582b6aSBoian Petkantchin   (registerOne<OpTypes>(ctx), ...);
354fb582b6aSBoian Petkantchin }
355fb582b6aSBoian Petkantchin 
356fb582b6aSBoian Petkantchin void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
357fb582b6aSBoian Petkantchin   registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
358fb582b6aSBoian Petkantchin     DialectRegistry registry;
359fb582b6aSBoian Petkantchin     registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
360fb582b6aSBoian Petkantchin                     tensor::TensorDialect>();
361fb582b6aSBoian Petkantchin     ctx->appendDialectRegistry(registry);
362fb582b6aSBoian Petkantchin     for (StringRef name : registry.getDialectNames())
363fb582b6aSBoian Petkantchin       ctx->getOrLoadDialect(name);
364fb582b6aSBoian Petkantchin 
365fb582b6aSBoian Petkantchin     registerOne<linalg::GenericOp>(ctx);
366fb582b6aSBoian Petkantchin     registerAll<
367fb582b6aSBoian Petkantchin #define GET_OP_LIST
368fb582b6aSBoian Petkantchin #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
369fb582b6aSBoian Petkantchin         >(ctx);
370fb582b6aSBoian Petkantchin   });
371fb582b6aSBoian Petkantchin }
372fb582b6aSBoian Petkantchin 
373fb582b6aSBoian Petkantchin } // namespace mlir::linalg
374