xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 9b034f0bfd641799c5ae16acd4fa53093a83c510)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Operator.h"
24 #include "mlir/TableGen/Predicate.h"
25 #include "mlir/TableGen/Type.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/PrettyStackTrace.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Main.h"
34 #include "llvm/TableGen/Record.h"
35 #include "llvm/TableGen/TableGenBackend.h"
36 
37 using namespace llvm;
38 using namespace mlir;
39 
40 using mlir::tblgen::Attribute;
41 using mlir::tblgen::Operator;
42 using mlir::tblgen::Type;
43 
44 namespace {
45 
46 // Wrapper around dag argument.
47 struct DagArg {
48   DagArg(Init *init) : init(init) {}
49   bool isAttr();
50 
51   Init *init;
52 };
53 
54 } // end namespace
55 
56 bool DagArg::isAttr() {
57   if (auto defInit = dyn_cast<DefInit>(init))
58     return defInit->getDef()->isSubClassOf("Attr");
59   return false;
60 }
61 
62 namespace {
63 class Pattern {
64 public:
65   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
66 
67 private:
68   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
69 
70   // Emit the rewrite pattern named `rewriteName`.
71   void emit(StringRef rewriteName);
72 
73   // Emit the matcher.
74   void emitMatcher(DagInit *tree);
75 
76   // Emits the value of constant attribute to `os`.
77   void emitAttributeValue(Record *constAttr);
78 
79   // Collect bound arguments.
80   void collectBoundArguments(DagInit *tree);
81 
82   // Map from bound argument name to DagArg.
83   StringMap<DagArg> boundArguments;
84 
85   // Number of the operations in the input pattern.
86   int numberOfOpsMatched = 0;
87 
88   Record *pattern;
89   raw_ostream &os;
90 };
91 } // end namespace
92 
93 void Pattern::emitAttributeValue(Record *constAttr) {
94   Attribute attr(constAttr->getValueAsDef("attr"));
95   auto value = constAttr->getValue("value");
96   Type type = attr.getType();
97   auto storageType = attr.getStorageType();
98 
99   // For attributes stored as strings we do not need to query builder etc.
100   if (storageType == "StringAttr") {
101     os << formatv("rewriter.getStringAttr({0})",
102                   value->getValue()->getAsString());
103     return;
104   }
105 
106   auto builder = type.getBuilderCall();
107   if (builder.empty())
108     PrintFatalError(pattern->getLoc(),
109                     "no builder specified for " + type.getTableGenDefName());
110 
111   // Construct the attribute based on storage type and builder.
112   // TODO(jpienaar): Verify the constants here
113   os << formatv("{0}::get(rewriter.{1}, {2})", storageType, builder,
114                 value->getValue()->getAsUnquotedString());
115 }
116 
117 void Pattern::collectBoundArguments(DagInit *tree) {
118   ++numberOfOpsMatched;
119   // TODO(jpienaar): Expand to multiple matches.
120   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
121     auto arg = tree->getArg(i);
122     if (auto argTree = dyn_cast<DagInit>(arg)) {
123       collectBoundArguments(argTree);
124       continue;
125     }
126     auto name = tree->getArgNameStr(i);
127     if (name.empty())
128       continue;
129     boundArguments.try_emplace(name, arg);
130   }
131 }
132 
133 // Helper function to match patterns.
134 static void matchOp(Record *pattern, DagInit *tree, int depth,
135                     raw_ostream &os) {
136   Operator op(cast<DefInit>(tree->getOperator())->getDef());
137   int indent = 4 + 2 * depth;
138   // Skip the operand matching at depth 0 as the pattern rewriter already does.
139   if (depth != 0) {
140     // Skip if there is no defining instruction (e.g., arguments to function).
141     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
142     os.indent(indent) << formatv(
143         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
144         op.qualifiedCppClassName());
145   }
146   if (tree->getNumArgs() != op.getNumArgs())
147     PrintFatalError(pattern->getLoc(),
148                     Twine("mismatch in number of arguments to op '") +
149                         op.getOperationName() +
150                         "' in pattern and op's definition");
151   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
152     auto arg = tree->getArg(i);
153     auto opArg = op.getArg(i);
154 
155     if (auto argTree = dyn_cast<DagInit>(arg)) {
156       os.indent(indent) << "{\n";
157       os.indent(indent + 2) << formatv(
158           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
159           depth + 1, depth, i);
160       matchOp(pattern, argTree, depth + 1, os);
161       os.indent(indent) << "}\n";
162       continue;
163     }
164 
165     // Verify arguments.
166     if (auto defInit = dyn_cast<DefInit>(arg)) {
167       // Verify operands.
168       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
169         // Skip verification where not needed due to definition of op.
170         if (operand->defInit == defInit)
171           goto StateCapture;
172 
173         if (!defInit->getDef()->isSubClassOf("Type"))
174           PrintFatalError(pattern->getLoc(),
175                           "type argument required for operand");
176 
177         auto pred = tblgen::Type(defInit).getPredicate();
178 
179         os.indent(indent)
180             << "if (!("
181             << formatv(pred.createTypeMatcherTemplate().c_str(),
182                        formatv("op{0}->getOperand({1})->getType()", depth, i))
183             << ")) return matchFailure();\n";
184       }
185 
186       // TODO(jpienaar): Verify attributes.
187       if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
188       }
189     }
190 
191   StateCapture:
192     auto name = tree->getArgNameStr(i);
193     if (name.empty())
194       continue;
195     if (opArg.is<Operator::Operand *>())
196       os.indent(indent) << "state->" << name << " = op" << depth
197                         << "->getOperand(" << i << ");\n";
198     if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
199       os.indent(indent) << "state->" << name << " = op" << depth
200                         << "->getAttrOfType<"
201                         << namedAttr->attr.getStorageType() << ">(\""
202                         << namedAttr->getName() << "\");\n";
203     }
204   }
205 }
206 
207 void Pattern::emitMatcher(DagInit *tree) {
208   // Emit the heading.
209   os << R"(
210   PatternMatchResult match(OperationInst *op0) const override {
211     // TODO: This just handle 1 result
212     if (op0->getNumResults() != 1) return matchFailure();
213     auto state = std::make_unique<MatchedState>();)"
214      << "\n";
215   matchOp(pattern, tree, 0, os);
216   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
217 }
218 
219 void Pattern::emit(StringRef rewriteName) {
220   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
221   // Collect bound arguments and compute number of ops matched.
222   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
223   // moment, revise.
224   collectBoundArguments(tree);
225 
226   // Emit RewritePattern for Pattern.
227   DefInit *root = cast<DefInit>(tree->getOperator());
228   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
229   os << formatv(R"(struct {0} : public RewritePattern {
230   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
231                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
232      << "\n";
233 
234   // Emit matched state.
235   os << "  struct MatchedState : public PatternState {\n";
236   for (auto &arg : boundArguments) {
237     if (arg.second.isAttr()) {
238       DefInit *defInit = cast<DefInit>(arg.second.init);
239       os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
240                    << ";\n";
241     } else {
242       os.indent(4) << "Value* " << arg.first() << ";\n";
243     }
244   }
245   os << "  };\n";
246 
247   emitMatcher(tree);
248   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
249   if (resultOps->size() != 1)
250     PrintFatalError("only single result rules supported");
251   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
252 
253   // TODO(jpienaar): Expand to multiple results.
254   for (auto result : resultTree->getArgs()) {
255     if (isa<DagInit>(result))
256       PrintFatalError(pattern->getLoc(), "only single op result supported");
257   }
258 
259   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
260   Operator resultOp(*resultRoot->getDef());
261   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
262 
263   os << formatv(R"(
264   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
265                PatternRewriter &rewriter) const override {
266     auto& s = *static_cast<MatchedState *>(state.get());
267     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
268                 resultOp.cppClassName());
269   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
270     PrintFatalError(pattern->getLoc(),
271                     Twine("mismatch between arguments of resultant op (") +
272                         Twine(resultOperands->getNumArgs()) +
273                         ") and arguments provided for rewrite (" +
274                         Twine(resultTree->getNumArgs()) + Twine(')'));
275   }
276 
277   // Create the builder call for the result.
278   // Add operands.
279   int i = 0;
280   for (auto operand : resultOp.getOperands()) {
281     // Start each operand on its own line.
282     (os << ",\n").indent(6);
283 
284     auto name = resultTree->getArgNameStr(i);
285     if (boundArguments.find(name) == boundArguments.end())
286       PrintFatalError(pattern->getLoc(),
287                       Twine("referencing unbound variable '") + name + "'");
288     if (operand.name)
289       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
290     os << "s." << name;
291     // TODO(jpienaar): verify types
292     ++i;
293   }
294 
295   // Add attributes.
296   for (int e = resultTree->getNumArgs(); i != e; ++i) {
297     // Start each attribute on its own line.
298     (os << ",\n").indent(6);
299 
300     // The argument in the result DAG pattern.
301     auto name = resultTree->getArgNameStr(i);
302     auto opName = resultOp.getArgName(i);
303     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
304     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
305     if (!value) {
306       if (boundArguments.find(name) == boundArguments.end())
307         PrintFatalError(pattern->getLoc(),
308                         Twine("referencing unbound variable '") + name + "'");
309       os << "/*" << opName << "=*/"
310          << "s." << name;
311       continue;
312     }
313 
314     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
315     auto argument = resultOp.getArg(i);
316     if (!argument.is<Operator::NamedAttribute *>())
317       PrintFatalError(pattern->getLoc(),
318                       Twine("expected attribute ") + Twine(i));
319 
320     if (!name.empty())
321       os << "/*" << name << "=*/";
322     emitAttributeValue(defInit->getDef());
323     // TODO(jpienaar): verify types
324   }
325   os << "\n    );\n  }\n};\n";
326 }
327 
328 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
329   Pattern pattern(p, os);
330   pattern.emit(rewriteName);
331 }
332 
333 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
334   emitSourceFileHeader("Rewriters", os);
335   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
336 
337   // Ensure unique patterns simply by appending unique suffix.
338   std::string baseRewriteName = "GeneratedConvert";
339   int rewritePatternCount = 0;
340   for (Record *p : patterns) {
341     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
342   }
343 
344   // Emit function to add the generated matchers to the pattern list.
345   os << "void populateWithGenerated(MLIRContext *context, "
346      << "OwningRewritePatternList *patterns) {\n";
347   for (unsigned i = 0; i != rewritePatternCount; ++i) {
348     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
349                  << i << ">(context));\n";
350   }
351   os << "}\n";
352 }
353 
354 mlir::GenRegistration
355     genRewriters("gen-rewriters", "Generate pattern rewriters",
356                  [](const RecordKeeper &records, raw_ostream &os) {
357                    emitRewriters(records, os);
358                    return false;
359                  });
360