xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1 //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11 
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Value.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
25 // Retrieve the mesh axes corresponding to each operation loop iterator based
26 // on the provided shardings for the op's operands and results.
27 // Assumes that the indexingMaps are projected permutations.
28 ShardingArray getMeshAxisAssignmentForLoopIterators(
29     ArrayRef<MeshSharding> operandShardings,
30     ArrayRef<MeshSharding> resultShardings,
31     ArrayRef<utils::IteratorType> loopIteratorTypes,
32     ArrayRef<AffineMap> indexingMaps);
33 
34 bool isAtLeastOneReductionIteratorSharded(
35     ArrayRef<utils::IteratorType> loopIteratorTypes,
36     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37 
38 // Get the set of mesh axes that correspond to reduction loop iterators.
39 SmallVector<MeshAxis> getReductionMeshAxes(
40     ArrayRef<utils::IteratorType> loopIteratorTypes,
41     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42 
43 // Inserts a clone of the operation that has all ranked tensor
44 // arguments/results sharded.
45 void spmdizeTriviallyShardableOperation(Operation &op,
46                                         ArrayRef<Value> spmdizedOperands,
47                                         ArrayRef<MeshSharding> operandShardings,
48                                         ArrayRef<MeshSharding> resultShardings,
49                                         IRMapping &spmdizationMap,
50                                         SymbolTableCollection &symbolTable,
51                                         OpBuilder &builder);
52 
53 // All ranked tensor argument and result dimensions have
54 // independent parallel loop iterators.
55 template <typename Op>
56 struct IndependentParallelIteratorDomainShardingInterface
57     : public ShardingInterface::ExternalModel<
58           IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
59   SmallVector<utils::IteratorType>
60   getLoopIteratorTypes(Operation *operation) const {
61     SmallVector<utils::IteratorType> iterTypes;
62     for (Type t : operation->getOperandTypes()) {
63       populateIteratorTypes(t, iterTypes);
64     }
65     for (Type t : operation->getResultTypes()) {
66       populateIteratorTypes(t, iterTypes);
67     }
68     return iterTypes;
69   }
70 
71   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
72     // TODO: implement.
73     return SmallVector<AffineMap>();
74   }
75 
76   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
77                         ArrayRef<MeshSharding> operandShardings,
78                         ArrayRef<MeshSharding> resultShardings,
79                         IRMapping &spmdizationMap,
80                         SymbolTableCollection &symbolTable,
81                         OpBuilder &builder) const {
82     spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
83                                        resultShardings, spmdizationMap,
84                                        symbolTable, builder);
85     return success();
86   }
87 
88 private:
89   void
90   populateIteratorTypes(Type t,
91                         SmallVector<utils::IteratorType> &iterTypes) const {
92     RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
93     if (!rankedTensorType) {
94       return;
95     }
96 
97     iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
98     for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
99       iterTypes.push_back(utils::IteratorType::parallel);
100     }
101   }
102 };
103 
104 // Sharding of elementwise operations like tensor addition and multiplication.
105 template <typename ElemwiseOp>
106 struct ElementwiseShardingInterface
107     : public ShardingInterface::ExternalModel<
108           ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
109   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
110     Value val = op->getOperand(0);
111     auto type = dyn_cast<RankedTensorType>(val.getType());
112     if (!type)
113       return {};
114     SmallVector<utils::IteratorType> types(type.getRank(),
115                                            utils::IteratorType::parallel);
116     return types;
117   }
118 
119   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
120     MLIRContext *ctx = op->getContext();
121     Value val = op->getOperand(0);
122     auto type = dyn_cast<RankedTensorType>(val.getType());
123     if (!type)
124       return {};
125     int64_t rank = type.getRank();
126     int64_t num = op->getNumOperands() + op->getNumResults();
127     SmallVector<AffineMap> maps(num,
128                                 AffineMap::getMultiDimIdentityMap(rank, ctx));
129     return maps;
130   }
131 
132   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
133                         ArrayRef<MeshSharding> operandShardings,
134                         ArrayRef<MeshSharding> resultShardings,
135                         IRMapping &spmdizationMap,
136                         SymbolTableCollection &symbolTable,
137                         OpBuilder &builder) const {
138     spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
139                                        resultShardings, spmdizationMap,
140                                        symbolTable, builder);
141     return success();
142   }
143 };
144 
145 } // namespace mesh
146 } // namespace mlir
147 
148 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
149