1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/IndentedOstream.h" 10 #include "mlir/TableGen/GenInfo.h" 11 #include "llvm/ADT/MapVector.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/Support/CommandLine.h" 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Error.h" 16 #include "llvm/TableGen/Record.h" 17 #include <regex> 18 19 using namespace llvm; 20 21 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); 22 static llvm::cl::opt<std::string> 23 selectedBcDialect("bytecode-dialect", 24 llvm::cl::desc("The dialect to gen for"), 25 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); 26 27 namespace { 28 29 /// Helper class to generate C++ bytecode parser helpers. 30 class Generator { 31 public: 32 Generator(raw_ostream &output) : output(output) {} 33 34 /// Returns whether successfully emitted attribute/type parsers. 35 void emitParse(StringRef kind, const Record &x); 36 37 /// Returns whether successfully emitted attribute/type printers. 38 void emitPrint(StringRef kind, StringRef type, 39 ArrayRef<std::pair<int64_t, const Record *>> vec); 40 41 /// Emits parse dispatch table. 42 void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec); 43 44 /// Emits print dispatch table. 45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); 46 47 private: 48 /// Emits parse calls to construct given kind. 49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, 50 ArrayRef<Init *> args, ArrayRef<std::string> argNames, 51 StringRef failure, mlir::raw_indented_ostream &ios); 52 53 /// Emits print instructions. 54 void emitPrintHelper(const Record *memberRec, StringRef kind, 55 StringRef parent, StringRef name, 56 mlir::raw_indented_ostream &ios); 57 58 raw_ostream &output; 59 }; 60 } // namespace 61 62 /// Helper to replace set of from strings to target in `s`. 63 /// Assumed: non-overlapping replacements. 64 static std::string format(StringRef templ, 65 std::map<std::string, std::string> &&map) { 66 std::string s = templ.str(); 67 for (const auto &[from, to] : map) 68 // All replacements start with $, don't treat as anchor. 69 s = std::regex_replace(s, std::regex("\\" + from), to); 70 return s; 71 } 72 73 /// Return string with first character capitalized. 74 static std::string capitalize(StringRef str) { 75 return ((Twine)toUpper(str[0]) + str.drop_front()).str(); 76 } 77 78 /// Return the C++ type for the given record. 79 static std::string getCType(const Record *def) { 80 std::string format = "{0}"; 81 if (def->isSubClassOf("Array")) { 82 def = def->getValueAsDef("elemT"); 83 format = "SmallVector<{0}>"; 84 } 85 86 StringRef cType = def->getValueAsString("cType"); 87 if (cType.empty()) { 88 if (def->isAnonymous()) 89 PrintFatalError(def->getLoc(), "Unable to determine cType"); 90 91 return formatv(format.c_str(), def->getName().str()); 92 } 93 return formatv(format.c_str(), cType.str()); 94 } 95 96 void Generator::emitParseDispatch(StringRef kind, 97 ArrayRef<const Record *> vec) { 98 mlir::raw_indented_ostream os(output); 99 char const *head = 100 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; 101 os << formatv(head, capitalize(kind)); 102 auto funScope = os.scope(" {\n", "}\n\n"); 103 104 if (vec.empty()) { 105 os << "return reader.emitError() << \"unknown attribute\", " 106 << capitalize(kind) << "();\n"; 107 return; 108 } 109 110 os << "uint64_t kind;\n"; 111 os << "if (failed(reader.readVarInt(kind)))\n" 112 << " return " << capitalize(kind) << "();\n"; 113 os << "switch (kind) "; 114 { 115 auto switchScope = os.scope("{\n", "}\n"); 116 for (const auto &it : llvm::enumerate(vec)) { 117 if (it.value()->getName() == "ReservedOrDead") 118 continue; 119 120 os << formatv("case {1}:\n return read{0}(context, reader);\n", 121 it.value()->getName(), it.index()); 122 } 123 os << "default:\n" 124 << " reader.emitError() << \"unknown attribute code: \" " 125 << "<< kind;\n" 126 << " return " << capitalize(kind) << "();\n"; 127 } 128 os << "return " << capitalize(kind) << "();\n"; 129 } 130 131 void Generator::emitParse(StringRef kind, const Record &x) { 132 if (x.getNameInitAsString() == "ReservedOrDead") 133 return; 134 135 char const *head = 136 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; 137 mlir::raw_indented_ostream os(output); 138 std::string returnType = getCType(&x); 139 os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName()); 140 DagInit *members = x.getValueAsDag("members"); 141 SmallVector<std::string> argNames = 142 llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { 143 return init->getAsUnquotedString(); 144 })); 145 StringRef builder = x.getValueAsString("cBuilder").trim(); 146 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, 147 returnType + "()", os); 148 os << "\n\n"; 149 } 150 151 void printParseConditional(mlir::raw_indented_ostream &ios, 152 ArrayRef<Init *> args, 153 ArrayRef<std::string> argNames) { 154 ios << "if "; 155 auto parenScope = ios.scope("(", ") {"); 156 ios.indent(); 157 158 auto listHelperName = [](StringRef name) { 159 return formatv("read{0}", capitalize(name)); 160 }; 161 162 auto parsedArgs = 163 llvm::to_vector(make_filter_range(args, [](Init *const attr) { 164 Record *def = cast<DefInit>(attr)->getDef(); 165 if (def->isSubClassOf("Array")) 166 return true; 167 return !def->getValueAsString("cParser").empty(); 168 })); 169 170 interleave( 171 zip(parsedArgs, argNames), 172 [&](std::tuple<llvm::Init *&, const std::string &> it) { 173 Record *attr = cast<DefInit>(std::get<0>(it))->getDef(); 174 std::string parser; 175 if (auto optParser = attr->getValueAsOptionalString("cParser")) { 176 parser = *optParser; 177 } else if (attr->isSubClassOf("Array")) { 178 Record *def = attr->getValueAsDef("elemT"); 179 bool composite = def->isSubClassOf("CompositeBytecode"); 180 if (!composite && def->isSubClassOf("AttributeKind")) 181 parser = "succeeded($_reader.readAttributes($_var))"; 182 else if (!composite && def->isSubClassOf("TypeKind")) 183 parser = "succeeded($_reader.readTypes($_var))"; 184 else 185 parser = ("succeeded($_reader.readList($_var, " + 186 listHelperName(std::get<1>(it)) + "))") 187 .str(); 188 } else { 189 PrintFatalError(attr->getLoc(), "No parser specified"); 190 } 191 std::string type = getCType(attr); 192 ios << format(parser, {{"$_reader", "reader"}, 193 {"$_resultType", type}, 194 {"$_var", std::get<1>(it)}}); 195 }, 196 [&]() { ios << " &&\n"; }); 197 } 198 199 void Generator::emitParseHelper(StringRef kind, StringRef returnType, 200 StringRef builder, ArrayRef<Init *> args, 201 ArrayRef<std::string> argNames, 202 StringRef failure, 203 mlir::raw_indented_ostream &ios) { 204 auto funScope = ios.scope("{\n", "}"); 205 206 if (args.empty()) { 207 ios << formatv("return get<{0}>(context);\n", returnType); 208 return; 209 } 210 211 // Print decls. 212 std::string lastCType = ""; 213 for (auto [arg, name] : zip(args, argNames)) { 214 DefInit *first = dyn_cast<DefInit>(arg); 215 if (!first) 216 PrintFatalError("Unexpected type for " + name); 217 Record *def = first->getDef(); 218 219 // Create variable decls, if there are a block of same type then create 220 // comma separated list of them. 221 std::string cType = getCType(def); 222 if (lastCType == cType) { 223 ios << ", "; 224 } else { 225 if (!lastCType.empty()) 226 ios << ";\n"; 227 ios << cType << " "; 228 } 229 ios << name; 230 lastCType = cType; 231 } 232 ios << ";\n"; 233 234 // Returns the name of the helper used in list parsing. E.g., the name of the 235 // lambda passed to array parsing. 236 auto listHelperName = [](StringRef name) { 237 return formatv("read{0}", capitalize(name)); 238 }; 239 240 // Emit list helper functions. 241 for (auto [arg, name] : zip(args, argNames)) { 242 Record *attr = cast<DefInit>(arg)->getDef(); 243 if (!attr->isSubClassOf("Array")) 244 continue; 245 246 // TODO: Dedupe readers. 247 Record *def = attr->getValueAsDef("elemT"); 248 if (!def->isSubClassOf("CompositeBytecode") && 249 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) 250 continue; 251 252 std::string returnType = getCType(def); 253 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" 254 << returnType << "> "; 255 SmallVector<Init *> args; 256 SmallVector<std::string> argNames; 257 if (def->isSubClassOf("CompositeBytecode")) { 258 DagInit *members = def->getValueAsDag("members"); 259 args = llvm::to_vector(members->getArgs()); 260 argNames = llvm::to_vector( 261 map_range(members->getArgNames(), [](StringInit *init) { 262 return init->getAsUnquotedString(); 263 })); 264 } else { 265 args = {def->getDefInit()}; 266 argNames = {"temp"}; 267 } 268 StringRef builder = def->getValueAsString("cBuilder"); 269 emitParseHelper(kind, returnType, builder, args, argNames, "failure()", 270 ios); 271 ios << ";\n"; 272 } 273 274 // Print parse conditional. 275 printParseConditional(ios, args, argNames); 276 277 // Compute args to pass to create method. 278 auto passedArgs = llvm::to_vector(make_filter_range( 279 argNames, [](StringRef str) { return !str.starts_with("_"); })); 280 std::string argStr; 281 raw_string_ostream argStream(argStr); 282 interleaveComma(passedArgs, argStream, 283 [&](const std::string &str) { argStream << str; }); 284 // Return the invoked constructor. 285 ios << "\nreturn " 286 << format(builder, {{"$_resultType", returnType.str()}, 287 {"$_args", argStream.str()}}) 288 << ";\n"; 289 ios.unindent(); 290 291 // TODO: Emit error in debug. 292 // This assumes the result types in error case can always be empty 293 // constructed. 294 ios << "}\nreturn " << failure << ";\n"; 295 } 296 297 void Generator::emitPrint(StringRef kind, StringRef type, 298 ArrayRef<std::pair<int64_t, const Record *>> vec) { 299 if (type == "ReservedOrDead") 300 return; 301 302 char const *head = 303 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; 304 mlir::raw_indented_ostream os(output); 305 os << formatv(head, type, kind); 306 auto funScope = os.scope("{\n", "}\n\n"); 307 308 // Check that predicates specified if multiple bytecode instances. 309 for (const llvm::Record *rec : make_second_range(vec)) { 310 StringRef pred = rec->getValueAsString("printerPredicate"); 311 if (vec.size() > 1 && pred.empty()) { 312 for (auto [index, rec] : vec) { 313 (void)index; 314 StringRef pred = rec->getValueAsString("printerPredicate"); 315 if (vec.size() > 1 && pred.empty()) 316 PrintError(rec->getLoc(), 317 "Requires parsing predicate given common cType"); 318 } 319 PrintFatalError("Unspecified for shared cType " + type); 320 } 321 } 322 323 for (auto [index, rec] : vec) { 324 StringRef pred = rec->getValueAsString("printerPredicate"); 325 if (!pred.empty()) { 326 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; 327 os.indent(); 328 } 329 330 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index 331 << ");\n"; 332 333 auto *members = rec->getValueAsDag("members"); 334 for (auto [arg, name] : 335 llvm::zip(members->getArgs(), members->getArgNames())) { 336 DefInit *def = dyn_cast<DefInit>(arg); 337 assert(def); 338 Record *memberRec = def->getDef(); 339 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); 340 } 341 342 if (!pred.empty()) { 343 os.unindent(); 344 os << "}\n"; 345 } 346 } 347 } 348 349 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind, 350 StringRef parent, StringRef name, 351 mlir::raw_indented_ostream &ios) { 352 std::string getter; 353 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); 354 cGetter && !cGetter->empty()) { 355 getter = format( 356 *cGetter, 357 {{"$_attrType", parent.str()}, 358 {"$_member", name.str()}, 359 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); 360 } else { 361 getter = 362 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) 363 .str(); 364 } 365 366 if (memberRec->isSubClassOf("Array")) { 367 Record *def = memberRec->getValueAsDef("elemT"); 368 if (!def->isSubClassOf("CompositeBytecode")) { 369 if (def->isSubClassOf("AttributeKind")) { 370 ios << "writer.writeAttributes(" << getter << ");\n"; 371 return; 372 } 373 if (def->isSubClassOf("TypeKind")) { 374 ios << "writer.writeTypes(" << getter << ");\n"; 375 return; 376 } 377 } 378 std::string returnType = getCType(def); 379 std::string nestedName = kind.str(); 380 ios << "writer.writeList(" << getter << ", [&](" << returnType << " " 381 << nestedName << ") "; 382 auto lambdaScope = ios.scope("{\n", "});\n"); 383 return emitPrintHelper(def, kind, nestedName, nestedName, ios); 384 } 385 if (memberRec->isSubClassOf("CompositeBytecode")) { 386 auto *members = memberRec->getValueAsDag("members"); 387 for (auto [arg, argName] : 388 zip(members->getArgs(), members->getArgNames())) { 389 DefInit *def = dyn_cast<DefInit>(arg); 390 assert(def); 391 emitPrintHelper(def->getDef(), kind, parent, 392 argName->getAsUnquotedString(), ios); 393 } 394 } 395 396 if (std::string printer = memberRec->getValueAsString("cPrinter").str(); 397 !printer.empty()) 398 ios << format(printer, {{"$_writer", "writer"}, 399 {"$_name", kind.str()}, 400 {"$_getter", getter}}) 401 << ";\n"; 402 } 403 404 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { 405 mlir::raw_indented_ostream os(output); 406 char const *head = R"(static LogicalResult write{0}({0} {1}, 407 DialectBytecodeWriter &writer))"; 408 os << formatv(head, capitalize(kind), kind); 409 auto funScope = os.scope(" {\n", "}\n\n"); 410 411 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind 412 << ")"; 413 auto switchScope = os.scope("", ""); 414 for (StringRef type : vec) { 415 if (type == "ReservedOrDead") 416 continue; 417 418 os << "\n.Case([&](" << type << " t)"; 419 auto caseScope = os.scope(" {\n", "})"); 420 os << "return write(t, writer), success();\n"; 421 } 422 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; 423 } 424 425 namespace { 426 /// Container of Attribute or Type for Dialect. 427 struct AttrOrType { 428 std::vector<const Record *> attr, type; 429 }; 430 } // namespace 431 432 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { 433 MapVector<StringRef, AttrOrType> dialectAttrOrType; 434 for (const Record *it : 435 records.getAllDerivedDefinitions("DialectAttributes")) { 436 if (!selectedBcDialect.empty() && 437 it->getValueAsString("dialect") != selectedBcDialect) 438 continue; 439 dialectAttrOrType[it->getValueAsString("dialect")].attr = 440 it->getValueAsListOfDefs("elems"); 441 } 442 for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) { 443 if (!selectedBcDialect.empty() && 444 it->getValueAsString("dialect") != selectedBcDialect) 445 continue; 446 dialectAttrOrType[it->getValueAsString("dialect")].type = 447 it->getValueAsListOfDefs("elems"); 448 } 449 450 if (dialectAttrOrType.size() != 1) 451 PrintFatalError("Single dialect per invocation required (either only " 452 "one in input file or specified via dialect option)"); 453 454 auto it = dialectAttrOrType.front(); 455 Generator gen(os); 456 457 SmallVector<std::vector<const Record *> *, 2> vecs; 458 SmallVector<std::string, 2> kinds; 459 vecs.push_back(&it.second.attr); 460 kinds.push_back("attribute"); 461 vecs.push_back(&it.second.type); 462 kinds.push_back("type"); 463 for (auto [vec, kind] : zip(vecs, kinds)) { 464 // Handle Attribute/Type emission. 465 std::map<std::string, std::vector<std::pair<int64_t, const Record *>>> 466 perType; 467 for (auto kt : llvm::enumerate(*vec)) 468 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); 469 for (const auto &jt : perType) { 470 for (auto kt : jt.second) 471 gen.emitParse(kind, *std::get<1>(kt)); 472 gen.emitPrint(kind, jt.first, jt.second); 473 } 474 gen.emitParseDispatch(kind, *vec); 475 476 SmallVector<std::string> types; 477 for (const auto &it : perType) { 478 types.push_back(it.first); 479 } 480 gen.emitPrintDispatch(kind, types); 481 } 482 483 return false; 484 } 485 486 static mlir::GenRegistration 487 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", 488 [](const RecordKeeper &records, raw_ostream &os) { 489 return emitBCRW(records, os); 490 }); 491