xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision ce57789d8e5dc109dc9bd330232b31a22a80ad3a)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               PDLPatternConfigSet *configSet,
38                                               ByteCodeAddr rewriterAddr) {
39   PatternBenefit benefit = matchOp.getBenefit();
40   MLIRContext *ctx = matchOp.getContext();
41 
42   // Collect the set of generated operations.
43   SmallVector<StringRef, 8> generatedOps;
44   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
45     generatedOps =
46         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
47 
48   // Check to see if this is pattern matches a specific operation type.
49   if (Optional<StringRef> rootKind = matchOp.getRootKind())
50     return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
51                               generatedOps);
52   return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
53                             benefit, ctx, generatedOps);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // PDLByteCodeMutableState
58 //===----------------------------------------------------------------------===//
59 
60 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
61 /// to the position of the pattern within the range returned by
62 /// `PDLByteCode::getPatterns`.
63 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
64                                                    PatternBenefit benefit) {
65   currentPatternBenefits[patternIndex] = benefit;
66 }
67 
68 /// Cleanup any allocated state after a full match/rewrite has been completed.
69 /// This method should be called irregardless of whether the match+rewrite was a
70 /// success or not.
71 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
72   allocatedTypeRangeMemory.clear();
73   allocatedValueRangeMemory.clear();
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Bytecode OpCodes
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 enum OpCode : ByteCodeField {
82   /// Apply an externally registered constraint.
83   ApplyConstraint,
84   /// Apply an externally registered rewrite.
85   ApplyRewrite,
86   /// Check if two generic values are equal.
87   AreEqual,
88   /// Check if two ranges are equal.
89   AreRangesEqual,
90   /// Unconditional branch.
91   Branch,
92   /// Compare the operand count of an operation with a constant.
93   CheckOperandCount,
94   /// Compare the name of an operation with a constant.
95   CheckOperationName,
96   /// Compare the result count of an operation with a constant.
97   CheckResultCount,
98   /// Compare a range of types to a constant range of types.
99   CheckTypes,
100   /// Continue to the next iteration of a loop.
101   Continue,
102   /// Create a type range from a list of constant types.
103   CreateConstantTypeRange,
104   /// Create an operation.
105   CreateOperation,
106   /// Create a type range from a list of dynamic types.
107   CreateDynamicTypeRange,
108   /// Create a value range.
109   CreateDynamicValueRange,
110   /// Erase an operation.
111   EraseOp,
112   /// Extract the op from a range at the specified index.
113   ExtractOp,
114   /// Extract the type from a range at the specified index.
115   ExtractType,
116   /// Extract the value from a range at the specified index.
117   ExtractValue,
118   /// Terminate a matcher or rewrite sequence.
119   Finalize,
120   /// Iterate over a range of values.
121   ForEach,
122   /// Get a specific attribute of an operation.
123   GetAttribute,
124   /// Get the type of an attribute.
125   GetAttributeType,
126   /// Get the defining operation of a value.
127   GetDefiningOp,
128   /// Get a specific operand of an operation.
129   GetOperand0,
130   GetOperand1,
131   GetOperand2,
132   GetOperand3,
133   GetOperandN,
134   /// Get a specific operand group of an operation.
135   GetOperands,
136   /// Get a specific result of an operation.
137   GetResult0,
138   GetResult1,
139   GetResult2,
140   GetResult3,
141   GetResultN,
142   /// Get a specific result group of an operation.
143   GetResults,
144   /// Get the users of a value or a range of values.
145   GetUsers,
146   /// Get the type of a value.
147   GetValueType,
148   /// Get the types of a value range.
149   GetValueRangeTypes,
150   /// Check if a generic value is not null.
151   IsNotNull,
152   /// Record a successful pattern match.
153   RecordMatch,
154   /// Replace an operation.
155   ReplaceOp,
156   /// Compare an attribute with a set of constants.
157   SwitchAttribute,
158   /// Compare the operand count of an operation with a set of constants.
159   SwitchOperandCount,
160   /// Compare the name of an operation with a set of constants.
161   SwitchOperationName,
162   /// Compare the result count of an operation with a set of constants.
163   SwitchResultCount,
164   /// Compare a type with a set of constants.
165   SwitchType,
166   /// Compare a range of types with a set of constants.
167   SwitchTypes,
168 };
169 } // namespace
170 
171 /// A marker used to indicate if an operation should infer types.
172 static constexpr ByteCodeField kInferTypesMarker =
173     std::numeric_limits<ByteCodeField>::max();
174 
175 //===----------------------------------------------------------------------===//
176 // ByteCode Generation
177 //===----------------------------------------------------------------------===//
178 
179 //===----------------------------------------------------------------------===//
180 // Generator
181 
182 namespace {
183 struct ByteCodeLiveRange;
184 struct ByteCodeWriter;
185 
186 /// Check if the given class `T` can be converted to an opaque pointer.
187 template <typename T, typename... Args>
188 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
189 
190 /// This class represents the main generator for the pattern bytecode.
191 class Generator {
192 public:
193   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
194             SmallVectorImpl<ByteCodeField> &matcherByteCode,
195             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
196             SmallVectorImpl<PDLByteCodePattern> &patterns,
197             ByteCodeField &maxValueMemoryIndex,
198             ByteCodeField &maxOpRangeMemoryIndex,
199             ByteCodeField &maxTypeRangeMemoryIndex,
200             ByteCodeField &maxValueRangeMemoryIndex,
201             ByteCodeField &maxLoopLevel,
202             llvm::StringMap<PDLConstraintFunction> &constraintFns,
203             llvm::StringMap<PDLRewriteFunction> &rewriteFns,
204             const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
205       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
206         rewriterByteCode(rewriterByteCode), patterns(patterns),
207         maxValueMemoryIndex(maxValueMemoryIndex),
208         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
209         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
210         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
211         maxLoopLevel(maxLoopLevel), configMap(configMap) {
212     for (const auto &it : llvm::enumerate(constraintFns))
213       constraintToMemIndex.try_emplace(it.value().first(), it.index());
214     for (const auto &it : llvm::enumerate(rewriteFns))
215       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
216   }
217 
218   /// Generate the bytecode for the given PDL interpreter module.
219   void generate(ModuleOp module);
220 
221   /// Return the memory index to use for the given value.
222   ByteCodeField &getMemIndex(Value value) {
223     assert(valueToMemIndex.count(value) &&
224            "expected memory index to be assigned");
225     return valueToMemIndex[value];
226   }
227 
228   /// Return the range memory index used to store the given range value.
229   ByteCodeField &getRangeStorageIndex(Value value) {
230     assert(valueToRangeIndex.count(value) &&
231            "expected range index to be assigned");
232     return valueToRangeIndex[value];
233   }
234 
235   /// Return an index to use when referring to the given data that is uniqued in
236   /// the MLIR context.
237   template <typename T>
238   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
239   getMemIndex(T val) {
240     const void *opaqueVal = val.getAsOpaquePointer();
241 
242     // Get or insert a reference to this value.
243     auto it = uniquedDataToMemIndex.try_emplace(
244         opaqueVal, maxValueMemoryIndex + uniquedData.size());
245     if (it.second)
246       uniquedData.push_back(opaqueVal);
247     return it.first->second;
248   }
249 
250 private:
251   /// Allocate memory indices for the results of operations within the matcher
252   /// and rewriters.
253   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
254                              ModuleOp rewriterModule);
255 
256   /// Generate the bytecode for the given operation.
257   void generate(Region *region, ByteCodeWriter &writer);
258   void generate(Operation *op, ByteCodeWriter &writer);
259   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
289   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
290   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
291   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
292   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
293   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
294   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
295   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
296   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
297 
298   /// Mapping from value to its corresponding memory index.
299   DenseMap<Value, ByteCodeField> valueToMemIndex;
300 
301   /// Mapping from a range value to its corresponding range storage index.
302   DenseMap<Value, ByteCodeField> valueToRangeIndex;
303 
304   /// Mapping from the name of an externally registered rewrite to its index in
305   /// the bytecode registry.
306   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
307 
308   /// Mapping from the name of an externally registered constraint to its index
309   /// in the bytecode registry.
310   llvm::StringMap<ByteCodeField> constraintToMemIndex;
311 
312   /// Mapping from rewriter function name to the bytecode address of the
313   /// rewriter function in byte.
314   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
315 
316   /// Mapping from a uniqued storage object to its memory index within
317   /// `uniquedData`.
318   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
319 
320   /// The current level of the foreach loop.
321   ByteCodeField curLoopLevel = 0;
322 
323   /// The current MLIR context.
324   MLIRContext *ctx;
325 
326   /// Mapping from block to its address.
327   DenseMap<Block *, ByteCodeAddr> blockToAddr;
328 
329   /// Data of the ByteCode class to be populated.
330   std::vector<const void *> &uniquedData;
331   SmallVectorImpl<ByteCodeField> &matcherByteCode;
332   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
333   SmallVectorImpl<PDLByteCodePattern> &patterns;
334   ByteCodeField &maxValueMemoryIndex;
335   ByteCodeField &maxOpRangeMemoryIndex;
336   ByteCodeField &maxTypeRangeMemoryIndex;
337   ByteCodeField &maxValueRangeMemoryIndex;
338   ByteCodeField &maxLoopLevel;
339 
340   /// A map of pattern configurations.
341   const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
342 };
343 
344 /// This class provides utilities for writing a bytecode stream.
345 struct ByteCodeWriter {
346   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
347       : bytecode(bytecode), generator(generator) {}
348 
349   /// Append a field to the bytecode.
350   void append(ByteCodeField field) { bytecode.push_back(field); }
351   void append(OpCode opCode) { bytecode.push_back(opCode); }
352 
353   /// Append an address to the bytecode.
354   void append(ByteCodeAddr field) {
355     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
356                   "unexpected ByteCode address size");
357 
358     ByteCodeField fieldParts[2];
359     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
360     bytecode.append({fieldParts[0], fieldParts[1]});
361   }
362 
363   /// Append a single successor to the bytecode, the exact address will need to
364   /// be resolved later.
365   void append(Block *successor) {
366     // Add back a reference to the successor so that the address can be resolved
367     // later.
368     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
369     append(ByteCodeAddr(0));
370   }
371 
372   /// Append a successor range to the bytecode, the exact address will need to
373   /// be resolved later.
374   void append(SuccessorRange successors) {
375     for (Block *successor : successors)
376       append(successor);
377   }
378 
379   /// Append a range of values that will be read as generic PDLValues.
380   void appendPDLValueList(OperandRange values) {
381     bytecode.push_back(values.size());
382     for (Value value : values)
383       appendPDLValue(value);
384   }
385 
386   /// Append a value as a PDLValue.
387   void appendPDLValue(Value value) {
388     appendPDLValueKind(value);
389     append(value);
390   }
391 
392   /// Append the PDLValue::Kind of the given value.
393   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
394 
395   /// Append the PDLValue::Kind of the given type.
396   void appendPDLValueKind(Type type) {
397     PDLValue::Kind kind =
398         TypeSwitch<Type, PDLValue::Kind>(type)
399             .Case<pdl::AttributeType>(
400                 [](Type) { return PDLValue::Kind::Attribute; })
401             .Case<pdl::OperationType>(
402                 [](Type) { return PDLValue::Kind::Operation; })
403             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
404               if (rangeTy.getElementType().isa<pdl::TypeType>())
405                 return PDLValue::Kind::TypeRange;
406               return PDLValue::Kind::ValueRange;
407             })
408             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
409             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
410     bytecode.push_back(static_cast<ByteCodeField>(kind));
411   }
412 
413   /// Append a value that will be stored in a memory slot and not inline within
414   /// the bytecode.
415   template <typename T>
416   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
417                    std::is_pointer<T>::value>
418   append(T value) {
419     bytecode.push_back(generator.getMemIndex(value));
420   }
421 
422   /// Append a range of values.
423   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
424   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
425   append(T range) {
426     bytecode.push_back(llvm::size(range));
427     for (auto it : range)
428       append(it);
429   }
430 
431   /// Append a variadic number of fields to the bytecode.
432   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
433   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
434     append(field);
435     append(field2, fields...);
436   }
437 
438   /// Appends a value as a pointer, stored inline within the bytecode.
439   template <typename T>
440   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
441   appendInline(T value) {
442     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
443     const void *pointer = value.getAsOpaquePointer();
444     ByteCodeField fieldParts[numParts];
445     std::memcpy(fieldParts, &pointer, sizeof(const void *));
446     bytecode.append(fieldParts, fieldParts + numParts);
447   }
448 
449   /// Successor references in the bytecode that have yet to be resolved.
450   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
451 
452   /// The underlying bytecode buffer.
453   SmallVectorImpl<ByteCodeField> &bytecode;
454 
455   /// The main generator producing PDL.
456   Generator &generator;
457 };
458 
459 /// This class represents a live range of PDL Interpreter values, containing
460 /// information about when values are live within a match/rewrite.
461 struct ByteCodeLiveRange {
462   using Set = llvm::IntervalMap<uint64_t, char, 16>;
463   using Allocator = Set::Allocator;
464 
465   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
466 
467   /// Union this live range with the one provided.
468   void unionWith(const ByteCodeLiveRange &rhs) {
469     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
470          ++it)
471       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
472   }
473 
474   /// Returns true if this range overlaps with the one provided.
475   bool overlaps(const ByteCodeLiveRange &rhs) const {
476     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
477         .valid();
478   }
479 
480   /// A map representing the ranges of the match/rewrite that a value is live in
481   /// the interpreter.
482   ///
483   /// We use std::unique_ptr here, because IntervalMap does not provide a
484   /// correct copy or move constructor. We can eliminate the pointer once
485   /// https://reviews.llvm.org/D113240 lands.
486   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
487 
488   /// The operation range storage index for this range.
489   Optional<unsigned> opRangeIndex;
490 
491   /// The type range storage index for this range.
492   Optional<unsigned> typeRangeIndex;
493 
494   /// The value range storage index for this range.
495   Optional<unsigned> valueRangeIndex;
496 };
497 } // namespace
498 
499 void Generator::generate(ModuleOp module) {
500   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
501       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
502   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
503       pdl_interp::PDLInterpDialect::getRewriterModuleName());
504   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
505 
506   // Allocate memory indices for the results of operations within the matcher
507   // and rewriters.
508   allocateMemoryIndices(matcherFunc, rewriterModule);
509 
510   // Generate code for the rewriter functions.
511   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
512   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
513     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
514     for (Operation &op : rewriterFunc.getOps())
515       generate(&op, rewriterByteCodeWriter);
516   }
517   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
518          "unexpected branches in rewriter function");
519 
520   // Generate code for the matcher function.
521   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
522   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
523 
524   // Resolve successor references in the matcher.
525   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
526     ByteCodeAddr addr = blockToAddr[it.first];
527     for (unsigned offsetToFix : it.second)
528       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
529   }
530 }
531 
532 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
533                                       ModuleOp rewriterModule) {
534   // Rewriters use simplistic allocation scheme that simply assigns an index to
535   // each result.
536   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
537     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
538     auto processRewriterValue = [&](Value val) {
539       valueToMemIndex.try_emplace(val, index++);
540       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
541         Type elementTy = rangeType.getElementType();
542         if (elementTy.isa<pdl::TypeType>())
543           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
544         else if (elementTy.isa<pdl::ValueType>())
545           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
546       }
547     };
548 
549     for (BlockArgument arg : rewriterFunc.getArguments())
550       processRewriterValue(arg);
551     rewriterFunc.getBody().walk([&](Operation *op) {
552       for (Value result : op->getResults())
553         processRewriterValue(result);
554     });
555     if (index > maxValueMemoryIndex)
556       maxValueMemoryIndex = index;
557     if (typeRangeIndex > maxTypeRangeMemoryIndex)
558       maxTypeRangeMemoryIndex = typeRangeIndex;
559     if (valueRangeIndex > maxValueRangeMemoryIndex)
560       maxValueRangeMemoryIndex = valueRangeIndex;
561   }
562 
563   // The matcher function uses a more sophisticated numbering that tries to
564   // minimize the number of memory indices assigned. This is done by determining
565   // a live range of the values within the matcher, then the allocation is just
566   // finding the minimal number of overlapping live ranges. This is essentially
567   // a simplified form of register allocation where we don't necessarily have a
568   // limited number of registers, but we still want to minimize the number used.
569   DenseMap<Operation *, unsigned> opToFirstIndex;
570   DenseMap<Operation *, unsigned> opToLastIndex;
571 
572   // A custom walk that marks the first and the last index of each operation.
573   // The entry marks the beginning of the liveness range for this operation,
574   // followed by nested operations, followed by the end of the liveness range.
575   unsigned index = 0;
576   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
577     opToFirstIndex.try_emplace(op, index++);
578     for (Region &region : op->getRegions())
579       for (Block &block : region.getBlocks())
580         for (Operation &nested : block)
581           walk(&nested);
582     opToLastIndex.try_emplace(op, index++);
583   };
584   walk(matcherFunc);
585 
586   // Liveness info for each of the defs within the matcher.
587   ByteCodeLiveRange::Allocator allocator;
588   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
589 
590   // Assign the root operation being matched to slot 0.
591   BlockArgument rootOpArg = matcherFunc.getArgument(0);
592   valueToMemIndex[rootOpArg] = 0;
593 
594   // Walk each of the blocks, computing the def interval that the value is used.
595   Liveness matcherLiveness(matcherFunc);
596   matcherFunc->walk([&](Block *block) {
597     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
598     assert(info && "expected liveness info for block");
599     auto processValue = [&](Value value, Operation *firstUseOrDef) {
600       // We don't need to process the root op argument, this value is always
601       // assigned to the first memory slot.
602       if (value == rootOpArg)
603         return;
604 
605       // Set indices for the range of this block that the value is used.
606       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
607       defRangeIt->second.liveness->insert(
608           opToFirstIndex[firstUseOrDef],
609           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
610           /*dummyValue*/ 0);
611 
612       // Check to see if this value is a range type.
613       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
614         Type eleType = rangeTy.getElementType();
615         if (eleType.isa<pdl::OperationType>())
616           defRangeIt->second.opRangeIndex = 0;
617         else if (eleType.isa<pdl::TypeType>())
618           defRangeIt->second.typeRangeIndex = 0;
619         else if (eleType.isa<pdl::ValueType>())
620           defRangeIt->second.valueRangeIndex = 0;
621       }
622     };
623 
624     // Process the live-ins of this block.
625     for (Value liveIn : info->in()) {
626       // Only process the value if it has been defined in the current region.
627       // Other values that span across pdl_interp.foreach will be added higher
628       // up. This ensures that the we keep them alive for the entire duration
629       // of the loop.
630       if (liveIn.getParentRegion() == block->getParent())
631         processValue(liveIn, &block->front());
632     }
633 
634     // Process the block arguments for the entry block (those are not live-in).
635     if (block->isEntryBlock()) {
636       for (Value argument : block->getArguments())
637         processValue(argument, &block->front());
638     }
639 
640     // Process any new defs within this block.
641     for (Operation &op : *block)
642       for (Value result : op.getResults())
643         processValue(result, &op);
644   });
645 
646   // Greedily allocate memory slots using the computed def live ranges.
647   std::vector<ByteCodeLiveRange> allocatedIndices;
648 
649   // The number of memory indices currently allocated (and its next value).
650   // Recall that the root gets allocated memory index 0.
651   ByteCodeField numIndices = 1;
652 
653   // The number of memory ranges of various types (and their next values).
654   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
655 
656   for (auto &defIt : valueDefRanges) {
657     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
658     ByteCodeLiveRange &defRange = defIt.second;
659 
660     // Try to allocate to an existing index.
661     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
662       ByteCodeLiveRange &existingRange = existingIndexIt.value();
663       if (!defRange.overlaps(existingRange)) {
664         existingRange.unionWith(defRange);
665         memIndex = existingIndexIt.index() + 1;
666 
667         if (defRange.opRangeIndex) {
668           if (!existingRange.opRangeIndex)
669             existingRange.opRangeIndex = numOpRanges++;
670           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
671         } else if (defRange.typeRangeIndex) {
672           if (!existingRange.typeRangeIndex)
673             existingRange.typeRangeIndex = numTypeRanges++;
674           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
675         } else if (defRange.valueRangeIndex) {
676           if (!existingRange.valueRangeIndex)
677             existingRange.valueRangeIndex = numValueRanges++;
678           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
679         }
680         break;
681       }
682     }
683 
684     // If no existing index could be used, add a new one.
685     if (memIndex == 0) {
686       allocatedIndices.emplace_back(allocator);
687       ByteCodeLiveRange &newRange = allocatedIndices.back();
688       newRange.unionWith(defRange);
689 
690       // Allocate an index for op/type/value ranges.
691       if (defRange.opRangeIndex) {
692         newRange.opRangeIndex = numOpRanges;
693         valueToRangeIndex[defIt.first] = numOpRanges++;
694       } else if (defRange.typeRangeIndex) {
695         newRange.typeRangeIndex = numTypeRanges;
696         valueToRangeIndex[defIt.first] = numTypeRanges++;
697       } else if (defRange.valueRangeIndex) {
698         newRange.valueRangeIndex = numValueRanges;
699         valueToRangeIndex[defIt.first] = numValueRanges++;
700       }
701 
702       memIndex = allocatedIndices.size();
703       ++numIndices;
704     }
705   }
706 
707   // Print the index usage and ensure that we did not run out of index space.
708   LLVM_DEBUG({
709     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
710                  << "(down from initial " << valueDefRanges.size() << ").\n";
711   });
712   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
713          "Ran out of memory for allocated indices");
714 
715   // Update the max number of indices.
716   if (numIndices > maxValueMemoryIndex)
717     maxValueMemoryIndex = numIndices;
718   if (numOpRanges > maxOpRangeMemoryIndex)
719     maxOpRangeMemoryIndex = numOpRanges;
720   if (numTypeRanges > maxTypeRangeMemoryIndex)
721     maxTypeRangeMemoryIndex = numTypeRanges;
722   if (numValueRanges > maxValueRangeMemoryIndex)
723     maxValueRangeMemoryIndex = numValueRanges;
724 }
725 
726 void Generator::generate(Region *region, ByteCodeWriter &writer) {
727   llvm::ReversePostOrderTraversal<Region *> rpot(region);
728   for (Block *block : rpot) {
729     // Keep track of where this block begins within the matcher function.
730     blockToAddr.try_emplace(block, matcherByteCode.size());
731     for (Operation &op : *block)
732       generate(&op, writer);
733   }
734 }
735 
736 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
737   LLVM_DEBUG({
738     // The following list must contain all the operations that do not
739     // produce any bytecode.
740     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
741       writer.appendInline(op->getLoc());
742   });
743   TypeSwitch<Operation *>(op)
744       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
745             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
746             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
747             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
748             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
749             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
750             pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
751             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
752             pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
753             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
754             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
755             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
756             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
757             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
758             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
759             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
760             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
761             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
762             pdl_interp::SwitchResultCountOp>(
763           [&](auto interpOp) { this->generate(interpOp, writer); })
764       .Default([](Operation *) {
765         llvm_unreachable("unknown `pdl_interp` operation");
766       });
767 }
768 
769 void Generator::generate(pdl_interp::ApplyConstraintOp op,
770                          ByteCodeWriter &writer) {
771   assert(constraintToMemIndex.count(op.getName()) &&
772          "expected index for constraint function");
773   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
774   writer.appendPDLValueList(op.getArgs());
775   writer.append(op.getSuccessors());
776 }
777 void Generator::generate(pdl_interp::ApplyRewriteOp op,
778                          ByteCodeWriter &writer) {
779   assert(externalRewriterToMemIndex.count(op.getName()) &&
780          "expected index for rewrite function");
781   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
782   writer.appendPDLValueList(op.getArgs());
783 
784   ResultRange results = op.getResults();
785   writer.append(ByteCodeField(results.size()));
786   for (Value result : results) {
787     // In debug mode we also record the expected kind of the result, so that we
788     // can provide extra verification of the native rewrite function.
789 #ifndef NDEBUG
790     writer.appendPDLValueKind(result);
791 #endif
792 
793     // Range results also need to append the range storage index.
794     if (result.getType().isa<pdl::RangeType>())
795       writer.append(getRangeStorageIndex(result));
796     writer.append(result);
797   }
798 }
799 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
800   Value lhs = op.getLhs();
801   if (lhs.getType().isa<pdl::RangeType>()) {
802     writer.append(OpCode::AreRangesEqual);
803     writer.appendPDLValueKind(lhs);
804     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
805     return;
806   }
807 
808   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
809 }
810 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
811   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
812 }
813 void Generator::generate(pdl_interp::CheckAttributeOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
816                 op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckOperandCountOp op,
819                          ByteCodeWriter &writer) {
820   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
821                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
822                 op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckOperationNameOp op,
825                          ByteCodeWriter &writer) {
826   writer.append(OpCode::CheckOperationName, op.getInputOp(),
827                 OperationName(op.getName(), ctx), op.getSuccessors());
828 }
829 void Generator::generate(pdl_interp::CheckResultCountOp op,
830                          ByteCodeWriter &writer) {
831   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
832                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
833                 op.getSuccessors());
834 }
835 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
836   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
837                 op.getSuccessors());
838 }
839 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
840   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
841                 op.getSuccessors());
842 }
843 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
844   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
845   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
846 }
847 void Generator::generate(pdl_interp::CreateAttributeOp op,
848                          ByteCodeWriter &writer) {
849   // Simply repoint the memory index of the result to the constant.
850   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
851 }
852 void Generator::generate(pdl_interp::CreateOperationOp op,
853                          ByteCodeWriter &writer) {
854   writer.append(OpCode::CreateOperation, op.getResultOp(),
855                 OperationName(op.getName(), ctx));
856   writer.appendPDLValueList(op.getInputOperands());
857 
858   // Add the attributes.
859   OperandRange attributes = op.getInputAttributes();
860   writer.append(static_cast<ByteCodeField>(attributes.size()));
861   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
862     writer.append(std::get<0>(it), std::get<1>(it));
863 
864   // Add the result types. If the operation has inferred results, we use a
865   // marker "size" value. Otherwise, we add the list of explicit result types.
866   if (op.getInferredResultTypes())
867     writer.append(kInferTypesMarker);
868   else
869     writer.appendPDLValueList(op.getInputResultTypes());
870 }
871 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
872   // Append the correct opcode for the range type.
873   TypeSwitch<Type>(op.getType().getElementType())
874       .Case(
875           [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
876       .Case([&](pdl::ValueType) {
877         writer.append(OpCode::CreateDynamicValueRange);
878       });
879 
880   writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
881   writer.appendPDLValueList(op->getOperands());
882 }
883 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
884   // Simply repoint the memory index of the result to the constant.
885   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
886 }
887 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
888   writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
889                 getRangeStorageIndex(op.getResult()), op.getValue());
890 }
891 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
892   writer.append(OpCode::EraseOp, op.getInputOp());
893 }
894 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
895   OpCode opCode =
896       TypeSwitch<Type, OpCode>(op.getResult().getType())
897           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
898           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
899           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
900           .Default([](Type) -> OpCode {
901             llvm_unreachable("unsupported element type");
902           });
903   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
904 }
905 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
906   writer.append(OpCode::Finalize);
907 }
908 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
909   BlockArgument arg = op.getLoopVariable();
910   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
911   writer.appendPDLValueKind(arg.getType());
912   writer.append(curLoopLevel, op.getSuccessor());
913   ++curLoopLevel;
914   if (curLoopLevel > maxLoopLevel)
915     maxLoopLevel = curLoopLevel;
916   generate(&op.getRegion(), writer);
917   --curLoopLevel;
918 }
919 void Generator::generate(pdl_interp::GetAttributeOp op,
920                          ByteCodeWriter &writer) {
921   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
922                 op.getNameAttr());
923 }
924 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
925                          ByteCodeWriter &writer) {
926   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
927 }
928 void Generator::generate(pdl_interp::GetDefiningOpOp op,
929                          ByteCodeWriter &writer) {
930   writer.append(OpCode::GetDefiningOp, op.getInputOp());
931   writer.appendPDLValue(op.getValue());
932 }
933 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
934   uint32_t index = op.getIndex();
935   if (index < 4)
936     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
937   else
938     writer.append(OpCode::GetOperandN, index);
939   writer.append(op.getInputOp(), op.getValue());
940 }
941 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
942   Value result = op.getValue();
943   Optional<uint32_t> index = op.getIndex();
944   writer.append(OpCode::GetOperands,
945                 index.value_or(std::numeric_limits<uint32_t>::max()),
946                 op.getInputOp());
947   if (result.getType().isa<pdl::RangeType>())
948     writer.append(getRangeStorageIndex(result));
949   else
950     writer.append(std::numeric_limits<ByteCodeField>::max());
951   writer.append(result);
952 }
953 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
954   uint32_t index = op.getIndex();
955   if (index < 4)
956     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
957   else
958     writer.append(OpCode::GetResultN, index);
959   writer.append(op.getInputOp(), op.getValue());
960 }
961 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
962   Value result = op.getValue();
963   Optional<uint32_t> index = op.getIndex();
964   writer.append(OpCode::GetResults,
965                 index.value_or(std::numeric_limits<uint32_t>::max()),
966                 op.getInputOp());
967   if (result.getType().isa<pdl::RangeType>())
968     writer.append(getRangeStorageIndex(result));
969   else
970     writer.append(std::numeric_limits<ByteCodeField>::max());
971   writer.append(result);
972 }
973 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
974   Value operations = op.getOperations();
975   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
976   writer.append(OpCode::GetUsers, operations, rangeIndex);
977   writer.appendPDLValue(op.getValue());
978 }
979 void Generator::generate(pdl_interp::GetValueTypeOp op,
980                          ByteCodeWriter &writer) {
981   if (op.getType().isa<pdl::RangeType>()) {
982     Value result = op.getResult();
983     writer.append(OpCode::GetValueRangeTypes, result,
984                   getRangeStorageIndex(result), op.getValue());
985   } else {
986     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
987   }
988 }
989 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
990   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
991 }
992 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
993   ByteCodeField patternIndex = patterns.size();
994   patterns.emplace_back(PDLByteCodePattern::create(
995       op, configMap.lookup(op),
996       rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
997   writer.append(OpCode::RecordMatch, patternIndex,
998                 SuccessorRange(op.getOperation()), op.getMatchedOps());
999   writer.appendPDLValueList(op.getInputs());
1000 }
1001 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1002   writer.append(OpCode::ReplaceOp, op.getInputOp());
1003   writer.appendPDLValueList(op.getReplValues());
1004 }
1005 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1006                          ByteCodeWriter &writer) {
1007   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1008                 op.getCaseValuesAttr(), op.getSuccessors());
1009 }
1010 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1011                          ByteCodeWriter &writer) {
1012   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1013                 op.getCaseValuesAttr(), op.getSuccessors());
1014 }
1015 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1016                          ByteCodeWriter &writer) {
1017   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1018     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1019   });
1020   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1021                 op.getSuccessors());
1022 }
1023 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1024                          ByteCodeWriter &writer) {
1025   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1026                 op.getCaseValuesAttr(), op.getSuccessors());
1027 }
1028 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1029   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1030                 op.getSuccessors());
1031 }
1032 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1033   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1034                 op.getSuccessors());
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // PDLByteCode
1039 //===----------------------------------------------------------------------===//
1040 
1041 PDLByteCode::PDLByteCode(
1042     ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1043     const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1044     llvm::StringMap<PDLConstraintFunction> constraintFns,
1045     llvm::StringMap<PDLRewriteFunction> rewriteFns)
1046     : configs(std::move(configs)) {
1047   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1048                       rewriterByteCode, patterns, maxValueMemoryIndex,
1049                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1050                       maxLoopLevel, constraintFns, rewriteFns, configMap);
1051   generator.generate(module);
1052 
1053   // Initialize the external functions.
1054   for (auto &it : constraintFns)
1055     constraintFunctions.push_back(std::move(it.second));
1056   for (auto &it : rewriteFns)
1057     rewriteFunctions.push_back(std::move(it.second));
1058 }
1059 
1060 /// Initialize the given state such that it can be used to execute the current
1061 /// bytecode.
1062 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1063   state.memory.resize(maxValueMemoryIndex, nullptr);
1064   state.opRangeMemory.resize(maxOpRangeCount);
1065   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1066   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1067   state.loopIndex.resize(maxLoopLevel, 0);
1068   state.currentPatternBenefits.reserve(patterns.size());
1069   for (const PDLByteCodePattern &pattern : patterns)
1070     state.currentPatternBenefits.push_back(pattern.getBenefit());
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // ByteCode Execution
1075 
1076 namespace {
1077 /// This class provides support for executing a bytecode stream.
1078 class ByteCodeExecutor {
1079 public:
1080   ByteCodeExecutor(
1081       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1082       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1083       MutableArrayRef<TypeRange> typeRangeMemory,
1084       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1085       MutableArrayRef<ValueRange> valueRangeMemory,
1086       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1087       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1088       ArrayRef<ByteCodeField> code,
1089       ArrayRef<PatternBenefit> currentPatternBenefits,
1090       ArrayRef<PDLByteCodePattern> patterns,
1091       ArrayRef<PDLConstraintFunction> constraintFunctions,
1092       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1093       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1094         typeRangeMemory(typeRangeMemory),
1095         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1096         valueRangeMemory(valueRangeMemory),
1097         allocatedValueRangeMemory(allocatedValueRangeMemory),
1098         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1099         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1100         constraintFunctions(constraintFunctions),
1101         rewriteFunctions(rewriteFunctions) {}
1102 
1103   /// Start executing the code at the current bytecode index. `matches` is an
1104   /// optional field provided when this function is executed in a matching
1105   /// context.
1106   LogicalResult
1107   execute(PatternRewriter &rewriter,
1108           SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1109           Optional<Location> mainRewriteLoc = {});
1110 
1111 private:
1112   /// Internal implementation of executing each of the bytecode commands.
1113   void executeApplyConstraint(PatternRewriter &rewriter);
1114   LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1115   void executeAreEqual();
1116   void executeAreRangesEqual();
1117   void executeBranch();
1118   void executeCheckOperandCount();
1119   void executeCheckOperationName();
1120   void executeCheckResultCount();
1121   void executeCheckTypes();
1122   void executeContinue();
1123   void executeCreateConstantTypeRange();
1124   void executeCreateOperation(PatternRewriter &rewriter,
1125                               Location mainRewriteLoc);
1126   template <typename T>
1127   void executeDynamicCreateRange(StringRef type);
1128   void executeEraseOp(PatternRewriter &rewriter);
1129   template <typename T, typename Range, PDLValue::Kind kind>
1130   void executeExtract();
1131   void executeFinalize();
1132   void executeForEach();
1133   void executeGetAttribute();
1134   void executeGetAttributeType();
1135   void executeGetDefiningOp();
1136   void executeGetOperand(unsigned index);
1137   void executeGetOperands();
1138   void executeGetResult(unsigned index);
1139   void executeGetResults();
1140   void executeGetUsers();
1141   void executeGetValueType();
1142   void executeGetValueRangeTypes();
1143   void executeIsNotNull();
1144   void executeRecordMatch(PatternRewriter &rewriter,
1145                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1146   void executeReplaceOp(PatternRewriter &rewriter);
1147   void executeSwitchAttribute();
1148   void executeSwitchOperandCount();
1149   void executeSwitchOperationName();
1150   void executeSwitchResultCount();
1151   void executeSwitchType();
1152   void executeSwitchTypes();
1153 
1154   /// Pushes a code iterator to the stack.
1155   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1156 
1157   /// Pops a code iterator from the stack, returning true on success.
1158   void popCodeIt() {
1159     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1160     curCodeIt = resumeCodeIt.back();
1161     resumeCodeIt.pop_back();
1162   }
1163 
1164   /// Return the bytecode iterator at the start of the current op code.
1165   const ByteCodeField *getPrevCodeIt() const {
1166     LLVM_DEBUG({
1167       // Account for the op code and the Location stored inline.
1168       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1169     });
1170 
1171     // Account for the op code only.
1172     return curCodeIt - 1;
1173   }
1174 
1175   /// Read a value from the bytecode buffer, optionally skipping a certain
1176   /// number of prefix values. These methods always update the buffer to point
1177   /// to the next field after the read data.
1178   template <typename T = ByteCodeField>
1179   T read(size_t skipN = 0) {
1180     curCodeIt += skipN;
1181     return readImpl<T>();
1182   }
1183   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1184 
1185   /// Read a list of values from the bytecode buffer.
1186   template <typename ValueT, typename T>
1187   void readList(SmallVectorImpl<T> &list) {
1188     list.clear();
1189     for (unsigned i = 0, e = read(); i != e; ++i)
1190       list.push_back(read<ValueT>());
1191   }
1192 
1193   /// Read a list of values from the bytecode buffer. The values may be encoded
1194   /// either as a single element or a range of elements.
1195   void readList(SmallVectorImpl<Type> &list) {
1196     for (unsigned i = 0, e = read(); i != e; ++i) {
1197       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1198         list.push_back(read<Type>());
1199       } else {
1200         TypeRange *values = read<TypeRange *>();
1201         list.append(values->begin(), values->end());
1202       }
1203     }
1204   }
1205   void readList(SmallVectorImpl<Value> &list) {
1206     for (unsigned i = 0, e = read(); i != e; ++i) {
1207       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1208         list.push_back(read<Value>());
1209       } else {
1210         ValueRange *values = read<ValueRange *>();
1211         list.append(values->begin(), values->end());
1212       }
1213     }
1214   }
1215 
1216   /// Read a value stored inline as a pointer.
1217   template <typename T>
1218   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1219   readInline() {
1220     const void *pointer;
1221     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1222     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1223     return T::getFromOpaquePointer(pointer);
1224   }
1225 
1226   /// Jump to a specific successor based on a predicate value.
1227   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1228   /// Jump to a specific successor based on a destination index.
1229   void selectJump(size_t destIndex) {
1230     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1231   }
1232 
1233   /// Handle a switch operation with the provided value and cases.
1234   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1235   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1236     LLVM_DEBUG({
1237       llvm::dbgs() << "  * Value: " << value << "\n"
1238                    << "  * Cases: ";
1239       llvm::interleaveComma(cases, llvm::dbgs());
1240       llvm::dbgs() << "\n";
1241     });
1242 
1243     // Check to see if the attribute value is within the case list. Jump to
1244     // the correct successor index based on the result.
1245     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1246       if (cmp(*it, value))
1247         return selectJump(size_t((it - cases.begin()) + 1));
1248     selectJump(size_t(0));
1249   }
1250 
1251   /// Store a pointer to memory.
1252   void storeToMemory(unsigned index, const void *value) {
1253     memory[index] = value;
1254   }
1255 
1256   /// Store a value to memory as an opaque pointer.
1257   template <typename T>
1258   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1259   storeToMemory(unsigned index, T value) {
1260     memory[index] = value.getAsOpaquePointer();
1261   }
1262 
1263   /// Internal implementation of reading various data types from the bytecode
1264   /// stream.
1265   template <typename T>
1266   const void *readFromMemory() {
1267     size_t index = *curCodeIt++;
1268 
1269     // If this type is an SSA value, it can only be stored in non-const memory.
1270     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1271                         Value>::value ||
1272         index < memory.size())
1273       return memory[index];
1274 
1275     // Otherwise, if this index is not inbounds it is uniqued.
1276     return uniquedMemory[index - memory.size()];
1277   }
1278   template <typename T>
1279   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1280     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1281   }
1282   template <typename T>
1283   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1284                    T>
1285   readImpl() {
1286     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1287   }
1288   template <typename T>
1289   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1290     switch (read<PDLValue::Kind>()) {
1291     case PDLValue::Kind::Attribute:
1292       return read<Attribute>();
1293     case PDLValue::Kind::Operation:
1294       return read<Operation *>();
1295     case PDLValue::Kind::Type:
1296       return read<Type>();
1297     case PDLValue::Kind::Value:
1298       return read<Value>();
1299     case PDLValue::Kind::TypeRange:
1300       return read<TypeRange *>();
1301     case PDLValue::Kind::ValueRange:
1302       return read<ValueRange *>();
1303     }
1304     llvm_unreachable("unhandled PDLValue::Kind");
1305   }
1306   template <typename T>
1307   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1308     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1309                   "unexpected ByteCode address size");
1310     ByteCodeAddr result;
1311     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1312     curCodeIt += 2;
1313     return result;
1314   }
1315   template <typename T>
1316   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1317     return *curCodeIt++;
1318   }
1319   template <typename T>
1320   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1321     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1322   }
1323 
1324   /// Assign the given range to the given memory index. This allocates a new
1325   /// range object if necessary.
1326   template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1327   void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1328                            unsigned rangeIndex) {
1329     // Utility functor used to type-erase the assignment.
1330     auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1331       // If the input range is empty, we don't need to allocate anything.
1332       if (range.empty()) {
1333         rangeMemory[rangeIndex] = {};
1334       } else {
1335         // Allocate a buffer for this type range.
1336         llvm::OwningArrayRef<T> storage(llvm::size(range));
1337         llvm::copy(range, storage.begin());
1338 
1339         // Assign this to the range slot and use the range as the value for the
1340         // memory index.
1341         allocatedRangeMemory.emplace_back(std::move(storage));
1342         rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1343       }
1344       memory[memIndex] = &rangeMemory[rangeIndex];
1345     };
1346 
1347     // Dispatch based on the concrete range type.
1348     if constexpr (std::is_same_v<T, Type>) {
1349       return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1350     } else if constexpr (std::is_same_v<T, Value>) {
1351       return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1352     } else {
1353       llvm_unreachable("unhandled range type");
1354     }
1355   }
1356 
1357   /// The underlying bytecode buffer.
1358   const ByteCodeField *curCodeIt;
1359 
1360   /// The stack of bytecode positions at which to resume operation.
1361   SmallVector<const ByteCodeField *> resumeCodeIt;
1362 
1363   /// The current execution memory.
1364   MutableArrayRef<const void *> memory;
1365   MutableArrayRef<OwningOpRange> opRangeMemory;
1366   MutableArrayRef<TypeRange> typeRangeMemory;
1367   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1368   MutableArrayRef<ValueRange> valueRangeMemory;
1369   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1370 
1371   /// The current loop indices.
1372   MutableArrayRef<unsigned> loopIndex;
1373 
1374   /// References to ByteCode data necessary for execution.
1375   ArrayRef<const void *> uniquedMemory;
1376   ArrayRef<ByteCodeField> code;
1377   ArrayRef<PatternBenefit> currentPatternBenefits;
1378   ArrayRef<PDLByteCodePattern> patterns;
1379   ArrayRef<PDLConstraintFunction> constraintFunctions;
1380   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1381 };
1382 
1383 /// This class is an instantiation of the PDLResultList that provides access to
1384 /// the returned results. This API is not on `PDLResultList` to avoid
1385 /// overexposing access to information specific solely to the ByteCode.
1386 class ByteCodeRewriteResultList : public PDLResultList {
1387 public:
1388   ByteCodeRewriteResultList(unsigned maxNumResults)
1389       : PDLResultList(maxNumResults) {}
1390 
1391   /// Return the list of PDL results.
1392   MutableArrayRef<PDLValue> getResults() { return results; }
1393 
1394   /// Return the type ranges allocated by this list.
1395   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1396     return allocatedTypeRanges;
1397   }
1398 
1399   /// Return the value ranges allocated by this list.
1400   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1401     return allocatedValueRanges;
1402   }
1403 };
1404 } // namespace
1405 
1406 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1407   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1408   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1409   SmallVector<PDLValue, 16> args;
1410   readList<PDLValue>(args);
1411 
1412   LLVM_DEBUG({
1413     llvm::dbgs() << "  * Arguments: ";
1414     llvm::interleaveComma(args, llvm::dbgs());
1415   });
1416 
1417   // Invoke the constraint and jump to the proper destination.
1418   selectJump(succeeded(constraintFn(rewriter, args)));
1419 }
1420 
1421 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1422   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1423   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1424   SmallVector<PDLValue, 16> args;
1425   readList<PDLValue>(args);
1426 
1427   LLVM_DEBUG({
1428     llvm::dbgs() << "  * Arguments: ";
1429     llvm::interleaveComma(args, llvm::dbgs());
1430   });
1431 
1432   // Execute the rewrite function.
1433   ByteCodeField numResults = read();
1434   ByteCodeRewriteResultList results(numResults);
1435   LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1436 
1437   assert(results.getResults().size() == numResults &&
1438          "native PDL rewrite function returned unexpected number of results");
1439 
1440   // Store the results in the bytecode memory.
1441   for (PDLValue &result : results.getResults()) {
1442     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1443 
1444 // In debug mode we also verify the expected kind of the result.
1445 #ifndef NDEBUG
1446     assert(result.getKind() == read<PDLValue::Kind>() &&
1447            "native PDL rewrite function returned an unexpected type of result");
1448 #endif
1449 
1450     // If the result is a range, we need to copy it over to the bytecodes
1451     // range memory.
1452     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1453       unsigned rangeIndex = read();
1454       typeRangeMemory[rangeIndex] = *typeRange;
1455       memory[read()] = &typeRangeMemory[rangeIndex];
1456     } else if (Optional<ValueRange> valueRange =
1457                    result.dyn_cast<ValueRange>()) {
1458       unsigned rangeIndex = read();
1459       valueRangeMemory[rangeIndex] = *valueRange;
1460       memory[read()] = &valueRangeMemory[rangeIndex];
1461     } else {
1462       memory[read()] = result.getAsOpaquePointer();
1463     }
1464   }
1465 
1466   // Copy over any underlying storage allocated for result ranges.
1467   for (auto &it : results.getAllocatedTypeRanges())
1468     allocatedTypeRangeMemory.push_back(std::move(it));
1469   for (auto &it : results.getAllocatedValueRanges())
1470     allocatedValueRangeMemory.push_back(std::move(it));
1471 
1472   // Process the result of the rewrite.
1473   if (failed(rewriteResult)) {
1474     LLVM_DEBUG(llvm::dbgs() << "  - Failed");
1475     return failure();
1476   }
1477   return success();
1478 }
1479 
1480 void ByteCodeExecutor::executeAreEqual() {
1481   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1482   const void *lhs = read<const void *>();
1483   const void *rhs = read<const void *>();
1484 
1485   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1486   selectJump(lhs == rhs);
1487 }
1488 
1489 void ByteCodeExecutor::executeAreRangesEqual() {
1490   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1491   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1492   const void *lhs = read<const void *>();
1493   const void *rhs = read<const void *>();
1494 
1495   switch (valueKind) {
1496   case PDLValue::Kind::TypeRange: {
1497     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1498     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1499     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1500     selectJump(*lhsRange == *rhsRange);
1501     break;
1502   }
1503   case PDLValue::Kind::ValueRange: {
1504     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1505     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1506     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1507     selectJump(*lhsRange == *rhsRange);
1508     break;
1509   }
1510   default:
1511     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1512   }
1513 }
1514 
1515 void ByteCodeExecutor::executeBranch() {
1516   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1517   curCodeIt = &code[read<ByteCodeAddr>()];
1518 }
1519 
1520 void ByteCodeExecutor::executeCheckOperandCount() {
1521   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1522   Operation *op = read<Operation *>();
1523   uint32_t expectedCount = read<uint32_t>();
1524   bool compareAtLeast = read();
1525 
1526   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1527                           << "  * Expected: " << expectedCount << "\n"
1528                           << "  * Comparator: "
1529                           << (compareAtLeast ? ">=" : "==") << "\n");
1530   if (compareAtLeast)
1531     selectJump(op->getNumOperands() >= expectedCount);
1532   else
1533     selectJump(op->getNumOperands() == expectedCount);
1534 }
1535 
1536 void ByteCodeExecutor::executeCheckOperationName() {
1537   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1538   Operation *op = read<Operation *>();
1539   OperationName expectedName = read<OperationName>();
1540 
1541   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1542                           << "  * Expected: \"" << expectedName << "\"\n");
1543   selectJump(op->getName() == expectedName);
1544 }
1545 
1546 void ByteCodeExecutor::executeCheckResultCount() {
1547   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1548   Operation *op = read<Operation *>();
1549   uint32_t expectedCount = read<uint32_t>();
1550   bool compareAtLeast = read();
1551 
1552   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1553                           << "  * Expected: " << expectedCount << "\n"
1554                           << "  * Comparator: "
1555                           << (compareAtLeast ? ">=" : "==") << "\n");
1556   if (compareAtLeast)
1557     selectJump(op->getNumResults() >= expectedCount);
1558   else
1559     selectJump(op->getNumResults() == expectedCount);
1560 }
1561 
1562 void ByteCodeExecutor::executeCheckTypes() {
1563   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1564   TypeRange *lhs = read<TypeRange *>();
1565   Attribute rhs = read<Attribute>();
1566   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1567 
1568   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1569 }
1570 
1571 void ByteCodeExecutor::executeContinue() {
1572   ByteCodeField level = read();
1573   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1574                           << "  * Level: " << level << "\n");
1575   ++loopIndex[level];
1576   popCodeIt();
1577 }
1578 
1579 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1580   LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1581   unsigned memIndex = read();
1582   unsigned rangeIndex = read();
1583   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1584 
1585   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1586   assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1587                       rangeIndex);
1588 }
1589 
1590 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1591                                               Location mainRewriteLoc) {
1592   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1593 
1594   unsigned memIndex = read();
1595   OperationState state(mainRewriteLoc, read<OperationName>());
1596   readList(state.operands);
1597   for (unsigned i = 0, e = read(); i != e; ++i) {
1598     StringAttr name = read<StringAttr>();
1599     if (Attribute attr = read<Attribute>())
1600       state.addAttribute(name, attr);
1601   }
1602 
1603   // Read in the result types. If the "size" is the sentinel value, this
1604   // indicates that the result types should be inferred.
1605   unsigned numResults = read();
1606   if (numResults == kInferTypesMarker) {
1607     InferTypeOpInterface::Concept *inferInterface =
1608         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1609     assert(inferInterface &&
1610            "expected operation to provide InferTypeOpInterface");
1611 
1612     // TODO: Handle failure.
1613     if (failed(inferInterface->inferReturnTypes(
1614             state.getContext(), state.location, state.operands,
1615             state.attributes.getDictionary(state.getContext()), state.regions,
1616             state.types)))
1617       return;
1618   } else {
1619     // Otherwise, this is a fixed number of results.
1620     for (unsigned i = 0; i != numResults; ++i) {
1621       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1622         state.types.push_back(read<Type>());
1623       } else {
1624         TypeRange *resultTypes = read<TypeRange *>();
1625         state.types.append(resultTypes->begin(), resultTypes->end());
1626       }
1627     }
1628   }
1629 
1630   Operation *resultOp = rewriter.create(state);
1631   memory[memIndex] = resultOp;
1632 
1633   LLVM_DEBUG({
1634     llvm::dbgs() << "  * Attributes: "
1635                  << state.attributes.getDictionary(state.getContext())
1636                  << "\n  * Operands: ";
1637     llvm::interleaveComma(state.operands, llvm::dbgs());
1638     llvm::dbgs() << "\n  * Result Types: ";
1639     llvm::interleaveComma(state.types, llvm::dbgs());
1640     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1641   });
1642 }
1643 
1644 template <typename T>
1645 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1646   LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1647   unsigned memIndex = read();
1648   unsigned rangeIndex = read();
1649   SmallVector<T> values;
1650   readList(values);
1651 
1652   LLVM_DEBUG({
1653     llvm::dbgs() << "\n  * " << type << "s: ";
1654     llvm::interleaveComma(values, llvm::dbgs());
1655     llvm::dbgs() << "\n";
1656   });
1657 
1658   assignRangeToMemory(values, memIndex, rangeIndex);
1659 }
1660 
1661 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1662   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1663   Operation *op = read<Operation *>();
1664 
1665   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1666   rewriter.eraseOp(op);
1667 }
1668 
1669 template <typename T, typename Range, PDLValue::Kind kind>
1670 void ByteCodeExecutor::executeExtract() {
1671   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1672   Range *range = read<Range *>();
1673   unsigned index = read<uint32_t>();
1674   unsigned memIndex = read();
1675 
1676   if (!range) {
1677     memory[memIndex] = nullptr;
1678     return;
1679   }
1680 
1681   T result = index < range->size() ? (*range)[index] : T();
1682   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1683                           << "  * Index: " << index << "\n"
1684                           << "  * Result: " << result << "\n");
1685   storeToMemory(memIndex, result);
1686 }
1687 
1688 void ByteCodeExecutor::executeFinalize() {
1689   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1690 }
1691 
1692 void ByteCodeExecutor::executeForEach() {
1693   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1694   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1695   unsigned rangeIndex = read();
1696   unsigned memIndex = read();
1697   const void *value = nullptr;
1698 
1699   switch (read<PDLValue::Kind>()) {
1700   case PDLValue::Kind::Operation: {
1701     unsigned &index = loopIndex[read()];
1702     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1703     assert(index <= array.size() && "iterated past the end");
1704     if (index < array.size()) {
1705       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1706       value = array[index];
1707       break;
1708     }
1709 
1710     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1711     index = 0;
1712     selectJump(size_t(0));
1713     return;
1714   }
1715   default:
1716     llvm_unreachable("unexpected `ForEach` value kind");
1717   }
1718 
1719   // Store the iterate value and the stack address.
1720   memory[memIndex] = value;
1721   pushCodeIt(prevCodeIt);
1722 
1723   // Skip over the successor (we will enter the body of the loop).
1724   read<ByteCodeAddr>();
1725 }
1726 
1727 void ByteCodeExecutor::executeGetAttribute() {
1728   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1729   unsigned memIndex = read();
1730   Operation *op = read<Operation *>();
1731   StringAttr attrName = read<StringAttr>();
1732   Attribute attr = op->getAttr(attrName);
1733 
1734   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1735                           << "  * Attribute: " << attrName << "\n"
1736                           << "  * Result: " << attr << "\n");
1737   memory[memIndex] = attr.getAsOpaquePointer();
1738 }
1739 
1740 void ByteCodeExecutor::executeGetAttributeType() {
1741   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1742   unsigned memIndex = read();
1743   Attribute attr = read<Attribute>();
1744   Type type;
1745   if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1746     type = typedAttr.getType();
1747 
1748   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1749                           << "  * Result: " << type << "\n");
1750   memory[memIndex] = type.getAsOpaquePointer();
1751 }
1752 
1753 void ByteCodeExecutor::executeGetDefiningOp() {
1754   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1755   unsigned memIndex = read();
1756   Operation *op = nullptr;
1757   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1758     Value value = read<Value>();
1759     if (value)
1760       op = value.getDefiningOp();
1761     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1762   } else {
1763     ValueRange *values = read<ValueRange *>();
1764     if (values && !values->empty()) {
1765       op = values->front().getDefiningOp();
1766     }
1767     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1768   }
1769 
1770   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1771   memory[memIndex] = op;
1772 }
1773 
1774 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1775   Operation *op = read<Operation *>();
1776   unsigned memIndex = read();
1777   Value operand =
1778       index < op->getNumOperands() ? op->getOperand(index) : Value();
1779 
1780   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1781                           << "  * Index: " << index << "\n"
1782                           << "  * Result: " << operand << "\n");
1783   memory[memIndex] = operand.getAsOpaquePointer();
1784 }
1785 
1786 /// This function is the internal implementation of `GetResults` and
1787 /// `GetOperands` that provides support for extracting a value range from the
1788 /// given operation.
1789 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1790 static void *
1791 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1792                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1793                           MutableArrayRef<ValueRange> valueRangeMemory) {
1794   // Check for the sentinel index that signals that all values should be
1795   // returned.
1796   if (index == std::numeric_limits<uint32_t>::max()) {
1797     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1798     // `values` is already the full value range.
1799 
1800     // Otherwise, check to see if this operation uses AttrSizedSegments.
1801   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1802     LLVM_DEBUG(llvm::dbgs()
1803                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1804 
1805     auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1806     if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1807       return nullptr;
1808 
1809     ArrayRef<int32_t> segments = segmentAttr;
1810     unsigned startIndex =
1811         std::accumulate(segments.begin(), segments.begin() + index, 0);
1812     values = values.slice(startIndex, *std::next(segments.begin(), index));
1813 
1814     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1815                             << *std::next(segments.begin(), index) << "]\n");
1816 
1817     // Otherwise, assume this is the last operand group of the operation.
1818     // FIXME: We currently don't support operations with
1819     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1820     // have a way to detect it's presence.
1821   } else if (values.size() >= index) {
1822     LLVM_DEBUG(llvm::dbgs()
1823                << "  * Treating values as trailing variadic range\n");
1824     values = values.drop_front(index);
1825 
1826     // If we couldn't detect a way to compute the values, bail out.
1827   } else {
1828     return nullptr;
1829   }
1830 
1831   // If the range index is valid, we are returning a range.
1832   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1833     valueRangeMemory[rangeIndex] = values;
1834     return &valueRangeMemory[rangeIndex];
1835   }
1836 
1837   // If a range index wasn't provided, the range is required to be non-variadic.
1838   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1839 }
1840 
1841 void ByteCodeExecutor::executeGetOperands() {
1842   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1843   unsigned index = read<uint32_t>();
1844   Operation *op = read<Operation *>();
1845   ByteCodeField rangeIndex = read();
1846 
1847   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1848       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1849       valueRangeMemory);
1850   if (!result)
1851     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1852   memory[read()] = result;
1853 }
1854 
1855 void ByteCodeExecutor::executeGetResult(unsigned index) {
1856   Operation *op = read<Operation *>();
1857   unsigned memIndex = read();
1858   OpResult result =
1859       index < op->getNumResults() ? op->getResult(index) : OpResult();
1860 
1861   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1862                           << "  * Index: " << index << "\n"
1863                           << "  * Result: " << result << "\n");
1864   memory[memIndex] = result.getAsOpaquePointer();
1865 }
1866 
1867 void ByteCodeExecutor::executeGetResults() {
1868   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1869   unsigned index = read<uint32_t>();
1870   Operation *op = read<Operation *>();
1871   ByteCodeField rangeIndex = read();
1872 
1873   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1874       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1875       valueRangeMemory);
1876   if (!result)
1877     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1878   memory[read()] = result;
1879 }
1880 
1881 void ByteCodeExecutor::executeGetUsers() {
1882   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1883   unsigned memIndex = read();
1884   unsigned rangeIndex = read();
1885   OwningOpRange &range = opRangeMemory[rangeIndex];
1886   memory[memIndex] = &range;
1887 
1888   range = OwningOpRange();
1889   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1890     // Read the value.
1891     Value value = read<Value>();
1892     if (!value)
1893       return;
1894     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1895 
1896     // Extract the users of a single value.
1897     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1898     llvm::copy(value.getUsers(), range.begin());
1899   } else {
1900     // Read a range of values.
1901     ValueRange *values = read<ValueRange *>();
1902     if (!values)
1903       return;
1904     LLVM_DEBUG({
1905       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1906       llvm::interleaveComma(*values, llvm::dbgs());
1907       llvm::dbgs() << "\n";
1908     });
1909 
1910     // Extract all the users of a range of values.
1911     SmallVector<Operation *> users;
1912     for (Value value : *values)
1913       users.append(value.user_begin(), value.user_end());
1914     range = OwningOpRange(users.size());
1915     llvm::copy(users, range.begin());
1916   }
1917 
1918   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1919 }
1920 
1921 void ByteCodeExecutor::executeGetValueType() {
1922   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1923   unsigned memIndex = read();
1924   Value value = read<Value>();
1925   Type type = value ? value.getType() : Type();
1926 
1927   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1928                           << "  * Result: " << type << "\n");
1929   memory[memIndex] = type.getAsOpaquePointer();
1930 }
1931 
1932 void ByteCodeExecutor::executeGetValueRangeTypes() {
1933   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1934   unsigned memIndex = read();
1935   unsigned rangeIndex = read();
1936   ValueRange *values = read<ValueRange *>();
1937   if (!values) {
1938     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1939     memory[memIndex] = nullptr;
1940     return;
1941   }
1942 
1943   LLVM_DEBUG({
1944     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1945     llvm::interleaveComma(*values, llvm::dbgs());
1946     llvm::dbgs() << "\n  * Result: ";
1947     llvm::interleaveComma(values->getType(), llvm::dbgs());
1948     llvm::dbgs() << "\n";
1949   });
1950   typeRangeMemory[rangeIndex] = values->getType();
1951   memory[memIndex] = &typeRangeMemory[rangeIndex];
1952 }
1953 
1954 void ByteCodeExecutor::executeIsNotNull() {
1955   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1956   const void *value = read<const void *>();
1957 
1958   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1959   selectJump(value != nullptr);
1960 }
1961 
1962 void ByteCodeExecutor::executeRecordMatch(
1963     PatternRewriter &rewriter,
1964     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1965   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1966   unsigned patternIndex = read();
1967   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1968   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1969 
1970   // If the benefit of the pattern is impossible, skip the processing of the
1971   // rest of the pattern.
1972   if (benefit.isImpossibleToMatch()) {
1973     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1974     curCodeIt = dest;
1975     return;
1976   }
1977 
1978   // Create a fused location containing the locations of each of the
1979   // operations used in the match. This will be used as the location for
1980   // created operations during the rewrite that don't already have an
1981   // explicit location set.
1982   unsigned numMatchLocs = read();
1983   SmallVector<Location, 4> matchLocs;
1984   matchLocs.reserve(numMatchLocs);
1985   for (unsigned i = 0; i != numMatchLocs; ++i)
1986     matchLocs.push_back(read<Operation *>()->getLoc());
1987   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1988 
1989   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1990                           << "  * Location: " << matchLoc << "\n");
1991   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1992   PDLByteCode::MatchResult &match = matches.back();
1993 
1994   // Record all of the inputs to the match. If any of the inputs are ranges, we
1995   // will also need to remap the range pointer to memory stored in the match
1996   // state.
1997   unsigned numInputs = read();
1998   match.values.reserve(numInputs);
1999   match.typeRangeValues.reserve(numInputs);
2000   match.valueRangeValues.reserve(numInputs);
2001   for (unsigned i = 0; i < numInputs; ++i) {
2002     switch (read<PDLValue::Kind>()) {
2003     case PDLValue::Kind::TypeRange:
2004       match.typeRangeValues.push_back(*read<TypeRange *>());
2005       match.values.push_back(&match.typeRangeValues.back());
2006       break;
2007     case PDLValue::Kind::ValueRange:
2008       match.valueRangeValues.push_back(*read<ValueRange *>());
2009       match.values.push_back(&match.valueRangeValues.back());
2010       break;
2011     default:
2012       match.values.push_back(read<const void *>());
2013       break;
2014     }
2015   }
2016   curCodeIt = dest;
2017 }
2018 
2019 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2020   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2021   Operation *op = read<Operation *>();
2022   SmallVector<Value, 16> args;
2023   readList(args);
2024 
2025   LLVM_DEBUG({
2026     llvm::dbgs() << "  * Operation: " << *op << "\n"
2027                  << "  * Values: ";
2028     llvm::interleaveComma(args, llvm::dbgs());
2029     llvm::dbgs() << "\n";
2030   });
2031   rewriter.replaceOp(op, args);
2032 }
2033 
2034 void ByteCodeExecutor::executeSwitchAttribute() {
2035   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2036   Attribute value = read<Attribute>();
2037   ArrayAttr cases = read<ArrayAttr>();
2038   handleSwitch(value, cases);
2039 }
2040 
2041 void ByteCodeExecutor::executeSwitchOperandCount() {
2042   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2043   Operation *op = read<Operation *>();
2044   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2045 
2046   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2047   handleSwitch(op->getNumOperands(), cases);
2048 }
2049 
2050 void ByteCodeExecutor::executeSwitchOperationName() {
2051   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2052   OperationName value = read<Operation *>()->getName();
2053   size_t caseCount = read();
2054 
2055   // The operation names are stored in-line, so to print them out for
2056   // debugging purposes we need to read the array before executing the
2057   // switch so that we can display all of the possible values.
2058   LLVM_DEBUG({
2059     const ByteCodeField *prevCodeIt = curCodeIt;
2060     llvm::dbgs() << "  * Value: " << value << "\n"
2061                  << "  * Cases: ";
2062     llvm::interleaveComma(
2063         llvm::map_range(llvm::seq<size_t>(0, caseCount),
2064                         [&](size_t) { return read<OperationName>(); }),
2065         llvm::dbgs());
2066     llvm::dbgs() << "\n";
2067     curCodeIt = prevCodeIt;
2068   });
2069 
2070   // Try to find the switch value within any of the cases.
2071   for (size_t i = 0; i != caseCount; ++i) {
2072     if (read<OperationName>() == value) {
2073       curCodeIt += (caseCount - i - 1);
2074       return selectJump(i + 1);
2075     }
2076   }
2077   selectJump(size_t(0));
2078 }
2079 
2080 void ByteCodeExecutor::executeSwitchResultCount() {
2081   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2082   Operation *op = read<Operation *>();
2083   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2084 
2085   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2086   handleSwitch(op->getNumResults(), cases);
2087 }
2088 
2089 void ByteCodeExecutor::executeSwitchType() {
2090   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2091   Type value = read<Type>();
2092   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2093   handleSwitch(value, cases);
2094 }
2095 
2096 void ByteCodeExecutor::executeSwitchTypes() {
2097   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2098   TypeRange *value = read<TypeRange *>();
2099   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2100   if (!value) {
2101     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2102     return selectJump(size_t(0));
2103   }
2104   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2105     return value == caseValue.getAsValueRange<TypeAttr>();
2106   });
2107 }
2108 
2109 LogicalResult
2110 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2111                           SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2112                           Optional<Location> mainRewriteLoc) {
2113   while (true) {
2114     // Print the location of the operation being executed.
2115     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2116 
2117     OpCode opCode = static_cast<OpCode>(read());
2118     switch (opCode) {
2119     case ApplyConstraint:
2120       executeApplyConstraint(rewriter);
2121       break;
2122     case ApplyRewrite:
2123       if (failed(executeApplyRewrite(rewriter)))
2124         return failure();
2125       break;
2126     case AreEqual:
2127       executeAreEqual();
2128       break;
2129     case AreRangesEqual:
2130       executeAreRangesEqual();
2131       break;
2132     case Branch:
2133       executeBranch();
2134       break;
2135     case CheckOperandCount:
2136       executeCheckOperandCount();
2137       break;
2138     case CheckOperationName:
2139       executeCheckOperationName();
2140       break;
2141     case CheckResultCount:
2142       executeCheckResultCount();
2143       break;
2144     case CheckTypes:
2145       executeCheckTypes();
2146       break;
2147     case Continue:
2148       executeContinue();
2149       break;
2150     case CreateConstantTypeRange:
2151       executeCreateConstantTypeRange();
2152       break;
2153     case CreateOperation:
2154       executeCreateOperation(rewriter, *mainRewriteLoc);
2155       break;
2156     case CreateDynamicTypeRange:
2157       executeDynamicCreateRange<Type>("Type");
2158       break;
2159     case CreateDynamicValueRange:
2160       executeDynamicCreateRange<Value>("Value");
2161       break;
2162     case EraseOp:
2163       executeEraseOp(rewriter);
2164       break;
2165     case ExtractOp:
2166       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2167       break;
2168     case ExtractType:
2169       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2170       break;
2171     case ExtractValue:
2172       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2173       break;
2174     case Finalize:
2175       executeFinalize();
2176       LLVM_DEBUG(llvm::dbgs() << "\n");
2177       return success();
2178     case ForEach:
2179       executeForEach();
2180       break;
2181     case GetAttribute:
2182       executeGetAttribute();
2183       break;
2184     case GetAttributeType:
2185       executeGetAttributeType();
2186       break;
2187     case GetDefiningOp:
2188       executeGetDefiningOp();
2189       break;
2190     case GetOperand0:
2191     case GetOperand1:
2192     case GetOperand2:
2193     case GetOperand3: {
2194       unsigned index = opCode - GetOperand0;
2195       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2196       executeGetOperand(index);
2197       break;
2198     }
2199     case GetOperandN:
2200       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2201       executeGetOperand(read<uint32_t>());
2202       break;
2203     case GetOperands:
2204       executeGetOperands();
2205       break;
2206     case GetResult0:
2207     case GetResult1:
2208     case GetResult2:
2209     case GetResult3: {
2210       unsigned index = opCode - GetResult0;
2211       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2212       executeGetResult(index);
2213       break;
2214     }
2215     case GetResultN:
2216       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2217       executeGetResult(read<uint32_t>());
2218       break;
2219     case GetResults:
2220       executeGetResults();
2221       break;
2222     case GetUsers:
2223       executeGetUsers();
2224       break;
2225     case GetValueType:
2226       executeGetValueType();
2227       break;
2228     case GetValueRangeTypes:
2229       executeGetValueRangeTypes();
2230       break;
2231     case IsNotNull:
2232       executeIsNotNull();
2233       break;
2234     case RecordMatch:
2235       assert(matches &&
2236              "expected matches to be provided when executing the matcher");
2237       executeRecordMatch(rewriter, *matches);
2238       break;
2239     case ReplaceOp:
2240       executeReplaceOp(rewriter);
2241       break;
2242     case SwitchAttribute:
2243       executeSwitchAttribute();
2244       break;
2245     case SwitchOperandCount:
2246       executeSwitchOperandCount();
2247       break;
2248     case SwitchOperationName:
2249       executeSwitchOperationName();
2250       break;
2251     case SwitchResultCount:
2252       executeSwitchResultCount();
2253       break;
2254     case SwitchType:
2255       executeSwitchType();
2256       break;
2257     case SwitchTypes:
2258       executeSwitchTypes();
2259       break;
2260     }
2261     LLVM_DEBUG(llvm::dbgs() << "\n");
2262   }
2263 }
2264 
2265 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2266                         SmallVectorImpl<MatchResult> &matches,
2267                         PDLByteCodeMutableState &state) const {
2268   // The first memory slot is always the root operation.
2269   state.memory[0] = op;
2270 
2271   // The matcher function always starts at code address 0.
2272   ByteCodeExecutor executor(
2273       matcherByteCode.data(), state.memory, state.opRangeMemory,
2274       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2275       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2276       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2277       constraintFunctions, rewriteFunctions);
2278   LogicalResult executeResult = executor.execute(rewriter, &matches);
2279   assert(succeeded(executeResult) && "unexpected matcher execution failure");
2280 
2281   // Order the found matches by benefit.
2282   std::stable_sort(matches.begin(), matches.end(),
2283                    [](const MatchResult &lhs, const MatchResult &rhs) {
2284                      return lhs.benefit > rhs.benefit;
2285                    });
2286 }
2287 
2288 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2289                                    const MatchResult &match,
2290                                    PDLByteCodeMutableState &state) const {
2291   auto *configSet = match.pattern->getConfigSet();
2292   if (configSet)
2293     configSet->notifyRewriteBegin(rewriter);
2294 
2295   // The arguments of the rewrite function are stored at the start of the
2296   // memory buffer.
2297   llvm::copy(match.values, state.memory.begin());
2298 
2299   ByteCodeExecutor executor(
2300       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2301       state.opRangeMemory, state.typeRangeMemory,
2302       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2303       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2304       rewriterByteCode, state.currentPatternBenefits, patterns,
2305       constraintFunctions, rewriteFunctions);
2306   LogicalResult result =
2307       executor.execute(rewriter, /*matches=*/nullptr, match.location);
2308 
2309   if (configSet)
2310     configSet->notifyRewriteEnd(rewriter);
2311 
2312   // If the rewrite failed, check if the pattern rewriter can recover. If it
2313   // can, we can signal to the pattern applicator to keep trying patterns. If it
2314   // doesn't, we need to bail. Bailing here should be fine, given that we have
2315   // no means to propagate such a failure to the user, and it also indicates a
2316   // bug in the user code (i.e. failable rewrites should not be used with
2317   // pattern rewriters that don't support it).
2318   if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2319     LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2320     llvm::report_fatal_error(
2321         "Native PDL Rewrite failed, but the pattern "
2322         "rewriter doesn't support recovery. Failable pattern rewrites should "
2323         "not be used with pattern rewriters that do not support them.");
2324   }
2325   return result;
2326 }
2327