1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/IndentedOstream.h" 10 #include "mlir/TableGen/GenInfo.h" 11 #include "llvm/ADT/MapVector.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/Support/CommandLine.h" 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Error.h" 16 #include "llvm/TableGen/Record.h" 17 #include <regex> 18 19 using namespace llvm; 20 21 static cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); 22 static cl::opt<std::string> 23 selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"), 24 cl::cat(dialectGenCat), cl::CommaSeparated); 25 26 namespace { 27 28 /// Helper class to generate C++ bytecode parser helpers. 29 class Generator { 30 public: 31 Generator(raw_ostream &output) : output(output) {} 32 33 /// Returns whether successfully emitted attribute/type parsers. 34 void emitParse(StringRef kind, const Record &x); 35 36 /// Returns whether successfully emitted attribute/type printers. 37 void emitPrint(StringRef kind, StringRef type, 38 ArrayRef<std::pair<int64_t, const Record *>> vec); 39 40 /// Emits parse dispatch table. 41 void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec); 42 43 /// Emits print dispatch table. 44 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); 45 46 private: 47 /// Emits parse calls to construct given kind. 48 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, 49 ArrayRef<Init *> args, ArrayRef<std::string> argNames, 50 StringRef failure, mlir::raw_indented_ostream &ios); 51 52 /// Emits print instructions. 53 void emitPrintHelper(const Record *memberRec, StringRef kind, 54 StringRef parent, StringRef name, 55 mlir::raw_indented_ostream &ios); 56 57 raw_ostream &output; 58 }; 59 } // namespace 60 61 /// Helper to replace set of from strings to target in `s`. 62 /// Assumed: non-overlapping replacements. 63 static std::string format(StringRef templ, 64 std::map<std::string, std::string> &&map) { 65 std::string s = templ.str(); 66 for (const auto &[from, to] : map) 67 // All replacements start with $, don't treat as anchor. 68 s = std::regex_replace(s, std::regex("\\" + from), to); 69 return s; 70 } 71 72 /// Return string with first character capitalized. 73 static std::string capitalize(StringRef str) { 74 return ((Twine)toUpper(str[0]) + str.drop_front()).str(); 75 } 76 77 /// Return the C++ type for the given record. 78 static std::string getCType(const Record *def) { 79 std::string format = "{0}"; 80 if (def->isSubClassOf("Array")) { 81 def = def->getValueAsDef("elemT"); 82 format = "SmallVector<{0}>"; 83 } 84 85 StringRef cType = def->getValueAsString("cType"); 86 if (cType.empty()) { 87 if (def->isAnonymous()) 88 PrintFatalError(def->getLoc(), "Unable to determine cType"); 89 90 return formatv(format.c_str(), def->getName().str()); 91 } 92 return formatv(format.c_str(), cType.str()); 93 } 94 95 void Generator::emitParseDispatch(StringRef kind, 96 ArrayRef<const Record *> vec) { 97 mlir::raw_indented_ostream os(output); 98 char const *head = 99 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; 100 os << formatv(head, capitalize(kind)); 101 auto funScope = os.scope(" {\n", "}\n\n"); 102 103 if (vec.empty()) { 104 os << "return reader.emitError() << \"unknown attribute\", " 105 << capitalize(kind) << "();\n"; 106 return; 107 } 108 109 os << "uint64_t kind;\n"; 110 os << "if (failed(reader.readVarInt(kind)))\n" 111 << " return " << capitalize(kind) << "();\n"; 112 os << "switch (kind) "; 113 { 114 auto switchScope = os.scope("{\n", "}\n"); 115 for (const auto &it : llvm::enumerate(vec)) { 116 if (it.value()->getName() == "ReservedOrDead") 117 continue; 118 119 os << formatv("case {1}:\n return read{0}(context, reader);\n", 120 it.value()->getName(), it.index()); 121 } 122 os << "default:\n" 123 << " reader.emitError() << \"unknown attribute code: \" " 124 << "<< kind;\n" 125 << " return " << capitalize(kind) << "();\n"; 126 } 127 os << "return " << capitalize(kind) << "();\n"; 128 } 129 130 void Generator::emitParse(StringRef kind, const Record &x) { 131 if (x.getNameInitAsString() == "ReservedOrDead") 132 return; 133 134 char const *head = 135 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; 136 mlir::raw_indented_ostream os(output); 137 std::string returnType = getCType(&x); 138 os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName()); 139 DagInit *members = x.getValueAsDag("members"); 140 SmallVector<std::string> argNames = 141 llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { 142 return init->getAsUnquotedString(); 143 })); 144 StringRef builder = x.getValueAsString("cBuilder").trim(); 145 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, 146 returnType + "()", os); 147 os << "\n\n"; 148 } 149 150 void printParseConditional(mlir::raw_indented_ostream &ios, 151 ArrayRef<Init *> args, 152 ArrayRef<std::string> argNames) { 153 ios << "if "; 154 auto parenScope = ios.scope("(", ") {"); 155 ios.indent(); 156 157 auto listHelperName = [](StringRef name) { 158 return formatv("read{0}", capitalize(name)); 159 }; 160 161 auto parsedArgs = 162 llvm::to_vector(make_filter_range(args, [](Init *const attr) { 163 const Record *def = cast<DefInit>(attr)->getDef(); 164 if (def->isSubClassOf("Array")) 165 return true; 166 return !def->getValueAsString("cParser").empty(); 167 })); 168 169 interleave( 170 zip(parsedArgs, argNames), 171 [&](std::tuple<llvm::Init *&, const std::string &> it) { 172 const Record *attr = cast<DefInit>(std::get<0>(it))->getDef(); 173 std::string parser; 174 if (auto optParser = attr->getValueAsOptionalString("cParser")) { 175 parser = *optParser; 176 } else if (attr->isSubClassOf("Array")) { 177 const Record *def = attr->getValueAsDef("elemT"); 178 bool composite = def->isSubClassOf("CompositeBytecode"); 179 if (!composite && def->isSubClassOf("AttributeKind")) 180 parser = "succeeded($_reader.readAttributes($_var))"; 181 else if (!composite && def->isSubClassOf("TypeKind")) 182 parser = "succeeded($_reader.readTypes($_var))"; 183 else 184 parser = ("succeeded($_reader.readList($_var, " + 185 listHelperName(std::get<1>(it)) + "))") 186 .str(); 187 } else { 188 PrintFatalError(attr->getLoc(), "No parser specified"); 189 } 190 std::string type = getCType(attr); 191 ios << format(parser, {{"$_reader", "reader"}, 192 {"$_resultType", type}, 193 {"$_var", std::get<1>(it)}}); 194 }, 195 [&]() { ios << " &&\n"; }); 196 } 197 198 void Generator::emitParseHelper(StringRef kind, StringRef returnType, 199 StringRef builder, ArrayRef<Init *> args, 200 ArrayRef<std::string> argNames, 201 StringRef failure, 202 mlir::raw_indented_ostream &ios) { 203 auto funScope = ios.scope("{\n", "}"); 204 205 if (args.empty()) { 206 ios << formatv("return get<{0}>(context);\n", returnType); 207 return; 208 } 209 210 // Print decls. 211 std::string lastCType = ""; 212 for (auto [arg, name] : zip(args, argNames)) { 213 DefInit *first = dyn_cast<DefInit>(arg); 214 if (!first) 215 PrintFatalError("Unexpected type for " + name); 216 const Record *def = first->getDef(); 217 218 // Create variable decls, if there are a block of same type then create 219 // comma separated list of them. 220 std::string cType = getCType(def); 221 if (lastCType == cType) { 222 ios << ", "; 223 } else { 224 if (!lastCType.empty()) 225 ios << ";\n"; 226 ios << cType << " "; 227 } 228 ios << name; 229 lastCType = cType; 230 } 231 ios << ";\n"; 232 233 // Returns the name of the helper used in list parsing. E.g., the name of the 234 // lambda passed to array parsing. 235 auto listHelperName = [](StringRef name) { 236 return formatv("read{0}", capitalize(name)); 237 }; 238 239 // Emit list helper functions. 240 for (auto [arg, name] : zip(args, argNames)) { 241 const Record *attr = cast<DefInit>(arg)->getDef(); 242 if (!attr->isSubClassOf("Array")) 243 continue; 244 245 // TODO: Dedupe readers. 246 const Record *def = attr->getValueAsDef("elemT"); 247 if (!def->isSubClassOf("CompositeBytecode") && 248 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) 249 continue; 250 251 std::string returnType = getCType(def); 252 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" 253 << returnType << "> "; 254 SmallVector<Init *> args; 255 SmallVector<std::string> argNames; 256 if (def->isSubClassOf("CompositeBytecode")) { 257 DagInit *members = def->getValueAsDag("members"); 258 args = llvm::to_vector(members->getArgs()); 259 argNames = llvm::to_vector( 260 map_range(members->getArgNames(), [](StringInit *init) { 261 return init->getAsUnquotedString(); 262 })); 263 } else { 264 args = {def->getDefInit()}; 265 argNames = {"temp"}; 266 } 267 StringRef builder = def->getValueAsString("cBuilder"); 268 emitParseHelper(kind, returnType, builder, args, argNames, "failure()", 269 ios); 270 ios << ";\n"; 271 } 272 273 // Print parse conditional. 274 printParseConditional(ios, args, argNames); 275 276 // Compute args to pass to create method. 277 auto passedArgs = llvm::to_vector(make_filter_range( 278 argNames, [](StringRef str) { return !str.starts_with("_"); })); 279 std::string argStr; 280 raw_string_ostream argStream(argStr); 281 interleaveComma(passedArgs, argStream, 282 [&](const std::string &str) { argStream << str; }); 283 // Return the invoked constructor. 284 ios << "\nreturn " 285 << format(builder, {{"$_resultType", returnType.str()}, 286 {"$_args", argStream.str()}}) 287 << ";\n"; 288 ios.unindent(); 289 290 // TODO: Emit error in debug. 291 // This assumes the result types in error case can always be empty 292 // constructed. 293 ios << "}\nreturn " << failure << ";\n"; 294 } 295 296 void Generator::emitPrint(StringRef kind, StringRef type, 297 ArrayRef<std::pair<int64_t, const Record *>> vec) { 298 if (type == "ReservedOrDead") 299 return; 300 301 char const *head = 302 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; 303 mlir::raw_indented_ostream os(output); 304 os << formatv(head, type, kind); 305 auto funScope = os.scope("{\n", "}\n\n"); 306 307 // Check that predicates specified if multiple bytecode instances. 308 for (const Record *rec : make_second_range(vec)) { 309 StringRef pred = rec->getValueAsString("printerPredicate"); 310 if (vec.size() > 1 && pred.empty()) { 311 for (auto [index, rec] : vec) { 312 (void)index; 313 StringRef pred = rec->getValueAsString("printerPredicate"); 314 if (vec.size() > 1 && pred.empty()) 315 PrintError(rec->getLoc(), 316 "Requires parsing predicate given common cType"); 317 } 318 PrintFatalError("Unspecified for shared cType " + type); 319 } 320 } 321 322 for (auto [index, rec] : vec) { 323 StringRef pred = rec->getValueAsString("printerPredicate"); 324 if (!pred.empty()) { 325 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; 326 os.indent(); 327 } 328 329 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index 330 << ");\n"; 331 332 auto *members = rec->getValueAsDag("members"); 333 for (auto [arg, name] : 334 llvm::zip(members->getArgs(), members->getArgNames())) { 335 DefInit *def = dyn_cast<DefInit>(arg); 336 assert(def); 337 const Record *memberRec = def->getDef(); 338 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); 339 } 340 341 if (!pred.empty()) { 342 os.unindent(); 343 os << "}\n"; 344 } 345 } 346 } 347 348 void Generator::emitPrintHelper(const Record *memberRec, StringRef kind, 349 StringRef parent, StringRef name, 350 mlir::raw_indented_ostream &ios) { 351 std::string getter; 352 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); 353 cGetter && !cGetter->empty()) { 354 getter = format( 355 *cGetter, 356 {{"$_attrType", parent.str()}, 357 {"$_member", name.str()}, 358 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); 359 } else { 360 getter = 361 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) 362 .str(); 363 } 364 365 if (memberRec->isSubClassOf("Array")) { 366 const Record *def = memberRec->getValueAsDef("elemT"); 367 if (!def->isSubClassOf("CompositeBytecode")) { 368 if (def->isSubClassOf("AttributeKind")) { 369 ios << "writer.writeAttributes(" << getter << ");\n"; 370 return; 371 } 372 if (def->isSubClassOf("TypeKind")) { 373 ios << "writer.writeTypes(" << getter << ");\n"; 374 return; 375 } 376 } 377 std::string returnType = getCType(def); 378 std::string nestedName = kind.str(); 379 ios << "writer.writeList(" << getter << ", [&](" << returnType << " " 380 << nestedName << ") "; 381 auto lambdaScope = ios.scope("{\n", "});\n"); 382 return emitPrintHelper(def, kind, nestedName, nestedName, ios); 383 } 384 if (memberRec->isSubClassOf("CompositeBytecode")) { 385 auto *members = memberRec->getValueAsDag("members"); 386 for (auto [arg, argName] : 387 zip(members->getArgs(), members->getArgNames())) { 388 DefInit *def = dyn_cast<DefInit>(arg); 389 assert(def); 390 emitPrintHelper(def->getDef(), kind, parent, 391 argName->getAsUnquotedString(), ios); 392 } 393 } 394 395 if (std::string printer = memberRec->getValueAsString("cPrinter").str(); 396 !printer.empty()) 397 ios << format(printer, {{"$_writer", "writer"}, 398 {"$_name", kind.str()}, 399 {"$_getter", getter}}) 400 << ";\n"; 401 } 402 403 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { 404 mlir::raw_indented_ostream os(output); 405 char const *head = R"(static LogicalResult write{0}({0} {1}, 406 DialectBytecodeWriter &writer))"; 407 os << formatv(head, capitalize(kind), kind); 408 auto funScope = os.scope(" {\n", "}\n\n"); 409 410 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind 411 << ")"; 412 auto switchScope = os.scope("", ""); 413 for (StringRef type : vec) { 414 if (type == "ReservedOrDead") 415 continue; 416 417 os << "\n.Case([&](" << type << " t)"; 418 auto caseScope = os.scope(" {\n", "})"); 419 os << "return write(t, writer), success();\n"; 420 } 421 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; 422 } 423 424 namespace { 425 /// Container of Attribute or Type for Dialect. 426 struct AttrOrType { 427 std::vector<const Record *> attr, type; 428 }; 429 } // namespace 430 431 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { 432 MapVector<StringRef, AttrOrType> dialectAttrOrType; 433 for (const Record *it : 434 records.getAllDerivedDefinitions("DialectAttributes")) { 435 if (!selectedBcDialect.empty() && 436 it->getValueAsString("dialect") != selectedBcDialect) 437 continue; 438 dialectAttrOrType[it->getValueAsString("dialect")].attr = 439 it->getValueAsListOfDefs("elems"); 440 } 441 for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) { 442 if (!selectedBcDialect.empty() && 443 it->getValueAsString("dialect") != selectedBcDialect) 444 continue; 445 dialectAttrOrType[it->getValueAsString("dialect")].type = 446 it->getValueAsListOfDefs("elems"); 447 } 448 449 if (dialectAttrOrType.size() != 1) 450 PrintFatalError("Single dialect per invocation required (either only " 451 "one in input file or specified via dialect option)"); 452 453 auto it = dialectAttrOrType.front(); 454 Generator gen(os); 455 456 SmallVector<std::vector<const Record *> *, 2> vecs; 457 SmallVector<std::string, 2> kinds; 458 vecs.push_back(&it.second.attr); 459 kinds.push_back("attribute"); 460 vecs.push_back(&it.second.type); 461 kinds.push_back("type"); 462 for (auto [vec, kind] : zip(vecs, kinds)) { 463 // Handle Attribute/Type emission. 464 std::map<std::string, std::vector<std::pair<int64_t, const Record *>>> 465 perType; 466 for (auto kt : llvm::enumerate(*vec)) 467 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); 468 for (const auto &jt : perType) { 469 for (auto kt : jt.second) 470 gen.emitParse(kind, *std::get<1>(kt)); 471 gen.emitPrint(kind, jt.first, jt.second); 472 } 473 gen.emitParseDispatch(kind, *vec); 474 475 SmallVector<std::string> types; 476 for (const auto &it : perType) { 477 types.push_back(it.first); 478 } 479 gen.emitPrintDispatch(kind, types); 480 } 481 482 return false; 483 } 484 485 static mlir::GenRegistration 486 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", 487 [](const RecordKeeper &records, raw_ostream &os) { 488 return emitBCRW(records, os); 489 }); 490