1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 10 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 11 #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/DialectRegistry.h" 14 #include "llvm/Support/Debug.h" 15 16 #define DEBUG_TYPE "tensor-sharding-impl" 17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 18 19 using namespace mlir; 20 using namespace mlir::tensor; 21 using namespace mlir::mesh; 22 23 namespace { 24 25 // Sharding of tensor.empty 26 struct EmptyOpShardingInterface 27 : public ShardingInterface::ExternalModel<EmptyOpShardingInterface, 28 tensor::EmptyOp> { 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank(); 31 return SmallVector<utils::IteratorType>(ndims, 32 utils::IteratorType::parallel); 33 } 34 35 SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 36 MLIRContext *ctx = op->getContext(); 37 Value val = op->getResult(0); 38 auto type = dyn_cast<RankedTensorType>(val.getType()); 39 if (!type) 40 return {}; 41 return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}; 42 } 43 44 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 45 ArrayRef<MeshSharding> operandShardings, 46 ArrayRef<MeshSharding> resultShardings, 47 IRMapping &spmdizationMap, 48 SymbolTableCollection &symbolTable, 49 OpBuilder &builder) const { 50 auto shardType = cast<ShapedType>(mesh::shardType( 51 op->getResult(0).getType(), 52 mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable), 53 resultShardings[0])); 54 Operation *newOp = nullptr; 55 // if the sharding introduces a new dynamic dimension, we take it from 56 // the dynamic sharding info. For now bail out if it's not 57 // provided. 58 assert(resultShardings.size() == 1); 59 if (!shardType.hasStaticShape()) { 60 assert(op->getResult(0).hasOneUse()); 61 SmallVector<Value> newOperands; 62 auto oldType = cast<ShapedType>(op->getResult(0).getType()); 63 assert(oldType.getRank() == shardType.getRank()); 64 int currOldOprndNum = -1; 65 mesh::ShardShapeOp shapeForDevice; 66 Value device; 67 Operation *newSharding = nullptr; 68 for (auto i = 0; i < oldType.getRank(); ++i) { 69 if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) { 70 if (!newSharding) { 71 newSharding = 72 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]); 73 device = builder.create<mesh::ProcessLinearIndexOp>( 74 op->getLoc(), resultShardings[0].getMesh()); 75 shapeForDevice = builder.create<mesh::ShardShapeOp>( 76 op->getLoc(), oldType.getShape(), newSharding->getResult(0), 77 device); 78 } 79 newOperands.emplace_back(shapeForDevice.getResult()[i]); 80 } else if (oldType.isDynamicDim(i)) { 81 assert(shardType.isDynamicDim(i)); 82 newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); 83 } 84 } 85 newOp = 86 builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands); 87 spmdizationMap.map(op->getResult(0), newOp->getResult(0)); 88 } else { 89 // `clone` will populate the mapping of old to new results. 90 newOp = builder.clone(*op, spmdizationMap); 91 } 92 newOp->getResult(0).setType(shardType); 93 94 return success(); 95 } 96 }; 97 } // namespace 98 99 void mlir::tensor::registerShardingInterfaceExternalModels( 100 DialectRegistry ®istry) { 101 102 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 103 EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx); 104 }); 105 } 106