xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision aae85ddce104fd5dca4dbb1b97a8b9949b249c93)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Operator.h"
24 #include "mlir/TableGen/Predicate.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/PrettyStackTrace.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Main.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35 
36 using namespace llvm;
37 using namespace mlir;
38 
39 namespace {
40 
41 // Wrapper around dag argument.
42 struct DagArg {
43   DagArg(Init *init) : init(init) {}
44   bool isAttr();
45 
46   Init *init;
47 };
48 
49 } // end namespace
50 
51 bool DagArg::isAttr() {
52   if (auto defInit = dyn_cast<DefInit>(init))
53     return defInit->getDef()->isSubClassOf("Attr");
54   return false;
55 }
56 
57 namespace {
58 class Pattern {
59 public:
60   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
61 
62 private:
63   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
64 
65   // Emit the rewrite pattern named `rewriteName`.
66   void emit(StringRef rewriteName);
67 
68   // Emit the matcher.
69   void emitMatcher(DagInit *tree);
70 
71   // Emits the value of constant attribute to `os`.
72   void emitAttributeValue(Record *constAttr);
73 
74   // Collect bound arguments.
75   void collectBoundArguments(DagInit *tree);
76 
77   // Map from bound argument name to DagArg.
78   StringMap<DagArg> boundArguments;
79 
80   // Number of the operations in the input pattern.
81   int numberOfOpsMatched = 0;
82 
83   Record *pattern;
84   raw_ostream &os;
85 };
86 } // end namespace
87 
88 void Pattern::emitAttributeValue(Record *constAttr) {
89   Record *attr = constAttr->getValueAsDef("attr");
90   auto value = constAttr->getValue("value");
91   Record *type = attr->getValueAsDef("type");
92   auto storageType = attr->getValueAsString("storageType").trim();
93 
94   // For attributes stored as strings we do not need to query builder etc.
95   if (storageType == "StringAttr") {
96     os << formatv("rewriter.getStringAttr({0})",
97                   value->getValue()->getAsString());
98     return;
99   }
100 
101   // Construct the attribute based on storage type and builder.
102   if (auto b = type->getValue("builderCall")) {
103     if (isa<UnsetInit>(b->getValue()))
104       PrintFatalError(pattern->getLoc(),
105                       "no builder specified for " + type->getName());
106     CodeInit *builder = cast<CodeInit>(b->getValue());
107     // TODO(jpienaar): Verify the constants here
108     os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
109                   builder->getValue(),
110                   value->getValue()->getAsUnquotedString());
111     return;
112   }
113 
114   PrintFatalError(pattern->getLoc(), "unable to emit attribute");
115 }
116 
117 void Pattern::collectBoundArguments(DagInit *tree) {
118   ++numberOfOpsMatched;
119   // TODO(jpienaar): Expand to multiple matches.
120   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
121     auto arg = tree->getArg(i);
122     if (auto argTree = dyn_cast<DagInit>(arg)) {
123       collectBoundArguments(argTree);
124       continue;
125     }
126     auto name = tree->getArgNameStr(i);
127     if (name.empty())
128       continue;
129     boundArguments.try_emplace(name, arg);
130   }
131 }
132 
133 // Helper function to match patterns.
134 static void matchOp(Record *pattern, DagInit *tree, int depth,
135                     raw_ostream &os) {
136   Operator op(cast<DefInit>(tree->getOperator())->getDef());
137   int indent = 4 + 2 * depth;
138   // Skip the operand matching at depth 0 as the pattern rewriter already does.
139   if (depth != 0) {
140     // Skip if there is no defining instruction (e.g., arguments to function).
141     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
142     os.indent(indent) << formatv(
143         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
144         op.qualifiedCppClassName());
145   }
146   if (tree->getNumArgs() != op.getNumArgs())
147     PrintFatalError(pattern->getLoc(),
148                     Twine("mismatch in number of arguments to op '") +
149                         op.getOperationName() +
150                         "' in pattern and op's definition");
151   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
152     auto arg = tree->getArg(i);
153     auto opArg = op.getArg(i);
154 
155     if (auto argTree = dyn_cast<DagInit>(arg)) {
156       os.indent(indent) << "{\n";
157       os.indent(indent + 2) << formatv(
158           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
159           depth + 1, depth, i);
160       matchOp(pattern, argTree, depth + 1, os);
161       os.indent(indent) << "}\n";
162       continue;
163     }
164 
165     // Verify arguments.
166     if (auto defInit = dyn_cast<DefInit>(arg)) {
167       // Verify operands.
168       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
169         // Skip verification where not needed due to definition of op.
170         if (operand->defInit == defInit)
171           goto StateCapture;
172 
173         if (!defInit->getDef()->isSubClassOf("Type"))
174           PrintFatalError(pattern->getLoc(),
175                           "type argument required for operand");
176 
177         // TODO(jpienaar): Factor out type class and move these there.
178         auto predicate = defInit->getDef()->getValue("predicate")->getValue();
179         auto predCnf = cast<DefInit>(predicate);
180         auto conjunctiveList =
181             predCnf->getDef()->getValueAsListInit("conditions");
182         PredCNF pred(conjunctiveList);
183         os.indent(indent)
184             << "if (!("
185             << formatv(pred.createTypeMatcherTemplate().c_str(),
186                        formatv("op{0}->getOperand({1})->getType()", depth, i))
187             << ")) return matchFailure();\n";
188       }
189 
190       // TODO(jpienaar): Verify attributes.
191       if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
192       }
193     }
194 
195   StateCapture:
196     auto name = tree->getArgNameStr(i);
197     if (name.empty())
198       continue;
199     if (opArg.is<Operator::Operand *>())
200       os.indent(indent) << "state->" << name << " = op" << depth
201                         << "->getOperand(" << i << ");\n";
202     if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
203       os.indent(indent) << "state->" << name << " = op" << depth
204                         << "->getAttrOfType<" << attr->getStorageType()
205                         << ">(\"" << attr->getName() << "\");\n";
206     }
207   }
208 }
209 
210 void Pattern::emitMatcher(DagInit *tree) {
211   // Emit the heading.
212   os << R"(
213   PatternMatchResult match(OperationInst *op0) const override {
214     // TODO: This just handle 1 result
215     if (op0->getNumResults() != 1) return matchFailure();
216     auto state = std::make_unique<MatchedState>();)"
217      << "\n";
218   matchOp(pattern, tree, 0, os);
219   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
220 }
221 
222 void Pattern::emit(StringRef rewriteName) {
223   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
224   // Collect bound arguments and compute number of ops matched.
225   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
226   // moment, revise.
227   collectBoundArguments(tree);
228 
229   // Emit RewritePattern for Pattern.
230   DefInit *root = cast<DefInit>(tree->getOperator());
231   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
232   os << formatv(R"(struct {0} : public RewritePattern {
233   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
234                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
235      << "\n";
236 
237   // Emit matched state.
238   os << "  struct MatchedState : public PatternState {\n";
239   for (auto &arg : boundArguments) {
240     if (arg.second.isAttr()) {
241       DefInit *defInit = cast<DefInit>(arg.second.init);
242       os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
243                    << " " << arg.first() << ";\n";
244     } else {
245       os.indent(4) << "Value* " << arg.first() << ";\n";
246     }
247   }
248   os << "  };\n";
249 
250   emitMatcher(tree);
251   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
252   if (resultOps->size() != 1)
253     PrintFatalError("only single result rules supported");
254   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
255 
256   // TODO(jpienaar): Expand to multiple results.
257   for (auto result : resultTree->getArgs()) {
258     if (isa<DagInit>(result))
259       PrintFatalError(pattern->getLoc(), "only single op result supported");
260   }
261 
262   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
263   Operator resultOp(*resultRoot->getDef());
264   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
265 
266   os << formatv(R"(
267   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
268                PatternRewriter &rewriter) const override {
269     auto& s = *static_cast<MatchedState *>(state.get());
270     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
271                 resultOp.cppClassName());
272   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
273     PrintFatalError(pattern->getLoc(),
274                     Twine("mismatch between arguments of resultant op (") +
275                         Twine(resultOperands->getNumArgs()) +
276                         ") and arguments provided for rewrite (" +
277                         Twine(resultTree->getNumArgs()) + Twine(')'));
278   }
279 
280   // Create the builder call for the result.
281   // Add operands.
282   int i = 0;
283   for (auto operand : resultOp.getOperands()) {
284     // Start each operand on its own line.
285     (os << ",\n").indent(6);
286 
287     auto name = resultTree->getArgNameStr(i);
288     if (boundArguments.find(name) == boundArguments.end())
289       PrintFatalError(pattern->getLoc(),
290                       Twine("referencing unbound variable '") + name + "'");
291     if (operand.name)
292       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
293     os << "s." << name;
294     // TODO(jpienaar): verify types
295     ++i;
296   }
297 
298   // Add attributes.
299   for (int e = resultTree->getNumArgs(); i != e; ++i) {
300     // Start each attribute on its own line.
301     (os << ",\n").indent(6);
302 
303     // The argument in the result DAG pattern.
304     auto name = resultTree->getArgNameStr(i);
305     auto opName = resultOp.getArgName(i);
306     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
307     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
308     if (!value) {
309       if (boundArguments.find(name) == boundArguments.end())
310         PrintFatalError(pattern->getLoc(),
311                         Twine("referencing unbound variable '") + name + "'");
312       os << "/*" << opName << "=*/"
313          << "s." << name;
314       continue;
315     }
316 
317     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
318     auto argument = resultOp.getArg(i);
319     if (!argument.is<mlir::Operator::Attribute *>())
320       PrintFatalError(pattern->getLoc(),
321                       Twine("expected attribute ") + Twine(i));
322 
323     if (!name.empty())
324       os << "/*" << name << "=*/";
325     emitAttributeValue(defInit->getDef());
326     // TODO(jpienaar): verify types
327   }
328   os << "\n    );\n  }\n};\n";
329 }
330 
331 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
332   Pattern pattern(p, os);
333   pattern.emit(rewriteName);
334 }
335 
336 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
337   emitSourceFileHeader("Rewriters", os);
338   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
339 
340   // Ensure unique patterns simply by appending unique suffix.
341   std::string baseRewriteName = "GeneratedConvert";
342   int rewritePatternCount = 0;
343   for (Record *p : patterns) {
344     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
345   }
346 
347   // Emit function to add the generated matchers to the pattern list.
348   os << "void populateWithGenerated(MLIRContext *context, "
349      << "OwningRewritePatternList *patterns) {\n";
350   for (unsigned i = 0; i != rewritePatternCount; ++i) {
351     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
352                  << i << ">(context));\n";
353   }
354   os << "}\n";
355 }
356 
357 mlir::GenRegistration
358     genRewriters("gen-rewriters", "Generate pattern rewriters",
359                  [](const RecordKeeper &records, raw_ostream &os) {
360                    emitRewriters(records, os);
361                    return false;
362                  });
363