1 //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ 10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ 11 12 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 13 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Value.h" 16 17 namespace mlir { 18 19 class Operation; 20 class IRMapping; 21 class SymbolTableCollection; 22 23 namespace mesh { 24 25 // Retrieve the mesh axes corresponding to each operation loop iterator based 26 // on the provided shardings for the op's operands and results. 27 // Assumes that the indexingMaps are projected permutations. 28 ShardingArray getMeshAxisAssignmentForLoopIterators( 29 ArrayRef<MeshSharding> operandShardings, 30 ArrayRef<MeshSharding> resultShardings, 31 ArrayRef<utils::IteratorType> loopIteratorTypes, 32 ArrayRef<AffineMap> indexingMaps); 33 34 bool isAtLeastOneReductionIteratorSharded( 35 ArrayRef<utils::IteratorType> loopIteratorTypes, 36 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); 37 38 // Get the set of mesh axes that correspond to reduction loop iterators. 39 SmallVector<MeshAxis> getReductionMeshAxes( 40 ArrayRef<utils::IteratorType> loopIteratorTypes, 41 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); 42 43 // Inserts a clone of the operation that has all ranked tensor 44 // arguments/results sharded. 45 void spmdizeTriviallyShardableOperation(Operation &op, 46 ArrayRef<Value> spmdizedOperands, 47 ArrayRef<MeshSharding> operandShardings, 48 ArrayRef<MeshSharding> resultShardings, 49 IRMapping &spmdizationMap, 50 SymbolTableCollection &symbolTable, 51 OpBuilder &builder); 52 53 // All ranked tensor argument and result dimensions have 54 // independent parallel loop iterators. 55 template <typename Op> 56 struct IndependentParallelIteratorDomainShardingInterface 57 : public ShardingInterface::ExternalModel< 58 IndependentParallelIteratorDomainShardingInterface<Op>, Op> { 59 SmallVector<utils::IteratorType> 60 getLoopIteratorTypes(Operation *operation) const { 61 SmallVector<utils::IteratorType> iterTypes; 62 for (Type t : operation->getOperandTypes()) { 63 populateIteratorTypes(t, iterTypes); 64 } 65 for (Type t : operation->getResultTypes()) { 66 populateIteratorTypes(t, iterTypes); 67 } 68 return iterTypes; 69 } 70 71 SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 72 // TODO: implement. 73 return SmallVector<AffineMap>(); 74 } 75 76 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 77 ArrayRef<MeshSharding> operandShardings, 78 ArrayRef<MeshSharding> resultShardings, 79 IRMapping &spmdizationMap, 80 SymbolTableCollection &symbolTable, 81 OpBuilder &builder) const { 82 spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, 83 resultShardings, spmdizationMap, 84 symbolTable, builder); 85 return success(); 86 } 87 88 private: 89 void 90 populateIteratorTypes(Type t, 91 SmallVector<utils::IteratorType> &iterTypes) const { 92 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t); 93 if (!rankedTensorType) { 94 return; 95 } 96 97 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank()); 98 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) { 99 iterTypes.push_back(utils::IteratorType::parallel); 100 } 101 } 102 }; 103 104 // Sharding of elementwise operations like tensor addition and multiplication. 105 template <typename ElemwiseOp> 106 struct ElementwiseShardingInterface 107 : public ShardingInterface::ExternalModel< 108 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> { 109 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 110 Value val = op->getOperand(0); 111 auto type = dyn_cast<RankedTensorType>(val.getType()); 112 if (!type) 113 return {}; 114 SmallVector<utils::IteratorType> types(type.getRank(), 115 utils::IteratorType::parallel); 116 return types; 117 } 118 119 SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 120 MLIRContext *ctx = op->getContext(); 121 Value val = op->getOperand(0); 122 auto type = dyn_cast<RankedTensorType>(val.getType()); 123 if (!type) 124 return {}; 125 int64_t rank = type.getRank(); 126 int64_t num = op->getNumOperands() + op->getNumResults(); 127 SmallVector<AffineMap> maps(num, 128 AffineMap::getMultiDimIdentityMap(rank, ctx)); 129 return maps; 130 } 131 132 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 133 ArrayRef<MeshSharding> operandShardings, 134 ArrayRef<MeshSharding> resultShardings, 135 IRMapping &spmdizationMap, 136 SymbolTableCollection &symbolTable, 137 OpBuilder &builder) const { 138 spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, 139 resultShardings, spmdizationMap, 140 symbolTable, builder); 141 return success(); 142 } 143 }; 144 145 } // namespace mesh 146 } // namespace mlir 147 148 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ 149