xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 0fbf4ff232cdad9f6131527f2a23019bfb331b9e)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Predicate.h"
26 #include "mlir/TableGen/Type.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
37 
38 using namespace llvm;
39 using namespace mlir;
40 
41 using mlir::tblgen::Argument;
42 using mlir::tblgen::Attribute;
43 using mlir::tblgen::NamedAttribute;
44 using mlir::tblgen::Operand;
45 using mlir::tblgen::Operator;
46 using mlir::tblgen::Type;
47 
48 namespace {
49 
50 // Wrapper around DAG argument.
51 struct DagArg {
52   DagArg(Argument arg, Init *constraintInit)
53       : arg(arg), constraintInit(constraintInit) {}
54   bool isAttr();
55 
56   Argument arg;
57   Init *constraintInit;
58 };
59 
60 } // end namespace
61 
62 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
63 
64 namespace {
65 class Pattern {
66 public:
67   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
68 
69 private:
70   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
71 
72   // Emits the rewrite pattern named `rewriteName`.
73   void emit(StringRef rewriteName);
74 
75   // Emits the matcher.
76   void emitMatcher(DagInit *tree);
77 
78   // Emits the rewrite() method.
79   void emitRewriteMethod();
80 
81   // Emits the C++ statement to replace the matched DAG with an existing value.
82   void emitReplaceWithExistingValue(DagInit *resultTree);
83   // Emits the C++ statement to replace the matched DAG with a new op.
84   void emitReplaceOpWithNewOp(DagInit *resultTree);
85 
86   // Emits the value of constant attribute to `os`.
87   void emitAttributeValue(Record *constAttr);
88 
89   // Collects bound arguments.
90   void collectBoundArguments(DagInit *tree);
91 
92   // Checks whether an argument with the given `name` is bound in source
93   // pattern. Prints fatal error if not; does nothing otherwise.
94   void checkArgumentBound(StringRef name) const;
95 
96   // Helper function to match patterns.
97   void matchOp(DagInit *tree, int depth);
98 
99   // Returns the Operator stored for the given record.
100   Operator &getOperator(const llvm::Record *record);
101 
102   // Map from bound argument name to DagArg.
103   StringMap<DagArg> boundArguments;
104 
105   // Map from Record* to Operator.
106   DenseMap<const llvm::Record *, Operator> opMap;
107 
108   // Number of the operations in the input pattern.
109   int numberOfOpsMatched = 0;
110 
111   Record *pattern;
112   raw_ostream &os;
113 };
114 } // end namespace
115 
116 // Returns the Operator stored for the given record.
117 auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
118   return opMap.try_emplace(record, record).first->second;
119 }
120 
121 void Pattern::emitAttributeValue(Record *constAttr) {
122   Attribute attr(constAttr->getValueAsDef("attr"));
123   auto value = constAttr->getValue("value");
124 
125   if (!attr.isConstBuildable())
126     PrintFatalError(pattern->getLoc(),
127                     "Attribute " + attr.getTableGenDefName() +
128                         " does not have the 'constBuilderCall' field");
129 
130   // TODO(jpienaar): Verify the constants here
131   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
132                 value->getValue()->getAsUnquotedString());
133 }
134 
135 void Pattern::collectBoundArguments(DagInit *tree) {
136   ++numberOfOpsMatched;
137   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
138   // TODO(jpienaar): Expand to multiple matches.
139   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
140     auto arg = tree->getArg(i);
141     if (auto argTree = dyn_cast<DagInit>(arg)) {
142       collectBoundArguments(argTree);
143       continue;
144     }
145     auto name = tree->getArgNameStr(i);
146     if (name.empty())
147       continue;
148     boundArguments.try_emplace(name, op.getArg(i), arg);
149   }
150 }
151 
152 void Pattern::checkArgumentBound(StringRef name) const {
153   if (boundArguments.find(name) == boundArguments.end())
154     PrintFatalError(pattern->getLoc(),
155                     Twine("referencing unbound variable '") + name + "'");
156 }
157 
158 // Helper function to match patterns.
159 void Pattern::matchOp(DagInit *tree, int depth) {
160   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
161   int indent = 4 + 2 * depth;
162   // Skip the operand matching at depth 0 as the pattern rewriter already does.
163   if (depth != 0) {
164     // Skip if there is no defining instruction (e.g., arguments to function).
165     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
166     os.indent(indent) << formatv(
167         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
168         op.qualifiedCppClassName());
169   }
170   if (tree->getNumArgs() != op.getNumArgs())
171     PrintFatalError(pattern->getLoc(),
172                     Twine("mismatch in number of arguments to op '") +
173                         op.getOperationName() +
174                         "' in pattern and op's definition");
175   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
176     auto arg = tree->getArg(i);
177     auto opArg = op.getArg(i);
178 
179     if (auto argTree = dyn_cast<DagInit>(arg)) {
180       os.indent(indent) << "{\n";
181       os.indent(indent + 2) << formatv(
182           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
183           depth + 1, depth, i);
184       matchOp(argTree, depth + 1);
185       os.indent(indent) << "}\n";
186       continue;
187     }
188 
189     // Verify arguments.
190     if (auto defInit = dyn_cast<DefInit>(arg)) {
191       // Verify operands.
192       if (auto *operand = opArg.dyn_cast<Operand *>()) {
193         // Skip verification where not needed due to definition of op.
194         if (operand->defInit == defInit)
195           goto StateCapture;
196 
197         if (!defInit->getDef()->isSubClassOf("Type"))
198           PrintFatalError(pattern->getLoc(),
199                           "type argument required for operand");
200 
201         auto constraint = tblgen::TypeConstraint(*defInit);
202         os.indent(indent)
203             << "if (!("
204             << formatv(constraint.getConditionTemplate().c_str(),
205                        formatv("op{0}->getOperand({1})->getType()", depth, i))
206             << ")) return matchFailure();\n";
207       }
208 
209       // TODO(jpienaar): Verify attributes.
210       if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
211         auto constraint = tblgen::AttrConstraint(defInit);
212         std::string condition = formatv(
213             constraint.getConditionTemplate().c_str(),
214             formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
215                     namedAttr->attr.getStorageType(), namedAttr->getName()));
216         os.indent(indent) << "if (!(" << condition
217                           << ")) return matchFailure();\n";
218       }
219     }
220 
221   StateCapture:
222     auto name = tree->getArgNameStr(i);
223     if (name.empty())
224       continue;
225     if (opArg.is<Operand *>())
226       os.indent(indent) << "state->" << name << " = op" << depth
227                         << "->getOperand(" << i << ");\n";
228     if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
229       os.indent(indent) << "state->" << name << " = op" << depth
230                         << "->getAttrOfType<"
231                         << namedAttr->attr.getStorageType() << ">(\""
232                         << namedAttr->getName() << "\");\n";
233     }
234   }
235 }
236 
237 void Pattern::emitMatcher(DagInit *tree) {
238   // Emit the heading.
239   os << R"(
240   PatternMatchResult match(OperationInst *op0) const override {
241     // TODO: This just handle 1 result
242     if (op0->getNumResults() != 1) return matchFailure();
243     auto state = std::make_unique<MatchedState>();)"
244      << "\n";
245   matchOp(tree, 0);
246   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
247 }
248 
249 void Pattern::emit(StringRef rewriteName) {
250   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
251   // Collect bound arguments and compute number of ops matched.
252   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
253   // moment, revise.
254   collectBoundArguments(tree);
255 
256   // Emit RewritePattern for Pattern.
257   DefInit *root = cast<DefInit>(tree->getOperator());
258   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
259   os << formatv(R"(struct {0} : public RewritePattern {
260   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
261                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
262      << "\n";
263 
264   // Emit matched state.
265   os << "  struct MatchedState : public PatternState {\n";
266   for (auto &arg : boundArguments) {
267     if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
268       os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
269                    << ";\n";
270     } else {
271       os.indent(4) << "Value* " << arg.first() << ";\n";
272     }
273   }
274   os << "  };\n";
275 
276   emitMatcher(tree);
277   emitRewriteMethod();
278 
279   os << "};\n";
280 }
281 
282 void Pattern::emitRewriteMethod() {
283   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
284   if (resultOps->size() != 1)
285     PrintFatalError("only single result rules supported");
286   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
287 
288   // TODO(jpienaar): Expand to multiple results.
289   for (auto result : resultTree->getArgs()) {
290     if (isa<DagInit>(result))
291       PrintFatalError(pattern->getLoc(), "only single op result supported");
292   }
293 
294   os << R"(
295   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
296                PatternRewriter &rewriter) const override {
297     auto& s = *static_cast<MatchedState *>(state.get());
298 )";
299 
300   auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
301   if (dagOpDef->getName() == "replaceWithValue")
302     emitReplaceWithExistingValue(resultTree);
303   else
304     emitReplaceOpWithNewOp(resultTree);
305 
306   os << "  }\n";
307 }
308 
309 void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
310   if (resultTree->getNumArgs() != 1) {
311     PrintFatalError(pattern->getLoc(),
312                     "exactly one argument needed in the result pattern");
313   }
314 
315   auto name = resultTree->getArgNameStr(0);
316   checkArgumentBound(name);
317   os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
318 }
319 
320 void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
321   DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
322   Operator &resultOp = getOperator(dagOperator->getDef());
323   auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
324 
325   os << formatv(R"(
326     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
327                 resultOp.cppClassName());
328   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
329     PrintFatalError(pattern->getLoc(),
330                     Twine("mismatch between arguments of resultant op (") +
331                         Twine(resultOperands->getNumArgs()) +
332                         ") and arguments provided for rewrite (" +
333                         Twine(resultTree->getNumArgs()) + Twine(')'));
334   }
335 
336   // Create the builder call for the result.
337   // Add operands.
338   int i = 0;
339   for (auto operand : resultOp.getOperands()) {
340     // Start each operand on its own line.
341     (os << ",\n").indent(6);
342 
343     auto name = resultTree->getArgNameStr(i);
344     checkArgumentBound(name);
345     if (operand.name)
346       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
347     os << "s." << name;
348     // TODO(jpienaar): verify types
349     ++i;
350   }
351 
352   // Add attributes.
353   for (int e = resultTree->getNumArgs(); i != e; ++i) {
354     // Start each attribute on its own line.
355     (os << ",\n").indent(6);
356 
357     // The argument in the result DAG pattern.
358     auto name = resultTree->getArgNameStr(i);
359     auto opName = resultOp.getArgName(i);
360     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
361     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
362     if (!value) {
363       checkArgumentBound(name);
364       auto result = "s." + name;
365       os << "/*" << opName << "=*/";
366       if (defInit) {
367         auto transform = defInit->getDef();
368         if (transform->isSubClassOf("tAttr")) {
369           // TODO(jpienaar): move to helper class.
370           os << formatv(
371               transform->getValueAsString("attrTransform").str().c_str(),
372               result);
373           continue;
374         }
375       }
376       os << result;
377       continue;
378     }
379 
380     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
381     auto argument = resultOp.getArg(i);
382     if (!argument.is<NamedAttribute *>())
383       PrintFatalError(pattern->getLoc(),
384                       Twine("expected attribute ") + Twine(i));
385 
386     if (!name.empty())
387       os << "/*" << name << "=*/";
388     emitAttributeValue(defInit->getDef());
389     // TODO(jpienaar): verify types
390   }
391   os << "\n    );\n";
392 }
393 
394 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
395   Pattern pattern(p, os);
396   pattern.emit(rewriteName);
397 }
398 
399 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
400   emitSourceFileHeader("Rewriters", os);
401   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
402 
403   // Ensure unique patterns simply by appending unique suffix.
404   std::string baseRewriteName = "GeneratedConvert";
405   int rewritePatternCount = 0;
406   for (Record *p : patterns) {
407     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
408   }
409 
410   // Emit function to add the generated matchers to the pattern list.
411   os << "void populateWithGenerated(MLIRContext *context, "
412      << "OwningRewritePatternList *patterns) {\n";
413   for (unsigned i = 0; i != rewritePatternCount; ++i) {
414     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
415                  << i << ">(context));\n";
416   }
417   os << "}\n";
418 }
419 
420 static mlir::GenRegistration
421     genRewriters("gen-rewriters", "Generate pattern rewriters",
422                  [](const RecordKeeper &records, raw_ostream &os) {
423                    emitRewriters(records, os);
424                    return false;
425                  });
426