xref: /llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp (revision a140931be5080543372ed833aea4e8f9c96bc4b5)
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
17 #include <regex>
18 
19 using namespace llvm;
20 
21 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22 static llvm::cl::opt<std::string>
23     selectedBcDialect("bytecode-dialect",
24                       llvm::cl::desc("The dialect to gen for"),
25                       llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
26 
27 namespace {
28 
29 /// Helper class to generate C++ bytecode parser helpers.
30 class Generator {
31 public:
32   Generator(raw_ostream &output) : output(output) {}
33 
34   /// Returns whether successfully emitted attribute/type parsers.
35   void emitParse(StringRef kind, const Record &x);
36 
37   /// Returns whether successfully emitted attribute/type printers.
38   void emitPrint(StringRef kind, StringRef type,
39                  ArrayRef<std::pair<int64_t, const Record *>> vec);
40 
41   /// Emits parse dispatch table.
42   void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
43 
44   /// Emits print dispatch table.
45   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
46 
47 private:
48   /// Emits parse calls to construct given kind.
49   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50                        ArrayRef<Init *> args, ArrayRef<std::string> argNames,
51                        StringRef failure, mlir::raw_indented_ostream &ios);
52 
53   /// Emits print instructions.
54   void emitPrintHelper(const Record *memberRec, StringRef kind,
55                        StringRef parent, StringRef name,
56                        mlir::raw_indented_ostream &ios);
57 
58   raw_ostream &output;
59 };
60 } // namespace
61 
62 /// Helper to replace set of from strings to target in `s`.
63 /// Assumed: non-overlapping replacements.
64 static std::string format(StringRef templ,
65                           std::map<std::string, std::string> &&map) {
66   std::string s = templ.str();
67   for (const auto &[from, to] : map)
68     // All replacements start with $, don't treat as anchor.
69     s = std::regex_replace(s, std::regex("\\" + from), to);
70   return s;
71 }
72 
73 /// Return string with first character capitalized.
74 static std::string capitalize(StringRef str) {
75   return ((Twine)toUpper(str[0]) + str.drop_front()).str();
76 }
77 
78 /// Return the C++ type for the given record.
79 static std::string getCType(const Record *def) {
80   std::string format = "{0}";
81   if (def->isSubClassOf("Array")) {
82     def = def->getValueAsDef("elemT");
83     format = "SmallVector<{0}>";
84   }
85 
86   StringRef cType = def->getValueAsString("cType");
87   if (cType.empty()) {
88     if (def->isAnonymous())
89       PrintFatalError(def->getLoc(), "Unable to determine cType");
90 
91     return formatv(format.c_str(), def->getName().str());
92   }
93   return formatv(format.c_str(), cType.str());
94 }
95 
96 void Generator::emitParseDispatch(StringRef kind,
97                                   ArrayRef<const Record *> vec) {
98   mlir::raw_indented_ostream os(output);
99   char const *head =
100       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
101   os << formatv(head, capitalize(kind));
102   auto funScope = os.scope(" {\n", "}\n\n");
103 
104   if (vec.empty()) {
105     os << "return reader.emitError() << \"unknown attribute\", "
106        << capitalize(kind) << "();\n";
107     return;
108   }
109 
110   os << "uint64_t kind;\n";
111   os << "if (failed(reader.readVarInt(kind)))\n"
112      << "  return " << capitalize(kind) << "();\n";
113   os << "switch (kind) ";
114   {
115     auto switchScope = os.scope("{\n", "}\n");
116     for (const auto &it : llvm::enumerate(vec)) {
117       if (it.value()->getName() == "ReservedOrDead")
118         continue;
119 
120       os << formatv("case {1}:\n  return read{0}(context, reader);\n",
121                     it.value()->getName(), it.index());
122     }
123     os << "default:\n"
124        << "  reader.emitError() << \"unknown attribute code: \" "
125        << "<< kind;\n"
126        << "  return " << capitalize(kind) << "();\n";
127   }
128   os << "return " << capitalize(kind) << "();\n";
129 }
130 
131 void Generator::emitParse(StringRef kind, const Record &x) {
132   if (x.getNameInitAsString() == "ReservedOrDead")
133     return;
134 
135   char const *head =
136       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
137   mlir::raw_indented_ostream os(output);
138   std::string returnType = getCType(&x);
139   os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
140   DagInit *members = x.getValueAsDag("members");
141   SmallVector<std::string> argNames =
142       llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
143         return init->getAsUnquotedString();
144       }));
145   StringRef builder = x.getValueAsString("cBuilder").trim();
146   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
147                   returnType + "()", os);
148   os << "\n\n";
149 }
150 
151 void printParseConditional(mlir::raw_indented_ostream &ios,
152                            ArrayRef<Init *> args,
153                            ArrayRef<std::string> argNames) {
154   ios << "if ";
155   auto parenScope = ios.scope("(", ") {");
156   ios.indent();
157 
158   auto listHelperName = [](StringRef name) {
159     return formatv("read{0}", capitalize(name));
160   };
161 
162   auto parsedArgs =
163       llvm::to_vector(make_filter_range(args, [](Init *const attr) {
164         Record *def = cast<DefInit>(attr)->getDef();
165         if (def->isSubClassOf("Array"))
166           return true;
167         return !def->getValueAsString("cParser").empty();
168       }));
169 
170   interleave(
171       zip(parsedArgs, argNames),
172       [&](std::tuple<llvm::Init *&, const std::string &> it) {
173         Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
174         std::string parser;
175         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
176           parser = *optParser;
177         } else if (attr->isSubClassOf("Array")) {
178           Record *def = attr->getValueAsDef("elemT");
179           bool composite = def->isSubClassOf("CompositeBytecode");
180           if (!composite && def->isSubClassOf("AttributeKind"))
181             parser = "succeeded($_reader.readAttributes($_var))";
182           else if (!composite && def->isSubClassOf("TypeKind"))
183             parser = "succeeded($_reader.readTypes($_var))";
184           else
185             parser = ("succeeded($_reader.readList($_var, " +
186                       listHelperName(std::get<1>(it)) + "))")
187                          .str();
188         } else {
189           PrintFatalError(attr->getLoc(), "No parser specified");
190         }
191         std::string type = getCType(attr);
192         ios << format(parser, {{"$_reader", "reader"},
193                                {"$_resultType", type},
194                                {"$_var", std::get<1>(it)}});
195       },
196       [&]() { ios << " &&\n"; });
197 }
198 
199 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
200                                 StringRef builder, ArrayRef<Init *> args,
201                                 ArrayRef<std::string> argNames,
202                                 StringRef failure,
203                                 mlir::raw_indented_ostream &ios) {
204   auto funScope = ios.scope("{\n", "}");
205 
206   if (args.empty()) {
207     ios << formatv("return get<{0}>(context);\n", returnType);
208     return;
209   }
210 
211   // Print decls.
212   std::string lastCType = "";
213   for (auto [arg, name] : zip(args, argNames)) {
214     DefInit *first = dyn_cast<DefInit>(arg);
215     if (!first)
216       PrintFatalError("Unexpected type for " + name);
217     Record *def = first->getDef();
218 
219     // Create variable decls, if there are a block of same type then create
220     // comma separated list of them.
221     std::string cType = getCType(def);
222     if (lastCType == cType) {
223       ios << ", ";
224     } else {
225       if (!lastCType.empty())
226         ios << ";\n";
227       ios << cType << " ";
228     }
229     ios << name;
230     lastCType = cType;
231   }
232   ios << ";\n";
233 
234   // Returns the name of the helper used in list parsing. E.g., the name of the
235   // lambda passed to array parsing.
236   auto listHelperName = [](StringRef name) {
237     return formatv("read{0}", capitalize(name));
238   };
239 
240   // Emit list helper functions.
241   for (auto [arg, name] : zip(args, argNames)) {
242     Record *attr = cast<DefInit>(arg)->getDef();
243     if (!attr->isSubClassOf("Array"))
244       continue;
245 
246     // TODO: Dedupe readers.
247     Record *def = attr->getValueAsDef("elemT");
248     if (!def->isSubClassOf("CompositeBytecode") &&
249         (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
250       continue;
251 
252     std::string returnType = getCType(def);
253     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
254         << returnType << "> ";
255     SmallVector<Init *> args;
256     SmallVector<std::string> argNames;
257     if (def->isSubClassOf("CompositeBytecode")) {
258       DagInit *members = def->getValueAsDag("members");
259       args = llvm::to_vector(members->getArgs());
260       argNames = llvm::to_vector(
261           map_range(members->getArgNames(), [](StringInit *init) {
262             return init->getAsUnquotedString();
263           }));
264     } else {
265       args = {def->getDefInit()};
266       argNames = {"temp"};
267     }
268     StringRef builder = def->getValueAsString("cBuilder");
269     emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
270                     ios);
271     ios << ";\n";
272   }
273 
274   // Print parse conditional.
275   printParseConditional(ios, args, argNames);
276 
277   // Compute args to pass to create method.
278   auto passedArgs = llvm::to_vector(make_filter_range(
279       argNames, [](StringRef str) { return !str.starts_with("_"); }));
280   std::string argStr;
281   raw_string_ostream argStream(argStr);
282   interleaveComma(passedArgs, argStream,
283                   [&](const std::string &str) { argStream << str; });
284   // Return the invoked constructor.
285   ios << "\nreturn "
286       << format(builder, {{"$_resultType", returnType.str()},
287                           {"$_args", argStream.str()}})
288       << ";\n";
289   ios.unindent();
290 
291   // TODO: Emit error in debug.
292   // This assumes the result types in error case can always be empty
293   // constructed.
294   ios << "}\nreturn " << failure << ";\n";
295 }
296 
297 void Generator::emitPrint(StringRef kind, StringRef type,
298                           ArrayRef<std::pair<int64_t, const Record *>> vec) {
299   if (type == "ReservedOrDead")
300     return;
301 
302   char const *head =
303       R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
304   mlir::raw_indented_ostream os(output);
305   os << formatv(head, type, kind);
306   auto funScope = os.scope("{\n", "}\n\n");
307 
308   // Check that predicates specified if multiple bytecode instances.
309   for (const llvm::Record *rec : make_second_range(vec)) {
310     StringRef pred = rec->getValueAsString("printerPredicate");
311     if (vec.size() > 1 && pred.empty()) {
312       for (auto [index, rec] : vec) {
313         (void)index;
314         StringRef pred = rec->getValueAsString("printerPredicate");
315         if (vec.size() > 1 && pred.empty())
316           PrintError(rec->getLoc(),
317                      "Requires parsing predicate given common cType");
318       }
319       PrintFatalError("Unspecified for shared cType " + type);
320     }
321   }
322 
323   for (auto [index, rec] : vec) {
324     StringRef pred = rec->getValueAsString("printerPredicate");
325     if (!pred.empty()) {
326       os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
327       os.indent();
328     }
329 
330     os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
331        << ");\n";
332 
333     auto *members = rec->getValueAsDag("members");
334     for (auto [arg, name] :
335          llvm::zip(members->getArgs(), members->getArgNames())) {
336       DefInit *def = dyn_cast<DefInit>(arg);
337       assert(def);
338       Record *memberRec = def->getDef();
339       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
340     }
341 
342     if (!pred.empty()) {
343       os.unindent();
344       os << "}\n";
345     }
346   }
347 }
348 
349 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
350                                 StringRef parent, StringRef name,
351                                 mlir::raw_indented_ostream &ios) {
352   std::string getter;
353   if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
354       cGetter && !cGetter->empty()) {
355     getter = format(
356         *cGetter,
357         {{"$_attrType", parent.str()},
358          {"$_member", name.str()},
359          {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
360   } else {
361     getter =
362         formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
363             .str();
364   }
365 
366   if (memberRec->isSubClassOf("Array")) {
367     Record *def = memberRec->getValueAsDef("elemT");
368     if (!def->isSubClassOf("CompositeBytecode")) {
369       if (def->isSubClassOf("AttributeKind")) {
370         ios << "writer.writeAttributes(" << getter << ");\n";
371         return;
372       }
373       if (def->isSubClassOf("TypeKind")) {
374         ios << "writer.writeTypes(" << getter << ");\n";
375         return;
376       }
377     }
378     std::string returnType = getCType(def);
379     std::string nestedName = kind.str();
380     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
381         << nestedName << ") ";
382     auto lambdaScope = ios.scope("{\n", "});\n");
383     return emitPrintHelper(def, kind, nestedName, nestedName, ios);
384   }
385   if (memberRec->isSubClassOf("CompositeBytecode")) {
386     auto *members = memberRec->getValueAsDag("members");
387     for (auto [arg, argName] :
388          zip(members->getArgs(), members->getArgNames())) {
389       DefInit *def = dyn_cast<DefInit>(arg);
390       assert(def);
391       emitPrintHelper(def->getDef(), kind, parent,
392                       argName->getAsUnquotedString(), ios);
393     }
394   }
395 
396   if (std::string printer = memberRec->getValueAsString("cPrinter").str();
397       !printer.empty())
398     ios << format(printer, {{"$_writer", "writer"},
399                             {"$_name", kind.str()},
400                             {"$_getter", getter}})
401         << ";\n";
402 }
403 
404 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
405   mlir::raw_indented_ostream os(output);
406   char const *head = R"(static LogicalResult write{0}({0} {1},
407                                 DialectBytecodeWriter &writer))";
408   os << formatv(head, capitalize(kind), kind);
409   auto funScope = os.scope(" {\n", "}\n\n");
410 
411   os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
412      << ")";
413   auto switchScope = os.scope("", "");
414   for (StringRef type : vec) {
415     if (type == "ReservedOrDead")
416       continue;
417 
418     os << "\n.Case([&](" << type << " t)";
419     auto caseScope = os.scope(" {\n", "})");
420     os << "return write(t, writer), success();\n";
421   }
422   os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
423 }
424 
425 namespace {
426 /// Container of Attribute or Type for Dialect.
427 struct AttrOrType {
428   std::vector<const Record *> attr, type;
429 };
430 } // namespace
431 
432 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
433   MapVector<StringRef, AttrOrType> dialectAttrOrType;
434   for (const Record *it :
435        records.getAllDerivedDefinitions("DialectAttributes")) {
436     if (!selectedBcDialect.empty() &&
437         it->getValueAsString("dialect") != selectedBcDialect)
438       continue;
439     dialectAttrOrType[it->getValueAsString("dialect")].attr =
440         it->getValueAsListOfDefs("elems");
441   }
442   for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
443     if (!selectedBcDialect.empty() &&
444         it->getValueAsString("dialect") != selectedBcDialect)
445       continue;
446     dialectAttrOrType[it->getValueAsString("dialect")].type =
447         it->getValueAsListOfDefs("elems");
448   }
449 
450   if (dialectAttrOrType.size() != 1)
451     PrintFatalError("Single dialect per invocation required (either only "
452                     "one in input file or specified via dialect option)");
453 
454   auto it = dialectAttrOrType.front();
455   Generator gen(os);
456 
457   SmallVector<std::vector<const Record *> *, 2> vecs;
458   SmallVector<std::string, 2> kinds;
459   vecs.push_back(&it.second.attr);
460   kinds.push_back("attribute");
461   vecs.push_back(&it.second.type);
462   kinds.push_back("type");
463   for (auto [vec, kind] : zip(vecs, kinds)) {
464     // Handle Attribute/Type emission.
465     std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
466         perType;
467     for (auto kt : llvm::enumerate(*vec))
468       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
469     for (const auto &jt : perType) {
470       for (auto kt : jt.second)
471         gen.emitParse(kind, *std::get<1>(kt));
472       gen.emitPrint(kind, jt.first, jt.second);
473     }
474     gen.emitParseDispatch(kind, *vec);
475 
476     SmallVector<std::string> types;
477     for (const auto &it : perType) {
478       types.push_back(it.first);
479     }
480     gen.emitPrintDispatch(kind, types);
481   }
482 
483   return false;
484 }
485 
486 static mlir::GenRegistration
487     genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
488             [](const RecordKeeper &records, raw_ostream &os) {
489               return emitBCRW(records, os);
490             });
491