xref: /llvm-project/mlir/lib/Bytecode/Writer/IRNumbering.cpp (revision f1ac7725e4fd5afa21fb244f9bcc33de654ed80c)
1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRNumbering.h"
10 #include "mlir/Bytecode/BytecodeImplementation.h"
11 #include "mlir/Bytecode/BytecodeOpInterface.h"
12 #include "mlir/Bytecode/BytecodeWriter.h"
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/OpDefinition.h"
17 
18 using namespace mlir;
19 using namespace mlir::bytecode::detail;
20 
21 //===----------------------------------------------------------------------===//
22 // NumberingDialectWriter
23 //===----------------------------------------------------------------------===//
24 
25 struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
26   NumberingDialectWriter(
27       IRNumberingState &state,
28       llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
29       : state(state), dialectVersionMap(dialectVersionMap) {}
30 
31   void writeAttribute(Attribute attr) override { state.number(attr); }
32   void writeOptionalAttribute(Attribute attr) override {
33     if (attr)
34       state.number(attr);
35   }
36   void writeType(Type type) override { state.number(type); }
37   void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
38     state.number(resource.getDialect(), resource);
39   }
40 
41   /// Stubbed out methods that are not used for numbering.
42   void writeVarInt(uint64_t) override {}
43   void writeSignedVarInt(int64_t value) override {}
44   void writeAPIntWithKnownWidth(const APInt &value) override {}
45   void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
46   void writeOwnedString(StringRef) override {
47     // TODO: It might be nice to prenumber strings and sort by the number of
48     // references. This could potentially be useful for optimizing things like
49     // file locations.
50   }
51   void writeOwnedBlob(ArrayRef<char> blob) override {}
52   void writeOwnedBool(bool value) override {}
53 
54   int64_t getBytecodeVersion() const override {
55     return state.getDesiredBytecodeVersion();
56   }
57 
58   FailureOr<const DialectVersion *>
59   getDialectVersion(StringRef dialectName) const override {
60     auto dialectEntry = dialectVersionMap.find(dialectName);
61     if (dialectEntry == dialectVersionMap.end())
62       return failure();
63     return dialectEntry->getValue().get();
64   }
65 
66   /// The parent numbering state that is populated by this writer.
67   IRNumberingState &state;
68 
69   /// A map containing dialect version information for each dialect to emit.
70   llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
71 };
72 
73 //===----------------------------------------------------------------------===//
74 // IR Numbering
75 //===----------------------------------------------------------------------===//
76 
77 /// Group and sort the elements of the given range by their parent dialect. This
78 /// grouping is applied to sub-sections of the ranged defined by how many bytes
79 /// it takes to encode a varint index to that sub-section.
80 template <typename T>
81 static void groupByDialectPerByte(T range) {
82   if (range.empty())
83     return;
84 
85   // A functor used to sort by a given dialect, with a desired dialect to be
86   // ordered first (to better enable sharing of dialects across byte groups).
87   auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
88                           const auto &rhs) {
89     if (lhs->dialect->number == dialectToOrderFirst)
90       return rhs->dialect->number != dialectToOrderFirst;
91     if (rhs->dialect->number == dialectToOrderFirst)
92       return false;
93     return lhs->dialect->number < rhs->dialect->number;
94   };
95 
96   unsigned dialectToOrderFirst = 0;
97   size_t elementsInByteGroup = 0;
98   auto iterRange = range;
99   for (unsigned i = 1; i < 9; ++i) {
100     // Update the number of elements in the current byte grouping. Reminder
101     // that varint encodes 7-bits per byte, so that's how we compute the
102     // number of elements in each byte grouping.
103     elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
104 
105     // Slice out the sub-set of elements that are in the current byte grouping
106     // to be sorted.
107     auto byteSubRange = iterRange.take_front(elementsInByteGroup);
108     iterRange = iterRange.drop_front(byteSubRange.size());
109 
110     // Sort the sub range for this byte.
111     llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
112       return sortByDialect(dialectToOrderFirst, lhs, rhs);
113     });
114 
115     // Update the dialect to order first to be the dialect at the end of the
116     // current grouping. This seeks to allow larger dialect groupings across
117     // byte boundaries.
118     dialectToOrderFirst = byteSubRange.back()->dialect->number;
119 
120     // If the data range is now empty, we are done.
121     if (iterRange.empty())
122       break;
123   }
124 
125   // Assign the entry numbers based on the sort order.
126   for (auto [idx, value] : llvm::enumerate(range))
127     value->number = idx;
128 }
129 
130 IRNumberingState::IRNumberingState(Operation *op,
131                                    const BytecodeWriterConfig &config)
132     : config(config) {
133   computeGlobalNumberingState(op);
134 
135   // Number the root operation.
136   number(*op);
137 
138   // A worklist of region contexts to number and the next value id before that
139   // region.
140   SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
141 
142   // Functor to push the regions of the given operation onto the numbering
143   // context.
144   auto addOpRegionsToNumber = [&](Operation *op) {
145     MutableArrayRef<Region> regions = op->getRegions();
146     if (regions.empty())
147       return;
148 
149     // Isolated regions don't share value numbers with their parent, so we can
150     // start numbering these regions at zero.
151     unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
152     for (Region &region : regions)
153       numberContext.emplace_back(&region, opFirstValueID);
154   };
155   addOpRegionsToNumber(op);
156 
157   // Iteratively process each of the nested regions.
158   while (!numberContext.empty()) {
159     Region *region;
160     std::tie(region, nextValueID) = numberContext.pop_back_val();
161     number(*region);
162 
163     // Traverse into nested regions.
164     for (Operation &op : region->getOps())
165       addOpRegionsToNumber(&op);
166   }
167 
168   // Number each of the dialects. For now this is just in the order they were
169   // found, given that the number of dialects on average is small enough to fit
170   // within a singly byte (128). If we ever have real world use cases that have
171   // a huge number of dialects, this could be made more intelligent.
172   for (auto [idx, dialect] : llvm::enumerate(dialects))
173     dialect.second->number = idx;
174 
175   // Number each of the recorded components within each dialect.
176 
177   // First sort by ref count so that the most referenced elements are first. We
178   // try to bias more heavily used elements to the front. This allows for more
179   // frequently referenced things to be encoded using smaller varints.
180   auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
181     return lhs->refCount > rhs->refCount;
182   };
183   llvm::stable_sort(orderedAttrs, sortByRefCountFn);
184   llvm::stable_sort(orderedOpNames, sortByRefCountFn);
185   llvm::stable_sort(orderedTypes, sortByRefCountFn);
186 
187   // After that, we apply a secondary ordering based on the parent dialect. This
188   // ordering is applied to sub-sections of the element list defined by how many
189   // bytes it takes to encode a varint index to that sub-section. This allows
190   // for more efficiently encoding components of the same dialect (e.g. we only
191   // have to encode the dialect reference once).
192   groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs));
193   groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames));
194   groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes));
195 
196   // Finalize the numbering of the dialect resources.
197   finalizeDialectResourceNumberings(op);
198 }
199 
200 void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
201   // A simple state struct tracking data used when walking operations.
202   struct StackState {
203     /// The operation currently being walked.
204     Operation *op;
205 
206     /// The numbering of the operation.
207     OperationNumbering *numbering;
208 
209     /// A flag indicating if the current state or one of its parents has
210     /// unresolved isolation status. This is tracked separately from the
211     /// isIsolatedFromAbove bit on `numbering` because we need to be able to
212     /// handle the given case:
213     ///   top.op {
214     ///     %value = ...
215     ///     middle.op {
216     ///       %value2 = ...
217     ///       inner.op {
218     ///         // Here we mark `inner.op` as not isolated. Note `middle.op`
219     ///         // isn't known not isolated yet.
220     ///         use.op %value2
221     ///
222     ///         // Here inner.op is already known to be non-isolated, but
223     ///         // `middle.op` is now also discovered to be non-isolated.
224     ///         use.op %value
225     ///       }
226     ///     }
227     ///   }
228     bool hasUnresolvedIsolation;
229   };
230 
231   // Compute a global operation ID numbering according to the pre-order walk of
232   // the IR. This is used as reference to construct use-list orders.
233   unsigned operationID = 0;
234 
235   // Walk each of the operations within the IR, tracking a stack of operations
236   // as we recurse into nested regions. This walk method hooks in at two stages
237   // during the walk:
238   //
239   //   BeforeAllRegions:
240   //     Here we generate a numbering for the operation and push it onto the
241   //     stack if it has regions. We also compute the isolation status of parent
242   //     regions at this stage. This is done by checking the parent regions of
243   //     operands used by the operation, and marking each region between the
244   //     the operand region and the current as not isolated. See
245   //     StackState::hasUnresolvedIsolation above for an example.
246   //
247   //   AfterAllRegions:
248   //     Here we pop the operation from the stack, and if it hasn't been marked
249   //     as non-isolated, we mark it as so. A non-isolated use would have been
250   //     found while walking the regions, so it is safe to mark the operation at
251   //     this point.
252   //
253   SmallVector<StackState> opStack;
254   rootOp->walk([&](Operation *op, const WalkStage &stage) {
255     // After visiting all nested regions, we pop the operation from the stack.
256     if (op->getNumRegions() && stage.isAfterAllRegions()) {
257       // If no non-isolated uses were found, we can safely mark this operation
258       // as isolated from above.
259       OperationNumbering *numbering = opStack.pop_back_val().numbering;
260       if (!numbering->isIsolatedFromAbove.has_value())
261         numbering->isIsolatedFromAbove = true;
262       return;
263     }
264 
265     // When visiting before nested regions, we process "IsolatedFromAbove"
266     // checks and compute the number for this operation.
267     if (!stage.isBeforeAllRegions())
268       return;
269     // Update the isolation status of parent regions if any have yet to be
270     // resolved.
271     if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
272       Region *parentRegion = op->getParentRegion();
273       for (Value operand : op->getOperands()) {
274         Region *operandRegion = operand.getParentRegion();
275         if (operandRegion == parentRegion)
276           continue;
277         // We've found a use of an operand outside of the current region,
278         // walk the operation stack searching for the parent operation,
279         // marking every region on the way as not isolated.
280         Operation *operandContainerOp = operandRegion->getParentOp();
281         auto it = std::find_if(
282             opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
283               // We only need to mark up to the container region, or the first
284               // that has an unresolved status.
285               return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
286             });
287         assert(it != opStack.rend() && "expected to find the container");
288         for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
289           // If we stopped at a region that knows its isolation status, we can
290           // stop updating the isolation status for the parent regions.
291           state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
292           state.numbering->isIsolatedFromAbove = false;
293         }
294       }
295     }
296 
297     // Compute the number for this op and push it onto the stack.
298     auto *numbering =
299         new (opAllocator.Allocate()) OperationNumbering(operationID++);
300     if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
301       numbering->isIsolatedFromAbove = true;
302     operations.try_emplace(op, numbering);
303     if (op->getNumRegions()) {
304       opStack.emplace_back(StackState{
305           op, numbering, !numbering->isIsolatedFromAbove.has_value()});
306     }
307   });
308 }
309 
310 void IRNumberingState::number(Attribute attr) {
311   auto it = attrs.insert({attr, nullptr});
312   if (!it.second) {
313     ++it.first->second->refCount;
314     return;
315   }
316   auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
317   it.first->second = numbering;
318   orderedAttrs.push_back(numbering);
319 
320   // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
321   // have a registered dialect when it got created. We don't want to encode this
322   // as the builtin OpaqueAttr, we want to encode it as if the dialect was
323   // actually loaded.
324   if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
325     numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
326     return;
327   }
328   numbering->dialect = &numberDialect(&attr.getDialect());
329 
330   // If this attribute will be emitted using the bytecode format, perform a
331   // dummy writing to number any nested components.
332   // TODO: We don't allow custom encodings for mutable attributes right now.
333   if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
334     // Try overriding emission with callbacks.
335     for (const auto &callback : config.getAttributeWriterCallbacks()) {
336       NumberingDialectWriter writer(*this, config.getDialectVersionMap());
337       // The client has the ability to override the group name through the
338       // callback.
339       std::optional<StringRef> groupNameOverride;
340       if (succeeded(callback->write(attr, groupNameOverride, writer))) {
341         if (groupNameOverride.has_value())
342           numbering->dialect = &numberDialect(*groupNameOverride);
343         return;
344       }
345     }
346 
347     if (const auto *interface = numbering->dialect->interface) {
348       NumberingDialectWriter writer(*this, config.getDialectVersionMap());
349       if (succeeded(interface->writeAttribute(attr, writer)))
350         return;
351     }
352   }
353   // If this attribute will be emitted using the fallback, number the nested
354   // dialect resources. We don't number everything (e.g. no nested
355   // attributes/types), because we don't want to encode things we won't decode
356   // (the textual format can't really share much).
357   AsmState tempState(attr.getContext());
358   llvm::raw_null_ostream dummyOS;
359   attr.print(dummyOS, tempState);
360 
361   // Number the used dialect resources.
362   for (const auto &it : tempState.getDialectResources())
363     number(it.getFirst(), it.getSecond().getArrayRef());
364 }
365 
366 void IRNumberingState::number(Block &block) {
367   // Number the arguments of the block.
368   for (BlockArgument arg : block.getArguments()) {
369     valueIDs.try_emplace(arg, nextValueID++);
370     number(arg.getLoc());
371     number(arg.getType());
372   }
373 
374   // Number the operations in this block.
375   unsigned &numOps = blockOperationCounts[&block];
376   for (Operation &op : block) {
377     number(op);
378     ++numOps;
379   }
380 }
381 
382 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
383   DialectNumbering *&numbering = registeredDialects[dialect];
384   if (!numbering) {
385     numbering = &numberDialect(dialect->getNamespace());
386     numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
387     numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
388   }
389   return *numbering;
390 }
391 
392 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
393   DialectNumbering *&numbering = dialects[dialect];
394   if (!numbering) {
395     numbering = new (dialectAllocator.Allocate())
396         DialectNumbering(dialect, dialects.size() - 1);
397   }
398   return *numbering;
399 }
400 
401 void IRNumberingState::number(Region &region) {
402   if (region.empty())
403     return;
404   size_t firstValueID = nextValueID;
405 
406   // Number the blocks within this region.
407   size_t blockCount = 0;
408   for (auto it : llvm::enumerate(region)) {
409     blockIDs.try_emplace(&it.value(), it.index());
410     number(it.value());
411     ++blockCount;
412   }
413 
414   // Remember the number of blocks and values in this region.
415   regionBlockValueCounts.try_emplace(&region, blockCount,
416                                      nextValueID - firstValueID);
417 }
418 
419 void IRNumberingState::number(Operation &op) {
420   // Number the components of an operation that won't be numbered elsewhere
421   // (e.g. we don't number operands, regions, or successors here).
422   number(op.getName());
423   for (OpResult result : op.getResults()) {
424     valueIDs.try_emplace(result, nextValueID++);
425     number(result.getType());
426   }
427 
428   // Prior to a version with native property encoding, or when properties are
429   // not used, we need to number also the merged dictionary containing both the
430   // inherent and discardable attribute.
431   DictionaryAttr dictAttr;
432   if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding)
433     dictAttr = op.getRawDictionaryAttrs();
434   else
435     dictAttr = op.getAttrDictionary();
436   // Only number the operation's dictionary if it isn't empty.
437   if (!dictAttr.empty())
438     number(dictAttr);
439 
440   // Visit the operation properties (if any) to make sure referenced attributes
441   // are numbered.
442   if (config.getDesiredBytecodeVersion() >=
443           bytecode::kNativePropertiesEncoding &&
444       op.getPropertiesStorageSize()) {
445     if (op.isRegistered()) {
446       // Operation that have properties *must* implement this interface.
447       auto iface = cast<BytecodeOpInterface>(op);
448       NumberingDialectWriter writer(*this, config.getDialectVersionMap());
449       iface.writeProperties(writer);
450     } else {
451       // Unregistered op are storing properties as an optional attribute.
452       if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
453         number(prop);
454     }
455   }
456 
457   number(op.getLoc());
458 }
459 
460 void IRNumberingState::number(OperationName opName) {
461   OpNameNumbering *&numbering = opNames[opName];
462   if (numbering) {
463     ++numbering->refCount;
464     return;
465   }
466   DialectNumbering *dialectNumber = nullptr;
467   if (Dialect *dialect = opName.getDialect())
468     dialectNumber = &numberDialect(dialect);
469   else
470     dialectNumber = &numberDialect(opName.getDialectNamespace());
471 
472   numbering =
473       new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
474   orderedOpNames.push_back(numbering);
475 }
476 
477 void IRNumberingState::number(Type type) {
478   auto it = types.insert({type, nullptr});
479   if (!it.second) {
480     ++it.first->second->refCount;
481     return;
482   }
483   auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
484   it.first->second = numbering;
485   orderedTypes.push_back(numbering);
486 
487   // Check for OpaqueType, which is a dialect-specific type that didn't have a
488   // registered dialect when it got created. We don't want to encode this as the
489   // builtin OpaqueType, we want to encode it as if the dialect was actually
490   // loaded.
491   if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
492     numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
493     return;
494   }
495   numbering->dialect = &numberDialect(&type.getDialect());
496 
497   // If this type will be emitted using the bytecode format, perform a dummy
498   // writing to number any nested components.
499   // TODO: We don't allow custom encodings for mutable types right now.
500   if (!type.hasTrait<TypeTrait::IsMutable>()) {
501     // Try overriding emission with callbacks.
502     for (const auto &callback : config.getTypeWriterCallbacks()) {
503       NumberingDialectWriter writer(*this, config.getDialectVersionMap());
504       // The client has the ability to override the group name through the
505       // callback.
506       std::optional<StringRef> groupNameOverride;
507       if (succeeded(callback->write(type, groupNameOverride, writer))) {
508         if (groupNameOverride.has_value())
509           numbering->dialect = &numberDialect(*groupNameOverride);
510         return;
511       }
512     }
513 
514     // If this attribute will be emitted using the bytecode format, perform a
515     // dummy writing to number any nested components.
516     if (const auto *interface = numbering->dialect->interface) {
517       NumberingDialectWriter writer(*this, config.getDialectVersionMap());
518       if (succeeded(interface->writeType(type, writer)))
519         return;
520     }
521   }
522   // If this type will be emitted using the fallback, number the nested dialect
523   // resources. We don't number everything (e.g. no nested attributes/types),
524   // because we don't want to encode things we won't decode (the textual format
525   // can't really share much).
526   AsmState tempState(type.getContext());
527   llvm::raw_null_ostream dummyOS;
528   type.print(dummyOS, tempState);
529 
530   // Number the used dialect resources.
531   for (const auto &it : tempState.getDialectResources())
532     number(it.getFirst(), it.getSecond().getArrayRef());
533 }
534 
535 void IRNumberingState::number(Dialect *dialect,
536                               ArrayRef<AsmDialectResourceHandle> resources) {
537   DialectNumbering &dialectNumber = numberDialect(dialect);
538   assert(
539       dialectNumber.asmInterface &&
540       "expected dialect owning a resource to implement OpAsmDialectInterface");
541 
542   for (const auto &resource : resources) {
543     // Check if this is a newly seen resource.
544     if (!dialectNumber.resources.insert(resource))
545       return;
546 
547     auto *numbering =
548         new (resourceAllocator.Allocate()) DialectResourceNumbering(
549             dialectNumber.asmInterface->getResourceKey(resource));
550     dialectNumber.resourceMap.insert({numbering->key, numbering});
551     dialectResources.try_emplace(resource, numbering);
552   }
553 }
554 
555 int64_t IRNumberingState::getDesiredBytecodeVersion() const {
556   return config.getDesiredBytecodeVersion();
557 }
558 
559 namespace {
560 /// A dummy resource builder used to number dialect resources.
561 struct NumberingResourceBuilder : public AsmResourceBuilder {
562   NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
563       : dialect(dialect), nextResourceID(nextResourceID) {}
564   ~NumberingResourceBuilder() override = default;
565 
566   void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
567     numberEntry(key);
568   }
569   void buildBool(StringRef key, bool) final { numberEntry(key); }
570   void buildString(StringRef key, StringRef) final {
571     // TODO: We could pre-number the value string here as well.
572     numberEntry(key);
573   }
574 
575   /// Number the dialect entry for the given key.
576   void numberEntry(StringRef key) {
577     // TODO: We could pre-number resource key strings here as well.
578 
579     auto *it = dialect->resourceMap.find(key);
580     if (it != dialect->resourceMap.end()) {
581       it->second->number = nextResourceID++;
582       it->second->isDeclaration = false;
583     }
584   }
585 
586   DialectNumbering *dialect;
587   unsigned &nextResourceID;
588 };
589 } // namespace
590 
591 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
592   unsigned nextResourceID = 0;
593   for (DialectNumbering &dialect : getDialects()) {
594     if (!dialect.asmInterface)
595       continue;
596     NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
597     dialect.asmInterface->buildResources(rootOp, dialect.resources,
598                                          entryBuilder);
599 
600     // Number any resources that weren't added by the dialect. This can happen
601     // if there was no backing data to the resource, but we still want these
602     // resource references to roundtrip, so we number them and indicate that the
603     // data is missing.
604     for (const auto &it : dialect.resourceMap)
605       if (it.second->isDeclaration)
606         it.second->number = nextResourceID++;
607   }
608 }
609