xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision b2cc2c344e6d295d1ad09221935e3430dd5bfb3c)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Operator.h"
24 #include "mlir/TableGen/Predicate.h"
25 #include "mlir/TableGen/Type.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/PrettyStackTrace.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Main.h"
34 #include "llvm/TableGen/Record.h"
35 #include "llvm/TableGen/TableGenBackend.h"
36 
37 using namespace llvm;
38 using namespace mlir;
39 
40 namespace {
41 
42 // Wrapper around dag argument.
43 struct DagArg {
44   DagArg(Init *init) : init(init) {}
45   bool isAttr();
46 
47   Init *init;
48 };
49 
50 } // end namespace
51 
52 bool DagArg::isAttr() {
53   if (auto defInit = dyn_cast<DefInit>(init))
54     return defInit->getDef()->isSubClassOf("Attr");
55   return false;
56 }
57 
58 namespace {
59 class Pattern {
60 public:
61   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
62 
63 private:
64   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
65 
66   // Emit the rewrite pattern named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69   // Emit the matcher.
70   void emitMatcher(DagInit *tree);
71 
72   // Emits the value of constant attribute to `os`.
73   void emitAttributeValue(Record *constAttr);
74 
75   // Collect bound arguments.
76   void collectBoundArguments(DagInit *tree);
77 
78   // Map from bound argument name to DagArg.
79   StringMap<DagArg> boundArguments;
80 
81   // Number of the operations in the input pattern.
82   int numberOfOpsMatched = 0;
83 
84   Record *pattern;
85   raw_ostream &os;
86 };
87 } // end namespace
88 
89 void Pattern::emitAttributeValue(Record *constAttr) {
90   Record *attr = constAttr->getValueAsDef("attr");
91   auto value = constAttr->getValue("value");
92   tblgen::Type type(attr->getValueAsDef("type"));
93   auto storageType = attr->getValueAsString("storageType").trim();
94 
95   // For attributes stored as strings we do not need to query builder etc.
96   if (storageType == "StringAttr") {
97     os << formatv("rewriter.getStringAttr({0})",
98                   value->getValue()->getAsString());
99     return;
100   }
101 
102   auto builder = type.getBuilderCall();
103   if (builder.empty())
104     PrintFatalError(pattern->getLoc(),
105                     "no builder specified for " + type.getTableGenDefName());
106 
107   // Construct the attribute based on storage type and builder.
108   // TODO(jpienaar): Verify the constants here
109   os << formatv("{0}::get(rewriter.{1}, {2})", storageType, builder,
110                 value->getValue()->getAsUnquotedString());
111 }
112 
113 void Pattern::collectBoundArguments(DagInit *tree) {
114   ++numberOfOpsMatched;
115   // TODO(jpienaar): Expand to multiple matches.
116   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
117     auto arg = tree->getArg(i);
118     if (auto argTree = dyn_cast<DagInit>(arg)) {
119       collectBoundArguments(argTree);
120       continue;
121     }
122     auto name = tree->getArgNameStr(i);
123     if (name.empty())
124       continue;
125     boundArguments.try_emplace(name, arg);
126   }
127 }
128 
129 // Helper function to match patterns.
130 static void matchOp(Record *pattern, DagInit *tree, int depth,
131                     raw_ostream &os) {
132   Operator op(cast<DefInit>(tree->getOperator())->getDef());
133   int indent = 4 + 2 * depth;
134   // Skip the operand matching at depth 0 as the pattern rewriter already does.
135   if (depth != 0) {
136     // Skip if there is no defining instruction (e.g., arguments to function).
137     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
138     os.indent(indent) << formatv(
139         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
140         op.qualifiedCppClassName());
141   }
142   if (tree->getNumArgs() != op.getNumArgs())
143     PrintFatalError(pattern->getLoc(),
144                     Twine("mismatch in number of arguments to op '") +
145                         op.getOperationName() +
146                         "' in pattern and op's definition");
147   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
148     auto arg = tree->getArg(i);
149     auto opArg = op.getArg(i);
150 
151     if (auto argTree = dyn_cast<DagInit>(arg)) {
152       os.indent(indent) << "{\n";
153       os.indent(indent + 2) << formatv(
154           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
155           depth + 1, depth, i);
156       matchOp(pattern, argTree, depth + 1, os);
157       os.indent(indent) << "}\n";
158       continue;
159     }
160 
161     // Verify arguments.
162     if (auto defInit = dyn_cast<DefInit>(arg)) {
163       // Verify operands.
164       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
165         // Skip verification where not needed due to definition of op.
166         if (operand->defInit == defInit)
167           goto StateCapture;
168 
169         if (!defInit->getDef()->isSubClassOf("Type"))
170           PrintFatalError(pattern->getLoc(),
171                           "type argument required for operand");
172 
173         auto pred = tblgen::Type(defInit).getPredicate();
174 
175         os.indent(indent)
176             << "if (!("
177             << formatv(pred.createTypeMatcherTemplate().c_str(),
178                        formatv("op{0}->getOperand({1})->getType()", depth, i))
179             << ")) return matchFailure();\n";
180       }
181 
182       // TODO(jpienaar): Verify attributes.
183       if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
184       }
185     }
186 
187   StateCapture:
188     auto name = tree->getArgNameStr(i);
189     if (name.empty())
190       continue;
191     if (opArg.is<Operator::Operand *>())
192       os.indent(indent) << "state->" << name << " = op" << depth
193                         << "->getOperand(" << i << ");\n";
194     if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
195       os.indent(indent) << "state->" << name << " = op" << depth
196                         << "->getAttrOfType<" << attr->getStorageType()
197                         << ">(\"" << attr->getName() << "\");\n";
198     }
199   }
200 }
201 
202 void Pattern::emitMatcher(DagInit *tree) {
203   // Emit the heading.
204   os << R"(
205   PatternMatchResult match(OperationInst *op0) const override {
206     // TODO: This just handle 1 result
207     if (op0->getNumResults() != 1) return matchFailure();
208     auto state = std::make_unique<MatchedState>();)"
209      << "\n";
210   matchOp(pattern, tree, 0, os);
211   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
212 }
213 
214 void Pattern::emit(StringRef rewriteName) {
215   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
216   // Collect bound arguments and compute number of ops matched.
217   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
218   // moment, revise.
219   collectBoundArguments(tree);
220 
221   // Emit RewritePattern for Pattern.
222   DefInit *root = cast<DefInit>(tree->getOperator());
223   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
224   os << formatv(R"(struct {0} : public RewritePattern {
225   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
226                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
227      << "\n";
228 
229   // Emit matched state.
230   os << "  struct MatchedState : public PatternState {\n";
231   for (auto &arg : boundArguments) {
232     if (arg.second.isAttr()) {
233       DefInit *defInit = cast<DefInit>(arg.second.init);
234       os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
235                    << " " << arg.first() << ";\n";
236     } else {
237       os.indent(4) << "Value* " << arg.first() << ";\n";
238     }
239   }
240   os << "  };\n";
241 
242   emitMatcher(tree);
243   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
244   if (resultOps->size() != 1)
245     PrintFatalError("only single result rules supported");
246   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
247 
248   // TODO(jpienaar): Expand to multiple results.
249   for (auto result : resultTree->getArgs()) {
250     if (isa<DagInit>(result))
251       PrintFatalError(pattern->getLoc(), "only single op result supported");
252   }
253 
254   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
255   Operator resultOp(*resultRoot->getDef());
256   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
257 
258   os << formatv(R"(
259   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
260                PatternRewriter &rewriter) const override {
261     auto& s = *static_cast<MatchedState *>(state.get());
262     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
263                 resultOp.cppClassName());
264   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
265     PrintFatalError(pattern->getLoc(),
266                     Twine("mismatch between arguments of resultant op (") +
267                         Twine(resultOperands->getNumArgs()) +
268                         ") and arguments provided for rewrite (" +
269                         Twine(resultTree->getNumArgs()) + Twine(')'));
270   }
271 
272   // Create the builder call for the result.
273   // Add operands.
274   int i = 0;
275   for (auto operand : resultOp.getOperands()) {
276     // Start each operand on its own line.
277     (os << ",\n").indent(6);
278 
279     auto name = resultTree->getArgNameStr(i);
280     if (boundArguments.find(name) == boundArguments.end())
281       PrintFatalError(pattern->getLoc(),
282                       Twine("referencing unbound variable '") + name + "'");
283     if (operand.name)
284       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
285     os << "s." << name;
286     // TODO(jpienaar): verify types
287     ++i;
288   }
289 
290   // Add attributes.
291   for (int e = resultTree->getNumArgs(); i != e; ++i) {
292     // Start each attribute on its own line.
293     (os << ",\n").indent(6);
294 
295     // The argument in the result DAG pattern.
296     auto name = resultTree->getArgNameStr(i);
297     auto opName = resultOp.getArgName(i);
298     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
299     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
300     if (!value) {
301       if (boundArguments.find(name) == boundArguments.end())
302         PrintFatalError(pattern->getLoc(),
303                         Twine("referencing unbound variable '") + name + "'");
304       os << "/*" << opName << "=*/"
305          << "s." << name;
306       continue;
307     }
308 
309     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
310     auto argument = resultOp.getArg(i);
311     if (!argument.is<mlir::Operator::Attribute *>())
312       PrintFatalError(pattern->getLoc(),
313                       Twine("expected attribute ") + Twine(i));
314 
315     if (!name.empty())
316       os << "/*" << name << "=*/";
317     emitAttributeValue(defInit->getDef());
318     // TODO(jpienaar): verify types
319   }
320   os << "\n    );\n  }\n};\n";
321 }
322 
323 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
324   Pattern pattern(p, os);
325   pattern.emit(rewriteName);
326 }
327 
328 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
329   emitSourceFileHeader("Rewriters", os);
330   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
331 
332   // Ensure unique patterns simply by appending unique suffix.
333   std::string baseRewriteName = "GeneratedConvert";
334   int rewritePatternCount = 0;
335   for (Record *p : patterns) {
336     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
337   }
338 
339   // Emit function to add the generated matchers to the pattern list.
340   os << "void populateWithGenerated(MLIRContext *context, "
341      << "OwningRewritePatternList *patterns) {\n";
342   for (unsigned i = 0; i != rewritePatternCount; ++i) {
343     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
344                  << i << ">(context));\n";
345   }
346   os << "}\n";
347 }
348 
349 mlir::GenRegistration
350     genRewriters("gen-rewriters", "Generate pattern rewriters",
351                  [](const RecordKeeper &records, raw_ostream &os) {
352                    emitRewriters(records, os);
353                    return false;
354                  });
355