1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRNumbering.h" 10 #include "mlir/Bytecode/BytecodeImplementation.h" 11 #include "mlir/Bytecode/BytecodeOpInterface.h" 12 #include "mlir/Bytecode/BytecodeWriter.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/OpDefinition.h" 17 18 using namespace mlir; 19 using namespace mlir::bytecode::detail; 20 21 //===----------------------------------------------------------------------===// 22 // NumberingDialectWriter 23 //===----------------------------------------------------------------------===// 24 25 struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { 26 NumberingDialectWriter( 27 IRNumberingState &state, 28 llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap) 29 : state(state), dialectVersionMap(dialectVersionMap) {} 30 31 void writeAttribute(Attribute attr) override { state.number(attr); } 32 void writeOptionalAttribute(Attribute attr) override { 33 if (attr) 34 state.number(attr); 35 } 36 void writeType(Type type) override { state.number(type); } 37 void writeResourceHandle(const AsmDialectResourceHandle &resource) override { 38 state.number(resource.getDialect(), resource); 39 } 40 41 /// Stubbed out methods that are not used for numbering. 42 void writeVarInt(uint64_t) override {} 43 void writeSignedVarInt(int64_t value) override {} 44 void writeAPIntWithKnownWidth(const APInt &value) override {} 45 void writeAPFloatWithKnownSemantics(const APFloat &value) override {} 46 void writeOwnedString(StringRef) override { 47 // TODO: It might be nice to prenumber strings and sort by the number of 48 // references. This could potentially be useful for optimizing things like 49 // file locations. 50 } 51 void writeOwnedBlob(ArrayRef<char> blob) override {} 52 void writeOwnedBool(bool value) override {} 53 54 int64_t getBytecodeVersion() const override { 55 return state.getDesiredBytecodeVersion(); 56 } 57 58 FailureOr<const DialectVersion *> 59 getDialectVersion(StringRef dialectName) const override { 60 auto dialectEntry = dialectVersionMap.find(dialectName); 61 if (dialectEntry == dialectVersionMap.end()) 62 return failure(); 63 return dialectEntry->getValue().get(); 64 } 65 66 /// The parent numbering state that is populated by this writer. 67 IRNumberingState &state; 68 69 /// A map containing dialect version information for each dialect to emit. 70 llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap; 71 }; 72 73 //===----------------------------------------------------------------------===// 74 // IR Numbering 75 //===----------------------------------------------------------------------===// 76 77 /// Group and sort the elements of the given range by their parent dialect. This 78 /// grouping is applied to sub-sections of the ranged defined by how many bytes 79 /// it takes to encode a varint index to that sub-section. 80 template <typename T> 81 static void groupByDialectPerByte(T range) { 82 if (range.empty()) 83 return; 84 85 // A functor used to sort by a given dialect, with a desired dialect to be 86 // ordered first (to better enable sharing of dialects across byte groups). 87 auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs, 88 const auto &rhs) { 89 if (lhs->dialect->number == dialectToOrderFirst) 90 return rhs->dialect->number != dialectToOrderFirst; 91 if (rhs->dialect->number == dialectToOrderFirst) 92 return false; 93 return lhs->dialect->number < rhs->dialect->number; 94 }; 95 96 unsigned dialectToOrderFirst = 0; 97 size_t elementsInByteGroup = 0; 98 auto iterRange = range; 99 for (unsigned i = 1; i < 9; ++i) { 100 // Update the number of elements in the current byte grouping. Reminder 101 // that varint encodes 7-bits per byte, so that's how we compute the 102 // number of elements in each byte grouping. 103 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup; 104 105 // Slice out the sub-set of elements that are in the current byte grouping 106 // to be sorted. 107 auto byteSubRange = iterRange.take_front(elementsInByteGroup); 108 iterRange = iterRange.drop_front(byteSubRange.size()); 109 110 // Sort the sub range for this byte. 111 llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) { 112 return sortByDialect(dialectToOrderFirst, lhs, rhs); 113 }); 114 115 // Update the dialect to order first to be the dialect at the end of the 116 // current grouping. This seeks to allow larger dialect groupings across 117 // byte boundaries. 118 dialectToOrderFirst = byteSubRange.back()->dialect->number; 119 120 // If the data range is now empty, we are done. 121 if (iterRange.empty()) 122 break; 123 } 124 125 // Assign the entry numbers based on the sort order. 126 for (auto [idx, value] : llvm::enumerate(range)) 127 value->number = idx; 128 } 129 130 IRNumberingState::IRNumberingState(Operation *op, 131 const BytecodeWriterConfig &config) 132 : config(config) { 133 computeGlobalNumberingState(op); 134 135 // Number the root operation. 136 number(*op); 137 138 // A worklist of region contexts to number and the next value id before that 139 // region. 140 SmallVector<std::pair<Region *, unsigned>, 8> numberContext; 141 142 // Functor to push the regions of the given operation onto the numbering 143 // context. 144 auto addOpRegionsToNumber = [&](Operation *op) { 145 MutableArrayRef<Region> regions = op->getRegions(); 146 if (regions.empty()) 147 return; 148 149 // Isolated regions don't share value numbers with their parent, so we can 150 // start numbering these regions at zero. 151 unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID; 152 for (Region ®ion : regions) 153 numberContext.emplace_back(®ion, opFirstValueID); 154 }; 155 addOpRegionsToNumber(op); 156 157 // Iteratively process each of the nested regions. 158 while (!numberContext.empty()) { 159 Region *region; 160 std::tie(region, nextValueID) = numberContext.pop_back_val(); 161 number(*region); 162 163 // Traverse into nested regions. 164 for (Operation &op : region->getOps()) 165 addOpRegionsToNumber(&op); 166 } 167 168 // Number each of the dialects. For now this is just in the order they were 169 // found, given that the number of dialects on average is small enough to fit 170 // within a singly byte (128). If we ever have real world use cases that have 171 // a huge number of dialects, this could be made more intelligent. 172 for (auto [idx, dialect] : llvm::enumerate(dialects)) 173 dialect.second->number = idx; 174 175 // Number each of the recorded components within each dialect. 176 177 // First sort by ref count so that the most referenced elements are first. We 178 // try to bias more heavily used elements to the front. This allows for more 179 // frequently referenced things to be encoded using smaller varints. 180 auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) { 181 return lhs->refCount > rhs->refCount; 182 }; 183 llvm::stable_sort(orderedAttrs, sortByRefCountFn); 184 llvm::stable_sort(orderedOpNames, sortByRefCountFn); 185 llvm::stable_sort(orderedTypes, sortByRefCountFn); 186 187 // After that, we apply a secondary ordering based on the parent dialect. This 188 // ordering is applied to sub-sections of the element list defined by how many 189 // bytes it takes to encode a varint index to that sub-section. This allows 190 // for more efficiently encoding components of the same dialect (e.g. we only 191 // have to encode the dialect reference once). 192 groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs)); 193 groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames)); 194 groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes)); 195 196 // Finalize the numbering of the dialect resources. 197 finalizeDialectResourceNumberings(op); 198 } 199 200 void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) { 201 // A simple state struct tracking data used when walking operations. 202 struct StackState { 203 /// The operation currently being walked. 204 Operation *op; 205 206 /// The numbering of the operation. 207 OperationNumbering *numbering; 208 209 /// A flag indicating if the current state or one of its parents has 210 /// unresolved isolation status. This is tracked separately from the 211 /// isIsolatedFromAbove bit on `numbering` because we need to be able to 212 /// handle the given case: 213 /// top.op { 214 /// %value = ... 215 /// middle.op { 216 /// %value2 = ... 217 /// inner.op { 218 /// // Here we mark `inner.op` as not isolated. Note `middle.op` 219 /// // isn't known not isolated yet. 220 /// use.op %value2 221 /// 222 /// // Here inner.op is already known to be non-isolated, but 223 /// // `middle.op` is now also discovered to be non-isolated. 224 /// use.op %value 225 /// } 226 /// } 227 /// } 228 bool hasUnresolvedIsolation; 229 }; 230 231 // Compute a global operation ID numbering according to the pre-order walk of 232 // the IR. This is used as reference to construct use-list orders. 233 unsigned operationID = 0; 234 235 // Walk each of the operations within the IR, tracking a stack of operations 236 // as we recurse into nested regions. This walk method hooks in at two stages 237 // during the walk: 238 // 239 // BeforeAllRegions: 240 // Here we generate a numbering for the operation and push it onto the 241 // stack if it has regions. We also compute the isolation status of parent 242 // regions at this stage. This is done by checking the parent regions of 243 // operands used by the operation, and marking each region between the 244 // the operand region and the current as not isolated. See 245 // StackState::hasUnresolvedIsolation above for an example. 246 // 247 // AfterAllRegions: 248 // Here we pop the operation from the stack, and if it hasn't been marked 249 // as non-isolated, we mark it as so. A non-isolated use would have been 250 // found while walking the regions, so it is safe to mark the operation at 251 // this point. 252 // 253 SmallVector<StackState> opStack; 254 rootOp->walk([&](Operation *op, const WalkStage &stage) { 255 // After visiting all nested regions, we pop the operation from the stack. 256 if (op->getNumRegions() && stage.isAfterAllRegions()) { 257 // If no non-isolated uses were found, we can safely mark this operation 258 // as isolated from above. 259 OperationNumbering *numbering = opStack.pop_back_val().numbering; 260 if (!numbering->isIsolatedFromAbove.has_value()) 261 numbering->isIsolatedFromAbove = true; 262 return; 263 } 264 265 // When visiting before nested regions, we process "IsolatedFromAbove" 266 // checks and compute the number for this operation. 267 if (!stage.isBeforeAllRegions()) 268 return; 269 // Update the isolation status of parent regions if any have yet to be 270 // resolved. 271 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) { 272 Region *parentRegion = op->getParentRegion(); 273 for (Value operand : op->getOperands()) { 274 Region *operandRegion = operand.getParentRegion(); 275 if (operandRegion == parentRegion) 276 continue; 277 // We've found a use of an operand outside of the current region, 278 // walk the operation stack searching for the parent operation, 279 // marking every region on the way as not isolated. 280 Operation *operandContainerOp = operandRegion->getParentOp(); 281 auto it = std::find_if( 282 opStack.rbegin(), opStack.rend(), [=](const StackState &it) { 283 // We only need to mark up to the container region, or the first 284 // that has an unresolved status. 285 return !it.hasUnresolvedIsolation || it.op == operandContainerOp; 286 }); 287 assert(it != opStack.rend() && "expected to find the container"); 288 for (auto &state : llvm::make_range(opStack.rbegin(), it)) { 289 // If we stopped at a region that knows its isolation status, we can 290 // stop updating the isolation status for the parent regions. 291 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation; 292 state.numbering->isIsolatedFromAbove = false; 293 } 294 } 295 } 296 297 // Compute the number for this op and push it onto the stack. 298 auto *numbering = 299 new (opAllocator.Allocate()) OperationNumbering(operationID++); 300 if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 301 numbering->isIsolatedFromAbove = true; 302 operations.try_emplace(op, numbering); 303 if (op->getNumRegions()) { 304 opStack.emplace_back(StackState{ 305 op, numbering, !numbering->isIsolatedFromAbove.has_value()}); 306 } 307 }); 308 } 309 310 void IRNumberingState::number(Attribute attr) { 311 auto it = attrs.insert({attr, nullptr}); 312 if (!it.second) { 313 ++it.first->second->refCount; 314 return; 315 } 316 auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr); 317 it.first->second = numbering; 318 orderedAttrs.push_back(numbering); 319 320 // Check for OpaqueAttr, which is a dialect-specific attribute that didn't 321 // have a registered dialect when it got created. We don't want to encode this 322 // as the builtin OpaqueAttr, we want to encode it as if the dialect was 323 // actually loaded. 324 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) { 325 numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); 326 return; 327 } 328 numbering->dialect = &numberDialect(&attr.getDialect()); 329 330 // If this attribute will be emitted using the bytecode format, perform a 331 // dummy writing to number any nested components. 332 // TODO: We don't allow custom encodings for mutable attributes right now. 333 if (!attr.hasTrait<AttributeTrait::IsMutable>()) { 334 // Try overriding emission with callbacks. 335 for (const auto &callback : config.getAttributeWriterCallbacks()) { 336 NumberingDialectWriter writer(*this, config.getDialectVersionMap()); 337 // The client has the ability to override the group name through the 338 // callback. 339 std::optional<StringRef> groupNameOverride; 340 if (succeeded(callback->write(attr, groupNameOverride, writer))) { 341 if (groupNameOverride.has_value()) 342 numbering->dialect = &numberDialect(*groupNameOverride); 343 return; 344 } 345 } 346 347 if (const auto *interface = numbering->dialect->interface) { 348 NumberingDialectWriter writer(*this, config.getDialectVersionMap()); 349 if (succeeded(interface->writeAttribute(attr, writer))) 350 return; 351 } 352 } 353 // If this attribute will be emitted using the fallback, number the nested 354 // dialect resources. We don't number everything (e.g. no nested 355 // attributes/types), because we don't want to encode things we won't decode 356 // (the textual format can't really share much). 357 AsmState tempState(attr.getContext()); 358 llvm::raw_null_ostream dummyOS; 359 attr.print(dummyOS, tempState); 360 361 // Number the used dialect resources. 362 for (const auto &it : tempState.getDialectResources()) 363 number(it.getFirst(), it.getSecond().getArrayRef()); 364 } 365 366 void IRNumberingState::number(Block &block) { 367 // Number the arguments of the block. 368 for (BlockArgument arg : block.getArguments()) { 369 valueIDs.try_emplace(arg, nextValueID++); 370 number(arg.getLoc()); 371 number(arg.getType()); 372 } 373 374 // Number the operations in this block. 375 unsigned &numOps = blockOperationCounts[&block]; 376 for (Operation &op : block) { 377 number(op); 378 ++numOps; 379 } 380 } 381 382 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & { 383 DialectNumbering *&numbering = registeredDialects[dialect]; 384 if (!numbering) { 385 numbering = &numberDialect(dialect->getNamespace()); 386 numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect); 387 numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect); 388 } 389 return *numbering; 390 } 391 392 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & { 393 DialectNumbering *&numbering = dialects[dialect]; 394 if (!numbering) { 395 numbering = new (dialectAllocator.Allocate()) 396 DialectNumbering(dialect, dialects.size() - 1); 397 } 398 return *numbering; 399 } 400 401 void IRNumberingState::number(Region ®ion) { 402 if (region.empty()) 403 return; 404 size_t firstValueID = nextValueID; 405 406 // Number the blocks within this region. 407 size_t blockCount = 0; 408 for (auto it : llvm::enumerate(region)) { 409 blockIDs.try_emplace(&it.value(), it.index()); 410 number(it.value()); 411 ++blockCount; 412 } 413 414 // Remember the number of blocks and values in this region. 415 regionBlockValueCounts.try_emplace(®ion, blockCount, 416 nextValueID - firstValueID); 417 } 418 419 void IRNumberingState::number(Operation &op) { 420 // Number the components of an operation that won't be numbered elsewhere 421 // (e.g. we don't number operands, regions, or successors here). 422 number(op.getName()); 423 for (OpResult result : op.getResults()) { 424 valueIDs.try_emplace(result, nextValueID++); 425 number(result.getType()); 426 } 427 428 // Prior to a version with native property encoding, or when properties are 429 // not used, we need to number also the merged dictionary containing both the 430 // inherent and discardable attribute. 431 DictionaryAttr dictAttr; 432 if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding) 433 dictAttr = op.getRawDictionaryAttrs(); 434 else 435 dictAttr = op.getAttrDictionary(); 436 // Only number the operation's dictionary if it isn't empty. 437 if (!dictAttr.empty()) 438 number(dictAttr); 439 440 // Visit the operation properties (if any) to make sure referenced attributes 441 // are numbered. 442 if (config.getDesiredBytecodeVersion() >= 443 bytecode::kNativePropertiesEncoding && 444 op.getPropertiesStorageSize()) { 445 if (op.isRegistered()) { 446 // Operation that have properties *must* implement this interface. 447 auto iface = cast<BytecodeOpInterface>(op); 448 NumberingDialectWriter writer(*this, config.getDialectVersionMap()); 449 iface.writeProperties(writer); 450 } else { 451 // Unregistered op are storing properties as an optional attribute. 452 if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>()) 453 number(prop); 454 } 455 } 456 457 number(op.getLoc()); 458 } 459 460 void IRNumberingState::number(OperationName opName) { 461 OpNameNumbering *&numbering = opNames[opName]; 462 if (numbering) { 463 ++numbering->refCount; 464 return; 465 } 466 DialectNumbering *dialectNumber = nullptr; 467 if (Dialect *dialect = opName.getDialect()) 468 dialectNumber = &numberDialect(dialect); 469 else 470 dialectNumber = &numberDialect(opName.getDialectNamespace()); 471 472 numbering = 473 new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName); 474 orderedOpNames.push_back(numbering); 475 } 476 477 void IRNumberingState::number(Type type) { 478 auto it = types.insert({type, nullptr}); 479 if (!it.second) { 480 ++it.first->second->refCount; 481 return; 482 } 483 auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type); 484 it.first->second = numbering; 485 orderedTypes.push_back(numbering); 486 487 // Check for OpaqueType, which is a dialect-specific type that didn't have a 488 // registered dialect when it got created. We don't want to encode this as the 489 // builtin OpaqueType, we want to encode it as if the dialect was actually 490 // loaded. 491 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) { 492 numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); 493 return; 494 } 495 numbering->dialect = &numberDialect(&type.getDialect()); 496 497 // If this type will be emitted using the bytecode format, perform a dummy 498 // writing to number any nested components. 499 // TODO: We don't allow custom encodings for mutable types right now. 500 if (!type.hasTrait<TypeTrait::IsMutable>()) { 501 // Try overriding emission with callbacks. 502 for (const auto &callback : config.getTypeWriterCallbacks()) { 503 NumberingDialectWriter writer(*this, config.getDialectVersionMap()); 504 // The client has the ability to override the group name through the 505 // callback. 506 std::optional<StringRef> groupNameOverride; 507 if (succeeded(callback->write(type, groupNameOverride, writer))) { 508 if (groupNameOverride.has_value()) 509 numbering->dialect = &numberDialect(*groupNameOverride); 510 return; 511 } 512 } 513 514 // If this attribute will be emitted using the bytecode format, perform a 515 // dummy writing to number any nested components. 516 if (const auto *interface = numbering->dialect->interface) { 517 NumberingDialectWriter writer(*this, config.getDialectVersionMap()); 518 if (succeeded(interface->writeType(type, writer))) 519 return; 520 } 521 } 522 // If this type will be emitted using the fallback, number the nested dialect 523 // resources. We don't number everything (e.g. no nested attributes/types), 524 // because we don't want to encode things we won't decode (the textual format 525 // can't really share much). 526 AsmState tempState(type.getContext()); 527 llvm::raw_null_ostream dummyOS; 528 type.print(dummyOS, tempState); 529 530 // Number the used dialect resources. 531 for (const auto &it : tempState.getDialectResources()) 532 number(it.getFirst(), it.getSecond().getArrayRef()); 533 } 534 535 void IRNumberingState::number(Dialect *dialect, 536 ArrayRef<AsmDialectResourceHandle> resources) { 537 DialectNumbering &dialectNumber = numberDialect(dialect); 538 assert( 539 dialectNumber.asmInterface && 540 "expected dialect owning a resource to implement OpAsmDialectInterface"); 541 542 for (const auto &resource : resources) { 543 // Check if this is a newly seen resource. 544 if (!dialectNumber.resources.insert(resource)) 545 return; 546 547 auto *numbering = 548 new (resourceAllocator.Allocate()) DialectResourceNumbering( 549 dialectNumber.asmInterface->getResourceKey(resource)); 550 dialectNumber.resourceMap.insert({numbering->key, numbering}); 551 dialectResources.try_emplace(resource, numbering); 552 } 553 } 554 555 int64_t IRNumberingState::getDesiredBytecodeVersion() const { 556 return config.getDesiredBytecodeVersion(); 557 } 558 559 namespace { 560 /// A dummy resource builder used to number dialect resources. 561 struct NumberingResourceBuilder : public AsmResourceBuilder { 562 NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID) 563 : dialect(dialect), nextResourceID(nextResourceID) {} 564 ~NumberingResourceBuilder() override = default; 565 566 void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final { 567 numberEntry(key); 568 } 569 void buildBool(StringRef key, bool) final { numberEntry(key); } 570 void buildString(StringRef key, StringRef) final { 571 // TODO: We could pre-number the value string here as well. 572 numberEntry(key); 573 } 574 575 /// Number the dialect entry for the given key. 576 void numberEntry(StringRef key) { 577 // TODO: We could pre-number resource key strings here as well. 578 579 auto *it = dialect->resourceMap.find(key); 580 if (it != dialect->resourceMap.end()) { 581 it->second->number = nextResourceID++; 582 it->second->isDeclaration = false; 583 } 584 } 585 586 DialectNumbering *dialect; 587 unsigned &nextResourceID; 588 }; 589 } // namespace 590 591 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) { 592 unsigned nextResourceID = 0; 593 for (DialectNumbering &dialect : getDialects()) { 594 if (!dialect.asmInterface) 595 continue; 596 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID); 597 dialect.asmInterface->buildResources(rootOp, dialect.resources, 598 entryBuilder); 599 600 // Number any resources that weren't added by the dialect. This can happen 601 // if there was no backing data to the resource, but we still want these 602 // resource references to roundtrip, so we number them and indicate that the 603 // data is missing. 604 for (const auto &it : dialect.resourceMap) 605 if (it.second->isDeclaration) 606 it.second->number = nextResourceID++; 607 } 608 } 609