1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/IndentedOstream.h" 10 #include "mlir/TableGen/GenInfo.h" 11 #include "llvm/ADT/MapVector.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/Support/CommandLine.h" 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Error.h" 16 #include "llvm/TableGen/Record.h" 17 #include <regex> 18 19 using namespace llvm; 20 21 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); 22 static cl::opt<std::string> 23 selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"), 24 cl::cat(dialectGenCat), cl::CommaSeparated); 25 26 namespace { 27 28 /// Helper class to generate C++ bytecode parser helpers. 29 class Generator { 30 public: 31 Generator(raw_ostream &output) : output(output) {} 32 33 /// Returns whether successfully emitted attribute/type parsers. 34 void emitParse(StringRef kind, const Record &x); 35 36 /// Returns whether successfully emitted attribute/type printers. 37 void emitPrint(StringRef kind, StringRef type, 38 ArrayRef<std::pair<int64_t, const Record *>> vec); 39 40 /// Emits parse dispatch table. 41 void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec); 42 43 /// Emits print dispatch table. 44 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); 45 46 private: 47 /// Emits parse calls to construct given kind. 48 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, 49 ArrayRef<const Init *> args, 50 ArrayRef<std::string> argNames, StringRef failure, 51 mlir::raw_indented_ostream &ios); 52 53 /// Emits print instructions. 54 void emitPrintHelper(const Record *memberRec, StringRef kind, 55 StringRef parent, StringRef name, 56 mlir::raw_indented_ostream &ios); 57 58 raw_ostream &output; 59 }; 60 } // namespace 61 62 /// Helper to replace set of from strings to target in `s`. 63 /// Assumed: non-overlapping replacements. 64 static std::string format(StringRef templ, 65 std::map<std::string, std::string> &&map) { 66 std::string s = templ.str(); 67 for (const auto &[from, to] : map) 68 // All replacements start with $, don't treat as anchor. 69 s = std::regex_replace(s, std::regex("\\" + from), to); 70 return s; 71 } 72 73 /// Return string with first character capitalized. 74 static std::string capitalize(StringRef str) { 75 return ((Twine)toUpper(str[0]) + str.drop_front()).str(); 76 } 77 78 /// Return the C++ type for the given record. 79 static std::string getCType(const Record *def) { 80 std::string format = "{0}"; 81 if (def->isSubClassOf("Array")) { 82 def = def->getValueAsDef("elemT"); 83 format = "SmallVector<{0}>"; 84 } 85 86 StringRef cType = def->getValueAsString("cType"); 87 if (cType.empty()) { 88 if (def->isAnonymous()) 89 PrintFatalError(def->getLoc(), "Unable to determine cType"); 90 91 return formatv(format.c_str(), def->getName().str()); 92 } 93 return formatv(format.c_str(), cType.str()); 94 } 95 96 void Generator::emitParseDispatch(StringRef kind, 97 ArrayRef<const Record *> vec) { 98 mlir::raw_indented_ostream os(output); 99 char const *head = 100 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; 101 os << formatv(head, capitalize(kind)); 102 auto funScope = os.scope(" {\n", "}\n\n"); 103 104 if (vec.empty()) { 105 os << "return reader.emitError() << \"unknown attribute\", " 106 << capitalize(kind) << "();\n"; 107 return; 108 } 109 110 os << "uint64_t kind;\n"; 111 os << "if (failed(reader.readVarInt(kind)))\n" 112 << " return " << capitalize(kind) << "();\n"; 113 os << "switch (kind) "; 114 { 115 auto switchScope = os.scope("{\n", "}\n"); 116 for (const auto &it : llvm::enumerate(vec)) { 117 if (it.value()->getName() == "ReservedOrDead") 118 continue; 119 120 os << formatv("case {1}:\n return read{0}(context, reader);\n", 121 it.value()->getName(), it.index()); 122 } 123 os << "default:\n" 124 << " reader.emitError() << \"unknown attribute code: \" " 125 << "<< kind;\n" 126 << " return " << capitalize(kind) << "();\n"; 127 } 128 os << "return " << capitalize(kind) << "();\n"; 129 } 130 131 void Generator::emitParse(StringRef kind, const Record &x) { 132 if (x.getNameInitAsString() == "ReservedOrDead") 133 return; 134 135 char const *head = 136 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; 137 mlir::raw_indented_ostream os(output); 138 std::string returnType = getCType(&x); 139 os << formatv(head, 140 kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", 141 x.getName()); 142 const DagInit *members = x.getValueAsDag("members"); 143 SmallVector<std::string> argNames = llvm::to_vector( 144 map_range(members->getArgNames(), [](const StringInit *init) { 145 return init->getAsUnquotedString(); 146 })); 147 StringRef builder = x.getValueAsString("cBuilder").trim(); 148 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, 149 returnType + "()", os); 150 os << "\n\n"; 151 } 152 153 void printParseConditional(mlir::raw_indented_ostream &ios, 154 ArrayRef<const Init *> args, 155 ArrayRef<std::string> argNames) { 156 ios << "if "; 157 auto parenScope = ios.scope("(", ") {"); 158 ios.indent(); 159 160 auto listHelperName = [](StringRef name) { 161 return formatv("read{0}", capitalize(name)); 162 }; 163 164 auto parsedArgs = 165 llvm::to_vector(make_filter_range(args, [](const Init *const attr) { 166 const Record *def = cast<DefInit>(attr)->getDef(); 167 if (def->isSubClassOf("Array")) 168 return true; 169 return !def->getValueAsString("cParser").empty(); 170 })); 171 172 interleave( 173 zip(parsedArgs, argNames), 174 [&](std::tuple<const Init *&, const std::string &> it) { 175 const Record *attr = cast<DefInit>(std::get<0>(it))->getDef(); 176 std::string parser; 177 if (auto optParser = attr->getValueAsOptionalString("cParser")) { 178 parser = *optParser; 179 } else if (attr->isSubClassOf("Array")) { 180 const Record *def = attr->getValueAsDef("elemT"); 181 bool composite = def->isSubClassOf("CompositeBytecode"); 182 if (!composite && def->isSubClassOf("AttributeKind")) 183 parser = "succeeded($_reader.readAttributes($_var))"; 184 else if (!composite && def->isSubClassOf("TypeKind")) 185 parser = "succeeded($_reader.readTypes($_var))"; 186 else 187 parser = ("succeeded($_reader.readList($_var, " + 188 listHelperName(std::get<1>(it)) + "))") 189 .str(); 190 } else { 191 PrintFatalError(attr->getLoc(), "No parser specified"); 192 } 193 std::string type = getCType(attr); 194 ios << format(parser, {{"$_reader", "reader"}, 195 {"$_resultType", type}, 196 {"$_var", std::get<1>(it)}}); 197 }, 198 [&]() { ios << " &&\n"; }); 199 } 200 201 void Generator::emitParseHelper(StringRef kind, StringRef returnType, 202 StringRef builder, ArrayRef<const Init *> args, 203 ArrayRef<std::string> argNames, 204 StringRef failure, 205 mlir::raw_indented_ostream &ios) { 206 auto funScope = ios.scope("{\n", "}"); 207 208 if (args.empty()) { 209 ios << formatv("return get<{0}>(context);\n", returnType); 210 return; 211 } 212 213 // Print decls. 214 std::string lastCType = ""; 215 for (auto [arg, name] : zip(args, argNames)) { 216 const DefInit *first = dyn_cast<DefInit>(arg); 217 if (!first) 218 PrintFatalError("Unexpected type for " + name); 219 const Record *def = first->getDef(); 220 221 // Create variable decls, if there are a block of same type then create 222 // comma separated list of them. 223 std::string cType = getCType(def); 224 if (lastCType == cType) { 225 ios << ", "; 226 } else { 227 if (!lastCType.empty()) 228 ios << ";\n"; 229 ios << cType << " "; 230 } 231 ios << name; 232 lastCType = cType; 233 } 234 ios << ";\n"; 235 236 // Returns the name of the helper used in list parsing. E.g., the name of the 237 // lambda passed to array parsing. 238 auto listHelperName = [](StringRef name) { 239 return formatv("read{0}", capitalize(name)); 240 }; 241 242 // Emit list helper functions. 243 for (auto [arg, name] : zip(args, argNames)) { 244 const Record *attr = cast<DefInit>(arg)->getDef(); 245 if (!attr->isSubClassOf("Array")) 246 continue; 247 248 // TODO: Dedupe readers. 249 const Record *def = attr->getValueAsDef("elemT"); 250 if (!def->isSubClassOf("CompositeBytecode") && 251 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) 252 continue; 253 254 std::string returnType = getCType(def); 255 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" 256 << returnType << "> "; 257 SmallVector<const Init *> args; 258 SmallVector<std::string> argNames; 259 if (def->isSubClassOf("CompositeBytecode")) { 260 const DagInit *members = def->getValueAsDag("members"); 261 args = llvm::to_vector(map_range( 262 members->getArgs(), [](Init *init) { return (const Init *)init; })); 263 argNames = llvm::to_vector( 264 map_range(members->getArgNames(), [](const StringInit *init) { 265 return init->getAsUnquotedString(); 266 })); 267 } else { 268 args = {def->getDefInit()}; 269 argNames = {"temp"}; 270 } 271 StringRef builder = def->getValueAsString("cBuilder"); 272 emitParseHelper(kind, returnType, builder, args, argNames, "failure()", 273 ios); 274 ios << ";\n"; 275 } 276 277 // Print parse conditional. 278 printParseConditional(ios, args, argNames); 279 280 // Compute args to pass to create method. 281 auto passedArgs = llvm::to_vector(make_filter_range( 282 argNames, [](StringRef str) { return !str.starts_with("_"); })); 283 std::string argStr; 284 raw_string_ostream argStream(argStr); 285 interleaveComma(passedArgs, argStream, 286 [&](const std::string &str) { argStream << str; }); 287 // Return the invoked constructor. 288 ios << "\nreturn " 289 << format(builder, {{"$_resultType", returnType.str()}, 290 {"$_args", argStream.str()}}) 291 << ";\n"; 292 ios.unindent(); 293 294 // TODO: Emit error in debug. 295 // This assumes the result types in error case can always be empty 296 // constructed. 297 ios << "}\nreturn " << failure << ";\n"; 298 } 299 300 void Generator::emitPrint(StringRef kind, StringRef type, 301 ArrayRef<std::pair<int64_t, const Record *>> vec) { 302 if (type == "ReservedOrDead") 303 return; 304 305 char const *head = 306 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; 307 mlir::raw_indented_ostream os(output); 308 os << formatv(head, type, kind); 309 auto funScope = os.scope("{\n", "}\n\n"); 310 311 // Check that predicates specified if multiple bytecode instances. 312 for (const Record *rec : make_second_range(vec)) { 313 StringRef pred = rec->getValueAsString("printerPredicate"); 314 if (vec.size() > 1 && pred.empty()) { 315 for (auto [index, rec] : vec) { 316 (void)index; 317 StringRef pred = rec->getValueAsString("printerPredicate"); 318 if (vec.size() > 1 && pred.empty()) 319 PrintError(rec->getLoc(), 320 "Requires parsing predicate given common cType"); 321 } 322 PrintFatalError("Unspecified for shared cType " + type); 323 } 324 } 325 326 for (auto [index, rec] : vec) { 327 StringRef pred = rec->getValueAsString("printerPredicate"); 328 if (!pred.empty()) { 329 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; 330 os.indent(); 331 } 332 333 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index 334 << ");\n"; 335 336 auto *members = rec->getValueAsDag("members"); 337 for (auto [arg, name] : 338 llvm::zip(members->getArgs(), members->getArgNames())) { 339 const DefInit *def = dyn_cast<DefInit>(arg); 340 assert(def); 341 const Record *memberRec = def->getDef(); 342 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); 343 } 344 345 if (!pred.empty()) { 346 os.unindent(); 347 os << "}\n"; 348 } 349 } 350 } 351 352 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind, 353 StringRef parent, StringRef name, 354 mlir::raw_indented_ostream &ios) { 355 std::string getter; 356 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); 357 cGetter && !cGetter->empty()) { 358 getter = format( 359 *cGetter, 360 {{"$_attrType", parent.str()}, 361 {"$_member", name.str()}, 362 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); 363 } else { 364 getter = 365 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) 366 .str(); 367 } 368 369 if (memberRec->isSubClassOf("Array")) { 370 const Record *def = memberRec->getValueAsDef("elemT"); 371 if (!def->isSubClassOf("CompositeBytecode")) { 372 if (def->isSubClassOf("AttributeKind")) { 373 ios << "writer.writeAttributes(" << getter << ");\n"; 374 return; 375 } 376 if (def->isSubClassOf("TypeKind")) { 377 ios << "writer.writeTypes(" << getter << ");\n"; 378 return; 379 } 380 } 381 std::string returnType = getCType(def); 382 std::string nestedName = kind.str(); 383 ios << "writer.writeList(" << getter << ", [&](" << returnType << " " 384 << nestedName << ") "; 385 auto lambdaScope = ios.scope("{\n", "});\n"); 386 return emitPrintHelper(def, kind, nestedName, nestedName, ios); 387 } 388 if (memberRec->isSubClassOf("CompositeBytecode")) { 389 auto *members = memberRec->getValueAsDag("members"); 390 for (auto [arg, argName] : 391 zip(members->getArgs(), members->getArgNames())) { 392 const DefInit *def = dyn_cast<DefInit>(arg); 393 assert(def); 394 emitPrintHelper(def->getDef(), kind, parent, 395 argName->getAsUnquotedString(), ios); 396 } 397 } 398 399 if (std::string printer = memberRec->getValueAsString("cPrinter").str(); 400 !printer.empty()) 401 ios << format(printer, {{"$_writer", "writer"}, 402 {"$_name", kind.str()}, 403 {"$_getter", getter}}) 404 << ";\n"; 405 } 406 407 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { 408 mlir::raw_indented_ostream os(output); 409 char const *head = R"(static LogicalResult write{0}({0} {1}, 410 DialectBytecodeWriter &writer))"; 411 os << formatv(head, capitalize(kind), kind); 412 auto funScope = os.scope(" {\n", "}\n\n"); 413 414 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind 415 << ")"; 416 auto switchScope = os.scope("", ""); 417 for (StringRef type : vec) { 418 if (type == "ReservedOrDead") 419 continue; 420 421 os << "\n.Case([&](" << type << " t)"; 422 auto caseScope = os.scope(" {\n", "})"); 423 os << "return write(t, writer), success();\n"; 424 } 425 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; 426 } 427 428 namespace { 429 /// Container of Attribute or Type for Dialect. 430 struct AttrOrType { 431 std::vector<const Record *> attr, type; 432 }; 433 } // namespace 434 435 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { 436 MapVector<StringRef, AttrOrType> dialectAttrOrType; 437 for (const Record *it : 438 records.getAllDerivedDefinitions("DialectAttributes")) { 439 if (!selectedBcDialect.empty() && 440 it->getValueAsString("dialect") != selectedBcDialect) 441 continue; 442 dialectAttrOrType[it->getValueAsString("dialect")].attr = 443 it->getValueAsListOfDefs("elems"); 444 } 445 for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) { 446 if (!selectedBcDialect.empty() && 447 it->getValueAsString("dialect") != selectedBcDialect) 448 continue; 449 dialectAttrOrType[it->getValueAsString("dialect")].type = 450 it->getValueAsListOfDefs("elems"); 451 } 452 453 if (dialectAttrOrType.size() != 1) 454 PrintFatalError("Single dialect per invocation required (either only " 455 "one in input file or specified via dialect option)"); 456 457 auto it = dialectAttrOrType.front(); 458 Generator gen(os); 459 460 SmallVector<std::vector<const Record *> *, 2> vecs; 461 SmallVector<std::string, 2> kinds; 462 vecs.push_back(&it.second.attr); 463 kinds.push_back("attribute"); 464 vecs.push_back(&it.second.type); 465 kinds.push_back("type"); 466 for (auto [vec, kind] : zip(vecs, kinds)) { 467 // Handle Attribute/Type emission. 468 std::map<std::string, std::vector<std::pair<int64_t, const Record *>>> 469 perType; 470 for (auto kt : llvm::enumerate(*vec)) 471 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); 472 for (const auto &jt : perType) { 473 for (auto kt : jt.second) 474 gen.emitParse(kind, *std::get<1>(kt)); 475 gen.emitPrint(kind, jt.first, jt.second); 476 } 477 gen.emitParseDispatch(kind, *vec); 478 479 SmallVector<std::string> types; 480 for (const auto &it : perType) { 481 types.push_back(it.first); 482 } 483 gen.emitPrintDispatch(kind, types); 484 } 485 486 return false; 487 } 488 489 static mlir::GenRegistration 490 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", 491 [](const RecordKeeper &records, raw_ostream &os) { 492 return emitBCRW(records, os); 493 }); 494