xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision ad637f3ccee4be5aa9f738d6665d0b1a326613a7)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Attribute.h"
23 #include "mlir/TableGen/GenInfo.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Pattern.h"
26 #include "mlir/TableGen/Predicate.h"
27 #include "mlir/TableGen/Type.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "llvm/Support/PrettyStackTrace.h"
33 #include "llvm/Support/Signals.h"
34 #include "llvm/TableGen/Error.h"
35 #include "llvm/TableGen/Main.h"
36 #include "llvm/TableGen/Record.h"
37 #include "llvm/TableGen/TableGenBackend.h"
38 
39 using namespace llvm;
40 using namespace mlir;
41 
42 using mlir::tblgen::Argument;
43 using mlir::tblgen::Attribute;
44 using mlir::tblgen::DagNode;
45 using mlir::tblgen::NamedAttribute;
46 using mlir::tblgen::Operand;
47 using mlir::tblgen::Operator;
48 using mlir::tblgen::Pattern;
49 using mlir::tblgen::RecordOperatorMap;
50 using mlir::tblgen::Type;
51 
52 namespace {
53 
54 // Wrapper around DAG argument.
55 struct DagArg {
56   DagArg(Argument arg, Init *constraintInit)
57       : arg(arg), constraintInit(constraintInit) {}
58   bool isAttr();
59 
60   Argument arg;
61   Init *constraintInit;
62 };
63 
64 } // end namespace
65 
66 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
67 
68 namespace {
69 class PatternEmitter {
70 public:
71   static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
72                    raw_ostream &os);
73 
74 private:
75   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
76       : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
77 
78   // Emits the mlir::RewritePattern struct named `rewriteName`.
79   void emit(StringRef rewriteName);
80 
81   // Emits the match() method.
82   void emitMatchMethod(DagNode tree);
83 
84   // Emits the rewrite() method.
85   void emitRewriteMethod();
86 
87   // Emits the C++ statement to replace the matched DAG with an existing value.
88   void emitReplaceWithExistingValue(DagNode resultTree);
89   // Emits the C++ statement to replace the matched DAG with a new op.
90   void emitReplaceOpWithNewOp(DagNode resultTree);
91 
92   // Emits the value of constant attribute to `os`.
93   void emitAttributeValue(Record *constAttr);
94 
95   // Emits C++ statements for matching the op constrained by the given DAG
96   // `tree`.
97   void emitOpMatch(DagNode tree, int depth);
98 
99 private:
100   // Pattern instantiation location followed by the location of multiclass
101   // prototypes used. This is intended to be used as a whole to
102   // PrintFatalError() on errors.
103   ArrayRef<llvm::SMLoc> loc;
104   // Op's TableGen Record to wrapper object
105   RecordOperatorMap *opMap;
106   // Handy wrapper for pattern being emitted
107   Pattern pattern;
108   raw_ostream &os;
109 };
110 } // end namespace
111 
112 void PatternEmitter::emitAttributeValue(Record *constAttr) {
113   Attribute attr(constAttr->getValueAsDef("attr"));
114   auto value = constAttr->getValue("value");
115 
116   if (!attr.isConstBuildable())
117     PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
118                              " does not have the 'constBuilderCall' field");
119 
120   // TODO(jpienaar): Verify the constants here
121   os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
122                 value->getValue()->getAsUnquotedString());
123 }
124 
125 // Helper function to match patterns.
126 void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
127   Operator &op = tree.getDialectOp(opMap);
128   int indent = 4 + 2 * depth;
129   // Skip the operand matching at depth 0 as the pattern rewriter already does.
130   if (depth != 0) {
131     // Skip if there is no defining instruction (e.g., arguments to function).
132     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
133     os.indent(indent) << formatv(
134         "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
135         op.qualifiedCppClassName());
136   }
137   if (tree.getNumArgs() != op.getNumArgs())
138     PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
139                              op.getOperationName() +
140                              "' in pattern and op's definition");
141   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
142     auto opArg = op.getArg(i);
143 
144     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
145       os.indent(indent) << "{\n";
146       os.indent(indent + 2) << formatv(
147           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
148           depth + 1, depth, i);
149       emitOpMatch(argTree, depth + 1);
150       os.indent(indent) << "}\n";
151       continue;
152     }
153 
154     // Verify arguments.
155     if (auto defInit = tree.getArgAsDefInit(i)) {
156       // Verify operands.
157       if (auto *operand = opArg.dyn_cast<Operand *>()) {
158         // Skip verification where not needed due to definition of op.
159         if (operand->defInit == defInit)
160           goto StateCapture;
161 
162         if (!defInit->getDef()->isSubClassOf("Type"))
163           PrintFatalError(loc, "type argument required for operand");
164 
165         auto constraint = tblgen::TypeConstraint(*defInit);
166         os.indent(indent)
167             << "if (!("
168             << formatv(constraint.getConditionTemplate().c_str(),
169                        formatv("op{0}->getOperand({1})->getType()", depth, i))
170             << ")) return matchFailure();\n";
171       }
172 
173       // TODO(jpienaar): Verify attributes.
174       if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
175         auto constraint = tblgen::AttrConstraint(defInit);
176         std::string condition = formatv(
177             constraint.getConditionTemplate().c_str(),
178             formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
179                     namedAttr->attr.getStorageType(), namedAttr->getName()));
180         os.indent(indent) << "if (!(" << condition
181                           << ")) return matchFailure();\n";
182       }
183     }
184 
185   StateCapture:
186     auto name = tree.getArgName(i);
187     if (name.empty())
188       continue;
189     if (opArg.is<Operand *>())
190       os.indent(indent) << "state->" << name << " = op" << depth
191                         << "->getOperand(" << i << ");\n";
192     if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
193       os.indent(indent) << "state->" << name << " = op" << depth
194                         << "->getAttrOfType<"
195                         << namedAttr->attr.getStorageType() << ">(\""
196                         << namedAttr->getName() << "\");\n";
197     }
198   }
199 }
200 
201 void PatternEmitter::emitMatchMethod(DagNode tree) {
202   // Emit the heading.
203   os << R"(
204   PatternMatchResult match(OperationInst *op0) const override {
205     // TODO: This just handle 1 result
206     if (op0->getNumResults() != 1) return matchFailure();
207     auto ctx = op0->getContext(); (void)ctx;
208     auto state = std::make_unique<MatchedState>();)"
209      << "\n";
210   emitOpMatch(tree, 0);
211   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
212 }
213 
214 void PatternEmitter::emit(StringRef rewriteName) {
215   // Get the DAG tree for the source pattern
216   DagNode tree = pattern.getSourcePattern();
217 
218   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
219   // moment, revise.
220   unsigned benefit = tree.getNumOps();
221 
222   const Operator &rootOp = pattern.getSourceRootOp();
223   auto rootName = rootOp.getOperationName();
224 
225   // Emit RewritePattern for Pattern.
226   os << formatv(R"(struct {0} : public RewritePattern {
227   {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
228                 rewriteName, rootName, benefit)
229      << "\n";
230 
231   // Emit matched state.
232   os << "  struct MatchedState : public PatternState {\n";
233   for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
234     if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
235       os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
236                    << ";\n";
237     } else {
238       os.indent(4) << "Value* " << arg.first() << ";\n";
239     }
240   }
241   os << "  };\n";
242 
243   emitMatchMethod(tree);
244   emitRewriteMethod();
245 
246   os << "};\n";
247 }
248 
249 void PatternEmitter::emitRewriteMethod() {
250   if (pattern.getNumResults() != 1)
251     PrintFatalError("only single result rules supported");
252 
253   DagNode resultTree = pattern.getResultPattern(0);
254 
255   // TODO(jpienaar): Expand to multiple results.
256   for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i)
257     if (resultTree.getArgAsNestedDag(i))
258       PrintFatalError(loc, "only single op result supported");
259 
260   os << R"(
261   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
262                PatternRewriter &rewriter) const override {
263     auto& s = *static_cast<MatchedState *>(state.get());
264 )";
265 
266   if (resultTree.isReplaceWithValue())
267     emitReplaceWithExistingValue(resultTree);
268   else
269     emitReplaceOpWithNewOp(resultTree);
270 
271   os << "  }\n";
272 }
273 
274 void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
275   if (resultTree.getNumArgs() != 1) {
276     PrintFatalError(loc, "exactly one argument needed in the result pattern");
277   }
278 
279   auto name = resultTree.getArgName(0);
280   pattern.ensureArgBoundInSourcePattern(name);
281   os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
282 }
283 
284 void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
285   Operator &resultOp = resultTree.getDialectOp(opMap);
286   auto numOpArgs =
287       resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
288 
289   os << formatv(R"(
290     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
291                 resultOp.cppClassName());
292   if (numOpArgs != resultTree.getNumArgs()) {
293     PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
294                              Twine(numOpArgs) +
295                              ") and arguments provided for rewrite (" +
296                              Twine(resultTree.getNumArgs()) + Twine(')'));
297   }
298 
299   // Create the builder call for the result.
300   // Add operands.
301   int i = 0;
302   for (auto operand : resultOp.getOperands()) {
303     // Start each operand on its own line.
304     (os << ",\n").indent(6);
305 
306     auto name = resultTree.getArgName(i);
307     pattern.ensureArgBoundInSourcePattern(name);
308     if (operand.name)
309       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
310     os << "s." << name;
311     // TODO(jpienaar): verify types
312     ++i;
313   }
314 
315   // Add attributes.
316   for (int e = resultTree.getNumArgs(); i != e; ++i) {
317     // Start each attribute on its own line.
318     (os << ",\n").indent(6);
319 
320     // The argument in the result DAG pattern.
321     auto argName = resultTree.getArgName(i);
322     auto opName = resultOp.getArgName(i);
323     auto *defInit = resultTree.getArgAsDefInit(i);
324     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
325     if (!value) {
326       pattern.ensureArgBoundInSourcePattern(argName);
327       auto result = "s." + argName;
328       os << "/*" << opName << "=*/";
329       if (defInit) {
330         auto transform = defInit->getDef();
331         if (transform->isSubClassOf("tAttr")) {
332           // TODO(jpienaar): move to helper class.
333           os << formatv(
334               transform->getValueAsString("attrTransform").str().c_str(),
335               result);
336           continue;
337         }
338       }
339       os << result;
340       continue;
341     }
342 
343     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
344     auto argument = resultOp.getArg(i);
345     if (!argument.is<NamedAttribute *>())
346       PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
347 
348     if (!argName.empty())
349       os << "/*" << argName << "=*/";
350     emitAttributeValue(defInit->getDef());
351     // TODO(jpienaar): verify types
352   }
353   os << "\n    );\n";
354 }
355 
356 void PatternEmitter::emit(StringRef rewriteName, Record *p,
357                           RecordOperatorMap *mapper, raw_ostream &os) {
358   PatternEmitter(p, mapper, os).emit(rewriteName);
359 }
360 
361 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
362   emitSourceFileHeader("Rewriters", os);
363   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
364 
365   // We put the map here because it can be shared among multiple patterns.
366   RecordOperatorMap recordOpMap;
367 
368   // Ensure unique patterns simply by appending unique suffix.
369   std::string baseRewriteName = "GeneratedConvert";
370   int rewritePatternCount = 0;
371   for (Record *p : patterns) {
372     PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
373                          p, &recordOpMap, os);
374   }
375 
376   // Emit function to add the generated matchers to the pattern list.
377   os << "void populateWithGenerated(MLIRContext *context, "
378      << "OwningRewritePatternList *patterns) {\n";
379   for (unsigned i = 0; i != rewritePatternCount; ++i) {
380     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
381                  << i << ">(context));\n";
382   }
383   os << "}\n";
384 }
385 
386 static mlir::GenRegistration
387     genRewriters("gen-rewriters", "Generate pattern rewriters",
388                  [](const RecordKeeper &records, raw_ostream &os) {
389                    emitRewriters(records, os);
390                    return false;
391                  });
392