xref: /llvm-project/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
11a8fb887SBoian Petkantchin //===- TestSimplification.cpp - Test simplification -----------------------===//
21a8fb887SBoian Petkantchin //
31a8fb887SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41a8fb887SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
51a8fb887SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61a8fb887SBoian Petkantchin //
71a8fb887SBoian Petkantchin //===----------------------------------------------------------------------===//
81a8fb887SBoian Petkantchin 
91a8fb887SBoian Petkantchin #include "mlir/Dialect/Func/IR/FuncOps.h"
101a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
111a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
121a8fb887SBoian Petkantchin #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
131a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinDialect.h"
141a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinOps.h"
151a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h"
161a8fb887SBoian Petkantchin #include "mlir/IR/Diagnostics.h"
171a8fb887SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
181a8fb887SBoian Petkantchin #include "mlir/IR/PatternMatch.h"
191a8fb887SBoian Petkantchin #include "mlir/IR/SymbolTable.h"
201a8fb887SBoian Petkantchin #include "mlir/IR/Value.h"
211a8fb887SBoian Petkantchin #include "mlir/Pass/Pass.h"
221a8fb887SBoian Petkantchin #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
231a8fb887SBoian Petkantchin 
241a8fb887SBoian Petkantchin using namespace mlir;
251a8fb887SBoian Petkantchin using namespace mlir::mesh;
261a8fb887SBoian Petkantchin 
271a8fb887SBoian Petkantchin namespace {
281a8fb887SBoian Petkantchin 
291a8fb887SBoian Petkantchin struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
301a8fb887SBoian Petkantchin   using OpRewritePattern<ShardOp>::OpRewritePattern;
311a8fb887SBoian Petkantchin 
321a8fb887SBoian Petkantchin   LogicalResult matchAndRewrite(ShardOp op,
331a8fb887SBoian Petkantchin                                 PatternRewriter &rewriter) const override {
341a8fb887SBoian Petkantchin     if (op.getAnnotateForUsers()) {
351a8fb887SBoian Petkantchin       return failure();
361a8fb887SBoian Petkantchin     }
371a8fb887SBoian Petkantchin 
381a8fb887SBoian Petkantchin     SymbolTableCollection symbolTable;
399a8437f5SBoian Petkantchin     mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
40baabcb28SFrank Schlimbach         op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
411a8fb887SBoian Petkantchin 
421a8fb887SBoian Petkantchin     bool foundUser = false;
431a8fb887SBoian Petkantchin     for (auto user : op->getUsers()) {
441a8fb887SBoian Petkantchin       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
451a8fb887SBoian Petkantchin         if (targetShardOp.getAnnotateForUsers() &&
469a8437f5SBoian Petkantchin             mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
47baabcb28SFrank Schlimbach                         targetShardOp,
48baabcb28SFrank Schlimbach                         cast<ShardingOp>(
49baabcb28SFrank Schlimbach                             targetShardOp.getSharding().getDefiningOp())
50baabcb28SFrank Schlimbach                             .getMeshAttr())) {
511a8fb887SBoian Petkantchin           foundUser = true;
521a8fb887SBoian Petkantchin           break;
531a8fb887SBoian Petkantchin         }
541a8fb887SBoian Petkantchin       }
551a8fb887SBoian Petkantchin     }
561a8fb887SBoian Petkantchin 
571a8fb887SBoian Petkantchin     if (!foundUser) {
581a8fb887SBoian Petkantchin       return failure();
591a8fb887SBoian Petkantchin     }
601a8fb887SBoian Petkantchin 
611a8fb887SBoian Petkantchin     for (auto user : op->getUsers()) {
621a8fb887SBoian Petkantchin       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
631a8fb887SBoian Petkantchin       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
649a8437f5SBoian Petkantchin           symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
65baabcb28SFrank Schlimbach               targetShardOp,
66baabcb28SFrank Schlimbach               cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
67baabcb28SFrank Schlimbach                   .getMeshAttr()) != mesh) {
681a8fb887SBoian Petkantchin         continue;
691a8fb887SBoian Petkantchin       }
701a8fb887SBoian Petkantchin 
711a8fb887SBoian Petkantchin       ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
721a8fb887SBoian Petkantchin       ShapedType sourceShardShape =
73baabcb28SFrank Schlimbach           shardShapedType(op.getResult().getType(), mesh, op.getSharding());
74a5757c5bSChristian Sigg       TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
751a8fb887SBoian Petkantchin           builder
76baabcb28SFrank Schlimbach               .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
77a5757c5bSChristian Sigg               ->getResult(0));
781a8fb887SBoian Petkantchin       TypedValue<ShapedType> targetShard =
791a8fb887SBoian Petkantchin           reshard(builder, mesh, op, targetShardOp, sourceShard);
801a8fb887SBoian Petkantchin       Value newTargetUnsharded =
811a8fb887SBoian Petkantchin           builder
821a8fb887SBoian Petkantchin               .create<UnrealizedConversionCastOp>(
831a8fb887SBoian Petkantchin                   targetShardOp.getResult().getType(), targetShard)
841a8fb887SBoian Petkantchin               ->getResult(0);
851a8fb887SBoian Petkantchin       rewriter.replaceAllUsesWith(targetShardOp.getResult(),
861a8fb887SBoian Petkantchin                                   newTargetUnsharded);
871a8fb887SBoian Petkantchin     }
881a8fb887SBoian Petkantchin 
891a8fb887SBoian Petkantchin     return success();
901a8fb887SBoian Petkantchin   }
911a8fb887SBoian Petkantchin };
921a8fb887SBoian Petkantchin 
931a8fb887SBoian Petkantchin struct TestMeshReshardingPass
941a8fb887SBoian Petkantchin     : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
951a8fb887SBoian Petkantchin   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
961a8fb887SBoian Petkantchin 
971a8fb887SBoian Petkantchin   void runOnOperation() override {
981a8fb887SBoian Petkantchin     RewritePatternSet patterns(&getContext());
991a8fb887SBoian Petkantchin     patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
100*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation().getOperation(),
1011a8fb887SBoian Petkantchin                                      std::move(patterns)))) {
1021a8fb887SBoian Petkantchin       return signalPassFailure();
1031a8fb887SBoian Petkantchin     }
1041a8fb887SBoian Petkantchin   }
1051a8fb887SBoian Petkantchin   void getDependentDialects(DialectRegistry &registry) const override {
1061a8fb887SBoian Petkantchin     reshardingRegisterDependentDialects(registry);
1071a8fb887SBoian Petkantchin     registry.insert<BuiltinDialect>();
1081a8fb887SBoian Petkantchin   }
1091a8fb887SBoian Petkantchin   StringRef getArgument() const final {
1101a8fb887SBoian Petkantchin     return "test-mesh-resharding-spmdization";
1111a8fb887SBoian Petkantchin   }
1121a8fb887SBoian Petkantchin   StringRef getDescription() const final {
1131a8fb887SBoian Petkantchin     return "Test Mesh dialect resharding spmdization.";
1141a8fb887SBoian Petkantchin   }
1151a8fb887SBoian Petkantchin };
1161a8fb887SBoian Petkantchin } // namespace
1171a8fb887SBoian Petkantchin 
1181a8fb887SBoian Petkantchin namespace mlir {
1191a8fb887SBoian Petkantchin namespace test {
1201a8fb887SBoian Petkantchin void registerTestMeshReshardingSpmdizationPass() {
1211a8fb887SBoian Petkantchin   PassRegistration<TestMeshReshardingPass>();
1221a8fb887SBoian Petkantchin }
1231a8fb887SBoian Petkantchin } // namespace test
1241a8fb887SBoian Petkantchin } // namespace mlir
125