xref: /llvm-project/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestSimplification.cpp - Test simplification -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
11 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
12 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypeInterfaces.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 using namespace mlir;
25 using namespace mlir::mesh;
26 
27 namespace {
28 
29 struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
30   using OpRewritePattern<ShardOp>::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(ShardOp op,
33                                 PatternRewriter &rewriter) const override {
34     if (op.getAnnotateForUsers()) {
35       return failure();
36     }
37 
38     SymbolTableCollection symbolTable;
39     mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
40         op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
41 
42     bool foundUser = false;
43     for (auto user : op->getUsers()) {
44       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
45         if (targetShardOp.getAnnotateForUsers() &&
46             mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
47                         targetShardOp,
48                         cast<ShardingOp>(
49                             targetShardOp.getSharding().getDefiningOp())
50                             .getMeshAttr())) {
51           foundUser = true;
52           break;
53         }
54       }
55     }
56 
57     if (!foundUser) {
58       return failure();
59     }
60 
61     for (auto user : op->getUsers()) {
62       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
63       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
64           symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
65               targetShardOp,
66               cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
67                   .getMeshAttr()) != mesh) {
68         continue;
69       }
70 
71       ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
72       ShapedType sourceShardShape =
73           shardShapedType(op.getResult().getType(), mesh, op.getSharding());
74       TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
75           builder
76               .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
77               ->getResult(0));
78       TypedValue<ShapedType> targetShard =
79           reshard(builder, mesh, op, targetShardOp, sourceShard);
80       Value newTargetUnsharded =
81           builder
82               .create<UnrealizedConversionCastOp>(
83                   targetShardOp.getResult().getType(), targetShard)
84               ->getResult(0);
85       rewriter.replaceAllUsesWith(targetShardOp.getResult(),
86                                   newTargetUnsharded);
87     }
88 
89     return success();
90   }
91 };
92 
93 struct TestMeshReshardingPass
94     : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
95   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
96 
97   void runOnOperation() override {
98     RewritePatternSet patterns(&getContext());
99     patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
100     if (failed(applyPatternsGreedily(getOperation().getOperation(),
101                                      std::move(patterns)))) {
102       return signalPassFailure();
103     }
104   }
105   void getDependentDialects(DialectRegistry &registry) const override {
106     reshardingRegisterDependentDialects(registry);
107     registry.insert<BuiltinDialect>();
108   }
109   StringRef getArgument() const final {
110     return "test-mesh-resharding-spmdization";
111   }
112   StringRef getDescription() const final {
113     return "Test Mesh dialect resharding spmdization.";
114   }
115 };
116 } // namespace
117 
118 namespace mlir {
119 namespace test {
120 void registerTestMeshReshardingSpmdizationPass() {
121   PassRegistration<TestMeshReshardingPass>();
122 }
123 } // namespace test
124 } // namespace mlir
125