1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle 9b4130e9eSStanislav Funiak #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h" 11abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h" 12abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13abfd1a8bSRiver Riddle 14abfd1a8bSRiver Riddle using namespace mlir; 15abfd1a8bSRiver Riddle 16abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL. 17ea64828aSRiver Riddle static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, 18ea64828aSRiver Riddle Operation *rootOp) { 19abfd1a8bSRiver Riddle return success(rootOp->getName().getStringRef() == "test.op"); 20abfd1a8bSRiver Riddle } 21ea64828aSRiver Riddle static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, 22ea64828aSRiver Riddle Operation *root, 23ea64828aSRiver Riddle Operation *rootCopy) { 24ea64828aSRiver Riddle return customSingleEntityConstraint(rewriter, rootCopy); 25abfd1a8bSRiver Riddle } 26ea64828aSRiver Riddle static LogicalResult customMultiEntityVariadicConstraint( 27ea64828aSRiver Riddle PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { 2885ab413bSRiver Riddle if (operandValues.size() != 2 || typeValues.size() != 2) 2985ab413bSRiver Riddle return failure(); 3085ab413bSRiver Riddle return success(); 3185ab413bSRiver Riddle } 32abfd1a8bSRiver Riddle 338ec28af8SMatthias Gehre // Custom constraint that returns a value if the op is named test.success_op 348ec28af8SMatthias Gehre static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, 358ec28af8SMatthias Gehre PDLResultList &results, 368ec28af8SMatthias Gehre ArrayRef<PDLValue> args) { 378ec28af8SMatthias Gehre auto *op = args[0].cast<Operation *>(); 388ec28af8SMatthias Gehre if (op->getName().getStringRef() == "test.success_op") { 398ec28af8SMatthias Gehre StringAttr customAttr = rewriter.getStringAttr("test.success"); 408ec28af8SMatthias Gehre results.push_back(customAttr); 418ec28af8SMatthias Gehre return success(); 428ec28af8SMatthias Gehre } 438ec28af8SMatthias Gehre return failure(); 448ec28af8SMatthias Gehre } 458ec28af8SMatthias Gehre 468ec28af8SMatthias Gehre // Custom constraint that returns a type if the op is named test.success_op 478ec28af8SMatthias Gehre static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, 488ec28af8SMatthias Gehre PDLResultList &results, 498ec28af8SMatthias Gehre ArrayRef<PDLValue> args) { 508ec28af8SMatthias Gehre auto *op = args[0].cast<Operation *>(); 518ec28af8SMatthias Gehre if (op->getName().getStringRef() == "test.success_op") { 528ec28af8SMatthias Gehre results.push_back(rewriter.getF32Type()); 538ec28af8SMatthias Gehre return success(); 548ec28af8SMatthias Gehre } 558ec28af8SMatthias Gehre return failure(); 568ec28af8SMatthias Gehre } 578ec28af8SMatthias Gehre 588ec28af8SMatthias Gehre // Custom constraint that returns a type range of variable length if the op is 598ec28af8SMatthias Gehre // named test.success_op 608ec28af8SMatthias Gehre static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, 618ec28af8SMatthias Gehre PDLResultList &results, 628ec28af8SMatthias Gehre ArrayRef<PDLValue> args) { 638ec28af8SMatthias Gehre auto *op = args[0].cast<Operation *>(); 64a5757c5bSChristian Sigg int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt(); 658ec28af8SMatthias Gehre 668ec28af8SMatthias Gehre if (op->getName().getStringRef() == "test.success_op") { 678ec28af8SMatthias Gehre SmallVector<Type> types; 688ec28af8SMatthias Gehre for (int i = 0; i < numTypes; i++) { 698ec28af8SMatthias Gehre types.push_back(rewriter.getF32Type()); 708ec28af8SMatthias Gehre } 718ec28af8SMatthias Gehre results.push_back(TypeRange(types)); 728ec28af8SMatthias Gehre return success(); 738ec28af8SMatthias Gehre } 748ec28af8SMatthias Gehre return failure(); 758ec28af8SMatthias Gehre } 768ec28af8SMatthias Gehre 77abfd1a8bSRiver Riddle // Custom creator invoked from PDL. 78ea64828aSRiver Riddle static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { 79ea64828aSRiver Riddle return rewriter.create(OperationState(op->getLoc(), "test.success")); 80abfd1a8bSRiver Riddle } 81ea64828aSRiver Riddle static auto customVariadicResultCreate(PatternRewriter &rewriter, 82ea64828aSRiver Riddle Operation *root) { 83ea64828aSRiver Riddle return std::make_pair(root->getOperands(), root->getOperands().getTypes()); 8485ab413bSRiver Riddle } 85ea64828aSRiver Riddle static Type customCreateType(PatternRewriter &rewriter) { 86ea64828aSRiver Riddle return rewriter.getF32Type(); 87ea64828aSRiver Riddle } 88ea64828aSRiver Riddle static std::string customCreateStrAttr(PatternRewriter &rewriter) { 89ea64828aSRiver Riddle return "test.str"; 9085ab413bSRiver Riddle } 91abfd1a8bSRiver Riddle 92abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL. 93ea64828aSRiver Riddle static void customRewriter(PatternRewriter &rewriter, Operation *root, 94ea64828aSRiver Riddle Value input) { 95ea64828aSRiver Riddle rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), 96ea64828aSRiver Riddle input); 97abfd1a8bSRiver Riddle rewriter.eraseOp(root); 98abfd1a8bSRiver Riddle } 99abfd1a8bSRiver Riddle 100abfd1a8bSRiver Riddle namespace { 101abfd1a8bSRiver Riddle struct TestPDLByteCodePass 102abfd1a8bSRiver Riddle : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 1035e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) 1045e50dd04SRiver Riddle 105b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 106b5e22e6dSMehdi Amini StringRef getDescription() const final { 107b5e22e6dSMehdi Amini return "Test PDL ByteCode functionality"; 108b5e22e6dSMehdi Amini } 109b4130e9eSStanislav Funiak void getDependentDialects(DialectRegistry ®istry) const override { 110b4130e9eSStanislav Funiak // Mark the pdl_interp dialect as a dependent. This is needed, because we 111b4130e9eSStanislav Funiak // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 112b4130e9eSStanislav Funiak registry.insert<pdl_interp::PDLInterpDialect>(); 113b4130e9eSStanislav Funiak } 114abfd1a8bSRiver Riddle void runOnOperation() final { 115abfd1a8bSRiver Riddle ModuleOp module = getOperation(); 116abfd1a8bSRiver Riddle 117abfd1a8bSRiver Riddle // The test cases are encompassed via two modules, one containing the 118abfd1a8bSRiver Riddle // patterns and one containing the operations to rewrite. 11941d4aa7dSChris Lattner ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 12041d4aa7dSChris Lattner StringAttr::get(module->getContext(), "patterns")); 12141d4aa7dSChris Lattner ModuleOp irModule = module.lookupSymbol<ModuleOp>( 12241d4aa7dSChris Lattner StringAttr::get(module->getContext(), "ir")); 123abfd1a8bSRiver Riddle if (!patternModule || !irModule) 124abfd1a8bSRiver Riddle return; 125abfd1a8bSRiver Riddle 12606c3b9c7SRiver Riddle RewritePatternSet patternList(module->getContext()); 12706c3b9c7SRiver Riddle 12806c3b9c7SRiver Riddle // Register ahead of time to test when functions are registered without a 12906c3b9c7SRiver Riddle // pattern. 13006c3b9c7SRiver Riddle patternList.getPDLPatterns().registerConstraintFunction( 13106c3b9c7SRiver Riddle "multi_entity_constraint", customMultiEntityConstraint); 13206c3b9c7SRiver Riddle patternList.getPDLPatterns().registerConstraintFunction( 13306c3b9c7SRiver Riddle "single_entity_constraint", customSingleEntityConstraint); 13406c3b9c7SRiver Riddle 135abfd1a8bSRiver Riddle // Process the pattern module. 136abfd1a8bSRiver Riddle patternModule.getOperation()->remove(); 137abfd1a8bSRiver Riddle PDLPatternModule pdlPattern(patternModule); 13806c3b9c7SRiver Riddle 13906c3b9c7SRiver Riddle // Note: This constraint was already registered, but we re-register here to 14006c3b9c7SRiver Riddle // ensure that duplication registration is allowed (the duplicate mapping 14106c3b9c7SRiver Riddle // will be ignored). This tests that we support separating the registration 14206c3b9c7SRiver Riddle // of library functions from the construction of patterns, and also that we 14306c3b9c7SRiver Riddle // allow multiple patterns to depend on the same library functions (without 14406c3b9c7SRiver Riddle // asserting/crashing). 145abfd1a8bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_constraint", 146abfd1a8bSRiver Riddle customMultiEntityConstraint); 14785ab413bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 14885ab413bSRiver Riddle customMultiEntityVariadicConstraint); 1498ec28af8SMatthias Gehre pdlPattern.registerConstraintFunction("op_constr_return_attr", 1508ec28af8SMatthias Gehre customValueResultConstraint); 1518ec28af8SMatthias Gehre pdlPattern.registerConstraintFunction("op_constr_return_type", 1528ec28af8SMatthias Gehre customTypeResultConstraint); 1538ec28af8SMatthias Gehre pdlPattern.registerConstraintFunction("op_constr_return_type_range", 1548ec28af8SMatthias Gehre customTypeRangeResultConstraint); 15502c4c0d5SRiver Riddle pdlPattern.registerRewriteFunction("creator", customCreate); 15685ab413bSRiver Riddle pdlPattern.registerRewriteFunction("var_creator", 15785ab413bSRiver Riddle customVariadicResultCreate); 15885ab413bSRiver Riddle pdlPattern.registerRewriteFunction("type_creator", customCreateType); 159ea64828aSRiver Riddle pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr); 160abfd1a8bSRiver Riddle pdlPattern.registerRewriteFunction("rewriter", customRewriter); 16106c3b9c7SRiver Riddle patternList.add(std::move(pdlPattern)); 162abfd1a8bSRiver Riddle 163abfd1a8bSRiver Riddle // Invoke the pattern driver with the provided patterns. 164*09dfc571SJacques Pienaar (void)applyPatternsGreedily(irModule.getBodyRegion(), 165abfd1a8bSRiver Riddle std::move(patternList)); 166abfd1a8bSRiver Riddle } 167abfd1a8bSRiver Riddle }; 168be0a7e9fSMehdi Amini } // namespace 169abfd1a8bSRiver Riddle 170abfd1a8bSRiver Riddle namespace mlir { 171abfd1a8bSRiver Riddle namespace test { 172b5e22e6dSMehdi Amini void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 173abfd1a8bSRiver Riddle } // namespace test 174abfd1a8bSRiver Riddle } // namespace mlir 175