xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision c80e6edba4a9593f0587e27fa0ac825ebe174afd)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38                                               PDLPatternConfigSet *configSet,
39                                               ByteCodeAddr rewriterAddr) {
40   PatternBenefit benefit = matchOp.getBenefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Collect the set of generated operations.
44   SmallVector<StringRef, 8> generatedOps;
45   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46     generatedOps =
47         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49   // Check to see if this is pattern matches a specific operation type.
50   if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51     return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52                               generatedOps);
53   return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54                             benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65                                                    PatternBenefit benefit) {
66   currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
72 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73   allocatedTypeRangeMemory.clear();
74   allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83   /// Apply an externally registered constraint.
84   ApplyConstraint,
85   /// Apply an externally registered rewrite.
86   ApplyRewrite,
87   /// Check if two generic values are equal.
88   AreEqual,
89   /// Check if two ranges are equal.
90   AreRangesEqual,
91   /// Unconditional branch.
92   Branch,
93   /// Compare the operand count of an operation with a constant.
94   CheckOperandCount,
95   /// Compare the name of an operation with a constant.
96   CheckOperationName,
97   /// Compare the result count of an operation with a constant.
98   CheckResultCount,
99   /// Compare a range of types to a constant range of types.
100   CheckTypes,
101   /// Continue to the next iteration of a loop.
102   Continue,
103   /// Create a type range from a list of constant types.
104   CreateConstantTypeRange,
105   /// Create an operation.
106   CreateOperation,
107   /// Create a type range from a list of dynamic types.
108   CreateDynamicTypeRange,
109   /// Create a value range.
110   CreateDynamicValueRange,
111   /// Erase an operation.
112   EraseOp,
113   /// Extract the op from a range at the specified index.
114   ExtractOp,
115   /// Extract the type from a range at the specified index.
116   ExtractType,
117   /// Extract the value from a range at the specified index.
118   ExtractValue,
119   /// Terminate a matcher or rewrite sequence.
120   Finalize,
121   /// Iterate over a range of values.
122   ForEach,
123   /// Get a specific attribute of an operation.
124   GetAttribute,
125   /// Get the type of an attribute.
126   GetAttributeType,
127   /// Get the defining operation of a value.
128   GetDefiningOp,
129   /// Get a specific operand of an operation.
130   GetOperand0,
131   GetOperand1,
132   GetOperand2,
133   GetOperand3,
134   GetOperandN,
135   /// Get a specific operand group of an operation.
136   GetOperands,
137   /// Get a specific result of an operation.
138   GetResult0,
139   GetResult1,
140   GetResult2,
141   GetResult3,
142   GetResultN,
143   /// Get a specific result group of an operation.
144   GetResults,
145   /// Get the users of a value or a range of values.
146   GetUsers,
147   /// Get the type of a value.
148   GetValueType,
149   /// Get the types of a value range.
150   GetValueRangeTypes,
151   /// Check if a generic value is not null.
152   IsNotNull,
153   /// Record a successful pattern match.
154   RecordMatch,
155   /// Replace an operation.
156   ReplaceOp,
157   /// Compare an attribute with a set of constants.
158   SwitchAttribute,
159   /// Compare the operand count of an operation with a set of constants.
160   SwitchOperandCount,
161   /// Compare the name of an operation with a set of constants.
162   SwitchOperationName,
163   /// Compare the result count of an operation with a set of constants.
164   SwitchResultCount,
165   /// Compare a type with a set of constants.
166   SwitchType,
167   /// Compare a range of types with a set of constants.
168   SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
174     std::numeric_limits<ByteCodeField>::max();
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
194   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195             SmallVectorImpl<ByteCodeField> &matcherByteCode,
196             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197             SmallVectorImpl<PDLByteCodePattern> &patterns,
198             ByteCodeField &maxValueMemoryIndex,
199             ByteCodeField &maxOpRangeMemoryIndex,
200             ByteCodeField &maxTypeRangeMemoryIndex,
201             ByteCodeField &maxValueRangeMemoryIndex,
202             ByteCodeField &maxLoopLevel,
203             llvm::StringMap<PDLConstraintFunction> &constraintFns,
204             llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205             const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207         rewriterByteCode(rewriterByteCode), patterns(patterns),
208         maxValueMemoryIndex(maxValueMemoryIndex),
209         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212         maxLoopLevel(maxLoopLevel), configMap(configMap) {
213     for (const auto &it : llvm::enumerate(constraintFns))
214       constraintToMemIndex.try_emplace(it.value().first(), it.index());
215     for (const auto &it : llvm::enumerate(rewriteFns))
216       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217   }
218 
219   /// Generate the bytecode for the given PDL interpreter module.
220   void generate(ModuleOp module);
221 
222   /// Return the memory index to use for the given value.
223   ByteCodeField &getMemIndex(Value value) {
224     assert(valueToMemIndex.count(value) &&
225            "expected memory index to be assigned");
226     return valueToMemIndex[value];
227   }
228 
229   /// Return the range memory index used to store the given range value.
230   ByteCodeField &getRangeStorageIndex(Value value) {
231     assert(valueToRangeIndex.count(value) &&
232            "expected range index to be assigned");
233     return valueToRangeIndex[value];
234   }
235 
236   /// Return an index to use when referring to the given data that is uniqued in
237   /// the MLIR context.
238   template <typename T>
239   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240   getMemIndex(T val) {
241     const void *opaqueVal = val.getAsOpaquePointer();
242 
243     // Get or insert a reference to this value.
244     auto it = uniquedDataToMemIndex.try_emplace(
245         opaqueVal, maxValueMemoryIndex + uniquedData.size());
246     if (it.second)
247       uniquedData.push_back(opaqueVal);
248     return it.first->second;
249   }
250 
251 private:
252   /// Allocate memory indices for the results of operations within the matcher
253   /// and rewriters.
254   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255                              ModuleOp rewriterModule);
256 
257   /// Generate the bytecode for the given operation.
258   void generate(Region *region, ByteCodeWriter &writer);
259   void generate(Operation *op, ByteCodeWriter &writer);
260   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299   /// Mapping from value to its corresponding memory index.
300   DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302   /// Mapping from a range value to its corresponding range storage index.
303   DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305   /// Mapping from the name of an externally registered rewrite to its index in
306   /// the bytecode registry.
307   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309   /// Mapping from the name of an externally registered constraint to its index
310   /// in the bytecode registry.
311   llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313   /// Mapping from rewriter function name to the bytecode address of the
314   /// rewriter function in byte.
315   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317   /// Mapping from a uniqued storage object to its memory index within
318   /// `uniquedData`.
319   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321   /// The current level of the foreach loop.
322   ByteCodeField curLoopLevel = 0;
323 
324   /// The current MLIR context.
325   MLIRContext *ctx;
326 
327   /// Mapping from block to its address.
328   DenseMap<Block *, ByteCodeAddr> blockToAddr;
329 
330   /// Data of the ByteCode class to be populated.
331   std::vector<const void *> &uniquedData;
332   SmallVectorImpl<ByteCodeField> &matcherByteCode;
333   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334   SmallVectorImpl<PDLByteCodePattern> &patterns;
335   ByteCodeField &maxValueMemoryIndex;
336   ByteCodeField &maxOpRangeMemoryIndex;
337   ByteCodeField &maxTypeRangeMemoryIndex;
338   ByteCodeField &maxValueRangeMemoryIndex;
339   ByteCodeField &maxLoopLevel;
340 
341   /// A map of pattern configurations.
342   const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
347   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348       : bytecode(bytecode), generator(generator) {}
349 
350   /// Append a field to the bytecode.
351   void append(ByteCodeField field) { bytecode.push_back(field); }
352   void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354   /// Append an address to the bytecode.
355   void append(ByteCodeAddr field) {
356     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357                   "unexpected ByteCode address size");
358 
359     ByteCodeField fieldParts[2];
360     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361     bytecode.append({fieldParts[0], fieldParts[1]});
362   }
363 
364   /// Append a single successor to the bytecode, the exact address will need to
365   /// be resolved later.
366   void append(Block *successor) {
367     // Add back a reference to the successor so that the address can be resolved
368     // later.
369     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370     append(ByteCodeAddr(0));
371   }
372 
373   /// Append a successor range to the bytecode, the exact address will need to
374   /// be resolved later.
375   void append(SuccessorRange successors) {
376     for (Block *successor : successors)
377       append(successor);
378   }
379 
380   /// Append a range of values that will be read as generic PDLValues.
381   void appendPDLValueList(OperandRange values) {
382     bytecode.push_back(values.size());
383     for (Value value : values)
384       appendPDLValue(value);
385   }
386 
387   /// Append a value as a PDLValue.
388   void appendPDLValue(Value value) {
389     appendPDLValueKind(value);
390     append(value);
391   }
392 
393   /// Append the PDLValue::Kind of the given value.
394   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396   /// Append the PDLValue::Kind of the given type.
397   void appendPDLValueKind(Type type) {
398     PDLValue::Kind kind =
399         TypeSwitch<Type, PDLValue::Kind>(type)
400             .Case<pdl::AttributeType>(
401                 [](Type) { return PDLValue::Kind::Attribute; })
402             .Case<pdl::OperationType>(
403                 [](Type) { return PDLValue::Kind::Operation; })
404             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405               if (isa<pdl::TypeType>(rangeTy.getElementType()))
406                 return PDLValue::Kind::TypeRange;
407               return PDLValue::Kind::ValueRange;
408             })
409             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411     bytecode.push_back(static_cast<ByteCodeField>(kind));
412   }
413 
414   /// Append a value that will be stored in a memory slot and not inline within
415   /// the bytecode.
416   template <typename T>
417   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418                    std::is_pointer<T>::value>
419   append(T value) {
420     bytecode.push_back(generator.getMemIndex(value));
421   }
422 
423   /// Append a range of values.
424   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426   append(T range) {
427     bytecode.push_back(llvm::size(range));
428     for (auto it : range)
429       append(it);
430   }
431 
432   /// Append a variadic number of fields to the bytecode.
433   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435     append(field);
436     append(field2, fields...);
437   }
438 
439   /// Appends a value as a pointer, stored inline within the bytecode.
440   template <typename T>
441   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442   appendInline(T value) {
443     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444     const void *pointer = value.getAsOpaquePointer();
445     ByteCodeField fieldParts[numParts];
446     std::memcpy(fieldParts, &pointer, sizeof(const void *));
447     bytecode.append(fieldParts, fieldParts + numParts);
448   }
449 
450   /// Successor references in the bytecode that have yet to be resolved.
451   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453   /// The underlying bytecode buffer.
454   SmallVectorImpl<ByteCodeField> &bytecode;
455 
456   /// The main generator producing PDL.
457   Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463   using Set = llvm::IntervalMap<uint64_t, char, 16>;
464   using Allocator = Set::Allocator;
465 
466   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468   /// Union this live range with the one provided.
469   void unionWith(const ByteCodeLiveRange &rhs) {
470     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471          ++it)
472       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473   }
474 
475   /// Returns true if this range overlaps with the one provided.
476   bool overlaps(const ByteCodeLiveRange &rhs) const {
477     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478         .valid();
479   }
480 
481   /// A map representing the ranges of the match/rewrite that a value is live in
482   /// the interpreter.
483   ///
484   /// We use std::unique_ptr here, because IntervalMap does not provide a
485   /// correct copy or move constructor. We can eliminate the pointer once
486   /// https://reviews.llvm.org/D113240 lands.
487   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489   /// The operation range storage index for this range.
490   std::optional<unsigned> opRangeIndex;
491 
492   /// The type range storage index for this range.
493   std::optional<unsigned> typeRangeIndex;
494 
495   /// The value range storage index for this range.
496   std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
500 void Generator::generate(ModuleOp module) {
501   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504       pdl_interp::PDLInterpDialect::getRewriterModuleName());
505   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507   // Allocate memory indices for the results of operations within the matcher
508   // and rewriters.
509   allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511   // Generate code for the rewriter functions.
512   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515     for (Operation &op : rewriterFunc.getOps())
516       generate(&op, rewriterByteCodeWriter);
517   }
518   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519          "unexpected branches in rewriter function");
520 
521   // Generate code for the matcher function.
522   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525   // Resolve successor references in the matcher.
526   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527     ByteCodeAddr addr = blockToAddr[it.first];
528     for (unsigned offsetToFix : it.second)
529       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530   }
531 }
532 
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534                                       ModuleOp rewriterModule) {
535   // Rewriters use simplistic allocation scheme that simply assigns an index to
536   // each result.
537   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539     auto processRewriterValue = [&](Value val) {
540       valueToMemIndex.try_emplace(val, index++);
541       if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542         Type elementTy = rangeType.getElementType();
543         if (isa<pdl::TypeType>(elementTy))
544           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545         else if (isa<pdl::ValueType>(elementTy))
546           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547       }
548     };
549 
550     for (BlockArgument arg : rewriterFunc.getArguments())
551       processRewriterValue(arg);
552     rewriterFunc.getBody().walk([&](Operation *op) {
553       for (Value result : op->getResults())
554         processRewriterValue(result);
555     });
556     if (index > maxValueMemoryIndex)
557       maxValueMemoryIndex = index;
558     if (typeRangeIndex > maxTypeRangeMemoryIndex)
559       maxTypeRangeMemoryIndex = typeRangeIndex;
560     if (valueRangeIndex > maxValueRangeMemoryIndex)
561       maxValueRangeMemoryIndex = valueRangeIndex;
562   }
563 
564   // The matcher function uses a more sophisticated numbering that tries to
565   // minimize the number of memory indices assigned. This is done by determining
566   // a live range of the values within the matcher, then the allocation is just
567   // finding the minimal number of overlapping live ranges. This is essentially
568   // a simplified form of register allocation where we don't necessarily have a
569   // limited number of registers, but we still want to minimize the number used.
570   DenseMap<Operation *, unsigned> opToFirstIndex;
571   DenseMap<Operation *, unsigned> opToLastIndex;
572 
573   // A custom walk that marks the first and the last index of each operation.
574   // The entry marks the beginning of the liveness range for this operation,
575   // followed by nested operations, followed by the end of the liveness range.
576   unsigned index = 0;
577   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578     opToFirstIndex.try_emplace(op, index++);
579     for (Region &region : op->getRegions())
580       for (Block &block : region.getBlocks())
581         for (Operation &nested : block)
582           walk(&nested);
583     opToLastIndex.try_emplace(op, index++);
584   };
585   walk(matcherFunc);
586 
587   // Liveness info for each of the defs within the matcher.
588   ByteCodeLiveRange::Allocator allocator;
589   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591   // Assign the root operation being matched to slot 0.
592   BlockArgument rootOpArg = matcherFunc.getArgument(0);
593   valueToMemIndex[rootOpArg] = 0;
594 
595   // Walk each of the blocks, computing the def interval that the value is used.
596   Liveness matcherLiveness(matcherFunc);
597   matcherFunc->walk([&](Block *block) {
598     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599     assert(info && "expected liveness info for block");
600     auto processValue = [&](Value value, Operation *firstUseOrDef) {
601       // We don't need to process the root op argument, this value is always
602       // assigned to the first memory slot.
603       if (value == rootOpArg)
604         return;
605 
606       // Set indices for the range of this block that the value is used.
607       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608       defRangeIt->second.liveness->insert(
609           opToFirstIndex[firstUseOrDef],
610           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611           /*dummyValue*/ 0);
612 
613       // Check to see if this value is a range type.
614       if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615         Type eleType = rangeTy.getElementType();
616         if (isa<pdl::OperationType>(eleType))
617           defRangeIt->second.opRangeIndex = 0;
618         else if (isa<pdl::TypeType>(eleType))
619           defRangeIt->second.typeRangeIndex = 0;
620         else if (isa<pdl::ValueType>(eleType))
621           defRangeIt->second.valueRangeIndex = 0;
622       }
623     };
624 
625     // Process the live-ins of this block.
626     for (Value liveIn : info->in()) {
627       // Only process the value if it has been defined in the current region.
628       // Other values that span across pdl_interp.foreach will be added higher
629       // up. This ensures that the we keep them alive for the entire duration
630       // of the loop.
631       if (liveIn.getParentRegion() == block->getParent())
632         processValue(liveIn, &block->front());
633     }
634 
635     // Process the block arguments for the entry block (those are not live-in).
636     if (block->isEntryBlock()) {
637       for (Value argument : block->getArguments())
638         processValue(argument, &block->front());
639     }
640 
641     // Process any new defs within this block.
642     for (Operation &op : *block)
643       for (Value result : op.getResults())
644         processValue(result, &op);
645   });
646 
647   // Greedily allocate memory slots using the computed def live ranges.
648   std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650   // The number of memory indices currently allocated (and its next value).
651   // Recall that the root gets allocated memory index 0.
652   ByteCodeField numIndices = 1;
653 
654   // The number of memory ranges of various types (and their next values).
655   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657   for (auto &defIt : valueDefRanges) {
658     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659     ByteCodeLiveRange &defRange = defIt.second;
660 
661     // Try to allocate to an existing index.
662     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663       ByteCodeLiveRange &existingRange = existingIndexIt.value();
664       if (!defRange.overlaps(existingRange)) {
665         existingRange.unionWith(defRange);
666         memIndex = existingIndexIt.index() + 1;
667 
668         if (defRange.opRangeIndex) {
669           if (!existingRange.opRangeIndex)
670             existingRange.opRangeIndex = numOpRanges++;
671           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672         } else if (defRange.typeRangeIndex) {
673           if (!existingRange.typeRangeIndex)
674             existingRange.typeRangeIndex = numTypeRanges++;
675           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676         } else if (defRange.valueRangeIndex) {
677           if (!existingRange.valueRangeIndex)
678             existingRange.valueRangeIndex = numValueRanges++;
679           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680         }
681         break;
682       }
683     }
684 
685     // If no existing index could be used, add a new one.
686     if (memIndex == 0) {
687       allocatedIndices.emplace_back(allocator);
688       ByteCodeLiveRange &newRange = allocatedIndices.back();
689       newRange.unionWith(defRange);
690 
691       // Allocate an index for op/type/value ranges.
692       if (defRange.opRangeIndex) {
693         newRange.opRangeIndex = numOpRanges;
694         valueToRangeIndex[defIt.first] = numOpRanges++;
695       } else if (defRange.typeRangeIndex) {
696         newRange.typeRangeIndex = numTypeRanges;
697         valueToRangeIndex[defIt.first] = numTypeRanges++;
698       } else if (defRange.valueRangeIndex) {
699         newRange.valueRangeIndex = numValueRanges;
700         valueToRangeIndex[defIt.first] = numValueRanges++;
701       }
702 
703       memIndex = allocatedIndices.size();
704       ++numIndices;
705     }
706   }
707 
708   // Print the index usage and ensure that we did not run out of index space.
709   LLVM_DEBUG({
710     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711                  << "(down from initial " << valueDefRanges.size() << ").\n";
712   });
713   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714          "Ran out of memory for allocated indices");
715 
716   // Update the max number of indices.
717   if (numIndices > maxValueMemoryIndex)
718     maxValueMemoryIndex = numIndices;
719   if (numOpRanges > maxOpRangeMemoryIndex)
720     maxOpRangeMemoryIndex = numOpRanges;
721   if (numTypeRanges > maxTypeRangeMemoryIndex)
722     maxTypeRangeMemoryIndex = numTypeRanges;
723   if (numValueRanges > maxValueRangeMemoryIndex)
724     maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728   llvm::ReversePostOrderTraversal<Region *> rpot(region);
729   for (Block *block : rpot) {
730     // Keep track of where this block begins within the matcher function.
731     blockToAddr.try_emplace(block, matcherByteCode.size());
732     for (Operation &op : *block)
733       generate(&op, writer);
734   }
735 }
736 
737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738   LLVM_DEBUG({
739     // The following list must contain all the operations that do not
740     // produce any bytecode.
741     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742       writer.appendInline(op->getLoc());
743   });
744   TypeSwitch<Operation *>(op)
745       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751             pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753             pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763             pdl_interp::SwitchResultCountOp>(
764           [&](auto interpOp) { this->generate(interpOp, writer); })
765       .Default([](Operation *) {
766         llvm_unreachable("unknown `pdl_interp` operation");
767       });
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771                          ByteCodeWriter &writer) {
772   assert(constraintToMemIndex.count(op.getName()) &&
773          "expected index for constraint function");
774   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775   writer.appendPDLValueList(op.getArgs());
776   writer.append(ByteCodeField(op.getIsNegated()));
777   writer.append(op.getSuccessors());
778 }
779 void Generator::generate(pdl_interp::ApplyRewriteOp op,
780                          ByteCodeWriter &writer) {
781   assert(externalRewriterToMemIndex.count(op.getName()) &&
782          "expected index for rewrite function");
783   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
784   writer.appendPDLValueList(op.getArgs());
785 
786   ResultRange results = op.getResults();
787   writer.append(ByteCodeField(results.size()));
788   for (Value result : results) {
789     // In debug mode we also record the expected kind of the result, so that we
790     // can provide extra verification of the native rewrite function.
791 #ifndef NDEBUG
792     writer.appendPDLValueKind(result);
793 #endif
794 
795     // Range results also need to append the range storage index.
796     if (isa<pdl::RangeType>(result.getType()))
797       writer.append(getRangeStorageIndex(result));
798     writer.append(result);
799   }
800 }
801 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
802   Value lhs = op.getLhs();
803   if (isa<pdl::RangeType>(lhs.getType())) {
804     writer.append(OpCode::AreRangesEqual);
805     writer.appendPDLValueKind(lhs);
806     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
807     return;
808   }
809 
810   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
811 }
812 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
813   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
814 }
815 void Generator::generate(pdl_interp::CheckAttributeOp op,
816                          ByteCodeWriter &writer) {
817   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
818                 op.getSuccessors());
819 }
820 void Generator::generate(pdl_interp::CheckOperandCountOp op,
821                          ByteCodeWriter &writer) {
822   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
823                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
824                 op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::CheckOperationNameOp op,
827                          ByteCodeWriter &writer) {
828   writer.append(OpCode::CheckOperationName, op.getInputOp(),
829                 OperationName(op.getName(), ctx), op.getSuccessors());
830 }
831 void Generator::generate(pdl_interp::CheckResultCountOp op,
832                          ByteCodeWriter &writer) {
833   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
834                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
835                 op.getSuccessors());
836 }
837 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
838   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
839                 op.getSuccessors());
840 }
841 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
842   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
843                 op.getSuccessors());
844 }
845 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
846   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
847   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
848 }
849 void Generator::generate(pdl_interp::CreateAttributeOp op,
850                          ByteCodeWriter &writer) {
851   // Simply repoint the memory index of the result to the constant.
852   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
853 }
854 void Generator::generate(pdl_interp::CreateOperationOp op,
855                          ByteCodeWriter &writer) {
856   writer.append(OpCode::CreateOperation, op.getResultOp(),
857                 OperationName(op.getName(), ctx));
858   writer.appendPDLValueList(op.getInputOperands());
859 
860   // Add the attributes.
861   OperandRange attributes = op.getInputAttributes();
862   writer.append(static_cast<ByteCodeField>(attributes.size()));
863   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
864     writer.append(std::get<0>(it), std::get<1>(it));
865 
866   // Add the result types. If the operation has inferred results, we use a
867   // marker "size" value. Otherwise, we add the list of explicit result types.
868   if (op.getInferredResultTypes())
869     writer.append(kInferTypesMarker);
870   else
871     writer.appendPDLValueList(op.getInputResultTypes());
872 }
873 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
874   // Append the correct opcode for the range type.
875   TypeSwitch<Type>(op.getType().getElementType())
876       .Case(
877           [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
878       .Case([&](pdl::ValueType) {
879         writer.append(OpCode::CreateDynamicValueRange);
880       });
881 
882   writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
883   writer.appendPDLValueList(op->getOperands());
884 }
885 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
886   // Simply repoint the memory index of the result to the constant.
887   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
888 }
889 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
890   writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
891                 getRangeStorageIndex(op.getResult()), op.getValue());
892 }
893 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
894   writer.append(OpCode::EraseOp, op.getInputOp());
895 }
896 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
897   OpCode opCode =
898       TypeSwitch<Type, OpCode>(op.getResult().getType())
899           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
900           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
901           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
902           .Default([](Type) -> OpCode {
903             llvm_unreachable("unsupported element type");
904           });
905   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
906 }
907 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
908   writer.append(OpCode::Finalize);
909 }
910 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
911   BlockArgument arg = op.getLoopVariable();
912   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
913   writer.appendPDLValueKind(arg.getType());
914   writer.append(curLoopLevel, op.getSuccessor());
915   ++curLoopLevel;
916   if (curLoopLevel > maxLoopLevel)
917     maxLoopLevel = curLoopLevel;
918   generate(&op.getRegion(), writer);
919   --curLoopLevel;
920 }
921 void Generator::generate(pdl_interp::GetAttributeOp op,
922                          ByteCodeWriter &writer) {
923   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
924                 op.getNameAttr());
925 }
926 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
927                          ByteCodeWriter &writer) {
928   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
929 }
930 void Generator::generate(pdl_interp::GetDefiningOpOp op,
931                          ByteCodeWriter &writer) {
932   writer.append(OpCode::GetDefiningOp, op.getInputOp());
933   writer.appendPDLValue(op.getValue());
934 }
935 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
936   uint32_t index = op.getIndex();
937   if (index < 4)
938     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
939   else
940     writer.append(OpCode::GetOperandN, index);
941   writer.append(op.getInputOp(), op.getValue());
942 }
943 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
944   Value result = op.getValue();
945   std::optional<uint32_t> index = op.getIndex();
946   writer.append(OpCode::GetOperands,
947                 index.value_or(std::numeric_limits<uint32_t>::max()),
948                 op.getInputOp());
949   if (isa<pdl::RangeType>(result.getType()))
950     writer.append(getRangeStorageIndex(result));
951   else
952     writer.append(std::numeric_limits<ByteCodeField>::max());
953   writer.append(result);
954 }
955 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
956   uint32_t index = op.getIndex();
957   if (index < 4)
958     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
959   else
960     writer.append(OpCode::GetResultN, index);
961   writer.append(op.getInputOp(), op.getValue());
962 }
963 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
964   Value result = op.getValue();
965   std::optional<uint32_t> index = op.getIndex();
966   writer.append(OpCode::GetResults,
967                 index.value_or(std::numeric_limits<uint32_t>::max()),
968                 op.getInputOp());
969   if (isa<pdl::RangeType>(result.getType()))
970     writer.append(getRangeStorageIndex(result));
971   else
972     writer.append(std::numeric_limits<ByteCodeField>::max());
973   writer.append(result);
974 }
975 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
976   Value operations = op.getOperations();
977   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
978   writer.append(OpCode::GetUsers, operations, rangeIndex);
979   writer.appendPDLValue(op.getValue());
980 }
981 void Generator::generate(pdl_interp::GetValueTypeOp op,
982                          ByteCodeWriter &writer) {
983   if (isa<pdl::RangeType>(op.getType())) {
984     Value result = op.getResult();
985     writer.append(OpCode::GetValueRangeTypes, result,
986                   getRangeStorageIndex(result), op.getValue());
987   } else {
988     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
989   }
990 }
991 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
992   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
993 }
994 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
995   ByteCodeField patternIndex = patterns.size();
996   patterns.emplace_back(PDLByteCodePattern::create(
997       op, configMap.lookup(op),
998       rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
999   writer.append(OpCode::RecordMatch, patternIndex,
1000                 SuccessorRange(op.getOperation()), op.getMatchedOps());
1001   writer.appendPDLValueList(op.getInputs());
1002 }
1003 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1004   writer.append(OpCode::ReplaceOp, op.getInputOp());
1005   writer.appendPDLValueList(op.getReplValues());
1006 }
1007 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1008                          ByteCodeWriter &writer) {
1009   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1010                 op.getCaseValuesAttr(), op.getSuccessors());
1011 }
1012 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1013                          ByteCodeWriter &writer) {
1014   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1015                 op.getCaseValuesAttr(), op.getSuccessors());
1016 }
1017 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1018                          ByteCodeWriter &writer) {
1019   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1020     return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1021   });
1022   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1023                 op.getSuccessors());
1024 }
1025 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1026                          ByteCodeWriter &writer) {
1027   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1028                 op.getCaseValuesAttr(), op.getSuccessors());
1029 }
1030 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1031   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1032                 op.getSuccessors());
1033 }
1034 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1035   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1036                 op.getSuccessors());
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // PDLByteCode
1041 //===----------------------------------------------------------------------===//
1042 
1043 PDLByteCode::PDLByteCode(
1044     ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1045     const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1046     llvm::StringMap<PDLConstraintFunction> constraintFns,
1047     llvm::StringMap<PDLRewriteFunction> rewriteFns)
1048     : configs(std::move(configs)) {
1049   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1050                       rewriterByteCode, patterns, maxValueMemoryIndex,
1051                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1052                       maxLoopLevel, constraintFns, rewriteFns, configMap);
1053   generator.generate(module);
1054 
1055   // Initialize the external functions.
1056   for (auto &it : constraintFns)
1057     constraintFunctions.push_back(std::move(it.second));
1058   for (auto &it : rewriteFns)
1059     rewriteFunctions.push_back(std::move(it.second));
1060 }
1061 
1062 /// Initialize the given state such that it can be used to execute the current
1063 /// bytecode.
1064 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1065   state.memory.resize(maxValueMemoryIndex, nullptr);
1066   state.opRangeMemory.resize(maxOpRangeCount);
1067   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1068   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1069   state.loopIndex.resize(maxLoopLevel, 0);
1070   state.currentPatternBenefits.reserve(patterns.size());
1071   for (const PDLByteCodePattern &pattern : patterns)
1072     state.currentPatternBenefits.push_back(pattern.getBenefit());
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // ByteCode Execution
1077 
1078 namespace {
1079 /// This class provides support for executing a bytecode stream.
1080 class ByteCodeExecutor {
1081 public:
1082   ByteCodeExecutor(
1083       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1084       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1085       MutableArrayRef<TypeRange> typeRangeMemory,
1086       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1087       MutableArrayRef<ValueRange> valueRangeMemory,
1088       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1089       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1090       ArrayRef<ByteCodeField> code,
1091       ArrayRef<PatternBenefit> currentPatternBenefits,
1092       ArrayRef<PDLByteCodePattern> patterns,
1093       ArrayRef<PDLConstraintFunction> constraintFunctions,
1094       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1095       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1096         typeRangeMemory(typeRangeMemory),
1097         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1098         valueRangeMemory(valueRangeMemory),
1099         allocatedValueRangeMemory(allocatedValueRangeMemory),
1100         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1101         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1102         constraintFunctions(constraintFunctions),
1103         rewriteFunctions(rewriteFunctions) {}
1104 
1105   /// Start executing the code at the current bytecode index. `matches` is an
1106   /// optional field provided when this function is executed in a matching
1107   /// context.
1108   LogicalResult
1109   execute(PatternRewriter &rewriter,
1110           SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1111           std::optional<Location> mainRewriteLoc = {});
1112 
1113 private:
1114   /// Internal implementation of executing each of the bytecode commands.
1115   void executeApplyConstraint(PatternRewriter &rewriter);
1116   LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1117   void executeAreEqual();
1118   void executeAreRangesEqual();
1119   void executeBranch();
1120   void executeCheckOperandCount();
1121   void executeCheckOperationName();
1122   void executeCheckResultCount();
1123   void executeCheckTypes();
1124   void executeContinue();
1125   void executeCreateConstantTypeRange();
1126   void executeCreateOperation(PatternRewriter &rewriter,
1127                               Location mainRewriteLoc);
1128   template <typename T>
1129   void executeDynamicCreateRange(StringRef type);
1130   void executeEraseOp(PatternRewriter &rewriter);
1131   template <typename T, typename Range, PDLValue::Kind kind>
1132   void executeExtract();
1133   void executeFinalize();
1134   void executeForEach();
1135   void executeGetAttribute();
1136   void executeGetAttributeType();
1137   void executeGetDefiningOp();
1138   void executeGetOperand(unsigned index);
1139   void executeGetOperands();
1140   void executeGetResult(unsigned index);
1141   void executeGetResults();
1142   void executeGetUsers();
1143   void executeGetValueType();
1144   void executeGetValueRangeTypes();
1145   void executeIsNotNull();
1146   void executeRecordMatch(PatternRewriter &rewriter,
1147                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1148   void executeReplaceOp(PatternRewriter &rewriter);
1149   void executeSwitchAttribute();
1150   void executeSwitchOperandCount();
1151   void executeSwitchOperationName();
1152   void executeSwitchResultCount();
1153   void executeSwitchType();
1154   void executeSwitchTypes();
1155 
1156   /// Pushes a code iterator to the stack.
1157   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1158 
1159   /// Pops a code iterator from the stack, returning true on success.
1160   void popCodeIt() {
1161     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1162     curCodeIt = resumeCodeIt.back();
1163     resumeCodeIt.pop_back();
1164   }
1165 
1166   /// Return the bytecode iterator at the start of the current op code.
1167   const ByteCodeField *getPrevCodeIt() const {
1168     LLVM_DEBUG({
1169       // Account for the op code and the Location stored inline.
1170       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1171     });
1172 
1173     // Account for the op code only.
1174     return curCodeIt - 1;
1175   }
1176 
1177   /// Read a value from the bytecode buffer, optionally skipping a certain
1178   /// number of prefix values. These methods always update the buffer to point
1179   /// to the next field after the read data.
1180   template <typename T = ByteCodeField>
1181   T read(size_t skipN = 0) {
1182     curCodeIt += skipN;
1183     return readImpl<T>();
1184   }
1185   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1186 
1187   /// Read a list of values from the bytecode buffer.
1188   template <typename ValueT, typename T>
1189   void readList(SmallVectorImpl<T> &list) {
1190     list.clear();
1191     for (unsigned i = 0, e = read(); i != e; ++i)
1192       list.push_back(read<ValueT>());
1193   }
1194 
1195   /// Read a list of values from the bytecode buffer. The values may be encoded
1196   /// either as a single element or a range of elements.
1197   void readList(SmallVectorImpl<Type> &list) {
1198     for (unsigned i = 0, e = read(); i != e; ++i) {
1199       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1200         list.push_back(read<Type>());
1201       } else {
1202         TypeRange *values = read<TypeRange *>();
1203         list.append(values->begin(), values->end());
1204       }
1205     }
1206   }
1207   void readList(SmallVectorImpl<Value> &list) {
1208     for (unsigned i = 0, e = read(); i != e; ++i) {
1209       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1210         list.push_back(read<Value>());
1211       } else {
1212         ValueRange *values = read<ValueRange *>();
1213         list.append(values->begin(), values->end());
1214       }
1215     }
1216   }
1217 
1218   /// Read a value stored inline as a pointer.
1219   template <typename T>
1220   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1221   readInline() {
1222     const void *pointer;
1223     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1224     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1225     return T::getFromOpaquePointer(pointer);
1226   }
1227 
1228   /// Jump to a specific successor based on a predicate value.
1229   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1230   /// Jump to a specific successor based on a destination index.
1231   void selectJump(size_t destIndex) {
1232     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1233   }
1234 
1235   /// Handle a switch operation with the provided value and cases.
1236   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1237   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1238     LLVM_DEBUG({
1239       llvm::dbgs() << "  * Value: " << value << "\n"
1240                    << "  * Cases: ";
1241       llvm::interleaveComma(cases, llvm::dbgs());
1242       llvm::dbgs() << "\n";
1243     });
1244 
1245     // Check to see if the attribute value is within the case list. Jump to
1246     // the correct successor index based on the result.
1247     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1248       if (cmp(*it, value))
1249         return selectJump(size_t((it - cases.begin()) + 1));
1250     selectJump(size_t(0));
1251   }
1252 
1253   /// Store a pointer to memory.
1254   void storeToMemory(unsigned index, const void *value) {
1255     memory[index] = value;
1256   }
1257 
1258   /// Store a value to memory as an opaque pointer.
1259   template <typename T>
1260   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1261   storeToMemory(unsigned index, T value) {
1262     memory[index] = value.getAsOpaquePointer();
1263   }
1264 
1265   /// Internal implementation of reading various data types from the bytecode
1266   /// stream.
1267   template <typename T>
1268   const void *readFromMemory() {
1269     size_t index = *curCodeIt++;
1270 
1271     // If this type is an SSA value, it can only be stored in non-const memory.
1272     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1273                         Value>::value ||
1274         index < memory.size())
1275       return memory[index];
1276 
1277     // Otherwise, if this index is not inbounds it is uniqued.
1278     return uniquedMemory[index - memory.size()];
1279   }
1280   template <typename T>
1281   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1282     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1283   }
1284   template <typename T>
1285   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1286                    T>
1287   readImpl() {
1288     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1289   }
1290   template <typename T>
1291   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1292     switch (read<PDLValue::Kind>()) {
1293     case PDLValue::Kind::Attribute:
1294       return read<Attribute>();
1295     case PDLValue::Kind::Operation:
1296       return read<Operation *>();
1297     case PDLValue::Kind::Type:
1298       return read<Type>();
1299     case PDLValue::Kind::Value:
1300       return read<Value>();
1301     case PDLValue::Kind::TypeRange:
1302       return read<TypeRange *>();
1303     case PDLValue::Kind::ValueRange:
1304       return read<ValueRange *>();
1305     }
1306     llvm_unreachable("unhandled PDLValue::Kind");
1307   }
1308   template <typename T>
1309   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1310     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1311                   "unexpected ByteCode address size");
1312     ByteCodeAddr result;
1313     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1314     curCodeIt += 2;
1315     return result;
1316   }
1317   template <typename T>
1318   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1319     return *curCodeIt++;
1320   }
1321   template <typename T>
1322   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1323     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1324   }
1325 
1326   /// Assign the given range to the given memory index. This allocates a new
1327   /// range object if necessary.
1328   template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1329   void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1330                            unsigned rangeIndex) {
1331     // Utility functor used to type-erase the assignment.
1332     auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1333       // If the input range is empty, we don't need to allocate anything.
1334       if (range.empty()) {
1335         rangeMemory[rangeIndex] = {};
1336       } else {
1337         // Allocate a buffer for this type range.
1338         llvm::OwningArrayRef<T> storage(llvm::size(range));
1339         llvm::copy(range, storage.begin());
1340 
1341         // Assign this to the range slot and use the range as the value for the
1342         // memory index.
1343         allocatedRangeMemory.emplace_back(std::move(storage));
1344         rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1345       }
1346       memory[memIndex] = &rangeMemory[rangeIndex];
1347     };
1348 
1349     // Dispatch based on the concrete range type.
1350     if constexpr (std::is_same_v<T, Type>) {
1351       return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1352     } else if constexpr (std::is_same_v<T, Value>) {
1353       return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1354     } else {
1355       llvm_unreachable("unhandled range type");
1356     }
1357   }
1358 
1359   /// The underlying bytecode buffer.
1360   const ByteCodeField *curCodeIt;
1361 
1362   /// The stack of bytecode positions at which to resume operation.
1363   SmallVector<const ByteCodeField *> resumeCodeIt;
1364 
1365   /// The current execution memory.
1366   MutableArrayRef<const void *> memory;
1367   MutableArrayRef<OwningOpRange> opRangeMemory;
1368   MutableArrayRef<TypeRange> typeRangeMemory;
1369   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1370   MutableArrayRef<ValueRange> valueRangeMemory;
1371   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1372 
1373   /// The current loop indices.
1374   MutableArrayRef<unsigned> loopIndex;
1375 
1376   /// References to ByteCode data necessary for execution.
1377   ArrayRef<const void *> uniquedMemory;
1378   ArrayRef<ByteCodeField> code;
1379   ArrayRef<PatternBenefit> currentPatternBenefits;
1380   ArrayRef<PDLByteCodePattern> patterns;
1381   ArrayRef<PDLConstraintFunction> constraintFunctions;
1382   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1383 };
1384 
1385 /// This class is an instantiation of the PDLResultList that provides access to
1386 /// the returned results. This API is not on `PDLResultList` to avoid
1387 /// overexposing access to information specific solely to the ByteCode.
1388 class ByteCodeRewriteResultList : public PDLResultList {
1389 public:
1390   ByteCodeRewriteResultList(unsigned maxNumResults)
1391       : PDLResultList(maxNumResults) {}
1392 
1393   /// Return the list of PDL results.
1394   MutableArrayRef<PDLValue> getResults() { return results; }
1395 
1396   /// Return the type ranges allocated by this list.
1397   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1398     return allocatedTypeRanges;
1399   }
1400 
1401   /// Return the value ranges allocated by this list.
1402   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1403     return allocatedValueRanges;
1404   }
1405 };
1406 } // namespace
1407 
1408 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1409   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1410   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1411   SmallVector<PDLValue, 16> args;
1412   readList<PDLValue>(args);
1413 
1414   LLVM_DEBUG({
1415     llvm::dbgs() << "  * Arguments: ";
1416     llvm::interleaveComma(args, llvm::dbgs());
1417     llvm::dbgs() << "\n";
1418   });
1419 
1420   ByteCodeField isNegated = read();
1421   LLVM_DEBUG({
1422     llvm::dbgs() << "  * isNegated: " << isNegated << "\n";
1423     llvm::interleaveComma(args, llvm::dbgs());
1424   });
1425   // Invoke the constraint and jump to the proper destination.
1426   selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
1427 }
1428 
1429 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1430   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1431   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1432   SmallVector<PDLValue, 16> args;
1433   readList<PDLValue>(args);
1434 
1435   LLVM_DEBUG({
1436     llvm::dbgs() << "  * Arguments: ";
1437     llvm::interleaveComma(args, llvm::dbgs());
1438   });
1439 
1440   // Execute the rewrite function.
1441   ByteCodeField numResults = read();
1442   ByteCodeRewriteResultList results(numResults);
1443   LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1444 
1445   assert(results.getResults().size() == numResults &&
1446          "native PDL rewrite function returned unexpected number of results");
1447 
1448   // Store the results in the bytecode memory.
1449   for (PDLValue &result : results.getResults()) {
1450     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1451 
1452 // In debug mode we also verify the expected kind of the result.
1453 #ifndef NDEBUG
1454     assert(result.getKind() == read<PDLValue::Kind>() &&
1455            "native PDL rewrite function returned an unexpected type of result");
1456 #endif
1457 
1458     // If the result is a range, we need to copy it over to the bytecodes
1459     // range memory.
1460     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1461       unsigned rangeIndex = read();
1462       typeRangeMemory[rangeIndex] = *typeRange;
1463       memory[read()] = &typeRangeMemory[rangeIndex];
1464     } else if (std::optional<ValueRange> valueRange =
1465                    result.dyn_cast<ValueRange>()) {
1466       unsigned rangeIndex = read();
1467       valueRangeMemory[rangeIndex] = *valueRange;
1468       memory[read()] = &valueRangeMemory[rangeIndex];
1469     } else {
1470       memory[read()] = result.getAsOpaquePointer();
1471     }
1472   }
1473 
1474   // Copy over any underlying storage allocated for result ranges.
1475   for (auto &it : results.getAllocatedTypeRanges())
1476     allocatedTypeRangeMemory.push_back(std::move(it));
1477   for (auto &it : results.getAllocatedValueRanges())
1478     allocatedValueRangeMemory.push_back(std::move(it));
1479 
1480   // Process the result of the rewrite.
1481   if (failed(rewriteResult)) {
1482     LLVM_DEBUG(llvm::dbgs() << "  - Failed");
1483     return failure();
1484   }
1485   return success();
1486 }
1487 
1488 void ByteCodeExecutor::executeAreEqual() {
1489   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1490   const void *lhs = read<const void *>();
1491   const void *rhs = read<const void *>();
1492 
1493   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1494   selectJump(lhs == rhs);
1495 }
1496 
1497 void ByteCodeExecutor::executeAreRangesEqual() {
1498   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1499   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1500   const void *lhs = read<const void *>();
1501   const void *rhs = read<const void *>();
1502 
1503   switch (valueKind) {
1504   case PDLValue::Kind::TypeRange: {
1505     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1506     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1507     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1508     selectJump(*lhsRange == *rhsRange);
1509     break;
1510   }
1511   case PDLValue::Kind::ValueRange: {
1512     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1513     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1514     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1515     selectJump(*lhsRange == *rhsRange);
1516     break;
1517   }
1518   default:
1519     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1520   }
1521 }
1522 
1523 void ByteCodeExecutor::executeBranch() {
1524   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1525   curCodeIt = &code[read<ByteCodeAddr>()];
1526 }
1527 
1528 void ByteCodeExecutor::executeCheckOperandCount() {
1529   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1530   Operation *op = read<Operation *>();
1531   uint32_t expectedCount = read<uint32_t>();
1532   bool compareAtLeast = read();
1533 
1534   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1535                           << "  * Expected: " << expectedCount << "\n"
1536                           << "  * Comparator: "
1537                           << (compareAtLeast ? ">=" : "==") << "\n");
1538   if (compareAtLeast)
1539     selectJump(op->getNumOperands() >= expectedCount);
1540   else
1541     selectJump(op->getNumOperands() == expectedCount);
1542 }
1543 
1544 void ByteCodeExecutor::executeCheckOperationName() {
1545   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1546   Operation *op = read<Operation *>();
1547   OperationName expectedName = read<OperationName>();
1548 
1549   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1550                           << "  * Expected: \"" << expectedName << "\"\n");
1551   selectJump(op->getName() == expectedName);
1552 }
1553 
1554 void ByteCodeExecutor::executeCheckResultCount() {
1555   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1556   Operation *op = read<Operation *>();
1557   uint32_t expectedCount = read<uint32_t>();
1558   bool compareAtLeast = read();
1559 
1560   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1561                           << "  * Expected: " << expectedCount << "\n"
1562                           << "  * Comparator: "
1563                           << (compareAtLeast ? ">=" : "==") << "\n");
1564   if (compareAtLeast)
1565     selectJump(op->getNumResults() >= expectedCount);
1566   else
1567     selectJump(op->getNumResults() == expectedCount);
1568 }
1569 
1570 void ByteCodeExecutor::executeCheckTypes() {
1571   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1572   TypeRange *lhs = read<TypeRange *>();
1573   Attribute rhs = read<Attribute>();
1574   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1575 
1576   selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1577 }
1578 
1579 void ByteCodeExecutor::executeContinue() {
1580   ByteCodeField level = read();
1581   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1582                           << "  * Level: " << level << "\n");
1583   ++loopIndex[level];
1584   popCodeIt();
1585 }
1586 
1587 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1588   LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1589   unsigned memIndex = read();
1590   unsigned rangeIndex = read();
1591   ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1592 
1593   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1594   assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1595                       rangeIndex);
1596 }
1597 
1598 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1599                                               Location mainRewriteLoc) {
1600   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1601 
1602   unsigned memIndex = read();
1603   OperationState state(mainRewriteLoc, read<OperationName>());
1604   readList(state.operands);
1605   for (unsigned i = 0, e = read(); i != e; ++i) {
1606     StringAttr name = read<StringAttr>();
1607     if (Attribute attr = read<Attribute>())
1608       state.addAttribute(name, attr);
1609   }
1610 
1611   // Read in the result types. If the "size" is the sentinel value, this
1612   // indicates that the result types should be inferred.
1613   unsigned numResults = read();
1614   if (numResults == kInferTypesMarker) {
1615     InferTypeOpInterface::Concept *inferInterface =
1616         state.name.getInterface<InferTypeOpInterface>();
1617     assert(inferInterface &&
1618            "expected operation to provide InferTypeOpInterface");
1619 
1620     // TODO: Handle failure.
1621     if (failed(inferInterface->inferReturnTypes(
1622             state.getContext(), state.location, state.operands,
1623             state.attributes.getDictionary(state.getContext()),
1624             state.getRawProperties(), state.regions, state.types)))
1625       return;
1626   } else {
1627     // Otherwise, this is a fixed number of results.
1628     for (unsigned i = 0; i != numResults; ++i) {
1629       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1630         state.types.push_back(read<Type>());
1631       } else {
1632         TypeRange *resultTypes = read<TypeRange *>();
1633         state.types.append(resultTypes->begin(), resultTypes->end());
1634       }
1635     }
1636   }
1637 
1638   Operation *resultOp = rewriter.create(state);
1639   memory[memIndex] = resultOp;
1640 
1641   LLVM_DEBUG({
1642     llvm::dbgs() << "  * Attributes: "
1643                  << state.attributes.getDictionary(state.getContext())
1644                  << "\n  * Operands: ";
1645     llvm::interleaveComma(state.operands, llvm::dbgs());
1646     llvm::dbgs() << "\n  * Result Types: ";
1647     llvm::interleaveComma(state.types, llvm::dbgs());
1648     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1649   });
1650 }
1651 
1652 template <typename T>
1653 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1654   LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1655   unsigned memIndex = read();
1656   unsigned rangeIndex = read();
1657   SmallVector<T> values;
1658   readList(values);
1659 
1660   LLVM_DEBUG({
1661     llvm::dbgs() << "\n  * " << type << "s: ";
1662     llvm::interleaveComma(values, llvm::dbgs());
1663     llvm::dbgs() << "\n";
1664   });
1665 
1666   assignRangeToMemory(values, memIndex, rangeIndex);
1667 }
1668 
1669 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1670   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1671   Operation *op = read<Operation *>();
1672 
1673   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1674   rewriter.eraseOp(op);
1675 }
1676 
1677 template <typename T, typename Range, PDLValue::Kind kind>
1678 void ByteCodeExecutor::executeExtract() {
1679   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1680   Range *range = read<Range *>();
1681   unsigned index = read<uint32_t>();
1682   unsigned memIndex = read();
1683 
1684   if (!range) {
1685     memory[memIndex] = nullptr;
1686     return;
1687   }
1688 
1689   T result = index < range->size() ? (*range)[index] : T();
1690   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1691                           << "  * Index: " << index << "\n"
1692                           << "  * Result: " << result << "\n");
1693   storeToMemory(memIndex, result);
1694 }
1695 
1696 void ByteCodeExecutor::executeFinalize() {
1697   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1698 }
1699 
1700 void ByteCodeExecutor::executeForEach() {
1701   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1702   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1703   unsigned rangeIndex = read();
1704   unsigned memIndex = read();
1705   const void *value = nullptr;
1706 
1707   switch (read<PDLValue::Kind>()) {
1708   case PDLValue::Kind::Operation: {
1709     unsigned &index = loopIndex[read()];
1710     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1711     assert(index <= array.size() && "iterated past the end");
1712     if (index < array.size()) {
1713       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1714       value = array[index];
1715       break;
1716     }
1717 
1718     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1719     index = 0;
1720     selectJump(size_t(0));
1721     return;
1722   }
1723   default:
1724     llvm_unreachable("unexpected `ForEach` value kind");
1725   }
1726 
1727   // Store the iterate value and the stack address.
1728   memory[memIndex] = value;
1729   pushCodeIt(prevCodeIt);
1730 
1731   // Skip over the successor (we will enter the body of the loop).
1732   read<ByteCodeAddr>();
1733 }
1734 
1735 void ByteCodeExecutor::executeGetAttribute() {
1736   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1737   unsigned memIndex = read();
1738   Operation *op = read<Operation *>();
1739   StringAttr attrName = read<StringAttr>();
1740   Attribute attr = op->getAttr(attrName);
1741 
1742   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1743                           << "  * Attribute: " << attrName << "\n"
1744                           << "  * Result: " << attr << "\n");
1745   memory[memIndex] = attr.getAsOpaquePointer();
1746 }
1747 
1748 void ByteCodeExecutor::executeGetAttributeType() {
1749   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1750   unsigned memIndex = read();
1751   Attribute attr = read<Attribute>();
1752   Type type;
1753   if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1754     type = typedAttr.getType();
1755 
1756   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1757                           << "  * Result: " << type << "\n");
1758   memory[memIndex] = type.getAsOpaquePointer();
1759 }
1760 
1761 void ByteCodeExecutor::executeGetDefiningOp() {
1762   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1763   unsigned memIndex = read();
1764   Operation *op = nullptr;
1765   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1766     Value value = read<Value>();
1767     if (value)
1768       op = value.getDefiningOp();
1769     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1770   } else {
1771     ValueRange *values = read<ValueRange *>();
1772     if (values && !values->empty()) {
1773       op = values->front().getDefiningOp();
1774     }
1775     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1776   }
1777 
1778   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1779   memory[memIndex] = op;
1780 }
1781 
1782 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1783   Operation *op = read<Operation *>();
1784   unsigned memIndex = read();
1785   Value operand =
1786       index < op->getNumOperands() ? op->getOperand(index) : Value();
1787 
1788   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1789                           << "  * Index: " << index << "\n"
1790                           << "  * Result: " << operand << "\n");
1791   memory[memIndex] = operand.getAsOpaquePointer();
1792 }
1793 
1794 /// This function is the internal implementation of `GetResults` and
1795 /// `GetOperands` that provides support for extracting a value range from the
1796 /// given operation.
1797 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1798 static void *
1799 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1800                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1801                           MutableArrayRef<ValueRange> valueRangeMemory) {
1802   // Check for the sentinel index that signals that all values should be
1803   // returned.
1804   if (index == std::numeric_limits<uint32_t>::max()) {
1805     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1806     // `values` is already the full value range.
1807 
1808     // Otherwise, check to see if this operation uses AttrSizedSegments.
1809   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1810     LLVM_DEBUG(llvm::dbgs()
1811                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1812 
1813     auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1814     if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1815       return nullptr;
1816 
1817     ArrayRef<int32_t> segments = segmentAttr;
1818     unsigned startIndex =
1819         std::accumulate(segments.begin(), segments.begin() + index, 0);
1820     values = values.slice(startIndex, *std::next(segments.begin(), index));
1821 
1822     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1823                             << *std::next(segments.begin(), index) << "]\n");
1824 
1825     // Otherwise, assume this is the last operand group of the operation.
1826     // FIXME: We currently don't support operations with
1827     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1828     // have a way to detect it's presence.
1829   } else if (values.size() >= index) {
1830     LLVM_DEBUG(llvm::dbgs()
1831                << "  * Treating values as trailing variadic range\n");
1832     values = values.drop_front(index);
1833 
1834     // If we couldn't detect a way to compute the values, bail out.
1835   } else {
1836     return nullptr;
1837   }
1838 
1839   // If the range index is valid, we are returning a range.
1840   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1841     valueRangeMemory[rangeIndex] = values;
1842     return &valueRangeMemory[rangeIndex];
1843   }
1844 
1845   // If a range index wasn't provided, the range is required to be non-variadic.
1846   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1847 }
1848 
1849 void ByteCodeExecutor::executeGetOperands() {
1850   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1851   unsigned index = read<uint32_t>();
1852   Operation *op = read<Operation *>();
1853   ByteCodeField rangeIndex = read();
1854 
1855   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1856       op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1857       valueRangeMemory);
1858   if (!result)
1859     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1860   memory[read()] = result;
1861 }
1862 
1863 void ByteCodeExecutor::executeGetResult(unsigned index) {
1864   Operation *op = read<Operation *>();
1865   unsigned memIndex = read();
1866   OpResult result =
1867       index < op->getNumResults() ? op->getResult(index) : OpResult();
1868 
1869   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1870                           << "  * Index: " << index << "\n"
1871                           << "  * Result: " << result << "\n");
1872   memory[memIndex] = result.getAsOpaquePointer();
1873 }
1874 
1875 void ByteCodeExecutor::executeGetResults() {
1876   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1877   unsigned index = read<uint32_t>();
1878   Operation *op = read<Operation *>();
1879   ByteCodeField rangeIndex = read();
1880 
1881   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1882       op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1883       valueRangeMemory);
1884   if (!result)
1885     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1886   memory[read()] = result;
1887 }
1888 
1889 void ByteCodeExecutor::executeGetUsers() {
1890   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1891   unsigned memIndex = read();
1892   unsigned rangeIndex = read();
1893   OwningOpRange &range = opRangeMemory[rangeIndex];
1894   memory[memIndex] = &range;
1895 
1896   range = OwningOpRange();
1897   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1898     // Read the value.
1899     Value value = read<Value>();
1900     if (!value)
1901       return;
1902     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1903 
1904     // Extract the users of a single value.
1905     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1906     llvm::copy(value.getUsers(), range.begin());
1907   } else {
1908     // Read a range of values.
1909     ValueRange *values = read<ValueRange *>();
1910     if (!values)
1911       return;
1912     LLVM_DEBUG({
1913       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1914       llvm::interleaveComma(*values, llvm::dbgs());
1915       llvm::dbgs() << "\n";
1916     });
1917 
1918     // Extract all the users of a range of values.
1919     SmallVector<Operation *> users;
1920     for (Value value : *values)
1921       users.append(value.user_begin(), value.user_end());
1922     range = OwningOpRange(users.size());
1923     llvm::copy(users, range.begin());
1924   }
1925 
1926   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1927 }
1928 
1929 void ByteCodeExecutor::executeGetValueType() {
1930   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1931   unsigned memIndex = read();
1932   Value value = read<Value>();
1933   Type type = value ? value.getType() : Type();
1934 
1935   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1936                           << "  * Result: " << type << "\n");
1937   memory[memIndex] = type.getAsOpaquePointer();
1938 }
1939 
1940 void ByteCodeExecutor::executeGetValueRangeTypes() {
1941   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1942   unsigned memIndex = read();
1943   unsigned rangeIndex = read();
1944   ValueRange *values = read<ValueRange *>();
1945   if (!values) {
1946     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1947     memory[memIndex] = nullptr;
1948     return;
1949   }
1950 
1951   LLVM_DEBUG({
1952     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1953     llvm::interleaveComma(*values, llvm::dbgs());
1954     llvm::dbgs() << "\n  * Result: ";
1955     llvm::interleaveComma(values->getType(), llvm::dbgs());
1956     llvm::dbgs() << "\n";
1957   });
1958   typeRangeMemory[rangeIndex] = values->getType();
1959   memory[memIndex] = &typeRangeMemory[rangeIndex];
1960 }
1961 
1962 void ByteCodeExecutor::executeIsNotNull() {
1963   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1964   const void *value = read<const void *>();
1965 
1966   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1967   selectJump(value != nullptr);
1968 }
1969 
1970 void ByteCodeExecutor::executeRecordMatch(
1971     PatternRewriter &rewriter,
1972     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1973   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1974   unsigned patternIndex = read();
1975   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1976   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1977 
1978   // If the benefit of the pattern is impossible, skip the processing of the
1979   // rest of the pattern.
1980   if (benefit.isImpossibleToMatch()) {
1981     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1982     curCodeIt = dest;
1983     return;
1984   }
1985 
1986   // Create a fused location containing the locations of each of the
1987   // operations used in the match. This will be used as the location for
1988   // created operations during the rewrite that don't already have an
1989   // explicit location set.
1990   unsigned numMatchLocs = read();
1991   SmallVector<Location, 4> matchLocs;
1992   matchLocs.reserve(numMatchLocs);
1993   for (unsigned i = 0; i != numMatchLocs; ++i)
1994     matchLocs.push_back(read<Operation *>()->getLoc());
1995   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1996 
1997   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1998                           << "  * Location: " << matchLoc << "\n");
1999   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2000   PDLByteCode::MatchResult &match = matches.back();
2001 
2002   // Record all of the inputs to the match. If any of the inputs are ranges, we
2003   // will also need to remap the range pointer to memory stored in the match
2004   // state.
2005   unsigned numInputs = read();
2006   match.values.reserve(numInputs);
2007   match.typeRangeValues.reserve(numInputs);
2008   match.valueRangeValues.reserve(numInputs);
2009   for (unsigned i = 0; i < numInputs; ++i) {
2010     switch (read<PDLValue::Kind>()) {
2011     case PDLValue::Kind::TypeRange:
2012       match.typeRangeValues.push_back(*read<TypeRange *>());
2013       match.values.push_back(&match.typeRangeValues.back());
2014       break;
2015     case PDLValue::Kind::ValueRange:
2016       match.valueRangeValues.push_back(*read<ValueRange *>());
2017       match.values.push_back(&match.valueRangeValues.back());
2018       break;
2019     default:
2020       match.values.push_back(read<const void *>());
2021       break;
2022     }
2023   }
2024   curCodeIt = dest;
2025 }
2026 
2027 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2028   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2029   Operation *op = read<Operation *>();
2030   SmallVector<Value, 16> args;
2031   readList(args);
2032 
2033   LLVM_DEBUG({
2034     llvm::dbgs() << "  * Operation: " << *op << "\n"
2035                  << "  * Values: ";
2036     llvm::interleaveComma(args, llvm::dbgs());
2037     llvm::dbgs() << "\n";
2038   });
2039   rewriter.replaceOp(op, args);
2040 }
2041 
2042 void ByteCodeExecutor::executeSwitchAttribute() {
2043   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2044   Attribute value = read<Attribute>();
2045   ArrayAttr cases = read<ArrayAttr>();
2046   handleSwitch(value, cases);
2047 }
2048 
2049 void ByteCodeExecutor::executeSwitchOperandCount() {
2050   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2051   Operation *op = read<Operation *>();
2052   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2053 
2054   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2055   handleSwitch(op->getNumOperands(), cases);
2056 }
2057 
2058 void ByteCodeExecutor::executeSwitchOperationName() {
2059   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2060   OperationName value = read<Operation *>()->getName();
2061   size_t caseCount = read();
2062 
2063   // The operation names are stored in-line, so to print them out for
2064   // debugging purposes we need to read the array before executing the
2065   // switch so that we can display all of the possible values.
2066   LLVM_DEBUG({
2067     const ByteCodeField *prevCodeIt = curCodeIt;
2068     llvm::dbgs() << "  * Value: " << value << "\n"
2069                  << "  * Cases: ";
2070     llvm::interleaveComma(
2071         llvm::map_range(llvm::seq<size_t>(0, caseCount),
2072                         [&](size_t) { return read<OperationName>(); }),
2073         llvm::dbgs());
2074     llvm::dbgs() << "\n";
2075     curCodeIt = prevCodeIt;
2076   });
2077 
2078   // Try to find the switch value within any of the cases.
2079   for (size_t i = 0; i != caseCount; ++i) {
2080     if (read<OperationName>() == value) {
2081       curCodeIt += (caseCount - i - 1);
2082       return selectJump(i + 1);
2083     }
2084   }
2085   selectJump(size_t(0));
2086 }
2087 
2088 void ByteCodeExecutor::executeSwitchResultCount() {
2089   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2090   Operation *op = read<Operation *>();
2091   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2092 
2093   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2094   handleSwitch(op->getNumResults(), cases);
2095 }
2096 
2097 void ByteCodeExecutor::executeSwitchType() {
2098   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2099   Type value = read<Type>();
2100   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2101   handleSwitch(value, cases);
2102 }
2103 
2104 void ByteCodeExecutor::executeSwitchTypes() {
2105   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2106   TypeRange *value = read<TypeRange *>();
2107   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2108   if (!value) {
2109     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2110     return selectJump(size_t(0));
2111   }
2112   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2113     return value == caseValue.getAsValueRange<TypeAttr>();
2114   });
2115 }
2116 
2117 LogicalResult
2118 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2119                           SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2120                           std::optional<Location> mainRewriteLoc) {
2121   while (true) {
2122     // Print the location of the operation being executed.
2123     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2124 
2125     OpCode opCode = static_cast<OpCode>(read());
2126     switch (opCode) {
2127     case ApplyConstraint:
2128       executeApplyConstraint(rewriter);
2129       break;
2130     case ApplyRewrite:
2131       if (failed(executeApplyRewrite(rewriter)))
2132         return failure();
2133       break;
2134     case AreEqual:
2135       executeAreEqual();
2136       break;
2137     case AreRangesEqual:
2138       executeAreRangesEqual();
2139       break;
2140     case Branch:
2141       executeBranch();
2142       break;
2143     case CheckOperandCount:
2144       executeCheckOperandCount();
2145       break;
2146     case CheckOperationName:
2147       executeCheckOperationName();
2148       break;
2149     case CheckResultCount:
2150       executeCheckResultCount();
2151       break;
2152     case CheckTypes:
2153       executeCheckTypes();
2154       break;
2155     case Continue:
2156       executeContinue();
2157       break;
2158     case CreateConstantTypeRange:
2159       executeCreateConstantTypeRange();
2160       break;
2161     case CreateOperation:
2162       executeCreateOperation(rewriter, *mainRewriteLoc);
2163       break;
2164     case CreateDynamicTypeRange:
2165       executeDynamicCreateRange<Type>("Type");
2166       break;
2167     case CreateDynamicValueRange:
2168       executeDynamicCreateRange<Value>("Value");
2169       break;
2170     case EraseOp:
2171       executeEraseOp(rewriter);
2172       break;
2173     case ExtractOp:
2174       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2175       break;
2176     case ExtractType:
2177       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2178       break;
2179     case ExtractValue:
2180       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2181       break;
2182     case Finalize:
2183       executeFinalize();
2184       LLVM_DEBUG(llvm::dbgs() << "\n");
2185       return success();
2186     case ForEach:
2187       executeForEach();
2188       break;
2189     case GetAttribute:
2190       executeGetAttribute();
2191       break;
2192     case GetAttributeType:
2193       executeGetAttributeType();
2194       break;
2195     case GetDefiningOp:
2196       executeGetDefiningOp();
2197       break;
2198     case GetOperand0:
2199     case GetOperand1:
2200     case GetOperand2:
2201     case GetOperand3: {
2202       unsigned index = opCode - GetOperand0;
2203       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2204       executeGetOperand(index);
2205       break;
2206     }
2207     case GetOperandN:
2208       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2209       executeGetOperand(read<uint32_t>());
2210       break;
2211     case GetOperands:
2212       executeGetOperands();
2213       break;
2214     case GetResult0:
2215     case GetResult1:
2216     case GetResult2:
2217     case GetResult3: {
2218       unsigned index = opCode - GetResult0;
2219       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2220       executeGetResult(index);
2221       break;
2222     }
2223     case GetResultN:
2224       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2225       executeGetResult(read<uint32_t>());
2226       break;
2227     case GetResults:
2228       executeGetResults();
2229       break;
2230     case GetUsers:
2231       executeGetUsers();
2232       break;
2233     case GetValueType:
2234       executeGetValueType();
2235       break;
2236     case GetValueRangeTypes:
2237       executeGetValueRangeTypes();
2238       break;
2239     case IsNotNull:
2240       executeIsNotNull();
2241       break;
2242     case RecordMatch:
2243       assert(matches &&
2244              "expected matches to be provided when executing the matcher");
2245       executeRecordMatch(rewriter, *matches);
2246       break;
2247     case ReplaceOp:
2248       executeReplaceOp(rewriter);
2249       break;
2250     case SwitchAttribute:
2251       executeSwitchAttribute();
2252       break;
2253     case SwitchOperandCount:
2254       executeSwitchOperandCount();
2255       break;
2256     case SwitchOperationName:
2257       executeSwitchOperationName();
2258       break;
2259     case SwitchResultCount:
2260       executeSwitchResultCount();
2261       break;
2262     case SwitchType:
2263       executeSwitchType();
2264       break;
2265     case SwitchTypes:
2266       executeSwitchTypes();
2267       break;
2268     }
2269     LLVM_DEBUG(llvm::dbgs() << "\n");
2270   }
2271 }
2272 
2273 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2274                         SmallVectorImpl<MatchResult> &matches,
2275                         PDLByteCodeMutableState &state) const {
2276   // The first memory slot is always the root operation.
2277   state.memory[0] = op;
2278 
2279   // The matcher function always starts at code address 0.
2280   ByteCodeExecutor executor(
2281       matcherByteCode.data(), state.memory, state.opRangeMemory,
2282       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2283       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2284       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2285       constraintFunctions, rewriteFunctions);
2286   LogicalResult executeResult = executor.execute(rewriter, &matches);
2287   (void)executeResult;
2288   assert(succeeded(executeResult) && "unexpected matcher execution failure");
2289 
2290   // Order the found matches by benefit.
2291   std::stable_sort(matches.begin(), matches.end(),
2292                    [](const MatchResult &lhs, const MatchResult &rhs) {
2293                      return lhs.benefit > rhs.benefit;
2294                    });
2295 }
2296 
2297 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2298                                    const MatchResult &match,
2299                                    PDLByteCodeMutableState &state) const {
2300   auto *configSet = match.pattern->getConfigSet();
2301   if (configSet)
2302     configSet->notifyRewriteBegin(rewriter);
2303 
2304   // The arguments of the rewrite function are stored at the start of the
2305   // memory buffer.
2306   llvm::copy(match.values, state.memory.begin());
2307 
2308   ByteCodeExecutor executor(
2309       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2310       state.opRangeMemory, state.typeRangeMemory,
2311       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2312       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2313       rewriterByteCode, state.currentPatternBenefits, patterns,
2314       constraintFunctions, rewriteFunctions);
2315   LogicalResult result =
2316       executor.execute(rewriter, /*matches=*/nullptr, match.location);
2317 
2318   if (configSet)
2319     configSet->notifyRewriteEnd(rewriter);
2320 
2321   // If the rewrite failed, check if the pattern rewriter can recover. If it
2322   // can, we can signal to the pattern applicator to keep trying patterns. If it
2323   // doesn't, we need to bail. Bailing here should be fine, given that we have
2324   // no means to propagate such a failure to the user, and it also indicates a
2325   // bug in the user code (i.e. failable rewrites should not be used with
2326   // pattern rewriters that don't support it).
2327   if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2328     LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2329     llvm::report_fatal_error(
2330         "Native PDL Rewrite failed, but the pattern "
2331         "rewriter doesn't support recovery. Failable pattern rewrites should "
2332         "not be used with pattern rewriters that do not support them.");
2333   }
2334   return result;
2335 }
2336