xref: /llvm-project/mlir/test/lib/Rewrite/TestPDLByteCode.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 
14 using namespace mlir;
15 
16 /// Custom constraint invoked from PDL.
17 static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18                                                   Operation *rootOp) {
19   return success(rootOp->getName().getStringRef() == "test.op");
20 }
21 static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22                                                  Operation *root,
23                                                  Operation *rootCopy) {
24   return customSingleEntityConstraint(rewriter, rootCopy);
25 }
26 static LogicalResult customMultiEntityVariadicConstraint(
27     PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
28   if (operandValues.size() != 2 || typeValues.size() != 2)
29     return failure();
30   return success();
31 }
32 
33 // Custom constraint that returns a value if the op is named test.success_op
34 static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
35                                                  PDLResultList &results,
36                                                  ArrayRef<PDLValue> args) {
37   auto *op = args[0].cast<Operation *>();
38   if (op->getName().getStringRef() == "test.success_op") {
39     StringAttr customAttr = rewriter.getStringAttr("test.success");
40     results.push_back(customAttr);
41     return success();
42   }
43   return failure();
44 }
45 
46 // Custom constraint that returns a type if the op is named test.success_op
47 static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
48                                                 PDLResultList &results,
49                                                 ArrayRef<PDLValue> args) {
50   auto *op = args[0].cast<Operation *>();
51   if (op->getName().getStringRef() == "test.success_op") {
52     results.push_back(rewriter.getF32Type());
53     return success();
54   }
55   return failure();
56 }
57 
58 // Custom constraint that returns a type range of variable length if the op is
59 // named test.success_op
60 static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
61                                                      PDLResultList &results,
62                                                      ArrayRef<PDLValue> args) {
63   auto *op = args[0].cast<Operation *>();
64   int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
65 
66   if (op->getName().getStringRef() == "test.success_op") {
67     SmallVector<Type> types;
68     for (int i = 0; i < numTypes; i++) {
69       types.push_back(rewriter.getF32Type());
70     }
71     results.push_back(TypeRange(types));
72     return success();
73   }
74   return failure();
75 }
76 
77 // Custom creator invoked from PDL.
78 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
79   return rewriter.create(OperationState(op->getLoc(), "test.success"));
80 }
81 static auto customVariadicResultCreate(PatternRewriter &rewriter,
82                                        Operation *root) {
83   return std::make_pair(root->getOperands(), root->getOperands().getTypes());
84 }
85 static Type customCreateType(PatternRewriter &rewriter) {
86   return rewriter.getF32Type();
87 }
88 static std::string customCreateStrAttr(PatternRewriter &rewriter) {
89   return "test.str";
90 }
91 
92 /// Custom rewriter invoked from PDL.
93 static void customRewriter(PatternRewriter &rewriter, Operation *root,
94                            Value input) {
95   rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
96                   input);
97   rewriter.eraseOp(root);
98 }
99 
100 namespace {
101 struct TestPDLByteCodePass
102     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
103   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
104 
105   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
106   StringRef getDescription() const final {
107     return "Test PDL ByteCode functionality";
108   }
109   void getDependentDialects(DialectRegistry &registry) const override {
110     // Mark the pdl_interp dialect as a dependent. This is needed, because we
111     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
112     registry.insert<pdl_interp::PDLInterpDialect>();
113   }
114   void runOnOperation() final {
115     ModuleOp module = getOperation();
116 
117     // The test cases are encompassed via two modules, one containing the
118     // patterns and one containing the operations to rewrite.
119     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
120         StringAttr::get(module->getContext(), "patterns"));
121     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
122         StringAttr::get(module->getContext(), "ir"));
123     if (!patternModule || !irModule)
124       return;
125 
126     RewritePatternSet patternList(module->getContext());
127 
128     // Register ahead of time to test when functions are registered without a
129     // pattern.
130     patternList.getPDLPatterns().registerConstraintFunction(
131         "multi_entity_constraint", customMultiEntityConstraint);
132     patternList.getPDLPatterns().registerConstraintFunction(
133         "single_entity_constraint", customSingleEntityConstraint);
134 
135     // Process the pattern module.
136     patternModule.getOperation()->remove();
137     PDLPatternModule pdlPattern(patternModule);
138 
139     // Note: This constraint was already registered, but we re-register here to
140     // ensure that duplication registration is allowed (the duplicate mapping
141     // will be ignored). This tests that we support separating the registration
142     // of library functions from the construction of patterns, and also that we
143     // allow multiple patterns to depend on the same library functions (without
144     // asserting/crashing).
145     pdlPattern.registerConstraintFunction("multi_entity_constraint",
146                                           customMultiEntityConstraint);
147     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
148                                           customMultiEntityVariadicConstraint);
149     pdlPattern.registerConstraintFunction("op_constr_return_attr",
150                                           customValueResultConstraint);
151     pdlPattern.registerConstraintFunction("op_constr_return_type",
152                                           customTypeResultConstraint);
153     pdlPattern.registerConstraintFunction("op_constr_return_type_range",
154                                           customTypeRangeResultConstraint);
155     pdlPattern.registerRewriteFunction("creator", customCreate);
156     pdlPattern.registerRewriteFunction("var_creator",
157                                        customVariadicResultCreate);
158     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
159     pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
160     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
161     patternList.add(std::move(pdlPattern));
162 
163     // Invoke the pattern driver with the provided patterns.
164     (void)applyPatternsGreedily(irModule.getBodyRegion(),
165                                 std::move(patternList));
166   }
167 };
168 } // namespace
169 
170 namespace mlir {
171 namespace test {
172 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
173 } // namespace test
174 } // namespace mlir
175