1*681ae097SFrank Schlimbach //===- ShardingInterfaceImpl.cpp ------------------------------------------===// 2*681ae097SFrank Schlimbach // 3*681ae097SFrank Schlimbach // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*681ae097SFrank Schlimbach // See https://llvm.org/LICENSE.txt for license information. 5*681ae097SFrank Schlimbach // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*681ae097SFrank Schlimbach // 7*681ae097SFrank Schlimbach //===----------------------------------------------------------------------===// 8*681ae097SFrank Schlimbach 9*681ae097SFrank Schlimbach #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 10*681ae097SFrank Schlimbach #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" 11*681ae097SFrank Schlimbach #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h" 12*681ae097SFrank Schlimbach #include "mlir/Dialect/Tensor/IR/Tensor.h" 13*681ae097SFrank Schlimbach #include "mlir/IR/DialectRegistry.h" 14*681ae097SFrank Schlimbach #include "llvm/Support/Debug.h" 15*681ae097SFrank Schlimbach 16*681ae097SFrank Schlimbach #define DEBUG_TYPE "tensor-sharding-impl" 17*681ae097SFrank Schlimbach #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 18*681ae097SFrank Schlimbach 19*681ae097SFrank Schlimbach using namespace mlir; 20*681ae097SFrank Schlimbach using namespace mlir::tensor; 21*681ae097SFrank Schlimbach using namespace mlir::mesh; 22*681ae097SFrank Schlimbach 23*681ae097SFrank Schlimbach namespace { 24*681ae097SFrank Schlimbach 25*681ae097SFrank Schlimbach // Sharding of tensor.empty 26*681ae097SFrank Schlimbach struct EmptyOpShardingInterface 27*681ae097SFrank Schlimbach : public ShardingInterface::ExternalModel<EmptyOpShardingInterface, 28*681ae097SFrank Schlimbach tensor::EmptyOp> { 29*681ae097SFrank Schlimbach SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30*681ae097SFrank Schlimbach auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank(); 31*681ae097SFrank Schlimbach return SmallVector<utils::IteratorType>(ndims, 32*681ae097SFrank Schlimbach utils::IteratorType::parallel); 33*681ae097SFrank Schlimbach } 34*681ae097SFrank Schlimbach 35*681ae097SFrank Schlimbach SmallVector<AffineMap> getIndexingMaps(Operation *op) const { 36*681ae097SFrank Schlimbach MLIRContext *ctx = op->getContext(); 37*681ae097SFrank Schlimbach Value val = op->getResult(0); 38*681ae097SFrank Schlimbach auto type = dyn_cast<RankedTensorType>(val.getType()); 39*681ae097SFrank Schlimbach if (!type) 40*681ae097SFrank Schlimbach return {}; 41*681ae097SFrank Schlimbach return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}; 42*681ae097SFrank Schlimbach } 43*681ae097SFrank Schlimbach 44*681ae097SFrank Schlimbach LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, 45*681ae097SFrank Schlimbach ArrayRef<MeshSharding> operandShardings, 46*681ae097SFrank Schlimbach ArrayRef<MeshSharding> resultShardings, 47*681ae097SFrank Schlimbach IRMapping &spmdizationMap, 48*681ae097SFrank Schlimbach SymbolTableCollection &symbolTable, 49*681ae097SFrank Schlimbach OpBuilder &builder) const { 50*681ae097SFrank Schlimbach auto shardType = cast<ShapedType>(mesh::shardType( 51*681ae097SFrank Schlimbach op->getResult(0).getType(), 52*681ae097SFrank Schlimbach mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable), 53*681ae097SFrank Schlimbach resultShardings[0])); 54*681ae097SFrank Schlimbach Operation *newOp = nullptr; 55*681ae097SFrank Schlimbach // if the sharding introduces a new dynamic dimension, we take it from 56*681ae097SFrank Schlimbach // the dynamic sharding info. For now bail out if it's not 57*681ae097SFrank Schlimbach // provided. 58*681ae097SFrank Schlimbach assert(resultShardings.size() == 1); 59*681ae097SFrank Schlimbach if (!shardType.hasStaticShape()) { 60*681ae097SFrank Schlimbach assert(op->getResult(0).hasOneUse()); 61*681ae097SFrank Schlimbach SmallVector<Value> newOperands; 62*681ae097SFrank Schlimbach auto oldType = cast<ShapedType>(op->getResult(0).getType()); 63*681ae097SFrank Schlimbach assert(oldType.getRank() == shardType.getRank()); 64*681ae097SFrank Schlimbach int currOldOprndNum = -1; 65*681ae097SFrank Schlimbach mesh::ShardShapeOp shapeForDevice; 66*681ae097SFrank Schlimbach Value device; 67*681ae097SFrank Schlimbach Operation *newSharding = nullptr; 68*681ae097SFrank Schlimbach for (auto i = 0; i < oldType.getRank(); ++i) { 69*681ae097SFrank Schlimbach if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) { 70*681ae097SFrank Schlimbach if (!newSharding) { 71*681ae097SFrank Schlimbach newSharding = 72*681ae097SFrank Schlimbach builder.create<ShardingOp>(op->getLoc(), resultShardings[0]); 73*681ae097SFrank Schlimbach device = builder.create<mesh::ProcessLinearIndexOp>( 74*681ae097SFrank Schlimbach op->getLoc(), resultShardings[0].getMesh()); 75*681ae097SFrank Schlimbach shapeForDevice = builder.create<mesh::ShardShapeOp>( 76*681ae097SFrank Schlimbach op->getLoc(), oldType.getShape(), newSharding->getResult(0), 77*681ae097SFrank Schlimbach device); 78*681ae097SFrank Schlimbach } 79*681ae097SFrank Schlimbach newOperands.emplace_back(shapeForDevice.getResult()[i]); 80*681ae097SFrank Schlimbach } else if (oldType.isDynamicDim(i)) { 81*681ae097SFrank Schlimbach assert(shardType.isDynamicDim(i)); 82*681ae097SFrank Schlimbach newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); 83*681ae097SFrank Schlimbach } 84*681ae097SFrank Schlimbach } 85*681ae097SFrank Schlimbach newOp = 86*681ae097SFrank Schlimbach builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands); 87*681ae097SFrank Schlimbach spmdizationMap.map(op->getResult(0), newOp->getResult(0)); 88*681ae097SFrank Schlimbach } else { 89*681ae097SFrank Schlimbach // `clone` will populate the mapping of old to new results. 90*681ae097SFrank Schlimbach newOp = builder.clone(*op, spmdizationMap); 91*681ae097SFrank Schlimbach } 92*681ae097SFrank Schlimbach newOp->getResult(0).setType(shardType); 93*681ae097SFrank Schlimbach 94*681ae097SFrank Schlimbach return success(); 95*681ae097SFrank Schlimbach } 96*681ae097SFrank Schlimbach }; 97*681ae097SFrank Schlimbach } // namespace 98*681ae097SFrank Schlimbach 99*681ae097SFrank Schlimbach void mlir::tensor::registerShardingInterfaceExternalModels( 100*681ae097SFrank Schlimbach DialectRegistry ®istry) { 101*681ae097SFrank Schlimbach 102*681ae097SFrank Schlimbach registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 103*681ae097SFrank Schlimbach EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx); 104*681ae097SFrank Schlimbach }); 105*681ae097SFrank Schlimbach } 106