xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision a5827fc91d3c2ebe447a7f8e22ff7f2bfd9ba440)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Predicate.h"
26 #include "mlir/TableGen/Type.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
37 
38 using namespace llvm;
39 using namespace mlir;
40 
41 using mlir::tblgen::Attribute;
42 using mlir::tblgen::Operator;
43 using mlir::tblgen::Type;
44 
45 namespace {
46 
47 // Wrapper around DAG argument.
48 struct DagArg {
49   DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit)
50       : arg(arg), constraintInit(constraintInit) {}
51   bool isAttr();
52 
53   mlir::tblgen::Operator::Argument arg;
54   Init *constraintInit;
55 };
56 
57 } // end namespace
58 
59 bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); }
60 
61 namespace {
62 class Pattern {
63 public:
64   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
65 
66 private:
67   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
68 
69   // Emit the rewrite pattern named `rewriteName`.
70   void emit(StringRef rewriteName);
71 
72   // Emit the matcher.
73   void emitMatcher(DagInit *tree);
74 
75   // Emits the value of constant attribute to `os`.
76   void emitAttributeValue(Record *constAttr);
77 
78   // Collect bound arguments.
79   void collectBoundArguments(DagInit *tree);
80 
81   // Helper function to match patterns.
82   void matchOp(DagInit *tree, int depth);
83 
84   // Returns the Operator stored for the given record.
85   Operator &getOperator(const llvm::Record *record);
86 
87   // Map from bound argument name to DagArg.
88   StringMap<DagArg> boundArguments;
89 
90   // Map from Record* to Operator.
91   DenseMap<const llvm::Record *, Operator> opMap;
92 
93   // Number of the operations in the input pattern.
94   int numberOfOpsMatched = 0;
95 
96   Record *pattern;
97   raw_ostream &os;
98 };
99 } // end namespace
100 
101 // Returns the Operator stored for the given record.
102 auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
103   return opMap.try_emplace(record, record).first->second;
104 }
105 
106 void Pattern::emitAttributeValue(Record *constAttr) {
107   Attribute attr(constAttr->getValueAsDef("attr"));
108   auto value = constAttr->getValue("value");
109 
110   if (!attr.isConstBuildable())
111     PrintFatalError(pattern->getLoc(),
112                     "Attribute " + attr.getTableGenDefName() +
113                         " does not have the 'constBuilderCall' field");
114 
115   // TODO(jpienaar): Verify the constants here
116   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
117                 value->getValue()->getAsUnquotedString());
118 }
119 
120 void Pattern::collectBoundArguments(DagInit *tree) {
121   ++numberOfOpsMatched;
122   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
123   // TODO(jpienaar): Expand to multiple matches.
124   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
125     auto arg = tree->getArg(i);
126     if (auto argTree = dyn_cast<DagInit>(arg)) {
127       collectBoundArguments(argTree);
128       continue;
129     }
130     auto name = tree->getArgNameStr(i);
131     if (name.empty())
132       continue;
133     boundArguments.try_emplace(name, op.getArg(i), arg);
134   }
135 }
136 
137 // Helper function to match patterns.
138 void Pattern::matchOp(DagInit *tree, int depth) {
139   Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
140   int indent = 4 + 2 * depth;
141   // Skip the operand matching at depth 0 as the pattern rewriter already does.
142   if (depth != 0) {
143     // Skip if there is no defining instruction (e.g., arguments to function).
144     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
145     os.indent(indent) << formatv(
146         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
147         op.qualifiedCppClassName());
148   }
149   if (tree->getNumArgs() != op.getNumArgs())
150     PrintFatalError(pattern->getLoc(),
151                     Twine("mismatch in number of arguments to op '") +
152                         op.getOperationName() +
153                         "' in pattern and op's definition");
154   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
155     auto arg = tree->getArg(i);
156     auto opArg = op.getArg(i);
157 
158     if (auto argTree = dyn_cast<DagInit>(arg)) {
159       os.indent(indent) << "{\n";
160       os.indent(indent + 2) << formatv(
161           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
162           depth + 1, depth, i);
163       matchOp(argTree, depth + 1);
164       os.indent(indent) << "}\n";
165       continue;
166     }
167 
168     // Verify arguments.
169     if (auto defInit = dyn_cast<DefInit>(arg)) {
170       // Verify operands.
171       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
172         // Skip verification where not needed due to definition of op.
173         if (operand->defInit == defInit)
174           goto StateCapture;
175 
176         if (!defInit->getDef()->isSubClassOf("Type"))
177           PrintFatalError(pattern->getLoc(),
178                           "type argument required for operand");
179 
180         auto constraint = tblgen::TypeConstraint(*defInit);
181         os.indent(indent)
182             << "if (!("
183             << formatv(constraint.getConditionTemplate().str().c_str(),
184                        formatv("op{0}->getOperand({1})->getType()", depth, i))
185             << ")) return matchFailure();\n";
186       }
187 
188       // TODO(jpienaar): Verify attributes.
189       if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
190         // TODO(jpienaar): move to helper class.
191         if (defInit->getDef()->isSubClassOf("mAttr")) {
192           auto pred =
193               tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
194           os.indent(indent)
195               << "if (!("
196               << formatv(pred.getCondition().str().c_str(),
197                          formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
198                                  namedAttr->attr.getStorageType(),
199                                  namedAttr->getName()))
200               << ")) return matchFailure();\n";
201         }
202       }
203     }
204 
205   StateCapture:
206     auto name = tree->getArgNameStr(i);
207     if (name.empty())
208       continue;
209     if (opArg.is<Operator::Operand *>())
210       os.indent(indent) << "state->" << name << " = op" << depth
211                         << "->getOperand(" << i << ");\n";
212     if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
213       os.indent(indent) << "state->" << name << " = op" << depth
214                         << "->getAttrOfType<"
215                         << namedAttr->attr.getStorageType() << ">(\""
216                         << namedAttr->getName() << "\");\n";
217     }
218   }
219 }
220 
221 void Pattern::emitMatcher(DagInit *tree) {
222   // Emit the heading.
223   os << R"(
224   PatternMatchResult match(OperationInst *op0) const override {
225     // TODO: This just handle 1 result
226     if (op0->getNumResults() != 1) return matchFailure();
227     auto state = std::make_unique<MatchedState>();)"
228      << "\n";
229   matchOp(tree, 0);
230   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
231 }
232 
233 void Pattern::emit(StringRef rewriteName) {
234   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
235   // Collect bound arguments and compute number of ops matched.
236   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
237   // moment, revise.
238   collectBoundArguments(tree);
239 
240   // Emit RewritePattern for Pattern.
241   DefInit *root = cast<DefInit>(tree->getOperator());
242   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
243   os << formatv(R"(struct {0} : public RewritePattern {
244   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
245                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
246      << "\n";
247 
248   // Emit matched state.
249   os << "  struct MatchedState : public PatternState {\n";
250   for (auto &arg : boundArguments) {
251     if (auto namedAttr =
252             arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) {
253       os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
254                    << ";\n";
255     } else {
256       os.indent(4) << "Value* " << arg.first() << ";\n";
257     }
258   }
259   os << "  };\n";
260 
261   emitMatcher(tree);
262   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
263   if (resultOps->size() != 1)
264     PrintFatalError("only single result rules supported");
265   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
266 
267   // TODO(jpienaar): Expand to multiple results.
268   for (auto result : resultTree->getArgs()) {
269     if (isa<DagInit>(result))
270       PrintFatalError(pattern->getLoc(), "only single op result supported");
271   }
272 
273   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
274   Operator &resultOp = getOperator(resultRoot->getDef());
275   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
276 
277   os << formatv(R"(
278   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
279                PatternRewriter &rewriter) const override {
280     auto& s = *static_cast<MatchedState *>(state.get());
281     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
282                 resultOp.cppClassName());
283   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
284     PrintFatalError(pattern->getLoc(),
285                     Twine("mismatch between arguments of resultant op (") +
286                         Twine(resultOperands->getNumArgs()) +
287                         ") and arguments provided for rewrite (" +
288                         Twine(resultTree->getNumArgs()) + Twine(')'));
289   }
290 
291   // Create the builder call for the result.
292   // Add operands.
293   int i = 0;
294   for (auto operand : resultOp.getOperands()) {
295     // Start each operand on its own line.
296     (os << ",\n").indent(6);
297 
298     auto name = resultTree->getArgNameStr(i);
299     if (boundArguments.find(name) == boundArguments.end())
300       PrintFatalError(pattern->getLoc(),
301                       Twine("referencing unbound variable '") + name + "'");
302     if (operand.name)
303       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
304     os << "s." << name;
305     // TODO(jpienaar): verify types
306     ++i;
307   }
308 
309   // Add attributes.
310   for (int e = resultTree->getNumArgs(); i != e; ++i) {
311     // Start each attribute on its own line.
312     (os << ",\n").indent(6);
313 
314     // The argument in the result DAG pattern.
315     auto name = resultTree->getArgNameStr(i);
316     auto opName = resultOp.getArgName(i);
317     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
318     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
319     if (!value) {
320       if (boundArguments.find(name) == boundArguments.end())
321         PrintFatalError(pattern->getLoc(),
322                         Twine("referencing unbound variable '") + name + "'");
323       auto result = "s." + name;
324       os << "/*" << opName << "=*/";
325       if (defInit) {
326         auto transform = defInit->getDef();
327         if (transform->isSubClassOf("tAttr")) {
328           // TODO(jpienaar): move to helper class.
329           os << formatv(
330               transform->getValueAsString("attrTransform").str().c_str(),
331               result);
332           continue;
333         }
334       }
335       os << result;
336       continue;
337     }
338 
339     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
340     auto argument = resultOp.getArg(i);
341     if (!argument.is<Operator::NamedAttribute *>())
342       PrintFatalError(pattern->getLoc(),
343                       Twine("expected attribute ") + Twine(i));
344 
345     if (!name.empty())
346       os << "/*" << name << "=*/";
347     emitAttributeValue(defInit->getDef());
348     // TODO(jpienaar): verify types
349   }
350   os << "\n    );\n  }\n};\n";
351 }
352 
353 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
354   Pattern pattern(p, os);
355   pattern.emit(rewriteName);
356 }
357 
358 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
359   emitSourceFileHeader("Rewriters", os);
360   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
361 
362   // Ensure unique patterns simply by appending unique suffix.
363   std::string baseRewriteName = "GeneratedConvert";
364   int rewritePatternCount = 0;
365   for (Record *p : patterns) {
366     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
367   }
368 
369   // Emit function to add the generated matchers to the pattern list.
370   os << "void populateWithGenerated(MLIRContext *context, "
371      << "OwningRewritePatternList *patterns) {\n";
372   for (unsigned i = 0; i != rewritePatternCount; ++i) {
373     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
374                  << i << ">(context));\n";
375   }
376   os << "}\n";
377 }
378 
379 mlir::GenRegistration
380     genRewriters("gen-rewriters", "Generate pattern rewriters",
381                  [](const RecordKeeper &records, raw_ostream &os) {
382                    emitRewriters(records, os);
383                    return false;
384                  });
385