xref: /llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp (revision a9d003ef855ff7ed1bf4f8229ee9944b55936e6f)
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
17 #include <regex>
18 
19 using namespace llvm;
20 
21 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22 static llvm::cl::opt<std::string>
23     selectedBcDialect("bytecode-dialect",
24                       llvm::cl::desc("The dialect to gen for"),
25                       llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
26 
27 namespace {
28 
29 /// Helper class to generate C++ bytecode parser helpers.
30 class Generator {
31 public:
32   Generator(raw_ostream &output) : output(output) {}
33 
34   /// Returns whether successfully emitted attribute/type parsers.
35   void emitParse(StringRef kind, Record &x);
36 
37   /// Returns whether successfully emitted attribute/type printers.
38   void emitPrint(StringRef kind, StringRef type,
39                  ArrayRef<std::pair<int64_t, Record *>> vec);
40 
41   /// Emits parse dispatch table.
42   void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
43 
44   /// Emits print dispatch table.
45   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
46 
47 private:
48   /// Emits parse calls to construct given kind.
49   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50                        ArrayRef<Init *> args, ArrayRef<std::string> argNames,
51                        StringRef failure, mlir::raw_indented_ostream &ios);
52 
53   /// Emits print instructions.
54   void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
55                        StringRef name, mlir::raw_indented_ostream &ios);
56 
57   raw_ostream &output;
58 };
59 } // namespace
60 
61 /// Helper to replace set of from strings to target in `s`.
62 /// Assumed: non-overlapping replacements.
63 static std::string format(StringRef templ,
64                           std::map<std::string, std::string> &&map) {
65   std::string s = templ.str();
66   for (const auto &[from, to] : map)
67     // All replacements start with $, don't treat as anchor.
68     s = std::regex_replace(s, std::regex("\\" + from), to);
69   return s;
70 }
71 
72 /// Return string with first character capitalized.
73 static std::string capitalize(StringRef str) {
74   return ((Twine)toUpper(str[0]) + str.drop_front()).str();
75 }
76 
77 /// Return the C++ type for the given record.
78 static std::string getCType(Record *def) {
79   std::string format = "{0}";
80   if (def->isSubClassOf("Array")) {
81     def = def->getValueAsDef("elemT");
82     format = "SmallVector<{0}>";
83   }
84 
85   StringRef cType = def->getValueAsString("cType");
86   if (cType.empty()) {
87     if (def->isAnonymous())
88       PrintFatalError(def->getLoc(), "Unable to determine cType");
89 
90     return formatv(format.c_str(), def->getName().str());
91   }
92   return formatv(format.c_str(), cType.str());
93 }
94 
95 void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
96   mlir::raw_indented_ostream os(output);
97   char const *head =
98       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
99   os << formatv(head, capitalize(kind));
100   auto funScope = os.scope(" {\n", "}\n\n");
101 
102   os << "uint64_t kind;\n";
103   os << "if (failed(reader.readVarInt(kind)))\n"
104      << "  return " << capitalize(kind) << "();\n";
105   os << "switch (kind) ";
106   {
107     auto switchScope = os.scope("{\n", "}\n");
108     for (const auto &it : llvm::enumerate(vec)) {
109       if (it.value()->getName() == "ReservedOrDead")
110         continue;
111 
112       os << formatv("case {1}:\n  return read{0}(context, reader);\n",
113                     it.value()->getName(), it.index());
114     }
115     os << "default:\n"
116        << "  reader.emitError() << \"unknown attribute code: \" "
117        << "<< kind;\n"
118        << "  return " << capitalize(kind) << "();\n";
119   }
120   os << "return " << capitalize(kind) << "();\n";
121 }
122 
123 void Generator::emitParse(StringRef kind, Record &x) {
124   if (x.getNameInitAsString() == "ReservedOrDead")
125     return;
126 
127   char const *head =
128       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
129   mlir::raw_indented_ostream os(output);
130   std::string returnType = getCType(&x);
131   os << formatv(head, returnType, x.getName());
132   DagInit *members = x.getValueAsDag("members");
133   SmallVector<std::string> argNames =
134       llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
135         return init->getAsUnquotedString();
136       }));
137   StringRef builder = x.getValueAsString("cBuilder");
138   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
139                   returnType + "()", os);
140   os << "\n\n";
141 }
142 
143 void printParseConditional(mlir::raw_indented_ostream &ios,
144                            ArrayRef<Init *> args,
145                            ArrayRef<std::string> argNames) {
146   ios << "if ";
147   auto parenScope = ios.scope("(", ") {");
148   ios.indent();
149 
150   auto listHelperName = [](StringRef name) {
151     return formatv("read{0}", capitalize(name));
152   };
153 
154   auto parsedArgs =
155       llvm::to_vector(make_filter_range(args, [](Init *const attr) {
156         Record *def = cast<DefInit>(attr)->getDef();
157         if (def->isSubClassOf("Array"))
158           return true;
159         return !def->getValueAsString("cParser").empty();
160       }));
161 
162   interleave(
163       zip(parsedArgs, argNames),
164       [&](std::tuple<llvm::Init *&, const std::string &> it) {
165         Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
166         std::string parser;
167         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
168           parser = *optParser;
169         } else if (attr->isSubClassOf("Array")) {
170           Record *def = attr->getValueAsDef("elemT");
171           bool composite = def->isSubClassOf("CompositeBytecode");
172           if (!composite && def->isSubClassOf("AttributeKind"))
173             parser = "succeeded($_reader.readAttributes($_var))";
174           else if (!composite && def->isSubClassOf("TypeKind"))
175             parser = "succeeded($_reader.readTypes($_var))";
176           else
177             parser = ("succeeded($_reader.readList($_var, " +
178                       listHelperName(std::get<1>(it)) + "))")
179                          .str();
180         } else {
181           PrintFatalError(attr->getLoc(), "No parser specified");
182         }
183         std::string type = getCType(attr);
184         ios << format(parser, {{"$_reader", "reader"},
185                                {"$_resultType", type},
186                                {"$_var", std::get<1>(it)}});
187       },
188       [&]() { ios << " &&\n"; });
189 }
190 
191 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
192                                 StringRef builder, ArrayRef<Init *> args,
193                                 ArrayRef<std::string> argNames,
194                                 StringRef failure,
195                                 mlir::raw_indented_ostream &ios) {
196   auto funScope = ios.scope("{\n", "}");
197 
198   if (args.empty()) {
199     ios << formatv("return get<{0}>(context);\n", returnType);
200     return;
201   }
202 
203   // Print decls.
204   std::string lastCType = "";
205   for (auto [arg, name] : zip(args, argNames)) {
206     DefInit *first = dyn_cast<DefInit>(arg);
207     if (!first)
208       PrintFatalError("Unexpected type for " + name);
209     Record *def = first->getDef();
210 
211     // Create variable decls, if there are a block of same type then create
212     // comma separated list of them.
213     std::string cType = getCType(def);
214     if (lastCType == cType) {
215       ios << ", ";
216     } else {
217       if (!lastCType.empty())
218         ios << ";\n";
219       ios << cType << " ";
220     }
221     ios << name;
222     lastCType = cType;
223   }
224   ios << ";\n";
225 
226   // Returns the name of the helper used in list parsing. E.g., the name of the
227   // lambda passed to array parsing.
228   auto listHelperName = [](StringRef name) {
229     return formatv("read{0}", capitalize(name));
230   };
231 
232   // Emit list helper functions.
233   for (auto [arg, name] : zip(args, argNames)) {
234     Record *attr = cast<DefInit>(arg)->getDef();
235     if (!attr->isSubClassOf("Array"))
236       continue;
237 
238     // TODO: Dedupe readers.
239     Record *def = attr->getValueAsDef("elemT");
240     if (!def->isSubClassOf("CompositeBytecode") &&
241         (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
242       continue;
243 
244     std::string returnType = getCType(def);
245     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
246         << returnType << "> ";
247     SmallVector<Init *> args;
248     SmallVector<std::string> argNames;
249     if (def->isSubClassOf("CompositeBytecode")) {
250       DagInit *members = def->getValueAsDag("members");
251       args = llvm::to_vector(members->getArgs());
252       argNames = llvm::to_vector(
253           map_range(members->getArgNames(), [](StringInit *init) {
254             return init->getAsUnquotedString();
255           }));
256     } else {
257       args = {def->getDefInit()};
258       argNames = {"temp"};
259     }
260     StringRef builder = def->getValueAsString("cBuilder");
261     emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
262                     ios);
263     ios << ";\n";
264   }
265 
266   // Print parse conditional.
267   printParseConditional(ios, args, argNames);
268 
269   // Compute args to pass to create method.
270   auto passedArgs = llvm::to_vector(make_filter_range(
271       argNames, [](StringRef str) { return !str.starts_with("_"); }));
272   std::string argStr;
273   raw_string_ostream argStream(argStr);
274   interleaveComma(passedArgs, argStream,
275                   [&](const std::string &str) { argStream << str; });
276   // Return the invoked constructor.
277   ios << "\nreturn "
278       << format(builder, {{"$_resultType", returnType.str()},
279                           {"$_args", argStream.str()}})
280       << ";\n";
281   ios.unindent();
282 
283   // TODO: Emit error in debug.
284   // This assumes the result types in error case can always be empty
285   // constructed.
286   ios << "}\nreturn " << failure << ";\n";
287 }
288 
289 void Generator::emitPrint(StringRef kind, StringRef type,
290                           ArrayRef<std::pair<int64_t, Record *>> vec) {
291   if (type == "ReservedOrDead")
292     return;
293 
294   char const *head =
295       R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
296   mlir::raw_indented_ostream os(output);
297   os << formatv(head, type, kind);
298   auto funScope = os.scope("{\n", "}\n\n");
299 
300   // Check that predicates specified if multiple bytecode instances.
301   for (llvm::Record *rec : make_second_range(vec)) {
302     StringRef pred = rec->getValueAsString("printerPredicate");
303     if (vec.size() > 1 && pred.empty()) {
304       for (auto [index, rec] : vec) {
305         (void)index;
306         StringRef pred = rec->getValueAsString("printerPredicate");
307         if (vec.size() > 1 && pred.empty())
308           PrintError(rec->getLoc(),
309                      "Requires parsing predicate given common cType");
310       }
311       PrintFatalError("Unspecified for shared cType " + type);
312     }
313   }
314 
315   for (auto [index, rec] : vec) {
316     StringRef pred = rec->getValueAsString("printerPredicate");
317     if (!pred.empty()) {
318       os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
319       os.indent();
320     }
321 
322     os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
323        << ");\n";
324 
325     auto *members = rec->getValueAsDag("members");
326     for (auto [arg, name] :
327          llvm::zip(members->getArgs(), members->getArgNames())) {
328       DefInit *def = dyn_cast<DefInit>(arg);
329       assert(def);
330       Record *memberRec = def->getDef();
331       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
332     }
333 
334     if (!pred.empty()) {
335       os.unindent();
336       os << "}\n";
337     }
338   }
339 }
340 
341 void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
342                                 StringRef parent, StringRef name,
343                                 mlir::raw_indented_ostream &ios) {
344   std::string getter;
345   if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
346       cGetter && !cGetter->empty()) {
347     getter = format(
348         *cGetter,
349         {{"$_attrType", parent.str()},
350          {"$_member", name.str()},
351          {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
352   } else {
353     getter =
354         formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
355             .str();
356   }
357 
358   if (memberRec->isSubClassOf("Array")) {
359     Record *def = memberRec->getValueAsDef("elemT");
360     if (!def->isSubClassOf("CompositeBytecode")) {
361       if (def->isSubClassOf("AttributeKind")) {
362         ios << "writer.writeAttributes(" << getter << ");\n";
363         return;
364       }
365       if (def->isSubClassOf("TypeKind")) {
366         ios << "writer.writeTypes(" << getter << ");\n";
367         return;
368       }
369     }
370     std::string returnType = getCType(def);
371     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
372         << kind << ") ";
373     auto lambdaScope = ios.scope("{\n", "});\n");
374     return emitPrintHelper(def, kind, kind, kind, ios);
375   }
376   if (memberRec->isSubClassOf("CompositeBytecode")) {
377     auto *members = memberRec->getValueAsDag("members");
378     for (auto [arg, argName] :
379          zip(members->getArgs(), members->getArgNames())) {
380       DefInit *def = dyn_cast<DefInit>(arg);
381       assert(def);
382       emitPrintHelper(def->getDef(), kind, parent,
383                       argName->getAsUnquotedString(), ios);
384     }
385   }
386 
387   if (std::string printer = memberRec->getValueAsString("cPrinter").str();
388       !printer.empty())
389     ios << format(printer, {{"$_writer", "writer"},
390                             {"$_name", kind.str()},
391                             {"$_getter", getter}})
392         << ";\n";
393 }
394 
395 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
396   mlir::raw_indented_ostream os(output);
397   char const *head = R"(static LogicalResult write{0}({0} {1},
398                                 DialectBytecodeWriter &writer))";
399   os << formatv(head, capitalize(kind), kind);
400   auto funScope = os.scope(" {\n", "}\n\n");
401 
402   os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
403      << ")";
404   auto switchScope = os.scope("", "");
405   for (StringRef type : vec) {
406     if (type == "ReservedOrDead")
407       continue;
408 
409     os << "\n.Case([&](" << type << " t)";
410     auto caseScope = os.scope(" {\n", "})");
411     os << "return write(t, writer), success();\n";
412   }
413   os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
414 }
415 
416 namespace {
417 /// Container of Attribute or Type for Dialect.
418 struct AttrOrType {
419   std::vector<Record *> attr, type;
420 };
421 } // namespace
422 
423 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
424   MapVector<StringRef, AttrOrType> dialectAttrOrType;
425   for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
426     if (!selectedBcDialect.empty() &&
427         it->getValueAsString("dialect") != selectedBcDialect)
428       continue;
429     dialectAttrOrType[it->getValueAsString("dialect")].attr =
430         it->getValueAsListOfDefs("elems");
431   }
432   for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
433     if (!selectedBcDialect.empty() &&
434         it->getValueAsString("dialect") != selectedBcDialect)
435       continue;
436     dialectAttrOrType[it->getValueAsString("dialect")].type =
437         it->getValueAsListOfDefs("elems");
438   }
439 
440   if (dialectAttrOrType.size() != 1)
441     PrintFatalError("Single dialect per invocation required (either only "
442                     "one in input file or specified via dialect option)");
443 
444   auto it = dialectAttrOrType.front();
445   Generator gen(os);
446 
447   SmallVector<std::vector<Record *> *, 2> vecs;
448   SmallVector<std::string, 2> kinds;
449   vecs.push_back(&it.second.attr);
450   kinds.push_back("attribute");
451   vecs.push_back(&it.second.type);
452   kinds.push_back("type");
453   for (auto [vec, kind] : zip(vecs, kinds)) {
454     // Handle Attribute/Type emission.
455     std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
456     for (auto kt : llvm::enumerate(*vec))
457       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
458     for (const auto &jt : perType) {
459       for (auto kt : jt.second)
460         gen.emitParse(kind, *std::get<1>(kt));
461       gen.emitPrint(kind, jt.first, jt.second);
462     }
463     gen.emitParseDispatch(kind, *vec);
464 
465     SmallVector<std::string> types;
466     for (const auto &it : perType) {
467       types.push_back(it.first);
468     }
469     gen.emitPrintDispatch(kind, types);
470   }
471 
472   return false;
473 }
474 
475 static mlir::GenRegistration
476     genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
477             [](const RecordKeeper &records, raw_ostream &os) {
478               return emitBCRW(records, os);
479             });
480