xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 3e5ee82b817ce0c0188f046b5ef3b01a329d8041)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Operator.h"
24 #include "mlir/TableGen/Predicate.h"
25 #include "mlir/TableGen/Type.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/PrettyStackTrace.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Main.h"
34 #include "llvm/TableGen/Record.h"
35 #include "llvm/TableGen/TableGenBackend.h"
36 
37 using namespace llvm;
38 using namespace mlir;
39 
40 using mlir::tblgen::Operator;
41 using mlir::tblgen::Type;
42 
43 namespace {
44 
45 // Wrapper around dag argument.
46 struct DagArg {
47   DagArg(Init *init) : init(init) {}
48   bool isAttr();
49 
50   Init *init;
51 };
52 
53 } // end namespace
54 
55 bool DagArg::isAttr() {
56   if (auto defInit = dyn_cast<DefInit>(init))
57     return defInit->getDef()->isSubClassOf("Attr");
58   return false;
59 }
60 
61 namespace {
62 class Pattern {
63 public:
64   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
65 
66 private:
67   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
68 
69   // Emit the rewrite pattern named `rewriteName`.
70   void emit(StringRef rewriteName);
71 
72   // Emit the matcher.
73   void emitMatcher(DagInit *tree);
74 
75   // Emits the value of constant attribute to `os`.
76   void emitAttributeValue(Record *constAttr);
77 
78   // Collect bound arguments.
79   void collectBoundArguments(DagInit *tree);
80 
81   // Map from bound argument name to DagArg.
82   StringMap<DagArg> boundArguments;
83 
84   // Number of the operations in the input pattern.
85   int numberOfOpsMatched = 0;
86 
87   Record *pattern;
88   raw_ostream &os;
89 };
90 } // end namespace
91 
92 void Pattern::emitAttributeValue(Record *constAttr) {
93   Record *attr = constAttr->getValueAsDef("attr");
94   auto value = constAttr->getValue("value");
95   Type type(attr->getValueAsDef("type"));
96   auto storageType = attr->getValueAsString("storageType").trim();
97 
98   // For attributes stored as strings we do not need to query builder etc.
99   if (storageType == "StringAttr") {
100     os << formatv("rewriter.getStringAttr({0})",
101                   value->getValue()->getAsString());
102     return;
103   }
104 
105   auto builder = type.getBuilderCall();
106   if (builder.empty())
107     PrintFatalError(pattern->getLoc(),
108                     "no builder specified for " + type.getTableGenDefName());
109 
110   // Construct the attribute based on storage type and builder.
111   // TODO(jpienaar): Verify the constants here
112   os << formatv("{0}::get(rewriter.{1}, {2})", storageType, builder,
113                 value->getValue()->getAsUnquotedString());
114 }
115 
116 void Pattern::collectBoundArguments(DagInit *tree) {
117   ++numberOfOpsMatched;
118   // TODO(jpienaar): Expand to multiple matches.
119   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
120     auto arg = tree->getArg(i);
121     if (auto argTree = dyn_cast<DagInit>(arg)) {
122       collectBoundArguments(argTree);
123       continue;
124     }
125     auto name = tree->getArgNameStr(i);
126     if (name.empty())
127       continue;
128     boundArguments.try_emplace(name, arg);
129   }
130 }
131 
132 // Helper function to match patterns.
133 static void matchOp(Record *pattern, DagInit *tree, int depth,
134                     raw_ostream &os) {
135   Operator op(cast<DefInit>(tree->getOperator())->getDef());
136   int indent = 4 + 2 * depth;
137   // Skip the operand matching at depth 0 as the pattern rewriter already does.
138   if (depth != 0) {
139     // Skip if there is no defining instruction (e.g., arguments to function).
140     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
141     os.indent(indent) << formatv(
142         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
143         op.qualifiedCppClassName());
144   }
145   if (tree->getNumArgs() != op.getNumArgs())
146     PrintFatalError(pattern->getLoc(),
147                     Twine("mismatch in number of arguments to op '") +
148                         op.getOperationName() +
149                         "' in pattern and op's definition");
150   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
151     auto arg = tree->getArg(i);
152     auto opArg = op.getArg(i);
153 
154     if (auto argTree = dyn_cast<DagInit>(arg)) {
155       os.indent(indent) << "{\n";
156       os.indent(indent + 2) << formatv(
157           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
158           depth + 1, depth, i);
159       matchOp(pattern, argTree, depth + 1, os);
160       os.indent(indent) << "}\n";
161       continue;
162     }
163 
164     // Verify arguments.
165     if (auto defInit = dyn_cast<DefInit>(arg)) {
166       // Verify operands.
167       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
168         // Skip verification where not needed due to definition of op.
169         if (operand->defInit == defInit)
170           goto StateCapture;
171 
172         if (!defInit->getDef()->isSubClassOf("Type"))
173           PrintFatalError(pattern->getLoc(),
174                           "type argument required for operand");
175 
176         auto pred = tblgen::Type(defInit).getPredicate();
177 
178         os.indent(indent)
179             << "if (!("
180             << formatv(pred.createTypeMatcherTemplate().c_str(),
181                        formatv("op{0}->getOperand({1})->getType()", depth, i))
182             << ")) return matchFailure();\n";
183       }
184 
185       // TODO(jpienaar): Verify attributes.
186       if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
187       }
188     }
189 
190   StateCapture:
191     auto name = tree->getArgNameStr(i);
192     if (name.empty())
193       continue;
194     if (opArg.is<Operator::Operand *>())
195       os.indent(indent) << "state->" << name << " = op" << depth
196                         << "->getOperand(" << i << ");\n";
197     if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
198       os.indent(indent) << "state->" << name << " = op" << depth
199                         << "->getAttrOfType<" << attr->getStorageType()
200                         << ">(\"" << attr->getName() << "\");\n";
201     }
202   }
203 }
204 
205 void Pattern::emitMatcher(DagInit *tree) {
206   // Emit the heading.
207   os << R"(
208   PatternMatchResult match(OperationInst *op0) const override {
209     // TODO: This just handle 1 result
210     if (op0->getNumResults() != 1) return matchFailure();
211     auto state = std::make_unique<MatchedState>();)"
212      << "\n";
213   matchOp(pattern, tree, 0, os);
214   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
215 }
216 
217 void Pattern::emit(StringRef rewriteName) {
218   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
219   // Collect bound arguments and compute number of ops matched.
220   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
221   // moment, revise.
222   collectBoundArguments(tree);
223 
224   // Emit RewritePattern for Pattern.
225   DefInit *root = cast<DefInit>(tree->getOperator());
226   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
227   os << formatv(R"(struct {0} : public RewritePattern {
228   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
229                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
230      << "\n";
231 
232   // Emit matched state.
233   os << "  struct MatchedState : public PatternState {\n";
234   for (auto &arg : boundArguments) {
235     if (arg.second.isAttr()) {
236       DefInit *defInit = cast<DefInit>(arg.second.init);
237       os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
238                    << " " << arg.first() << ";\n";
239     } else {
240       os.indent(4) << "Value* " << arg.first() << ";\n";
241     }
242   }
243   os << "  };\n";
244 
245   emitMatcher(tree);
246   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
247   if (resultOps->size() != 1)
248     PrintFatalError("only single result rules supported");
249   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
250 
251   // TODO(jpienaar): Expand to multiple results.
252   for (auto result : resultTree->getArgs()) {
253     if (isa<DagInit>(result))
254       PrintFatalError(pattern->getLoc(), "only single op result supported");
255   }
256 
257   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
258   Operator resultOp(*resultRoot->getDef());
259   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
260 
261   os << formatv(R"(
262   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
263                PatternRewriter &rewriter) const override {
264     auto& s = *static_cast<MatchedState *>(state.get());
265     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
266                 resultOp.cppClassName());
267   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
268     PrintFatalError(pattern->getLoc(),
269                     Twine("mismatch between arguments of resultant op (") +
270                         Twine(resultOperands->getNumArgs()) +
271                         ") and arguments provided for rewrite (" +
272                         Twine(resultTree->getNumArgs()) + Twine(')'));
273   }
274 
275   // Create the builder call for the result.
276   // Add operands.
277   int i = 0;
278   for (auto operand : resultOp.getOperands()) {
279     // Start each operand on its own line.
280     (os << ",\n").indent(6);
281 
282     auto name = resultTree->getArgNameStr(i);
283     if (boundArguments.find(name) == boundArguments.end())
284       PrintFatalError(pattern->getLoc(),
285                       Twine("referencing unbound variable '") + name + "'");
286     if (operand.name)
287       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
288     os << "s." << name;
289     // TODO(jpienaar): verify types
290     ++i;
291   }
292 
293   // Add attributes.
294   for (int e = resultTree->getNumArgs(); i != e; ++i) {
295     // Start each attribute on its own line.
296     (os << ",\n").indent(6);
297 
298     // The argument in the result DAG pattern.
299     auto name = resultTree->getArgNameStr(i);
300     auto opName = resultOp.getArgName(i);
301     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
302     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
303     if (!value) {
304       if (boundArguments.find(name) == boundArguments.end())
305         PrintFatalError(pattern->getLoc(),
306                         Twine("referencing unbound variable '") + name + "'");
307       os << "/*" << opName << "=*/"
308          << "s." << name;
309       continue;
310     }
311 
312     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
313     auto argument = resultOp.getArg(i);
314     if (!argument.is<Operator::Attribute *>())
315       PrintFatalError(pattern->getLoc(),
316                       Twine("expected attribute ") + Twine(i));
317 
318     if (!name.empty())
319       os << "/*" << name << "=*/";
320     emitAttributeValue(defInit->getDef());
321     // TODO(jpienaar): verify types
322   }
323   os << "\n    );\n  }\n};\n";
324 }
325 
326 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
327   Pattern pattern(p, os);
328   pattern.emit(rewriteName);
329 }
330 
331 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
332   emitSourceFileHeader("Rewriters", os);
333   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
334 
335   // Ensure unique patterns simply by appending unique suffix.
336   std::string baseRewriteName = "GeneratedConvert";
337   int rewritePatternCount = 0;
338   for (Record *p : patterns) {
339     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
340   }
341 
342   // Emit function to add the generated matchers to the pattern list.
343   os << "void populateWithGenerated(MLIRContext *context, "
344      << "OwningRewritePatternList *patterns) {\n";
345   for (unsigned i = 0; i != rewritePatternCount; ++i) {
346     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
347                  << i << ">(context));\n";
348   }
349   os << "}\n";
350 }
351 
352 mlir::GenRegistration
353     genRewriters("gen-rewriters", "Generate pattern rewriters",
354                  [](const RecordKeeper &records, raw_ostream &os) {
355                    emitRewriters(records, os);
356                    return false;
357                  });
358