1 //===- TestSimplification.cpp - Test simplification -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 11 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" 12 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 13 #include "mlir/IR/BuiltinDialect.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/BuiltinTypeInterfaces.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/SymbolTable.h" 20 #include "mlir/IR/Value.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 using namespace mlir::mesh; 26 27 namespace { 28 29 struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { 30 using OpRewritePattern<ShardOp>::OpRewritePattern; 31 32 LogicalResult matchAndRewrite(ShardOp op, 33 PatternRewriter &rewriter) const override { 34 if (op.getAnnotateForUsers()) { 35 return failure(); 36 } 37 38 SymbolTableCollection symbolTable; 39 mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 40 op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr()); 41 42 bool foundUser = false; 43 for (auto user : op->getUsers()) { 44 if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) { 45 if (targetShardOp.getAnnotateForUsers() && 46 mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 47 targetShardOp, 48 cast<ShardingOp>( 49 targetShardOp.getSharding().getDefiningOp()) 50 .getMeshAttr())) { 51 foundUser = true; 52 break; 53 } 54 } 55 } 56 57 if (!foundUser) { 58 return failure(); 59 } 60 61 for (auto user : op->getUsers()) { 62 auto targetShardOp = llvm::dyn_cast<ShardOp>(user); 63 if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || 64 symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( 65 targetShardOp, 66 cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()) 67 .getMeshAttr()) != mesh) { 68 continue; 69 } 70 71 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 72 ShapedType sourceShardShape = 73 shardShapedType(op.getResult().getType(), mesh, op.getSharding()); 74 TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>( 75 builder 76 .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc()) 77 ->getResult(0)); 78 TypedValue<ShapedType> targetShard = 79 reshard(builder, mesh, op, targetShardOp, sourceShard); 80 Value newTargetUnsharded = 81 builder 82 .create<UnrealizedConversionCastOp>( 83 targetShardOp.getResult().getType(), targetShard) 84 ->getResult(0); 85 rewriter.replaceAllUsesWith(targetShardOp.getResult(), 86 newTargetUnsharded); 87 } 88 89 return success(); 90 } 91 }; 92 93 struct TestMeshReshardingPass 94 : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> { 95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass) 96 97 void runOnOperation() override { 98 RewritePatternSet patterns(&getContext()); 99 patterns.insert<TestMeshReshardingRewritePattern>(&getContext()); 100 if (failed(applyPatternsGreedily(getOperation().getOperation(), 101 std::move(patterns)))) { 102 return signalPassFailure(); 103 } 104 } 105 void getDependentDialects(DialectRegistry ®istry) const override { 106 reshardingRegisterDependentDialects(registry); 107 registry.insert<BuiltinDialect>(); 108 } 109 StringRef getArgument() const final { 110 return "test-mesh-resharding-spmdization"; 111 } 112 StringRef getDescription() const final { 113 return "Test Mesh dialect resharding spmdization."; 114 } 115 }; 116 } // namespace 117 118 namespace mlir { 119 namespace test { 120 void registerTestMeshReshardingSpmdizationPass() { 121 PassRegistration<TestMeshReshardingPass>(); 122 } 123 } // namespace test 124 } // namespace mlir 125