xref: /llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp (revision f4d758634305304c0deb49a4ed3f99180a2488ea)
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/ADT/SmallVectorExtras.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/FormatVariadic.h"
16 #include "llvm/TableGen/Error.h"
17 #include "llvm/TableGen/Record.h"
18 #include <regex>
19 
20 using namespace llvm;
21 
22 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
23 static cl::opt<std::string>
24     selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
25                       cl::cat(dialectGenCat), cl::CommaSeparated);
26 
27 namespace {
28 
29 /// Helper class to generate C++ bytecode parser helpers.
30 class Generator {
31 public:
32   Generator(raw_ostream &output) : output(output) {}
33 
34   /// Returns whether successfully emitted attribute/type parsers.
35   void emitParse(StringRef kind, const Record &x);
36 
37   /// Returns whether successfully emitted attribute/type printers.
38   void emitPrint(StringRef kind, StringRef type,
39                  ArrayRef<std::pair<int64_t, const Record *>> vec);
40 
41   /// Emits parse dispatch table.
42   void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
43 
44   /// Emits print dispatch table.
45   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
46 
47 private:
48   /// Emits parse calls to construct given kind.
49   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50                        ArrayRef<const Init *> args,
51                        ArrayRef<std::string> argNames, StringRef failure,
52                        mlir::raw_indented_ostream &ios);
53 
54   /// Emits print instructions.
55   void emitPrintHelper(const Record *memberRec, StringRef kind,
56                        StringRef parent, StringRef name,
57                        mlir::raw_indented_ostream &ios);
58 
59   raw_ostream &output;
60 };
61 } // namespace
62 
63 /// Helper to replace set of from strings to target in `s`.
64 /// Assumed: non-overlapping replacements.
65 static std::string format(StringRef templ,
66                           std::map<std::string, std::string> &&map) {
67   std::string s = templ.str();
68   for (const auto &[from, to] : map)
69     // All replacements start with $, don't treat as anchor.
70     s = std::regex_replace(s, std::regex("\\" + from), to);
71   return s;
72 }
73 
74 /// Return string with first character capitalized.
75 static std::string capitalize(StringRef str) {
76   return ((Twine)toUpper(str[0]) + str.drop_front()).str();
77 }
78 
79 /// Return the C++ type for the given record.
80 static std::string getCType(const Record *def) {
81   std::string format = "{0}";
82   if (def->isSubClassOf("Array")) {
83     def = def->getValueAsDef("elemT");
84     format = "SmallVector<{0}>";
85   }
86 
87   StringRef cType = def->getValueAsString("cType");
88   if (cType.empty()) {
89     if (def->isAnonymous())
90       PrintFatalError(def->getLoc(), "Unable to determine cType");
91 
92     return formatv(format.c_str(), def->getName().str());
93   }
94   return formatv(format.c_str(), cType.str());
95 }
96 
97 void Generator::emitParseDispatch(StringRef kind,
98                                   ArrayRef<const Record *> vec) {
99   mlir::raw_indented_ostream os(output);
100   char const *head =
101       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
102   os << formatv(head, capitalize(kind));
103   auto funScope = os.scope(" {\n", "}\n\n");
104 
105   if (vec.empty()) {
106     os << "return reader.emitError() << \"unknown attribute\", "
107        << capitalize(kind) << "();\n";
108     return;
109   }
110 
111   os << "uint64_t kind;\n";
112   os << "if (failed(reader.readVarInt(kind)))\n"
113      << "  return " << capitalize(kind) << "();\n";
114   os << "switch (kind) ";
115   {
116     auto switchScope = os.scope("{\n", "}\n");
117     for (const auto &it : llvm::enumerate(vec)) {
118       if (it.value()->getName() == "ReservedOrDead")
119         continue;
120 
121       os << formatv("case {1}:\n  return read{0}(context, reader);\n",
122                     it.value()->getName(), it.index());
123     }
124     os << "default:\n"
125        << "  reader.emitError() << \"unknown attribute code: \" "
126        << "<< kind;\n"
127        << "  return " << capitalize(kind) << "();\n";
128   }
129   os << "return " << capitalize(kind) << "();\n";
130 }
131 
132 void Generator::emitParse(StringRef kind, const Record &x) {
133   if (x.getNameInitAsString() == "ReservedOrDead")
134     return;
135 
136   char const *head =
137       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
138   mlir::raw_indented_ostream os(output);
139   std::string returnType = getCType(&x);
140   os << formatv(head,
141                 kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
142                 x.getName());
143   const DagInit *members = x.getValueAsDag("members");
144   SmallVector<std::string> argNames = llvm::to_vector(
145       map_range(members->getArgNames(), [](const StringInit *init) {
146         return init->getAsUnquotedString();
147       }));
148   StringRef builder = x.getValueAsString("cBuilder").trim();
149   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
150                   returnType + "()", os);
151   os << "\n\n";
152 }
153 
154 void printParseConditional(mlir::raw_indented_ostream &ios,
155                            ArrayRef<const Init *> args,
156                            ArrayRef<std::string> argNames) {
157   ios << "if ";
158   auto parenScope = ios.scope("(", ") {");
159   ios.indent();
160 
161   auto listHelperName = [](StringRef name) {
162     return formatv("read{0}", capitalize(name));
163   };
164 
165   auto parsedArgs = llvm::filter_to_vector(args, [](const Init *const attr) {
166     const Record *def = cast<DefInit>(attr)->getDef();
167     if (def->isSubClassOf("Array"))
168       return true;
169     return !def->getValueAsString("cParser").empty();
170   });
171 
172   interleave(
173       zip(parsedArgs, argNames),
174       [&](std::tuple<const Init *&, const std::string &> it) {
175         const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
176         std::string parser;
177         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
178           parser = *optParser;
179         } else if (attr->isSubClassOf("Array")) {
180           const Record *def = attr->getValueAsDef("elemT");
181           bool composite = def->isSubClassOf("CompositeBytecode");
182           if (!composite && def->isSubClassOf("AttributeKind"))
183             parser = "succeeded($_reader.readAttributes($_var))";
184           else if (!composite && def->isSubClassOf("TypeKind"))
185             parser = "succeeded($_reader.readTypes($_var))";
186           else
187             parser = ("succeeded($_reader.readList($_var, " +
188                       listHelperName(std::get<1>(it)) + "))")
189                          .str();
190         } else {
191           PrintFatalError(attr->getLoc(), "No parser specified");
192         }
193         std::string type = getCType(attr);
194         ios << format(parser, {{"$_reader", "reader"},
195                                {"$_resultType", type},
196                                {"$_var", std::get<1>(it)}});
197       },
198       [&]() { ios << " &&\n"; });
199 }
200 
201 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
202                                 StringRef builder, ArrayRef<const Init *> args,
203                                 ArrayRef<std::string> argNames,
204                                 StringRef failure,
205                                 mlir::raw_indented_ostream &ios) {
206   auto funScope = ios.scope("{\n", "}");
207 
208   if (args.empty()) {
209     ios << formatv("return get<{0}>(context);\n", returnType);
210     return;
211   }
212 
213   // Print decls.
214   std::string lastCType = "";
215   for (auto [arg, name] : zip(args, argNames)) {
216     const DefInit *first = dyn_cast<DefInit>(arg);
217     if (!first)
218       PrintFatalError("Unexpected type for " + name);
219     const Record *def = first->getDef();
220 
221     // Create variable decls, if there are a block of same type then create
222     // comma separated list of them.
223     std::string cType = getCType(def);
224     if (lastCType == cType) {
225       ios << ", ";
226     } else {
227       if (!lastCType.empty())
228         ios << ";\n";
229       ios << cType << " ";
230     }
231     ios << name;
232     lastCType = cType;
233   }
234   ios << ";\n";
235 
236   // Returns the name of the helper used in list parsing. E.g., the name of the
237   // lambda passed to array parsing.
238   auto listHelperName = [](StringRef name) {
239     return formatv("read{0}", capitalize(name));
240   };
241 
242   // Emit list helper functions.
243   for (auto [arg, name] : zip(args, argNames)) {
244     const Record *attr = cast<DefInit>(arg)->getDef();
245     if (!attr->isSubClassOf("Array"))
246       continue;
247 
248     // TODO: Dedupe readers.
249     const Record *def = attr->getValueAsDef("elemT");
250     if (!def->isSubClassOf("CompositeBytecode") &&
251         (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
252       continue;
253 
254     std::string returnType = getCType(def);
255     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
256         << returnType << "> ";
257     SmallVector<const Init *> args;
258     SmallVector<std::string> argNames;
259     if (def->isSubClassOf("CompositeBytecode")) {
260       const DagInit *members = def->getValueAsDag("members");
261       args = llvm::to_vector(members->getArgs());
262       argNames = llvm::to_vector(
263           map_range(members->getArgNames(), [](const StringInit *init) {
264             return init->getAsUnquotedString();
265           }));
266     } else {
267       args = {def->getDefInit()};
268       argNames = {"temp"};
269     }
270     StringRef builder = def->getValueAsString("cBuilder");
271     emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
272                     ios);
273     ios << ";\n";
274   }
275 
276   // Print parse conditional.
277   printParseConditional(ios, args, argNames);
278 
279   // Compute args to pass to create method.
280   auto passedArgs = llvm::filter_to_vector(
281       argNames, [](StringRef str) { return !str.starts_with("_"); });
282   std::string argStr;
283   raw_string_ostream argStream(argStr);
284   interleaveComma(passedArgs, argStream,
285                   [&](const std::string &str) { argStream << str; });
286   // Return the invoked constructor.
287   ios << "\nreturn "
288       << format(builder, {{"$_resultType", returnType.str()},
289                           {"$_args", argStream.str()}})
290       << ";\n";
291   ios.unindent();
292 
293   // TODO: Emit error in debug.
294   // This assumes the result types in error case can always be empty
295   // constructed.
296   ios << "}\nreturn " << failure << ";\n";
297 }
298 
299 void Generator::emitPrint(StringRef kind, StringRef type,
300                           ArrayRef<std::pair<int64_t, const Record *>> vec) {
301   if (type == "ReservedOrDead")
302     return;
303 
304   char const *head =
305       R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
306   mlir::raw_indented_ostream os(output);
307   os << formatv(head, type, kind);
308   auto funScope = os.scope("{\n", "}\n\n");
309 
310   // Check that predicates specified if multiple bytecode instances.
311   for (const Record *rec : make_second_range(vec)) {
312     StringRef pred = rec->getValueAsString("printerPredicate");
313     if (vec.size() > 1 && pred.empty()) {
314       for (auto [index, rec] : vec) {
315         (void)index;
316         StringRef pred = rec->getValueAsString("printerPredicate");
317         if (vec.size() > 1 && pred.empty())
318           PrintError(rec->getLoc(),
319                      "Requires parsing predicate given common cType");
320       }
321       PrintFatalError("Unspecified for shared cType " + type);
322     }
323   }
324 
325   for (auto [index, rec] : vec) {
326     StringRef pred = rec->getValueAsString("printerPredicate");
327     if (!pred.empty()) {
328       os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
329       os.indent();
330     }
331 
332     os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
333        << ");\n";
334 
335     auto *members = rec->getValueAsDag("members");
336     for (auto [arg, name] :
337          llvm::zip(members->getArgs(), members->getArgNames())) {
338       const DefInit *def = dyn_cast<DefInit>(arg);
339       assert(def);
340       const Record *memberRec = def->getDef();
341       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
342     }
343 
344     if (!pred.empty()) {
345       os.unindent();
346       os << "}\n";
347     }
348   }
349 }
350 
351 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
352                                 StringRef parent, StringRef name,
353                                 mlir::raw_indented_ostream &ios) {
354   std::string getter;
355   if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
356       cGetter && !cGetter->empty()) {
357     getter = format(
358         *cGetter,
359         {{"$_attrType", parent.str()},
360          {"$_member", name.str()},
361          {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
362   } else {
363     getter =
364         formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
365             .str();
366   }
367 
368   if (memberRec->isSubClassOf("Array")) {
369     const Record *def = memberRec->getValueAsDef("elemT");
370     if (!def->isSubClassOf("CompositeBytecode")) {
371       if (def->isSubClassOf("AttributeKind")) {
372         ios << "writer.writeAttributes(" << getter << ");\n";
373         return;
374       }
375       if (def->isSubClassOf("TypeKind")) {
376         ios << "writer.writeTypes(" << getter << ");\n";
377         return;
378       }
379     }
380     std::string returnType = getCType(def);
381     std::string nestedName = kind.str();
382     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
383         << nestedName << ") ";
384     auto lambdaScope = ios.scope("{\n", "});\n");
385     return emitPrintHelper(def, kind, nestedName, nestedName, ios);
386   }
387   if (memberRec->isSubClassOf("CompositeBytecode")) {
388     auto *members = memberRec->getValueAsDag("members");
389     for (auto [arg, argName] :
390          zip(members->getArgs(), members->getArgNames())) {
391       const DefInit *def = dyn_cast<DefInit>(arg);
392       assert(def);
393       emitPrintHelper(def->getDef(), kind, parent,
394                       argName->getAsUnquotedString(), ios);
395     }
396   }
397 
398   if (std::string printer = memberRec->getValueAsString("cPrinter").str();
399       !printer.empty())
400     ios << format(printer, {{"$_writer", "writer"},
401                             {"$_name", kind.str()},
402                             {"$_getter", getter}})
403         << ";\n";
404 }
405 
406 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
407   mlir::raw_indented_ostream os(output);
408   char const *head = R"(static LogicalResult write{0}({0} {1},
409                                 DialectBytecodeWriter &writer))";
410   os << formatv(head, capitalize(kind), kind);
411   auto funScope = os.scope(" {\n", "}\n\n");
412 
413   os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
414      << ")";
415   auto switchScope = os.scope("", "");
416   for (StringRef type : vec) {
417     if (type == "ReservedOrDead")
418       continue;
419 
420     os << "\n.Case([&](" << type << " t)";
421     auto caseScope = os.scope(" {\n", "})");
422     os << "return write(t, writer), success();\n";
423   }
424   os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
425 }
426 
427 namespace {
428 /// Container of Attribute or Type for Dialect.
429 struct AttrOrType {
430   std::vector<const Record *> attr, type;
431 };
432 } // namespace
433 
434 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
435   MapVector<StringRef, AttrOrType> dialectAttrOrType;
436   for (const Record *it :
437        records.getAllDerivedDefinitions("DialectAttributes")) {
438     if (!selectedBcDialect.empty() &&
439         it->getValueAsString("dialect") != selectedBcDialect)
440       continue;
441     dialectAttrOrType[it->getValueAsString("dialect")].attr =
442         it->getValueAsListOfDefs("elems");
443   }
444   for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
445     if (!selectedBcDialect.empty() &&
446         it->getValueAsString("dialect") != selectedBcDialect)
447       continue;
448     dialectAttrOrType[it->getValueAsString("dialect")].type =
449         it->getValueAsListOfDefs("elems");
450   }
451 
452   if (dialectAttrOrType.size() != 1)
453     PrintFatalError("Single dialect per invocation required (either only "
454                     "one in input file or specified via dialect option)");
455 
456   auto it = dialectAttrOrType.front();
457   Generator gen(os);
458 
459   SmallVector<std::vector<const Record *> *, 2> vecs;
460   SmallVector<std::string, 2> kinds;
461   vecs.push_back(&it.second.attr);
462   kinds.push_back("attribute");
463   vecs.push_back(&it.second.type);
464   kinds.push_back("type");
465   for (auto [vec, kind] : zip(vecs, kinds)) {
466     // Handle Attribute/Type emission.
467     std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
468         perType;
469     for (auto kt : llvm::enumerate(*vec))
470       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
471     for (const auto &jt : perType) {
472       for (auto kt : jt.second)
473         gen.emitParse(kind, *std::get<1>(kt));
474       gen.emitPrint(kind, jt.first, jt.second);
475     }
476     gen.emitParseDispatch(kind, *vec);
477 
478     SmallVector<std::string> types;
479     for (const auto &it : perType) {
480       types.push_back(it.first);
481     }
482     gen.emitPrintDispatch(kind, types);
483   }
484 
485   return false;
486 }
487 
488 static mlir::GenRegistration
489     genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
490             [](const RecordKeeper &records, raw_ostream &os) {
491               return emitBCRW(records, os);
492             });
493