xref: /llvm-project/mlir/lib/IR/PDL/PDLPatternMatch.cpp (revision 6ae7f66ff5169ddc5a7b9ab545707042c77e036c)
1 //===- PDLPatternMatch.cpp - Base classes for PDL pattern match
2 //------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/IR/IRMapping.h"
11 #include "mlir/IR/Iterators.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/IR/RegionKindInterface.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // PDLValue
19 //===----------------------------------------------------------------------===//
20 
print(raw_ostream & os) const21 void PDLValue::print(raw_ostream &os) const {
22   if (!value) {
23     os << "<NULL-PDLValue>";
24     return;
25   }
26   switch (kind) {
27   case Kind::Attribute:
28     os << cast<Attribute>();
29     break;
30   case Kind::Operation:
31     os << *cast<Operation *>();
32     break;
33   case Kind::Type:
34     os << cast<Type>();
35     break;
36   case Kind::TypeRange:
37     llvm::interleaveComma(cast<TypeRange>(), os);
38     break;
39   case Kind::Value:
40     os << cast<Value>();
41     break;
42   case Kind::ValueRange:
43     llvm::interleaveComma(cast<ValueRange>(), os);
44     break;
45   }
46 }
47 
print(raw_ostream & os,Kind kind)48 void PDLValue::print(raw_ostream &os, Kind kind) {
49   switch (kind) {
50   case Kind::Attribute:
51     os << "Attribute";
52     break;
53   case Kind::Operation:
54     os << "Operation";
55     break;
56   case Kind::Type:
57     os << "Type";
58     break;
59   case Kind::TypeRange:
60     os << "TypeRange";
61     break;
62   case Kind::Value:
63     os << "Value";
64     break;
65   case Kind::ValueRange:
66     os << "ValueRange";
67     break;
68   }
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // PDLPatternModule
73 //===----------------------------------------------------------------------===//
74 
mergeIn(PDLPatternModule && other)75 void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
76   // Ignore the other module if it has no patterns.
77   if (!other.pdlModule)
78     return;
79 
80   // Steal the functions and config of the other module.
81   for (auto &it : other.constraintFunctions)
82     registerConstraintFunction(it.first(), std::move(it.second));
83   for (auto &it : other.rewriteFunctions)
84     registerRewriteFunction(it.first(), std::move(it.second));
85   for (auto &it : other.configs)
86     configs.emplace_back(std::move(it));
87   for (auto &it : other.configMap)
88     configMap.insert(it);
89 
90   // Steal the other state if we have no patterns.
91   if (!pdlModule) {
92     pdlModule = std::move(other.pdlModule);
93     return;
94   }
95 
96   // Merge the pattern operations from the other module into this one.
97   Block *block = pdlModule->getBody();
98   block->getOperations().splice(block->end(),
99                                 other.pdlModule->getBody()->getOperations());
100 }
101 
attachConfigToPatterns(ModuleOp module,PDLPatternConfigSet & configSet)102 void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
103                                               PDLPatternConfigSet &configSet) {
104   // Attach the configuration to the symbols within the module. We only add
105   // to symbols to avoid hardcoding any specific operation names here (given
106   // that we don't depend on any PDL dialect). We can't use
107   // cast<SymbolOpInterface> here because patterns may be optional symbols.
108   module->walk([&](Operation *op) {
109     if (op->hasTrait<SymbolOpInterface::Trait>())
110       configMap[op] = &configSet;
111   });
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Function Registry
116 
registerConstraintFunction(StringRef name,PDLConstraintFunction constraintFn)117 void PDLPatternModule::registerConstraintFunction(
118     StringRef name, PDLConstraintFunction constraintFn) {
119   // TODO: Is it possible to diagnose when `name` is already registered to
120   // a function that is not equivalent to `constraintFn`?
121   // Allow existing mappings in the case multiple patterns depend on the same
122   // constraint.
123   constraintFunctions.try_emplace(name, std::move(constraintFn));
124 }
125 
registerRewriteFunction(StringRef name,PDLRewriteFunction rewriteFn)126 void PDLPatternModule::registerRewriteFunction(StringRef name,
127                                                PDLRewriteFunction rewriteFn) {
128   // TODO: Is it possible to diagnose when `name` is already registered to
129   // a function that is not equivalent to `rewriteFn`?
130   // Allow existing mappings in the case multiple patterns depend on the same
131   // rewrite.
132   rewriteFunctions.try_emplace(name, std::move(rewriteFn));
133 }
134