xref: /llvm-project/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (revision 681ae0972205e575ff1fd1d7ab0ef710ae364348)
1*681ae097SFrank Schlimbach //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2*681ae097SFrank Schlimbach //
3*681ae097SFrank Schlimbach // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*681ae097SFrank Schlimbach // See https://llvm.org/LICENSE.txt for license information.
5*681ae097SFrank Schlimbach // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*681ae097SFrank Schlimbach //
7*681ae097SFrank Schlimbach //===----------------------------------------------------------------------===//
8*681ae097SFrank Schlimbach 
9*681ae097SFrank Schlimbach #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10*681ae097SFrank Schlimbach #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11*681ae097SFrank Schlimbach #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
12*681ae097SFrank Schlimbach #include "mlir/Dialect/Tensor/IR/Tensor.h"
13*681ae097SFrank Schlimbach #include "mlir/IR/DialectRegistry.h"
14*681ae097SFrank Schlimbach #include "llvm/Support/Debug.h"
15*681ae097SFrank Schlimbach 
16*681ae097SFrank Schlimbach #define DEBUG_TYPE "tensor-sharding-impl"
17*681ae097SFrank Schlimbach #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
18*681ae097SFrank Schlimbach 
19*681ae097SFrank Schlimbach using namespace mlir;
20*681ae097SFrank Schlimbach using namespace mlir::tensor;
21*681ae097SFrank Schlimbach using namespace mlir::mesh;
22*681ae097SFrank Schlimbach 
23*681ae097SFrank Schlimbach namespace {
24*681ae097SFrank Schlimbach 
25*681ae097SFrank Schlimbach // Sharding of tensor.empty
26*681ae097SFrank Schlimbach struct EmptyOpShardingInterface
27*681ae097SFrank Schlimbach     : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
28*681ae097SFrank Schlimbach                                               tensor::EmptyOp> {
29*681ae097SFrank Schlimbach   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30*681ae097SFrank Schlimbach     auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
31*681ae097SFrank Schlimbach     return SmallVector<utils::IteratorType>(ndims,
32*681ae097SFrank Schlimbach                                             utils::IteratorType::parallel);
33*681ae097SFrank Schlimbach   }
34*681ae097SFrank Schlimbach 
35*681ae097SFrank Schlimbach   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
36*681ae097SFrank Schlimbach     MLIRContext *ctx = op->getContext();
37*681ae097SFrank Schlimbach     Value val = op->getResult(0);
38*681ae097SFrank Schlimbach     auto type = dyn_cast<RankedTensorType>(val.getType());
39*681ae097SFrank Schlimbach     if (!type)
40*681ae097SFrank Schlimbach       return {};
41*681ae097SFrank Schlimbach     return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
42*681ae097SFrank Schlimbach   }
43*681ae097SFrank Schlimbach 
44*681ae097SFrank Schlimbach   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
45*681ae097SFrank Schlimbach                         ArrayRef<MeshSharding> operandShardings,
46*681ae097SFrank Schlimbach                         ArrayRef<MeshSharding> resultShardings,
47*681ae097SFrank Schlimbach                         IRMapping &spmdizationMap,
48*681ae097SFrank Schlimbach                         SymbolTableCollection &symbolTable,
49*681ae097SFrank Schlimbach                         OpBuilder &builder) const {
50*681ae097SFrank Schlimbach     auto shardType = cast<ShapedType>(mesh::shardType(
51*681ae097SFrank Schlimbach         op->getResult(0).getType(),
52*681ae097SFrank Schlimbach         mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
53*681ae097SFrank Schlimbach         resultShardings[0]));
54*681ae097SFrank Schlimbach     Operation *newOp = nullptr;
55*681ae097SFrank Schlimbach     // if the sharding introduces a new dynamic dimension, we take it from
56*681ae097SFrank Schlimbach     // the dynamic sharding info. For now bail out if it's not
57*681ae097SFrank Schlimbach     // provided.
58*681ae097SFrank Schlimbach     assert(resultShardings.size() == 1);
59*681ae097SFrank Schlimbach     if (!shardType.hasStaticShape()) {
60*681ae097SFrank Schlimbach       assert(op->getResult(0).hasOneUse());
61*681ae097SFrank Schlimbach       SmallVector<Value> newOperands;
62*681ae097SFrank Schlimbach       auto oldType = cast<ShapedType>(op->getResult(0).getType());
63*681ae097SFrank Schlimbach       assert(oldType.getRank() == shardType.getRank());
64*681ae097SFrank Schlimbach       int currOldOprndNum = -1;
65*681ae097SFrank Schlimbach       mesh::ShardShapeOp shapeForDevice;
66*681ae097SFrank Schlimbach       Value device;
67*681ae097SFrank Schlimbach       Operation *newSharding = nullptr;
68*681ae097SFrank Schlimbach       for (auto i = 0; i < oldType.getRank(); ++i) {
69*681ae097SFrank Schlimbach         if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
70*681ae097SFrank Schlimbach           if (!newSharding) {
71*681ae097SFrank Schlimbach             newSharding =
72*681ae097SFrank Schlimbach                 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
73*681ae097SFrank Schlimbach             device = builder.create<mesh::ProcessLinearIndexOp>(
74*681ae097SFrank Schlimbach                 op->getLoc(), resultShardings[0].getMesh());
75*681ae097SFrank Schlimbach             shapeForDevice = builder.create<mesh::ShardShapeOp>(
76*681ae097SFrank Schlimbach                 op->getLoc(), oldType.getShape(), newSharding->getResult(0),
77*681ae097SFrank Schlimbach                 device);
78*681ae097SFrank Schlimbach           }
79*681ae097SFrank Schlimbach           newOperands.emplace_back(shapeForDevice.getResult()[i]);
80*681ae097SFrank Schlimbach         } else if (oldType.isDynamicDim(i)) {
81*681ae097SFrank Schlimbach           assert(shardType.isDynamicDim(i));
82*681ae097SFrank Schlimbach           newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
83*681ae097SFrank Schlimbach         }
84*681ae097SFrank Schlimbach       }
85*681ae097SFrank Schlimbach       newOp =
86*681ae097SFrank Schlimbach           builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
87*681ae097SFrank Schlimbach       spmdizationMap.map(op->getResult(0), newOp->getResult(0));
88*681ae097SFrank Schlimbach     } else {
89*681ae097SFrank Schlimbach       // `clone` will populate the mapping of old to new results.
90*681ae097SFrank Schlimbach       newOp = builder.clone(*op, spmdizationMap);
91*681ae097SFrank Schlimbach     }
92*681ae097SFrank Schlimbach     newOp->getResult(0).setType(shardType);
93*681ae097SFrank Schlimbach 
94*681ae097SFrank Schlimbach     return success();
95*681ae097SFrank Schlimbach   }
96*681ae097SFrank Schlimbach };
97*681ae097SFrank Schlimbach } // namespace
98*681ae097SFrank Schlimbach 
99*681ae097SFrank Schlimbach void mlir::tensor::registerShardingInterfaceExternalModels(
100*681ae097SFrank Schlimbach     DialectRegistry &registry) {
101*681ae097SFrank Schlimbach 
102*681ae097SFrank Schlimbach   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
103*681ae097SFrank Schlimbach     EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
104*681ae097SFrank Schlimbach   });
105*681ae097SFrank Schlimbach }
106