xref: /llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp (revision bccd37f69fdc7b5cd00d9231cabbe74bfe38f598)
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
17 #include <regex>
18 
19 using namespace llvm;
20 
21 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22 static cl::opt<std::string>
23     selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
24                       cl::cat(dialectGenCat), cl::CommaSeparated);
25 
26 namespace {
27 
28 /// Helper class to generate C++ bytecode parser helpers.
29 class Generator {
30 public:
31   Generator(raw_ostream &output) : output(output) {}
32 
33   /// Returns whether successfully emitted attribute/type parsers.
34   void emitParse(StringRef kind, const Record &x);
35 
36   /// Returns whether successfully emitted attribute/type printers.
37   void emitPrint(StringRef kind, StringRef type,
38                  ArrayRef<std::pair<int64_t, const Record *>> vec);
39 
40   /// Emits parse dispatch table.
41   void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
42 
43   /// Emits print dispatch table.
44   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
45 
46 private:
47   /// Emits parse calls to construct given kind.
48   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
49                        ArrayRef<Init *> args, ArrayRef<std::string> argNames,
50                        StringRef failure, mlir::raw_indented_ostream &ios);
51 
52   /// Emits print instructions.
53   void emitPrintHelper(const Record *memberRec, StringRef kind,
54                        StringRef parent, StringRef name,
55                        mlir::raw_indented_ostream &ios);
56 
57   raw_ostream &output;
58 };
59 } // namespace
60 
61 /// Helper to replace set of from strings to target in `s`.
62 /// Assumed: non-overlapping replacements.
63 static std::string format(StringRef templ,
64                           std::map<std::string, std::string> &&map) {
65   std::string s = templ.str();
66   for (const auto &[from, to] : map)
67     // All replacements start with $, don't treat as anchor.
68     s = std::regex_replace(s, std::regex("\\" + from), to);
69   return s;
70 }
71 
72 /// Return string with first character capitalized.
73 static std::string capitalize(StringRef str) {
74   return ((Twine)toUpper(str[0]) + str.drop_front()).str();
75 }
76 
77 /// Return the C++ type for the given record.
78 static std::string getCType(const Record *def) {
79   std::string format = "{0}";
80   if (def->isSubClassOf("Array")) {
81     def = def->getValueAsDef("elemT");
82     format = "SmallVector<{0}>";
83   }
84 
85   StringRef cType = def->getValueAsString("cType");
86   if (cType.empty()) {
87     if (def->isAnonymous())
88       PrintFatalError(def->getLoc(), "Unable to determine cType");
89 
90     return formatv(format.c_str(), def->getName().str());
91   }
92   return formatv(format.c_str(), cType.str());
93 }
94 
95 void Generator::emitParseDispatch(StringRef kind,
96                                   ArrayRef<const Record *> vec) {
97   mlir::raw_indented_ostream os(output);
98   char const *head =
99       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
100   os << formatv(head, capitalize(kind));
101   auto funScope = os.scope(" {\n", "}\n\n");
102 
103   if (vec.empty()) {
104     os << "return reader.emitError() << \"unknown attribute\", "
105        << capitalize(kind) << "();\n";
106     return;
107   }
108 
109   os << "uint64_t kind;\n";
110   os << "if (failed(reader.readVarInt(kind)))\n"
111      << "  return " << capitalize(kind) << "();\n";
112   os << "switch (kind) ";
113   {
114     auto switchScope = os.scope("{\n", "}\n");
115     for (const auto &it : llvm::enumerate(vec)) {
116       if (it.value()->getName() == "ReservedOrDead")
117         continue;
118 
119       os << formatv("case {1}:\n  return read{0}(context, reader);\n",
120                     it.value()->getName(), it.index());
121     }
122     os << "default:\n"
123        << "  reader.emitError() << \"unknown attribute code: \" "
124        << "<< kind;\n"
125        << "  return " << capitalize(kind) << "();\n";
126   }
127   os << "return " << capitalize(kind) << "();\n";
128 }
129 
130 void Generator::emitParse(StringRef kind, const Record &x) {
131   if (x.getNameInitAsString() == "ReservedOrDead")
132     return;
133 
134   char const *head =
135       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
136   mlir::raw_indented_ostream os(output);
137   std::string returnType = getCType(&x);
138   os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
139   DagInit *members = x.getValueAsDag("members");
140   SmallVector<std::string> argNames =
141       llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
142         return init->getAsUnquotedString();
143       }));
144   StringRef builder = x.getValueAsString("cBuilder").trim();
145   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
146                   returnType + "()", os);
147   os << "\n\n";
148 }
149 
150 void printParseConditional(mlir::raw_indented_ostream &ios,
151                            ArrayRef<Init *> args,
152                            ArrayRef<std::string> argNames) {
153   ios << "if ";
154   auto parenScope = ios.scope("(", ") {");
155   ios.indent();
156 
157   auto listHelperName = [](StringRef name) {
158     return formatv("read{0}", capitalize(name));
159   };
160 
161   auto parsedArgs =
162       llvm::to_vector(make_filter_range(args, [](Init *const attr) {
163         const Record *def = cast<DefInit>(attr)->getDef();
164         if (def->isSubClassOf("Array"))
165           return true;
166         return !def->getValueAsString("cParser").empty();
167       }));
168 
169   interleave(
170       zip(parsedArgs, argNames),
171       [&](std::tuple<llvm::Init *&, const std::string &> it) {
172         const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
173         std::string parser;
174         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
175           parser = *optParser;
176         } else if (attr->isSubClassOf("Array")) {
177           const Record *def = attr->getValueAsDef("elemT");
178           bool composite = def->isSubClassOf("CompositeBytecode");
179           if (!composite && def->isSubClassOf("AttributeKind"))
180             parser = "succeeded($_reader.readAttributes($_var))";
181           else if (!composite && def->isSubClassOf("TypeKind"))
182             parser = "succeeded($_reader.readTypes($_var))";
183           else
184             parser = ("succeeded($_reader.readList($_var, " +
185                       listHelperName(std::get<1>(it)) + "))")
186                          .str();
187         } else {
188           PrintFatalError(attr->getLoc(), "No parser specified");
189         }
190         std::string type = getCType(attr);
191         ios << format(parser, {{"$_reader", "reader"},
192                                {"$_resultType", type},
193                                {"$_var", std::get<1>(it)}});
194       },
195       [&]() { ios << " &&\n"; });
196 }
197 
198 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
199                                 StringRef builder, ArrayRef<Init *> args,
200                                 ArrayRef<std::string> argNames,
201                                 StringRef failure,
202                                 mlir::raw_indented_ostream &ios) {
203   auto funScope = ios.scope("{\n", "}");
204 
205   if (args.empty()) {
206     ios << formatv("return get<{0}>(context);\n", returnType);
207     return;
208   }
209 
210   // Print decls.
211   std::string lastCType = "";
212   for (auto [arg, name] : zip(args, argNames)) {
213     DefInit *first = dyn_cast<DefInit>(arg);
214     if (!first)
215       PrintFatalError("Unexpected type for " + name);
216     const Record *def = first->getDef();
217 
218     // Create variable decls, if there are a block of same type then create
219     // comma separated list of them.
220     std::string cType = getCType(def);
221     if (lastCType == cType) {
222       ios << ", ";
223     } else {
224       if (!lastCType.empty())
225         ios << ";\n";
226       ios << cType << " ";
227     }
228     ios << name;
229     lastCType = cType;
230   }
231   ios << ";\n";
232 
233   // Returns the name of the helper used in list parsing. E.g., the name of the
234   // lambda passed to array parsing.
235   auto listHelperName = [](StringRef name) {
236     return formatv("read{0}", capitalize(name));
237   };
238 
239   // Emit list helper functions.
240   for (auto [arg, name] : zip(args, argNames)) {
241     const Record *attr = cast<DefInit>(arg)->getDef();
242     if (!attr->isSubClassOf("Array"))
243       continue;
244 
245     // TODO: Dedupe readers.
246     const Record *def = attr->getValueAsDef("elemT");
247     if (!def->isSubClassOf("CompositeBytecode") &&
248         (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
249       continue;
250 
251     std::string returnType = getCType(def);
252     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
253         << returnType << "> ";
254     SmallVector<Init *> args;
255     SmallVector<std::string> argNames;
256     if (def->isSubClassOf("CompositeBytecode")) {
257       DagInit *members = def->getValueAsDag("members");
258       args = llvm::to_vector(members->getArgs());
259       argNames = llvm::to_vector(
260           map_range(members->getArgNames(), [](StringInit *init) {
261             return init->getAsUnquotedString();
262           }));
263     } else {
264       args = {def->getDefInit()};
265       argNames = {"temp"};
266     }
267     StringRef builder = def->getValueAsString("cBuilder");
268     emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
269                     ios);
270     ios << ";\n";
271   }
272 
273   // Print parse conditional.
274   printParseConditional(ios, args, argNames);
275 
276   // Compute args to pass to create method.
277   auto passedArgs = llvm::to_vector(make_filter_range(
278       argNames, [](StringRef str) { return !str.starts_with("_"); }));
279   std::string argStr;
280   raw_string_ostream argStream(argStr);
281   interleaveComma(passedArgs, argStream,
282                   [&](const std::string &str) { argStream << str; });
283   // Return the invoked constructor.
284   ios << "\nreturn "
285       << format(builder, {{"$_resultType", returnType.str()},
286                           {"$_args", argStream.str()}})
287       << ";\n";
288   ios.unindent();
289 
290   // TODO: Emit error in debug.
291   // This assumes the result types in error case can always be empty
292   // constructed.
293   ios << "}\nreturn " << failure << ";\n";
294 }
295 
296 void Generator::emitPrint(StringRef kind, StringRef type,
297                           ArrayRef<std::pair<int64_t, const Record *>> vec) {
298   if (type == "ReservedOrDead")
299     return;
300 
301   char const *head =
302       R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
303   mlir::raw_indented_ostream os(output);
304   os << formatv(head, type, kind);
305   auto funScope = os.scope("{\n", "}\n\n");
306 
307   // Check that predicates specified if multiple bytecode instances.
308   for (const Record *rec : make_second_range(vec)) {
309     StringRef pred = rec->getValueAsString("printerPredicate");
310     if (vec.size() > 1 && pred.empty()) {
311       for (auto [index, rec] : vec) {
312         (void)index;
313         StringRef pred = rec->getValueAsString("printerPredicate");
314         if (vec.size() > 1 && pred.empty())
315           PrintError(rec->getLoc(),
316                      "Requires parsing predicate given common cType");
317       }
318       PrintFatalError("Unspecified for shared cType " + type);
319     }
320   }
321 
322   for (auto [index, rec] : vec) {
323     StringRef pred = rec->getValueAsString("printerPredicate");
324     if (!pred.empty()) {
325       os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
326       os.indent();
327     }
328 
329     os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
330        << ");\n";
331 
332     auto *members = rec->getValueAsDag("members");
333     for (auto [arg, name] :
334          llvm::zip(members->getArgs(), members->getArgNames())) {
335       DefInit *def = dyn_cast<DefInit>(arg);
336       assert(def);
337       const Record *memberRec = def->getDef();
338       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
339     }
340 
341     if (!pred.empty()) {
342       os.unindent();
343       os << "}\n";
344     }
345   }
346 }
347 
348 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
349                                 StringRef parent, StringRef name,
350                                 mlir::raw_indented_ostream &ios) {
351   std::string getter;
352   if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
353       cGetter && !cGetter->empty()) {
354     getter = format(
355         *cGetter,
356         {{"$_attrType", parent.str()},
357          {"$_member", name.str()},
358          {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
359   } else {
360     getter =
361         formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
362             .str();
363   }
364 
365   if (memberRec->isSubClassOf("Array")) {
366     const Record *def = memberRec->getValueAsDef("elemT");
367     if (!def->isSubClassOf("CompositeBytecode")) {
368       if (def->isSubClassOf("AttributeKind")) {
369         ios << "writer.writeAttributes(" << getter << ");\n";
370         return;
371       }
372       if (def->isSubClassOf("TypeKind")) {
373         ios << "writer.writeTypes(" << getter << ");\n";
374         return;
375       }
376     }
377     std::string returnType = getCType(def);
378     std::string nestedName = kind.str();
379     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
380         << nestedName << ") ";
381     auto lambdaScope = ios.scope("{\n", "});\n");
382     return emitPrintHelper(def, kind, nestedName, nestedName, ios);
383   }
384   if (memberRec->isSubClassOf("CompositeBytecode")) {
385     auto *members = memberRec->getValueAsDag("members");
386     for (auto [arg, argName] :
387          zip(members->getArgs(), members->getArgNames())) {
388       DefInit *def = dyn_cast<DefInit>(arg);
389       assert(def);
390       emitPrintHelper(def->getDef(), kind, parent,
391                       argName->getAsUnquotedString(), ios);
392     }
393   }
394 
395   if (std::string printer = memberRec->getValueAsString("cPrinter").str();
396       !printer.empty())
397     ios << format(printer, {{"$_writer", "writer"},
398                             {"$_name", kind.str()},
399                             {"$_getter", getter}})
400         << ";\n";
401 }
402 
403 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
404   mlir::raw_indented_ostream os(output);
405   char const *head = R"(static LogicalResult write{0}({0} {1},
406                                 DialectBytecodeWriter &writer))";
407   os << formatv(head, capitalize(kind), kind);
408   auto funScope = os.scope(" {\n", "}\n\n");
409 
410   os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
411      << ")";
412   auto switchScope = os.scope("", "");
413   for (StringRef type : vec) {
414     if (type == "ReservedOrDead")
415       continue;
416 
417     os << "\n.Case([&](" << type << " t)";
418     auto caseScope = os.scope(" {\n", "})");
419     os << "return write(t, writer), success();\n";
420   }
421   os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
422 }
423 
424 namespace {
425 /// Container of Attribute or Type for Dialect.
426 struct AttrOrType {
427   std::vector<const Record *> attr, type;
428 };
429 } // namespace
430 
431 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
432   MapVector<StringRef, AttrOrType> dialectAttrOrType;
433   for (const Record *it :
434        records.getAllDerivedDefinitions("DialectAttributes")) {
435     if (!selectedBcDialect.empty() &&
436         it->getValueAsString("dialect") != selectedBcDialect)
437       continue;
438     dialectAttrOrType[it->getValueAsString("dialect")].attr =
439         it->getValueAsListOfDefs("elems");
440   }
441   for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
442     if (!selectedBcDialect.empty() &&
443         it->getValueAsString("dialect") != selectedBcDialect)
444       continue;
445     dialectAttrOrType[it->getValueAsString("dialect")].type =
446         it->getValueAsListOfDefs("elems");
447   }
448 
449   if (dialectAttrOrType.size() != 1)
450     PrintFatalError("Single dialect per invocation required (either only "
451                     "one in input file or specified via dialect option)");
452 
453   auto it = dialectAttrOrType.front();
454   Generator gen(os);
455 
456   SmallVector<std::vector<const Record *> *, 2> vecs;
457   SmallVector<std::string, 2> kinds;
458   vecs.push_back(&it.second.attr);
459   kinds.push_back("attribute");
460   vecs.push_back(&it.second.type);
461   kinds.push_back("type");
462   for (auto [vec, kind] : zip(vecs, kinds)) {
463     // Handle Attribute/Type emission.
464     std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
465         perType;
466     for (auto kt : llvm::enumerate(*vec))
467       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
468     for (const auto &jt : perType) {
469       for (auto kt : jt.second)
470         gen.emitParse(kind, *std::get<1>(kt));
471       gen.emitPrint(kind, jt.first, jt.second);
472     }
473     gen.emitParseDispatch(kind, *vec);
474 
475     SmallVector<std::string> types;
476     for (const auto &it : perType) {
477       types.push_back(it.first);
478     }
479     gen.emitPrintDispatch(kind, types);
480   }
481 
482   return false;
483 }
484 
485 static mlir::GenRegistration
486     genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
487             [](const RecordKeeper &records, raw_ostream &os) {
488               return emitBCRW(records, os);
489             });
490