xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 2de5e9fd196ab8783552e842192f9c7832169b78)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Predicate.h"
26 #include "mlir/TableGen/Type.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
37 
38 using namespace llvm;
39 using namespace mlir;
40 
41 using mlir::tblgen::Attribute;
42 using mlir::tblgen::Operator;
43 using mlir::tblgen::Type;
44 
45 namespace {
46 
47 // Wrapper around DAG argument.
48 struct DagArg {
49   DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit)
50       : arg(arg), constraintInit(constraintInit) {}
51   bool isAttr();
52 
53   mlir::tblgen::Operator::Argument arg;
54   Init *constraintInit;
55 };
56 
57 } // end namespace
58 
59 bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); }
60 
61 namespace {
62 class Pattern {
63 public:
64   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
65 
66 private:
67   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
68 
69   // Emits the rewrite pattern named `rewriteName`.
70   void emit(StringRef rewriteName);
71 
72   // Emits the matcher.
73   void emitMatcher(DagInit *tree);
74 
75   // Emits the rewrite() method.
76   void emitRewriteMethod();
77 
78   // Emits the C++ statement to replace the matched DAG with an existing value.
79   void emitReplaceWithExistingValue(DagInit *resultTree);
80   // Emits the C++ statement to replace the matched DAG with a new op.
81   void emitReplaceOpWithNewOp(DagInit *resultTree);
82 
83   // Emits the value of constant attribute to `os`.
84   void emitAttributeValue(Record *constAttr);
85 
86   // Collects bound arguments.
87   void collectBoundArguments(DagInit *tree);
88 
89   // Checks whether an argument with the given `name` is bound in source
90   // pattern. Prints fatal error if not; does nothing otherwise.
91   void checkArgumentBound(StringRef name) const;
92 
93   // Helper function to match patterns.
94   void matchOp(DagInit *tree, int depth);
95 
96   // Returns the Operator stored for the given record.
97   Operator &getOperator(const llvm::Record *record);
98 
99   // Map from bound argument name to DagArg.
100   StringMap<DagArg> boundArguments;
101 
102   // Map from Record* to Operator.
103   DenseMap<const llvm::Record *, Operator> opMap;
104 
105   // Number of the operations in the input pattern.
106   int numberOfOpsMatched = 0;
107 
108   Record *pattern;
109   raw_ostream &os;
110 };
111 } // end namespace
112 
113 // Returns the Operator stored for the given record.
114 auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
115   return opMap.try_emplace(record, record).first->second;
116 }
117 
118 void Pattern::emitAttributeValue(Record *constAttr) {
119   Attribute attr(constAttr->getValueAsDef("attr"));
120   auto value = constAttr->getValue("value");
121 
122   if (!attr.isConstBuildable())
123     PrintFatalError(pattern->getLoc(),
124                     "Attribute " + attr.getTableGenDefName() +
125                         " does not have the 'constBuilderCall' field");
126 
127   // TODO(jpienaar): Verify the constants here
128   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
129                 value->getValue()->getAsUnquotedString());
130 }
131 
132 void Pattern::collectBoundArguments(DagInit *tree) {
133   ++numberOfOpsMatched;
134   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
135   // TODO(jpienaar): Expand to multiple matches.
136   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
137     auto arg = tree->getArg(i);
138     if (auto argTree = dyn_cast<DagInit>(arg)) {
139       collectBoundArguments(argTree);
140       continue;
141     }
142     auto name = tree->getArgNameStr(i);
143     if (name.empty())
144       continue;
145     boundArguments.try_emplace(name, op.getArg(i), arg);
146   }
147 }
148 
149 void Pattern::checkArgumentBound(StringRef name) const {
150   if (boundArguments.find(name) == boundArguments.end())
151     PrintFatalError(pattern->getLoc(),
152                     Twine("referencing unbound variable '") + name + "'");
153 }
154 
155 // Helper function to match patterns.
156 void Pattern::matchOp(DagInit *tree, int depth) {
157   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
158   int indent = 4 + 2 * depth;
159   // Skip the operand matching at depth 0 as the pattern rewriter already does.
160   if (depth != 0) {
161     // Skip if there is no defining instruction (e.g., arguments to function).
162     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
163     os.indent(indent) << formatv(
164         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
165         op.qualifiedCppClassName());
166   }
167   if (tree->getNumArgs() != op.getNumArgs())
168     PrintFatalError(pattern->getLoc(),
169                     Twine("mismatch in number of arguments to op '") +
170                         op.getOperationName() +
171                         "' in pattern and op's definition");
172   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
173     auto arg = tree->getArg(i);
174     auto opArg = op.getArg(i);
175 
176     if (auto argTree = dyn_cast<DagInit>(arg)) {
177       os.indent(indent) << "{\n";
178       os.indent(indent + 2) << formatv(
179           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
180           depth + 1, depth, i);
181       matchOp(argTree, depth + 1);
182       os.indent(indent) << "}\n";
183       continue;
184     }
185 
186     // Verify arguments.
187     if (auto defInit = dyn_cast<DefInit>(arg)) {
188       // Verify operands.
189       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
190         // Skip verification where not needed due to definition of op.
191         if (operand->defInit == defInit)
192           goto StateCapture;
193 
194         if (!defInit->getDef()->isSubClassOf("Type"))
195           PrintFatalError(pattern->getLoc(),
196                           "type argument required for operand");
197 
198         auto constraint = tblgen::TypeConstraint(*defInit);
199         os.indent(indent)
200             << "if (!("
201             << formatv(constraint.getConditionTemplate().c_str(),
202                        formatv("op{0}->getOperand({1})->getType()", depth, i))
203             << ")) return matchFailure();\n";
204       }
205 
206       // TODO(jpienaar): Verify attributes.
207       if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
208         // TODO(jpienaar): move to helper class.
209         if (defInit->getDef()->isSubClassOf("mAttr")) {
210           auto pred =
211               tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
212           os.indent(indent)
213               << "if (!("
214               << formatv(pred.getCondition().c_str(),
215                          formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
216                                  namedAttr->attr.getStorageType(),
217                                  namedAttr->getName()))
218               << ")) return matchFailure();\n";
219         }
220       }
221     }
222 
223   StateCapture:
224     auto name = tree->getArgNameStr(i);
225     if (name.empty())
226       continue;
227     if (opArg.is<Operator::Operand *>())
228       os.indent(indent) << "state->" << name << " = op" << depth
229                         << "->getOperand(" << i << ");\n";
230     if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
231       os.indent(indent) << "state->" << name << " = op" << depth
232                         << "->getAttrOfType<"
233                         << namedAttr->attr.getStorageType() << ">(\""
234                         << namedAttr->getName() << "\");\n";
235     }
236   }
237 }
238 
239 void Pattern::emitMatcher(DagInit *tree) {
240   // Emit the heading.
241   os << R"(
242   PatternMatchResult match(OperationInst *op0) const override {
243     // TODO: This just handle 1 result
244     if (op0->getNumResults() != 1) return matchFailure();
245     auto state = std::make_unique<MatchedState>();)"
246      << "\n";
247   matchOp(tree, 0);
248   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
249 }
250 
251 void Pattern::emit(StringRef rewriteName) {
252   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
253   // Collect bound arguments and compute number of ops matched.
254   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
255   // moment, revise.
256   collectBoundArguments(tree);
257 
258   // Emit RewritePattern for Pattern.
259   DefInit *root = cast<DefInit>(tree->getOperator());
260   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
261   os << formatv(R"(struct {0} : public RewritePattern {
262   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
263                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
264      << "\n";
265 
266   // Emit matched state.
267   os << "  struct MatchedState : public PatternState {\n";
268   for (auto &arg : boundArguments) {
269     if (auto namedAttr =
270             arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) {
271       os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
272                    << ";\n";
273     } else {
274       os.indent(4) << "Value* " << arg.first() << ";\n";
275     }
276   }
277   os << "  };\n";
278 
279   emitMatcher(tree);
280   emitRewriteMethod();
281 
282   os << "};\n";
283 }
284 
285 void Pattern::emitRewriteMethod() {
286   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
287   if (resultOps->size() != 1)
288     PrintFatalError("only single result rules supported");
289   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
290 
291   // TODO(jpienaar): Expand to multiple results.
292   for (auto result : resultTree->getArgs()) {
293     if (isa<DagInit>(result))
294       PrintFatalError(pattern->getLoc(), "only single op result supported");
295   }
296 
297   os << R"(
298   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
299                PatternRewriter &rewriter) const override {
300     auto& s = *static_cast<MatchedState *>(state.get());
301 )";
302 
303   auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
304   if (dagOpDef->getName() == "replaceWithValue")
305     emitReplaceWithExistingValue(resultTree);
306   else
307     emitReplaceOpWithNewOp(resultTree);
308 
309   os << "  }\n";
310 }
311 
312 void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
313   if (resultTree->getNumArgs() != 1) {
314     PrintFatalError(pattern->getLoc(),
315                     "exactly one argument needed in the result pattern");
316   }
317 
318   auto name = resultTree->getArgNameStr(0);
319   checkArgumentBound(name);
320   os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
321 }
322 
323 void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
324   DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
325   Operator &resultOp = getOperator(dagOperator->getDef());
326   auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
327 
328   os << formatv(R"(
329     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
330                 resultOp.cppClassName());
331   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
332     PrintFatalError(pattern->getLoc(),
333                     Twine("mismatch between arguments of resultant op (") +
334                         Twine(resultOperands->getNumArgs()) +
335                         ") and arguments provided for rewrite (" +
336                         Twine(resultTree->getNumArgs()) + Twine(')'));
337   }
338 
339   // Create the builder call for the result.
340   // Add operands.
341   int i = 0;
342   for (auto operand : resultOp.getOperands()) {
343     // Start each operand on its own line.
344     (os << ",\n").indent(6);
345 
346     auto name = resultTree->getArgNameStr(i);
347     checkArgumentBound(name);
348     if (operand.name)
349       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
350     os << "s." << name;
351     // TODO(jpienaar): verify types
352     ++i;
353   }
354 
355   // Add attributes.
356   for (int e = resultTree->getNumArgs(); i != e; ++i) {
357     // Start each attribute on its own line.
358     (os << ",\n").indent(6);
359 
360     // The argument in the result DAG pattern.
361     auto name = resultTree->getArgNameStr(i);
362     auto opName = resultOp.getArgName(i);
363     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
364     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
365     if (!value) {
366       checkArgumentBound(name);
367       auto result = "s." + name;
368       os << "/*" << opName << "=*/";
369       if (defInit) {
370         auto transform = defInit->getDef();
371         if (transform->isSubClassOf("tAttr")) {
372           // TODO(jpienaar): move to helper class.
373           os << formatv(
374               transform->getValueAsString("attrTransform").str().c_str(),
375               result);
376           continue;
377         }
378       }
379       os << result;
380       continue;
381     }
382 
383     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
384     auto argument = resultOp.getArg(i);
385     if (!argument.is<Operator::NamedAttribute *>())
386       PrintFatalError(pattern->getLoc(),
387                       Twine("expected attribute ") + Twine(i));
388 
389     if (!name.empty())
390       os << "/*" << name << "=*/";
391     emitAttributeValue(defInit->getDef());
392     // TODO(jpienaar): verify types
393   }
394   os << "\n    );\n";
395 }
396 
397 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
398   Pattern pattern(p, os);
399   pattern.emit(rewriteName);
400 }
401 
402 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
403   emitSourceFileHeader("Rewriters", os);
404   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
405 
406   // Ensure unique patterns simply by appending unique suffix.
407   std::string baseRewriteName = "GeneratedConvert";
408   int rewritePatternCount = 0;
409   for (Record *p : patterns) {
410     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
411   }
412 
413   // Emit function to add the generated matchers to the pattern list.
414   os << "void populateWithGenerated(MLIRContext *context, "
415      << "OwningRewritePatternList *patterns) {\n";
416   for (unsigned i = 0; i != rewritePatternCount; ++i) {
417     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
418                  << i << ">(context));\n";
419   }
420   os << "}\n";
421 }
422 
423 static mlir::GenRegistration
424     genRewriters("gen-rewriters", "Generate pattern rewriters",
425                  [](const RecordKeeper &records, raw_ostream &os) {
426                    emitRewriters(records, os);
427                    return false;
428                  });
429