xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision bbe3f4d9f5047abb5366c4d94c5c0199d8616b3f)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Operator.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/PrettyStackTrace.h"
29 #include "llvm/Support/Signals.h"
30 #include "llvm/TableGen/Error.h"
31 #include "llvm/TableGen/Main.h"
32 #include "llvm/TableGen/Record.h"
33 #include "llvm/TableGen/TableGenBackend.h"
34 
35 using namespace llvm;
36 using namespace mlir;
37 
38 namespace {
39 
40 // Wrapper around dag argument.
41 struct DagArg {
42   DagArg(Init *init) : init(init){};
43   bool isAttr();
44 
45   Init *init;
46 };
47 
48 } // end namespace
49 
50 bool DagArg::isAttr() {
51   if (auto defInit = dyn_cast<DefInit>(init))
52     return defInit->getDef()->isSubClassOf("Attr");
53   return false;
54 }
55 
56 namespace {
57 class Pattern {
58 public:
59   static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
60 
61 private:
62   Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os){};
63 
64   // Emit the rewrite pattern named `rewriteName`.
65   void emit(StringRef rewriteName);
66 
67   // Emit the matcher.
68   void emitMatcher(DagInit *tree);
69 
70   // Emits the value of constant attribute to `os`.
71   void emitAttributeValue(Record *constAttr);
72 
73   // Collect bound arguments.
74   void collectBoundArguments(DagInit *tree);
75 
76   // Map from bound argument name to DagArg.
77   StringMap<DagArg> boundArguments;
78 
79   // Number of the operations in the input pattern.
80   int numberOfOpsMatched = 0;
81 
82   Record *pattern;
83   raw_ostream &os;
84 };
85 } // end namespace
86 
87 void Pattern::emitAttributeValue(Record *constAttr) {
88   Record *type = constAttr->getValueAsDef("type");
89   auto value = constAttr->getValue("value");
90 
91   // Construct the attribute based on `type`.
92   // TODO(jpienaar): Generalize this to avoid hardcoding here.
93   if (type->isSubClassOf("F")) {
94     string mlirType;
95     switch (type->getValueAsInt("bitwidth")) {
96     case 32:
97       mlirType = "Type::getF32(context)";
98       break;
99     default:
100       PrintFatalError("unsupported floating point width");
101     }
102     // TODO(jpienaar): Verify the floating point constant here.
103     os << formatv("FloatAttr::get({0}, {1})", mlirType,
104                   value->getValue()->getAsUnquotedString());
105     return;
106   }
107 
108   // Fallback to the type of value.
109   switch (value->getType()->getRecTyKind()) {
110   case RecTy::IntRecTyKind:
111     // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
112     // type could instead be queried. These are expected to be mostly used
113     // for enums or constant indices and so no arithmetic operations are
114     // expected on these.
115     os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
116                   value->getValue()->getAsString());
117     break;
118   case RecTy::StringRecTyKind:
119     os << formatv("StringAttr::get({0}, context)",
120                   value->getValue()->getAsString());
121     break;
122   default:
123     PrintFatalError(pattern->getLoc(),
124                     Twine("unsupported/unimplemented value type for ") +
125                         value->getName());
126   }
127 }
128 
129 void Pattern::collectBoundArguments(DagInit *tree) {
130   ++numberOfOpsMatched;
131   // TODO(jpienaar): Expand to multiple matches.
132   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
133     auto arg = tree->getArg(i);
134     if (auto argTree = dyn_cast<DagInit>(arg)) {
135       collectBoundArguments(argTree);
136       continue;
137     }
138     auto name = tree->getArgNameStr(i);
139     if (name.empty())
140       continue;
141     boundArguments.try_emplace(name, arg);
142   }
143 }
144 
145 // Helper function to match patterns.
146 static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
147   Operator op(cast<DefInit>(tree->getOperator())->getDef());
148   int indent = 4 + 2 * depth;
149   // Skip the operand matching at depth 0 as the pattern rewriter already does.
150   if (depth != 0) {
151     // Skip if there is no defining instruction (e.g., arguments to function).
152     os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
153     // TODO(jpienaar): This is bad, we should not be checking strings here, we
154     // should be matching using mOp (and helpers). Currently doing this to allow
155     // for TF ops that aren't registed. Fix it.
156     os.indent(indent) << formatv(
157                              "if (op{0}->getName().getStringRef() != \"{1}\")",
158                              depth, op.getOperationName())
159                       << "\n";
160     os.indent(indent + 2) << "return matchFailure();\n";
161   }
162   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
163     auto arg = tree->getArg(i);
164     if (auto argTree = dyn_cast<DagInit>(arg)) {
165       os.indent(indent) << "{\n";
166       os.indent(indent + 2) << formatv(
167           "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
168           depth + 1, depth, i);
169       matchOp(argTree, depth + 1, os);
170       os.indent(indent) << "}\n";
171       continue;
172     }
173     auto name = tree->getArgNameStr(i);
174     if (name.empty())
175       continue;
176     os.indent(indent) << "state->" << name << " = op" << depth
177                       << "->getOperand(" << i << ");\n";
178   }
179 }
180 
181 void Pattern::emitMatcher(DagInit *tree) {
182   // Emit the heading.
183   os << R"(
184   PatternMatchResult match(OperationInst *op0) const override {
185     // TODO: This just handle 1 result
186     if (op0->getNumResults() != 1) return matchFailure();
187     auto state = std::make_unique<MatchedState>();)"
188      << "\n";
189   matchOp(tree, 0, os);
190   os.indent(4) << "return matchSuccess(std::move(state));\n  }\n";
191 }
192 
193 void Pattern::emit(StringRef rewriteName) {
194   DagInit *tree = pattern->getValueAsDag("PatternToMatch");
195   // Collect bound arguments and compute number of ops matched.
196   // TODO(jpienaar): the benefit metric is simply number of ops matched at the
197   // moment, revise.
198   collectBoundArguments(tree);
199 
200   // Emit RewritePattern for Pattern.
201   DefInit *root = cast<DefInit>(tree->getOperator());
202   auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
203   os << formatv(R"(struct {0} : public RewritePattern {
204   {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
205                 rewriteName, rootName->getAsString(), numberOfOpsMatched)
206      << "\n";
207 
208   // Emit matched state.
209   os << "  struct MatchedState : public PatternState {\n";
210   for (auto &arg : boundArguments) {
211     if (arg.second.isAttr()) {
212       DefInit *defInit = cast<DefInit>(arg.second.init);
213       os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
214                    << " " << arg.first() << ";\n";
215     } else {
216       os.indent(4) << "Value* " << arg.first() << ";\n";
217     }
218   }
219   os << "  };\n";
220 
221   emitMatcher(tree);
222   ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
223   if (resultOps->size() != 1)
224     PrintFatalError("only single result rules supported");
225   DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
226 
227   // TODO(jpienaar): Expand to multiple results.
228   for (auto result : resultTree->getArgs()) {
229     if (isa<DagInit>(result))
230       PrintFatalError(pattern->getLoc(), "only single op result supported");
231   }
232 
233   DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
234   Operator resultOp(*resultRoot->getDef());
235   auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
236 
237   os << formatv(R"(
238   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
239                PatternRewriter &rewriter) const override {
240     auto* context = op->getContext(); (void)context;
241     auto& s = *static_cast<MatchedState *>(state.get());
242     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
243                 resultOp.cppClassName());
244   if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
245     PrintFatalError(pattern->getLoc(),
246                     Twine("mismatch between arguments of resultant op (") +
247                         Twine(resultOperands->getNumArgs()) +
248                         ") and arguments provided for rewrite (" +
249                         Twine(resultTree->getNumArgs()) + Twine(')'));
250   }
251 
252   // Create the builder call for the result.
253   // Add operands.
254   int i = 0;
255   for (auto operand : resultOp.getOperands()) {
256     // Start each operand on its own line.
257     (os << ",\n").indent(6);
258 
259     auto name = resultTree->getArgNameStr(i);
260     if (boundArguments.find(name) == boundArguments.end())
261       PrintFatalError(pattern->getLoc(),
262                       Twine("referencing unbound variable '") + name + "'");
263     if (operand.name)
264       os << "/*" << operand.name->getAsUnquotedString() << "=*/";
265     os << "s." << name;
266     // TODO(jpienaar): verify types
267     ++i;
268   }
269 
270   // Add attributes.
271   for (int e = resultTree->getNumArgs(); i != e; ++i) {
272     // Start each attribute on its own line.
273     (os << ",\n").indent(6);
274 
275     // The argument in the result DAG pattern.
276     std::string name = resultTree->getArgNameStr(i);
277     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
278     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
279     if (!value)
280       PrintFatalError(pattern->getLoc(),
281                       Twine("attribute '") + name +
282                           "' needs to be constant initialized");
283 
284     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
285     auto argument = resultOp.getArg(i);
286     if (!argument.is<mlir::Operator::Attribute *>())
287       PrintFatalError(pattern->getLoc(),
288                       Twine("expected attribute ") + Twine(i));
289 
290     if (!name.empty())
291       os << "/*" << name << "=*/";
292     emitAttributeValue(defInit->getDef());
293     // TODO(jpienaar): verify types
294   }
295   os << "\n    );\n  }\n};\n";
296 }
297 
298 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
299   Pattern pattern(p, os);
300   pattern.emit(rewriteName);
301 }
302 
303 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
304   emitSourceFileHeader("Rewriters", os);
305   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
306 
307   // Ensure unique patterns simply by appending unique suffix.
308   std::string baseRewriteName = "GeneratedConvert";
309   int rewritePatternCount = 0;
310   for (Record *p : patterns) {
311     Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
312   }
313 
314   // Emit function to add the generated matchers to the pattern list.
315   os << "void populateWithGenerated(MLIRContext *context, "
316      << "OwningRewritePatternList *patterns) {\n";
317   for (unsigned i = 0; i != rewritePatternCount; ++i) {
318     os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
319                  << i << ">(context));\n";
320   }
321   os << "}\n";
322 }
323 
324 mlir::GenRegistration
325     genRewriters("gen-rewriters", "Generate pattern rewriters",
326                  [](const RecordKeeper &records, raw_ostream &os) {
327                    emitRewriters(records, os);
328                    return false;
329                  });
330