xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision 8c66344ee9f67f76b3cb6b3345a46345a2d3975a)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               PDLPatternConfigSet *configSet,
38                                               ByteCodeAddr rewriterAddr) {
39   PatternBenefit benefit = matchOp.getBenefit();
40   MLIRContext *ctx = matchOp.getContext();
41 
42   // Collect the set of generated operations.
43   SmallVector<StringRef, 8> generatedOps;
44   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
45     generatedOps =
46         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
47 
48   // Check to see if this is pattern matches a specific operation type.
49   if (Optional<StringRef> rootKind = matchOp.getRootKind())
50     return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
51                               generatedOps);
52   return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
53                             benefit, ctx, generatedOps);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // PDLByteCodeMutableState
58 //===----------------------------------------------------------------------===//
59 
60 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
61 /// to the position of the pattern within the range returned by
62 /// `PDLByteCode::getPatterns`.
63 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
64                                                    PatternBenefit benefit) {
65   currentPatternBenefits[patternIndex] = benefit;
66 }
67 
68 /// Cleanup any allocated state after a full match/rewrite has been completed.
69 /// This method should be called irregardless of whether the match+rewrite was a
70 /// success or not.
71 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
72   allocatedTypeRangeMemory.clear();
73   allocatedValueRangeMemory.clear();
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Bytecode OpCodes
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 enum OpCode : ByteCodeField {
82   /// Apply an externally registered constraint.
83   ApplyConstraint,
84   /// Apply an externally registered rewrite.
85   ApplyRewrite,
86   /// Check if two generic values are equal.
87   AreEqual,
88   /// Check if two ranges are equal.
89   AreRangesEqual,
90   /// Unconditional branch.
91   Branch,
92   /// Compare the operand count of an operation with a constant.
93   CheckOperandCount,
94   /// Compare the name of an operation with a constant.
95   CheckOperationName,
96   /// Compare the result count of an operation with a constant.
97   CheckResultCount,
98   /// Compare a range of types to a constant range of types.
99   CheckTypes,
100   /// Continue to the next iteration of a loop.
101   Continue,
102   /// Create an operation.
103   CreateOperation,
104   /// Create a range of types.
105   CreateTypes,
106   /// Erase an operation.
107   EraseOp,
108   /// Extract the op from a range at the specified index.
109   ExtractOp,
110   /// Extract the type from a range at the specified index.
111   ExtractType,
112   /// Extract the value from a range at the specified index.
113   ExtractValue,
114   /// Terminate a matcher or rewrite sequence.
115   Finalize,
116   /// Iterate over a range of values.
117   ForEach,
118   /// Get a specific attribute of an operation.
119   GetAttribute,
120   /// Get the type of an attribute.
121   GetAttributeType,
122   /// Get the defining operation of a value.
123   GetDefiningOp,
124   /// Get a specific operand of an operation.
125   GetOperand0,
126   GetOperand1,
127   GetOperand2,
128   GetOperand3,
129   GetOperandN,
130   /// Get a specific operand group of an operation.
131   GetOperands,
132   /// Get a specific result of an operation.
133   GetResult0,
134   GetResult1,
135   GetResult2,
136   GetResult3,
137   GetResultN,
138   /// Get a specific result group of an operation.
139   GetResults,
140   /// Get the users of a value or a range of values.
141   GetUsers,
142   /// Get the type of a value.
143   GetValueType,
144   /// Get the types of a value range.
145   GetValueRangeTypes,
146   /// Check if a generic value is not null.
147   IsNotNull,
148   /// Record a successful pattern match.
149   RecordMatch,
150   /// Replace an operation.
151   ReplaceOp,
152   /// Compare an attribute with a set of constants.
153   SwitchAttribute,
154   /// Compare the operand count of an operation with a set of constants.
155   SwitchOperandCount,
156   /// Compare the name of an operation with a set of constants.
157   SwitchOperationName,
158   /// Compare the result count of an operation with a set of constants.
159   SwitchResultCount,
160   /// Compare a type with a set of constants.
161   SwitchType,
162   /// Compare a range of types with a set of constants.
163   SwitchTypes,
164 };
165 } // namespace
166 
167 /// A marker used to indicate if an operation should infer types.
168 static constexpr ByteCodeField kInferTypesMarker =
169     std::numeric_limits<ByteCodeField>::max();
170 
171 //===----------------------------------------------------------------------===//
172 // ByteCode Generation
173 //===----------------------------------------------------------------------===//
174 
175 //===----------------------------------------------------------------------===//
176 // Generator
177 
178 namespace {
179 struct ByteCodeLiveRange;
180 struct ByteCodeWriter;
181 
182 /// Check if the given class `T` can be converted to an opaque pointer.
183 template <typename T, typename... Args>
184 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
185 
186 /// This class represents the main generator for the pattern bytecode.
187 class Generator {
188 public:
189   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
190             SmallVectorImpl<ByteCodeField> &matcherByteCode,
191             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
192             SmallVectorImpl<PDLByteCodePattern> &patterns,
193             ByteCodeField &maxValueMemoryIndex,
194             ByteCodeField &maxOpRangeMemoryIndex,
195             ByteCodeField &maxTypeRangeMemoryIndex,
196             ByteCodeField &maxValueRangeMemoryIndex,
197             ByteCodeField &maxLoopLevel,
198             llvm::StringMap<PDLConstraintFunction> &constraintFns,
199             llvm::StringMap<PDLRewriteFunction> &rewriteFns,
200             const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
201       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
202         rewriterByteCode(rewriterByteCode), patterns(patterns),
203         maxValueMemoryIndex(maxValueMemoryIndex),
204         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
205         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
206         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
207         maxLoopLevel(maxLoopLevel), configMap(configMap) {
208     for (const auto &it : llvm::enumerate(constraintFns))
209       constraintToMemIndex.try_emplace(it.value().first(), it.index());
210     for (const auto &it : llvm::enumerate(rewriteFns))
211       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
212   }
213 
214   /// Generate the bytecode for the given PDL interpreter module.
215   void generate(ModuleOp module);
216 
217   /// Return the memory index to use for the given value.
218   ByteCodeField &getMemIndex(Value value) {
219     assert(valueToMemIndex.count(value) &&
220            "expected memory index to be assigned");
221     return valueToMemIndex[value];
222   }
223 
224   /// Return the range memory index used to store the given range value.
225   ByteCodeField &getRangeStorageIndex(Value value) {
226     assert(valueToRangeIndex.count(value) &&
227            "expected range index to be assigned");
228     return valueToRangeIndex[value];
229   }
230 
231   /// Return an index to use when referring to the given data that is uniqued in
232   /// the MLIR context.
233   template <typename T>
234   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
235   getMemIndex(T val) {
236     const void *opaqueVal = val.getAsOpaquePointer();
237 
238     // Get or insert a reference to this value.
239     auto it = uniquedDataToMemIndex.try_emplace(
240         opaqueVal, maxValueMemoryIndex + uniquedData.size());
241     if (it.second)
242       uniquedData.push_back(opaqueVal);
243     return it.first->second;
244   }
245 
246 private:
247   /// Allocate memory indices for the results of operations within the matcher
248   /// and rewriters.
249   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
250                              ModuleOp rewriterModule);
251 
252   /// Generate the bytecode for the given operation.
253   void generate(Region *region, ByteCodeWriter &writer);
254   void generate(Operation *op, ByteCodeWriter &writer);
255   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
289   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
290   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
291   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
292 
293   /// Mapping from value to its corresponding memory index.
294   DenseMap<Value, ByteCodeField> valueToMemIndex;
295 
296   /// Mapping from a range value to its corresponding range storage index.
297   DenseMap<Value, ByteCodeField> valueToRangeIndex;
298 
299   /// Mapping from the name of an externally registered rewrite to its index in
300   /// the bytecode registry.
301   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
302 
303   /// Mapping from the name of an externally registered constraint to its index
304   /// in the bytecode registry.
305   llvm::StringMap<ByteCodeField> constraintToMemIndex;
306 
307   /// Mapping from rewriter function name to the bytecode address of the
308   /// rewriter function in byte.
309   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
310 
311   /// Mapping from a uniqued storage object to its memory index within
312   /// `uniquedData`.
313   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
314 
315   /// The current level of the foreach loop.
316   ByteCodeField curLoopLevel = 0;
317 
318   /// The current MLIR context.
319   MLIRContext *ctx;
320 
321   /// Mapping from block to its address.
322   DenseMap<Block *, ByteCodeAddr> blockToAddr;
323 
324   /// Data of the ByteCode class to be populated.
325   std::vector<const void *> &uniquedData;
326   SmallVectorImpl<ByteCodeField> &matcherByteCode;
327   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
328   SmallVectorImpl<PDLByteCodePattern> &patterns;
329   ByteCodeField &maxValueMemoryIndex;
330   ByteCodeField &maxOpRangeMemoryIndex;
331   ByteCodeField &maxTypeRangeMemoryIndex;
332   ByteCodeField &maxValueRangeMemoryIndex;
333   ByteCodeField &maxLoopLevel;
334 
335   /// A map of pattern configurations.
336   const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
337 };
338 
339 /// This class provides utilities for writing a bytecode stream.
340 struct ByteCodeWriter {
341   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
342       : bytecode(bytecode), generator(generator) {}
343 
344   /// Append a field to the bytecode.
345   void append(ByteCodeField field) { bytecode.push_back(field); }
346   void append(OpCode opCode) { bytecode.push_back(opCode); }
347 
348   /// Append an address to the bytecode.
349   void append(ByteCodeAddr field) {
350     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
351                   "unexpected ByteCode address size");
352 
353     ByteCodeField fieldParts[2];
354     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
355     bytecode.append({fieldParts[0], fieldParts[1]});
356   }
357 
358   /// Append a single successor to the bytecode, the exact address will need to
359   /// be resolved later.
360   void append(Block *successor) {
361     // Add back a reference to the successor so that the address can be resolved
362     // later.
363     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
364     append(ByteCodeAddr(0));
365   }
366 
367   /// Append a successor range to the bytecode, the exact address will need to
368   /// be resolved later.
369   void append(SuccessorRange successors) {
370     for (Block *successor : successors)
371       append(successor);
372   }
373 
374   /// Append a range of values that will be read as generic PDLValues.
375   void appendPDLValueList(OperandRange values) {
376     bytecode.push_back(values.size());
377     for (Value value : values)
378       appendPDLValue(value);
379   }
380 
381   /// Append a value as a PDLValue.
382   void appendPDLValue(Value value) {
383     appendPDLValueKind(value);
384     append(value);
385   }
386 
387   /// Append the PDLValue::Kind of the given value.
388   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
389 
390   /// Append the PDLValue::Kind of the given type.
391   void appendPDLValueKind(Type type) {
392     PDLValue::Kind kind =
393         TypeSwitch<Type, PDLValue::Kind>(type)
394             .Case<pdl::AttributeType>(
395                 [](Type) { return PDLValue::Kind::Attribute; })
396             .Case<pdl::OperationType>(
397                 [](Type) { return PDLValue::Kind::Operation; })
398             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
399               if (rangeTy.getElementType().isa<pdl::TypeType>())
400                 return PDLValue::Kind::TypeRange;
401               return PDLValue::Kind::ValueRange;
402             })
403             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
404             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
405     bytecode.push_back(static_cast<ByteCodeField>(kind));
406   }
407 
408   /// Append a value that will be stored in a memory slot and not inline within
409   /// the bytecode.
410   template <typename T>
411   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
412                    std::is_pointer<T>::value>
413   append(T value) {
414     bytecode.push_back(generator.getMemIndex(value));
415   }
416 
417   /// Append a range of values.
418   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
419   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
420   append(T range) {
421     bytecode.push_back(llvm::size(range));
422     for (auto it : range)
423       append(it);
424   }
425 
426   /// Append a variadic number of fields to the bytecode.
427   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
428   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
429     append(field);
430     append(field2, fields...);
431   }
432 
433   /// Appends a value as a pointer, stored inline within the bytecode.
434   template <typename T>
435   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
436   appendInline(T value) {
437     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
438     const void *pointer = value.getAsOpaquePointer();
439     ByteCodeField fieldParts[numParts];
440     std::memcpy(fieldParts, &pointer, sizeof(const void *));
441     bytecode.append(fieldParts, fieldParts + numParts);
442   }
443 
444   /// Successor references in the bytecode that have yet to be resolved.
445   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
446 
447   /// The underlying bytecode buffer.
448   SmallVectorImpl<ByteCodeField> &bytecode;
449 
450   /// The main generator producing PDL.
451   Generator &generator;
452 };
453 
454 /// This class represents a live range of PDL Interpreter values, containing
455 /// information about when values are live within a match/rewrite.
456 struct ByteCodeLiveRange {
457   using Set = llvm::IntervalMap<uint64_t, char, 16>;
458   using Allocator = Set::Allocator;
459 
460   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
461 
462   /// Union this live range with the one provided.
463   void unionWith(const ByteCodeLiveRange &rhs) {
464     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
465          ++it)
466       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
467   }
468 
469   /// Returns true if this range overlaps with the one provided.
470   bool overlaps(const ByteCodeLiveRange &rhs) const {
471     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
472         .valid();
473   }
474 
475   /// A map representing the ranges of the match/rewrite that a value is live in
476   /// the interpreter.
477   ///
478   /// We use std::unique_ptr here, because IntervalMap does not provide a
479   /// correct copy or move constructor. We can eliminate the pointer once
480   /// https://reviews.llvm.org/D113240 lands.
481   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
482 
483   /// The operation range storage index for this range.
484   Optional<unsigned> opRangeIndex;
485 
486   /// The type range storage index for this range.
487   Optional<unsigned> typeRangeIndex;
488 
489   /// The value range storage index for this range.
490   Optional<unsigned> valueRangeIndex;
491 };
492 } // namespace
493 
494 void Generator::generate(ModuleOp module) {
495   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
496       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
497   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
498       pdl_interp::PDLInterpDialect::getRewriterModuleName());
499   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
500 
501   // Allocate memory indices for the results of operations within the matcher
502   // and rewriters.
503   allocateMemoryIndices(matcherFunc, rewriterModule);
504 
505   // Generate code for the rewriter functions.
506   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
507   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
508     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
509     for (Operation &op : rewriterFunc.getOps())
510       generate(&op, rewriterByteCodeWriter);
511   }
512   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
513          "unexpected branches in rewriter function");
514 
515   // Generate code for the matcher function.
516   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
517   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
518 
519   // Resolve successor references in the matcher.
520   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
521     ByteCodeAddr addr = blockToAddr[it.first];
522     for (unsigned offsetToFix : it.second)
523       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
524   }
525 }
526 
527 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
528                                       ModuleOp rewriterModule) {
529   // Rewriters use simplistic allocation scheme that simply assigns an index to
530   // each result.
531   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
532     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
533     auto processRewriterValue = [&](Value val) {
534       valueToMemIndex.try_emplace(val, index++);
535       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
536         Type elementTy = rangeType.getElementType();
537         if (elementTy.isa<pdl::TypeType>())
538           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
539         else if (elementTy.isa<pdl::ValueType>())
540           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
541       }
542     };
543 
544     for (BlockArgument arg : rewriterFunc.getArguments())
545       processRewriterValue(arg);
546     rewriterFunc.getBody().walk([&](Operation *op) {
547       for (Value result : op->getResults())
548         processRewriterValue(result);
549     });
550     if (index > maxValueMemoryIndex)
551       maxValueMemoryIndex = index;
552     if (typeRangeIndex > maxTypeRangeMemoryIndex)
553       maxTypeRangeMemoryIndex = typeRangeIndex;
554     if (valueRangeIndex > maxValueRangeMemoryIndex)
555       maxValueRangeMemoryIndex = valueRangeIndex;
556   }
557 
558   // The matcher function uses a more sophisticated numbering that tries to
559   // minimize the number of memory indices assigned. This is done by determining
560   // a live range of the values within the matcher, then the allocation is just
561   // finding the minimal number of overlapping live ranges. This is essentially
562   // a simplified form of register allocation where we don't necessarily have a
563   // limited number of registers, but we still want to minimize the number used.
564   DenseMap<Operation *, unsigned> opToFirstIndex;
565   DenseMap<Operation *, unsigned> opToLastIndex;
566 
567   // A custom walk that marks the first and the last index of each operation.
568   // The entry marks the beginning of the liveness range for this operation,
569   // followed by nested operations, followed by the end of the liveness range.
570   unsigned index = 0;
571   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
572     opToFirstIndex.try_emplace(op, index++);
573     for (Region &region : op->getRegions())
574       for (Block &block : region.getBlocks())
575         for (Operation &nested : block)
576           walk(&nested);
577     opToLastIndex.try_emplace(op, index++);
578   };
579   walk(matcherFunc);
580 
581   // Liveness info for each of the defs within the matcher.
582   ByteCodeLiveRange::Allocator allocator;
583   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
584 
585   // Assign the root operation being matched to slot 0.
586   BlockArgument rootOpArg = matcherFunc.getArgument(0);
587   valueToMemIndex[rootOpArg] = 0;
588 
589   // Walk each of the blocks, computing the def interval that the value is used.
590   Liveness matcherLiveness(matcherFunc);
591   matcherFunc->walk([&](Block *block) {
592     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
593     assert(info && "expected liveness info for block");
594     auto processValue = [&](Value value, Operation *firstUseOrDef) {
595       // We don't need to process the root op argument, this value is always
596       // assigned to the first memory slot.
597       if (value == rootOpArg)
598         return;
599 
600       // Set indices for the range of this block that the value is used.
601       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
602       defRangeIt->second.liveness->insert(
603           opToFirstIndex[firstUseOrDef],
604           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
605           /*dummyValue*/ 0);
606 
607       // Check to see if this value is a range type.
608       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
609         Type eleType = rangeTy.getElementType();
610         if (eleType.isa<pdl::OperationType>())
611           defRangeIt->second.opRangeIndex = 0;
612         else if (eleType.isa<pdl::TypeType>())
613           defRangeIt->second.typeRangeIndex = 0;
614         else if (eleType.isa<pdl::ValueType>())
615           defRangeIt->second.valueRangeIndex = 0;
616       }
617     };
618 
619     // Process the live-ins of this block.
620     for (Value liveIn : info->in()) {
621       // Only process the value if it has been defined in the current region.
622       // Other values that span across pdl_interp.foreach will be added higher
623       // up. This ensures that the we keep them alive for the entire duration
624       // of the loop.
625       if (liveIn.getParentRegion() == block->getParent())
626         processValue(liveIn, &block->front());
627     }
628 
629     // Process the block arguments for the entry block (those are not live-in).
630     if (block->isEntryBlock()) {
631       for (Value argument : block->getArguments())
632         processValue(argument, &block->front());
633     }
634 
635     // Process any new defs within this block.
636     for (Operation &op : *block)
637       for (Value result : op.getResults())
638         processValue(result, &op);
639   });
640 
641   // Greedily allocate memory slots using the computed def live ranges.
642   std::vector<ByteCodeLiveRange> allocatedIndices;
643 
644   // The number of memory indices currently allocated (and its next value).
645   // Recall that the root gets allocated memory index 0.
646   ByteCodeField numIndices = 1;
647 
648   // The number of memory ranges of various types (and their next values).
649   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
650 
651   for (auto &defIt : valueDefRanges) {
652     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
653     ByteCodeLiveRange &defRange = defIt.second;
654 
655     // Try to allocate to an existing index.
656     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
657       ByteCodeLiveRange &existingRange = existingIndexIt.value();
658       if (!defRange.overlaps(existingRange)) {
659         existingRange.unionWith(defRange);
660         memIndex = existingIndexIt.index() + 1;
661 
662         if (defRange.opRangeIndex) {
663           if (!existingRange.opRangeIndex)
664             existingRange.opRangeIndex = numOpRanges++;
665           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
666         } else if (defRange.typeRangeIndex) {
667           if (!existingRange.typeRangeIndex)
668             existingRange.typeRangeIndex = numTypeRanges++;
669           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
670         } else if (defRange.valueRangeIndex) {
671           if (!existingRange.valueRangeIndex)
672             existingRange.valueRangeIndex = numValueRanges++;
673           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
674         }
675         break;
676       }
677     }
678 
679     // If no existing index could be used, add a new one.
680     if (memIndex == 0) {
681       allocatedIndices.emplace_back(allocator);
682       ByteCodeLiveRange &newRange = allocatedIndices.back();
683       newRange.unionWith(defRange);
684 
685       // Allocate an index for op/type/value ranges.
686       if (defRange.opRangeIndex) {
687         newRange.opRangeIndex = numOpRanges;
688         valueToRangeIndex[defIt.first] = numOpRanges++;
689       } else if (defRange.typeRangeIndex) {
690         newRange.typeRangeIndex = numTypeRanges;
691         valueToRangeIndex[defIt.first] = numTypeRanges++;
692       } else if (defRange.valueRangeIndex) {
693         newRange.valueRangeIndex = numValueRanges;
694         valueToRangeIndex[defIt.first] = numValueRanges++;
695       }
696 
697       memIndex = allocatedIndices.size();
698       ++numIndices;
699     }
700   }
701 
702   // Print the index usage and ensure that we did not run out of index space.
703   LLVM_DEBUG({
704     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
705                  << "(down from initial " << valueDefRanges.size() << ").\n";
706   });
707   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
708          "Ran out of memory for allocated indices");
709 
710   // Update the max number of indices.
711   if (numIndices > maxValueMemoryIndex)
712     maxValueMemoryIndex = numIndices;
713   if (numOpRanges > maxOpRangeMemoryIndex)
714     maxOpRangeMemoryIndex = numOpRanges;
715   if (numTypeRanges > maxTypeRangeMemoryIndex)
716     maxTypeRangeMemoryIndex = numTypeRanges;
717   if (numValueRanges > maxValueRangeMemoryIndex)
718     maxValueRangeMemoryIndex = numValueRanges;
719 }
720 
721 void Generator::generate(Region *region, ByteCodeWriter &writer) {
722   llvm::ReversePostOrderTraversal<Region *> rpot(region);
723   for (Block *block : rpot) {
724     // Keep track of where this block begins within the matcher function.
725     blockToAddr.try_emplace(block, matcherByteCode.size());
726     for (Operation &op : *block)
727       generate(&op, writer);
728   }
729 }
730 
731 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
732   LLVM_DEBUG({
733     // The following list must contain all the operations that do not
734     // produce any bytecode.
735     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
736       writer.appendInline(op->getLoc());
737   });
738   TypeSwitch<Operation *>(op)
739       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
740             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
741             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
742             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
743             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
744             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
745             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
746             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
747             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
748             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
749             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
750             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
751             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
752             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
753             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
754             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
755             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
756             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
757             pdl_interp::SwitchResultCountOp>(
758           [&](auto interpOp) { this->generate(interpOp, writer); })
759       .Default([](Operation *) {
760         llvm_unreachable("unknown `pdl_interp` operation");
761       });
762 }
763 
764 void Generator::generate(pdl_interp::ApplyConstraintOp op,
765                          ByteCodeWriter &writer) {
766   assert(constraintToMemIndex.count(op.getName()) &&
767          "expected index for constraint function");
768   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
769   writer.appendPDLValueList(op.getArgs());
770   writer.append(op.getSuccessors());
771 }
772 void Generator::generate(pdl_interp::ApplyRewriteOp op,
773                          ByteCodeWriter &writer) {
774   assert(externalRewriterToMemIndex.count(op.getName()) &&
775          "expected index for rewrite function");
776   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
777   writer.appendPDLValueList(op.getArgs());
778 
779   ResultRange results = op.getResults();
780   writer.append(ByteCodeField(results.size()));
781   for (Value result : results) {
782     // In debug mode we also record the expected kind of the result, so that we
783     // can provide extra verification of the native rewrite function.
784 #ifndef NDEBUG
785     writer.appendPDLValueKind(result);
786 #endif
787 
788     // Range results also need to append the range storage index.
789     if (result.getType().isa<pdl::RangeType>())
790       writer.append(getRangeStorageIndex(result));
791     writer.append(result);
792   }
793 }
794 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
795   Value lhs = op.getLhs();
796   if (lhs.getType().isa<pdl::RangeType>()) {
797     writer.append(OpCode::AreRangesEqual);
798     writer.appendPDLValueKind(lhs);
799     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
800     return;
801   }
802 
803   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
804 }
805 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
806   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
807 }
808 void Generator::generate(pdl_interp::CheckAttributeOp op,
809                          ByteCodeWriter &writer) {
810   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
811                 op.getSuccessors());
812 }
813 void Generator::generate(pdl_interp::CheckOperandCountOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
816                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
817                 op.getSuccessors());
818 }
819 void Generator::generate(pdl_interp::CheckOperationNameOp op,
820                          ByteCodeWriter &writer) {
821   writer.append(OpCode::CheckOperationName, op.getInputOp(),
822                 OperationName(op.getName(), ctx), op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckResultCountOp op,
825                          ByteCodeWriter &writer) {
826   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
827                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
828                 op.getSuccessors());
829 }
830 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
831   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
832                 op.getSuccessors());
833 }
834 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
835   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
836                 op.getSuccessors());
837 }
838 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
839   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
840   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
841 }
842 void Generator::generate(pdl_interp::CreateAttributeOp op,
843                          ByteCodeWriter &writer) {
844   // Simply repoint the memory index of the result to the constant.
845   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
846 }
847 void Generator::generate(pdl_interp::CreateOperationOp op,
848                          ByteCodeWriter &writer) {
849   writer.append(OpCode::CreateOperation, op.getResultOp(),
850                 OperationName(op.getName(), ctx));
851   writer.appendPDLValueList(op.getInputOperands());
852 
853   // Add the attributes.
854   OperandRange attributes = op.getInputAttributes();
855   writer.append(static_cast<ByteCodeField>(attributes.size()));
856   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
857     writer.append(std::get<0>(it), std::get<1>(it));
858 
859   // Add the result types. If the operation has inferred results, we use a
860   // marker "size" value. Otherwise, we add the list of explicit result types.
861   if (op.getInferredResultTypes())
862     writer.append(kInferTypesMarker);
863   else
864     writer.appendPDLValueList(op.getInputResultTypes());
865 }
866 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
867   // Simply repoint the memory index of the result to the constant.
868   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
869 }
870 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
871   writer.append(OpCode::CreateTypes, op.getResult(),
872                 getRangeStorageIndex(op.getResult()), op.getValue());
873 }
874 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
875   writer.append(OpCode::EraseOp, op.getInputOp());
876 }
877 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
878   OpCode opCode =
879       TypeSwitch<Type, OpCode>(op.getResult().getType())
880           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
881           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
882           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
883           .Default([](Type) -> OpCode {
884             llvm_unreachable("unsupported element type");
885           });
886   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
887 }
888 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
889   writer.append(OpCode::Finalize);
890 }
891 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
892   BlockArgument arg = op.getLoopVariable();
893   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
894   writer.appendPDLValueKind(arg.getType());
895   writer.append(curLoopLevel, op.getSuccessor());
896   ++curLoopLevel;
897   if (curLoopLevel > maxLoopLevel)
898     maxLoopLevel = curLoopLevel;
899   generate(&op.getRegion(), writer);
900   --curLoopLevel;
901 }
902 void Generator::generate(pdl_interp::GetAttributeOp op,
903                          ByteCodeWriter &writer) {
904   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
905                 op.getNameAttr());
906 }
907 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
908                          ByteCodeWriter &writer) {
909   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
910 }
911 void Generator::generate(pdl_interp::GetDefiningOpOp op,
912                          ByteCodeWriter &writer) {
913   writer.append(OpCode::GetDefiningOp, op.getInputOp());
914   writer.appendPDLValue(op.getValue());
915 }
916 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
917   uint32_t index = op.getIndex();
918   if (index < 4)
919     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
920   else
921     writer.append(OpCode::GetOperandN, index);
922   writer.append(op.getInputOp(), op.getValue());
923 }
924 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
925   Value result = op.getValue();
926   Optional<uint32_t> index = op.getIndex();
927   writer.append(OpCode::GetOperands,
928                 index.value_or(std::numeric_limits<uint32_t>::max()),
929                 op.getInputOp());
930   if (result.getType().isa<pdl::RangeType>())
931     writer.append(getRangeStorageIndex(result));
932   else
933     writer.append(std::numeric_limits<ByteCodeField>::max());
934   writer.append(result);
935 }
936 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
937   uint32_t index = op.getIndex();
938   if (index < 4)
939     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
940   else
941     writer.append(OpCode::GetResultN, index);
942   writer.append(op.getInputOp(), op.getValue());
943 }
944 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
945   Value result = op.getValue();
946   Optional<uint32_t> index = op.getIndex();
947   writer.append(OpCode::GetResults,
948                 index.value_or(std::numeric_limits<uint32_t>::max()),
949                 op.getInputOp());
950   if (result.getType().isa<pdl::RangeType>())
951     writer.append(getRangeStorageIndex(result));
952   else
953     writer.append(std::numeric_limits<ByteCodeField>::max());
954   writer.append(result);
955 }
956 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
957   Value operations = op.getOperations();
958   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
959   writer.append(OpCode::GetUsers, operations, rangeIndex);
960   writer.appendPDLValue(op.getValue());
961 }
962 void Generator::generate(pdl_interp::GetValueTypeOp op,
963                          ByteCodeWriter &writer) {
964   if (op.getType().isa<pdl::RangeType>()) {
965     Value result = op.getResult();
966     writer.append(OpCode::GetValueRangeTypes, result,
967                   getRangeStorageIndex(result), op.getValue());
968   } else {
969     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
970   }
971 }
972 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
973   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
974 }
975 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
976   ByteCodeField patternIndex = patterns.size();
977   patterns.emplace_back(PDLByteCodePattern::create(
978       op, configMap.lookup(op),
979       rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
980   writer.append(OpCode::RecordMatch, patternIndex,
981                 SuccessorRange(op.getOperation()), op.getMatchedOps());
982   writer.appendPDLValueList(op.getInputs());
983 }
984 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
985   writer.append(OpCode::ReplaceOp, op.getInputOp());
986   writer.appendPDLValueList(op.getReplValues());
987 }
988 void Generator::generate(pdl_interp::SwitchAttributeOp op,
989                          ByteCodeWriter &writer) {
990   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
991                 op.getCaseValuesAttr(), op.getSuccessors());
992 }
993 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
994                          ByteCodeWriter &writer) {
995   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
996                 op.getCaseValuesAttr(), op.getSuccessors());
997 }
998 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
999                          ByteCodeWriter &writer) {
1000   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1001     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1002   });
1003   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1004                 op.getSuccessors());
1005 }
1006 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1007                          ByteCodeWriter &writer) {
1008   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1009                 op.getCaseValuesAttr(), op.getSuccessors());
1010 }
1011 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1012   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1013                 op.getSuccessors());
1014 }
1015 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1016   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1017                 op.getSuccessors());
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // PDLByteCode
1022 //===----------------------------------------------------------------------===//
1023 
1024 PDLByteCode::PDLByteCode(
1025     ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1026     const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1027     llvm::StringMap<PDLConstraintFunction> constraintFns,
1028     llvm::StringMap<PDLRewriteFunction> rewriteFns)
1029     : configs(std::move(configs)) {
1030   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1031                       rewriterByteCode, patterns, maxValueMemoryIndex,
1032                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1033                       maxLoopLevel, constraintFns, rewriteFns, configMap);
1034   generator.generate(module);
1035 
1036   // Initialize the external functions.
1037   for (auto &it : constraintFns)
1038     constraintFunctions.push_back(std::move(it.second));
1039   for (auto &it : rewriteFns)
1040     rewriteFunctions.push_back(std::move(it.second));
1041 }
1042 
1043 /// Initialize the given state such that it can be used to execute the current
1044 /// bytecode.
1045 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1046   state.memory.resize(maxValueMemoryIndex, nullptr);
1047   state.opRangeMemory.resize(maxOpRangeCount);
1048   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1049   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1050   state.loopIndex.resize(maxLoopLevel, 0);
1051   state.currentPatternBenefits.reserve(patterns.size());
1052   for (const PDLByteCodePattern &pattern : patterns)
1053     state.currentPatternBenefits.push_back(pattern.getBenefit());
1054 }
1055 
1056 //===----------------------------------------------------------------------===//
1057 // ByteCode Execution
1058 
1059 namespace {
1060 /// This class provides support for executing a bytecode stream.
1061 class ByteCodeExecutor {
1062 public:
1063   ByteCodeExecutor(
1064       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1065       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1066       MutableArrayRef<TypeRange> typeRangeMemory,
1067       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1068       MutableArrayRef<ValueRange> valueRangeMemory,
1069       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1070       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1071       ArrayRef<ByteCodeField> code,
1072       ArrayRef<PatternBenefit> currentPatternBenefits,
1073       ArrayRef<PDLByteCodePattern> patterns,
1074       ArrayRef<PDLConstraintFunction> constraintFunctions,
1075       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1076       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1077         typeRangeMemory(typeRangeMemory),
1078         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1079         valueRangeMemory(valueRangeMemory),
1080         allocatedValueRangeMemory(allocatedValueRangeMemory),
1081         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1082         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1083         constraintFunctions(constraintFunctions),
1084         rewriteFunctions(rewriteFunctions) {}
1085 
1086   /// Start executing the code at the current bytecode index. `matches` is an
1087   /// optional field provided when this function is executed in a matching
1088   /// context.
1089   LogicalResult
1090   execute(PatternRewriter &rewriter,
1091           SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1092           Optional<Location> mainRewriteLoc = {});
1093 
1094 private:
1095   /// Internal implementation of executing each of the bytecode commands.
1096   void executeApplyConstraint(PatternRewriter &rewriter);
1097   LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1098   void executeAreEqual();
1099   void executeAreRangesEqual();
1100   void executeBranch();
1101   void executeCheckOperandCount();
1102   void executeCheckOperationName();
1103   void executeCheckResultCount();
1104   void executeCheckTypes();
1105   void executeContinue();
1106   void executeCreateOperation(PatternRewriter &rewriter,
1107                               Location mainRewriteLoc);
1108   void executeCreateTypes();
1109   void executeEraseOp(PatternRewriter &rewriter);
1110   template <typename T, typename Range, PDLValue::Kind kind>
1111   void executeExtract();
1112   void executeFinalize();
1113   void executeForEach();
1114   void executeGetAttribute();
1115   void executeGetAttributeType();
1116   void executeGetDefiningOp();
1117   void executeGetOperand(unsigned index);
1118   void executeGetOperands();
1119   void executeGetResult(unsigned index);
1120   void executeGetResults();
1121   void executeGetUsers();
1122   void executeGetValueType();
1123   void executeGetValueRangeTypes();
1124   void executeIsNotNull();
1125   void executeRecordMatch(PatternRewriter &rewriter,
1126                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1127   void executeReplaceOp(PatternRewriter &rewriter);
1128   void executeSwitchAttribute();
1129   void executeSwitchOperandCount();
1130   void executeSwitchOperationName();
1131   void executeSwitchResultCount();
1132   void executeSwitchType();
1133   void executeSwitchTypes();
1134 
1135   /// Pushes a code iterator to the stack.
1136   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1137 
1138   /// Pops a code iterator from the stack, returning true on success.
1139   void popCodeIt() {
1140     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1141     curCodeIt = resumeCodeIt.back();
1142     resumeCodeIt.pop_back();
1143   }
1144 
1145   /// Return the bytecode iterator at the start of the current op code.
1146   const ByteCodeField *getPrevCodeIt() const {
1147     LLVM_DEBUG({
1148       // Account for the op code and the Location stored inline.
1149       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1150     });
1151 
1152     // Account for the op code only.
1153     return curCodeIt - 1;
1154   }
1155 
1156   /// Read a value from the bytecode buffer, optionally skipping a certain
1157   /// number of prefix values. These methods always update the buffer to point
1158   /// to the next field after the read data.
1159   template <typename T = ByteCodeField>
1160   T read(size_t skipN = 0) {
1161     curCodeIt += skipN;
1162     return readImpl<T>();
1163   }
1164   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1165 
1166   /// Read a list of values from the bytecode buffer.
1167   template <typename ValueT, typename T>
1168   void readList(SmallVectorImpl<T> &list) {
1169     list.clear();
1170     for (unsigned i = 0, e = read(); i != e; ++i)
1171       list.push_back(read<ValueT>());
1172   }
1173 
1174   /// Read a list of values from the bytecode buffer. The values may be encoded
1175   /// as either Value or ValueRange elements.
1176   void readValueList(SmallVectorImpl<Value> &list) {
1177     for (unsigned i = 0, e = read(); i != e; ++i) {
1178       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1179         list.push_back(read<Value>());
1180       } else {
1181         ValueRange *values = read<ValueRange *>();
1182         list.append(values->begin(), values->end());
1183       }
1184     }
1185   }
1186 
1187   /// Read a value stored inline as a pointer.
1188   template <typename T>
1189   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1190   readInline() {
1191     const void *pointer;
1192     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1193     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1194     return T::getFromOpaquePointer(pointer);
1195   }
1196 
1197   /// Jump to a specific successor based on a predicate value.
1198   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1199   /// Jump to a specific successor based on a destination index.
1200   void selectJump(size_t destIndex) {
1201     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1202   }
1203 
1204   /// Handle a switch operation with the provided value and cases.
1205   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1206   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1207     LLVM_DEBUG({
1208       llvm::dbgs() << "  * Value: " << value << "\n"
1209                    << "  * Cases: ";
1210       llvm::interleaveComma(cases, llvm::dbgs());
1211       llvm::dbgs() << "\n";
1212     });
1213 
1214     // Check to see if the attribute value is within the case list. Jump to
1215     // the correct successor index based on the result.
1216     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1217       if (cmp(*it, value))
1218         return selectJump(size_t((it - cases.begin()) + 1));
1219     selectJump(size_t(0));
1220   }
1221 
1222   /// Store a pointer to memory.
1223   void storeToMemory(unsigned index, const void *value) {
1224     memory[index] = value;
1225   }
1226 
1227   /// Store a value to memory as an opaque pointer.
1228   template <typename T>
1229   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1230   storeToMemory(unsigned index, T value) {
1231     memory[index] = value.getAsOpaquePointer();
1232   }
1233 
1234   /// Internal implementation of reading various data types from the bytecode
1235   /// stream.
1236   template <typename T>
1237   const void *readFromMemory() {
1238     size_t index = *curCodeIt++;
1239 
1240     // If this type is an SSA value, it can only be stored in non-const memory.
1241     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1242                         Value>::value ||
1243         index < memory.size())
1244       return memory[index];
1245 
1246     // Otherwise, if this index is not inbounds it is uniqued.
1247     return uniquedMemory[index - memory.size()];
1248   }
1249   template <typename T>
1250   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1251     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1252   }
1253   template <typename T>
1254   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1255                    T>
1256   readImpl() {
1257     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1258   }
1259   template <typename T>
1260   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1261     switch (read<PDLValue::Kind>()) {
1262     case PDLValue::Kind::Attribute:
1263       return read<Attribute>();
1264     case PDLValue::Kind::Operation:
1265       return read<Operation *>();
1266     case PDLValue::Kind::Type:
1267       return read<Type>();
1268     case PDLValue::Kind::Value:
1269       return read<Value>();
1270     case PDLValue::Kind::TypeRange:
1271       return read<TypeRange *>();
1272     case PDLValue::Kind::ValueRange:
1273       return read<ValueRange *>();
1274     }
1275     llvm_unreachable("unhandled PDLValue::Kind");
1276   }
1277   template <typename T>
1278   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1279     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1280                   "unexpected ByteCode address size");
1281     ByteCodeAddr result;
1282     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1283     curCodeIt += 2;
1284     return result;
1285   }
1286   template <typename T>
1287   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1288     return *curCodeIt++;
1289   }
1290   template <typename T>
1291   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1292     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1293   }
1294 
1295   /// The underlying bytecode buffer.
1296   const ByteCodeField *curCodeIt;
1297 
1298   /// The stack of bytecode positions at which to resume operation.
1299   SmallVector<const ByteCodeField *> resumeCodeIt;
1300 
1301   /// The current execution memory.
1302   MutableArrayRef<const void *> memory;
1303   MutableArrayRef<OwningOpRange> opRangeMemory;
1304   MutableArrayRef<TypeRange> typeRangeMemory;
1305   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1306   MutableArrayRef<ValueRange> valueRangeMemory;
1307   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1308 
1309   /// The current loop indices.
1310   MutableArrayRef<unsigned> loopIndex;
1311 
1312   /// References to ByteCode data necessary for execution.
1313   ArrayRef<const void *> uniquedMemory;
1314   ArrayRef<ByteCodeField> code;
1315   ArrayRef<PatternBenefit> currentPatternBenefits;
1316   ArrayRef<PDLByteCodePattern> patterns;
1317   ArrayRef<PDLConstraintFunction> constraintFunctions;
1318   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1319 };
1320 
1321 /// This class is an instantiation of the PDLResultList that provides access to
1322 /// the returned results. This API is not on `PDLResultList` to avoid
1323 /// overexposing access to information specific solely to the ByteCode.
1324 class ByteCodeRewriteResultList : public PDLResultList {
1325 public:
1326   ByteCodeRewriteResultList(unsigned maxNumResults)
1327       : PDLResultList(maxNumResults) {}
1328 
1329   /// Return the list of PDL results.
1330   MutableArrayRef<PDLValue> getResults() { return results; }
1331 
1332   /// Return the type ranges allocated by this list.
1333   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1334     return allocatedTypeRanges;
1335   }
1336 
1337   /// Return the value ranges allocated by this list.
1338   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1339     return allocatedValueRanges;
1340   }
1341 };
1342 } // namespace
1343 
1344 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1345   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1346   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1347   SmallVector<PDLValue, 16> args;
1348   readList<PDLValue>(args);
1349 
1350   LLVM_DEBUG({
1351     llvm::dbgs() << "  * Arguments: ";
1352     llvm::interleaveComma(args, llvm::dbgs());
1353   });
1354 
1355   // Invoke the constraint and jump to the proper destination.
1356   selectJump(succeeded(constraintFn(rewriter, args)));
1357 }
1358 
1359 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1360   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1361   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1362   SmallVector<PDLValue, 16> args;
1363   readList<PDLValue>(args);
1364 
1365   LLVM_DEBUG({
1366     llvm::dbgs() << "  * Arguments: ";
1367     llvm::interleaveComma(args, llvm::dbgs());
1368   });
1369 
1370   // Execute the rewrite function.
1371   ByteCodeField numResults = read();
1372   ByteCodeRewriteResultList results(numResults);
1373   LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1374 
1375   assert(results.getResults().size() == numResults &&
1376          "native PDL rewrite function returned unexpected number of results");
1377 
1378   // Store the results in the bytecode memory.
1379   for (PDLValue &result : results.getResults()) {
1380     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1381 
1382 // In debug mode we also verify the expected kind of the result.
1383 #ifndef NDEBUG
1384     assert(result.getKind() == read<PDLValue::Kind>() &&
1385            "native PDL rewrite function returned an unexpected type of result");
1386 #endif
1387 
1388     // If the result is a range, we need to copy it over to the bytecodes
1389     // range memory.
1390     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1391       unsigned rangeIndex = read();
1392       typeRangeMemory[rangeIndex] = *typeRange;
1393       memory[read()] = &typeRangeMemory[rangeIndex];
1394     } else if (Optional<ValueRange> valueRange =
1395                    result.dyn_cast<ValueRange>()) {
1396       unsigned rangeIndex = read();
1397       valueRangeMemory[rangeIndex] = *valueRange;
1398       memory[read()] = &valueRangeMemory[rangeIndex];
1399     } else {
1400       memory[read()] = result.getAsOpaquePointer();
1401     }
1402   }
1403 
1404   // Copy over any underlying storage allocated for result ranges.
1405   for (auto &it : results.getAllocatedTypeRanges())
1406     allocatedTypeRangeMemory.push_back(std::move(it));
1407   for (auto &it : results.getAllocatedValueRanges())
1408     allocatedValueRangeMemory.push_back(std::move(it));
1409 
1410   // Process the result of the rewrite.
1411   if (failed(rewriteResult)) {
1412     LLVM_DEBUG(llvm::dbgs() << "  - Failed");
1413     return failure();
1414   }
1415   return success();
1416 }
1417 
1418 void ByteCodeExecutor::executeAreEqual() {
1419   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1420   const void *lhs = read<const void *>();
1421   const void *rhs = read<const void *>();
1422 
1423   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1424   selectJump(lhs == rhs);
1425 }
1426 
1427 void ByteCodeExecutor::executeAreRangesEqual() {
1428   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1429   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1430   const void *lhs = read<const void *>();
1431   const void *rhs = read<const void *>();
1432 
1433   switch (valueKind) {
1434   case PDLValue::Kind::TypeRange: {
1435     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1436     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1437     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1438     selectJump(*lhsRange == *rhsRange);
1439     break;
1440   }
1441   case PDLValue::Kind::ValueRange: {
1442     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1443     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1444     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1445     selectJump(*lhsRange == *rhsRange);
1446     break;
1447   }
1448   default:
1449     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1450   }
1451 }
1452 
1453 void ByteCodeExecutor::executeBranch() {
1454   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1455   curCodeIt = &code[read<ByteCodeAddr>()];
1456 }
1457 
1458 void ByteCodeExecutor::executeCheckOperandCount() {
1459   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1460   Operation *op = read<Operation *>();
1461   uint32_t expectedCount = read<uint32_t>();
1462   bool compareAtLeast = read();
1463 
1464   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1465                           << "  * Expected: " << expectedCount << "\n"
1466                           << "  * Comparator: "
1467                           << (compareAtLeast ? ">=" : "==") << "\n");
1468   if (compareAtLeast)
1469     selectJump(op->getNumOperands() >= expectedCount);
1470   else
1471     selectJump(op->getNumOperands() == expectedCount);
1472 }
1473 
1474 void ByteCodeExecutor::executeCheckOperationName() {
1475   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1476   Operation *op = read<Operation *>();
1477   OperationName expectedName = read<OperationName>();
1478 
1479   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1480                           << "  * Expected: \"" << expectedName << "\"\n");
1481   selectJump(op->getName() == expectedName);
1482 }
1483 
1484 void ByteCodeExecutor::executeCheckResultCount() {
1485   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1486   Operation *op = read<Operation *>();
1487   uint32_t expectedCount = read<uint32_t>();
1488   bool compareAtLeast = read();
1489 
1490   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1491                           << "  * Expected: " << expectedCount << "\n"
1492                           << "  * Comparator: "
1493                           << (compareAtLeast ? ">=" : "==") << "\n");
1494   if (compareAtLeast)
1495     selectJump(op->getNumResults() >= expectedCount);
1496   else
1497     selectJump(op->getNumResults() == expectedCount);
1498 }
1499 
1500 void ByteCodeExecutor::executeCheckTypes() {
1501   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1502   TypeRange *lhs = read<TypeRange *>();
1503   Attribute rhs = read<Attribute>();
1504   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1505 
1506   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1507 }
1508 
1509 void ByteCodeExecutor::executeContinue() {
1510   ByteCodeField level = read();
1511   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1512                           << "  * Level: " << level << "\n");
1513   ++loopIndex[level];
1514   popCodeIt();
1515 }
1516 
1517 void ByteCodeExecutor::executeCreateTypes() {
1518   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1519   unsigned memIndex = read();
1520   unsigned rangeIndex = read();
1521   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1522 
1523   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1524 
1525   // Allocate a buffer for this type range.
1526   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1527   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1528   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1529 
1530   // Assign this to the range slot and use the range as the value for the
1531   // memory index.
1532   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1533   memory[memIndex] = &typeRangeMemory[rangeIndex];
1534 }
1535 
1536 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1537                                               Location mainRewriteLoc) {
1538   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1539 
1540   unsigned memIndex = read();
1541   OperationState state(mainRewriteLoc, read<OperationName>());
1542   readValueList(state.operands);
1543   for (unsigned i = 0, e = read(); i != e; ++i) {
1544     StringAttr name = read<StringAttr>();
1545     if (Attribute attr = read<Attribute>())
1546       state.addAttribute(name, attr);
1547   }
1548 
1549   // Read in the result types. If the "size" is the sentinel value, this
1550   // indicates that the result types should be inferred.
1551   unsigned numResults = read();
1552   if (numResults == kInferTypesMarker) {
1553     InferTypeOpInterface::Concept *inferInterface =
1554         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1555     assert(inferInterface &&
1556            "expected operation to provide InferTypeOpInterface");
1557 
1558     // TODO: Handle failure.
1559     if (failed(inferInterface->inferReturnTypes(
1560             state.getContext(), state.location, state.operands,
1561             state.attributes.getDictionary(state.getContext()), state.regions,
1562             state.types)))
1563       return;
1564   } else {
1565     // Otherwise, this is a fixed number of results.
1566     for (unsigned i = 0; i != numResults; ++i) {
1567       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1568         state.types.push_back(read<Type>());
1569       } else {
1570         TypeRange *resultTypes = read<TypeRange *>();
1571         state.types.append(resultTypes->begin(), resultTypes->end());
1572       }
1573     }
1574   }
1575 
1576   Operation *resultOp = rewriter.create(state);
1577   memory[memIndex] = resultOp;
1578 
1579   LLVM_DEBUG({
1580     llvm::dbgs() << "  * Attributes: "
1581                  << state.attributes.getDictionary(state.getContext())
1582                  << "\n  * Operands: ";
1583     llvm::interleaveComma(state.operands, llvm::dbgs());
1584     llvm::dbgs() << "\n  * Result Types: ";
1585     llvm::interleaveComma(state.types, llvm::dbgs());
1586     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1587   });
1588 }
1589 
1590 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1591   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1592   Operation *op = read<Operation *>();
1593 
1594   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1595   rewriter.eraseOp(op);
1596 }
1597 
1598 template <typename T, typename Range, PDLValue::Kind kind>
1599 void ByteCodeExecutor::executeExtract() {
1600   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1601   Range *range = read<Range *>();
1602   unsigned index = read<uint32_t>();
1603   unsigned memIndex = read();
1604 
1605   if (!range) {
1606     memory[memIndex] = nullptr;
1607     return;
1608   }
1609 
1610   T result = index < range->size() ? (*range)[index] : T();
1611   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1612                           << "  * Index: " << index << "\n"
1613                           << "  * Result: " << result << "\n");
1614   storeToMemory(memIndex, result);
1615 }
1616 
1617 void ByteCodeExecutor::executeFinalize() {
1618   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1619 }
1620 
1621 void ByteCodeExecutor::executeForEach() {
1622   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1623   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1624   unsigned rangeIndex = read();
1625   unsigned memIndex = read();
1626   const void *value = nullptr;
1627 
1628   switch (read<PDLValue::Kind>()) {
1629   case PDLValue::Kind::Operation: {
1630     unsigned &index = loopIndex[read()];
1631     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1632     assert(index <= array.size() && "iterated past the end");
1633     if (index < array.size()) {
1634       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1635       value = array[index];
1636       break;
1637     }
1638 
1639     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1640     index = 0;
1641     selectJump(size_t(0));
1642     return;
1643   }
1644   default:
1645     llvm_unreachable("unexpected `ForEach` value kind");
1646   }
1647 
1648   // Store the iterate value and the stack address.
1649   memory[memIndex] = value;
1650   pushCodeIt(prevCodeIt);
1651 
1652   // Skip over the successor (we will enter the body of the loop).
1653   read<ByteCodeAddr>();
1654 }
1655 
1656 void ByteCodeExecutor::executeGetAttribute() {
1657   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1658   unsigned memIndex = read();
1659   Operation *op = read<Operation *>();
1660   StringAttr attrName = read<StringAttr>();
1661   Attribute attr = op->getAttr(attrName);
1662 
1663   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1664                           << "  * Attribute: " << attrName << "\n"
1665                           << "  * Result: " << attr << "\n");
1666   memory[memIndex] = attr.getAsOpaquePointer();
1667 }
1668 
1669 void ByteCodeExecutor::executeGetAttributeType() {
1670   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1671   unsigned memIndex = read();
1672   Attribute attr = read<Attribute>();
1673   Type type;
1674   if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1675     type = typedAttr.getType();
1676 
1677   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1678                           << "  * Result: " << type << "\n");
1679   memory[memIndex] = type.getAsOpaquePointer();
1680 }
1681 
1682 void ByteCodeExecutor::executeGetDefiningOp() {
1683   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1684   unsigned memIndex = read();
1685   Operation *op = nullptr;
1686   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1687     Value value = read<Value>();
1688     if (value)
1689       op = value.getDefiningOp();
1690     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1691   } else {
1692     ValueRange *values = read<ValueRange *>();
1693     if (values && !values->empty()) {
1694       op = values->front().getDefiningOp();
1695     }
1696     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1697   }
1698 
1699   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1700   memory[memIndex] = op;
1701 }
1702 
1703 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1704   Operation *op = read<Operation *>();
1705   unsigned memIndex = read();
1706   Value operand =
1707       index < op->getNumOperands() ? op->getOperand(index) : Value();
1708 
1709   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1710                           << "  * Index: " << index << "\n"
1711                           << "  * Result: " << operand << "\n");
1712   memory[memIndex] = operand.getAsOpaquePointer();
1713 }
1714 
1715 /// This function is the internal implementation of `GetResults` and
1716 /// `GetOperands` that provides support for extracting a value range from the
1717 /// given operation.
1718 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1719 static void *
1720 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1721                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1722                           MutableArrayRef<ValueRange> valueRangeMemory) {
1723   // Check for the sentinel index that signals that all values should be
1724   // returned.
1725   if (index == std::numeric_limits<uint32_t>::max()) {
1726     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1727     // `values` is already the full value range.
1728 
1729     // Otherwise, check to see if this operation uses AttrSizedSegments.
1730   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1731     LLVM_DEBUG(llvm::dbgs()
1732                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1733 
1734     auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1735     if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1736       return nullptr;
1737 
1738     ArrayRef<int32_t> segments = segmentAttr;
1739     unsigned startIndex =
1740         std::accumulate(segments.begin(), segments.begin() + index, 0);
1741     values = values.slice(startIndex, *std::next(segments.begin(), index));
1742 
1743     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1744                             << *std::next(segments.begin(), index) << "]\n");
1745 
1746     // Otherwise, assume this is the last operand group of the operation.
1747     // FIXME: We currently don't support operations with
1748     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1749     // have a way to detect it's presence.
1750   } else if (values.size() >= index) {
1751     LLVM_DEBUG(llvm::dbgs()
1752                << "  * Treating values as trailing variadic range\n");
1753     values = values.drop_front(index);
1754 
1755     // If we couldn't detect a way to compute the values, bail out.
1756   } else {
1757     return nullptr;
1758   }
1759 
1760   // If the range index is valid, we are returning a range.
1761   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1762     valueRangeMemory[rangeIndex] = values;
1763     return &valueRangeMemory[rangeIndex];
1764   }
1765 
1766   // If a range index wasn't provided, the range is required to be non-variadic.
1767   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1768 }
1769 
1770 void ByteCodeExecutor::executeGetOperands() {
1771   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1772   unsigned index = read<uint32_t>();
1773   Operation *op = read<Operation *>();
1774   ByteCodeField rangeIndex = read();
1775 
1776   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1777       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1778       valueRangeMemory);
1779   if (!result)
1780     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1781   memory[read()] = result;
1782 }
1783 
1784 void ByteCodeExecutor::executeGetResult(unsigned index) {
1785   Operation *op = read<Operation *>();
1786   unsigned memIndex = read();
1787   OpResult result =
1788       index < op->getNumResults() ? op->getResult(index) : OpResult();
1789 
1790   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1791                           << "  * Index: " << index << "\n"
1792                           << "  * Result: " << result << "\n");
1793   memory[memIndex] = result.getAsOpaquePointer();
1794 }
1795 
1796 void ByteCodeExecutor::executeGetResults() {
1797   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1798   unsigned index = read<uint32_t>();
1799   Operation *op = read<Operation *>();
1800   ByteCodeField rangeIndex = read();
1801 
1802   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1803       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1804       valueRangeMemory);
1805   if (!result)
1806     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1807   memory[read()] = result;
1808 }
1809 
1810 void ByteCodeExecutor::executeGetUsers() {
1811   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1812   unsigned memIndex = read();
1813   unsigned rangeIndex = read();
1814   OwningOpRange &range = opRangeMemory[rangeIndex];
1815   memory[memIndex] = &range;
1816 
1817   range = OwningOpRange();
1818   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1819     // Read the value.
1820     Value value = read<Value>();
1821     if (!value)
1822       return;
1823     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1824 
1825     // Extract the users of a single value.
1826     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1827     llvm::copy(value.getUsers(), range.begin());
1828   } else {
1829     // Read a range of values.
1830     ValueRange *values = read<ValueRange *>();
1831     if (!values)
1832       return;
1833     LLVM_DEBUG({
1834       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1835       llvm::interleaveComma(*values, llvm::dbgs());
1836       llvm::dbgs() << "\n";
1837     });
1838 
1839     // Extract all the users of a range of values.
1840     SmallVector<Operation *> users;
1841     for (Value value : *values)
1842       users.append(value.user_begin(), value.user_end());
1843     range = OwningOpRange(users.size());
1844     llvm::copy(users, range.begin());
1845   }
1846 
1847   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1848 }
1849 
1850 void ByteCodeExecutor::executeGetValueType() {
1851   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1852   unsigned memIndex = read();
1853   Value value = read<Value>();
1854   Type type = value ? value.getType() : Type();
1855 
1856   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1857                           << "  * Result: " << type << "\n");
1858   memory[memIndex] = type.getAsOpaquePointer();
1859 }
1860 
1861 void ByteCodeExecutor::executeGetValueRangeTypes() {
1862   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1863   unsigned memIndex = read();
1864   unsigned rangeIndex = read();
1865   ValueRange *values = read<ValueRange *>();
1866   if (!values) {
1867     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1868     memory[memIndex] = nullptr;
1869     return;
1870   }
1871 
1872   LLVM_DEBUG({
1873     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1874     llvm::interleaveComma(*values, llvm::dbgs());
1875     llvm::dbgs() << "\n  * Result: ";
1876     llvm::interleaveComma(values->getType(), llvm::dbgs());
1877     llvm::dbgs() << "\n";
1878   });
1879   typeRangeMemory[rangeIndex] = values->getType();
1880   memory[memIndex] = &typeRangeMemory[rangeIndex];
1881 }
1882 
1883 void ByteCodeExecutor::executeIsNotNull() {
1884   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1885   const void *value = read<const void *>();
1886 
1887   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1888   selectJump(value != nullptr);
1889 }
1890 
1891 void ByteCodeExecutor::executeRecordMatch(
1892     PatternRewriter &rewriter,
1893     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1894   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1895   unsigned patternIndex = read();
1896   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1897   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1898 
1899   // If the benefit of the pattern is impossible, skip the processing of the
1900   // rest of the pattern.
1901   if (benefit.isImpossibleToMatch()) {
1902     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1903     curCodeIt = dest;
1904     return;
1905   }
1906 
1907   // Create a fused location containing the locations of each of the
1908   // operations used in the match. This will be used as the location for
1909   // created operations during the rewrite that don't already have an
1910   // explicit location set.
1911   unsigned numMatchLocs = read();
1912   SmallVector<Location, 4> matchLocs;
1913   matchLocs.reserve(numMatchLocs);
1914   for (unsigned i = 0; i != numMatchLocs; ++i)
1915     matchLocs.push_back(read<Operation *>()->getLoc());
1916   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1917 
1918   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1919                           << "  * Location: " << matchLoc << "\n");
1920   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1921   PDLByteCode::MatchResult &match = matches.back();
1922 
1923   // Record all of the inputs to the match. If any of the inputs are ranges, we
1924   // will also need to remap the range pointer to memory stored in the match
1925   // state.
1926   unsigned numInputs = read();
1927   match.values.reserve(numInputs);
1928   match.typeRangeValues.reserve(numInputs);
1929   match.valueRangeValues.reserve(numInputs);
1930   for (unsigned i = 0; i < numInputs; ++i) {
1931     switch (read<PDLValue::Kind>()) {
1932     case PDLValue::Kind::TypeRange:
1933       match.typeRangeValues.push_back(*read<TypeRange *>());
1934       match.values.push_back(&match.typeRangeValues.back());
1935       break;
1936     case PDLValue::Kind::ValueRange:
1937       match.valueRangeValues.push_back(*read<ValueRange *>());
1938       match.values.push_back(&match.valueRangeValues.back());
1939       break;
1940     default:
1941       match.values.push_back(read<const void *>());
1942       break;
1943     }
1944   }
1945   curCodeIt = dest;
1946 }
1947 
1948 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1949   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1950   Operation *op = read<Operation *>();
1951   SmallVector<Value, 16> args;
1952   readValueList(args);
1953 
1954   LLVM_DEBUG({
1955     llvm::dbgs() << "  * Operation: " << *op << "\n"
1956                  << "  * Values: ";
1957     llvm::interleaveComma(args, llvm::dbgs());
1958     llvm::dbgs() << "\n";
1959   });
1960   rewriter.replaceOp(op, args);
1961 }
1962 
1963 void ByteCodeExecutor::executeSwitchAttribute() {
1964   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1965   Attribute value = read<Attribute>();
1966   ArrayAttr cases = read<ArrayAttr>();
1967   handleSwitch(value, cases);
1968 }
1969 
1970 void ByteCodeExecutor::executeSwitchOperandCount() {
1971   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1972   Operation *op = read<Operation *>();
1973   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1974 
1975   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1976   handleSwitch(op->getNumOperands(), cases);
1977 }
1978 
1979 void ByteCodeExecutor::executeSwitchOperationName() {
1980   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1981   OperationName value = read<Operation *>()->getName();
1982   size_t caseCount = read();
1983 
1984   // The operation names are stored in-line, so to print them out for
1985   // debugging purposes we need to read the array before executing the
1986   // switch so that we can display all of the possible values.
1987   LLVM_DEBUG({
1988     const ByteCodeField *prevCodeIt = curCodeIt;
1989     llvm::dbgs() << "  * Value: " << value << "\n"
1990                  << "  * Cases: ";
1991     llvm::interleaveComma(
1992         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1993                         [&](size_t) { return read<OperationName>(); }),
1994         llvm::dbgs());
1995     llvm::dbgs() << "\n";
1996     curCodeIt = prevCodeIt;
1997   });
1998 
1999   // Try to find the switch value within any of the cases.
2000   for (size_t i = 0; i != caseCount; ++i) {
2001     if (read<OperationName>() == value) {
2002       curCodeIt += (caseCount - i - 1);
2003       return selectJump(i + 1);
2004     }
2005   }
2006   selectJump(size_t(0));
2007 }
2008 
2009 void ByteCodeExecutor::executeSwitchResultCount() {
2010   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2011   Operation *op = read<Operation *>();
2012   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2013 
2014   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2015   handleSwitch(op->getNumResults(), cases);
2016 }
2017 
2018 void ByteCodeExecutor::executeSwitchType() {
2019   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2020   Type value = read<Type>();
2021   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2022   handleSwitch(value, cases);
2023 }
2024 
2025 void ByteCodeExecutor::executeSwitchTypes() {
2026   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2027   TypeRange *value = read<TypeRange *>();
2028   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2029   if (!value) {
2030     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2031     return selectJump(size_t(0));
2032   }
2033   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2034     return value == caseValue.getAsValueRange<TypeAttr>();
2035   });
2036 }
2037 
2038 LogicalResult
2039 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2040                           SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2041                           Optional<Location> mainRewriteLoc) {
2042   while (true) {
2043     // Print the location of the operation being executed.
2044     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2045 
2046     OpCode opCode = static_cast<OpCode>(read());
2047     switch (opCode) {
2048     case ApplyConstraint:
2049       executeApplyConstraint(rewriter);
2050       break;
2051     case ApplyRewrite:
2052       if (failed(executeApplyRewrite(rewriter)))
2053         return failure();
2054       break;
2055     case AreEqual:
2056       executeAreEqual();
2057       break;
2058     case AreRangesEqual:
2059       executeAreRangesEqual();
2060       break;
2061     case Branch:
2062       executeBranch();
2063       break;
2064     case CheckOperandCount:
2065       executeCheckOperandCount();
2066       break;
2067     case CheckOperationName:
2068       executeCheckOperationName();
2069       break;
2070     case CheckResultCount:
2071       executeCheckResultCount();
2072       break;
2073     case CheckTypes:
2074       executeCheckTypes();
2075       break;
2076     case Continue:
2077       executeContinue();
2078       break;
2079     case CreateOperation:
2080       executeCreateOperation(rewriter, *mainRewriteLoc);
2081       break;
2082     case CreateTypes:
2083       executeCreateTypes();
2084       break;
2085     case EraseOp:
2086       executeEraseOp(rewriter);
2087       break;
2088     case ExtractOp:
2089       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2090       break;
2091     case ExtractType:
2092       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2093       break;
2094     case ExtractValue:
2095       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2096       break;
2097     case Finalize:
2098       executeFinalize();
2099       LLVM_DEBUG(llvm::dbgs() << "\n");
2100       return success();
2101     case ForEach:
2102       executeForEach();
2103       break;
2104     case GetAttribute:
2105       executeGetAttribute();
2106       break;
2107     case GetAttributeType:
2108       executeGetAttributeType();
2109       break;
2110     case GetDefiningOp:
2111       executeGetDefiningOp();
2112       break;
2113     case GetOperand0:
2114     case GetOperand1:
2115     case GetOperand2:
2116     case GetOperand3: {
2117       unsigned index = opCode - GetOperand0;
2118       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2119       executeGetOperand(index);
2120       break;
2121     }
2122     case GetOperandN:
2123       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2124       executeGetOperand(read<uint32_t>());
2125       break;
2126     case GetOperands:
2127       executeGetOperands();
2128       break;
2129     case GetResult0:
2130     case GetResult1:
2131     case GetResult2:
2132     case GetResult3: {
2133       unsigned index = opCode - GetResult0;
2134       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2135       executeGetResult(index);
2136       break;
2137     }
2138     case GetResultN:
2139       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2140       executeGetResult(read<uint32_t>());
2141       break;
2142     case GetResults:
2143       executeGetResults();
2144       break;
2145     case GetUsers:
2146       executeGetUsers();
2147       break;
2148     case GetValueType:
2149       executeGetValueType();
2150       break;
2151     case GetValueRangeTypes:
2152       executeGetValueRangeTypes();
2153       break;
2154     case IsNotNull:
2155       executeIsNotNull();
2156       break;
2157     case RecordMatch:
2158       assert(matches &&
2159              "expected matches to be provided when executing the matcher");
2160       executeRecordMatch(rewriter, *matches);
2161       break;
2162     case ReplaceOp:
2163       executeReplaceOp(rewriter);
2164       break;
2165     case SwitchAttribute:
2166       executeSwitchAttribute();
2167       break;
2168     case SwitchOperandCount:
2169       executeSwitchOperandCount();
2170       break;
2171     case SwitchOperationName:
2172       executeSwitchOperationName();
2173       break;
2174     case SwitchResultCount:
2175       executeSwitchResultCount();
2176       break;
2177     case SwitchType:
2178       executeSwitchType();
2179       break;
2180     case SwitchTypes:
2181       executeSwitchTypes();
2182       break;
2183     }
2184     LLVM_DEBUG(llvm::dbgs() << "\n");
2185   }
2186 }
2187 
2188 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2189                         SmallVectorImpl<MatchResult> &matches,
2190                         PDLByteCodeMutableState &state) const {
2191   // The first memory slot is always the root operation.
2192   state.memory[0] = op;
2193 
2194   // The matcher function always starts at code address 0.
2195   ByteCodeExecutor executor(
2196       matcherByteCode.data(), state.memory, state.opRangeMemory,
2197       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2198       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2199       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2200       constraintFunctions, rewriteFunctions);
2201   LogicalResult executeResult = executor.execute(rewriter, &matches);
2202   assert(succeeded(executeResult) && "unexpected matcher execution failure");
2203 
2204   // Order the found matches by benefit.
2205   std::stable_sort(matches.begin(), matches.end(),
2206                    [](const MatchResult &lhs, const MatchResult &rhs) {
2207                      return lhs.benefit > rhs.benefit;
2208                    });
2209 }
2210 
2211 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2212                                    const MatchResult &match,
2213                                    PDLByteCodeMutableState &state) const {
2214   auto *configSet = match.pattern->getConfigSet();
2215   if (configSet)
2216     configSet->notifyRewriteBegin(rewriter);
2217 
2218   // The arguments of the rewrite function are stored at the start of the
2219   // memory buffer.
2220   llvm::copy(match.values, state.memory.begin());
2221 
2222   ByteCodeExecutor executor(
2223       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2224       state.opRangeMemory, state.typeRangeMemory,
2225       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2226       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2227       rewriterByteCode, state.currentPatternBenefits, patterns,
2228       constraintFunctions, rewriteFunctions);
2229   LogicalResult result =
2230       executor.execute(rewriter, /*matches=*/nullptr, match.location);
2231 
2232   if (configSet)
2233     configSet->notifyRewriteEnd(rewriter);
2234 
2235   // If the rewrite failed, check if the pattern rewriter can recover. If it
2236   // can, we can signal to the pattern applicator to keep trying patterns. If it
2237   // doesn't, we need to bail. Bailing here should be fine, given that we have
2238   // no means to propagate such a failure to the user, and it also indicates a
2239   // bug in the user code (i.e. failable rewrites should not be used with
2240   // pattern rewriters that don't support it).
2241   if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2242     LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2243     llvm::report_fatal_error(
2244         "Native PDL Rewrite failed, but the pattern "
2245         "rewriter doesn't support recovery. Failable pattern rewrites should "
2246         "not be used with pattern rewriters that do not support them.");
2247   }
2248   return result;
2249 }
2250