11a8fb887SBoian Petkantchin //===- TestSimplification.cpp - Test simplification -----------------------===// 21a8fb887SBoian Petkantchin // 31a8fb887SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41a8fb887SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information. 51a8fb887SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61a8fb887SBoian Petkantchin // 71a8fb887SBoian Petkantchin //===----------------------------------------------------------------------===// 81a8fb887SBoian Petkantchin 91a8fb887SBoian Petkantchin #include "mlir/Dialect/Func/IR/FuncOps.h" 101a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h" 111a8fb887SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" 121a8fb887SBoian Petkantchin #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 131a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinDialect.h" 141a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinOps.h" 151a8fb887SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h" 161a8fb887SBoian Petkantchin #include "mlir/IR/Diagnostics.h" 171a8fb887SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h" 181a8fb887SBoian Petkantchin #include "mlir/IR/PatternMatch.h" 191a8fb887SBoian Petkantchin #include "mlir/IR/SymbolTable.h" 201a8fb887SBoian Petkantchin #include "mlir/IR/Value.h" 211a8fb887SBoian Petkantchin #include "mlir/Pass/Pass.h" 221a8fb887SBoian Petkantchin #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 231a8fb887SBoian Petkantchin 241a8fb887SBoian Petkantchin using namespace mlir; 251a8fb887SBoian Petkantchin using namespace mlir::mesh; 261a8fb887SBoian Petkantchin 271a8fb887SBoian Petkantchin namespace { 281a8fb887SBoian Petkantchin 291a8fb887SBoian Petkantchin struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { 301a8fb887SBoian Petkantchin using OpRewritePattern<ShardOp>::OpRewritePattern; 311a8fb887SBoian Petkantchin 321a8fb887SBoian Petkantchin LogicalResult matchAndRewrite(ShardOp op, 331a8fb887SBoian Petkantchin PatternRewriter &rewriter) const override { 341a8fb887SBoian Petkantchin if (op.getAnnotateForUsers()) { 351a8fb887SBoian Petkantchin return failure(); 361a8fb887SBoian Petkantchin } 371a8fb887SBoian Petkantchin 381a8fb887SBoian Petkantchin SymbolTableCollection symbolTable; 399a8437f5SBoian Petkantchin mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 40baabcb28SFrank Schlimbach op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr()); 411a8fb887SBoian Petkantchin 421a8fb887SBoian Petkantchin bool foundUser = false; 431a8fb887SBoian Petkantchin for (auto user : op->getUsers()) { 441a8fb887SBoian Petkantchin if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) { 451a8fb887SBoian Petkantchin if (targetShardOp.getAnnotateForUsers() && 469a8437f5SBoian Petkantchin mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 47baabcb28SFrank Schlimbach targetShardOp, 48baabcb28SFrank Schlimbach cast<ShardingOp>( 49baabcb28SFrank Schlimbach targetShardOp.getSharding().getDefiningOp()) 50baabcb28SFrank Schlimbach .getMeshAttr())) { 511a8fb887SBoian Petkantchin foundUser = true; 521a8fb887SBoian Petkantchin break; 531a8fb887SBoian Petkantchin } 541a8fb887SBoian Petkantchin } 551a8fb887SBoian Petkantchin } 561a8fb887SBoian Petkantchin 571a8fb887SBoian Petkantchin if (!foundUser) { 581a8fb887SBoian Petkantchin return failure(); 591a8fb887SBoian Petkantchin } 601a8fb887SBoian Petkantchin 611a8fb887SBoian Petkantchin for (auto user : op->getUsers()) { 621a8fb887SBoian Petkantchin auto targetShardOp = llvm::dyn_cast<ShardOp>(user); 631a8fb887SBoian Petkantchin if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || 649a8437f5SBoian Petkantchin symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 65baabcb28SFrank Schlimbach targetShardOp, 66baabcb28SFrank Schlimbach cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()) 67baabcb28SFrank Schlimbach .getMeshAttr()) != mesh) { 681a8fb887SBoian Petkantchin continue; 691a8fb887SBoian Petkantchin } 701a8fb887SBoian Petkantchin 711a8fb887SBoian Petkantchin ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 721a8fb887SBoian Petkantchin ShapedType sourceShardShape = 73baabcb28SFrank Schlimbach shardShapedType(op.getResult().getType(), mesh, op.getSharding()); 74a5757c5bSChristian Sigg TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>( 751a8fb887SBoian Petkantchin builder 76baabcb28SFrank Schlimbach .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc()) 77a5757c5bSChristian Sigg ->getResult(0)); 781a8fb887SBoian Petkantchin TypedValue<ShapedType> targetShard = 791a8fb887SBoian Petkantchin reshard(builder, mesh, op, targetShardOp, sourceShard); 801a8fb887SBoian Petkantchin Value newTargetUnsharded = 811a8fb887SBoian Petkantchin builder 821a8fb887SBoian Petkantchin .create<UnrealizedConversionCastOp>( 831a8fb887SBoian Petkantchin targetShardOp.getResult().getType(), targetShard) 841a8fb887SBoian Petkantchin ->getResult(0); 851a8fb887SBoian Petkantchin rewriter.replaceAllUsesWith(targetShardOp.getResult(), 861a8fb887SBoian Petkantchin newTargetUnsharded); 871a8fb887SBoian Petkantchin } 881a8fb887SBoian Petkantchin 891a8fb887SBoian Petkantchin return success(); 901a8fb887SBoian Petkantchin } 911a8fb887SBoian Petkantchin }; 921a8fb887SBoian Petkantchin 931a8fb887SBoian Petkantchin struct TestMeshReshardingPass 941a8fb887SBoian Petkantchin : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> { 951a8fb887SBoian Petkantchin MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass) 961a8fb887SBoian Petkantchin 971a8fb887SBoian Petkantchin void runOnOperation() override { 981a8fb887SBoian Petkantchin RewritePatternSet patterns(&getContext()); 991a8fb887SBoian Petkantchin patterns.insert<TestMeshReshardingRewritePattern>(&getContext()); 100*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation().getOperation(), 1011a8fb887SBoian Petkantchin std::move(patterns)))) { 1021a8fb887SBoian Petkantchin return signalPassFailure(); 1031a8fb887SBoian Petkantchin } 1041a8fb887SBoian Petkantchin } 1051a8fb887SBoian Petkantchin void getDependentDialects(DialectRegistry ®istry) const override { 1061a8fb887SBoian Petkantchin reshardingRegisterDependentDialects(registry); 1071a8fb887SBoian Petkantchin registry.insert<BuiltinDialect>(); 1081a8fb887SBoian Petkantchin } 1091a8fb887SBoian Petkantchin StringRef getArgument() const final { 1101a8fb887SBoian Petkantchin return "test-mesh-resharding-spmdization"; 1111a8fb887SBoian Petkantchin } 1121a8fb887SBoian Petkantchin StringRef getDescription() const final { 1131a8fb887SBoian Petkantchin return "Test Mesh dialect resharding spmdization."; 1141a8fb887SBoian Petkantchin } 1151a8fb887SBoian Petkantchin }; 1161a8fb887SBoian Petkantchin } // namespace 1171a8fb887SBoian Petkantchin 1181a8fb887SBoian Petkantchin namespace mlir { 1191a8fb887SBoian Petkantchin namespace test { 1201a8fb887SBoian Petkantchin void registerTestMeshReshardingSpmdizationPass() { 1211a8fb887SBoian Petkantchin PassRegistration<TestMeshReshardingPass>(); 1221a8fb887SBoian Petkantchin } 1231a8fb887SBoian Petkantchin } // namespace test 1241a8fb887SBoian Petkantchin } // namespace mlir 125