xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision 0a81ace0047a2de93e71c82cdf0977fc989660df)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38                                               PDLPatternConfigSet *configSet,
39                                               ByteCodeAddr rewriterAddr) {
40   PatternBenefit benefit = matchOp.getBenefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Collect the set of generated operations.
44   SmallVector<StringRef, 8> generatedOps;
45   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46     generatedOps =
47         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49   // Check to see if this is pattern matches a specific operation type.
50   if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51     return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52                               generatedOps);
53   return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54                             benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65                                                    PatternBenefit benefit) {
66   currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
72 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73   allocatedTypeRangeMemory.clear();
74   allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83   /// Apply an externally registered constraint.
84   ApplyConstraint,
85   /// Apply an externally registered rewrite.
86   ApplyRewrite,
87   /// Check if two generic values are equal.
88   AreEqual,
89   /// Check if two ranges are equal.
90   AreRangesEqual,
91   /// Unconditional branch.
92   Branch,
93   /// Compare the operand count of an operation with a constant.
94   CheckOperandCount,
95   /// Compare the name of an operation with a constant.
96   CheckOperationName,
97   /// Compare the result count of an operation with a constant.
98   CheckResultCount,
99   /// Compare a range of types to a constant range of types.
100   CheckTypes,
101   /// Continue to the next iteration of a loop.
102   Continue,
103   /// Create a type range from a list of constant types.
104   CreateConstantTypeRange,
105   /// Create an operation.
106   CreateOperation,
107   /// Create a type range from a list of dynamic types.
108   CreateDynamicTypeRange,
109   /// Create a value range.
110   CreateDynamicValueRange,
111   /// Erase an operation.
112   EraseOp,
113   /// Extract the op from a range at the specified index.
114   ExtractOp,
115   /// Extract the type from a range at the specified index.
116   ExtractType,
117   /// Extract the value from a range at the specified index.
118   ExtractValue,
119   /// Terminate a matcher or rewrite sequence.
120   Finalize,
121   /// Iterate over a range of values.
122   ForEach,
123   /// Get a specific attribute of an operation.
124   GetAttribute,
125   /// Get the type of an attribute.
126   GetAttributeType,
127   /// Get the defining operation of a value.
128   GetDefiningOp,
129   /// Get a specific operand of an operation.
130   GetOperand0,
131   GetOperand1,
132   GetOperand2,
133   GetOperand3,
134   GetOperandN,
135   /// Get a specific operand group of an operation.
136   GetOperands,
137   /// Get a specific result of an operation.
138   GetResult0,
139   GetResult1,
140   GetResult2,
141   GetResult3,
142   GetResultN,
143   /// Get a specific result group of an operation.
144   GetResults,
145   /// Get the users of a value or a range of values.
146   GetUsers,
147   /// Get the type of a value.
148   GetValueType,
149   /// Get the types of a value range.
150   GetValueRangeTypes,
151   /// Check if a generic value is not null.
152   IsNotNull,
153   /// Record a successful pattern match.
154   RecordMatch,
155   /// Replace an operation.
156   ReplaceOp,
157   /// Compare an attribute with a set of constants.
158   SwitchAttribute,
159   /// Compare the operand count of an operation with a set of constants.
160   SwitchOperandCount,
161   /// Compare the name of an operation with a set of constants.
162   SwitchOperationName,
163   /// Compare the result count of an operation with a set of constants.
164   SwitchResultCount,
165   /// Compare a type with a set of constants.
166   SwitchType,
167   /// Compare a range of types with a set of constants.
168   SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
174     std::numeric_limits<ByteCodeField>::max();
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
194   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195             SmallVectorImpl<ByteCodeField> &matcherByteCode,
196             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197             SmallVectorImpl<PDLByteCodePattern> &patterns,
198             ByteCodeField &maxValueMemoryIndex,
199             ByteCodeField &maxOpRangeMemoryIndex,
200             ByteCodeField &maxTypeRangeMemoryIndex,
201             ByteCodeField &maxValueRangeMemoryIndex,
202             ByteCodeField &maxLoopLevel,
203             llvm::StringMap<PDLConstraintFunction> &constraintFns,
204             llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205             const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207         rewriterByteCode(rewriterByteCode), patterns(patterns),
208         maxValueMemoryIndex(maxValueMemoryIndex),
209         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212         maxLoopLevel(maxLoopLevel), configMap(configMap) {
213     for (const auto &it : llvm::enumerate(constraintFns))
214       constraintToMemIndex.try_emplace(it.value().first(), it.index());
215     for (const auto &it : llvm::enumerate(rewriteFns))
216       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217   }
218 
219   /// Generate the bytecode for the given PDL interpreter module.
220   void generate(ModuleOp module);
221 
222   /// Return the memory index to use for the given value.
223   ByteCodeField &getMemIndex(Value value) {
224     assert(valueToMemIndex.count(value) &&
225            "expected memory index to be assigned");
226     return valueToMemIndex[value];
227   }
228 
229   /// Return the range memory index used to store the given range value.
230   ByteCodeField &getRangeStorageIndex(Value value) {
231     assert(valueToRangeIndex.count(value) &&
232            "expected range index to be assigned");
233     return valueToRangeIndex[value];
234   }
235 
236   /// Return an index to use when referring to the given data that is uniqued in
237   /// the MLIR context.
238   template <typename T>
239   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240   getMemIndex(T val) {
241     const void *opaqueVal = val.getAsOpaquePointer();
242 
243     // Get or insert a reference to this value.
244     auto it = uniquedDataToMemIndex.try_emplace(
245         opaqueVal, maxValueMemoryIndex + uniquedData.size());
246     if (it.second)
247       uniquedData.push_back(opaqueVal);
248     return it.first->second;
249   }
250 
251 private:
252   /// Allocate memory indices for the results of operations within the matcher
253   /// and rewriters.
254   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255                              ModuleOp rewriterModule);
256 
257   /// Generate the bytecode for the given operation.
258   void generate(Region *region, ByteCodeWriter &writer);
259   void generate(Operation *op, ByteCodeWriter &writer);
260   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299   /// Mapping from value to its corresponding memory index.
300   DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302   /// Mapping from a range value to its corresponding range storage index.
303   DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305   /// Mapping from the name of an externally registered rewrite to its index in
306   /// the bytecode registry.
307   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309   /// Mapping from the name of an externally registered constraint to its index
310   /// in the bytecode registry.
311   llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313   /// Mapping from rewriter function name to the bytecode address of the
314   /// rewriter function in byte.
315   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317   /// Mapping from a uniqued storage object to its memory index within
318   /// `uniquedData`.
319   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321   /// The current level of the foreach loop.
322   ByteCodeField curLoopLevel = 0;
323 
324   /// The current MLIR context.
325   MLIRContext *ctx;
326 
327   /// Mapping from block to its address.
328   DenseMap<Block *, ByteCodeAddr> blockToAddr;
329 
330   /// Data of the ByteCode class to be populated.
331   std::vector<const void *> &uniquedData;
332   SmallVectorImpl<ByteCodeField> &matcherByteCode;
333   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334   SmallVectorImpl<PDLByteCodePattern> &patterns;
335   ByteCodeField &maxValueMemoryIndex;
336   ByteCodeField &maxOpRangeMemoryIndex;
337   ByteCodeField &maxTypeRangeMemoryIndex;
338   ByteCodeField &maxValueRangeMemoryIndex;
339   ByteCodeField &maxLoopLevel;
340 
341   /// A map of pattern configurations.
342   const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
347   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348       : bytecode(bytecode), generator(generator) {}
349 
350   /// Append a field to the bytecode.
351   void append(ByteCodeField field) { bytecode.push_back(field); }
352   void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354   /// Append an address to the bytecode.
355   void append(ByteCodeAddr field) {
356     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357                   "unexpected ByteCode address size");
358 
359     ByteCodeField fieldParts[2];
360     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361     bytecode.append({fieldParts[0], fieldParts[1]});
362   }
363 
364   /// Append a single successor to the bytecode, the exact address will need to
365   /// be resolved later.
366   void append(Block *successor) {
367     // Add back a reference to the successor so that the address can be resolved
368     // later.
369     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370     append(ByteCodeAddr(0));
371   }
372 
373   /// Append a successor range to the bytecode, the exact address will need to
374   /// be resolved later.
375   void append(SuccessorRange successors) {
376     for (Block *successor : successors)
377       append(successor);
378   }
379 
380   /// Append a range of values that will be read as generic PDLValues.
381   void appendPDLValueList(OperandRange values) {
382     bytecode.push_back(values.size());
383     for (Value value : values)
384       appendPDLValue(value);
385   }
386 
387   /// Append a value as a PDLValue.
388   void appendPDLValue(Value value) {
389     appendPDLValueKind(value);
390     append(value);
391   }
392 
393   /// Append the PDLValue::Kind of the given value.
394   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396   /// Append the PDLValue::Kind of the given type.
397   void appendPDLValueKind(Type type) {
398     PDLValue::Kind kind =
399         TypeSwitch<Type, PDLValue::Kind>(type)
400             .Case<pdl::AttributeType>(
401                 [](Type) { return PDLValue::Kind::Attribute; })
402             .Case<pdl::OperationType>(
403                 [](Type) { return PDLValue::Kind::Operation; })
404             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405               if (rangeTy.getElementType().isa<pdl::TypeType>())
406                 return PDLValue::Kind::TypeRange;
407               return PDLValue::Kind::ValueRange;
408             })
409             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411     bytecode.push_back(static_cast<ByteCodeField>(kind));
412   }
413 
414   /// Append a value that will be stored in a memory slot and not inline within
415   /// the bytecode.
416   template <typename T>
417   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418                    std::is_pointer<T>::value>
419   append(T value) {
420     bytecode.push_back(generator.getMemIndex(value));
421   }
422 
423   /// Append a range of values.
424   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426   append(T range) {
427     bytecode.push_back(llvm::size(range));
428     for (auto it : range)
429       append(it);
430   }
431 
432   /// Append a variadic number of fields to the bytecode.
433   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435     append(field);
436     append(field2, fields...);
437   }
438 
439   /// Appends a value as a pointer, stored inline within the bytecode.
440   template <typename T>
441   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442   appendInline(T value) {
443     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444     const void *pointer = value.getAsOpaquePointer();
445     ByteCodeField fieldParts[numParts];
446     std::memcpy(fieldParts, &pointer, sizeof(const void *));
447     bytecode.append(fieldParts, fieldParts + numParts);
448   }
449 
450   /// Successor references in the bytecode that have yet to be resolved.
451   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453   /// The underlying bytecode buffer.
454   SmallVectorImpl<ByteCodeField> &bytecode;
455 
456   /// The main generator producing PDL.
457   Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463   using Set = llvm::IntervalMap<uint64_t, char, 16>;
464   using Allocator = Set::Allocator;
465 
466   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468   /// Union this live range with the one provided.
469   void unionWith(const ByteCodeLiveRange &rhs) {
470     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471          ++it)
472       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473   }
474 
475   /// Returns true if this range overlaps with the one provided.
476   bool overlaps(const ByteCodeLiveRange &rhs) const {
477     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478         .valid();
479   }
480 
481   /// A map representing the ranges of the match/rewrite that a value is live in
482   /// the interpreter.
483   ///
484   /// We use std::unique_ptr here, because IntervalMap does not provide a
485   /// correct copy or move constructor. We can eliminate the pointer once
486   /// https://reviews.llvm.org/D113240 lands.
487   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489   /// The operation range storage index for this range.
490   std::optional<unsigned> opRangeIndex;
491 
492   /// The type range storage index for this range.
493   std::optional<unsigned> typeRangeIndex;
494 
495   /// The value range storage index for this range.
496   std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
500 void Generator::generate(ModuleOp module) {
501   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504       pdl_interp::PDLInterpDialect::getRewriterModuleName());
505   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507   // Allocate memory indices for the results of operations within the matcher
508   // and rewriters.
509   allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511   // Generate code for the rewriter functions.
512   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515     for (Operation &op : rewriterFunc.getOps())
516       generate(&op, rewriterByteCodeWriter);
517   }
518   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519          "unexpected branches in rewriter function");
520 
521   // Generate code for the matcher function.
522   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525   // Resolve successor references in the matcher.
526   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527     ByteCodeAddr addr = blockToAddr[it.first];
528     for (unsigned offsetToFix : it.second)
529       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530   }
531 }
532 
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534                                       ModuleOp rewriterModule) {
535   // Rewriters use simplistic allocation scheme that simply assigns an index to
536   // each result.
537   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539     auto processRewriterValue = [&](Value val) {
540       valueToMemIndex.try_emplace(val, index++);
541       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
542         Type elementTy = rangeType.getElementType();
543         if (elementTy.isa<pdl::TypeType>())
544           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545         else if (elementTy.isa<pdl::ValueType>())
546           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547       }
548     };
549 
550     for (BlockArgument arg : rewriterFunc.getArguments())
551       processRewriterValue(arg);
552     rewriterFunc.getBody().walk([&](Operation *op) {
553       for (Value result : op->getResults())
554         processRewriterValue(result);
555     });
556     if (index > maxValueMemoryIndex)
557       maxValueMemoryIndex = index;
558     if (typeRangeIndex > maxTypeRangeMemoryIndex)
559       maxTypeRangeMemoryIndex = typeRangeIndex;
560     if (valueRangeIndex > maxValueRangeMemoryIndex)
561       maxValueRangeMemoryIndex = valueRangeIndex;
562   }
563 
564   // The matcher function uses a more sophisticated numbering that tries to
565   // minimize the number of memory indices assigned. This is done by determining
566   // a live range of the values within the matcher, then the allocation is just
567   // finding the minimal number of overlapping live ranges. This is essentially
568   // a simplified form of register allocation where we don't necessarily have a
569   // limited number of registers, but we still want to minimize the number used.
570   DenseMap<Operation *, unsigned> opToFirstIndex;
571   DenseMap<Operation *, unsigned> opToLastIndex;
572 
573   // A custom walk that marks the first and the last index of each operation.
574   // The entry marks the beginning of the liveness range for this operation,
575   // followed by nested operations, followed by the end of the liveness range.
576   unsigned index = 0;
577   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578     opToFirstIndex.try_emplace(op, index++);
579     for (Region &region : op->getRegions())
580       for (Block &block : region.getBlocks())
581         for (Operation &nested : block)
582           walk(&nested);
583     opToLastIndex.try_emplace(op, index++);
584   };
585   walk(matcherFunc);
586 
587   // Liveness info for each of the defs within the matcher.
588   ByteCodeLiveRange::Allocator allocator;
589   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591   // Assign the root operation being matched to slot 0.
592   BlockArgument rootOpArg = matcherFunc.getArgument(0);
593   valueToMemIndex[rootOpArg] = 0;
594 
595   // Walk each of the blocks, computing the def interval that the value is used.
596   Liveness matcherLiveness(matcherFunc);
597   matcherFunc->walk([&](Block *block) {
598     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599     assert(info && "expected liveness info for block");
600     auto processValue = [&](Value value, Operation *firstUseOrDef) {
601       // We don't need to process the root op argument, this value is always
602       // assigned to the first memory slot.
603       if (value == rootOpArg)
604         return;
605 
606       // Set indices for the range of this block that the value is used.
607       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608       defRangeIt->second.liveness->insert(
609           opToFirstIndex[firstUseOrDef],
610           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611           /*dummyValue*/ 0);
612 
613       // Check to see if this value is a range type.
614       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
615         Type eleType = rangeTy.getElementType();
616         if (eleType.isa<pdl::OperationType>())
617           defRangeIt->second.opRangeIndex = 0;
618         else if (eleType.isa<pdl::TypeType>())
619           defRangeIt->second.typeRangeIndex = 0;
620         else if (eleType.isa<pdl::ValueType>())
621           defRangeIt->second.valueRangeIndex = 0;
622       }
623     };
624 
625     // Process the live-ins of this block.
626     for (Value liveIn : info->in()) {
627       // Only process the value if it has been defined in the current region.
628       // Other values that span across pdl_interp.foreach will be added higher
629       // up. This ensures that the we keep them alive for the entire duration
630       // of the loop.
631       if (liveIn.getParentRegion() == block->getParent())
632         processValue(liveIn, &block->front());
633     }
634 
635     // Process the block arguments for the entry block (those are not live-in).
636     if (block->isEntryBlock()) {
637       for (Value argument : block->getArguments())
638         processValue(argument, &block->front());
639     }
640 
641     // Process any new defs within this block.
642     for (Operation &op : *block)
643       for (Value result : op.getResults())
644         processValue(result, &op);
645   });
646 
647   // Greedily allocate memory slots using the computed def live ranges.
648   std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650   // The number of memory indices currently allocated (and its next value).
651   // Recall that the root gets allocated memory index 0.
652   ByteCodeField numIndices = 1;
653 
654   // The number of memory ranges of various types (and their next values).
655   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657   for (auto &defIt : valueDefRanges) {
658     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659     ByteCodeLiveRange &defRange = defIt.second;
660 
661     // Try to allocate to an existing index.
662     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663       ByteCodeLiveRange &existingRange = existingIndexIt.value();
664       if (!defRange.overlaps(existingRange)) {
665         existingRange.unionWith(defRange);
666         memIndex = existingIndexIt.index() + 1;
667 
668         if (defRange.opRangeIndex) {
669           if (!existingRange.opRangeIndex)
670             existingRange.opRangeIndex = numOpRanges++;
671           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672         } else if (defRange.typeRangeIndex) {
673           if (!existingRange.typeRangeIndex)
674             existingRange.typeRangeIndex = numTypeRanges++;
675           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676         } else if (defRange.valueRangeIndex) {
677           if (!existingRange.valueRangeIndex)
678             existingRange.valueRangeIndex = numValueRanges++;
679           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680         }
681         break;
682       }
683     }
684 
685     // If no existing index could be used, add a new one.
686     if (memIndex == 0) {
687       allocatedIndices.emplace_back(allocator);
688       ByteCodeLiveRange &newRange = allocatedIndices.back();
689       newRange.unionWith(defRange);
690 
691       // Allocate an index for op/type/value ranges.
692       if (defRange.opRangeIndex) {
693         newRange.opRangeIndex = numOpRanges;
694         valueToRangeIndex[defIt.first] = numOpRanges++;
695       } else if (defRange.typeRangeIndex) {
696         newRange.typeRangeIndex = numTypeRanges;
697         valueToRangeIndex[defIt.first] = numTypeRanges++;
698       } else if (defRange.valueRangeIndex) {
699         newRange.valueRangeIndex = numValueRanges;
700         valueToRangeIndex[defIt.first] = numValueRanges++;
701       }
702 
703       memIndex = allocatedIndices.size();
704       ++numIndices;
705     }
706   }
707 
708   // Print the index usage and ensure that we did not run out of index space.
709   LLVM_DEBUG({
710     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711                  << "(down from initial " << valueDefRanges.size() << ").\n";
712   });
713   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714          "Ran out of memory for allocated indices");
715 
716   // Update the max number of indices.
717   if (numIndices > maxValueMemoryIndex)
718     maxValueMemoryIndex = numIndices;
719   if (numOpRanges > maxOpRangeMemoryIndex)
720     maxOpRangeMemoryIndex = numOpRanges;
721   if (numTypeRanges > maxTypeRangeMemoryIndex)
722     maxTypeRangeMemoryIndex = numTypeRanges;
723   if (numValueRanges > maxValueRangeMemoryIndex)
724     maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728   llvm::ReversePostOrderTraversal<Region *> rpot(region);
729   for (Block *block : rpot) {
730     // Keep track of where this block begins within the matcher function.
731     blockToAddr.try_emplace(block, matcherByteCode.size());
732     for (Operation &op : *block)
733       generate(&op, writer);
734   }
735 }
736 
737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738   LLVM_DEBUG({
739     // The following list must contain all the operations that do not
740     // produce any bytecode.
741     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742       writer.appendInline(op->getLoc());
743   });
744   TypeSwitch<Operation *>(op)
745       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751             pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753             pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763             pdl_interp::SwitchResultCountOp>(
764           [&](auto interpOp) { this->generate(interpOp, writer); })
765       .Default([](Operation *) {
766         llvm_unreachable("unknown `pdl_interp` operation");
767       });
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771                          ByteCodeWriter &writer) {
772   assert(constraintToMemIndex.count(op.getName()) &&
773          "expected index for constraint function");
774   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775   writer.appendPDLValueList(op.getArgs());
776   writer.append(op.getSuccessors());
777 }
778 void Generator::generate(pdl_interp::ApplyRewriteOp op,
779                          ByteCodeWriter &writer) {
780   assert(externalRewriterToMemIndex.count(op.getName()) &&
781          "expected index for rewrite function");
782   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
783   writer.appendPDLValueList(op.getArgs());
784 
785   ResultRange results = op.getResults();
786   writer.append(ByteCodeField(results.size()));
787   for (Value result : results) {
788     // In debug mode we also record the expected kind of the result, so that we
789     // can provide extra verification of the native rewrite function.
790 #ifndef NDEBUG
791     writer.appendPDLValueKind(result);
792 #endif
793 
794     // Range results also need to append the range storage index.
795     if (result.getType().isa<pdl::RangeType>())
796       writer.append(getRangeStorageIndex(result));
797     writer.append(result);
798   }
799 }
800 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
801   Value lhs = op.getLhs();
802   if (lhs.getType().isa<pdl::RangeType>()) {
803     writer.append(OpCode::AreRangesEqual);
804     writer.appendPDLValueKind(lhs);
805     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
806     return;
807   }
808 
809   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
810 }
811 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
812   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
813 }
814 void Generator::generate(pdl_interp::CheckAttributeOp op,
815                          ByteCodeWriter &writer) {
816   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
817                 op.getSuccessors());
818 }
819 void Generator::generate(pdl_interp::CheckOperandCountOp op,
820                          ByteCodeWriter &writer) {
821   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
822                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
823                 op.getSuccessors());
824 }
825 void Generator::generate(pdl_interp::CheckOperationNameOp op,
826                          ByteCodeWriter &writer) {
827   writer.append(OpCode::CheckOperationName, op.getInputOp(),
828                 OperationName(op.getName(), ctx), op.getSuccessors());
829 }
830 void Generator::generate(pdl_interp::CheckResultCountOp op,
831                          ByteCodeWriter &writer) {
832   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
833                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
834                 op.getSuccessors());
835 }
836 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
837   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
838                 op.getSuccessors());
839 }
840 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
841   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
842                 op.getSuccessors());
843 }
844 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
845   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
846   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
847 }
848 void Generator::generate(pdl_interp::CreateAttributeOp op,
849                          ByteCodeWriter &writer) {
850   // Simply repoint the memory index of the result to the constant.
851   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
852 }
853 void Generator::generate(pdl_interp::CreateOperationOp op,
854                          ByteCodeWriter &writer) {
855   writer.append(OpCode::CreateOperation, op.getResultOp(),
856                 OperationName(op.getName(), ctx));
857   writer.appendPDLValueList(op.getInputOperands());
858 
859   // Add the attributes.
860   OperandRange attributes = op.getInputAttributes();
861   writer.append(static_cast<ByteCodeField>(attributes.size()));
862   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
863     writer.append(std::get<0>(it), std::get<1>(it));
864 
865   // Add the result types. If the operation has inferred results, we use a
866   // marker "size" value. Otherwise, we add the list of explicit result types.
867   if (op.getInferredResultTypes())
868     writer.append(kInferTypesMarker);
869   else
870     writer.appendPDLValueList(op.getInputResultTypes());
871 }
872 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
873   // Append the correct opcode for the range type.
874   TypeSwitch<Type>(op.getType().getElementType())
875       .Case(
876           [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
877       .Case([&](pdl::ValueType) {
878         writer.append(OpCode::CreateDynamicValueRange);
879       });
880 
881   writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
882   writer.appendPDLValueList(op->getOperands());
883 }
884 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
885   // Simply repoint the memory index of the result to the constant.
886   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
887 }
888 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
889   writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
890                 getRangeStorageIndex(op.getResult()), op.getValue());
891 }
892 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
893   writer.append(OpCode::EraseOp, op.getInputOp());
894 }
895 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
896   OpCode opCode =
897       TypeSwitch<Type, OpCode>(op.getResult().getType())
898           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
899           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
900           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
901           .Default([](Type) -> OpCode {
902             llvm_unreachable("unsupported element type");
903           });
904   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
905 }
906 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
907   writer.append(OpCode::Finalize);
908 }
909 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
910   BlockArgument arg = op.getLoopVariable();
911   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
912   writer.appendPDLValueKind(arg.getType());
913   writer.append(curLoopLevel, op.getSuccessor());
914   ++curLoopLevel;
915   if (curLoopLevel > maxLoopLevel)
916     maxLoopLevel = curLoopLevel;
917   generate(&op.getRegion(), writer);
918   --curLoopLevel;
919 }
920 void Generator::generate(pdl_interp::GetAttributeOp op,
921                          ByteCodeWriter &writer) {
922   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
923                 op.getNameAttr());
924 }
925 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
926                          ByteCodeWriter &writer) {
927   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
928 }
929 void Generator::generate(pdl_interp::GetDefiningOpOp op,
930                          ByteCodeWriter &writer) {
931   writer.append(OpCode::GetDefiningOp, op.getInputOp());
932   writer.appendPDLValue(op.getValue());
933 }
934 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
935   uint32_t index = op.getIndex();
936   if (index < 4)
937     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
938   else
939     writer.append(OpCode::GetOperandN, index);
940   writer.append(op.getInputOp(), op.getValue());
941 }
942 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
943   Value result = op.getValue();
944   std::optional<uint32_t> index = op.getIndex();
945   writer.append(OpCode::GetOperands,
946                 index.value_or(std::numeric_limits<uint32_t>::max()),
947                 op.getInputOp());
948   if (result.getType().isa<pdl::RangeType>())
949     writer.append(getRangeStorageIndex(result));
950   else
951     writer.append(std::numeric_limits<ByteCodeField>::max());
952   writer.append(result);
953 }
954 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
955   uint32_t index = op.getIndex();
956   if (index < 4)
957     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
958   else
959     writer.append(OpCode::GetResultN, index);
960   writer.append(op.getInputOp(), op.getValue());
961 }
962 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
963   Value result = op.getValue();
964   std::optional<uint32_t> index = op.getIndex();
965   writer.append(OpCode::GetResults,
966                 index.value_or(std::numeric_limits<uint32_t>::max()),
967                 op.getInputOp());
968   if (result.getType().isa<pdl::RangeType>())
969     writer.append(getRangeStorageIndex(result));
970   else
971     writer.append(std::numeric_limits<ByteCodeField>::max());
972   writer.append(result);
973 }
974 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
975   Value operations = op.getOperations();
976   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
977   writer.append(OpCode::GetUsers, operations, rangeIndex);
978   writer.appendPDLValue(op.getValue());
979 }
980 void Generator::generate(pdl_interp::GetValueTypeOp op,
981                          ByteCodeWriter &writer) {
982   if (op.getType().isa<pdl::RangeType>()) {
983     Value result = op.getResult();
984     writer.append(OpCode::GetValueRangeTypes, result,
985                   getRangeStorageIndex(result), op.getValue());
986   } else {
987     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
988   }
989 }
990 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
991   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
992 }
993 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
994   ByteCodeField patternIndex = patterns.size();
995   patterns.emplace_back(PDLByteCodePattern::create(
996       op, configMap.lookup(op),
997       rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
998   writer.append(OpCode::RecordMatch, patternIndex,
999                 SuccessorRange(op.getOperation()), op.getMatchedOps());
1000   writer.appendPDLValueList(op.getInputs());
1001 }
1002 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1003   writer.append(OpCode::ReplaceOp, op.getInputOp());
1004   writer.appendPDLValueList(op.getReplValues());
1005 }
1006 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1007                          ByteCodeWriter &writer) {
1008   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1009                 op.getCaseValuesAttr(), op.getSuccessors());
1010 }
1011 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1012                          ByteCodeWriter &writer) {
1013   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1014                 op.getCaseValuesAttr(), op.getSuccessors());
1015 }
1016 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1017                          ByteCodeWriter &writer) {
1018   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1019     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1020   });
1021   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1022                 op.getSuccessors());
1023 }
1024 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1025                          ByteCodeWriter &writer) {
1026   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1027                 op.getCaseValuesAttr(), op.getSuccessors());
1028 }
1029 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1030   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1031                 op.getSuccessors());
1032 }
1033 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1034   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1035                 op.getSuccessors());
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // PDLByteCode
1040 //===----------------------------------------------------------------------===//
1041 
1042 PDLByteCode::PDLByteCode(
1043     ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1044     const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1045     llvm::StringMap<PDLConstraintFunction> constraintFns,
1046     llvm::StringMap<PDLRewriteFunction> rewriteFns)
1047     : configs(std::move(configs)) {
1048   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1049                       rewriterByteCode, patterns, maxValueMemoryIndex,
1050                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1051                       maxLoopLevel, constraintFns, rewriteFns, configMap);
1052   generator.generate(module);
1053 
1054   // Initialize the external functions.
1055   for (auto &it : constraintFns)
1056     constraintFunctions.push_back(std::move(it.second));
1057   for (auto &it : rewriteFns)
1058     rewriteFunctions.push_back(std::move(it.second));
1059 }
1060 
1061 /// Initialize the given state such that it can be used to execute the current
1062 /// bytecode.
1063 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1064   state.memory.resize(maxValueMemoryIndex, nullptr);
1065   state.opRangeMemory.resize(maxOpRangeCount);
1066   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1067   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1068   state.loopIndex.resize(maxLoopLevel, 0);
1069   state.currentPatternBenefits.reserve(patterns.size());
1070   for (const PDLByteCodePattern &pattern : patterns)
1071     state.currentPatternBenefits.push_back(pattern.getBenefit());
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // ByteCode Execution
1076 
1077 namespace {
1078 /// This class provides support for executing a bytecode stream.
1079 class ByteCodeExecutor {
1080 public:
1081   ByteCodeExecutor(
1082       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1083       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1084       MutableArrayRef<TypeRange> typeRangeMemory,
1085       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1086       MutableArrayRef<ValueRange> valueRangeMemory,
1087       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1088       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1089       ArrayRef<ByteCodeField> code,
1090       ArrayRef<PatternBenefit> currentPatternBenefits,
1091       ArrayRef<PDLByteCodePattern> patterns,
1092       ArrayRef<PDLConstraintFunction> constraintFunctions,
1093       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1094       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1095         typeRangeMemory(typeRangeMemory),
1096         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1097         valueRangeMemory(valueRangeMemory),
1098         allocatedValueRangeMemory(allocatedValueRangeMemory),
1099         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1100         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1101         constraintFunctions(constraintFunctions),
1102         rewriteFunctions(rewriteFunctions) {}
1103 
1104   /// Start executing the code at the current bytecode index. `matches` is an
1105   /// optional field provided when this function is executed in a matching
1106   /// context.
1107   LogicalResult
1108   execute(PatternRewriter &rewriter,
1109           SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1110           std::optional<Location> mainRewriteLoc = {});
1111 
1112 private:
1113   /// Internal implementation of executing each of the bytecode commands.
1114   void executeApplyConstraint(PatternRewriter &rewriter);
1115   LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1116   void executeAreEqual();
1117   void executeAreRangesEqual();
1118   void executeBranch();
1119   void executeCheckOperandCount();
1120   void executeCheckOperationName();
1121   void executeCheckResultCount();
1122   void executeCheckTypes();
1123   void executeContinue();
1124   void executeCreateConstantTypeRange();
1125   void executeCreateOperation(PatternRewriter &rewriter,
1126                               Location mainRewriteLoc);
1127   template <typename T>
1128   void executeDynamicCreateRange(StringRef type);
1129   void executeEraseOp(PatternRewriter &rewriter);
1130   template <typename T, typename Range, PDLValue::Kind kind>
1131   void executeExtract();
1132   void executeFinalize();
1133   void executeForEach();
1134   void executeGetAttribute();
1135   void executeGetAttributeType();
1136   void executeGetDefiningOp();
1137   void executeGetOperand(unsigned index);
1138   void executeGetOperands();
1139   void executeGetResult(unsigned index);
1140   void executeGetResults();
1141   void executeGetUsers();
1142   void executeGetValueType();
1143   void executeGetValueRangeTypes();
1144   void executeIsNotNull();
1145   void executeRecordMatch(PatternRewriter &rewriter,
1146                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1147   void executeReplaceOp(PatternRewriter &rewriter);
1148   void executeSwitchAttribute();
1149   void executeSwitchOperandCount();
1150   void executeSwitchOperationName();
1151   void executeSwitchResultCount();
1152   void executeSwitchType();
1153   void executeSwitchTypes();
1154 
1155   /// Pushes a code iterator to the stack.
1156   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1157 
1158   /// Pops a code iterator from the stack, returning true on success.
1159   void popCodeIt() {
1160     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1161     curCodeIt = resumeCodeIt.back();
1162     resumeCodeIt.pop_back();
1163   }
1164 
1165   /// Return the bytecode iterator at the start of the current op code.
1166   const ByteCodeField *getPrevCodeIt() const {
1167     LLVM_DEBUG({
1168       // Account for the op code and the Location stored inline.
1169       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1170     });
1171 
1172     // Account for the op code only.
1173     return curCodeIt - 1;
1174   }
1175 
1176   /// Read a value from the bytecode buffer, optionally skipping a certain
1177   /// number of prefix values. These methods always update the buffer to point
1178   /// to the next field after the read data.
1179   template <typename T = ByteCodeField>
1180   T read(size_t skipN = 0) {
1181     curCodeIt += skipN;
1182     return readImpl<T>();
1183   }
1184   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1185 
1186   /// Read a list of values from the bytecode buffer.
1187   template <typename ValueT, typename T>
1188   void readList(SmallVectorImpl<T> &list) {
1189     list.clear();
1190     for (unsigned i = 0, e = read(); i != e; ++i)
1191       list.push_back(read<ValueT>());
1192   }
1193 
1194   /// Read a list of values from the bytecode buffer. The values may be encoded
1195   /// either as a single element or a range of elements.
1196   void readList(SmallVectorImpl<Type> &list) {
1197     for (unsigned i = 0, e = read(); i != e; ++i) {
1198       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1199         list.push_back(read<Type>());
1200       } else {
1201         TypeRange *values = read<TypeRange *>();
1202         list.append(values->begin(), values->end());
1203       }
1204     }
1205   }
1206   void readList(SmallVectorImpl<Value> &list) {
1207     for (unsigned i = 0, e = read(); i != e; ++i) {
1208       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1209         list.push_back(read<Value>());
1210       } else {
1211         ValueRange *values = read<ValueRange *>();
1212         list.append(values->begin(), values->end());
1213       }
1214     }
1215   }
1216 
1217   /// Read a value stored inline as a pointer.
1218   template <typename T>
1219   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1220   readInline() {
1221     const void *pointer;
1222     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1223     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1224     return T::getFromOpaquePointer(pointer);
1225   }
1226 
1227   /// Jump to a specific successor based on a predicate value.
1228   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1229   /// Jump to a specific successor based on a destination index.
1230   void selectJump(size_t destIndex) {
1231     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1232   }
1233 
1234   /// Handle a switch operation with the provided value and cases.
1235   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1236   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1237     LLVM_DEBUG({
1238       llvm::dbgs() << "  * Value: " << value << "\n"
1239                    << "  * Cases: ";
1240       llvm::interleaveComma(cases, llvm::dbgs());
1241       llvm::dbgs() << "\n";
1242     });
1243 
1244     // Check to see if the attribute value is within the case list. Jump to
1245     // the correct successor index based on the result.
1246     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1247       if (cmp(*it, value))
1248         return selectJump(size_t((it - cases.begin()) + 1));
1249     selectJump(size_t(0));
1250   }
1251 
1252   /// Store a pointer to memory.
1253   void storeToMemory(unsigned index, const void *value) {
1254     memory[index] = value;
1255   }
1256 
1257   /// Store a value to memory as an opaque pointer.
1258   template <typename T>
1259   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1260   storeToMemory(unsigned index, T value) {
1261     memory[index] = value.getAsOpaquePointer();
1262   }
1263 
1264   /// Internal implementation of reading various data types from the bytecode
1265   /// stream.
1266   template <typename T>
1267   const void *readFromMemory() {
1268     size_t index = *curCodeIt++;
1269 
1270     // If this type is an SSA value, it can only be stored in non-const memory.
1271     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1272                         Value>::value ||
1273         index < memory.size())
1274       return memory[index];
1275 
1276     // Otherwise, if this index is not inbounds it is uniqued.
1277     return uniquedMemory[index - memory.size()];
1278   }
1279   template <typename T>
1280   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1281     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1282   }
1283   template <typename T>
1284   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1285                    T>
1286   readImpl() {
1287     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1288   }
1289   template <typename T>
1290   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1291     switch (read<PDLValue::Kind>()) {
1292     case PDLValue::Kind::Attribute:
1293       return read<Attribute>();
1294     case PDLValue::Kind::Operation:
1295       return read<Operation *>();
1296     case PDLValue::Kind::Type:
1297       return read<Type>();
1298     case PDLValue::Kind::Value:
1299       return read<Value>();
1300     case PDLValue::Kind::TypeRange:
1301       return read<TypeRange *>();
1302     case PDLValue::Kind::ValueRange:
1303       return read<ValueRange *>();
1304     }
1305     llvm_unreachable("unhandled PDLValue::Kind");
1306   }
1307   template <typename T>
1308   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1309     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1310                   "unexpected ByteCode address size");
1311     ByteCodeAddr result;
1312     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1313     curCodeIt += 2;
1314     return result;
1315   }
1316   template <typename T>
1317   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1318     return *curCodeIt++;
1319   }
1320   template <typename T>
1321   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1322     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1323   }
1324 
1325   /// Assign the given range to the given memory index. This allocates a new
1326   /// range object if necessary.
1327   template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1328   void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1329                            unsigned rangeIndex) {
1330     // Utility functor used to type-erase the assignment.
1331     auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1332       // If the input range is empty, we don't need to allocate anything.
1333       if (range.empty()) {
1334         rangeMemory[rangeIndex] = {};
1335       } else {
1336         // Allocate a buffer for this type range.
1337         llvm::OwningArrayRef<T> storage(llvm::size(range));
1338         llvm::copy(range, storage.begin());
1339 
1340         // Assign this to the range slot and use the range as the value for the
1341         // memory index.
1342         allocatedRangeMemory.emplace_back(std::move(storage));
1343         rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1344       }
1345       memory[memIndex] = &rangeMemory[rangeIndex];
1346     };
1347 
1348     // Dispatch based on the concrete range type.
1349     if constexpr (std::is_same_v<T, Type>) {
1350       return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1351     } else if constexpr (std::is_same_v<T, Value>) {
1352       return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1353     } else {
1354       llvm_unreachable("unhandled range type");
1355     }
1356   }
1357 
1358   /// The underlying bytecode buffer.
1359   const ByteCodeField *curCodeIt;
1360 
1361   /// The stack of bytecode positions at which to resume operation.
1362   SmallVector<const ByteCodeField *> resumeCodeIt;
1363 
1364   /// The current execution memory.
1365   MutableArrayRef<const void *> memory;
1366   MutableArrayRef<OwningOpRange> opRangeMemory;
1367   MutableArrayRef<TypeRange> typeRangeMemory;
1368   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1369   MutableArrayRef<ValueRange> valueRangeMemory;
1370   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1371 
1372   /// The current loop indices.
1373   MutableArrayRef<unsigned> loopIndex;
1374 
1375   /// References to ByteCode data necessary for execution.
1376   ArrayRef<const void *> uniquedMemory;
1377   ArrayRef<ByteCodeField> code;
1378   ArrayRef<PatternBenefit> currentPatternBenefits;
1379   ArrayRef<PDLByteCodePattern> patterns;
1380   ArrayRef<PDLConstraintFunction> constraintFunctions;
1381   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1382 };
1383 
1384 /// This class is an instantiation of the PDLResultList that provides access to
1385 /// the returned results. This API is not on `PDLResultList` to avoid
1386 /// overexposing access to information specific solely to the ByteCode.
1387 class ByteCodeRewriteResultList : public PDLResultList {
1388 public:
1389   ByteCodeRewriteResultList(unsigned maxNumResults)
1390       : PDLResultList(maxNumResults) {}
1391 
1392   /// Return the list of PDL results.
1393   MutableArrayRef<PDLValue> getResults() { return results; }
1394 
1395   /// Return the type ranges allocated by this list.
1396   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1397     return allocatedTypeRanges;
1398   }
1399 
1400   /// Return the value ranges allocated by this list.
1401   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1402     return allocatedValueRanges;
1403   }
1404 };
1405 } // namespace
1406 
1407 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1408   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1409   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1410   SmallVector<PDLValue, 16> args;
1411   readList<PDLValue>(args);
1412 
1413   LLVM_DEBUG({
1414     llvm::dbgs() << "  * Arguments: ";
1415     llvm::interleaveComma(args, llvm::dbgs());
1416   });
1417 
1418   // Invoke the constraint and jump to the proper destination.
1419   selectJump(succeeded(constraintFn(rewriter, args)));
1420 }
1421 
1422 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1423   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1424   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1425   SmallVector<PDLValue, 16> args;
1426   readList<PDLValue>(args);
1427 
1428   LLVM_DEBUG({
1429     llvm::dbgs() << "  * Arguments: ";
1430     llvm::interleaveComma(args, llvm::dbgs());
1431   });
1432 
1433   // Execute the rewrite function.
1434   ByteCodeField numResults = read();
1435   ByteCodeRewriteResultList results(numResults);
1436   LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1437 
1438   assert(results.getResults().size() == numResults &&
1439          "native PDL rewrite function returned unexpected number of results");
1440 
1441   // Store the results in the bytecode memory.
1442   for (PDLValue &result : results.getResults()) {
1443     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1444 
1445 // In debug mode we also verify the expected kind of the result.
1446 #ifndef NDEBUG
1447     assert(result.getKind() == read<PDLValue::Kind>() &&
1448            "native PDL rewrite function returned an unexpected type of result");
1449 #endif
1450 
1451     // If the result is a range, we need to copy it over to the bytecodes
1452     // range memory.
1453     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1454       unsigned rangeIndex = read();
1455       typeRangeMemory[rangeIndex] = *typeRange;
1456       memory[read()] = &typeRangeMemory[rangeIndex];
1457     } else if (std::optional<ValueRange> valueRange =
1458                    result.dyn_cast<ValueRange>()) {
1459       unsigned rangeIndex = read();
1460       valueRangeMemory[rangeIndex] = *valueRange;
1461       memory[read()] = &valueRangeMemory[rangeIndex];
1462     } else {
1463       memory[read()] = result.getAsOpaquePointer();
1464     }
1465   }
1466 
1467   // Copy over any underlying storage allocated for result ranges.
1468   for (auto &it : results.getAllocatedTypeRanges())
1469     allocatedTypeRangeMemory.push_back(std::move(it));
1470   for (auto &it : results.getAllocatedValueRanges())
1471     allocatedValueRangeMemory.push_back(std::move(it));
1472 
1473   // Process the result of the rewrite.
1474   if (failed(rewriteResult)) {
1475     LLVM_DEBUG(llvm::dbgs() << "  - Failed");
1476     return failure();
1477   }
1478   return success();
1479 }
1480 
1481 void ByteCodeExecutor::executeAreEqual() {
1482   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1483   const void *lhs = read<const void *>();
1484   const void *rhs = read<const void *>();
1485 
1486   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1487   selectJump(lhs == rhs);
1488 }
1489 
1490 void ByteCodeExecutor::executeAreRangesEqual() {
1491   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1492   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1493   const void *lhs = read<const void *>();
1494   const void *rhs = read<const void *>();
1495 
1496   switch (valueKind) {
1497   case PDLValue::Kind::TypeRange: {
1498     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1499     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1500     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1501     selectJump(*lhsRange == *rhsRange);
1502     break;
1503   }
1504   case PDLValue::Kind::ValueRange: {
1505     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1506     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1507     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1508     selectJump(*lhsRange == *rhsRange);
1509     break;
1510   }
1511   default:
1512     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1513   }
1514 }
1515 
1516 void ByteCodeExecutor::executeBranch() {
1517   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1518   curCodeIt = &code[read<ByteCodeAddr>()];
1519 }
1520 
1521 void ByteCodeExecutor::executeCheckOperandCount() {
1522   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1523   Operation *op = read<Operation *>();
1524   uint32_t expectedCount = read<uint32_t>();
1525   bool compareAtLeast = read();
1526 
1527   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1528                           << "  * Expected: " << expectedCount << "\n"
1529                           << "  * Comparator: "
1530                           << (compareAtLeast ? ">=" : "==") << "\n");
1531   if (compareAtLeast)
1532     selectJump(op->getNumOperands() >= expectedCount);
1533   else
1534     selectJump(op->getNumOperands() == expectedCount);
1535 }
1536 
1537 void ByteCodeExecutor::executeCheckOperationName() {
1538   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1539   Operation *op = read<Operation *>();
1540   OperationName expectedName = read<OperationName>();
1541 
1542   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1543                           << "  * Expected: \"" << expectedName << "\"\n");
1544   selectJump(op->getName() == expectedName);
1545 }
1546 
1547 void ByteCodeExecutor::executeCheckResultCount() {
1548   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1549   Operation *op = read<Operation *>();
1550   uint32_t expectedCount = read<uint32_t>();
1551   bool compareAtLeast = read();
1552 
1553   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1554                           << "  * Expected: " << expectedCount << "\n"
1555                           << "  * Comparator: "
1556                           << (compareAtLeast ? ">=" : "==") << "\n");
1557   if (compareAtLeast)
1558     selectJump(op->getNumResults() >= expectedCount);
1559   else
1560     selectJump(op->getNumResults() == expectedCount);
1561 }
1562 
1563 void ByteCodeExecutor::executeCheckTypes() {
1564   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1565   TypeRange *lhs = read<TypeRange *>();
1566   Attribute rhs = read<Attribute>();
1567   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1568 
1569   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1570 }
1571 
1572 void ByteCodeExecutor::executeContinue() {
1573   ByteCodeField level = read();
1574   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1575                           << "  * Level: " << level << "\n");
1576   ++loopIndex[level];
1577   popCodeIt();
1578 }
1579 
1580 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1581   LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1582   unsigned memIndex = read();
1583   unsigned rangeIndex = read();
1584   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1585 
1586   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1587   assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1588                       rangeIndex);
1589 }
1590 
1591 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1592                                               Location mainRewriteLoc) {
1593   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1594 
1595   unsigned memIndex = read();
1596   OperationState state(mainRewriteLoc, read<OperationName>());
1597   readList(state.operands);
1598   for (unsigned i = 0, e = read(); i != e; ++i) {
1599     StringAttr name = read<StringAttr>();
1600     if (Attribute attr = read<Attribute>())
1601       state.addAttribute(name, attr);
1602   }
1603 
1604   // Read in the result types. If the "size" is the sentinel value, this
1605   // indicates that the result types should be inferred.
1606   unsigned numResults = read();
1607   if (numResults == kInferTypesMarker) {
1608     InferTypeOpInterface::Concept *inferInterface =
1609         state.name.getInterface<InferTypeOpInterface>();
1610     assert(inferInterface &&
1611            "expected operation to provide InferTypeOpInterface");
1612 
1613     // TODO: Handle failure.
1614     if (failed(inferInterface->inferReturnTypes(
1615             state.getContext(), state.location, state.operands,
1616             state.attributes.getDictionary(state.getContext()), state.regions,
1617             state.types)))
1618       return;
1619   } else {
1620     // Otherwise, this is a fixed number of results.
1621     for (unsigned i = 0; i != numResults; ++i) {
1622       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1623         state.types.push_back(read<Type>());
1624       } else {
1625         TypeRange *resultTypes = read<TypeRange *>();
1626         state.types.append(resultTypes->begin(), resultTypes->end());
1627       }
1628     }
1629   }
1630 
1631   Operation *resultOp = rewriter.create(state);
1632   memory[memIndex] = resultOp;
1633 
1634   LLVM_DEBUG({
1635     llvm::dbgs() << "  * Attributes: "
1636                  << state.attributes.getDictionary(state.getContext())
1637                  << "\n  * Operands: ";
1638     llvm::interleaveComma(state.operands, llvm::dbgs());
1639     llvm::dbgs() << "\n  * Result Types: ";
1640     llvm::interleaveComma(state.types, llvm::dbgs());
1641     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1642   });
1643 }
1644 
1645 template <typename T>
1646 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1647   LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1648   unsigned memIndex = read();
1649   unsigned rangeIndex = read();
1650   SmallVector<T> values;
1651   readList(values);
1652 
1653   LLVM_DEBUG({
1654     llvm::dbgs() << "\n  * " << type << "s: ";
1655     llvm::interleaveComma(values, llvm::dbgs());
1656     llvm::dbgs() << "\n";
1657   });
1658 
1659   assignRangeToMemory(values, memIndex, rangeIndex);
1660 }
1661 
1662 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1663   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1664   Operation *op = read<Operation *>();
1665 
1666   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1667   rewriter.eraseOp(op);
1668 }
1669 
1670 template <typename T, typename Range, PDLValue::Kind kind>
1671 void ByteCodeExecutor::executeExtract() {
1672   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1673   Range *range = read<Range *>();
1674   unsigned index = read<uint32_t>();
1675   unsigned memIndex = read();
1676 
1677   if (!range) {
1678     memory[memIndex] = nullptr;
1679     return;
1680   }
1681 
1682   T result = index < range->size() ? (*range)[index] : T();
1683   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1684                           << "  * Index: " << index << "\n"
1685                           << "  * Result: " << result << "\n");
1686   storeToMemory(memIndex, result);
1687 }
1688 
1689 void ByteCodeExecutor::executeFinalize() {
1690   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1691 }
1692 
1693 void ByteCodeExecutor::executeForEach() {
1694   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1695   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1696   unsigned rangeIndex = read();
1697   unsigned memIndex = read();
1698   const void *value = nullptr;
1699 
1700   switch (read<PDLValue::Kind>()) {
1701   case PDLValue::Kind::Operation: {
1702     unsigned &index = loopIndex[read()];
1703     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1704     assert(index <= array.size() && "iterated past the end");
1705     if (index < array.size()) {
1706       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1707       value = array[index];
1708       break;
1709     }
1710 
1711     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1712     index = 0;
1713     selectJump(size_t(0));
1714     return;
1715   }
1716   default:
1717     llvm_unreachable("unexpected `ForEach` value kind");
1718   }
1719 
1720   // Store the iterate value and the stack address.
1721   memory[memIndex] = value;
1722   pushCodeIt(prevCodeIt);
1723 
1724   // Skip over the successor (we will enter the body of the loop).
1725   read<ByteCodeAddr>();
1726 }
1727 
1728 void ByteCodeExecutor::executeGetAttribute() {
1729   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1730   unsigned memIndex = read();
1731   Operation *op = read<Operation *>();
1732   StringAttr attrName = read<StringAttr>();
1733   Attribute attr = op->getAttr(attrName);
1734 
1735   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1736                           << "  * Attribute: " << attrName << "\n"
1737                           << "  * Result: " << attr << "\n");
1738   memory[memIndex] = attr.getAsOpaquePointer();
1739 }
1740 
1741 void ByteCodeExecutor::executeGetAttributeType() {
1742   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1743   unsigned memIndex = read();
1744   Attribute attr = read<Attribute>();
1745   Type type;
1746   if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1747     type = typedAttr.getType();
1748 
1749   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1750                           << "  * Result: " << type << "\n");
1751   memory[memIndex] = type.getAsOpaquePointer();
1752 }
1753 
1754 void ByteCodeExecutor::executeGetDefiningOp() {
1755   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1756   unsigned memIndex = read();
1757   Operation *op = nullptr;
1758   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1759     Value value = read<Value>();
1760     if (value)
1761       op = value.getDefiningOp();
1762     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1763   } else {
1764     ValueRange *values = read<ValueRange *>();
1765     if (values && !values->empty()) {
1766       op = values->front().getDefiningOp();
1767     }
1768     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1769   }
1770 
1771   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1772   memory[memIndex] = op;
1773 }
1774 
1775 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1776   Operation *op = read<Operation *>();
1777   unsigned memIndex = read();
1778   Value operand =
1779       index < op->getNumOperands() ? op->getOperand(index) : Value();
1780 
1781   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1782                           << "  * Index: " << index << "\n"
1783                           << "  * Result: " << operand << "\n");
1784   memory[memIndex] = operand.getAsOpaquePointer();
1785 }
1786 
1787 /// This function is the internal implementation of `GetResults` and
1788 /// `GetOperands` that provides support for extracting a value range from the
1789 /// given operation.
1790 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1791 static void *
1792 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1793                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1794                           MutableArrayRef<ValueRange> valueRangeMemory) {
1795   // Check for the sentinel index that signals that all values should be
1796   // returned.
1797   if (index == std::numeric_limits<uint32_t>::max()) {
1798     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1799     // `values` is already the full value range.
1800 
1801     // Otherwise, check to see if this operation uses AttrSizedSegments.
1802   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1803     LLVM_DEBUG(llvm::dbgs()
1804                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1805 
1806     auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1807     if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1808       return nullptr;
1809 
1810     ArrayRef<int32_t> segments = segmentAttr;
1811     unsigned startIndex =
1812         std::accumulate(segments.begin(), segments.begin() + index, 0);
1813     values = values.slice(startIndex, *std::next(segments.begin(), index));
1814 
1815     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1816                             << *std::next(segments.begin(), index) << "]\n");
1817 
1818     // Otherwise, assume this is the last operand group of the operation.
1819     // FIXME: We currently don't support operations with
1820     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1821     // have a way to detect it's presence.
1822   } else if (values.size() >= index) {
1823     LLVM_DEBUG(llvm::dbgs()
1824                << "  * Treating values as trailing variadic range\n");
1825     values = values.drop_front(index);
1826 
1827     // If we couldn't detect a way to compute the values, bail out.
1828   } else {
1829     return nullptr;
1830   }
1831 
1832   // If the range index is valid, we are returning a range.
1833   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1834     valueRangeMemory[rangeIndex] = values;
1835     return &valueRangeMemory[rangeIndex];
1836   }
1837 
1838   // If a range index wasn't provided, the range is required to be non-variadic.
1839   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1840 }
1841 
1842 void ByteCodeExecutor::executeGetOperands() {
1843   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1844   unsigned index = read<uint32_t>();
1845   Operation *op = read<Operation *>();
1846   ByteCodeField rangeIndex = read();
1847 
1848   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1849       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1850       valueRangeMemory);
1851   if (!result)
1852     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1853   memory[read()] = result;
1854 }
1855 
1856 void ByteCodeExecutor::executeGetResult(unsigned index) {
1857   Operation *op = read<Operation *>();
1858   unsigned memIndex = read();
1859   OpResult result =
1860       index < op->getNumResults() ? op->getResult(index) : OpResult();
1861 
1862   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1863                           << "  * Index: " << index << "\n"
1864                           << "  * Result: " << result << "\n");
1865   memory[memIndex] = result.getAsOpaquePointer();
1866 }
1867 
1868 void ByteCodeExecutor::executeGetResults() {
1869   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1870   unsigned index = read<uint32_t>();
1871   Operation *op = read<Operation *>();
1872   ByteCodeField rangeIndex = read();
1873 
1874   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1875       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1876       valueRangeMemory);
1877   if (!result)
1878     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1879   memory[read()] = result;
1880 }
1881 
1882 void ByteCodeExecutor::executeGetUsers() {
1883   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1884   unsigned memIndex = read();
1885   unsigned rangeIndex = read();
1886   OwningOpRange &range = opRangeMemory[rangeIndex];
1887   memory[memIndex] = &range;
1888 
1889   range = OwningOpRange();
1890   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1891     // Read the value.
1892     Value value = read<Value>();
1893     if (!value)
1894       return;
1895     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1896 
1897     // Extract the users of a single value.
1898     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1899     llvm::copy(value.getUsers(), range.begin());
1900   } else {
1901     // Read a range of values.
1902     ValueRange *values = read<ValueRange *>();
1903     if (!values)
1904       return;
1905     LLVM_DEBUG({
1906       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1907       llvm::interleaveComma(*values, llvm::dbgs());
1908       llvm::dbgs() << "\n";
1909     });
1910 
1911     // Extract all the users of a range of values.
1912     SmallVector<Operation *> users;
1913     for (Value value : *values)
1914       users.append(value.user_begin(), value.user_end());
1915     range = OwningOpRange(users.size());
1916     llvm::copy(users, range.begin());
1917   }
1918 
1919   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1920 }
1921 
1922 void ByteCodeExecutor::executeGetValueType() {
1923   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1924   unsigned memIndex = read();
1925   Value value = read<Value>();
1926   Type type = value ? value.getType() : Type();
1927 
1928   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1929                           << "  * Result: " << type << "\n");
1930   memory[memIndex] = type.getAsOpaquePointer();
1931 }
1932 
1933 void ByteCodeExecutor::executeGetValueRangeTypes() {
1934   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1935   unsigned memIndex = read();
1936   unsigned rangeIndex = read();
1937   ValueRange *values = read<ValueRange *>();
1938   if (!values) {
1939     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1940     memory[memIndex] = nullptr;
1941     return;
1942   }
1943 
1944   LLVM_DEBUG({
1945     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1946     llvm::interleaveComma(*values, llvm::dbgs());
1947     llvm::dbgs() << "\n  * Result: ";
1948     llvm::interleaveComma(values->getType(), llvm::dbgs());
1949     llvm::dbgs() << "\n";
1950   });
1951   typeRangeMemory[rangeIndex] = values->getType();
1952   memory[memIndex] = &typeRangeMemory[rangeIndex];
1953 }
1954 
1955 void ByteCodeExecutor::executeIsNotNull() {
1956   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1957   const void *value = read<const void *>();
1958 
1959   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1960   selectJump(value != nullptr);
1961 }
1962 
1963 void ByteCodeExecutor::executeRecordMatch(
1964     PatternRewriter &rewriter,
1965     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1966   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1967   unsigned patternIndex = read();
1968   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1969   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1970 
1971   // If the benefit of the pattern is impossible, skip the processing of the
1972   // rest of the pattern.
1973   if (benefit.isImpossibleToMatch()) {
1974     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1975     curCodeIt = dest;
1976     return;
1977   }
1978 
1979   // Create a fused location containing the locations of each of the
1980   // operations used in the match. This will be used as the location for
1981   // created operations during the rewrite that don't already have an
1982   // explicit location set.
1983   unsigned numMatchLocs = read();
1984   SmallVector<Location, 4> matchLocs;
1985   matchLocs.reserve(numMatchLocs);
1986   for (unsigned i = 0; i != numMatchLocs; ++i)
1987     matchLocs.push_back(read<Operation *>()->getLoc());
1988   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1989 
1990   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1991                           << "  * Location: " << matchLoc << "\n");
1992   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1993   PDLByteCode::MatchResult &match = matches.back();
1994 
1995   // Record all of the inputs to the match. If any of the inputs are ranges, we
1996   // will also need to remap the range pointer to memory stored in the match
1997   // state.
1998   unsigned numInputs = read();
1999   match.values.reserve(numInputs);
2000   match.typeRangeValues.reserve(numInputs);
2001   match.valueRangeValues.reserve(numInputs);
2002   for (unsigned i = 0; i < numInputs; ++i) {
2003     switch (read<PDLValue::Kind>()) {
2004     case PDLValue::Kind::TypeRange:
2005       match.typeRangeValues.push_back(*read<TypeRange *>());
2006       match.values.push_back(&match.typeRangeValues.back());
2007       break;
2008     case PDLValue::Kind::ValueRange:
2009       match.valueRangeValues.push_back(*read<ValueRange *>());
2010       match.values.push_back(&match.valueRangeValues.back());
2011       break;
2012     default:
2013       match.values.push_back(read<const void *>());
2014       break;
2015     }
2016   }
2017   curCodeIt = dest;
2018 }
2019 
2020 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2021   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2022   Operation *op = read<Operation *>();
2023   SmallVector<Value, 16> args;
2024   readList(args);
2025 
2026   LLVM_DEBUG({
2027     llvm::dbgs() << "  * Operation: " << *op << "\n"
2028                  << "  * Values: ";
2029     llvm::interleaveComma(args, llvm::dbgs());
2030     llvm::dbgs() << "\n";
2031   });
2032   rewriter.replaceOp(op, args);
2033 }
2034 
2035 void ByteCodeExecutor::executeSwitchAttribute() {
2036   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2037   Attribute value = read<Attribute>();
2038   ArrayAttr cases = read<ArrayAttr>();
2039   handleSwitch(value, cases);
2040 }
2041 
2042 void ByteCodeExecutor::executeSwitchOperandCount() {
2043   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2044   Operation *op = read<Operation *>();
2045   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2046 
2047   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2048   handleSwitch(op->getNumOperands(), cases);
2049 }
2050 
2051 void ByteCodeExecutor::executeSwitchOperationName() {
2052   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2053   OperationName value = read<Operation *>()->getName();
2054   size_t caseCount = read();
2055 
2056   // The operation names are stored in-line, so to print them out for
2057   // debugging purposes we need to read the array before executing the
2058   // switch so that we can display all of the possible values.
2059   LLVM_DEBUG({
2060     const ByteCodeField *prevCodeIt = curCodeIt;
2061     llvm::dbgs() << "  * Value: " << value << "\n"
2062                  << "  * Cases: ";
2063     llvm::interleaveComma(
2064         llvm::map_range(llvm::seq<size_t>(0, caseCount),
2065                         [&](size_t) { return read<OperationName>(); }),
2066         llvm::dbgs());
2067     llvm::dbgs() << "\n";
2068     curCodeIt = prevCodeIt;
2069   });
2070 
2071   // Try to find the switch value within any of the cases.
2072   for (size_t i = 0; i != caseCount; ++i) {
2073     if (read<OperationName>() == value) {
2074       curCodeIt += (caseCount - i - 1);
2075       return selectJump(i + 1);
2076     }
2077   }
2078   selectJump(size_t(0));
2079 }
2080 
2081 void ByteCodeExecutor::executeSwitchResultCount() {
2082   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2083   Operation *op = read<Operation *>();
2084   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2085 
2086   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2087   handleSwitch(op->getNumResults(), cases);
2088 }
2089 
2090 void ByteCodeExecutor::executeSwitchType() {
2091   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2092   Type value = read<Type>();
2093   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2094   handleSwitch(value, cases);
2095 }
2096 
2097 void ByteCodeExecutor::executeSwitchTypes() {
2098   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2099   TypeRange *value = read<TypeRange *>();
2100   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2101   if (!value) {
2102     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2103     return selectJump(size_t(0));
2104   }
2105   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2106     return value == caseValue.getAsValueRange<TypeAttr>();
2107   });
2108 }
2109 
2110 LogicalResult
2111 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2112                           SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2113                           std::optional<Location> mainRewriteLoc) {
2114   while (true) {
2115     // Print the location of the operation being executed.
2116     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2117 
2118     OpCode opCode = static_cast<OpCode>(read());
2119     switch (opCode) {
2120     case ApplyConstraint:
2121       executeApplyConstraint(rewriter);
2122       break;
2123     case ApplyRewrite:
2124       if (failed(executeApplyRewrite(rewriter)))
2125         return failure();
2126       break;
2127     case AreEqual:
2128       executeAreEqual();
2129       break;
2130     case AreRangesEqual:
2131       executeAreRangesEqual();
2132       break;
2133     case Branch:
2134       executeBranch();
2135       break;
2136     case CheckOperandCount:
2137       executeCheckOperandCount();
2138       break;
2139     case CheckOperationName:
2140       executeCheckOperationName();
2141       break;
2142     case CheckResultCount:
2143       executeCheckResultCount();
2144       break;
2145     case CheckTypes:
2146       executeCheckTypes();
2147       break;
2148     case Continue:
2149       executeContinue();
2150       break;
2151     case CreateConstantTypeRange:
2152       executeCreateConstantTypeRange();
2153       break;
2154     case CreateOperation:
2155       executeCreateOperation(rewriter, *mainRewriteLoc);
2156       break;
2157     case CreateDynamicTypeRange:
2158       executeDynamicCreateRange<Type>("Type");
2159       break;
2160     case CreateDynamicValueRange:
2161       executeDynamicCreateRange<Value>("Value");
2162       break;
2163     case EraseOp:
2164       executeEraseOp(rewriter);
2165       break;
2166     case ExtractOp:
2167       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2168       break;
2169     case ExtractType:
2170       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2171       break;
2172     case ExtractValue:
2173       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2174       break;
2175     case Finalize:
2176       executeFinalize();
2177       LLVM_DEBUG(llvm::dbgs() << "\n");
2178       return success();
2179     case ForEach:
2180       executeForEach();
2181       break;
2182     case GetAttribute:
2183       executeGetAttribute();
2184       break;
2185     case GetAttributeType:
2186       executeGetAttributeType();
2187       break;
2188     case GetDefiningOp:
2189       executeGetDefiningOp();
2190       break;
2191     case GetOperand0:
2192     case GetOperand1:
2193     case GetOperand2:
2194     case GetOperand3: {
2195       unsigned index = opCode - GetOperand0;
2196       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2197       executeGetOperand(index);
2198       break;
2199     }
2200     case GetOperandN:
2201       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2202       executeGetOperand(read<uint32_t>());
2203       break;
2204     case GetOperands:
2205       executeGetOperands();
2206       break;
2207     case GetResult0:
2208     case GetResult1:
2209     case GetResult2:
2210     case GetResult3: {
2211       unsigned index = opCode - GetResult0;
2212       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2213       executeGetResult(index);
2214       break;
2215     }
2216     case GetResultN:
2217       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2218       executeGetResult(read<uint32_t>());
2219       break;
2220     case GetResults:
2221       executeGetResults();
2222       break;
2223     case GetUsers:
2224       executeGetUsers();
2225       break;
2226     case GetValueType:
2227       executeGetValueType();
2228       break;
2229     case GetValueRangeTypes:
2230       executeGetValueRangeTypes();
2231       break;
2232     case IsNotNull:
2233       executeIsNotNull();
2234       break;
2235     case RecordMatch:
2236       assert(matches &&
2237              "expected matches to be provided when executing the matcher");
2238       executeRecordMatch(rewriter, *matches);
2239       break;
2240     case ReplaceOp:
2241       executeReplaceOp(rewriter);
2242       break;
2243     case SwitchAttribute:
2244       executeSwitchAttribute();
2245       break;
2246     case SwitchOperandCount:
2247       executeSwitchOperandCount();
2248       break;
2249     case SwitchOperationName:
2250       executeSwitchOperationName();
2251       break;
2252     case SwitchResultCount:
2253       executeSwitchResultCount();
2254       break;
2255     case SwitchType:
2256       executeSwitchType();
2257       break;
2258     case SwitchTypes:
2259       executeSwitchTypes();
2260       break;
2261     }
2262     LLVM_DEBUG(llvm::dbgs() << "\n");
2263   }
2264 }
2265 
2266 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2267                         SmallVectorImpl<MatchResult> &matches,
2268                         PDLByteCodeMutableState &state) const {
2269   // The first memory slot is always the root operation.
2270   state.memory[0] = op;
2271 
2272   // The matcher function always starts at code address 0.
2273   ByteCodeExecutor executor(
2274       matcherByteCode.data(), state.memory, state.opRangeMemory,
2275       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2276       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2277       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2278       constraintFunctions, rewriteFunctions);
2279   LogicalResult executeResult = executor.execute(rewriter, &matches);
2280   (void)executeResult;
2281   assert(succeeded(executeResult) && "unexpected matcher execution failure");
2282 
2283   // Order the found matches by benefit.
2284   std::stable_sort(matches.begin(), matches.end(),
2285                    [](const MatchResult &lhs, const MatchResult &rhs) {
2286                      return lhs.benefit > rhs.benefit;
2287                    });
2288 }
2289 
2290 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2291                                    const MatchResult &match,
2292                                    PDLByteCodeMutableState &state) const {
2293   auto *configSet = match.pattern->getConfigSet();
2294   if (configSet)
2295     configSet->notifyRewriteBegin(rewriter);
2296 
2297   // The arguments of the rewrite function are stored at the start of the
2298   // memory buffer.
2299   llvm::copy(match.values, state.memory.begin());
2300 
2301   ByteCodeExecutor executor(
2302       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2303       state.opRangeMemory, state.typeRangeMemory,
2304       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2305       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2306       rewriterByteCode, state.currentPatternBenefits, patterns,
2307       constraintFunctions, rewriteFunctions);
2308   LogicalResult result =
2309       executor.execute(rewriter, /*matches=*/nullptr, match.location);
2310 
2311   if (configSet)
2312     configSet->notifyRewriteEnd(rewriter);
2313 
2314   // If the rewrite failed, check if the pattern rewriter can recover. If it
2315   // can, we can signal to the pattern applicator to keep trying patterns. If it
2316   // doesn't, we need to bail. Bailing here should be fine, given that we have
2317   // no means to propagate such a failure to the user, and it also indicates a
2318   // bug in the user code (i.e. failable rewrites should not be used with
2319   // pattern rewriters that don't support it).
2320   if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2321     LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2322     llvm::report_fatal_error(
2323         "Native PDL Rewrite failed, but the pattern "
2324         "rewriter doesn't support recovery. Failable pattern rewrites should "
2325         "not be used with pattern rewriters that do not support them.");
2326   }
2327   return result;
2328 }
2329