1 //===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeWriter.h" 10 #include "IRNumbering.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/CachedHashString.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/Endian.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <optional> 25 26 #define DEBUG_TYPE "mlir-bytecode-writer" 27 28 using namespace mlir; 29 using namespace mlir::bytecode::detail; 30 31 //===----------------------------------------------------------------------===// 32 // BytecodeWriterConfig 33 //===----------------------------------------------------------------------===// 34 35 struct BytecodeWriterConfig::Impl { 36 Impl(StringRef producer) : producer(producer) {} 37 38 /// Version to use when writing. 39 /// Note: This only differs from kVersion if a specific version is set. 40 int64_t bytecodeVersion = bytecode::kVersion; 41 42 /// A flag specifying whether to elide emission of resources into the bytecode 43 /// file. 44 bool shouldElideResourceData = false; 45 46 /// A map containing dialect version information for each dialect to emit. 47 llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap; 48 49 /// The producer of the bytecode. 50 StringRef producer; 51 52 /// Printer callbacks used to emit custom type and attribute encodings. 53 llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> 54 attributeWriterCallbacks; 55 llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> 56 typeWriterCallbacks; 57 58 /// A collection of non-dialect resource printers. 59 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; 60 }; 61 62 BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) 63 : impl(std::make_unique<Impl>(producer)) {} 64 BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, 65 StringRef producer) 66 : BytecodeWriterConfig(producer) { 67 attachFallbackResourcePrinter(map); 68 } 69 BytecodeWriterConfig::~BytecodeWriterConfig() = default; 70 71 ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> 72 BytecodeWriterConfig::getAttributeWriterCallbacks() const { 73 return impl->attributeWriterCallbacks; 74 } 75 76 ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> 77 BytecodeWriterConfig::getTypeWriterCallbacks() const { 78 return impl->typeWriterCallbacks; 79 } 80 81 void BytecodeWriterConfig::attachAttributeCallback( 82 std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) { 83 impl->attributeWriterCallbacks.emplace_back(std::move(callback)); 84 } 85 86 void BytecodeWriterConfig::attachTypeCallback( 87 std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) { 88 impl->typeWriterCallbacks.emplace_back(std::move(callback)); 89 } 90 91 void BytecodeWriterConfig::attachResourcePrinter( 92 std::unique_ptr<AsmResourcePrinter> printer) { 93 impl->externalResourcePrinters.emplace_back(std::move(printer)); 94 } 95 96 void BytecodeWriterConfig::setElideResourceDataFlag( 97 bool shouldElideResourceData) { 98 impl->shouldElideResourceData = shouldElideResourceData; 99 } 100 101 void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) { 102 impl->bytecodeVersion = bytecodeVersion; 103 } 104 105 int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { 106 return impl->bytecodeVersion; 107 } 108 109 llvm::StringMap<std::unique_ptr<DialectVersion>> & 110 BytecodeWriterConfig::getDialectVersionMap() const { 111 return impl->dialectVersionMap; 112 } 113 114 void BytecodeWriterConfig::setDialectVersion( 115 llvm::StringRef dialectName, 116 std::unique_ptr<DialectVersion> dialectVersion) const { 117 assert(!impl->dialectVersionMap.contains(dialectName) && 118 "cannot override a previously set dialect version"); 119 impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)}); 120 } 121 122 //===----------------------------------------------------------------------===// 123 // EncodingEmitter 124 //===----------------------------------------------------------------------===// 125 126 namespace { 127 /// This class functions as the underlying encoding emitter for the bytecode 128 /// writer. This class is a bit different compared to other types of encoders; 129 /// it does not use a single buffer, but instead may contain several buffers 130 /// (some owned by the writer, and some not) that get concatted during the final 131 /// emission. 132 class EncodingEmitter { 133 public: 134 EncodingEmitter() = default; 135 EncodingEmitter(const EncodingEmitter &) = delete; 136 EncodingEmitter &operator=(const EncodingEmitter &) = delete; 137 138 /// Write the current contents to the provided stream. 139 void writeTo(raw_ostream &os) const; 140 141 /// Return the current size of the encoded buffer. 142 size_t size() const { return prevResultSize + currentResult.size(); } 143 144 //===--------------------------------------------------------------------===// 145 // Emission 146 //===--------------------------------------------------------------------===// 147 148 /// Backpatch a byte in the result buffer at the given offset. 149 void patchByte(uint64_t offset, uint8_t value, StringLiteral desc) { 150 LLVM_DEBUG(llvm::dbgs() << "patchByte(" << offset << ',' << uint64_t(value) 151 << ")\t" << desc << '\n'); 152 assert(offset < size() && offset >= prevResultSize && 153 "cannot patch previously emitted data"); 154 currentResult[offset - prevResultSize] = value; 155 } 156 157 /// Emit the provided blob of data, which is owned by the caller and is 158 /// guaranteed to not die before the end of the bytecode process. 159 void emitOwnedBlob(ArrayRef<uint8_t> data, StringLiteral desc) { 160 LLVM_DEBUG(llvm::dbgs() 161 << "emitOwnedBlob(" << data.size() << "b)\t" << desc << '\n'); 162 // Push the current buffer before adding the provided data. 163 appendResult(std::move(currentResult)); 164 appendOwnedResult(data); 165 } 166 167 /// Emit the provided blob of data that has the given alignment, which is 168 /// owned by the caller and is guaranteed to not die before the end of the 169 /// bytecode process. The alignment value is also encoded, making it available 170 /// on load. 171 void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment, 172 StringLiteral desc) { 173 emitVarInt(alignment, desc); 174 emitVarInt(data.size(), desc); 175 176 alignTo(alignment); 177 emitOwnedBlob(data, desc); 178 } 179 void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment, 180 StringLiteral desc) { 181 ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()), 182 data.size()); 183 emitOwnedBlobAndAlignment(castedData, alignment, desc); 184 } 185 186 /// Align the emitter to the given alignment. 187 void alignTo(unsigned alignment) { 188 if (alignment < 2) 189 return; 190 assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment"); 191 192 // Check to see if we need to emit any padding bytes to meet the desired 193 // alignment. 194 size_t curOffset = size(); 195 size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset; 196 while (paddingSize--) 197 emitByte(bytecode::kAlignmentByte, "alignment byte"); 198 199 // Keep track of the maximum required alignment. 200 requiredAlignment = std::max(requiredAlignment, alignment); 201 } 202 203 //===--------------------------------------------------------------------===// 204 // Integer Emission 205 206 /// Emit a single byte. 207 template <typename T> 208 void emitByte(T byte, StringLiteral desc) { 209 LLVM_DEBUG(llvm::dbgs() 210 << "emitByte(" << uint64_t(byte) << ")\t" << desc << '\n'); 211 currentResult.push_back(static_cast<uint8_t>(byte)); 212 } 213 214 /// Emit a range of bytes. 215 void emitBytes(ArrayRef<uint8_t> bytes, StringLiteral desc) { 216 LLVM_DEBUG(llvm::dbgs() 217 << "emitBytes(" << bytes.size() << "b)\t" << desc << '\n'); 218 llvm::append_range(currentResult, bytes); 219 } 220 221 /// Emit a variable length integer. The first encoded byte contains a prefix 222 /// in the low bits indicating the encoded length of the value. This length 223 /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits 224 /// indicate the number of _additional_ bytes (not including the prefix byte). 225 /// All remaining bits in the first byte, along with all of the bits in 226 /// additional bytes, provide the value of the integer encoded in 227 /// little-endian order. 228 void emitVarInt(uint64_t value, StringLiteral desc) { 229 LLVM_DEBUG(llvm::dbgs() << "emitVarInt(" << value << ")\t" << desc << '\n'); 230 231 // In the most common case, the value can be represented in a single byte. 232 // Given how hot this case is, explicitly handle that here. 233 if ((value >> 7) == 0) 234 return emitByte((value << 1) | 0x1, desc); 235 emitMultiByteVarInt(value, desc); 236 } 237 238 /// Emit a signed variable length integer. Signed varints are encoded using 239 /// a varint with zigzag encoding, meaning that we use the low bit of the 240 /// value to indicate the sign of the value. This allows for more efficient 241 /// encoding of negative values by limiting the number of active bits 242 void emitSignedVarInt(uint64_t value, StringLiteral desc) { 243 emitVarInt((value << 1) ^ (uint64_t)((int64_t)value >> 63), desc); 244 } 245 246 /// Emit a variable length integer whose low bit is used to encode the 247 /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0). 248 void emitVarIntWithFlag(uint64_t value, bool flag, StringLiteral desc) { 249 emitVarInt((value << 1) | (flag ? 1 : 0), desc); 250 } 251 252 //===--------------------------------------------------------------------===// 253 // String Emission 254 255 /// Emit the given string as a nul terminated string. 256 void emitNulTerminatedString(StringRef str, StringLiteral desc) { 257 emitString(str, desc); 258 emitByte(0, "null terminator"); 259 } 260 261 /// Emit the given string without a nul terminator. 262 void emitString(StringRef str, StringLiteral desc) { 263 emitBytes({reinterpret_cast<const uint8_t *>(str.data()), str.size()}, 264 desc); 265 } 266 267 //===--------------------------------------------------------------------===// 268 // Section Emission 269 270 /// Emit a nested section of the given code, whose contents are encoded in the 271 /// provided emitter. 272 void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { 273 // Emit the section code and length. The high bit of the code is used to 274 // indicate whether the section alignment is present, so save an offset to 275 // it. 276 uint64_t codeOffset = currentResult.size(); 277 emitByte(code, "section code"); 278 emitVarInt(emitter.size(), "section size"); 279 280 // Integrate the alignment of the section into this emitter if necessary. 281 unsigned emitterAlign = emitter.requiredAlignment; 282 if (emitterAlign > 1) { 283 if (size() & (emitterAlign - 1)) { 284 emitVarInt(emitterAlign, "section alignment"); 285 alignTo(emitterAlign); 286 287 // Indicate that we needed to align the section, the high bit of the 288 // code field is used for this. 289 currentResult[codeOffset] |= 0b10000000; 290 } else { 291 // Otherwise, if we happen to be at a compatible offset, we just 292 // remember that we need this alignment. 293 requiredAlignment = std::max(requiredAlignment, emitterAlign); 294 } 295 } 296 297 // Push our current buffer and then merge the provided section body into 298 // ours. 299 appendResult(std::move(currentResult)); 300 for (std::vector<uint8_t> &result : emitter.prevResultStorage) 301 prevResultStorage.push_back(std::move(result)); 302 llvm::append_range(prevResultList, emitter.prevResultList); 303 prevResultSize += emitter.prevResultSize; 304 appendResult(std::move(emitter.currentResult)); 305 } 306 307 private: 308 /// Emit the given value using a variable width encoding. This method is a 309 /// fallback when the number of bytes needed to encode the value is greater 310 /// than 1. We mark it noinline here so that the single byte hot path isn't 311 /// pessimized. 312 LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value, 313 StringLiteral desc); 314 315 /// Append a new result buffer to the current contents. 316 void appendResult(std::vector<uint8_t> &&result) { 317 if (result.empty()) 318 return; 319 prevResultStorage.emplace_back(std::move(result)); 320 appendOwnedResult(prevResultStorage.back()); 321 } 322 void appendOwnedResult(ArrayRef<uint8_t> result) { 323 if (result.empty()) 324 return; 325 prevResultSize += result.size(); 326 prevResultList.emplace_back(result); 327 } 328 329 /// The result of the emitter currently being built. We refrain from building 330 /// a single buffer to simplify emitting sections, large data, and more. The 331 /// result is thus represented using multiple distinct buffers, some of which 332 /// we own (via prevResultStorage), and some of which are just pointers into 333 /// externally owned buffers. 334 std::vector<uint8_t> currentResult; 335 std::vector<ArrayRef<uint8_t>> prevResultList; 336 std::vector<std::vector<uint8_t>> prevResultStorage; 337 338 /// An up-to-date total size of all of the buffers within `prevResultList`. 339 /// This enables O(1) size checks of the current encoding. 340 size_t prevResultSize = 0; 341 342 /// The highest required alignment for the start of this section. 343 unsigned requiredAlignment = 1; 344 }; 345 346 //===----------------------------------------------------------------------===// 347 // StringSectionBuilder 348 //===----------------------------------------------------------------------===// 349 350 namespace { 351 /// This class is used to simplify the process of emitting the string section. 352 class StringSectionBuilder { 353 public: 354 /// Add the given string to the string section, and return the index of the 355 /// string within the section. 356 size_t insert(StringRef str) { 357 auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); 358 return it.first->second; 359 } 360 361 /// Write the current set of strings to the given emitter. 362 void write(EncodingEmitter &emitter) { 363 emitter.emitVarInt(strings.size(), "string section size"); 364 365 // Emit the sizes in reverse order, so that we don't need to backpatch an 366 // offset to the string data or have a separate section. 367 for (const auto &it : llvm::reverse(strings)) 368 emitter.emitVarInt(it.first.size() + 1, "string size"); 369 // Emit the string data itself. 370 for (const auto &it : strings) 371 emitter.emitNulTerminatedString(it.first.val(), "string"); 372 } 373 374 private: 375 /// A set of strings referenced within the bytecode. The value of the map is 376 /// unused. 377 llvm::MapVector<llvm::CachedHashStringRef, size_t> strings; 378 }; 379 } // namespace 380 381 class DialectWriter : public DialectBytecodeWriter { 382 using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>; 383 384 public: 385 DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, 386 IRNumberingState &numberingState, 387 StringSectionBuilder &stringSection, 388 const DialectVersionMapT &dialectVersionMap) 389 : bytecodeVersion(bytecodeVersion), emitter(emitter), 390 numberingState(numberingState), stringSection(stringSection), 391 dialectVersionMap(dialectVersionMap) {} 392 393 //===--------------------------------------------------------------------===// 394 // IR 395 //===--------------------------------------------------------------------===// 396 397 void writeAttribute(Attribute attr) override { 398 emitter.emitVarInt(numberingState.getNumber(attr), "dialect attr"); 399 } 400 void writeOptionalAttribute(Attribute attr) override { 401 if (!attr) { 402 emitter.emitVarInt(0, "dialect optional attr none"); 403 return; 404 } 405 emitter.emitVarIntWithFlag(numberingState.getNumber(attr), true, 406 "dialect optional attr"); 407 } 408 409 void writeType(Type type) override { 410 emitter.emitVarInt(numberingState.getNumber(type), "dialect type"); 411 } 412 413 void writeResourceHandle(const AsmDialectResourceHandle &resource) override { 414 emitter.emitVarInt(numberingState.getNumber(resource), "dialect resource"); 415 } 416 417 //===--------------------------------------------------------------------===// 418 // Primitives 419 //===--------------------------------------------------------------------===// 420 421 void writeVarInt(uint64_t value) override { 422 emitter.emitVarInt(value, "dialect writer"); 423 } 424 425 void writeSignedVarInt(int64_t value) override { 426 emitter.emitSignedVarInt(value, "dialect writer"); 427 } 428 429 void writeAPIntWithKnownWidth(const APInt &value) override { 430 size_t bitWidth = value.getBitWidth(); 431 432 // If the value is a single byte, just emit it directly without going 433 // through a varint. 434 if (bitWidth <= 8) 435 return emitter.emitByte(value.getLimitedValue(), "dialect APInt"); 436 437 // If the value fits within a single varint, emit it directly. 438 if (bitWidth <= 64) 439 return emitter.emitSignedVarInt(value.getLimitedValue(), "dialect APInt"); 440 441 // Otherwise, we need to encode a variable number of active words. We use 442 // active words instead of the number of total words under the observation 443 // that smaller values will be more common. 444 unsigned numActiveWords = value.getActiveWords(); 445 emitter.emitVarInt(numActiveWords, "dialect APInt word count"); 446 447 const uint64_t *rawValueData = value.getRawData(); 448 for (unsigned i = 0; i < numActiveWords; ++i) 449 emitter.emitSignedVarInt(rawValueData[i], "dialect APInt word"); 450 } 451 452 void writeAPFloatWithKnownSemantics(const APFloat &value) override { 453 writeAPIntWithKnownWidth(value.bitcastToAPInt()); 454 } 455 456 void writeOwnedString(StringRef str) override { 457 emitter.emitVarInt(stringSection.insert(str), "dialect string"); 458 } 459 460 void writeOwnedBlob(ArrayRef<char> blob) override { 461 emitter.emitVarInt(blob.size(), "dialect blob"); 462 emitter.emitOwnedBlob( 463 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(blob.data()), 464 blob.size()), 465 "dialect blob"); 466 } 467 468 void writeOwnedBool(bool value) override { 469 emitter.emitByte(value, "dialect bool"); 470 } 471 472 int64_t getBytecodeVersion() const override { return bytecodeVersion; } 473 474 FailureOr<const DialectVersion *> 475 getDialectVersion(StringRef dialectName) const override { 476 auto dialectEntry = dialectVersionMap.find(dialectName); 477 if (dialectEntry == dialectVersionMap.end()) 478 return failure(); 479 return dialectEntry->getValue().get(); 480 } 481 482 private: 483 int64_t bytecodeVersion; 484 EncodingEmitter &emitter; 485 IRNumberingState &numberingState; 486 StringSectionBuilder &stringSection; 487 const DialectVersionMapT &dialectVersionMap; 488 }; 489 490 namespace { 491 class PropertiesSectionBuilder { 492 public: 493 PropertiesSectionBuilder(IRNumberingState &numberingState, 494 StringSectionBuilder &stringSection, 495 const BytecodeWriterConfig::Impl &config) 496 : numberingState(numberingState), stringSection(stringSection), 497 config(config) {} 498 499 /// Emit the op properties in the properties section and return the index of 500 /// the properties within the section. Return -1 if no properties was emitted. 501 std::optional<ssize_t> emit(Operation *op) { 502 EncodingEmitter propertiesEmitter; 503 if (!op->getPropertiesStorageSize()) 504 return std::nullopt; 505 if (!op->isRegistered()) { 506 // Unregistered op are storing properties as an optional attribute. 507 Attribute prop = *op->getPropertiesStorage().as<Attribute *>(); 508 if (!prop) 509 return std::nullopt; 510 EncodingEmitter sizeEmitter; 511 sizeEmitter.emitVarInt(numberingState.getNumber(prop), "properties size"); 512 scratch.clear(); 513 llvm::raw_svector_ostream os(scratch); 514 sizeEmitter.writeTo(os); 515 return emit(scratch); 516 } 517 518 EncodingEmitter emitter; 519 DialectWriter propertiesWriter(config.bytecodeVersion, emitter, 520 numberingState, stringSection, 521 config.dialectVersionMap); 522 auto iface = cast<BytecodeOpInterface>(op); 523 iface.writeProperties(propertiesWriter); 524 scratch.clear(); 525 llvm::raw_svector_ostream os(scratch); 526 emitter.writeTo(os); 527 return emit(scratch); 528 } 529 530 /// Write the current set of properties to the given emitter. 531 void write(EncodingEmitter &emitter) { 532 emitter.emitVarInt(propertiesStorage.size(), "properties size"); 533 if (propertiesStorage.empty()) 534 return; 535 for (const auto &storage : propertiesStorage) { 536 if (storage.empty()) { 537 emitter.emitBytes(ArrayRef<uint8_t>(), "empty properties"); 538 continue; 539 } 540 emitter.emitBytes(ArrayRef(reinterpret_cast<const uint8_t *>(&storage[0]), 541 storage.size()), 542 "property"); 543 } 544 } 545 546 /// Returns true if the section is empty. 547 bool empty() { return propertiesStorage.empty(); } 548 549 private: 550 /// Emit raw data and returns the offset in the internal buffer. 551 /// Data are deduplicated and will be copied in the internal buffer only if 552 /// they don't exist there already. 553 ssize_t emit(ArrayRef<char> rawProperties) { 554 // Populate a scratch buffer with the properties size. 555 SmallVector<char> sizeScratch; 556 { 557 EncodingEmitter sizeEmitter; 558 sizeEmitter.emitVarInt(rawProperties.size(), "properties"); 559 llvm::raw_svector_ostream os(sizeScratch); 560 sizeEmitter.writeTo(os); 561 } 562 // Append a new storage to the table now. 563 size_t index = propertiesStorage.size(); 564 propertiesStorage.emplace_back(); 565 std::vector<char> &newStorage = propertiesStorage.back(); 566 size_t propertiesSize = sizeScratch.size() + rawProperties.size(); 567 newStorage.reserve(propertiesSize); 568 newStorage.insert(newStorage.end(), sizeScratch.begin(), sizeScratch.end()); 569 newStorage.insert(newStorage.end(), rawProperties.begin(), 570 rawProperties.end()); 571 572 // Try to de-duplicate the new serialized properties. 573 // If the properties is a duplicate, pop it back from the storage. 574 auto inserted = propertiesUniquing.insert( 575 std::make_pair(ArrayRef<char>(newStorage), index)); 576 if (!inserted.second) 577 propertiesStorage.pop_back(); 578 return inserted.first->getSecond(); 579 } 580 581 /// Storage for properties. 582 std::vector<std::vector<char>> propertiesStorage; 583 SmallVector<char> scratch; 584 DenseMap<ArrayRef<char>, int64_t> propertiesUniquing; 585 IRNumberingState &numberingState; 586 StringSectionBuilder &stringSection; 587 const BytecodeWriterConfig::Impl &config; 588 }; 589 } // namespace 590 591 /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need 592 /// to go through an intermediate buffer when interacting with code that wants a 593 /// raw_ostream. 594 class RawEmitterOstream : public raw_ostream { 595 public: 596 explicit RawEmitterOstream(EncodingEmitter &emitter) : emitter(emitter) { 597 SetUnbuffered(); 598 } 599 600 private: 601 void write_impl(const char *ptr, size_t size) override { 602 emitter.emitBytes({reinterpret_cast<const uint8_t *>(ptr), size}, 603 "raw emitter"); 604 } 605 uint64_t current_pos() const override { return emitter.size(); } 606 607 /// The section being emitted to. 608 EncodingEmitter &emitter; 609 }; 610 } // namespace 611 612 void EncodingEmitter::writeTo(raw_ostream &os) const { 613 for (auto &prevResult : prevResultList) 614 os.write((const char *)prevResult.data(), prevResult.size()); 615 os.write((const char *)currentResult.data(), currentResult.size()); 616 } 617 618 void EncodingEmitter::emitMultiByteVarInt(uint64_t value, StringLiteral desc) { 619 // Compute the number of bytes needed to encode the value. Each byte can hold 620 // up to 7-bits of data. We only check up to the number of bits we can encode 621 // in the first byte (8). 622 uint64_t it = value >> 7; 623 for (size_t numBytes = 2; numBytes < 9; ++numBytes) { 624 if (LLVM_LIKELY(it >>= 7) == 0) { 625 uint64_t encodedValue = (value << 1) | 0x1; 626 encodedValue <<= (numBytes - 1); 627 llvm::support::ulittle64_t encodedValueLE(encodedValue); 628 emitBytes({reinterpret_cast<uint8_t *>(&encodedValueLE), numBytes}, desc); 629 return; 630 } 631 } 632 633 // If the value is too large to encode in a single byte, emit a special all 634 // zero marker byte and splat the value directly. 635 emitByte(0, desc); 636 llvm::support::ulittle64_t valueLE(value); 637 emitBytes({reinterpret_cast<uint8_t *>(&valueLE), sizeof(valueLE)}, desc); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // Bytecode Writer 642 //===----------------------------------------------------------------------===// 643 644 namespace { 645 class BytecodeWriter { 646 public: 647 BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) 648 : numberingState(op, config), config(config.getImpl()), 649 propertiesSection(numberingState, stringSection, config.getImpl()) {} 650 651 /// Write the bytecode for the given root operation. 652 LogicalResult write(Operation *rootOp, raw_ostream &os); 653 654 private: 655 //===--------------------------------------------------------------------===// 656 // Dialects 657 658 void writeDialectSection(EncodingEmitter &emitter); 659 660 //===--------------------------------------------------------------------===// 661 // Attributes and Types 662 663 void writeAttrTypeSection(EncodingEmitter &emitter); 664 665 //===--------------------------------------------------------------------===// 666 // Operations 667 668 LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); 669 LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); 670 LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); 671 LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); 672 673 LogicalResult writeRegions(EncodingEmitter &emitter, 674 MutableArrayRef<Region> regions) { 675 return success(llvm::all_of(regions, [&](Region ®ion) { 676 return succeeded(writeRegion(emitter, ®ion)); 677 })); 678 } 679 680 //===--------------------------------------------------------------------===// 681 // Resources 682 683 void writeResourceSection(Operation *op, EncodingEmitter &emitter); 684 685 //===--------------------------------------------------------------------===// 686 // Strings 687 688 void writeStringSection(EncodingEmitter &emitter); 689 690 //===--------------------------------------------------------------------===// 691 // Properties 692 693 void writePropertiesSection(EncodingEmitter &emitter); 694 695 //===--------------------------------------------------------------------===// 696 // Helpers 697 698 void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, 699 ValueRange range); 700 701 //===--------------------------------------------------------------------===// 702 // Fields 703 704 /// The builder used for the string section. 705 StringSectionBuilder stringSection; 706 707 /// The IR numbering state generated for the root operation. 708 IRNumberingState numberingState; 709 710 /// Configuration dictating bytecode emission. 711 const BytecodeWriterConfig::Impl &config; 712 713 /// Storage for the properties section 714 PropertiesSectionBuilder propertiesSection; 715 }; 716 } // namespace 717 718 LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { 719 EncodingEmitter emitter; 720 721 // Emit the bytecode file header. This is how we identify the output as a 722 // bytecode file. 723 emitter.emitString("ML\xefR", "bytecode header"); 724 725 // Emit the bytecode version. 726 if (config.bytecodeVersion < bytecode::kMinSupportedVersion || 727 config.bytecodeVersion > bytecode::kVersion) 728 return rootOp->emitError() 729 << "unsupported version requested " << config.bytecodeVersion 730 << ", must be in range [" 731 << static_cast<int64_t>(bytecode::kMinSupportedVersion) << ", " 732 << static_cast<int64_t>(bytecode::kVersion) << ']'; 733 emitter.emitVarInt(config.bytecodeVersion, "bytecode version"); 734 735 // Emit the producer. 736 emitter.emitNulTerminatedString(config.producer, "bytecode producer"); 737 738 // Emit the dialect section. 739 writeDialectSection(emitter); 740 741 // Emit the attributes and types section. 742 writeAttrTypeSection(emitter); 743 744 // Emit the IR section. 745 if (failed(writeIRSection(emitter, rootOp))) 746 return failure(); 747 748 // Emit the resources section. 749 writeResourceSection(rootOp, emitter); 750 751 // Emit the string section. 752 writeStringSection(emitter); 753 754 // Emit the properties section. 755 if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) 756 writePropertiesSection(emitter); 757 else if (!propertiesSection.empty()) 758 return rootOp->emitError( 759 "unexpected properties emitted incompatible with bytecode <5"); 760 761 // Write the generated bytecode to the provided output stream. 762 emitter.writeTo(os); 763 764 return success(); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // Dialects 769 770 /// Write the given entries in contiguous groups with the same parent dialect. 771 /// Each dialect sub-group is encoded with the parent dialect and number of 772 /// elements, followed by the encoding for the entries. The given callback is 773 /// invoked to encode each individual entry. 774 template <typename EntriesT, typename EntryCallbackT> 775 static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries, 776 EntryCallbackT &&callback) { 777 for (auto it = entries.begin(), e = entries.end(); it != e;) { 778 auto groupStart = it++; 779 780 // Find the end of the group that shares the same parent dialect. 781 DialectNumbering *currentDialect = groupStart->dialect; 782 it = std::find_if(it, e, [&](const auto &entry) { 783 return entry.dialect != currentDialect; 784 }); 785 786 // Emit the dialect and number of elements. 787 emitter.emitVarInt(currentDialect->number, "dialect number"); 788 emitter.emitVarInt(std::distance(groupStart, it), "dialect offset"); 789 790 // Emit the entries within the group. 791 for (auto &entry : llvm::make_range(groupStart, it)) 792 callback(entry); 793 } 794 } 795 796 void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { 797 EncodingEmitter dialectEmitter; 798 799 // Emit the referenced dialects. 800 auto dialects = numberingState.getDialects(); 801 dialectEmitter.emitVarInt(llvm::size(dialects), "dialects count"); 802 for (DialectNumbering &dialect : dialects) { 803 // Write the string section and get the ID. 804 size_t nameID = stringSection.insert(dialect.name); 805 806 if (config.bytecodeVersion < bytecode::kDialectVersioning) { 807 dialectEmitter.emitVarInt(nameID, "dialect name ID"); 808 continue; 809 } 810 811 // Try writing the version to the versionEmitter. 812 EncodingEmitter versionEmitter; 813 if (dialect.interface) { 814 // The writer used when emitting using a custom bytecode encoding. 815 DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, 816 numberingState, stringSection, 817 config.dialectVersionMap); 818 dialect.interface->writeVersion(versionWriter); 819 } 820 821 // If the version emitter is empty, version is not available. We can encode 822 // this in the dialect ID, so if there is no version, we don't write the 823 // section. 824 size_t versionAvailable = versionEmitter.size() > 0; 825 dialectEmitter.emitVarIntWithFlag(nameID, versionAvailable, 826 "dialect version"); 827 if (versionAvailable) 828 dialectEmitter.emitSection(bytecode::Section::kDialectVersions, 829 std::move(versionEmitter)); 830 } 831 832 if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) 833 dialectEmitter.emitVarInt(size(numberingState.getOpNames()), 834 "op names count"); 835 836 // Emit the referenced operation names grouped by dialect. 837 auto emitOpName = [&](OpNameNumbering &name) { 838 size_t stringId = stringSection.insert(name.name.stripDialect()); 839 if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding) 840 dialectEmitter.emitVarInt(stringId, "dialect op name"); 841 else 842 dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(), 843 "dialect op name"); 844 }; 845 writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); 846 847 emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); 848 } 849 850 //===----------------------------------------------------------------------===// 851 // Attributes and Types 852 853 void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { 854 EncodingEmitter attrTypeEmitter; 855 EncodingEmitter offsetEmitter; 856 offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()), 857 "attributes count"); 858 offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()), 859 "types count"); 860 861 // A functor used to emit an attribute or type entry. 862 uint64_t prevOffset = 0; 863 auto emitAttrOrType = [&](auto &entry) { 864 auto entryValue = entry.getValue(); 865 866 auto emitAttrOrTypeRawImpl = [&]() -> void { 867 RawEmitterOstream(attrTypeEmitter) << entryValue; 868 attrTypeEmitter.emitByte(0, "attr/type separator"); 869 }; 870 auto emitAttrOrTypeImpl = [&]() -> bool { 871 // TODO: We don't currently support custom encoded mutable types and 872 // attributes. 873 if (entryValue.template hasTrait<TypeTrait::IsMutable>() || 874 entryValue.template hasTrait<AttributeTrait::IsMutable>()) { 875 emitAttrOrTypeRawImpl(); 876 return false; 877 } 878 879 DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, 880 numberingState, stringSection, 881 config.dialectVersionMap); 882 if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { 883 for (const auto &callback : config.typeWriterCallbacks) { 884 if (succeeded(callback->write(entryValue, dialectWriter))) 885 return true; 886 } 887 if (const BytecodeDialectInterface *interface = 888 entry.dialect->interface) { 889 if (succeeded(interface->writeType(entryValue, dialectWriter))) 890 return true; 891 } 892 } else { 893 for (const auto &callback : config.attributeWriterCallbacks) { 894 if (succeeded(callback->write(entryValue, dialectWriter))) 895 return true; 896 } 897 if (const BytecodeDialectInterface *interface = 898 entry.dialect->interface) { 899 if (succeeded(interface->writeAttribute(entryValue, dialectWriter))) 900 return true; 901 } 902 } 903 904 // If the entry was not emitted using a callback or a dialect interface, 905 // emit it using the textual format. 906 emitAttrOrTypeRawImpl(); 907 return false; 908 }; 909 910 bool hasCustomEncoding = emitAttrOrTypeImpl(); 911 912 // Record the offset of this entry. 913 uint64_t curOffset = attrTypeEmitter.size(); 914 offsetEmitter.emitVarIntWithFlag(curOffset - prevOffset, hasCustomEncoding, 915 "attr/type offset"); 916 prevOffset = curOffset; 917 }; 918 919 // Emit the attribute and type entries for each dialect. 920 writeDialectGrouping(offsetEmitter, numberingState.getAttributes(), 921 emitAttrOrType); 922 writeDialectGrouping(offsetEmitter, numberingState.getTypes(), 923 emitAttrOrType); 924 925 // Emit the sections to the stream. 926 emitter.emitSection(bytecode::Section::kAttrTypeOffset, 927 std::move(offsetEmitter)); 928 emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter)); 929 } 930 931 //===----------------------------------------------------------------------===// 932 // Operations 933 934 LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, 935 Block *block) { 936 ArrayRef<BlockArgument> args = block->getArguments(); 937 bool hasArgs = !args.empty(); 938 939 // Emit the number of operations in this block, and if it has arguments. We 940 // use the low bit of the operation count to indicate if the block has 941 // arguments. 942 unsigned numOps = numberingState.getOperationCount(block); 943 emitter.emitVarIntWithFlag(numOps, hasArgs, "block num ops"); 944 945 // Emit the arguments of the block. 946 if (hasArgs) { 947 emitter.emitVarInt(args.size(), "block args count"); 948 for (BlockArgument arg : args) { 949 Location argLoc = arg.getLoc(); 950 if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) { 951 emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()), 952 !isa<UnknownLoc>(argLoc), "block arg type"); 953 if (!isa<UnknownLoc>(argLoc)) 954 emitter.emitVarInt(numberingState.getNumber(argLoc), 955 "block arg location"); 956 } else { 957 emitter.emitVarInt(numberingState.getNumber(arg.getType()), 958 "block arg type"); 959 emitter.emitVarInt(numberingState.getNumber(argLoc), 960 "block arg location"); 961 } 962 } 963 if (config.bytecodeVersion >= bytecode::kUseListOrdering) { 964 uint64_t maskOffset = emitter.size(); 965 uint8_t encodingMask = 0; 966 emitter.emitByte(0, "use-list separator"); 967 writeUseListOrders(emitter, encodingMask, args); 968 if (encodingMask) 969 emitter.patchByte(maskOffset, encodingMask, "block patch encoding"); 970 } 971 } 972 973 // Emit the operations within the block. 974 for (Operation &op : *block) 975 if (failed(writeOp(emitter, &op))) 976 return failure(); 977 return success(); 978 } 979 980 LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { 981 emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID"); 982 983 // Emit a mask for the operation components. We need to fill this in later 984 // (when we actually know what needs to be emitted), so emit a placeholder for 985 // now. 986 uint64_t maskOffset = emitter.size(); 987 uint8_t opEncodingMask = 0; 988 emitter.emitByte(0, "op separator"); 989 990 // Emit the location for this operation. 991 emitter.emitVarInt(numberingState.getNumber(op->getLoc()), "op location"); 992 993 // Emit the attributes of this operation. 994 DictionaryAttr attrs = op->getDiscardableAttrDictionary(); 995 // Allow deployment to version <kNativePropertiesEncoding by merging inherent 996 // attribute with the discardable ones. We should fail if there are any 997 // conflicts. When properties are not used by the op, also store everything as 998 // attributes. 999 if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding || 1000 !op->getPropertiesStorage()) { 1001 attrs = op->getAttrDictionary(); 1002 } 1003 if (!attrs.empty()) { 1004 opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; 1005 emitter.emitVarInt(numberingState.getNumber(attrs), "op attrs count"); 1006 } 1007 1008 // Emit the properties of this operation, for now we still support deployment 1009 // to version <kNativePropertiesEncoding. 1010 if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) { 1011 std::optional<ssize_t> propertiesId = propertiesSection.emit(op); 1012 if (propertiesId.has_value()) { 1013 opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; 1014 emitter.emitVarInt(*propertiesId, "op properties ID"); 1015 } 1016 } 1017 1018 // Emit the result types of the operation. 1019 if (unsigned numResults = op->getNumResults()) { 1020 opEncodingMask |= bytecode::OpEncodingMask::kHasResults; 1021 emitter.emitVarInt(numResults, "op results count"); 1022 for (Type type : op->getResultTypes()) 1023 emitter.emitVarInt(numberingState.getNumber(type), "op result type"); 1024 } 1025 1026 // Emit the operands of the operation. 1027 if (unsigned numOperands = op->getNumOperands()) { 1028 opEncodingMask |= bytecode::OpEncodingMask::kHasOperands; 1029 emitter.emitVarInt(numOperands, "op operands count"); 1030 for (Value operand : op->getOperands()) 1031 emitter.emitVarInt(numberingState.getNumber(operand), "op operand types"); 1032 } 1033 1034 // Emit the successors of the operation. 1035 if (unsigned numSuccessors = op->getNumSuccessors()) { 1036 opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors; 1037 emitter.emitVarInt(numSuccessors, "op successors count"); 1038 for (Block *successor : op->getSuccessors()) 1039 emitter.emitVarInt(numberingState.getNumber(successor), "op successor"); 1040 } 1041 1042 // Emit the use-list orders to bytecode, so we can reconstruct the same order 1043 // at parsing. 1044 if (config.bytecodeVersion >= bytecode::kUseListOrdering) 1045 writeUseListOrders(emitter, opEncodingMask, ValueRange(op->getResults())); 1046 1047 // Check for regions. 1048 unsigned numRegions = op->getNumRegions(); 1049 if (numRegions) 1050 opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions; 1051 1052 // Update the mask for the operation. 1053 emitter.patchByte(maskOffset, opEncodingMask, "op encoding mask"); 1054 1055 // With the mask emitted, we can now emit the regions of the operation. We do 1056 // this after mask emission to avoid offset complications that may arise by 1057 // emitting the regions first (e.g. if the regions are huge, backpatching the 1058 // op encoding mask is more annoying). 1059 if (numRegions) { 1060 bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op); 1061 emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove, 1062 "op regions count"); 1063 1064 // If the region is not isolated from above, or we are emitting bytecode 1065 // targeting version <kLazyLoading, we don't use a section. 1066 if (isIsolatedFromAbove && 1067 config.bytecodeVersion >= bytecode::kLazyLoading) { 1068 EncodingEmitter regionEmitter; 1069 if (failed(writeRegions(regionEmitter, op->getRegions()))) 1070 return failure(); 1071 emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); 1072 1073 } else if (failed(writeRegions(emitter, op->getRegions()))) { 1074 return failure(); 1075 } 1076 } 1077 return success(); 1078 } 1079 1080 void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, 1081 uint8_t &opEncodingMask, 1082 ValueRange range) { 1083 // Loop over the results and store the use-list order per result index. 1084 DenseMap<unsigned, llvm::SmallVector<unsigned>> map; 1085 for (auto item : llvm::enumerate(range)) { 1086 auto value = item.value(); 1087 // No need to store a custom use-list order if the result does not have 1088 // multiple uses. 1089 if (value.use_empty() || value.hasOneUse()) 1090 continue; 1091 1092 // For each result, assemble the list of pairs (use-list-index, 1093 // global-value-index). While doing so, detect if the global-value-index is 1094 // already ordered with respect to the use-list-index. 1095 bool alreadyOrdered = true; 1096 auto &firstUse = *value.use_begin(); 1097 uint64_t prevID = bytecode::getUseID( 1098 firstUse, numberingState.getNumber(firstUse.getOwner())); 1099 llvm::SmallVector<std::pair<unsigned, uint64_t>> useListPairs( 1100 {{0, prevID}}); 1101 1102 for (auto use : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1103 uint64_t currentID = bytecode::getUseID( 1104 use.value(), numberingState.getNumber(use.value().getOwner())); 1105 // The use-list order achieved when building the IR at parsing always 1106 // pushes new uses on front. Hence, if the order by unique ID is 1107 // monotonically decreasing, a roundtrip to bytecode preserves such order. 1108 alreadyOrdered &= (prevID > currentID); 1109 useListPairs.push_back({use.index(), currentID}); 1110 prevID = currentID; 1111 } 1112 1113 // Do not emit if the order is already sorted. 1114 if (alreadyOrdered) 1115 continue; 1116 1117 // Sort the use indices by the unique ID indices in descending order. 1118 std::sort( 1119 useListPairs.begin(), useListPairs.end(), 1120 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1121 1122 map.try_emplace(item.index(), llvm::map_range(useListPairs, [](auto elem) { 1123 return elem.first; 1124 })); 1125 } 1126 1127 if (map.empty()) 1128 return; 1129 1130 opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders; 1131 // Emit the number of results that have a custom use-list order if the number 1132 // of results is greater than one. 1133 if (range.size() != 1) { 1134 emitter.emitVarInt(map.size(), "custom use-list size"); 1135 } 1136 1137 for (const auto &item : map) { 1138 auto resultIdx = item.getFirst(); 1139 auto useListOrder = item.getSecond(); 1140 1141 // Compute the number of uses that are actually shuffled. If those are less 1142 // than half of the total uses, encoding the index pair `(src, dst)` is more 1143 // space efficient. 1144 size_t shuffledElements = 1145 llvm::count_if(llvm::enumerate(useListOrder), 1146 [](auto item) { return item.index() != item.value(); }); 1147 bool indexPairEncoding = shuffledElements < (useListOrder.size() / 2); 1148 1149 // For single result, we don't need to store the result index. 1150 if (range.size() != 1) 1151 emitter.emitVarInt(resultIdx, "use-list result index"); 1152 1153 if (indexPairEncoding) { 1154 emitter.emitVarIntWithFlag(shuffledElements * 2, indexPairEncoding, 1155 "use-list index pair size"); 1156 for (auto pair : llvm::enumerate(useListOrder)) { 1157 if (pair.index() != pair.value()) { 1158 emitter.emitVarInt(pair.value(), "use-list index pair first"); 1159 emitter.emitVarInt(pair.index(), "use-list index pair second"); 1160 } 1161 } 1162 } else { 1163 emitter.emitVarIntWithFlag(useListOrder.size(), indexPairEncoding, 1164 "use-list size"); 1165 for (const auto &index : useListOrder) 1166 emitter.emitVarInt(index, "use-list order"); 1167 } 1168 } 1169 } 1170 1171 LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, 1172 Region *region) { 1173 // If the region is empty, we only need to emit the number of blocks (which is 1174 // zero). 1175 if (region->empty()) { 1176 emitter.emitVarInt(/*numBlocks*/ 0, "region block count empty"); 1177 return success(); 1178 } 1179 1180 // Emit the number of blocks and values within the region. 1181 unsigned numBlocks, numValues; 1182 std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region); 1183 emitter.emitVarInt(numBlocks, "region block count"); 1184 emitter.emitVarInt(numValues, "region value count"); 1185 1186 // Emit the blocks within the region. 1187 for (Block &block : *region) 1188 if (failed(writeBlock(emitter, &block))) 1189 return failure(); 1190 return success(); 1191 } 1192 1193 LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, 1194 Operation *op) { 1195 EncodingEmitter irEmitter; 1196 1197 // Write the IR section the same way as a block with no arguments. Note that 1198 // the low-bit of the operation count for a block is used to indicate if the 1199 // block has arguments, which in this case is always false. 1200 irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false, "ir section"); 1201 1202 // Emit the operations. 1203 if (failed(writeOp(irEmitter, op))) 1204 return failure(); 1205 1206 emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); 1207 return success(); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // Resources 1212 1213 namespace { 1214 /// This class represents a resource builder implementation for the MLIR 1215 /// bytecode format. 1216 class ResourceBuilder : public AsmResourceBuilder { 1217 public: 1218 using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>; 1219 1220 ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection, 1221 PostProcessFn postProcessFn, bool shouldElideData) 1222 : emitter(emitter), stringSection(stringSection), 1223 postProcessFn(postProcessFn), shouldElideData(shouldElideData) {} 1224 ~ResourceBuilder() override = default; 1225 1226 void buildBlob(StringRef key, ArrayRef<char> data, 1227 uint32_t dataAlignment) final { 1228 if (!shouldElideData) 1229 emitter.emitOwnedBlobAndAlignment(data, dataAlignment, "resource blob"); 1230 postProcessFn(key, AsmResourceEntryKind::Blob); 1231 } 1232 void buildBool(StringRef key, bool data) final { 1233 if (!shouldElideData) 1234 emitter.emitByte(data, "resource bool"); 1235 postProcessFn(key, AsmResourceEntryKind::Bool); 1236 } 1237 void buildString(StringRef key, StringRef data) final { 1238 if (!shouldElideData) 1239 emitter.emitVarInt(stringSection.insert(data), "resource string"); 1240 postProcessFn(key, AsmResourceEntryKind::String); 1241 } 1242 1243 private: 1244 EncodingEmitter &emitter; 1245 StringSectionBuilder &stringSection; 1246 PostProcessFn postProcessFn; 1247 bool shouldElideData = false; 1248 }; 1249 } // namespace 1250 1251 void BytecodeWriter::writeResourceSection(Operation *op, 1252 EncodingEmitter &emitter) { 1253 EncodingEmitter resourceEmitter; 1254 EncodingEmitter resourceOffsetEmitter; 1255 uint64_t prevOffset = 0; 1256 SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>> 1257 curResourceEntries; 1258 1259 // Functor used to process the offset for a resource of `kind` defined by 1260 // 'key'. 1261 auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) { 1262 uint64_t curOffset = resourceEmitter.size(); 1263 curResourceEntries.emplace_back(key, kind, curOffset - prevOffset); 1264 prevOffset = curOffset; 1265 }; 1266 1267 // Functor used to emit a resource group defined by 'key'. 1268 auto emitResourceGroup = [&](uint64_t key) { 1269 resourceOffsetEmitter.emitVarInt(key, "resource group key"); 1270 resourceOffsetEmitter.emitVarInt(curResourceEntries.size(), 1271 "resource group size"); 1272 for (auto [key, kind, size] : curResourceEntries) { 1273 resourceOffsetEmitter.emitVarInt(stringSection.insert(key), 1274 "resource key"); 1275 resourceOffsetEmitter.emitVarInt(size, "resource size"); 1276 resourceOffsetEmitter.emitByte(kind, "resource kind"); 1277 } 1278 }; 1279 1280 // Builder used to emit resources. 1281 ResourceBuilder entryBuilder(resourceEmitter, stringSection, 1282 appendResourceOffset, 1283 config.shouldElideResourceData); 1284 1285 // Emit the external resource entries. 1286 resourceOffsetEmitter.emitVarInt(config.externalResourcePrinters.size(), 1287 "external resource printer count"); 1288 for (const auto &printer : config.externalResourcePrinters) { 1289 curResourceEntries.clear(); 1290 printer->buildResources(op, entryBuilder); 1291 emitResourceGroup(stringSection.insert(printer->getName())); 1292 } 1293 1294 // Emit the dialect resource entries. 1295 for (DialectNumbering &dialect : numberingState.getDialects()) { 1296 if (!dialect.asmInterface) 1297 continue; 1298 curResourceEntries.clear(); 1299 dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder); 1300 1301 // Emit the declaration resources for this dialect, these didn't get emitted 1302 // by the interface. These resources don't have data attached, so just use a 1303 // "blob" kind as a placeholder. 1304 for (const auto &resource : dialect.resourceMap) 1305 if (resource.second->isDeclaration) 1306 appendResourceOffset(resource.first, AsmResourceEntryKind::Blob); 1307 1308 // Emit the resource group for this dialect. 1309 if (!curResourceEntries.empty()) 1310 emitResourceGroup(dialect.number); 1311 } 1312 1313 // If we didn't emit any resource groups, elide the resource sections. 1314 if (resourceOffsetEmitter.size() == 0) 1315 return; 1316 1317 emitter.emitSection(bytecode::Section::kResourceOffset, 1318 std::move(resourceOffsetEmitter)); 1319 emitter.emitSection(bytecode::Section::kResource, std::move(resourceEmitter)); 1320 } 1321 1322 //===----------------------------------------------------------------------===// 1323 // Strings 1324 1325 void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { 1326 EncodingEmitter stringEmitter; 1327 stringSection.write(stringEmitter); 1328 emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter)); 1329 } 1330 1331 //===----------------------------------------------------------------------===// 1332 // Properties 1333 1334 void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { 1335 EncodingEmitter propertiesEmitter; 1336 propertiesSection.write(propertiesEmitter); 1337 emitter.emitSection(bytecode::Section::kProperties, 1338 std::move(propertiesEmitter)); 1339 } 1340 1341 //===----------------------------------------------------------------------===// 1342 // Entry Points 1343 //===----------------------------------------------------------------------===// 1344 1345 LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, 1346 const BytecodeWriterConfig &config) { 1347 BytecodeWriter writer(op, config); 1348 return writer.write(op, os); 1349 } 1350