xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision bd161ae5bcb87c72b94ea09ac06ca240f8b2b3e3)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Predicate.h"
26 #include "mlir/TableGen/Type.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
37 
38 using namespace llvm;
39 using namespace mlir;
40 
41 using mlir::tblgen::Attribute;
42 using mlir::tblgen::Operator;
43 using mlir::tblgen::Type;
44 
45 namespace {
46 
47 // Wrapper around dag argument.
48 struct DagArg {
49   DagArg(Init *init) : init(init) {}
50   bool isAttr();
51 
52   Init *init;
53 };
54 
55 } // end namespace
56 
57 bool DagArg::isAttr() {
58   if (auto defInit = dyn_cast<DefInit>(init))
59     return defInit->getDef()->isSubClassOf("Attr");
60   return false;
61 }
62 
63 namespace {
64 class Pattern {
65 public:
66   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
67 
68 private:
69   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
70 
71   // Emit the rewrite pattern named `rewriteName`.
72   void emit(StringRef rewriteName);
73 
74   // Emit the matcher.
75   void emitMatcher(DagInit *tree);
76 
77   // Emits the value of constant attribute to `os`.
78   void emitAttributeValue(Record *constAttr);
79 
80   // Collect bound arguments.
81   void collectBoundArguments(DagInit *tree);
82 
83   // Map from bound argument name to DagArg.
84   StringMap<DagArg> boundArguments;
85 
86   // Number of the operations in the input pattern.
87   int numberOfOpsMatched = 0;
88 
89   Record *pattern;
90   raw_ostream &os;
91 };
92 } // end namespace
93 
94 void Pattern::emitAttributeValue(Record *constAttr) {
95   Attribute attr(constAttr->getValueAsDef("attr"));
96   auto value = constAttr->getValue("value");
97 
98   if (!attr.isConstBuildable())
99     PrintFatalError(pattern->getLoc(),
100                     "Attribute " + attr.getTableGenDefName() +
101                         " does not have the 'constBuilderCall' field");
102 
103   // TODO(jpienaar): Verify the constants here
104   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
105                 value->getValue()->getAsUnquotedString());
106 }
107 
108 void Pattern::collectBoundArguments(DagInit *tree) {
109   ++numberOfOpsMatched;
110   // TODO(jpienaar): Expand to multiple matches.
111   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
112     auto arg = tree->getArg(i);
113     if (auto argTree = dyn_cast<DagInit>(arg)) {
114       collectBoundArguments(argTree);
115       continue;
116     }
117     auto name = tree->getArgNameStr(i);
118     if (name.empty())
119       continue;
120     boundArguments.try_emplace(name, arg);
121   }
122 }
123 
124 // Helper function to match patterns.
125 static void matchOp(Record *pattern, DagInit *tree, int depth,
126                     raw_ostream &os) {
127   Operator op(cast<DefInit>(tree->getOperator())->getDef());
128   int indent = 4 + 2 * depth;
129   // Skip the operand matching at depth 0 as the pattern rewriter already does.
130   if (depth != 0) {
131     // Skip if there is no defining instruction (e.g., arguments to function).
132     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
133     os.indent(indent) << formatv(
134         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
135         op.qualifiedCppClassName());
136   }
137   if (tree->getNumArgs() != op.getNumArgs())
138     PrintFatalError(pattern->getLoc(),
139                     Twine("mismatch in number of arguments to op '") +
140                         op.getOperationName() +
141                         "' in pattern and op's definition");
142   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
143     auto arg = tree->getArg(i);
144     auto opArg = op.getArg(i);
145 
146     if (auto argTree = dyn_cast<DagInit>(arg)) {
147       os.indent(indent) << "{\n";
148       os.indent(indent + 2) << formatv(
149           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
150           depth + 1, depth, i);
151       matchOp(pattern, argTree, depth + 1, os);
152       os.indent(indent) << "}\n";
153       continue;
154     }
155 
156     // Verify arguments.
157     if (auto defInit = dyn_cast<DefInit>(arg)) {
158       // Verify operands.
159       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
160         // Skip verification where not needed due to definition of op.
161         if (operand->defInit == defInit)
162           goto StateCapture;
163 
164         if (!defInit->getDef()->isSubClassOf("Type"))
165           PrintFatalError(pattern->getLoc(),
166                           "type argument required for operand");
167 
168         auto constraint = tblgen::TypeConstraint(*defInit);
169         os.indent(indent)
170             << "if (!("
171             << formatv(constraint.getConditionTemplate().str().c_str(),
172                        formatv("op{0}->getOperand({1})->getType()", depth, i))
173             << ")) return matchFailure();\n";
174       }
175 
176       // TODO(jpienaar): Verify attributes.
177       if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
178       }
179     }
180 
181   StateCapture:
182     auto name = tree->getArgNameStr(i);
183     if (name.empty())
184       continue;
185     if (opArg.is<Operator::Operand *>())
186       os.indent(indent) << "state->" << name << " = op" << depth
187                         << "->getOperand(" << i << ");\n";
188     if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
189       os.indent(indent) << "state->" << name << " = op" << depth
190                         << "->getAttrOfType<"
191                         << namedAttr->attr.getStorageType() << ">(\""
192                         << namedAttr->getName() << "\");\n";
193     }
194   }
195 }
196 
197 void Pattern::emitMatcher(DagInit *tree) {
198   // Emit the heading.
199   os << R"(
200   PatternMatchResult match(OperationInst *op0) const override {
201     // TODO: This just handle 1 result
202     if (op0->getNumResults() != 1) return matchFailure();
203     auto state = std::make_unique<MatchedState>();)"
204      << "\n";
205   matchOp(pattern, tree, 0, os);
206   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
207 }
208 
209 void Pattern::emit(StringRef rewriteName) {
210   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
211   // Collect bound arguments and compute number of ops matched.
212   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
213   // moment, revise.
214   collectBoundArguments(tree);
215 
216   // Emit RewritePattern for Pattern.
217   DefInit *root = cast<DefInit>(tree->getOperator());
218   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
219   os << formatv(R"(struct {0} : public RewritePattern {
220   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
221                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
222      << "\n";
223 
224   // Emit matched state.
225   os << "  struct MatchedState : public PatternState {\n";
226   for (auto &arg : boundArguments) {
227     if (arg.second.isAttr()) {
228       DefInit *defInit = cast<DefInit>(arg.second.init);
229       os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
230                    << ";\n";
231     } else {
232       os.indent(4) << "Value* " << arg.first() << ";\n";
233     }
234   }
235   os << "  };\n";
236 
237   emitMatcher(tree);
238   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
239   if (resultOps->size() != 1)
240     PrintFatalError("only single result rules supported");
241   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
242 
243   // TODO(jpienaar): Expand to multiple results.
244   for (auto result : resultTree->getArgs()) {
245     if (isa<DagInit>(result))
246       PrintFatalError(pattern->getLoc(), "only single op result supported");
247   }
248 
249   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
250   Operator resultOp(*resultRoot->getDef());
251   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
252 
253   os << formatv(R"(
254   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
255                PatternRewriter &rewriter) const override {
256     auto& s = *static_cast<MatchedState *>(state.get());
257     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
258                 resultOp.cppClassName());
259   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
260     PrintFatalError(pattern->getLoc(),
261                     Twine("mismatch between arguments of resultant op (") +
262                         Twine(resultOperands->getNumArgs()) +
263                         ") and arguments provided for rewrite (" +
264                         Twine(resultTree->getNumArgs()) + Twine(')'));
265   }
266 
267   // Create the builder call for the result.
268   // Add operands.
269   int i = 0;
270   for (auto operand : resultOp.getOperands()) {
271     // Start each operand on its own line.
272     (os << ",\n").indent(6);
273 
274     auto name = resultTree->getArgNameStr(i);
275     if (boundArguments.find(name) == boundArguments.end())
276       PrintFatalError(pattern->getLoc(),
277                       Twine("referencing unbound variable '") + name + "'");
278     if (operand.name)
279       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
280     os << "s." << name;
281     // TODO(jpienaar): verify types
282     ++i;
283   }
284 
285   // Add attributes.
286   for (int e = resultTree->getNumArgs(); i != e; ++i) {
287     // Start each attribute on its own line.
288     (os << ",\n").indent(6);
289 
290     // The argument in the result DAG pattern.
291     auto name = resultTree->getArgNameStr(i);
292     auto opName = resultOp.getArgName(i);
293     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
294     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
295     if (!value) {
296       if (boundArguments.find(name) == boundArguments.end())
297         PrintFatalError(pattern->getLoc(),
298                         Twine("referencing unbound variable '") + name + "'");
299       os << "/*" << opName << "=*/"
300          << "s." << name;
301       continue;
302     }
303 
304     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
305     auto argument = resultOp.getArg(i);
306     if (!argument.is<Operator::NamedAttribute *>())
307       PrintFatalError(pattern->getLoc(),
308                       Twine("expected attribute ") + Twine(i));
309 
310     if (!name.empty())
311       os << "/*" << name << "=*/";
312     emitAttributeValue(defInit->getDef());
313     // TODO(jpienaar): verify types
314   }
315   os << "\n    );\n  }\n};\n";
316 }
317 
318 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
319   Pattern pattern(p, os);
320   pattern.emit(rewriteName);
321 }
322 
323 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
324   emitSourceFileHeader("Rewriters", os);
325   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
326 
327   // Ensure unique patterns simply by appending unique suffix.
328   std::string baseRewriteName = "GeneratedConvert";
329   int rewritePatternCount = 0;
330   for (Record *p : patterns) {
331     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
332   }
333 
334   // Emit function to add the generated matchers to the pattern list.
335   os << "void populateWithGenerated(MLIRContext *context, "
336      << "OwningRewritePatternList *patterns) {\n";
337   for (unsigned i = 0; i != rewritePatternCount; ++i) {
338     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
339                  << i << ">(context));\n";
340   }
341   os << "}\n";
342 }
343 
344 mlir::GenRegistration
345     genRewriters("gen-rewriters", "Generate pattern rewriters",
346                  [](const RecordKeeper &records, raw_ostream &os) {
347                    emitRewriters(records, os);
348                    return false;
349                  });
350