xref: /llvm-project/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (revision 681ae0972205e575ff1fd1d7ab0ef710ae364348)
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11 #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/DialectRegistry.h"
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "tensor-sharding-impl"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 using namespace mlir::mesh;
22 
23 namespace {
24 
25 // Sharding of tensor.empty
26 struct EmptyOpShardingInterface
27     : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
28                                               tensor::EmptyOp> {
29   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30     auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
31     return SmallVector<utils::IteratorType>(ndims,
32                                             utils::IteratorType::parallel);
33   }
34 
35   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
36     MLIRContext *ctx = op->getContext();
37     Value val = op->getResult(0);
38     auto type = dyn_cast<RankedTensorType>(val.getType());
39     if (!type)
40       return {};
41     return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
42   }
43 
44   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
45                         ArrayRef<MeshSharding> operandShardings,
46                         ArrayRef<MeshSharding> resultShardings,
47                         IRMapping &spmdizationMap,
48                         SymbolTableCollection &symbolTable,
49                         OpBuilder &builder) const {
50     auto shardType = cast<ShapedType>(mesh::shardType(
51         op->getResult(0).getType(),
52         mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
53         resultShardings[0]));
54     Operation *newOp = nullptr;
55     // if the sharding introduces a new dynamic dimension, we take it from
56     // the dynamic sharding info. For now bail out if it's not
57     // provided.
58     assert(resultShardings.size() == 1);
59     if (!shardType.hasStaticShape()) {
60       assert(op->getResult(0).hasOneUse());
61       SmallVector<Value> newOperands;
62       auto oldType = cast<ShapedType>(op->getResult(0).getType());
63       assert(oldType.getRank() == shardType.getRank());
64       int currOldOprndNum = -1;
65       mesh::ShardShapeOp shapeForDevice;
66       Value device;
67       Operation *newSharding = nullptr;
68       for (auto i = 0; i < oldType.getRank(); ++i) {
69         if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
70           if (!newSharding) {
71             newSharding =
72                 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
73             device = builder.create<mesh::ProcessLinearIndexOp>(
74                 op->getLoc(), resultShardings[0].getMesh());
75             shapeForDevice = builder.create<mesh::ShardShapeOp>(
76                 op->getLoc(), oldType.getShape(), newSharding->getResult(0),
77                 device);
78           }
79           newOperands.emplace_back(shapeForDevice.getResult()[i]);
80         } else if (oldType.isDynamicDim(i)) {
81           assert(shardType.isDynamicDim(i));
82           newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
83         }
84       }
85       newOp =
86           builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
87       spmdizationMap.map(op->getResult(0), newOp->getResult(0));
88     } else {
89       // `clone` will populate the mapping of old to new results.
90       newOp = builder.clone(*op, spmdizationMap);
91     }
92     newOp->getResult(0).setType(shardType);
93 
94     return success();
95   }
96 };
97 } // namespace
98 
99 void mlir::tensor::registerShardingInterfaceExternalModels(
100     DialectRegistry &registry) {
101 
102   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
103     EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
104   });
105 }
106