1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/IndentedOstream.h" 10 #include "mlir/TableGen/GenInfo.h" 11 #include "llvm/ADT/MapVector.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/ADT/SmallVectorExtras.h" 14 #include "llvm/Support/CommandLine.h" 15 #include "llvm/Support/FormatVariadic.h" 16 #include "llvm/TableGen/Error.h" 17 #include "llvm/TableGen/Record.h" 18 #include <regex> 19 20 using namespace llvm; 21 22 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); 23 static cl::opt<std::string> 24 selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"), 25 cl::cat(dialectGenCat), cl::CommaSeparated); 26 27 namespace { 28 29 /// Helper class to generate C++ bytecode parser helpers. 30 class Generator { 31 public: 32 Generator(raw_ostream &output) : output(output) {} 33 34 /// Returns whether successfully emitted attribute/type parsers. 35 void emitParse(StringRef kind, const Record &x); 36 37 /// Returns whether successfully emitted attribute/type printers. 38 void emitPrint(StringRef kind, StringRef type, 39 ArrayRef<std::pair<int64_t, const Record *>> vec); 40 41 /// Emits parse dispatch table. 42 void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec); 43 44 /// Emits print dispatch table. 45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); 46 47 private: 48 /// Emits parse calls to construct given kind. 49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, 50 ArrayRef<const Init *> args, 51 ArrayRef<std::string> argNames, StringRef failure, 52 mlir::raw_indented_ostream &ios); 53 54 /// Emits print instructions. 55 void emitPrintHelper(const Record *memberRec, StringRef kind, 56 StringRef parent, StringRef name, 57 mlir::raw_indented_ostream &ios); 58 59 raw_ostream &output; 60 }; 61 } // namespace 62 63 /// Helper to replace set of from strings to target in `s`. 64 /// Assumed: non-overlapping replacements. 65 static std::string format(StringRef templ, 66 std::map<std::string, std::string> &&map) { 67 std::string s = templ.str(); 68 for (const auto &[from, to] : map) 69 // All replacements start with $, don't treat as anchor. 70 s = std::regex_replace(s, std::regex("\\" + from), to); 71 return s; 72 } 73 74 /// Return string with first character capitalized. 75 static std::string capitalize(StringRef str) { 76 return ((Twine)toUpper(str[0]) + str.drop_front()).str(); 77 } 78 79 /// Return the C++ type for the given record. 80 static std::string getCType(const Record *def) { 81 std::string format = "{0}"; 82 if (def->isSubClassOf("Array")) { 83 def = def->getValueAsDef("elemT"); 84 format = "SmallVector<{0}>"; 85 } 86 87 StringRef cType = def->getValueAsString("cType"); 88 if (cType.empty()) { 89 if (def->isAnonymous()) 90 PrintFatalError(def->getLoc(), "Unable to determine cType"); 91 92 return formatv(format.c_str(), def->getName().str()); 93 } 94 return formatv(format.c_str(), cType.str()); 95 } 96 97 void Generator::emitParseDispatch(StringRef kind, 98 ArrayRef<const Record *> vec) { 99 mlir::raw_indented_ostream os(output); 100 char const *head = 101 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; 102 os << formatv(head, capitalize(kind)); 103 auto funScope = os.scope(" {\n", "}\n\n"); 104 105 if (vec.empty()) { 106 os << "return reader.emitError() << \"unknown attribute\", " 107 << capitalize(kind) << "();\n"; 108 return; 109 } 110 111 os << "uint64_t kind;\n"; 112 os << "if (failed(reader.readVarInt(kind)))\n" 113 << " return " << capitalize(kind) << "();\n"; 114 os << "switch (kind) "; 115 { 116 auto switchScope = os.scope("{\n", "}\n"); 117 for (const auto &it : llvm::enumerate(vec)) { 118 if (it.value()->getName() == "ReservedOrDead") 119 continue; 120 121 os << formatv("case {1}:\n return read{0}(context, reader);\n", 122 it.value()->getName(), it.index()); 123 } 124 os << "default:\n" 125 << " reader.emitError() << \"unknown attribute code: \" " 126 << "<< kind;\n" 127 << " return " << capitalize(kind) << "();\n"; 128 } 129 os << "return " << capitalize(kind) << "();\n"; 130 } 131 132 void Generator::emitParse(StringRef kind, const Record &x) { 133 if (x.getNameInitAsString() == "ReservedOrDead") 134 return; 135 136 char const *head = 137 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; 138 mlir::raw_indented_ostream os(output); 139 std::string returnType = getCType(&x); 140 os << formatv(head, 141 kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", 142 x.getName()); 143 const DagInit *members = x.getValueAsDag("members"); 144 SmallVector<std::string> argNames = llvm::to_vector( 145 map_range(members->getArgNames(), [](const StringInit *init) { 146 return init->getAsUnquotedString(); 147 })); 148 StringRef builder = x.getValueAsString("cBuilder").trim(); 149 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, 150 returnType + "()", os); 151 os << "\n\n"; 152 } 153 154 void printParseConditional(mlir::raw_indented_ostream &ios, 155 ArrayRef<const Init *> args, 156 ArrayRef<std::string> argNames) { 157 ios << "if "; 158 auto parenScope = ios.scope("(", ") {"); 159 ios.indent(); 160 161 auto listHelperName = [](StringRef name) { 162 return formatv("read{0}", capitalize(name)); 163 }; 164 165 auto parsedArgs = llvm::filter_to_vector(args, [](const Init *const attr) { 166 const Record *def = cast<DefInit>(attr)->getDef(); 167 if (def->isSubClassOf("Array")) 168 return true; 169 return !def->getValueAsString("cParser").empty(); 170 }); 171 172 interleave( 173 zip(parsedArgs, argNames), 174 [&](std::tuple<const Init *&, const std::string &> it) { 175 const Record *attr = cast<DefInit>(std::get<0>(it))->getDef(); 176 std::string parser; 177 if (auto optParser = attr->getValueAsOptionalString("cParser")) { 178 parser = *optParser; 179 } else if (attr->isSubClassOf("Array")) { 180 const Record *def = attr->getValueAsDef("elemT"); 181 bool composite = def->isSubClassOf("CompositeBytecode"); 182 if (!composite && def->isSubClassOf("AttributeKind")) 183 parser = "succeeded($_reader.readAttributes($_var))"; 184 else if (!composite && def->isSubClassOf("TypeKind")) 185 parser = "succeeded($_reader.readTypes($_var))"; 186 else 187 parser = ("succeeded($_reader.readList($_var, " + 188 listHelperName(std::get<1>(it)) + "))") 189 .str(); 190 } else { 191 PrintFatalError(attr->getLoc(), "No parser specified"); 192 } 193 std::string type = getCType(attr); 194 ios << format(parser, {{"$_reader", "reader"}, 195 {"$_resultType", type}, 196 {"$_var", std::get<1>(it)}}); 197 }, 198 [&]() { ios << " &&\n"; }); 199 } 200 201 void Generator::emitParseHelper(StringRef kind, StringRef returnType, 202 StringRef builder, ArrayRef<const Init *> args, 203 ArrayRef<std::string> argNames, 204 StringRef failure, 205 mlir::raw_indented_ostream &ios) { 206 auto funScope = ios.scope("{\n", "}"); 207 208 if (args.empty()) { 209 ios << formatv("return get<{0}>(context);\n", returnType); 210 return; 211 } 212 213 // Print decls. 214 std::string lastCType = ""; 215 for (auto [arg, name] : zip(args, argNames)) { 216 const DefInit *first = dyn_cast<DefInit>(arg); 217 if (!first) 218 PrintFatalError("Unexpected type for " + name); 219 const Record *def = first->getDef(); 220 221 // Create variable decls, if there are a block of same type then create 222 // comma separated list of them. 223 std::string cType = getCType(def); 224 if (lastCType == cType) { 225 ios << ", "; 226 } else { 227 if (!lastCType.empty()) 228 ios << ";\n"; 229 ios << cType << " "; 230 } 231 ios << name; 232 lastCType = cType; 233 } 234 ios << ";\n"; 235 236 // Returns the name of the helper used in list parsing. E.g., the name of the 237 // lambda passed to array parsing. 238 auto listHelperName = [](StringRef name) { 239 return formatv("read{0}", capitalize(name)); 240 }; 241 242 // Emit list helper functions. 243 for (auto [arg, name] : zip(args, argNames)) { 244 const Record *attr = cast<DefInit>(arg)->getDef(); 245 if (!attr->isSubClassOf("Array")) 246 continue; 247 248 // TODO: Dedupe readers. 249 const Record *def = attr->getValueAsDef("elemT"); 250 if (!def->isSubClassOf("CompositeBytecode") && 251 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) 252 continue; 253 254 std::string returnType = getCType(def); 255 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" 256 << returnType << "> "; 257 SmallVector<const Init *> args; 258 SmallVector<std::string> argNames; 259 if (def->isSubClassOf("CompositeBytecode")) { 260 const DagInit *members = def->getValueAsDag("members"); 261 args = llvm::to_vector(members->getArgs()); 262 argNames = llvm::to_vector( 263 map_range(members->getArgNames(), [](const StringInit *init) { 264 return init->getAsUnquotedString(); 265 })); 266 } else { 267 args = {def->getDefInit()}; 268 argNames = {"temp"}; 269 } 270 StringRef builder = def->getValueAsString("cBuilder"); 271 emitParseHelper(kind, returnType, builder, args, argNames, "failure()", 272 ios); 273 ios << ";\n"; 274 } 275 276 // Print parse conditional. 277 printParseConditional(ios, args, argNames); 278 279 // Compute args to pass to create method. 280 auto passedArgs = llvm::filter_to_vector( 281 argNames, [](StringRef str) { return !str.starts_with("_"); }); 282 std::string argStr; 283 raw_string_ostream argStream(argStr); 284 interleaveComma(passedArgs, argStream, 285 [&](const std::string &str) { argStream << str; }); 286 // Return the invoked constructor. 287 ios << "\nreturn " 288 << format(builder, {{"$_resultType", returnType.str()}, 289 {"$_args", argStream.str()}}) 290 << ";\n"; 291 ios.unindent(); 292 293 // TODO: Emit error in debug. 294 // This assumes the result types in error case can always be empty 295 // constructed. 296 ios << "}\nreturn " << failure << ";\n"; 297 } 298 299 void Generator::emitPrint(StringRef kind, StringRef type, 300 ArrayRef<std::pair<int64_t, const Record *>> vec) { 301 if (type == "ReservedOrDead") 302 return; 303 304 char const *head = 305 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; 306 mlir::raw_indented_ostream os(output); 307 os << formatv(head, type, kind); 308 auto funScope = os.scope("{\n", "}\n\n"); 309 310 // Check that predicates specified if multiple bytecode instances. 311 for (const Record *rec : make_second_range(vec)) { 312 StringRef pred = rec->getValueAsString("printerPredicate"); 313 if (vec.size() > 1 && pred.empty()) { 314 for (auto [index, rec] : vec) { 315 (void)index; 316 StringRef pred = rec->getValueAsString("printerPredicate"); 317 if (vec.size() > 1 && pred.empty()) 318 PrintError(rec->getLoc(), 319 "Requires parsing predicate given common cType"); 320 } 321 PrintFatalError("Unspecified for shared cType " + type); 322 } 323 } 324 325 for (auto [index, rec] : vec) { 326 StringRef pred = rec->getValueAsString("printerPredicate"); 327 if (!pred.empty()) { 328 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; 329 os.indent(); 330 } 331 332 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index 333 << ");\n"; 334 335 auto *members = rec->getValueAsDag("members"); 336 for (auto [arg, name] : 337 llvm::zip(members->getArgs(), members->getArgNames())) { 338 const DefInit *def = dyn_cast<DefInit>(arg); 339 assert(def); 340 const Record *memberRec = def->getDef(); 341 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); 342 } 343 344 if (!pred.empty()) { 345 os.unindent(); 346 os << "}\n"; 347 } 348 } 349 } 350 351 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind, 352 StringRef parent, StringRef name, 353 mlir::raw_indented_ostream &ios) { 354 std::string getter; 355 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); 356 cGetter && !cGetter->empty()) { 357 getter = format( 358 *cGetter, 359 {{"$_attrType", parent.str()}, 360 {"$_member", name.str()}, 361 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); 362 } else { 363 getter = 364 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) 365 .str(); 366 } 367 368 if (memberRec->isSubClassOf("Array")) { 369 const Record *def = memberRec->getValueAsDef("elemT"); 370 if (!def->isSubClassOf("CompositeBytecode")) { 371 if (def->isSubClassOf("AttributeKind")) { 372 ios << "writer.writeAttributes(" << getter << ");\n"; 373 return; 374 } 375 if (def->isSubClassOf("TypeKind")) { 376 ios << "writer.writeTypes(" << getter << ");\n"; 377 return; 378 } 379 } 380 std::string returnType = getCType(def); 381 std::string nestedName = kind.str(); 382 ios << "writer.writeList(" << getter << ", [&](" << returnType << " " 383 << nestedName << ") "; 384 auto lambdaScope = ios.scope("{\n", "});\n"); 385 return emitPrintHelper(def, kind, nestedName, nestedName, ios); 386 } 387 if (memberRec->isSubClassOf("CompositeBytecode")) { 388 auto *members = memberRec->getValueAsDag("members"); 389 for (auto [arg, argName] : 390 zip(members->getArgs(), members->getArgNames())) { 391 const DefInit *def = dyn_cast<DefInit>(arg); 392 assert(def); 393 emitPrintHelper(def->getDef(), kind, parent, 394 argName->getAsUnquotedString(), ios); 395 } 396 } 397 398 if (std::string printer = memberRec->getValueAsString("cPrinter").str(); 399 !printer.empty()) 400 ios << format(printer, {{"$_writer", "writer"}, 401 {"$_name", kind.str()}, 402 {"$_getter", getter}}) 403 << ";\n"; 404 } 405 406 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { 407 mlir::raw_indented_ostream os(output); 408 char const *head = R"(static LogicalResult write{0}({0} {1}, 409 DialectBytecodeWriter &writer))"; 410 os << formatv(head, capitalize(kind), kind); 411 auto funScope = os.scope(" {\n", "}\n\n"); 412 413 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind 414 << ")"; 415 auto switchScope = os.scope("", ""); 416 for (StringRef type : vec) { 417 if (type == "ReservedOrDead") 418 continue; 419 420 os << "\n.Case([&](" << type << " t)"; 421 auto caseScope = os.scope(" {\n", "})"); 422 os << "return write(t, writer), success();\n"; 423 } 424 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; 425 } 426 427 namespace { 428 /// Container of Attribute or Type for Dialect. 429 struct AttrOrType { 430 std::vector<const Record *> attr, type; 431 }; 432 } // namespace 433 434 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { 435 MapVector<StringRef, AttrOrType> dialectAttrOrType; 436 for (const Record *it : 437 records.getAllDerivedDefinitions("DialectAttributes")) { 438 if (!selectedBcDialect.empty() && 439 it->getValueAsString("dialect") != selectedBcDialect) 440 continue; 441 dialectAttrOrType[it->getValueAsString("dialect")].attr = 442 it->getValueAsListOfDefs("elems"); 443 } 444 for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) { 445 if (!selectedBcDialect.empty() && 446 it->getValueAsString("dialect") != selectedBcDialect) 447 continue; 448 dialectAttrOrType[it->getValueAsString("dialect")].type = 449 it->getValueAsListOfDefs("elems"); 450 } 451 452 if (dialectAttrOrType.size() != 1) 453 PrintFatalError("Single dialect per invocation required (either only " 454 "one in input file or specified via dialect option)"); 455 456 auto it = dialectAttrOrType.front(); 457 Generator gen(os); 458 459 SmallVector<std::vector<const Record *> *, 2> vecs; 460 SmallVector<std::string, 2> kinds; 461 vecs.push_back(&it.second.attr); 462 kinds.push_back("attribute"); 463 vecs.push_back(&it.second.type); 464 kinds.push_back("type"); 465 for (auto [vec, kind] : zip(vecs, kinds)) { 466 // Handle Attribute/Type emission. 467 std::map<std::string, std::vector<std::pair<int64_t, const Record *>>> 468 perType; 469 for (auto kt : llvm::enumerate(*vec)) 470 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); 471 for (const auto &jt : perType) { 472 for (auto kt : jt.second) 473 gen.emitParse(kind, *std::get<1>(kt)); 474 gen.emitPrint(kind, jt.first, jt.second); 475 } 476 gen.emitParseDispatch(kind, *vec); 477 478 SmallVector<std::string> types; 479 for (const auto &it : perType) { 480 types.push_back(it.first); 481 } 482 gen.emitPrintDispatch(kind, types); 483 } 484 485 return false; 486 } 487 488 static mlir::GenRegistration 489 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", 490 [](const RecordKeeper &records, raw_ostream &os) { 491 return emitBCRW(records, os); 492 }); 493