xref: /llvm-project/mlir/test/lib/Rewrite/TestPDLByteCode.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle 
9b4130e9eSStanislav Funiak #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h"
11abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h"
12abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13abfd1a8bSRiver Riddle 
14abfd1a8bSRiver Riddle using namespace mlir;
15abfd1a8bSRiver Riddle 
16abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL.
17ea64828aSRiver Riddle static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18ea64828aSRiver Riddle                                                   Operation *rootOp) {
19abfd1a8bSRiver Riddle   return success(rootOp->getName().getStringRef() == "test.op");
20abfd1a8bSRiver Riddle }
21ea64828aSRiver Riddle static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22ea64828aSRiver Riddle                                                  Operation *root,
23ea64828aSRiver Riddle                                                  Operation *rootCopy) {
24ea64828aSRiver Riddle   return customSingleEntityConstraint(rewriter, rootCopy);
25abfd1a8bSRiver Riddle }
26ea64828aSRiver Riddle static LogicalResult customMultiEntityVariadicConstraint(
27ea64828aSRiver Riddle     PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
2885ab413bSRiver Riddle   if (operandValues.size() != 2 || typeValues.size() != 2)
2985ab413bSRiver Riddle     return failure();
3085ab413bSRiver Riddle   return success();
3185ab413bSRiver Riddle }
32abfd1a8bSRiver Riddle 
338ec28af8SMatthias Gehre // Custom constraint that returns a value if the op is named test.success_op
348ec28af8SMatthias Gehre static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
358ec28af8SMatthias Gehre                                                  PDLResultList &results,
368ec28af8SMatthias Gehre                                                  ArrayRef<PDLValue> args) {
378ec28af8SMatthias Gehre   auto *op = args[0].cast<Operation *>();
388ec28af8SMatthias Gehre   if (op->getName().getStringRef() == "test.success_op") {
398ec28af8SMatthias Gehre     StringAttr customAttr = rewriter.getStringAttr("test.success");
408ec28af8SMatthias Gehre     results.push_back(customAttr);
418ec28af8SMatthias Gehre     return success();
428ec28af8SMatthias Gehre   }
438ec28af8SMatthias Gehre   return failure();
448ec28af8SMatthias Gehre }
458ec28af8SMatthias Gehre 
468ec28af8SMatthias Gehre // Custom constraint that returns a type if the op is named test.success_op
478ec28af8SMatthias Gehre static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
488ec28af8SMatthias Gehre                                                 PDLResultList &results,
498ec28af8SMatthias Gehre                                                 ArrayRef<PDLValue> args) {
508ec28af8SMatthias Gehre   auto *op = args[0].cast<Operation *>();
518ec28af8SMatthias Gehre   if (op->getName().getStringRef() == "test.success_op") {
528ec28af8SMatthias Gehre     results.push_back(rewriter.getF32Type());
538ec28af8SMatthias Gehre     return success();
548ec28af8SMatthias Gehre   }
558ec28af8SMatthias Gehre   return failure();
568ec28af8SMatthias Gehre }
578ec28af8SMatthias Gehre 
588ec28af8SMatthias Gehre // Custom constraint that returns a type range of variable length if the op is
598ec28af8SMatthias Gehre // named test.success_op
608ec28af8SMatthias Gehre static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
618ec28af8SMatthias Gehre                                                      PDLResultList &results,
628ec28af8SMatthias Gehre                                                      ArrayRef<PDLValue> args) {
638ec28af8SMatthias Gehre   auto *op = args[0].cast<Operation *>();
64a5757c5bSChristian Sigg   int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
658ec28af8SMatthias Gehre 
668ec28af8SMatthias Gehre   if (op->getName().getStringRef() == "test.success_op") {
678ec28af8SMatthias Gehre     SmallVector<Type> types;
688ec28af8SMatthias Gehre     for (int i = 0; i < numTypes; i++) {
698ec28af8SMatthias Gehre       types.push_back(rewriter.getF32Type());
708ec28af8SMatthias Gehre     }
718ec28af8SMatthias Gehre     results.push_back(TypeRange(types));
728ec28af8SMatthias Gehre     return success();
738ec28af8SMatthias Gehre   }
748ec28af8SMatthias Gehre   return failure();
758ec28af8SMatthias Gehre }
768ec28af8SMatthias Gehre 
77abfd1a8bSRiver Riddle // Custom creator invoked from PDL.
78ea64828aSRiver Riddle static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
79ea64828aSRiver Riddle   return rewriter.create(OperationState(op->getLoc(), "test.success"));
80abfd1a8bSRiver Riddle }
81ea64828aSRiver Riddle static auto customVariadicResultCreate(PatternRewriter &rewriter,
82ea64828aSRiver Riddle                                        Operation *root) {
83ea64828aSRiver Riddle   return std::make_pair(root->getOperands(), root->getOperands().getTypes());
8485ab413bSRiver Riddle }
85ea64828aSRiver Riddle static Type customCreateType(PatternRewriter &rewriter) {
86ea64828aSRiver Riddle   return rewriter.getF32Type();
87ea64828aSRiver Riddle }
88ea64828aSRiver Riddle static std::string customCreateStrAttr(PatternRewriter &rewriter) {
89ea64828aSRiver Riddle   return "test.str";
9085ab413bSRiver Riddle }
91abfd1a8bSRiver Riddle 
92abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL.
93ea64828aSRiver Riddle static void customRewriter(PatternRewriter &rewriter, Operation *root,
94ea64828aSRiver Riddle                            Value input) {
95ea64828aSRiver Riddle   rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
96ea64828aSRiver Riddle                   input);
97abfd1a8bSRiver Riddle   rewriter.eraseOp(root);
98abfd1a8bSRiver Riddle }
99abfd1a8bSRiver Riddle 
100abfd1a8bSRiver Riddle namespace {
101abfd1a8bSRiver Riddle struct TestPDLByteCodePass
102abfd1a8bSRiver Riddle     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
1035e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
1045e50dd04SRiver Riddle 
105b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
106b5e22e6dSMehdi Amini   StringRef getDescription() const final {
107b5e22e6dSMehdi Amini     return "Test PDL ByteCode functionality";
108b5e22e6dSMehdi Amini   }
109b4130e9eSStanislav Funiak   void getDependentDialects(DialectRegistry &registry) const override {
110b4130e9eSStanislav Funiak     // Mark the pdl_interp dialect as a dependent. This is needed, because we
111b4130e9eSStanislav Funiak     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
112b4130e9eSStanislav Funiak     registry.insert<pdl_interp::PDLInterpDialect>();
113b4130e9eSStanislav Funiak   }
114abfd1a8bSRiver Riddle   void runOnOperation() final {
115abfd1a8bSRiver Riddle     ModuleOp module = getOperation();
116abfd1a8bSRiver Riddle 
117abfd1a8bSRiver Riddle     // The test cases are encompassed via two modules, one containing the
118abfd1a8bSRiver Riddle     // patterns and one containing the operations to rewrite.
11941d4aa7dSChris Lattner     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
12041d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "patterns"));
12141d4aa7dSChris Lattner     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
12241d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "ir"));
123abfd1a8bSRiver Riddle     if (!patternModule || !irModule)
124abfd1a8bSRiver Riddle       return;
125abfd1a8bSRiver Riddle 
12606c3b9c7SRiver Riddle     RewritePatternSet patternList(module->getContext());
12706c3b9c7SRiver Riddle 
12806c3b9c7SRiver Riddle     // Register ahead of time to test when functions are registered without a
12906c3b9c7SRiver Riddle     // pattern.
13006c3b9c7SRiver Riddle     patternList.getPDLPatterns().registerConstraintFunction(
13106c3b9c7SRiver Riddle         "multi_entity_constraint", customMultiEntityConstraint);
13206c3b9c7SRiver Riddle     patternList.getPDLPatterns().registerConstraintFunction(
13306c3b9c7SRiver Riddle         "single_entity_constraint", customSingleEntityConstraint);
13406c3b9c7SRiver Riddle 
135abfd1a8bSRiver Riddle     // Process the pattern module.
136abfd1a8bSRiver Riddle     patternModule.getOperation()->remove();
137abfd1a8bSRiver Riddle     PDLPatternModule pdlPattern(patternModule);
13806c3b9c7SRiver Riddle 
13906c3b9c7SRiver Riddle     // Note: This constraint was already registered, but we re-register here to
14006c3b9c7SRiver Riddle     // ensure that duplication registration is allowed (the duplicate mapping
14106c3b9c7SRiver Riddle     // will be ignored). This tests that we support separating the registration
14206c3b9c7SRiver Riddle     // of library functions from the construction of patterns, and also that we
14306c3b9c7SRiver Riddle     // allow multiple patterns to depend on the same library functions (without
14406c3b9c7SRiver Riddle     // asserting/crashing).
145abfd1a8bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_constraint",
146abfd1a8bSRiver Riddle                                           customMultiEntityConstraint);
14785ab413bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
14885ab413bSRiver Riddle                                           customMultiEntityVariadicConstraint);
1498ec28af8SMatthias Gehre     pdlPattern.registerConstraintFunction("op_constr_return_attr",
1508ec28af8SMatthias Gehre                                           customValueResultConstraint);
1518ec28af8SMatthias Gehre     pdlPattern.registerConstraintFunction("op_constr_return_type",
1528ec28af8SMatthias Gehre                                           customTypeResultConstraint);
1538ec28af8SMatthias Gehre     pdlPattern.registerConstraintFunction("op_constr_return_type_range",
1548ec28af8SMatthias Gehre                                           customTypeRangeResultConstraint);
15502c4c0d5SRiver Riddle     pdlPattern.registerRewriteFunction("creator", customCreate);
15685ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("var_creator",
15785ab413bSRiver Riddle                                        customVariadicResultCreate);
15885ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
159ea64828aSRiver Riddle     pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
160abfd1a8bSRiver Riddle     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
16106c3b9c7SRiver Riddle     patternList.add(std::move(pdlPattern));
162abfd1a8bSRiver Riddle 
163abfd1a8bSRiver Riddle     // Invoke the pattern driver with the provided patterns.
164*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(irModule.getBodyRegion(),
165abfd1a8bSRiver Riddle                                 std::move(patternList));
166abfd1a8bSRiver Riddle   }
167abfd1a8bSRiver Riddle };
168be0a7e9fSMehdi Amini } // namespace
169abfd1a8bSRiver Riddle 
170abfd1a8bSRiver Riddle namespace mlir {
171abfd1a8bSRiver Riddle namespace test {
172b5e22e6dSMehdi Amini void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
173abfd1a8bSRiver Riddle } // namespace test
174abfd1a8bSRiver Riddle } // namespace mlir
175