1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Pass/Pass.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 14 using namespace mlir; 15 16 /// Custom constraint invoked from PDL. 17 static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, 18 Operation *rootOp) { 19 return success(rootOp->getName().getStringRef() == "test.op"); 20 } 21 static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, 22 Operation *root, 23 Operation *rootCopy) { 24 return customSingleEntityConstraint(rewriter, rootCopy); 25 } 26 static LogicalResult customMultiEntityVariadicConstraint( 27 PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { 28 if (operandValues.size() != 2 || typeValues.size() != 2) 29 return failure(); 30 return success(); 31 } 32 33 // Custom constraint that returns a value if the op is named test.success_op 34 static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, 35 PDLResultList &results, 36 ArrayRef<PDLValue> args) { 37 auto *op = args[0].cast<Operation *>(); 38 if (op->getName().getStringRef() == "test.success_op") { 39 StringAttr customAttr = rewriter.getStringAttr("test.success"); 40 results.push_back(customAttr); 41 return success(); 42 } 43 return failure(); 44 } 45 46 // Custom constraint that returns a type if the op is named test.success_op 47 static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, 48 PDLResultList &results, 49 ArrayRef<PDLValue> args) { 50 auto *op = args[0].cast<Operation *>(); 51 if (op->getName().getStringRef() == "test.success_op") { 52 results.push_back(rewriter.getF32Type()); 53 return success(); 54 } 55 return failure(); 56 } 57 58 // Custom constraint that returns a type range of variable length if the op is 59 // named test.success_op 60 static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, 61 PDLResultList &results, 62 ArrayRef<PDLValue> args) { 63 auto *op = args[0].cast<Operation *>(); 64 int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt(); 65 66 if (op->getName().getStringRef() == "test.success_op") { 67 SmallVector<Type> types; 68 for (int i = 0; i < numTypes; i++) { 69 types.push_back(rewriter.getF32Type()); 70 } 71 results.push_back(TypeRange(types)); 72 return success(); 73 } 74 return failure(); 75 } 76 77 // Custom creator invoked from PDL. 78 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { 79 return rewriter.create(OperationState(op->getLoc(), "test.success")); 80 } 81 static auto customVariadicResultCreate(PatternRewriter &rewriter, 82 Operation *root) { 83 return std::make_pair(root->getOperands(), root->getOperands().getTypes()); 84 } 85 static Type customCreateType(PatternRewriter &rewriter) { 86 return rewriter.getF32Type(); 87 } 88 static std::string customCreateStrAttr(PatternRewriter &rewriter) { 89 return "test.str"; 90 } 91 92 /// Custom rewriter invoked from PDL. 93 static void customRewriter(PatternRewriter &rewriter, Operation *root, 94 Value input) { 95 rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), 96 input); 97 rewriter.eraseOp(root); 98 } 99 100 namespace { 101 struct TestPDLByteCodePass 102 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 103 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) 104 105 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 106 StringRef getDescription() const final { 107 return "Test PDL ByteCode functionality"; 108 } 109 void getDependentDialects(DialectRegistry ®istry) const override { 110 // Mark the pdl_interp dialect as a dependent. This is needed, because we 111 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 112 registry.insert<pdl_interp::PDLInterpDialect>(); 113 } 114 void runOnOperation() final { 115 ModuleOp module = getOperation(); 116 117 // The test cases are encompassed via two modules, one containing the 118 // patterns and one containing the operations to rewrite. 119 ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 120 StringAttr::get(module->getContext(), "patterns")); 121 ModuleOp irModule = module.lookupSymbol<ModuleOp>( 122 StringAttr::get(module->getContext(), "ir")); 123 if (!patternModule || !irModule) 124 return; 125 126 RewritePatternSet patternList(module->getContext()); 127 128 // Register ahead of time to test when functions are registered without a 129 // pattern. 130 patternList.getPDLPatterns().registerConstraintFunction( 131 "multi_entity_constraint", customMultiEntityConstraint); 132 patternList.getPDLPatterns().registerConstraintFunction( 133 "single_entity_constraint", customSingleEntityConstraint); 134 135 // Process the pattern module. 136 patternModule.getOperation()->remove(); 137 PDLPatternModule pdlPattern(patternModule); 138 139 // Note: This constraint was already registered, but we re-register here to 140 // ensure that duplication registration is allowed (the duplicate mapping 141 // will be ignored). This tests that we support separating the registration 142 // of library functions from the construction of patterns, and also that we 143 // allow multiple patterns to depend on the same library functions (without 144 // asserting/crashing). 145 pdlPattern.registerConstraintFunction("multi_entity_constraint", 146 customMultiEntityConstraint); 147 pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 148 customMultiEntityVariadicConstraint); 149 pdlPattern.registerConstraintFunction("op_constr_return_attr", 150 customValueResultConstraint); 151 pdlPattern.registerConstraintFunction("op_constr_return_type", 152 customTypeResultConstraint); 153 pdlPattern.registerConstraintFunction("op_constr_return_type_range", 154 customTypeRangeResultConstraint); 155 pdlPattern.registerRewriteFunction("creator", customCreate); 156 pdlPattern.registerRewriteFunction("var_creator", 157 customVariadicResultCreate); 158 pdlPattern.registerRewriteFunction("type_creator", customCreateType); 159 pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr); 160 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 161 patternList.add(std::move(pdlPattern)); 162 163 // Invoke the pattern driver with the provided patterns. 164 (void)applyPatternsGreedily(irModule.getBodyRegion(), 165 std::move(patternList)); 166 } 167 }; 168 } // namespace 169 170 namespace mlir { 171 namespace test { 172 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 173 } // namespace test 174 } // namespace mlir 175