1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/IndentedOstream.h" 10 #include "mlir/TableGen/GenInfo.h" 11 #include "llvm/ADT/MapVector.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/Support/CommandLine.h" 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Error.h" 16 #include "llvm/TableGen/Record.h" 17 #include <regex> 18 19 using namespace llvm; 20 21 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); 22 static llvm::cl::opt<std::string> 23 selectedBcDialect("bytecode-dialect", 24 llvm::cl::desc("The dialect to gen for"), 25 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); 26 27 namespace { 28 29 /// Helper class to generate C++ bytecode parser helpers. 30 class Generator { 31 public: 32 Generator(raw_ostream &output) : output(output) {} 33 34 /// Returns whether successfully emitted attribute/type parsers. 35 void emitParse(StringRef kind, Record &x); 36 37 /// Returns whether successfully emitted attribute/type printers. 38 void emitPrint(StringRef kind, StringRef type, 39 ArrayRef<std::pair<int64_t, Record *>> vec); 40 41 /// Emits parse dispatch table. 42 void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec); 43 44 /// Emits print dispatch table. 45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); 46 47 private: 48 /// Emits parse calls to construct given kind. 49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, 50 ArrayRef<Init *> args, ArrayRef<std::string> argNames, 51 StringRef failure, mlir::raw_indented_ostream &ios); 52 53 /// Emits print instructions. 54 void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent, 55 StringRef name, mlir::raw_indented_ostream &ios); 56 57 raw_ostream &output; 58 }; 59 } // namespace 60 61 /// Helper to replace set of from strings to target in `s`. 62 /// Assumed: non-overlapping replacements. 63 static std::string format(StringRef templ, 64 std::map<std::string, std::string> &&map) { 65 std::string s = templ.str(); 66 for (const auto &[from, to] : map) 67 // All replacements start with $, don't treat as anchor. 68 s = std::regex_replace(s, std::regex("\\" + from), to); 69 return s; 70 } 71 72 /// Return string with first character capitalized. 73 static std::string capitalize(StringRef str) { 74 return ((Twine)toUpper(str[0]) + str.drop_front()).str(); 75 } 76 77 /// Return the C++ type for the given record. 78 static std::string getCType(Record *def) { 79 std::string format = "{0}"; 80 if (def->isSubClassOf("Array")) { 81 def = def->getValueAsDef("elemT"); 82 format = "SmallVector<{0}>"; 83 } 84 85 StringRef cType = def->getValueAsString("cType"); 86 if (cType.empty()) { 87 if (def->isAnonymous()) 88 PrintFatalError(def->getLoc(), "Unable to determine cType"); 89 90 return formatv(format.c_str(), def->getName().str()); 91 } 92 return formatv(format.c_str(), cType.str()); 93 } 94 95 void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) { 96 mlir::raw_indented_ostream os(output); 97 char const *head = 98 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; 99 os << formatv(head, capitalize(kind)); 100 auto funScope = os.scope(" {\n", "}\n\n"); 101 102 os << "uint64_t kind;\n"; 103 os << "if (failed(reader.readVarInt(kind)))\n" 104 << " return " << capitalize(kind) << "();\n"; 105 os << "switch (kind) "; 106 { 107 auto switchScope = os.scope("{\n", "}\n"); 108 for (const auto &it : llvm::enumerate(vec)) { 109 if (it.value()->getName() == "ReservedOrDead") 110 continue; 111 112 os << formatv("case {1}:\n return read{0}(context, reader);\n", 113 it.value()->getName(), it.index()); 114 } 115 os << "default:\n" 116 << " reader.emitError() << \"unknown attribute code: \" " 117 << "<< kind;\n" 118 << " return " << capitalize(kind) << "();\n"; 119 } 120 os << "return " << capitalize(kind) << "();\n"; 121 } 122 123 void Generator::emitParse(StringRef kind, Record &x) { 124 if (x.getNameInitAsString() == "ReservedOrDead") 125 return; 126 127 char const *head = 128 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; 129 mlir::raw_indented_ostream os(output); 130 std::string returnType = getCType(&x); 131 os << formatv(head, returnType, x.getName()); 132 DagInit *members = x.getValueAsDag("members"); 133 SmallVector<std::string> argNames = 134 llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { 135 return init->getAsUnquotedString(); 136 })); 137 StringRef builder = x.getValueAsString("cBuilder"); 138 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, 139 returnType + "()", os); 140 os << "\n\n"; 141 } 142 143 void printParseConditional(mlir::raw_indented_ostream &ios, 144 ArrayRef<Init *> args, 145 ArrayRef<std::string> argNames) { 146 ios << "if "; 147 auto parenScope = ios.scope("(", ") {"); 148 ios.indent(); 149 150 auto listHelperName = [](StringRef name) { 151 return formatv("read{0}", capitalize(name)); 152 }; 153 154 auto parsedArgs = 155 llvm::to_vector(make_filter_range(args, [](Init *const attr) { 156 Record *def = cast<DefInit>(attr)->getDef(); 157 if (def->isSubClassOf("Array")) 158 return true; 159 return !def->getValueAsString("cParser").empty(); 160 })); 161 162 interleave( 163 zip(parsedArgs, argNames), 164 [&](std::tuple<llvm::Init *&, const std::string &> it) { 165 Record *attr = cast<DefInit>(std::get<0>(it))->getDef(); 166 std::string parser; 167 if (auto optParser = attr->getValueAsOptionalString("cParser")) { 168 parser = *optParser; 169 } else if (attr->isSubClassOf("Array")) { 170 Record *def = attr->getValueAsDef("elemT"); 171 bool composite = def->isSubClassOf("CompositeBytecode"); 172 if (!composite && def->isSubClassOf("AttributeKind")) 173 parser = "succeeded($_reader.readAttributes($_var))"; 174 else if (!composite && def->isSubClassOf("TypeKind")) 175 parser = "succeeded($_reader.readTypes($_var))"; 176 else 177 parser = ("succeeded($_reader.readList($_var, " + 178 listHelperName(std::get<1>(it)) + "))") 179 .str(); 180 } else { 181 PrintFatalError(attr->getLoc(), "No parser specified"); 182 } 183 std::string type = getCType(attr); 184 ios << format(parser, {{"$_reader", "reader"}, 185 {"$_resultType", type}, 186 {"$_var", std::get<1>(it)}}); 187 }, 188 [&]() { ios << " &&\n"; }); 189 } 190 191 void Generator::emitParseHelper(StringRef kind, StringRef returnType, 192 StringRef builder, ArrayRef<Init *> args, 193 ArrayRef<std::string> argNames, 194 StringRef failure, 195 mlir::raw_indented_ostream &ios) { 196 auto funScope = ios.scope("{\n", "}"); 197 198 if (args.empty()) { 199 ios << formatv("return get<{0}>(context);\n", returnType); 200 return; 201 } 202 203 // Print decls. 204 std::string lastCType = ""; 205 for (auto [arg, name] : zip(args, argNames)) { 206 DefInit *first = dyn_cast<DefInit>(arg); 207 if (!first) 208 PrintFatalError("Unexpected type for " + name); 209 Record *def = first->getDef(); 210 211 // Create variable decls, if there are a block of same type then create 212 // comma separated list of them. 213 std::string cType = getCType(def); 214 if (lastCType == cType) { 215 ios << ", "; 216 } else { 217 if (!lastCType.empty()) 218 ios << ";\n"; 219 ios << cType << " "; 220 } 221 ios << name; 222 lastCType = cType; 223 } 224 ios << ";\n"; 225 226 // Returns the name of the helper used in list parsing. E.g., the name of the 227 // lambda passed to array parsing. 228 auto listHelperName = [](StringRef name) { 229 return formatv("read{0}", capitalize(name)); 230 }; 231 232 // Emit list helper functions. 233 for (auto [arg, name] : zip(args, argNames)) { 234 Record *attr = cast<DefInit>(arg)->getDef(); 235 if (!attr->isSubClassOf("Array")) 236 continue; 237 238 // TODO: Dedupe readers. 239 Record *def = attr->getValueAsDef("elemT"); 240 if (!def->isSubClassOf("CompositeBytecode") && 241 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) 242 continue; 243 244 std::string returnType = getCType(def); 245 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" 246 << returnType << "> "; 247 SmallVector<Init *> args; 248 SmallVector<std::string> argNames; 249 if (def->isSubClassOf("CompositeBytecode")) { 250 DagInit *members = def->getValueAsDag("members"); 251 args = llvm::to_vector(members->getArgs()); 252 argNames = llvm::to_vector( 253 map_range(members->getArgNames(), [](StringInit *init) { 254 return init->getAsUnquotedString(); 255 })); 256 } else { 257 args = {def->getDefInit()}; 258 argNames = {"temp"}; 259 } 260 StringRef builder = def->getValueAsString("cBuilder"); 261 emitParseHelper(kind, returnType, builder, args, argNames, "failure()", 262 ios); 263 ios << ";\n"; 264 } 265 266 // Print parse conditional. 267 printParseConditional(ios, args, argNames); 268 269 // Compute args to pass to create method. 270 auto passedArgs = llvm::to_vector(make_filter_range( 271 argNames, [](StringRef str) { return !str.starts_with("_"); })); 272 std::string argStr; 273 raw_string_ostream argStream(argStr); 274 interleaveComma(passedArgs, argStream, 275 [&](const std::string &str) { argStream << str; }); 276 // Return the invoked constructor. 277 ios << "\nreturn " 278 << format(builder, {{"$_resultType", returnType.str()}, 279 {"$_args", argStream.str()}}) 280 << ";\n"; 281 ios.unindent(); 282 283 // TODO: Emit error in debug. 284 // This assumes the result types in error case can always be empty 285 // constructed. 286 ios << "}\nreturn " << failure << ";\n"; 287 } 288 289 void Generator::emitPrint(StringRef kind, StringRef type, 290 ArrayRef<std::pair<int64_t, Record *>> vec) { 291 if (type == "ReservedOrDead") 292 return; 293 294 char const *head = 295 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; 296 mlir::raw_indented_ostream os(output); 297 os << formatv(head, type, kind); 298 auto funScope = os.scope("{\n", "}\n\n"); 299 300 // Check that predicates specified if multiple bytecode instances. 301 for (llvm::Record *rec : make_second_range(vec)) { 302 StringRef pred = rec->getValueAsString("printerPredicate"); 303 if (vec.size() > 1 && pred.empty()) { 304 for (auto [index, rec] : vec) { 305 (void)index; 306 StringRef pred = rec->getValueAsString("printerPredicate"); 307 if (vec.size() > 1 && pred.empty()) 308 PrintError(rec->getLoc(), 309 "Requires parsing predicate given common cType"); 310 } 311 PrintFatalError("Unspecified for shared cType " + type); 312 } 313 } 314 315 for (auto [index, rec] : vec) { 316 StringRef pred = rec->getValueAsString("printerPredicate"); 317 if (!pred.empty()) { 318 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; 319 os.indent(); 320 } 321 322 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index 323 << ");\n"; 324 325 auto *members = rec->getValueAsDag("members"); 326 for (auto [arg, name] : 327 llvm::zip(members->getArgs(), members->getArgNames())) { 328 DefInit *def = dyn_cast<DefInit>(arg); 329 assert(def); 330 Record *memberRec = def->getDef(); 331 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); 332 } 333 334 if (!pred.empty()) { 335 os.unindent(); 336 os << "}\n"; 337 } 338 } 339 } 340 341 void Generator::emitPrintHelper(Record *memberRec, StringRef kind, 342 StringRef parent, StringRef name, 343 mlir::raw_indented_ostream &ios) { 344 std::string getter; 345 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); 346 cGetter && !cGetter->empty()) { 347 getter = format( 348 *cGetter, 349 {{"$_attrType", parent.str()}, 350 {"$_member", name.str()}, 351 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); 352 } else { 353 getter = 354 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) 355 .str(); 356 } 357 358 if (memberRec->isSubClassOf("Array")) { 359 Record *def = memberRec->getValueAsDef("elemT"); 360 if (!def->isSubClassOf("CompositeBytecode")) { 361 if (def->isSubClassOf("AttributeKind")) { 362 ios << "writer.writeAttributes(" << getter << ");\n"; 363 return; 364 } 365 if (def->isSubClassOf("TypeKind")) { 366 ios << "writer.writeTypes(" << getter << ");\n"; 367 return; 368 } 369 } 370 std::string returnType = getCType(def); 371 ios << "writer.writeList(" << getter << ", [&](" << returnType << " " 372 << kind << ") "; 373 auto lambdaScope = ios.scope("{\n", "});\n"); 374 return emitPrintHelper(def, kind, kind, kind, ios); 375 } 376 if (memberRec->isSubClassOf("CompositeBytecode")) { 377 auto *members = memberRec->getValueAsDag("members"); 378 for (auto [arg, argName] : 379 zip(members->getArgs(), members->getArgNames())) { 380 DefInit *def = dyn_cast<DefInit>(arg); 381 assert(def); 382 emitPrintHelper(def->getDef(), kind, parent, 383 argName->getAsUnquotedString(), ios); 384 } 385 } 386 387 if (std::string printer = memberRec->getValueAsString("cPrinter").str(); 388 !printer.empty()) 389 ios << format(printer, {{"$_writer", "writer"}, 390 {"$_name", kind.str()}, 391 {"$_getter", getter}}) 392 << ";\n"; 393 } 394 395 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { 396 mlir::raw_indented_ostream os(output); 397 char const *head = R"(static LogicalResult write{0}({0} {1}, 398 DialectBytecodeWriter &writer))"; 399 os << formatv(head, capitalize(kind), kind); 400 auto funScope = os.scope(" {\n", "}\n\n"); 401 402 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind 403 << ")"; 404 auto switchScope = os.scope("", ""); 405 for (StringRef type : vec) { 406 if (type == "ReservedOrDead") 407 continue; 408 409 os << "\n.Case([&](" << type << " t)"; 410 auto caseScope = os.scope(" {\n", "})"); 411 os << "return write(t, writer), success();\n"; 412 } 413 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; 414 } 415 416 namespace { 417 /// Container of Attribute or Type for Dialect. 418 struct AttrOrType { 419 std::vector<Record *> attr, type; 420 }; 421 } // namespace 422 423 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { 424 MapVector<StringRef, AttrOrType> dialectAttrOrType; 425 for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) { 426 if (!selectedBcDialect.empty() && 427 it->getValueAsString("dialect") != selectedBcDialect) 428 continue; 429 dialectAttrOrType[it->getValueAsString("dialect")].attr = 430 it->getValueAsListOfDefs("elems"); 431 } 432 for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) { 433 if (!selectedBcDialect.empty() && 434 it->getValueAsString("dialect") != selectedBcDialect) 435 continue; 436 dialectAttrOrType[it->getValueAsString("dialect")].type = 437 it->getValueAsListOfDefs("elems"); 438 } 439 440 if (dialectAttrOrType.size() != 1) 441 PrintFatalError("Single dialect per invocation required (either only " 442 "one in input file or specified via dialect option)"); 443 444 auto it = dialectAttrOrType.front(); 445 Generator gen(os); 446 447 SmallVector<std::vector<Record *> *, 2> vecs; 448 SmallVector<std::string, 2> kinds; 449 vecs.push_back(&it.second.attr); 450 kinds.push_back("attribute"); 451 vecs.push_back(&it.second.type); 452 kinds.push_back("type"); 453 for (auto [vec, kind] : zip(vecs, kinds)) { 454 // Handle Attribute/Type emission. 455 std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType; 456 for (auto kt : llvm::enumerate(*vec)) 457 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); 458 for (const auto &jt : perType) { 459 for (auto kt : jt.second) 460 gen.emitParse(kind, *std::get<1>(kt)); 461 gen.emitPrint(kind, jt.first, jt.second); 462 } 463 gen.emitParseDispatch(kind, *vec); 464 465 SmallVector<std::string> types; 466 for (const auto &it : perType) { 467 types.push_back(it.first); 468 } 469 gen.emitPrintDispatch(kind, types); 470 } 471 472 return false; 473 } 474 475 static mlir::GenRegistration 476 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", 477 [](const RecordKeeper &records, raw_ostream &os) { 478 return emitBCRW(records, os); 479 }); 480