xref: /llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp (revision 1ae9fe5ea0c502195f0c857759e2277ba9c8b338)
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
17 #include <regex>
18 
19 using namespace llvm;
20 
21 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22 static cl::opt<std::string>
23     selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
24                       cl::cat(dialectGenCat), cl::CommaSeparated);
25 
26 namespace {
27 
28 /// Helper class to generate C++ bytecode parser helpers.
29 class Generator {
30 public:
31   Generator(raw_ostream &output) : output(output) {}
32 
33   /// Returns whether successfully emitted attribute/type parsers.
34   void emitParse(StringRef kind, const Record &x);
35 
36   /// Returns whether successfully emitted attribute/type printers.
37   void emitPrint(StringRef kind, StringRef type,
38                  ArrayRef<std::pair<int64_t, const Record *>> vec);
39 
40   /// Emits parse dispatch table.
41   void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
42 
43   /// Emits print dispatch table.
44   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
45 
46 private:
47   /// Emits parse calls to construct given kind.
48   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
49                        ArrayRef<const Init *> args,
50                        ArrayRef<std::string> argNames, StringRef failure,
51                        mlir::raw_indented_ostream &ios);
52 
53   /// Emits print instructions.
54   void emitPrintHelper(const Record *memberRec, StringRef kind,
55                        StringRef parent, StringRef name,
56                        mlir::raw_indented_ostream &ios);
57 
58   raw_ostream &output;
59 };
60 } // namespace
61 
62 /// Helper to replace set of from strings to target in `s`.
63 /// Assumed: non-overlapping replacements.
64 static std::string format(StringRef templ,
65                           std::map<std::string, std::string> &&map) {
66   std::string s = templ.str();
67   for (const auto &[from, to] : map)
68     // All replacements start with $, don't treat as anchor.
69     s = std::regex_replace(s, std::regex("\\" + from), to);
70   return s;
71 }
72 
73 /// Return string with first character capitalized.
74 static std::string capitalize(StringRef str) {
75   return ((Twine)toUpper(str[0]) + str.drop_front()).str();
76 }
77 
78 /// Return the C++ type for the given record.
79 static std::string getCType(const Record *def) {
80   std::string format = "{0}";
81   if (def->isSubClassOf("Array")) {
82     def = def->getValueAsDef("elemT");
83     format = "SmallVector<{0}>";
84   }
85 
86   StringRef cType = def->getValueAsString("cType");
87   if (cType.empty()) {
88     if (def->isAnonymous())
89       PrintFatalError(def->getLoc(), "Unable to determine cType");
90 
91     return formatv(format.c_str(), def->getName().str());
92   }
93   return formatv(format.c_str(), cType.str());
94 }
95 
96 void Generator::emitParseDispatch(StringRef kind,
97                                   ArrayRef<const Record *> vec) {
98   mlir::raw_indented_ostream os(output);
99   char const *head =
100       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
101   os << formatv(head, capitalize(kind));
102   auto funScope = os.scope(" {\n", "}\n\n");
103 
104   if (vec.empty()) {
105     os << "return reader.emitError() << \"unknown attribute\", "
106        << capitalize(kind) << "();\n";
107     return;
108   }
109 
110   os << "uint64_t kind;\n";
111   os << "if (failed(reader.readVarInt(kind)))\n"
112      << "  return " << capitalize(kind) << "();\n";
113   os << "switch (kind) ";
114   {
115     auto switchScope = os.scope("{\n", "}\n");
116     for (const auto &it : llvm::enumerate(vec)) {
117       if (it.value()->getName() == "ReservedOrDead")
118         continue;
119 
120       os << formatv("case {1}:\n  return read{0}(context, reader);\n",
121                     it.value()->getName(), it.index());
122     }
123     os << "default:\n"
124        << "  reader.emitError() << \"unknown attribute code: \" "
125        << "<< kind;\n"
126        << "  return " << capitalize(kind) << "();\n";
127   }
128   os << "return " << capitalize(kind) << "();\n";
129 }
130 
131 void Generator::emitParse(StringRef kind, const Record &x) {
132   if (x.getNameInitAsString() == "ReservedOrDead")
133     return;
134 
135   char const *head =
136       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
137   mlir::raw_indented_ostream os(output);
138   std::string returnType = getCType(&x);
139   os << formatv(head,
140                 kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
141                 x.getName());
142   const DagInit *members = x.getValueAsDag("members");
143   SmallVector<std::string> argNames = llvm::to_vector(
144       map_range(members->getArgNames(), [](const StringInit *init) {
145         return init->getAsUnquotedString();
146       }));
147   StringRef builder = x.getValueAsString("cBuilder").trim();
148   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
149                   returnType + "()", os);
150   os << "\n\n";
151 }
152 
153 void printParseConditional(mlir::raw_indented_ostream &ios,
154                            ArrayRef<const Init *> args,
155                            ArrayRef<std::string> argNames) {
156   ios << "if ";
157   auto parenScope = ios.scope("(", ") {");
158   ios.indent();
159 
160   auto listHelperName = [](StringRef name) {
161     return formatv("read{0}", capitalize(name));
162   };
163 
164   auto parsedArgs =
165       llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
166         const Record *def = cast<DefInit>(attr)->getDef();
167         if (def->isSubClassOf("Array"))
168           return true;
169         return !def->getValueAsString("cParser").empty();
170       }));
171 
172   interleave(
173       zip(parsedArgs, argNames),
174       [&](std::tuple<const Init *&, const std::string &> it) {
175         const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
176         std::string parser;
177         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
178           parser = *optParser;
179         } else if (attr->isSubClassOf("Array")) {
180           const Record *def = attr->getValueAsDef("elemT");
181           bool composite = def->isSubClassOf("CompositeBytecode");
182           if (!composite && def->isSubClassOf("AttributeKind"))
183             parser = "succeeded($_reader.readAttributes($_var))";
184           else if (!composite && def->isSubClassOf("TypeKind"))
185             parser = "succeeded($_reader.readTypes($_var))";
186           else
187             parser = ("succeeded($_reader.readList($_var, " +
188                       listHelperName(std::get<1>(it)) + "))")
189                          .str();
190         } else {
191           PrintFatalError(attr->getLoc(), "No parser specified");
192         }
193         std::string type = getCType(attr);
194         ios << format(parser, {{"$_reader", "reader"},
195                                {"$_resultType", type},
196                                {"$_var", std::get<1>(it)}});
197       },
198       [&]() { ios << " &&\n"; });
199 }
200 
201 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
202                                 StringRef builder, ArrayRef<const Init *> args,
203                                 ArrayRef<std::string> argNames,
204                                 StringRef failure,
205                                 mlir::raw_indented_ostream &ios) {
206   auto funScope = ios.scope("{\n", "}");
207 
208   if (args.empty()) {
209     ios << formatv("return get<{0}>(context);\n", returnType);
210     return;
211   }
212 
213   // Print decls.
214   std::string lastCType = "";
215   for (auto [arg, name] : zip(args, argNames)) {
216     const DefInit *first = dyn_cast<DefInit>(arg);
217     if (!first)
218       PrintFatalError("Unexpected type for " + name);
219     const Record *def = first->getDef();
220 
221     // Create variable decls, if there are a block of same type then create
222     // comma separated list of them.
223     std::string cType = getCType(def);
224     if (lastCType == cType) {
225       ios << ", ";
226     } else {
227       if (!lastCType.empty())
228         ios << ";\n";
229       ios << cType << " ";
230     }
231     ios << name;
232     lastCType = cType;
233   }
234   ios << ";\n";
235 
236   // Returns the name of the helper used in list parsing. E.g., the name of the
237   // lambda passed to array parsing.
238   auto listHelperName = [](StringRef name) {
239     return formatv("read{0}", capitalize(name));
240   };
241 
242   // Emit list helper functions.
243   for (auto [arg, name] : zip(args, argNames)) {
244     const Record *attr = cast<DefInit>(arg)->getDef();
245     if (!attr->isSubClassOf("Array"))
246       continue;
247 
248     // TODO: Dedupe readers.
249     const Record *def = attr->getValueAsDef("elemT");
250     if (!def->isSubClassOf("CompositeBytecode") &&
251         (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
252       continue;
253 
254     std::string returnType = getCType(def);
255     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
256         << returnType << "> ";
257     SmallVector<const Init *> args;
258     SmallVector<std::string> argNames;
259     if (def->isSubClassOf("CompositeBytecode")) {
260       const DagInit *members = def->getValueAsDag("members");
261       args = llvm::to_vector(map_range(
262           members->getArgs(), [](Init *init) { return (const Init *)init; }));
263       argNames = llvm::to_vector(
264           map_range(members->getArgNames(), [](const StringInit *init) {
265             return init->getAsUnquotedString();
266           }));
267     } else {
268       args = {def->getDefInit()};
269       argNames = {"temp"};
270     }
271     StringRef builder = def->getValueAsString("cBuilder");
272     emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
273                     ios);
274     ios << ";\n";
275   }
276 
277   // Print parse conditional.
278   printParseConditional(ios, args, argNames);
279 
280   // Compute args to pass to create method.
281   auto passedArgs = llvm::to_vector(make_filter_range(
282       argNames, [](StringRef str) { return !str.starts_with("_"); }));
283   std::string argStr;
284   raw_string_ostream argStream(argStr);
285   interleaveComma(passedArgs, argStream,
286                   [&](const std::string &str) { argStream << str; });
287   // Return the invoked constructor.
288   ios << "\nreturn "
289       << format(builder, {{"$_resultType", returnType.str()},
290                           {"$_args", argStream.str()}})
291       << ";\n";
292   ios.unindent();
293 
294   // TODO: Emit error in debug.
295   // This assumes the result types in error case can always be empty
296   // constructed.
297   ios << "}\nreturn " << failure << ";\n";
298 }
299 
300 void Generator::emitPrint(StringRef kind, StringRef type,
301                           ArrayRef<std::pair<int64_t, const Record *>> vec) {
302   if (type == "ReservedOrDead")
303     return;
304 
305   char const *head =
306       R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
307   mlir::raw_indented_ostream os(output);
308   os << formatv(head, type, kind);
309   auto funScope = os.scope("{\n", "}\n\n");
310 
311   // Check that predicates specified if multiple bytecode instances.
312   for (const Record *rec : make_second_range(vec)) {
313     StringRef pred = rec->getValueAsString("printerPredicate");
314     if (vec.size() > 1 && pred.empty()) {
315       for (auto [index, rec] : vec) {
316         (void)index;
317         StringRef pred = rec->getValueAsString("printerPredicate");
318         if (vec.size() > 1 && pred.empty())
319           PrintError(rec->getLoc(),
320                      "Requires parsing predicate given common cType");
321       }
322       PrintFatalError("Unspecified for shared cType " + type);
323     }
324   }
325 
326   for (auto [index, rec] : vec) {
327     StringRef pred = rec->getValueAsString("printerPredicate");
328     if (!pred.empty()) {
329       os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
330       os.indent();
331     }
332 
333     os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
334        << ");\n";
335 
336     auto *members = rec->getValueAsDag("members");
337     for (auto [arg, name] :
338          llvm::zip(members->getArgs(), members->getArgNames())) {
339       const DefInit *def = dyn_cast<DefInit>(arg);
340       assert(def);
341       const Record *memberRec = def->getDef();
342       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
343     }
344 
345     if (!pred.empty()) {
346       os.unindent();
347       os << "}\n";
348     }
349   }
350 }
351 
352 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
353                                 StringRef parent, StringRef name,
354                                 mlir::raw_indented_ostream &ios) {
355   std::string getter;
356   if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
357       cGetter && !cGetter->empty()) {
358     getter = format(
359         *cGetter,
360         {{"$_attrType", parent.str()},
361          {"$_member", name.str()},
362          {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
363   } else {
364     getter =
365         formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
366             .str();
367   }
368 
369   if (memberRec->isSubClassOf("Array")) {
370     const Record *def = memberRec->getValueAsDef("elemT");
371     if (!def->isSubClassOf("CompositeBytecode")) {
372       if (def->isSubClassOf("AttributeKind")) {
373         ios << "writer.writeAttributes(" << getter << ");\n";
374         return;
375       }
376       if (def->isSubClassOf("TypeKind")) {
377         ios << "writer.writeTypes(" << getter << ");\n";
378         return;
379       }
380     }
381     std::string returnType = getCType(def);
382     std::string nestedName = kind.str();
383     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
384         << nestedName << ") ";
385     auto lambdaScope = ios.scope("{\n", "});\n");
386     return emitPrintHelper(def, kind, nestedName, nestedName, ios);
387   }
388   if (memberRec->isSubClassOf("CompositeBytecode")) {
389     auto *members = memberRec->getValueAsDag("members");
390     for (auto [arg, argName] :
391          zip(members->getArgs(), members->getArgNames())) {
392       const DefInit *def = dyn_cast<DefInit>(arg);
393       assert(def);
394       emitPrintHelper(def->getDef(), kind, parent,
395                       argName->getAsUnquotedString(), ios);
396     }
397   }
398 
399   if (std::string printer = memberRec->getValueAsString("cPrinter").str();
400       !printer.empty())
401     ios << format(printer, {{"$_writer", "writer"},
402                             {"$_name", kind.str()},
403                             {"$_getter", getter}})
404         << ";\n";
405 }
406 
407 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
408   mlir::raw_indented_ostream os(output);
409   char const *head = R"(static LogicalResult write{0}({0} {1},
410                                 DialectBytecodeWriter &writer))";
411   os << formatv(head, capitalize(kind), kind);
412   auto funScope = os.scope(" {\n", "}\n\n");
413 
414   os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
415      << ")";
416   auto switchScope = os.scope("", "");
417   for (StringRef type : vec) {
418     if (type == "ReservedOrDead")
419       continue;
420 
421     os << "\n.Case([&](" << type << " t)";
422     auto caseScope = os.scope(" {\n", "})");
423     os << "return write(t, writer), success();\n";
424   }
425   os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
426 }
427 
428 namespace {
429 /// Container of Attribute or Type for Dialect.
430 struct AttrOrType {
431   std::vector<const Record *> attr, type;
432 };
433 } // namespace
434 
435 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
436   MapVector<StringRef, AttrOrType> dialectAttrOrType;
437   for (const Record *it :
438        records.getAllDerivedDefinitions("DialectAttributes")) {
439     if (!selectedBcDialect.empty() &&
440         it->getValueAsString("dialect") != selectedBcDialect)
441       continue;
442     dialectAttrOrType[it->getValueAsString("dialect")].attr =
443         it->getValueAsListOfDefs("elems");
444   }
445   for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
446     if (!selectedBcDialect.empty() &&
447         it->getValueAsString("dialect") != selectedBcDialect)
448       continue;
449     dialectAttrOrType[it->getValueAsString("dialect")].type =
450         it->getValueAsListOfDefs("elems");
451   }
452 
453   if (dialectAttrOrType.size() != 1)
454     PrintFatalError("Single dialect per invocation required (either only "
455                     "one in input file or specified via dialect option)");
456 
457   auto it = dialectAttrOrType.front();
458   Generator gen(os);
459 
460   SmallVector<std::vector<const Record *> *, 2> vecs;
461   SmallVector<std::string, 2> kinds;
462   vecs.push_back(&it.second.attr);
463   kinds.push_back("attribute");
464   vecs.push_back(&it.second.type);
465   kinds.push_back("type");
466   for (auto [vec, kind] : zip(vecs, kinds)) {
467     // Handle Attribute/Type emission.
468     std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
469         perType;
470     for (auto kt : llvm::enumerate(*vec))
471       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
472     for (const auto &jt : perType) {
473       for (auto kt : jt.second)
474         gen.emitParse(kind, *std::get<1>(kt));
475       gen.emitPrint(kind, jt.first, jt.second);
476     }
477     gen.emitParseDispatch(kind, *vec);
478 
479     SmallVector<std::string> types;
480     for (const auto &it : perType) {
481       types.push_back(it.first);
482     }
483     gen.emitPrintDispatch(kind, types);
484   }
485 
486   return false;
487 }
488 
489 static mlir::GenRegistration
490     genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
491             [](const RecordKeeper &records, raw_ostream &os) {
492               return emitBCRW(records, os);
493             });
494