xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 0f9436e56a2c366c839cf43784c1c7a74f5c9ee3)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Pattern.h"
26 #include "mlir/TableGen/Predicate.h"
27 #include "mlir/TableGen/Type.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "llvm/Support/PrettyStackTrace.h"
33 #include "llvm/Support/Signals.h"
34 #include "llvm/TableGen/Error.h"
35 #include "llvm/TableGen/Main.h"
36 #include "llvm/TableGen/Record.h"
37 #include "llvm/TableGen/TableGenBackend.h"
38 
39 using namespace llvm;
40 using namespace mlir;
41 
42 using mlir::tblgen::Argument;
43 using mlir::tblgen::Attribute;
44 using mlir::tblgen::DagNode;
45 using mlir::tblgen::NamedAttribute;
46 using mlir::tblgen::Operand;
47 using mlir::tblgen::Operator;
48 using mlir::tblgen::Pattern;
49 using mlir::tblgen::RecordOperatorMap;
50 using mlir::tblgen::Type;
51 
52 namespace {
53 
54 // Wrapper around DAG argument.
55 struct DagArg {
56   DagArg(Argument arg, Init *constraintInit)
57       : arg(arg), constraintInit(constraintInit) {}
58   bool isAttr();
59 
60   Argument arg;
61   Init *constraintInit;
62 };
63 
64 } // end namespace
65 
66 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
67 
68 namespace {
69 class PatternEmitter {
70 public:
71   static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
72                    raw_ostream &os);
73 
74 private:
75   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
76       : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
77 
78   // Emits the mlir::RewritePattern struct named `rewriteName`.
79   void emit(StringRef rewriteName);
80 
81   // Emits the match() method.
82   void emitMatchMethod(DagNode tree);
83 
84   // Emits the rewrite() method.
85   void emitRewriteMethod();
86 
87   // Emits the C++ statement to replace the matched DAG with an existing value.
88   void emitReplaceWithExistingValue(DagNode resultTree);
89   // Emits the C++ statement to replace the matched DAG with a new op.
90   void emitReplaceOpWithNewOp(DagNode resultTree);
91 
92   // Emits the value of constant attribute to `os`.
93   void emitAttributeValue(Record *constAttr);
94 
95   // Emits C++ statements for matching the op constrained by the given DAG
96   // `tree`.
97   void emitOpMatch(DagNode tree, int depth);
98 
99 private:
100   // Pattern instantiation location followed by the location of multiclass
101   // prototypes used. This is intended to be used as a whole to
102   // PrintFatalError() on errors.
103   ArrayRef<llvm::SMLoc> loc;
104   // Op's TableGen Record to wrapper object
105   RecordOperatorMap *opMap;
106   // Handy wrapper for pattern being emitted
107   Pattern pattern;
108   raw_ostream &os;
109 };
110 } // end namespace
111 
112 void PatternEmitter::emitAttributeValue(Record *constAttr) {
113   Attribute attr(constAttr->getValueAsDef("attr"));
114   auto value = constAttr->getValue("value");
115 
116   if (!attr.isConstBuildable())
117     PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
118                              " does not have the 'constBuilderCall' field");
119 
120   // TODO(jpienaar): Verify the constants here
121   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
122                 value->getValue()->getAsUnquotedString());
123 }
124 
125 // Helper function to match patterns.
126 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
127   Operator &op = tree.getDialectOp(opMap);
128   int indent = 4 + 2 * depth;
129   // Skip the operand matching at depth 0 as the pattern rewriter already does.
130   if (depth != 0) {
131     // Skip if there is no defining instruction (e.g., arguments to function).
132     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
133     os.indent(indent) << formatv(
134         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
135         op.qualifiedCppClassName());
136   }
137   if (tree.getNumArgs() != op.getNumArgs())
138     PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
139                              op.getOperationName() +
140                              "' in pattern and op's definition");
141   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
142     auto opArg = op.getArg(i);
143 
144     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
145       os.indent(indent) << "{\n";
146       os.indent(indent + 2) << formatv(
147           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
148           depth + 1, depth, i);
149       emitOpMatch(argTree, depth + 1);
150       os.indent(indent) << "}\n";
151       continue;
152     }
153 
154     // Verify arguments.
155     if (auto defInit = tree.getArgAsDefInit(i)) {
156       // Verify operands.
157       if (auto *operand = opArg.dyn_cast<Operand *>()) {
158         // Skip verification where not needed due to definition of op.
159         if (operand->defInit == defInit)
160           goto StateCapture;
161 
162         if (!defInit->getDef()->isSubClassOf("Type"))
163           PrintFatalError(loc, "type argument required for operand");
164 
165         auto constraint = tblgen::TypeConstraint(*defInit);
166         os.indent(indent)
167             << "if (!("
168             << formatv(constraint.getConditionTemplate().c_str(),
169                        formatv("op{0}->getOperand({1})->getType()", depth, i))
170             << ")) return matchFailure();\n";
171       }
172 
173       // TODO(jpienaar): Verify attributes.
174       if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
175         auto constraint = tblgen::AttrConstraint(defInit);
176         std::string condition = formatv(
177             constraint.getConditionTemplate().c_str(),
178             formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
179                     namedAttr->attr.getStorageType(), namedAttr->getName()));
180         os.indent(indent) << "if (!(" << condition
181                           << ")) return matchFailure();\n";
182       }
183     }
184 
185   StateCapture:
186     auto name = tree.getArgName(i);
187     if (name.empty())
188       continue;
189     if (opArg.is<Operand *>())
190       os.indent(indent) << "state->" << name << " = op" << depth
191                         << "->getOperand(" << i << ");\n";
192     if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
193       os.indent(indent) << "state->" << name << " = op" << depth
194                         << "->getAttrOfType<"
195                         << namedAttr->attr.getStorageType() << ">(\""
196                         << namedAttr->getName() << "\");\n";
197     }
198   }
199 }
200 
201 void PatternEmitter::emitMatchMethod(DagNode tree) {
202   // Emit the heading.
203   os << R"(
204   PatternMatchResult match(OperationInst *op0) const override {
205     // TODO: This just handle 1 result
206     if (op0->getNumResults() != 1) return matchFailure();
207     auto state = std::make_unique<MatchedState>();)"
208      << "\n";
209   emitOpMatch(tree, 0);
210   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
211 }
212 
213 void PatternEmitter::emit(StringRef rewriteName) {
214   // Get the DAG tree for the source pattern
215   DagNode tree = pattern.getSourcePattern();
216 
217   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
218   // moment, revise.
219   unsigned benefit = tree.getNumOps();
220 
221   const Operator &rootOp = pattern.getSourceRootOp();
222   auto rootName = rootOp.getOperationName();
223 
224   // Emit RewritePattern for Pattern.
225   os << formatv(R"(struct {0} : public RewritePattern {
226   {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
227                 rewriteName, rootName, benefit)
228      << "\n";
229 
230   // Emit matched state.
231   os << "  struct MatchedState : public PatternState {\n";
232   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
233     if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
234       os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
235                    << ";\n";
236     } else {
237       os.indent(4) << "Value* " << arg.first() << ";\n";
238     }
239   }
240   os << "  };\n";
241 
242   emitMatchMethod(tree);
243   emitRewriteMethod();
244 
245   os << "};\n";
246 }
247 
248 void PatternEmitter::emitRewriteMethod() {
249   if (pattern.getNumResults() != 1)
250     PrintFatalError("only single result rules supported");
251 
252   DagNode resultTree = pattern.getResultPattern(0);
253 
254   // TODO(jpienaar): Expand to multiple results.
255   for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i)
256     if (resultTree.getArgAsNestedDag(i))
257       PrintFatalError(loc, "only single op result supported");
258 
259   os << R"(
260   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
261                PatternRewriter &rewriter) const override {
262     auto& s = *static_cast<MatchedState *>(state.get());
263 )";
264 
265   if (resultTree.isReplaceWithValue())
266     emitReplaceWithExistingValue(resultTree);
267   else
268     emitReplaceOpWithNewOp(resultTree);
269 
270   os << "  }\n";
271 }
272 
273 void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
274   if (resultTree.getNumArgs() != 1) {
275     PrintFatalError(loc, "exactly one argument needed in the result pattern");
276   }
277 
278   auto name = resultTree.getArgName(0);
279   pattern.ensureArgBoundInSourcePattern(name);
280   os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
281 }
282 
283 void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
284   Operator &resultOp = resultTree.getDialectOp(opMap);
285   auto numOpArgs =
286       resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
287 
288   os << formatv(R"(
289     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
290                 resultOp.cppClassName());
291   if (numOpArgs != resultTree.getNumArgs()) {
292     PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
293                              Twine(numOpArgs) +
294                              ") and arguments provided for rewrite (" +
295                              Twine(resultTree.getNumArgs()) + Twine(')'));
296   }
297 
298   // Create the builder call for the result.
299   // Add operands.
300   int i = 0;
301   for (auto operand : resultOp.getOperands()) {
302     // Start each operand on its own line.
303     (os << ",\n").indent(6);
304 
305     auto name = resultTree.getArgName(i);
306     pattern.ensureArgBoundInSourcePattern(name);
307     if (operand.name)
308       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
309     os << "s." << name;
310     // TODO(jpienaar): verify types
311     ++i;
312   }
313 
314   // Add attributes.
315   for (int e = resultTree.getNumArgs(); i != e; ++i) {
316     // Start each attribute on its own line.
317     (os << ",\n").indent(6);
318 
319     // The argument in the result DAG pattern.
320     auto argName = resultTree.getArgName(i);
321     auto opName = resultOp.getArgName(i);
322     auto *defInit = resultTree.getArgAsDefInit(i);
323     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
324     if (!value) {
325       pattern.ensureArgBoundInSourcePattern(argName);
326       auto result = "s." + argName;
327       os << "/*" << opName << "=*/";
328       if (defInit) {
329         auto transform = defInit->getDef();
330         if (transform->isSubClassOf("tAttr")) {
331           // TODO(jpienaar): move to helper class.
332           os << formatv(
333               transform->getValueAsString("attrTransform").str().c_str(),
334               result);
335           continue;
336         }
337       }
338       os << result;
339       continue;
340     }
341 
342     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
343     auto argument = resultOp.getArg(i);
344     if (!argument.is<NamedAttribute *>())
345       PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
346 
347     if (!argName.empty())
348       os << "/*" << argName << "=*/";
349     emitAttributeValue(defInit->getDef());
350     // TODO(jpienaar): verify types
351   }
352   os << "\n    );\n";
353 }
354 
355 void PatternEmitter::emit(StringRef rewriteName, Record *p,
356                           RecordOperatorMap *mapper, raw_ostream &os) {
357   PatternEmitter(p, mapper, os).emit(rewriteName);
358 }
359 
360 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
361   emitSourceFileHeader("Rewriters", os);
362   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
363 
364   // We put the map here because it can be shared among multiple patterns.
365   RecordOperatorMap recordOpMap;
366 
367   // Ensure unique patterns simply by appending unique suffix.
368   std::string baseRewriteName = "GeneratedConvert";
369   int rewritePatternCount = 0;
370   for (Record *p : patterns) {
371     PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
372                          p, &recordOpMap, os);
373   }
374 
375   // Emit function to add the generated matchers to the pattern list.
376   os << "void populateWithGenerated(MLIRContext *context, "
377      << "OwningRewritePatternList *patterns) {\n";
378   for (unsigned i = 0; i != rewritePatternCount; ++i) {
379     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
380                  << i << ">(context));\n";
381   }
382   os << "}\n";
383 }
384 
385 static mlir::GenRegistration
386     genRewriters("gen-rewriters", "Generate pattern rewriters",
387                  [](const RecordKeeper &records, raw_ostream &os) {
388                    emitRewriters(records, os);
389                    return false;
390                  });
391