xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision e1795322844ca45ecbcdca8669929a46c666127e)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.getBenefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.getRootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 /// A marker used to indicate if an operation should infer types.
166 static constexpr ByteCodeField kInferTypesMarker =
167     std::numeric_limits<ByteCodeField>::max();
168 
169 //===----------------------------------------------------------------------===//
170 // ByteCode Generation
171 //===----------------------------------------------------------------------===//
172 
173 //===----------------------------------------------------------------------===//
174 // Generator
175 
176 namespace {
177 struct ByteCodeLiveRange;
178 struct ByteCodeWriter;
179 
180 /// Check if the given class `T` can be converted to an opaque pointer.
181 template <typename T, typename... Args>
182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
183 
184 /// This class represents the main generator for the pattern bytecode.
185 class Generator {
186 public:
187   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
188             SmallVectorImpl<ByteCodeField> &matcherByteCode,
189             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
190             SmallVectorImpl<PDLByteCodePattern> &patterns,
191             ByteCodeField &maxValueMemoryIndex,
192             ByteCodeField &maxOpRangeMemoryIndex,
193             ByteCodeField &maxTypeRangeMemoryIndex,
194             ByteCodeField &maxValueRangeMemoryIndex,
195             ByteCodeField &maxLoopLevel,
196             llvm::StringMap<PDLConstraintFunction> &constraintFns,
197             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
198       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
199         rewriterByteCode(rewriterByteCode), patterns(patterns),
200         maxValueMemoryIndex(maxValueMemoryIndex),
201         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
202         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
203         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
204         maxLoopLevel(maxLoopLevel) {
205     for (const auto &it : llvm::enumerate(constraintFns))
206       constraintToMemIndex.try_emplace(it.value().first(), it.index());
207     for (const auto &it : llvm::enumerate(rewriteFns))
208       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
209   }
210 
211   /// Generate the bytecode for the given PDL interpreter module.
212   void generate(ModuleOp module);
213 
214   /// Return the memory index to use for the given value.
215   ByteCodeField &getMemIndex(Value value) {
216     assert(valueToMemIndex.count(value) &&
217            "expected memory index to be assigned");
218     return valueToMemIndex[value];
219   }
220 
221   /// Return the range memory index used to store the given range value.
222   ByteCodeField &getRangeStorageIndex(Value value) {
223     assert(valueToRangeIndex.count(value) &&
224            "expected range index to be assigned");
225     return valueToRangeIndex[value];
226   }
227 
228   /// Return an index to use when referring to the given data that is uniqued in
229   /// the MLIR context.
230   template <typename T>
231   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
232   getMemIndex(T val) {
233     const void *opaqueVal = val.getAsOpaquePointer();
234 
235     // Get or insert a reference to this value.
236     auto it = uniquedDataToMemIndex.try_emplace(
237         opaqueVal, maxValueMemoryIndex + uniquedData.size());
238     if (it.second)
239       uniquedData.push_back(opaqueVal);
240     return it.first->second;
241   }
242 
243 private:
244   /// Allocate memory indices for the results of operations within the matcher
245   /// and rewriters.
246   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
247                              ModuleOp rewriterModule);
248 
249   /// Generate the bytecode for the given operation.
250   void generate(Region *region, ByteCodeWriter &writer);
251   void generate(Operation *op, ByteCodeWriter &writer);
252   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
289 
290   /// Mapping from value to its corresponding memory index.
291   DenseMap<Value, ByteCodeField> valueToMemIndex;
292 
293   /// Mapping from a range value to its corresponding range storage index.
294   DenseMap<Value, ByteCodeField> valueToRangeIndex;
295 
296   /// Mapping from the name of an externally registered rewrite to its index in
297   /// the bytecode registry.
298   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
299 
300   /// Mapping from the name of an externally registered constraint to its index
301   /// in the bytecode registry.
302   llvm::StringMap<ByteCodeField> constraintToMemIndex;
303 
304   /// Mapping from rewriter function name to the bytecode address of the
305   /// rewriter function in byte.
306   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
307 
308   /// Mapping from a uniqued storage object to its memory index within
309   /// `uniquedData`.
310   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
311 
312   /// The current level of the foreach loop.
313   ByteCodeField curLoopLevel = 0;
314 
315   /// The current MLIR context.
316   MLIRContext *ctx;
317 
318   /// Mapping from block to its address.
319   DenseMap<Block *, ByteCodeAddr> blockToAddr;
320 
321   /// Data of the ByteCode class to be populated.
322   std::vector<const void *> &uniquedData;
323   SmallVectorImpl<ByteCodeField> &matcherByteCode;
324   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
325   SmallVectorImpl<PDLByteCodePattern> &patterns;
326   ByteCodeField &maxValueMemoryIndex;
327   ByteCodeField &maxOpRangeMemoryIndex;
328   ByteCodeField &maxTypeRangeMemoryIndex;
329   ByteCodeField &maxValueRangeMemoryIndex;
330   ByteCodeField &maxLoopLevel;
331 };
332 
333 /// This class provides utilities for writing a bytecode stream.
334 struct ByteCodeWriter {
335   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
336       : bytecode(bytecode), generator(generator) {}
337 
338   /// Append a field to the bytecode.
339   void append(ByteCodeField field) { bytecode.push_back(field); }
340   void append(OpCode opCode) { bytecode.push_back(opCode); }
341 
342   /// Append an address to the bytecode.
343   void append(ByteCodeAddr field) {
344     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
345                   "unexpected ByteCode address size");
346 
347     ByteCodeField fieldParts[2];
348     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
349     bytecode.append({fieldParts[0], fieldParts[1]});
350   }
351 
352   /// Append a single successor to the bytecode, the exact address will need to
353   /// be resolved later.
354   void append(Block *successor) {
355     // Add back a reference to the successor so that the address can be resolved
356     // later.
357     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
358     append(ByteCodeAddr(0));
359   }
360 
361   /// Append a successor range to the bytecode, the exact address will need to
362   /// be resolved later.
363   void append(SuccessorRange successors) {
364     for (Block *successor : successors)
365       append(successor);
366   }
367 
368   /// Append a range of values that will be read as generic PDLValues.
369   void appendPDLValueList(OperandRange values) {
370     bytecode.push_back(values.size());
371     for (Value value : values)
372       appendPDLValue(value);
373   }
374 
375   /// Append a value as a PDLValue.
376   void appendPDLValue(Value value) {
377     appendPDLValueKind(value);
378     append(value);
379   }
380 
381   /// Append the PDLValue::Kind of the given value.
382   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
383 
384   /// Append the PDLValue::Kind of the given type.
385   void appendPDLValueKind(Type type) {
386     PDLValue::Kind kind =
387         TypeSwitch<Type, PDLValue::Kind>(type)
388             .Case<pdl::AttributeType>(
389                 [](Type) { return PDLValue::Kind::Attribute; })
390             .Case<pdl::OperationType>(
391                 [](Type) { return PDLValue::Kind::Operation; })
392             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
393               if (rangeTy.getElementType().isa<pdl::TypeType>())
394                 return PDLValue::Kind::TypeRange;
395               return PDLValue::Kind::ValueRange;
396             })
397             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
398             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
399     bytecode.push_back(static_cast<ByteCodeField>(kind));
400   }
401 
402   /// Append a value that will be stored in a memory slot and not inline within
403   /// the bytecode.
404   template <typename T>
405   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
406                    std::is_pointer<T>::value>
407   append(T value) {
408     bytecode.push_back(generator.getMemIndex(value));
409   }
410 
411   /// Append a range of values.
412   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
413   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
414   append(T range) {
415     bytecode.push_back(llvm::size(range));
416     for (auto it : range)
417       append(it);
418   }
419 
420   /// Append a variadic number of fields to the bytecode.
421   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
422   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
423     append(field);
424     append(field2, fields...);
425   }
426 
427   /// Appends a value as a pointer, stored inline within the bytecode.
428   template <typename T>
429   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
430   appendInline(T value) {
431     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
432     const void *pointer = value.getAsOpaquePointer();
433     ByteCodeField fieldParts[numParts];
434     std::memcpy(fieldParts, &pointer, sizeof(const void *));
435     bytecode.append(fieldParts, fieldParts + numParts);
436   }
437 
438   /// Successor references in the bytecode that have yet to be resolved.
439   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
440 
441   /// The underlying bytecode buffer.
442   SmallVectorImpl<ByteCodeField> &bytecode;
443 
444   /// The main generator producing PDL.
445   Generator &generator;
446 };
447 
448 /// This class represents a live range of PDL Interpreter values, containing
449 /// information about when values are live within a match/rewrite.
450 struct ByteCodeLiveRange {
451   using Set = llvm::IntervalMap<uint64_t, char, 16>;
452   using Allocator = Set::Allocator;
453 
454   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
455 
456   /// Union this live range with the one provided.
457   void unionWith(const ByteCodeLiveRange &rhs) {
458     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
459          ++it)
460       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
461   }
462 
463   /// Returns true if this range overlaps with the one provided.
464   bool overlaps(const ByteCodeLiveRange &rhs) const {
465     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
466         .valid();
467   }
468 
469   /// A map representing the ranges of the match/rewrite that a value is live in
470   /// the interpreter.
471   ///
472   /// We use std::unique_ptr here, because IntervalMap does not provide a
473   /// correct copy or move constructor. We can eliminate the pointer once
474   /// https://reviews.llvm.org/D113240 lands.
475   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
476 
477   /// The operation range storage index for this range.
478   Optional<unsigned> opRangeIndex;
479 
480   /// The type range storage index for this range.
481   Optional<unsigned> typeRangeIndex;
482 
483   /// The value range storage index for this range.
484   Optional<unsigned> valueRangeIndex;
485 };
486 } // namespace
487 
488 void Generator::generate(ModuleOp module) {
489   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
490       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
491   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
492       pdl_interp::PDLInterpDialect::getRewriterModuleName());
493   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
494 
495   // Allocate memory indices for the results of operations within the matcher
496   // and rewriters.
497   allocateMemoryIndices(matcherFunc, rewriterModule);
498 
499   // Generate code for the rewriter functions.
500   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
501   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
502     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
503     for (Operation &op : rewriterFunc.getOps())
504       generate(&op, rewriterByteCodeWriter);
505   }
506   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
507          "unexpected branches in rewriter function");
508 
509   // Generate code for the matcher function.
510   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
511   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
512 
513   // Resolve successor references in the matcher.
514   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
515     ByteCodeAddr addr = blockToAddr[it.first];
516     for (unsigned offsetToFix : it.second)
517       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
518   }
519 }
520 
521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
522                                       ModuleOp rewriterModule) {
523   // Rewriters use simplistic allocation scheme that simply assigns an index to
524   // each result.
525   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
526     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
527     auto processRewriterValue = [&](Value val) {
528       valueToMemIndex.try_emplace(val, index++);
529       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
530         Type elementTy = rangeType.getElementType();
531         if (elementTy.isa<pdl::TypeType>())
532           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
533         else if (elementTy.isa<pdl::ValueType>())
534           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
535       }
536     };
537 
538     for (BlockArgument arg : rewriterFunc.getArguments())
539       processRewriterValue(arg);
540     rewriterFunc.getBody().walk([&](Operation *op) {
541       for (Value result : op->getResults())
542         processRewriterValue(result);
543     });
544     if (index > maxValueMemoryIndex)
545       maxValueMemoryIndex = index;
546     if (typeRangeIndex > maxTypeRangeMemoryIndex)
547       maxTypeRangeMemoryIndex = typeRangeIndex;
548     if (valueRangeIndex > maxValueRangeMemoryIndex)
549       maxValueRangeMemoryIndex = valueRangeIndex;
550   }
551 
552   // The matcher function uses a more sophisticated numbering that tries to
553   // minimize the number of memory indices assigned. This is done by determining
554   // a live range of the values within the matcher, then the allocation is just
555   // finding the minimal number of overlapping live ranges. This is essentially
556   // a simplified form of register allocation where we don't necessarily have a
557   // limited number of registers, but we still want to minimize the number used.
558   DenseMap<Operation *, unsigned> opToFirstIndex;
559   DenseMap<Operation *, unsigned> opToLastIndex;
560 
561   // A custom walk that marks the first and the last index of each operation.
562   // The entry marks the beginning of the liveness range for this operation,
563   // followed by nested operations, followed by the end of the liveness range.
564   unsigned index = 0;
565   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
566     opToFirstIndex.try_emplace(op, index++);
567     for (Region &region : op->getRegions())
568       for (Block &block : region.getBlocks())
569         for (Operation &nested : block)
570           walk(&nested);
571     opToLastIndex.try_emplace(op, index++);
572   };
573   walk(matcherFunc);
574 
575   // Liveness info for each of the defs within the matcher.
576   ByteCodeLiveRange::Allocator allocator;
577   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
578 
579   // Assign the root operation being matched to slot 0.
580   BlockArgument rootOpArg = matcherFunc.getArgument(0);
581   valueToMemIndex[rootOpArg] = 0;
582 
583   // Walk each of the blocks, computing the def interval that the value is used.
584   Liveness matcherLiveness(matcherFunc);
585   matcherFunc->walk([&](Block *block) {
586     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
587     assert(info && "expected liveness info for block");
588     auto processValue = [&](Value value, Operation *firstUseOrDef) {
589       // We don't need to process the root op argument, this value is always
590       // assigned to the first memory slot.
591       if (value == rootOpArg)
592         return;
593 
594       // Set indices for the range of this block that the value is used.
595       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
596       defRangeIt->second.liveness->insert(
597           opToFirstIndex[firstUseOrDef],
598           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
599           /*dummyValue*/ 0);
600 
601       // Check to see if this value is a range type.
602       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
603         Type eleType = rangeTy.getElementType();
604         if (eleType.isa<pdl::OperationType>())
605           defRangeIt->second.opRangeIndex = 0;
606         else if (eleType.isa<pdl::TypeType>())
607           defRangeIt->second.typeRangeIndex = 0;
608         else if (eleType.isa<pdl::ValueType>())
609           defRangeIt->second.valueRangeIndex = 0;
610       }
611     };
612 
613     // Process the live-ins of this block.
614     for (Value liveIn : info->in()) {
615       // Only process the value if it has been defined in the current region.
616       // Other values that span across pdl_interp.foreach will be added higher
617       // up. This ensures that the we keep them alive for the entire duration
618       // of the loop.
619       if (liveIn.getParentRegion() == block->getParent())
620         processValue(liveIn, &block->front());
621     }
622 
623     // Process the block arguments for the entry block (those are not live-in).
624     if (block->isEntryBlock()) {
625       for (Value argument : block->getArguments())
626         processValue(argument, &block->front());
627     }
628 
629     // Process any new defs within this block.
630     for (Operation &op : *block)
631       for (Value result : op.getResults())
632         processValue(result, &op);
633   });
634 
635   // Greedily allocate memory slots using the computed def live ranges.
636   std::vector<ByteCodeLiveRange> allocatedIndices;
637 
638   // The number of memory indices currently allocated (and its next value).
639   // Recall that the root gets allocated memory index 0.
640   ByteCodeField numIndices = 1;
641 
642   // The number of memory ranges of various types (and their next values).
643   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
644 
645   for (auto &defIt : valueDefRanges) {
646     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
647     ByteCodeLiveRange &defRange = defIt.second;
648 
649     // Try to allocate to an existing index.
650     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
651       ByteCodeLiveRange &existingRange = existingIndexIt.value();
652       if (!defRange.overlaps(existingRange)) {
653         existingRange.unionWith(defRange);
654         memIndex = existingIndexIt.index() + 1;
655 
656         if (defRange.opRangeIndex) {
657           if (!existingRange.opRangeIndex)
658             existingRange.opRangeIndex = numOpRanges++;
659           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
660         } else if (defRange.typeRangeIndex) {
661           if (!existingRange.typeRangeIndex)
662             existingRange.typeRangeIndex = numTypeRanges++;
663           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
664         } else if (defRange.valueRangeIndex) {
665           if (!existingRange.valueRangeIndex)
666             existingRange.valueRangeIndex = numValueRanges++;
667           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
668         }
669         break;
670       }
671     }
672 
673     // If no existing index could be used, add a new one.
674     if (memIndex == 0) {
675       allocatedIndices.emplace_back(allocator);
676       ByteCodeLiveRange &newRange = allocatedIndices.back();
677       newRange.unionWith(defRange);
678 
679       // Allocate an index for op/type/value ranges.
680       if (defRange.opRangeIndex) {
681         newRange.opRangeIndex = numOpRanges;
682         valueToRangeIndex[defIt.first] = numOpRanges++;
683       } else if (defRange.typeRangeIndex) {
684         newRange.typeRangeIndex = numTypeRanges;
685         valueToRangeIndex[defIt.first] = numTypeRanges++;
686       } else if (defRange.valueRangeIndex) {
687         newRange.valueRangeIndex = numValueRanges;
688         valueToRangeIndex[defIt.first] = numValueRanges++;
689       }
690 
691       memIndex = allocatedIndices.size();
692       ++numIndices;
693     }
694   }
695 
696   // Print the index usage and ensure that we did not run out of index space.
697   LLVM_DEBUG({
698     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
699                  << "(down from initial " << valueDefRanges.size() << ").\n";
700   });
701   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
702          "Ran out of memory for allocated indices");
703 
704   // Update the max number of indices.
705   if (numIndices > maxValueMemoryIndex)
706     maxValueMemoryIndex = numIndices;
707   if (numOpRanges > maxOpRangeMemoryIndex)
708     maxOpRangeMemoryIndex = numOpRanges;
709   if (numTypeRanges > maxTypeRangeMemoryIndex)
710     maxTypeRangeMemoryIndex = numTypeRanges;
711   if (numValueRanges > maxValueRangeMemoryIndex)
712     maxValueRangeMemoryIndex = numValueRanges;
713 }
714 
715 void Generator::generate(Region *region, ByteCodeWriter &writer) {
716   llvm::ReversePostOrderTraversal<Region *> rpot(region);
717   for (Block *block : rpot) {
718     // Keep track of where this block begins within the matcher function.
719     blockToAddr.try_emplace(block, matcherByteCode.size());
720     for (Operation &op : *block)
721       generate(&op, writer);
722   }
723 }
724 
725 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
726   LLVM_DEBUG({
727     // The following list must contain all the operations that do not
728     // produce any bytecode.
729     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
730       writer.appendInline(op->getLoc());
731   });
732   TypeSwitch<Operation *>(op)
733       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
734             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
735             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
736             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
737             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
738             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
739             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
740             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
741             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
742             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
743             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
744             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
745             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
746             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
747             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
748             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
749             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
750             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
751             pdl_interp::SwitchResultCountOp>(
752           [&](auto interpOp) { this->generate(interpOp, writer); })
753       .Default([](Operation *) {
754         llvm_unreachable("unknown `pdl_interp` operation");
755       });
756 }
757 
758 void Generator::generate(pdl_interp::ApplyConstraintOp op,
759                          ByteCodeWriter &writer) {
760   assert(constraintToMemIndex.count(op.getName()) &&
761          "expected index for constraint function");
762   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
763   writer.appendPDLValueList(op.getArgs());
764   writer.append(op.getSuccessors());
765 }
766 void Generator::generate(pdl_interp::ApplyRewriteOp op,
767                          ByteCodeWriter &writer) {
768   assert(externalRewriterToMemIndex.count(op.getName()) &&
769          "expected index for rewrite function");
770   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
771   writer.appendPDLValueList(op.getArgs());
772 
773   ResultRange results = op.getResults();
774   writer.append(ByteCodeField(results.size()));
775   for (Value result : results) {
776     // In debug mode we also record the expected kind of the result, so that we
777     // can provide extra verification of the native rewrite function.
778 #ifndef NDEBUG
779     writer.appendPDLValueKind(result);
780 #endif
781 
782     // Range results also need to append the range storage index.
783     if (result.getType().isa<pdl::RangeType>())
784       writer.append(getRangeStorageIndex(result));
785     writer.append(result);
786   }
787 }
788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789   Value lhs = op.getLhs();
790   if (lhs.getType().isa<pdl::RangeType>()) {
791     writer.append(OpCode::AreRangesEqual);
792     writer.appendPDLValueKind(lhs);
793     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
794     return;
795   }
796 
797   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
798 }
799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
800   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801 }
802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803                          ByteCodeWriter &writer) {
804   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
805                 op.getSuccessors());
806 }
807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808                          ByteCodeWriter &writer) {
809   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
810                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
811                 op.getSuccessors());
812 }
813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckOperationName, op.getInputOp(),
816                 OperationName(op.getName(), ctx), op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819                          ByteCodeWriter &writer) {
820   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
821                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
822                 op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
826                 op.getSuccessors());
827 }
828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
829   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
830                 op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
833   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
834   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
835 }
836 void Generator::generate(pdl_interp::CreateAttributeOp op,
837                          ByteCodeWriter &writer) {
838   // Simply repoint the memory index of the result to the constant.
839   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
840 }
841 void Generator::generate(pdl_interp::CreateOperationOp op,
842                          ByteCodeWriter &writer) {
843   writer.append(OpCode::CreateOperation, op.getResultOp(),
844                 OperationName(op.getName(), ctx));
845   writer.appendPDLValueList(op.getInputOperands());
846 
847   // Add the attributes.
848   OperandRange attributes = op.getInputAttributes();
849   writer.append(static_cast<ByteCodeField>(attributes.size()));
850   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851     writer.append(std::get<0>(it), std::get<1>(it));
852 
853   // Add the result types. If the operation has inferred results, we use a
854   // marker "size" value. Otherwise, we add the list of explicit result types.
855   if (op.getInferredResultTypes())
856     writer.append(kInferTypesMarker);
857   else
858     writer.appendPDLValueList(op.getInputResultTypes());
859 }
860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
861   // Simply repoint the memory index of the result to the constant.
862   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
863 }
864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
865   writer.append(OpCode::CreateTypes, op.getResult(),
866                 getRangeStorageIndex(op.getResult()), op.getValue());
867 }
868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
869   writer.append(OpCode::EraseOp, op.getInputOp());
870 }
871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
872   OpCode opCode =
873       TypeSwitch<Type, OpCode>(op.getResult().getType())
874           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
875           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
876           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
877           .Default([](Type) -> OpCode {
878             llvm_unreachable("unsupported element type");
879           });
880   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
881 }
882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
883   writer.append(OpCode::Finalize);
884 }
885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
886   BlockArgument arg = op.getLoopVariable();
887   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
888   writer.appendPDLValueKind(arg.getType());
889   writer.append(curLoopLevel, op.getSuccessor());
890   ++curLoopLevel;
891   if (curLoopLevel > maxLoopLevel)
892     maxLoopLevel = curLoopLevel;
893   generate(&op.getRegion(), writer);
894   --curLoopLevel;
895 }
896 void Generator::generate(pdl_interp::GetAttributeOp op,
897                          ByteCodeWriter &writer) {
898   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
899                 op.getNameAttr());
900 }
901 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
902                          ByteCodeWriter &writer) {
903   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
904 }
905 void Generator::generate(pdl_interp::GetDefiningOpOp op,
906                          ByteCodeWriter &writer) {
907   writer.append(OpCode::GetDefiningOp, op.getInputOp());
908   writer.appendPDLValue(op.getValue());
909 }
910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
911   uint32_t index = op.getIndex();
912   if (index < 4)
913     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
914   else
915     writer.append(OpCode::GetOperandN, index);
916   writer.append(op.getInputOp(), op.getValue());
917 }
918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
919   Value result = op.getValue();
920   Optional<uint32_t> index = op.getIndex();
921   writer.append(OpCode::GetOperands,
922                 index.value_or(std::numeric_limits<uint32_t>::max()),
923                 op.getInputOp());
924   if (result.getType().isa<pdl::RangeType>())
925     writer.append(getRangeStorageIndex(result));
926   else
927     writer.append(std::numeric_limits<ByteCodeField>::max());
928   writer.append(result);
929 }
930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
931   uint32_t index = op.getIndex();
932   if (index < 4)
933     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
934   else
935     writer.append(OpCode::GetResultN, index);
936   writer.append(op.getInputOp(), op.getValue());
937 }
938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
939   Value result = op.getValue();
940   Optional<uint32_t> index = op.getIndex();
941   writer.append(OpCode::GetResults,
942                 index.value_or(std::numeric_limits<uint32_t>::max()),
943                 op.getInputOp());
944   if (result.getType().isa<pdl::RangeType>())
945     writer.append(getRangeStorageIndex(result));
946   else
947     writer.append(std::numeric_limits<ByteCodeField>::max());
948   writer.append(result);
949 }
950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
951   Value operations = op.getOperations();
952   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
953   writer.append(OpCode::GetUsers, operations, rangeIndex);
954   writer.appendPDLValue(op.getValue());
955 }
956 void Generator::generate(pdl_interp::GetValueTypeOp op,
957                          ByteCodeWriter &writer) {
958   if (op.getType().isa<pdl::RangeType>()) {
959     Value result = op.getResult();
960     writer.append(OpCode::GetValueRangeTypes, result,
961                   getRangeStorageIndex(result), op.getValue());
962   } else {
963     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
964   }
965 }
966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
967   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
968 }
969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
970   ByteCodeField patternIndex = patterns.size();
971   patterns.emplace_back(PDLByteCodePattern::create(
972       op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
973   writer.append(OpCode::RecordMatch, patternIndex,
974                 SuccessorRange(op.getOperation()), op.getMatchedOps());
975   writer.appendPDLValueList(op.getInputs());
976 }
977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
978   writer.append(OpCode::ReplaceOp, op.getInputOp());
979   writer.appendPDLValueList(op.getReplValues());
980 }
981 void Generator::generate(pdl_interp::SwitchAttributeOp op,
982                          ByteCodeWriter &writer) {
983   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
984                 op.getCaseValuesAttr(), op.getSuccessors());
985 }
986 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987                          ByteCodeWriter &writer) {
988   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
989                 op.getCaseValuesAttr(), op.getSuccessors());
990 }
991 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992                          ByteCodeWriter &writer) {
993   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
994     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
995   });
996   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
997                 op.getSuccessors());
998 }
999 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000                          ByteCodeWriter &writer) {
1001   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1002                 op.getCaseValuesAttr(), op.getSuccessors());
1003 }
1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1005   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006                 op.getSuccessors());
1007 }
1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1009   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1010                 op.getSuccessors());
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // PDLByteCode
1015 //===----------------------------------------------------------------------===//
1016 
1017 PDLByteCode::PDLByteCode(ModuleOp module,
1018                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1019                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1021                       rewriterByteCode, patterns, maxValueMemoryIndex,
1022                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1023                       maxLoopLevel, constraintFns, rewriteFns);
1024   generator.generate(module);
1025 
1026   // Initialize the external functions.
1027   for (auto &it : constraintFns)
1028     constraintFunctions.push_back(std::move(it.second));
1029   for (auto &it : rewriteFns)
1030     rewriteFunctions.push_back(std::move(it.second));
1031 }
1032 
1033 /// Initialize the given state such that it can be used to execute the current
1034 /// bytecode.
1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1036   state.memory.resize(maxValueMemoryIndex, nullptr);
1037   state.opRangeMemory.resize(maxOpRangeCount);
1038   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1039   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1040   state.loopIndex.resize(maxLoopLevel, 0);
1041   state.currentPatternBenefits.reserve(patterns.size());
1042   for (const PDLByteCodePattern &pattern : patterns)
1043     state.currentPatternBenefits.push_back(pattern.getBenefit());
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ByteCode Execution
1048 
1049 namespace {
1050 /// This class provides support for executing a bytecode stream.
1051 class ByteCodeExecutor {
1052 public:
1053   ByteCodeExecutor(
1054       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1055       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1056       MutableArrayRef<TypeRange> typeRangeMemory,
1057       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1058       MutableArrayRef<ValueRange> valueRangeMemory,
1059       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1060       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1061       ArrayRef<ByteCodeField> code,
1062       ArrayRef<PatternBenefit> currentPatternBenefits,
1063       ArrayRef<PDLByteCodePattern> patterns,
1064       ArrayRef<PDLConstraintFunction> constraintFunctions,
1065       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1066       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1067         typeRangeMemory(typeRangeMemory),
1068         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1069         valueRangeMemory(valueRangeMemory),
1070         allocatedValueRangeMemory(allocatedValueRangeMemory),
1071         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1072         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1073         constraintFunctions(constraintFunctions),
1074         rewriteFunctions(rewriteFunctions) {}
1075 
1076   /// Start executing the code at the current bytecode index. `matches` is an
1077   /// optional field provided when this function is executed in a matching
1078   /// context.
1079   void execute(PatternRewriter &rewriter,
1080                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1081                Optional<Location> mainRewriteLoc = {});
1082 
1083 private:
1084   /// Internal implementation of executing each of the bytecode commands.
1085   void executeApplyConstraint(PatternRewriter &rewriter);
1086   void executeApplyRewrite(PatternRewriter &rewriter);
1087   void executeAreEqual();
1088   void executeAreRangesEqual();
1089   void executeBranch();
1090   void executeCheckOperandCount();
1091   void executeCheckOperationName();
1092   void executeCheckResultCount();
1093   void executeCheckTypes();
1094   void executeContinue();
1095   void executeCreateOperation(PatternRewriter &rewriter,
1096                               Location mainRewriteLoc);
1097   void executeCreateTypes();
1098   void executeEraseOp(PatternRewriter &rewriter);
1099   template <typename T, typename Range, PDLValue::Kind kind>
1100   void executeExtract();
1101   void executeFinalize();
1102   void executeForEach();
1103   void executeGetAttribute();
1104   void executeGetAttributeType();
1105   void executeGetDefiningOp();
1106   void executeGetOperand(unsigned index);
1107   void executeGetOperands();
1108   void executeGetResult(unsigned index);
1109   void executeGetResults();
1110   void executeGetUsers();
1111   void executeGetValueType();
1112   void executeGetValueRangeTypes();
1113   void executeIsNotNull();
1114   void executeRecordMatch(PatternRewriter &rewriter,
1115                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1116   void executeReplaceOp(PatternRewriter &rewriter);
1117   void executeSwitchAttribute();
1118   void executeSwitchOperandCount();
1119   void executeSwitchOperationName();
1120   void executeSwitchResultCount();
1121   void executeSwitchType();
1122   void executeSwitchTypes();
1123 
1124   /// Pushes a code iterator to the stack.
1125   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1126 
1127   /// Pops a code iterator from the stack, returning true on success.
1128   void popCodeIt() {
1129     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1130     curCodeIt = resumeCodeIt.back();
1131     resumeCodeIt.pop_back();
1132   }
1133 
1134   /// Return the bytecode iterator at the start of the current op code.
1135   const ByteCodeField *getPrevCodeIt() const {
1136     LLVM_DEBUG({
1137       // Account for the op code and the Location stored inline.
1138       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1139     });
1140 
1141     // Account for the op code only.
1142     return curCodeIt - 1;
1143   }
1144 
1145   /// Read a value from the bytecode buffer, optionally skipping a certain
1146   /// number of prefix values. These methods always update the buffer to point
1147   /// to the next field after the read data.
1148   template <typename T = ByteCodeField>
1149   T read(size_t skipN = 0) {
1150     curCodeIt += skipN;
1151     return readImpl<T>();
1152   }
1153   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1154 
1155   /// Read a list of values from the bytecode buffer.
1156   template <typename ValueT, typename T>
1157   void readList(SmallVectorImpl<T> &list) {
1158     list.clear();
1159     for (unsigned i = 0, e = read(); i != e; ++i)
1160       list.push_back(read<ValueT>());
1161   }
1162 
1163   /// Read a list of values from the bytecode buffer. The values may be encoded
1164   /// as either Value or ValueRange elements.
1165   void readValueList(SmallVectorImpl<Value> &list) {
1166     for (unsigned i = 0, e = read(); i != e; ++i) {
1167       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1168         list.push_back(read<Value>());
1169       } else {
1170         ValueRange *values = read<ValueRange *>();
1171         list.append(values->begin(), values->end());
1172       }
1173     }
1174   }
1175 
1176   /// Read a value stored inline as a pointer.
1177   template <typename T>
1178   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1179   readInline() {
1180     const void *pointer;
1181     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1182     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1183     return T::getFromOpaquePointer(pointer);
1184   }
1185 
1186   /// Jump to a specific successor based on a predicate value.
1187   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1188   /// Jump to a specific successor based on a destination index.
1189   void selectJump(size_t destIndex) {
1190     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1191   }
1192 
1193   /// Handle a switch operation with the provided value and cases.
1194   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1195   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1196     LLVM_DEBUG({
1197       llvm::dbgs() << "  * Value: " << value << "\n"
1198                    << "  * Cases: ";
1199       llvm::interleaveComma(cases, llvm::dbgs());
1200       llvm::dbgs() << "\n";
1201     });
1202 
1203     // Check to see if the attribute value is within the case list. Jump to
1204     // the correct successor index based on the result.
1205     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1206       if (cmp(*it, value))
1207         return selectJump(size_t((it - cases.begin()) + 1));
1208     selectJump(size_t(0));
1209   }
1210 
1211   /// Store a pointer to memory.
1212   void storeToMemory(unsigned index, const void *value) {
1213     memory[index] = value;
1214   }
1215 
1216   /// Store a value to memory as an opaque pointer.
1217   template <typename T>
1218   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1219   storeToMemory(unsigned index, T value) {
1220     memory[index] = value.getAsOpaquePointer();
1221   }
1222 
1223   /// Internal implementation of reading various data types from the bytecode
1224   /// stream.
1225   template <typename T>
1226   const void *readFromMemory() {
1227     size_t index = *curCodeIt++;
1228 
1229     // If this type is an SSA value, it can only be stored in non-const memory.
1230     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1231                         Value>::value ||
1232         index < memory.size())
1233       return memory[index];
1234 
1235     // Otherwise, if this index is not inbounds it is uniqued.
1236     return uniquedMemory[index - memory.size()];
1237   }
1238   template <typename T>
1239   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1240     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1241   }
1242   template <typename T>
1243   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1244                    T>
1245   readImpl() {
1246     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1247   }
1248   template <typename T>
1249   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1250     switch (read<PDLValue::Kind>()) {
1251     case PDLValue::Kind::Attribute:
1252       return read<Attribute>();
1253     case PDLValue::Kind::Operation:
1254       return read<Operation *>();
1255     case PDLValue::Kind::Type:
1256       return read<Type>();
1257     case PDLValue::Kind::Value:
1258       return read<Value>();
1259     case PDLValue::Kind::TypeRange:
1260       return read<TypeRange *>();
1261     case PDLValue::Kind::ValueRange:
1262       return read<ValueRange *>();
1263     }
1264     llvm_unreachable("unhandled PDLValue::Kind");
1265   }
1266   template <typename T>
1267   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1268     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1269                   "unexpected ByteCode address size");
1270     ByteCodeAddr result;
1271     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1272     curCodeIt += 2;
1273     return result;
1274   }
1275   template <typename T>
1276   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1277     return *curCodeIt++;
1278   }
1279   template <typename T>
1280   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1281     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1282   }
1283 
1284   /// The underlying bytecode buffer.
1285   const ByteCodeField *curCodeIt;
1286 
1287   /// The stack of bytecode positions at which to resume operation.
1288   SmallVector<const ByteCodeField *> resumeCodeIt;
1289 
1290   /// The current execution memory.
1291   MutableArrayRef<const void *> memory;
1292   MutableArrayRef<OwningOpRange> opRangeMemory;
1293   MutableArrayRef<TypeRange> typeRangeMemory;
1294   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1295   MutableArrayRef<ValueRange> valueRangeMemory;
1296   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1297 
1298   /// The current loop indices.
1299   MutableArrayRef<unsigned> loopIndex;
1300 
1301   /// References to ByteCode data necessary for execution.
1302   ArrayRef<const void *> uniquedMemory;
1303   ArrayRef<ByteCodeField> code;
1304   ArrayRef<PatternBenefit> currentPatternBenefits;
1305   ArrayRef<PDLByteCodePattern> patterns;
1306   ArrayRef<PDLConstraintFunction> constraintFunctions;
1307   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1308 };
1309 
1310 /// This class is an instantiation of the PDLResultList that provides access to
1311 /// the returned results. This API is not on `PDLResultList` to avoid
1312 /// overexposing access to information specific solely to the ByteCode.
1313 class ByteCodeRewriteResultList : public PDLResultList {
1314 public:
1315   ByteCodeRewriteResultList(unsigned maxNumResults)
1316       : PDLResultList(maxNumResults) {}
1317 
1318   /// Return the list of PDL results.
1319   MutableArrayRef<PDLValue> getResults() { return results; }
1320 
1321   /// Return the type ranges allocated by this list.
1322   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1323     return allocatedTypeRanges;
1324   }
1325 
1326   /// Return the value ranges allocated by this list.
1327   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1328     return allocatedValueRanges;
1329   }
1330 };
1331 } // namespace
1332 
1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1334   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1335   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1336   SmallVector<PDLValue, 16> args;
1337   readList<PDLValue>(args);
1338 
1339   LLVM_DEBUG({
1340     llvm::dbgs() << "  * Arguments: ";
1341     llvm::interleaveComma(args, llvm::dbgs());
1342   });
1343 
1344   // Invoke the constraint and jump to the proper destination.
1345   selectJump(succeeded(constraintFn(rewriter, args)));
1346 }
1347 
1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1349   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1350   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1351   SmallVector<PDLValue, 16> args;
1352   readList<PDLValue>(args);
1353 
1354   LLVM_DEBUG({
1355     llvm::dbgs() << "  * Arguments: ";
1356     llvm::interleaveComma(args, llvm::dbgs());
1357   });
1358 
1359   // Execute the rewrite function.
1360   ByteCodeField numResults = read();
1361   ByteCodeRewriteResultList results(numResults);
1362   rewriteFn(rewriter, results, args);
1363 
1364   assert(results.getResults().size() == numResults &&
1365          "native PDL rewrite function returned unexpected number of results");
1366 
1367   // Store the results in the bytecode memory.
1368   for (PDLValue &result : results.getResults()) {
1369     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1370 
1371 // In debug mode we also verify the expected kind of the result.
1372 #ifndef NDEBUG
1373     assert(result.getKind() == read<PDLValue::Kind>() &&
1374            "native PDL rewrite function returned an unexpected type of result");
1375 #endif
1376 
1377     // If the result is a range, we need to copy it over to the bytecodes
1378     // range memory.
1379     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1380       unsigned rangeIndex = read();
1381       typeRangeMemory[rangeIndex] = *typeRange;
1382       memory[read()] = &typeRangeMemory[rangeIndex];
1383     } else if (Optional<ValueRange> valueRange =
1384                    result.dyn_cast<ValueRange>()) {
1385       unsigned rangeIndex = read();
1386       valueRangeMemory[rangeIndex] = *valueRange;
1387       memory[read()] = &valueRangeMemory[rangeIndex];
1388     } else {
1389       memory[read()] = result.getAsOpaquePointer();
1390     }
1391   }
1392 
1393   // Copy over any underlying storage allocated for result ranges.
1394   for (auto &it : results.getAllocatedTypeRanges())
1395     allocatedTypeRangeMemory.push_back(std::move(it));
1396   for (auto &it : results.getAllocatedValueRanges())
1397     allocatedValueRangeMemory.push_back(std::move(it));
1398 }
1399 
1400 void ByteCodeExecutor::executeAreEqual() {
1401   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1402   const void *lhs = read<const void *>();
1403   const void *rhs = read<const void *>();
1404 
1405   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1406   selectJump(lhs == rhs);
1407 }
1408 
1409 void ByteCodeExecutor::executeAreRangesEqual() {
1410   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1411   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1412   const void *lhs = read<const void *>();
1413   const void *rhs = read<const void *>();
1414 
1415   switch (valueKind) {
1416   case PDLValue::Kind::TypeRange: {
1417     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1418     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1419     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1420     selectJump(*lhsRange == *rhsRange);
1421     break;
1422   }
1423   case PDLValue::Kind::ValueRange: {
1424     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1425     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1426     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1427     selectJump(*lhsRange == *rhsRange);
1428     break;
1429   }
1430   default:
1431     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1432   }
1433 }
1434 
1435 void ByteCodeExecutor::executeBranch() {
1436   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1437   curCodeIt = &code[read<ByteCodeAddr>()];
1438 }
1439 
1440 void ByteCodeExecutor::executeCheckOperandCount() {
1441   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1442   Operation *op = read<Operation *>();
1443   uint32_t expectedCount = read<uint32_t>();
1444   bool compareAtLeast = read();
1445 
1446   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1447                           << "  * Expected: " << expectedCount << "\n"
1448                           << "  * Comparator: "
1449                           << (compareAtLeast ? ">=" : "==") << "\n");
1450   if (compareAtLeast)
1451     selectJump(op->getNumOperands() >= expectedCount);
1452   else
1453     selectJump(op->getNumOperands() == expectedCount);
1454 }
1455 
1456 void ByteCodeExecutor::executeCheckOperationName() {
1457   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1458   Operation *op = read<Operation *>();
1459   OperationName expectedName = read<OperationName>();
1460 
1461   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1462                           << "  * Expected: \"" << expectedName << "\"\n");
1463   selectJump(op->getName() == expectedName);
1464 }
1465 
1466 void ByteCodeExecutor::executeCheckResultCount() {
1467   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1468   Operation *op = read<Operation *>();
1469   uint32_t expectedCount = read<uint32_t>();
1470   bool compareAtLeast = read();
1471 
1472   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1473                           << "  * Expected: " << expectedCount << "\n"
1474                           << "  * Comparator: "
1475                           << (compareAtLeast ? ">=" : "==") << "\n");
1476   if (compareAtLeast)
1477     selectJump(op->getNumResults() >= expectedCount);
1478   else
1479     selectJump(op->getNumResults() == expectedCount);
1480 }
1481 
1482 void ByteCodeExecutor::executeCheckTypes() {
1483   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1484   TypeRange *lhs = read<TypeRange *>();
1485   Attribute rhs = read<Attribute>();
1486   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1487 
1488   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1489 }
1490 
1491 void ByteCodeExecutor::executeContinue() {
1492   ByteCodeField level = read();
1493   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1494                           << "  * Level: " << level << "\n");
1495   ++loopIndex[level];
1496   popCodeIt();
1497 }
1498 
1499 void ByteCodeExecutor::executeCreateTypes() {
1500   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1501   unsigned memIndex = read();
1502   unsigned rangeIndex = read();
1503   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1504 
1505   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1506 
1507   // Allocate a buffer for this type range.
1508   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1509   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1510   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1511 
1512   // Assign this to the range slot and use the range as the value for the
1513   // memory index.
1514   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1515   memory[memIndex] = &typeRangeMemory[rangeIndex];
1516 }
1517 
1518 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1519                                               Location mainRewriteLoc) {
1520   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1521 
1522   unsigned memIndex = read();
1523   OperationState state(mainRewriteLoc, read<OperationName>());
1524   readValueList(state.operands);
1525   for (unsigned i = 0, e = read(); i != e; ++i) {
1526     StringAttr name = read<StringAttr>();
1527     if (Attribute attr = read<Attribute>())
1528       state.addAttribute(name, attr);
1529   }
1530 
1531   // Read in the result types. If the "size" is the sentinel value, this
1532   // indicates that the result types should be inferred.
1533   unsigned numResults = read();
1534   if (numResults == kInferTypesMarker) {
1535     InferTypeOpInterface::Concept *inferInterface =
1536         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1537     assert(inferInterface &&
1538            "expected operation to provide InferTypeOpInterface");
1539 
1540     // TODO: Handle failure.
1541     if (failed(inferInterface->inferReturnTypes(
1542             state.getContext(), state.location, state.operands,
1543             state.attributes.getDictionary(state.getContext()), state.regions,
1544             state.types)))
1545       return;
1546   } else {
1547     // Otherwise, this is a fixed number of results.
1548     for (unsigned i = 0; i != numResults; ++i) {
1549       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1550         state.types.push_back(read<Type>());
1551       } else {
1552         TypeRange *resultTypes = read<TypeRange *>();
1553         state.types.append(resultTypes->begin(), resultTypes->end());
1554       }
1555     }
1556   }
1557 
1558   Operation *resultOp = rewriter.create(state);
1559   memory[memIndex] = resultOp;
1560 
1561   LLVM_DEBUG({
1562     llvm::dbgs() << "  * Attributes: "
1563                  << state.attributes.getDictionary(state.getContext())
1564                  << "\n  * Operands: ";
1565     llvm::interleaveComma(state.operands, llvm::dbgs());
1566     llvm::dbgs() << "\n  * Result Types: ";
1567     llvm::interleaveComma(state.types, llvm::dbgs());
1568     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1569   });
1570 }
1571 
1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574   Operation *op = read<Operation *>();
1575 
1576   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1577   rewriter.eraseOp(op);
1578 }
1579 
1580 template <typename T, typename Range, PDLValue::Kind kind>
1581 void ByteCodeExecutor::executeExtract() {
1582   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1583   Range *range = read<Range *>();
1584   unsigned index = read<uint32_t>();
1585   unsigned memIndex = read();
1586 
1587   if (!range) {
1588     memory[memIndex] = nullptr;
1589     return;
1590   }
1591 
1592   T result = index < range->size() ? (*range)[index] : T();
1593   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1594                           << "  * Index: " << index << "\n"
1595                           << "  * Result: " << result << "\n");
1596   storeToMemory(memIndex, result);
1597 }
1598 
1599 void ByteCodeExecutor::executeFinalize() {
1600   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1601 }
1602 
1603 void ByteCodeExecutor::executeForEach() {
1604   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1606   unsigned rangeIndex = read();
1607   unsigned memIndex = read();
1608   const void *value = nullptr;
1609 
1610   switch (read<PDLValue::Kind>()) {
1611   case PDLValue::Kind::Operation: {
1612     unsigned &index = loopIndex[read()];
1613     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1614     assert(index <= array.size() && "iterated past the end");
1615     if (index < array.size()) {
1616       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1617       value = array[index];
1618       break;
1619     }
1620 
1621     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1622     index = 0;
1623     selectJump(size_t(0));
1624     return;
1625   }
1626   default:
1627     llvm_unreachable("unexpected `ForEach` value kind");
1628   }
1629 
1630   // Store the iterate value and the stack address.
1631   memory[memIndex] = value;
1632   pushCodeIt(prevCodeIt);
1633 
1634   // Skip over the successor (we will enter the body of the loop).
1635   read<ByteCodeAddr>();
1636 }
1637 
1638 void ByteCodeExecutor::executeGetAttribute() {
1639   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640   unsigned memIndex = read();
1641   Operation *op = read<Operation *>();
1642   StringAttr attrName = read<StringAttr>();
1643   Attribute attr = op->getAttr(attrName);
1644 
1645   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1646                           << "  * Attribute: " << attrName << "\n"
1647                           << "  * Result: " << attr << "\n");
1648   memory[memIndex] = attr.getAsOpaquePointer();
1649 }
1650 
1651 void ByteCodeExecutor::executeGetAttributeType() {
1652   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653   unsigned memIndex = read();
1654   Attribute attr = read<Attribute>();
1655   Type type;
1656   if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1657     type = typedAttr.getType();
1658 
1659   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1660                           << "  * Result: " << type << "\n");
1661   memory[memIndex] = type.getAsOpaquePointer();
1662 }
1663 
1664 void ByteCodeExecutor::executeGetDefiningOp() {
1665   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1666   unsigned memIndex = read();
1667   Operation *op = nullptr;
1668   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1669     Value value = read<Value>();
1670     if (value)
1671       op = value.getDefiningOp();
1672     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1673   } else {
1674     ValueRange *values = read<ValueRange *>();
1675     if (values && !values->empty()) {
1676       op = values->front().getDefiningOp();
1677     }
1678     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1679   }
1680 
1681   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1682   memory[memIndex] = op;
1683 }
1684 
1685 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1686   Operation *op = read<Operation *>();
1687   unsigned memIndex = read();
1688   Value operand =
1689       index < op->getNumOperands() ? op->getOperand(index) : Value();
1690 
1691   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1692                           << "  * Index: " << index << "\n"
1693                           << "  * Result: " << operand << "\n");
1694   memory[memIndex] = operand.getAsOpaquePointer();
1695 }
1696 
1697 /// This function is the internal implementation of `GetResults` and
1698 /// `GetOperands` that provides support for extracting a value range from the
1699 /// given operation.
1700 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1701 static void *
1702 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1703                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1704                           MutableArrayRef<ValueRange> valueRangeMemory) {
1705   // Check for the sentinel index that signals that all values should be
1706   // returned.
1707   if (index == std::numeric_limits<uint32_t>::max()) {
1708     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1709     // `values` is already the full value range.
1710 
1711     // Otherwise, check to see if this operation uses AttrSizedSegments.
1712   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1713     LLVM_DEBUG(llvm::dbgs()
1714                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1715 
1716     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1717     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1718       return nullptr;
1719 
1720     auto segments = segmentAttr.getValues<int32_t>();
1721     unsigned startIndex =
1722         std::accumulate(segments.begin(), segments.begin() + index, 0);
1723     values = values.slice(startIndex, *std::next(segments.begin(), index));
1724 
1725     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1726                             << *std::next(segments.begin(), index) << "]\n");
1727 
1728     // Otherwise, assume this is the last operand group of the operation.
1729     // FIXME: We currently don't support operations with
1730     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1731     // have a way to detect it's presence.
1732   } else if (values.size() >= index) {
1733     LLVM_DEBUG(llvm::dbgs()
1734                << "  * Treating values as trailing variadic range\n");
1735     values = values.drop_front(index);
1736 
1737     // If we couldn't detect a way to compute the values, bail out.
1738   } else {
1739     return nullptr;
1740   }
1741 
1742   // If the range index is valid, we are returning a range.
1743   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1744     valueRangeMemory[rangeIndex] = values;
1745     return &valueRangeMemory[rangeIndex];
1746   }
1747 
1748   // If a range index wasn't provided, the range is required to be non-variadic.
1749   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1750 }
1751 
1752 void ByteCodeExecutor::executeGetOperands() {
1753   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1754   unsigned index = read<uint32_t>();
1755   Operation *op = read<Operation *>();
1756   ByteCodeField rangeIndex = read();
1757 
1758   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1759       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1760       valueRangeMemory);
1761   if (!result)
1762     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1763   memory[read()] = result;
1764 }
1765 
1766 void ByteCodeExecutor::executeGetResult(unsigned index) {
1767   Operation *op = read<Operation *>();
1768   unsigned memIndex = read();
1769   OpResult result =
1770       index < op->getNumResults() ? op->getResult(index) : OpResult();
1771 
1772   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1773                           << "  * Index: " << index << "\n"
1774                           << "  * Result: " << result << "\n");
1775   memory[memIndex] = result.getAsOpaquePointer();
1776 }
1777 
1778 void ByteCodeExecutor::executeGetResults() {
1779   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1780   unsigned index = read<uint32_t>();
1781   Operation *op = read<Operation *>();
1782   ByteCodeField rangeIndex = read();
1783 
1784   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1785       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1786       valueRangeMemory);
1787   if (!result)
1788     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1789   memory[read()] = result;
1790 }
1791 
1792 void ByteCodeExecutor::executeGetUsers() {
1793   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1794   unsigned memIndex = read();
1795   unsigned rangeIndex = read();
1796   OwningOpRange &range = opRangeMemory[rangeIndex];
1797   memory[memIndex] = &range;
1798 
1799   range = OwningOpRange();
1800   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1801     // Read the value.
1802     Value value = read<Value>();
1803     if (!value)
1804       return;
1805     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1806 
1807     // Extract the users of a single value.
1808     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1809     llvm::copy(value.getUsers(), range.begin());
1810   } else {
1811     // Read a range of values.
1812     ValueRange *values = read<ValueRange *>();
1813     if (!values)
1814       return;
1815     LLVM_DEBUG({
1816       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1817       llvm::interleaveComma(*values, llvm::dbgs());
1818       llvm::dbgs() << "\n";
1819     });
1820 
1821     // Extract all the users of a range of values.
1822     SmallVector<Operation *> users;
1823     for (Value value : *values)
1824       users.append(value.user_begin(), value.user_end());
1825     range = OwningOpRange(users.size());
1826     llvm::copy(users, range.begin());
1827   }
1828 
1829   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1830 }
1831 
1832 void ByteCodeExecutor::executeGetValueType() {
1833   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1834   unsigned memIndex = read();
1835   Value value = read<Value>();
1836   Type type = value ? value.getType() : Type();
1837 
1838   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1839                           << "  * Result: " << type << "\n");
1840   memory[memIndex] = type.getAsOpaquePointer();
1841 }
1842 
1843 void ByteCodeExecutor::executeGetValueRangeTypes() {
1844   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1845   unsigned memIndex = read();
1846   unsigned rangeIndex = read();
1847   ValueRange *values = read<ValueRange *>();
1848   if (!values) {
1849     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1850     memory[memIndex] = nullptr;
1851     return;
1852   }
1853 
1854   LLVM_DEBUG({
1855     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1856     llvm::interleaveComma(*values, llvm::dbgs());
1857     llvm::dbgs() << "\n  * Result: ";
1858     llvm::interleaveComma(values->getType(), llvm::dbgs());
1859     llvm::dbgs() << "\n";
1860   });
1861   typeRangeMemory[rangeIndex] = values->getType();
1862   memory[memIndex] = &typeRangeMemory[rangeIndex];
1863 }
1864 
1865 void ByteCodeExecutor::executeIsNotNull() {
1866   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1867   const void *value = read<const void *>();
1868 
1869   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1870   selectJump(value != nullptr);
1871 }
1872 
1873 void ByteCodeExecutor::executeRecordMatch(
1874     PatternRewriter &rewriter,
1875     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1876   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1877   unsigned patternIndex = read();
1878   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1879   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1880 
1881   // If the benefit of the pattern is impossible, skip the processing of the
1882   // rest of the pattern.
1883   if (benefit.isImpossibleToMatch()) {
1884     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1885     curCodeIt = dest;
1886     return;
1887   }
1888 
1889   // Create a fused location containing the locations of each of the
1890   // operations used in the match. This will be used as the location for
1891   // created operations during the rewrite that don't already have an
1892   // explicit location set.
1893   unsigned numMatchLocs = read();
1894   SmallVector<Location, 4> matchLocs;
1895   matchLocs.reserve(numMatchLocs);
1896   for (unsigned i = 0; i != numMatchLocs; ++i)
1897     matchLocs.push_back(read<Operation *>()->getLoc());
1898   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1899 
1900   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1901                           << "  * Location: " << matchLoc << "\n");
1902   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1903   PDLByteCode::MatchResult &match = matches.back();
1904 
1905   // Record all of the inputs to the match. If any of the inputs are ranges, we
1906   // will also need to remap the range pointer to memory stored in the match
1907   // state.
1908   unsigned numInputs = read();
1909   match.values.reserve(numInputs);
1910   match.typeRangeValues.reserve(numInputs);
1911   match.valueRangeValues.reserve(numInputs);
1912   for (unsigned i = 0; i < numInputs; ++i) {
1913     switch (read<PDLValue::Kind>()) {
1914     case PDLValue::Kind::TypeRange:
1915       match.typeRangeValues.push_back(*read<TypeRange *>());
1916       match.values.push_back(&match.typeRangeValues.back());
1917       break;
1918     case PDLValue::Kind::ValueRange:
1919       match.valueRangeValues.push_back(*read<ValueRange *>());
1920       match.values.push_back(&match.valueRangeValues.back());
1921       break;
1922     default:
1923       match.values.push_back(read<const void *>());
1924       break;
1925     }
1926   }
1927   curCodeIt = dest;
1928 }
1929 
1930 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1931   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1932   Operation *op = read<Operation *>();
1933   SmallVector<Value, 16> args;
1934   readValueList(args);
1935 
1936   LLVM_DEBUG({
1937     llvm::dbgs() << "  * Operation: " << *op << "\n"
1938                  << "  * Values: ";
1939     llvm::interleaveComma(args, llvm::dbgs());
1940     llvm::dbgs() << "\n";
1941   });
1942   rewriter.replaceOp(op, args);
1943 }
1944 
1945 void ByteCodeExecutor::executeSwitchAttribute() {
1946   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1947   Attribute value = read<Attribute>();
1948   ArrayAttr cases = read<ArrayAttr>();
1949   handleSwitch(value, cases);
1950 }
1951 
1952 void ByteCodeExecutor::executeSwitchOperandCount() {
1953   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1954   Operation *op = read<Operation *>();
1955   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1956 
1957   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1958   handleSwitch(op->getNumOperands(), cases);
1959 }
1960 
1961 void ByteCodeExecutor::executeSwitchOperationName() {
1962   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1963   OperationName value = read<Operation *>()->getName();
1964   size_t caseCount = read();
1965 
1966   // The operation names are stored in-line, so to print them out for
1967   // debugging purposes we need to read the array before executing the
1968   // switch so that we can display all of the possible values.
1969   LLVM_DEBUG({
1970     const ByteCodeField *prevCodeIt = curCodeIt;
1971     llvm::dbgs() << "  * Value: " << value << "\n"
1972                  << "  * Cases: ";
1973     llvm::interleaveComma(
1974         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1975                         [&](size_t) { return read<OperationName>(); }),
1976         llvm::dbgs());
1977     llvm::dbgs() << "\n";
1978     curCodeIt = prevCodeIt;
1979   });
1980 
1981   // Try to find the switch value within any of the cases.
1982   for (size_t i = 0; i != caseCount; ++i) {
1983     if (read<OperationName>() == value) {
1984       curCodeIt += (caseCount - i - 1);
1985       return selectJump(i + 1);
1986     }
1987   }
1988   selectJump(size_t(0));
1989 }
1990 
1991 void ByteCodeExecutor::executeSwitchResultCount() {
1992   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1993   Operation *op = read<Operation *>();
1994   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1995 
1996   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1997   handleSwitch(op->getNumResults(), cases);
1998 }
1999 
2000 void ByteCodeExecutor::executeSwitchType() {
2001   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2002   Type value = read<Type>();
2003   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2004   handleSwitch(value, cases);
2005 }
2006 
2007 void ByteCodeExecutor::executeSwitchTypes() {
2008   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2009   TypeRange *value = read<TypeRange *>();
2010   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2011   if (!value) {
2012     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2013     return selectJump(size_t(0));
2014   }
2015   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2016     return value == caseValue.getAsValueRange<TypeAttr>();
2017   });
2018 }
2019 
2020 void ByteCodeExecutor::execute(
2021     PatternRewriter &rewriter,
2022     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2023     Optional<Location> mainRewriteLoc) {
2024   while (true) {
2025     // Print the location of the operation being executed.
2026     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2027 
2028     OpCode opCode = static_cast<OpCode>(read());
2029     switch (opCode) {
2030     case ApplyConstraint:
2031       executeApplyConstraint(rewriter);
2032       break;
2033     case ApplyRewrite:
2034       executeApplyRewrite(rewriter);
2035       break;
2036     case AreEqual:
2037       executeAreEqual();
2038       break;
2039     case AreRangesEqual:
2040       executeAreRangesEqual();
2041       break;
2042     case Branch:
2043       executeBranch();
2044       break;
2045     case CheckOperandCount:
2046       executeCheckOperandCount();
2047       break;
2048     case CheckOperationName:
2049       executeCheckOperationName();
2050       break;
2051     case CheckResultCount:
2052       executeCheckResultCount();
2053       break;
2054     case CheckTypes:
2055       executeCheckTypes();
2056       break;
2057     case Continue:
2058       executeContinue();
2059       break;
2060     case CreateOperation:
2061       executeCreateOperation(rewriter, *mainRewriteLoc);
2062       break;
2063     case CreateTypes:
2064       executeCreateTypes();
2065       break;
2066     case EraseOp:
2067       executeEraseOp(rewriter);
2068       break;
2069     case ExtractOp:
2070       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2071       break;
2072     case ExtractType:
2073       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2074       break;
2075     case ExtractValue:
2076       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2077       break;
2078     case Finalize:
2079       executeFinalize();
2080       LLVM_DEBUG(llvm::dbgs() << "\n");
2081       return;
2082     case ForEach:
2083       executeForEach();
2084       break;
2085     case GetAttribute:
2086       executeGetAttribute();
2087       break;
2088     case GetAttributeType:
2089       executeGetAttributeType();
2090       break;
2091     case GetDefiningOp:
2092       executeGetDefiningOp();
2093       break;
2094     case GetOperand0:
2095     case GetOperand1:
2096     case GetOperand2:
2097     case GetOperand3: {
2098       unsigned index = opCode - GetOperand0;
2099       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2100       executeGetOperand(index);
2101       break;
2102     }
2103     case GetOperandN:
2104       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2105       executeGetOperand(read<uint32_t>());
2106       break;
2107     case GetOperands:
2108       executeGetOperands();
2109       break;
2110     case GetResult0:
2111     case GetResult1:
2112     case GetResult2:
2113     case GetResult3: {
2114       unsigned index = opCode - GetResult0;
2115       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2116       executeGetResult(index);
2117       break;
2118     }
2119     case GetResultN:
2120       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2121       executeGetResult(read<uint32_t>());
2122       break;
2123     case GetResults:
2124       executeGetResults();
2125       break;
2126     case GetUsers:
2127       executeGetUsers();
2128       break;
2129     case GetValueType:
2130       executeGetValueType();
2131       break;
2132     case GetValueRangeTypes:
2133       executeGetValueRangeTypes();
2134       break;
2135     case IsNotNull:
2136       executeIsNotNull();
2137       break;
2138     case RecordMatch:
2139       assert(matches &&
2140              "expected matches to be provided when executing the matcher");
2141       executeRecordMatch(rewriter, *matches);
2142       break;
2143     case ReplaceOp:
2144       executeReplaceOp(rewriter);
2145       break;
2146     case SwitchAttribute:
2147       executeSwitchAttribute();
2148       break;
2149     case SwitchOperandCount:
2150       executeSwitchOperandCount();
2151       break;
2152     case SwitchOperationName:
2153       executeSwitchOperationName();
2154       break;
2155     case SwitchResultCount:
2156       executeSwitchResultCount();
2157       break;
2158     case SwitchType:
2159       executeSwitchType();
2160       break;
2161     case SwitchTypes:
2162       executeSwitchTypes();
2163       break;
2164     }
2165     LLVM_DEBUG(llvm::dbgs() << "\n");
2166   }
2167 }
2168 
2169 /// Run the pattern matcher on the given root operation, collecting the matched
2170 /// patterns in `matches`.
2171 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2172                         SmallVectorImpl<MatchResult> &matches,
2173                         PDLByteCodeMutableState &state) const {
2174   // The first memory slot is always the root operation.
2175   state.memory[0] = op;
2176 
2177   // The matcher function always starts at code address 0.
2178   ByteCodeExecutor executor(
2179       matcherByteCode.data(), state.memory, state.opRangeMemory,
2180       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2181       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2182       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2183       constraintFunctions, rewriteFunctions);
2184   executor.execute(rewriter, &matches);
2185 
2186   // Order the found matches by benefit.
2187   std::stable_sort(matches.begin(), matches.end(),
2188                    [](const MatchResult &lhs, const MatchResult &rhs) {
2189                      return lhs.benefit > rhs.benefit;
2190                    });
2191 }
2192 
2193 /// Run the rewriter of the given pattern on the root operation `op`.
2194 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2195                           PDLByteCodeMutableState &state) const {
2196   // The arguments of the rewrite function are stored at the start of the
2197   // memory buffer.
2198   llvm::copy(match.values, state.memory.begin());
2199 
2200   ByteCodeExecutor executor(
2201       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2202       state.opRangeMemory, state.typeRangeMemory,
2203       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2204       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2205       rewriterByteCode, state.currentPatternBenefits, patterns,
2206       constraintFunctions, rewriteFunctions);
2207   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2208 }
2209