xref: /llvm-project/mlir/lib/Rewrite/ByteCode.cpp (revision 0ec318e57b0c7cbc0d9c899390daf3248cbe6a53)
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
create(pdl_interp::RecordMatchOp matchOp,PDLPatternConfigSet * configSet,ByteCodeAddr rewriterAddr)37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38                                               PDLPatternConfigSet *configSet,
39                                               ByteCodeAddr rewriterAddr) {
40   PatternBenefit benefit = matchOp.getBenefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Collect the set of generated operations.
44   SmallVector<StringRef, 8> generatedOps;
45   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46     generatedOps =
47         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49   // Check to see if this is pattern matches a specific operation type.
50   if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51     return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52                               generatedOps);
53   return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54                             benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65                                                    PatternBenefit benefit) {
66   currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
cleanupAfterMatchAndRewrite()72 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73   allocatedTypeRangeMemory.clear();
74   allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83   /// Apply an externally registered constraint.
84   ApplyConstraint,
85   /// Apply an externally registered rewrite.
86   ApplyRewrite,
87   /// Check if two generic values are equal.
88   AreEqual,
89   /// Check if two ranges are equal.
90   AreRangesEqual,
91   /// Unconditional branch.
92   Branch,
93   /// Compare the operand count of an operation with a constant.
94   CheckOperandCount,
95   /// Compare the name of an operation with a constant.
96   CheckOperationName,
97   /// Compare the result count of an operation with a constant.
98   CheckResultCount,
99   /// Compare a range of types to a constant range of types.
100   CheckTypes,
101   /// Continue to the next iteration of a loop.
102   Continue,
103   /// Create a type range from a list of constant types.
104   CreateConstantTypeRange,
105   /// Create an operation.
106   CreateOperation,
107   /// Create a type range from a list of dynamic types.
108   CreateDynamicTypeRange,
109   /// Create a value range.
110   CreateDynamicValueRange,
111   /// Erase an operation.
112   EraseOp,
113   /// Extract the op from a range at the specified index.
114   ExtractOp,
115   /// Extract the type from a range at the specified index.
116   ExtractType,
117   /// Extract the value from a range at the specified index.
118   ExtractValue,
119   /// Terminate a matcher or rewrite sequence.
120   Finalize,
121   /// Iterate over a range of values.
122   ForEach,
123   /// Get a specific attribute of an operation.
124   GetAttribute,
125   /// Get the type of an attribute.
126   GetAttributeType,
127   /// Get the defining operation of a value.
128   GetDefiningOp,
129   /// Get a specific operand of an operation.
130   GetOperand0,
131   GetOperand1,
132   GetOperand2,
133   GetOperand3,
134   GetOperandN,
135   /// Get a specific operand group of an operation.
136   GetOperands,
137   /// Get a specific result of an operation.
138   GetResult0,
139   GetResult1,
140   GetResult2,
141   GetResult3,
142   GetResultN,
143   /// Get a specific result group of an operation.
144   GetResults,
145   /// Get the users of a value or a range of values.
146   GetUsers,
147   /// Get the type of a value.
148   GetValueType,
149   /// Get the types of a value range.
150   GetValueRangeTypes,
151   /// Check if a generic value is not null.
152   IsNotNull,
153   /// Record a successful pattern match.
154   RecordMatch,
155   /// Replace an operation.
156   ReplaceOp,
157   /// Compare an attribute with a set of constants.
158   SwitchAttribute,
159   /// Compare the operand count of an operation with a set of constants.
160   SwitchOperandCount,
161   /// Compare the name of an operation with a set of constants.
162   SwitchOperationName,
163   /// Compare the result count of an operation with a set of constants.
164   SwitchResultCount,
165   /// Compare a type with a set of constants.
166   SwitchType,
167   /// Compare a range of types with a set of constants.
168   SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
174     std::numeric_limits<ByteCodeField>::max();
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxOpRangeMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,ByteCodeField & maxLoopLevel,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns,const DenseMap<Operation *,PDLPatternConfigSet * > & configMap)194   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195             SmallVectorImpl<ByteCodeField> &matcherByteCode,
196             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197             SmallVectorImpl<PDLByteCodePattern> &patterns,
198             ByteCodeField &maxValueMemoryIndex,
199             ByteCodeField &maxOpRangeMemoryIndex,
200             ByteCodeField &maxTypeRangeMemoryIndex,
201             ByteCodeField &maxValueRangeMemoryIndex,
202             ByteCodeField &maxLoopLevel,
203             llvm::StringMap<PDLConstraintFunction> &constraintFns,
204             llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205             const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207         rewriterByteCode(rewriterByteCode), patterns(patterns),
208         maxValueMemoryIndex(maxValueMemoryIndex),
209         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212         maxLoopLevel(maxLoopLevel), configMap(configMap) {
213     for (const auto &it : llvm::enumerate(constraintFns))
214       constraintToMemIndex.try_emplace(it.value().first(), it.index());
215     for (const auto &it : llvm::enumerate(rewriteFns))
216       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217   }
218 
219   /// Generate the bytecode for the given PDL interpreter module.
220   void generate(ModuleOp module);
221 
222   /// Return the memory index to use for the given value.
getMemIndex(Value value)223   ByteCodeField &getMemIndex(Value value) {
224     assert(valueToMemIndex.count(value) &&
225            "expected memory index to be assigned");
226     return valueToMemIndex[value];
227   }
228 
229   /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)230   ByteCodeField &getRangeStorageIndex(Value value) {
231     assert(valueToRangeIndex.count(value) &&
232            "expected range index to be assigned");
233     return valueToRangeIndex[value];
234   }
235 
236   /// Return an index to use when referring to the given data that is uniqued in
237   /// the MLIR context.
238   template <typename T>
239   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)240   getMemIndex(T val) {
241     const void *opaqueVal = val.getAsOpaquePointer();
242 
243     // Get or insert a reference to this value.
244     auto it = uniquedDataToMemIndex.try_emplace(
245         opaqueVal, maxValueMemoryIndex + uniquedData.size());
246     if (it.second)
247       uniquedData.push_back(opaqueVal);
248     return it.first->second;
249   }
250 
251 private:
252   /// Allocate memory indices for the results of operations within the matcher
253   /// and rewriters.
254   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255                              ModuleOp rewriterModule);
256 
257   /// Generate the bytecode for the given operation.
258   void generate(Region *region, ByteCodeWriter &writer);
259   void generate(Operation *op, ByteCodeWriter &writer);
260   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299   /// Mapping from value to its corresponding memory index.
300   DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302   /// Mapping from a range value to its corresponding range storage index.
303   DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305   /// Mapping from the name of an externally registered rewrite to its index in
306   /// the bytecode registry.
307   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309   /// Mapping from the name of an externally registered constraint to its index
310   /// in the bytecode registry.
311   llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313   /// Mapping from rewriter function name to the bytecode address of the
314   /// rewriter function in byte.
315   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317   /// Mapping from a uniqued storage object to its memory index within
318   /// `uniquedData`.
319   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321   /// The current level of the foreach loop.
322   ByteCodeField curLoopLevel = 0;
323 
324   /// The current MLIR context.
325   MLIRContext *ctx;
326 
327   /// Mapping from block to its address.
328   DenseMap<Block *, ByteCodeAddr> blockToAddr;
329 
330   /// Data of the ByteCode class to be populated.
331   std::vector<const void *> &uniquedData;
332   SmallVectorImpl<ByteCodeField> &matcherByteCode;
333   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334   SmallVectorImpl<PDLByteCodePattern> &patterns;
335   ByteCodeField &maxValueMemoryIndex;
336   ByteCodeField &maxOpRangeMemoryIndex;
337   ByteCodeField &maxTypeRangeMemoryIndex;
338   ByteCodeField &maxValueRangeMemoryIndex;
339   ByteCodeField &maxLoopLevel;
340 
341   /// A map of pattern configurations.
342   const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
ByteCodeWriter__anon22ebf8dc0211::ByteCodeWriter347   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348       : bytecode(bytecode), generator(generator) {}
349 
350   /// Append a field to the bytecode.
append__anon22ebf8dc0211::ByteCodeWriter351   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon22ebf8dc0211::ByteCodeWriter352   void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354   /// Append an address to the bytecode.
append__anon22ebf8dc0211::ByteCodeWriter355   void append(ByteCodeAddr field) {
356     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357                   "unexpected ByteCode address size");
358 
359     ByteCodeField fieldParts[2];
360     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361     bytecode.append({fieldParts[0], fieldParts[1]});
362   }
363 
364   /// Append a single successor to the bytecode, the exact address will need to
365   /// be resolved later.
append__anon22ebf8dc0211::ByteCodeWriter366   void append(Block *successor) {
367     // Add back a reference to the successor so that the address can be resolved
368     // later.
369     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370     append(ByteCodeAddr(0));
371   }
372 
373   /// Append a successor range to the bytecode, the exact address will need to
374   /// be resolved later.
append__anon22ebf8dc0211::ByteCodeWriter375   void append(SuccessorRange successors) {
376     for (Block *successor : successors)
377       append(successor);
378   }
379 
380   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon22ebf8dc0211::ByteCodeWriter381   void appendPDLValueList(OperandRange values) {
382     bytecode.push_back(values.size());
383     for (Value value : values)
384       appendPDLValue(value);
385   }
386 
387   /// Append a value as a PDLValue.
appendPDLValue__anon22ebf8dc0211::ByteCodeWriter388   void appendPDLValue(Value value) {
389     appendPDLValueKind(value);
390     append(value);
391   }
392 
393   /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anon22ebf8dc0211::ByteCodeWriter394   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396   /// Append the PDLValue::Kind of the given type.
appendPDLValueKind__anon22ebf8dc0211::ByteCodeWriter397   void appendPDLValueKind(Type type) {
398     PDLValue::Kind kind =
399         TypeSwitch<Type, PDLValue::Kind>(type)
400             .Case<pdl::AttributeType>(
401                 [](Type) { return PDLValue::Kind::Attribute; })
402             .Case<pdl::OperationType>(
403                 [](Type) { return PDLValue::Kind::Operation; })
404             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405               if (isa<pdl::TypeType>(rangeTy.getElementType()))
406                 return PDLValue::Kind::TypeRange;
407               return PDLValue::Kind::ValueRange;
408             })
409             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411     bytecode.push_back(static_cast<ByteCodeField>(kind));
412   }
413 
414   /// Append a value that will be stored in a memory slot and not inline within
415   /// the bytecode.
416   template <typename T>
417   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418                    std::is_pointer<T>::value>
append__anon22ebf8dc0211::ByteCodeWriter419   append(T value) {
420     bytecode.push_back(generator.getMemIndex(value));
421   }
422 
423   /// Append a range of values.
424   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon22ebf8dc0211::ByteCodeWriter426   append(T range) {
427     bytecode.push_back(llvm::size(range));
428     for (auto it : range)
429       append(it);
430   }
431 
432   /// Append a variadic number of fields to the bytecode.
433   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon22ebf8dc0211::ByteCodeWriter434   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435     append(field);
436     append(field2, fields...);
437   }
438 
439   /// Appends a value as a pointer, stored inline within the bytecode.
440   template <typename T>
441   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
appendInline__anon22ebf8dc0211::ByteCodeWriter442   appendInline(T value) {
443     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444     const void *pointer = value.getAsOpaquePointer();
445     ByteCodeField fieldParts[numParts];
446     std::memcpy(fieldParts, &pointer, sizeof(const void *));
447     bytecode.append(fieldParts, fieldParts + numParts);
448   }
449 
450   /// Successor references in the bytecode that have yet to be resolved.
451   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453   /// The underlying bytecode buffer.
454   SmallVectorImpl<ByteCodeField> &bytecode;
455 
456   /// The main generator producing PDL.
457   Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463   using Set = llvm::IntervalMap<uint64_t, char, 16>;
464   using Allocator = Set::Allocator;
465 
ByteCodeLiveRange__anon22ebf8dc0211::ByteCodeLiveRange466   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468   /// Union this live range with the one provided.
unionWith__anon22ebf8dc0211::ByteCodeLiveRange469   void unionWith(const ByteCodeLiveRange &rhs) {
470     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471          ++it)
472       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473   }
474 
475   /// Returns true if this range overlaps with the one provided.
overlaps__anon22ebf8dc0211::ByteCodeLiveRange476   bool overlaps(const ByteCodeLiveRange &rhs) const {
477     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478         .valid();
479   }
480 
481   /// A map representing the ranges of the match/rewrite that a value is live in
482   /// the interpreter.
483   ///
484   /// We use std::unique_ptr here, because IntervalMap does not provide a
485   /// correct copy or move constructor. We can eliminate the pointer once
486   /// https://reviews.llvm.org/D113240 lands.
487   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489   /// The operation range storage index for this range.
490   std::optional<unsigned> opRangeIndex;
491 
492   /// The type range storage index for this range.
493   std::optional<unsigned> typeRangeIndex;
494 
495   /// The value range storage index for this range.
496   std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
generate(ModuleOp module)500 void Generator::generate(ModuleOp module) {
501   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504       pdl_interp::PDLInterpDialect::getRewriterModuleName());
505   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507   // Allocate memory indices for the results of operations within the matcher
508   // and rewriters.
509   allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511   // Generate code for the rewriter functions.
512   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515     for (Operation &op : rewriterFunc.getOps())
516       generate(&op, rewriterByteCodeWriter);
517   }
518   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519          "unexpected branches in rewriter function");
520 
521   // Generate code for the matcher function.
522   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525   // Resolve successor references in the matcher.
526   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527     ByteCodeAddr addr = blockToAddr[it.first];
528     for (unsigned offsetToFix : it.second)
529       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530   }
531 }
532 
allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534                                       ModuleOp rewriterModule) {
535   // Rewriters use simplistic allocation scheme that simply assigns an index to
536   // each result.
537   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539     auto processRewriterValue = [&](Value val) {
540       valueToMemIndex.try_emplace(val, index++);
541       if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542         Type elementTy = rangeType.getElementType();
543         if (isa<pdl::TypeType>(elementTy))
544           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545         else if (isa<pdl::ValueType>(elementTy))
546           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547       }
548     };
549 
550     for (BlockArgument arg : rewriterFunc.getArguments())
551       processRewriterValue(arg);
552     rewriterFunc.getBody().walk([&](Operation *op) {
553       for (Value result : op->getResults())
554         processRewriterValue(result);
555     });
556     if (index > maxValueMemoryIndex)
557       maxValueMemoryIndex = index;
558     if (typeRangeIndex > maxTypeRangeMemoryIndex)
559       maxTypeRangeMemoryIndex = typeRangeIndex;
560     if (valueRangeIndex > maxValueRangeMemoryIndex)
561       maxValueRangeMemoryIndex = valueRangeIndex;
562   }
563 
564   // The matcher function uses a more sophisticated numbering that tries to
565   // minimize the number of memory indices assigned. This is done by determining
566   // a live range of the values within the matcher, then the allocation is just
567   // finding the minimal number of overlapping live ranges. This is essentially
568   // a simplified form of register allocation where we don't necessarily have a
569   // limited number of registers, but we still want to minimize the number used.
570   DenseMap<Operation *, unsigned> opToFirstIndex;
571   DenseMap<Operation *, unsigned> opToLastIndex;
572 
573   // A custom walk that marks the first and the last index of each operation.
574   // The entry marks the beginning of the liveness range for this operation,
575   // followed by nested operations, followed by the end of the liveness range.
576   unsigned index = 0;
577   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578     opToFirstIndex.try_emplace(op, index++);
579     for (Region &region : op->getRegions())
580       for (Block &block : region.getBlocks())
581         for (Operation &nested : block)
582           walk(&nested);
583     opToLastIndex.try_emplace(op, index++);
584   };
585   walk(matcherFunc);
586 
587   // Liveness info for each of the defs within the matcher.
588   ByteCodeLiveRange::Allocator allocator;
589   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591   // Assign the root operation being matched to slot 0.
592   BlockArgument rootOpArg = matcherFunc.getArgument(0);
593   valueToMemIndex[rootOpArg] = 0;
594 
595   // Walk each of the blocks, computing the def interval that the value is used.
596   Liveness matcherLiveness(matcherFunc);
597   matcherFunc->walk([&](Block *block) {
598     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599     assert(info && "expected liveness info for block");
600     auto processValue = [&](Value value, Operation *firstUseOrDef) {
601       // We don't need to process the root op argument, this value is always
602       // assigned to the first memory slot.
603       if (value == rootOpArg)
604         return;
605 
606       // Set indices for the range of this block that the value is used.
607       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608       defRangeIt->second.liveness->insert(
609           opToFirstIndex[firstUseOrDef],
610           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611           /*dummyValue*/ 0);
612 
613       // Check to see if this value is a range type.
614       if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615         Type eleType = rangeTy.getElementType();
616         if (isa<pdl::OperationType>(eleType))
617           defRangeIt->second.opRangeIndex = 0;
618         else if (isa<pdl::TypeType>(eleType))
619           defRangeIt->second.typeRangeIndex = 0;
620         else if (isa<pdl::ValueType>(eleType))
621           defRangeIt->second.valueRangeIndex = 0;
622       }
623     };
624 
625     // Process the live-ins of this block.
626     for (Value liveIn : info->in()) {
627       // Only process the value if it has been defined in the current region.
628       // Other values that span across pdl_interp.foreach will be added higher
629       // up. This ensures that the we keep them alive for the entire duration
630       // of the loop.
631       if (liveIn.getParentRegion() == block->getParent())
632         processValue(liveIn, &block->front());
633     }
634 
635     // Process the block arguments for the entry block (those are not live-in).
636     if (block->isEntryBlock()) {
637       for (Value argument : block->getArguments())
638         processValue(argument, &block->front());
639     }
640 
641     // Process any new defs within this block.
642     for (Operation &op : *block)
643       for (Value result : op.getResults())
644         processValue(result, &op);
645   });
646 
647   // Greedily allocate memory slots using the computed def live ranges.
648   std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650   // The number of memory indices currently allocated (and its next value).
651   // Recall that the root gets allocated memory index 0.
652   ByteCodeField numIndices = 1;
653 
654   // The number of memory ranges of various types (and their next values).
655   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657   for (auto &defIt : valueDefRanges) {
658     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659     ByteCodeLiveRange &defRange = defIt.second;
660 
661     // Try to allocate to an existing index.
662     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663       ByteCodeLiveRange &existingRange = existingIndexIt.value();
664       if (!defRange.overlaps(existingRange)) {
665         existingRange.unionWith(defRange);
666         memIndex = existingIndexIt.index() + 1;
667 
668         if (defRange.opRangeIndex) {
669           if (!existingRange.opRangeIndex)
670             existingRange.opRangeIndex = numOpRanges++;
671           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672         } else if (defRange.typeRangeIndex) {
673           if (!existingRange.typeRangeIndex)
674             existingRange.typeRangeIndex = numTypeRanges++;
675           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676         } else if (defRange.valueRangeIndex) {
677           if (!existingRange.valueRangeIndex)
678             existingRange.valueRangeIndex = numValueRanges++;
679           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680         }
681         break;
682       }
683     }
684 
685     // If no existing index could be used, add a new one.
686     if (memIndex == 0) {
687       allocatedIndices.emplace_back(allocator);
688       ByteCodeLiveRange &newRange = allocatedIndices.back();
689       newRange.unionWith(defRange);
690 
691       // Allocate an index for op/type/value ranges.
692       if (defRange.opRangeIndex) {
693         newRange.opRangeIndex = numOpRanges;
694         valueToRangeIndex[defIt.first] = numOpRanges++;
695       } else if (defRange.typeRangeIndex) {
696         newRange.typeRangeIndex = numTypeRanges;
697         valueToRangeIndex[defIt.first] = numTypeRanges++;
698       } else if (defRange.valueRangeIndex) {
699         newRange.valueRangeIndex = numValueRanges;
700         valueToRangeIndex[defIt.first] = numValueRanges++;
701       }
702 
703       memIndex = allocatedIndices.size();
704       ++numIndices;
705     }
706   }
707 
708   // Print the index usage and ensure that we did not run out of index space.
709   LLVM_DEBUG({
710     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711                  << "(down from initial " << valueDefRanges.size() << ").\n";
712   });
713   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714          "Ran out of memory for allocated indices");
715 
716   // Update the max number of indices.
717   if (numIndices > maxValueMemoryIndex)
718     maxValueMemoryIndex = numIndices;
719   if (numOpRanges > maxOpRangeMemoryIndex)
720     maxOpRangeMemoryIndex = numOpRanges;
721   if (numTypeRanges > maxTypeRangeMemoryIndex)
722     maxTypeRangeMemoryIndex = numTypeRanges;
723   if (numValueRanges > maxValueRangeMemoryIndex)
724     maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
generate(Region * region,ByteCodeWriter & writer)727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728   llvm::ReversePostOrderTraversal<Region *> rpot(region);
729   for (Block *block : rpot) {
730     // Keep track of where this block begins within the matcher function.
731     blockToAddr.try_emplace(block, matcherByteCode.size());
732     for (Operation &op : *block)
733       generate(&op, writer);
734   }
735 }
736 
generate(Operation * op,ByteCodeWriter & writer)737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738   LLVM_DEBUG({
739     // The following list must contain all the operations that do not
740     // produce any bytecode.
741     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742       writer.appendInline(op->getLoc());
743   });
744   TypeSwitch<Operation *>(op)
745       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751             pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753             pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763             pdl_interp::SwitchResultCountOp>(
764           [&](auto interpOp) { this->generate(interpOp, writer); })
765       .Default([](Operation *) {
766         llvm_unreachable("unknown `pdl_interp` operation");
767       });
768 }
769 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771                          ByteCodeWriter &writer) {
772   // Constraints that should return a value have to be registered as rewrites.
773   // If a constraint and a rewrite of similar name are registered the
774   // constraint takes precedence
775   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776   writer.appendPDLValueList(op.getArgs());
777   writer.append(ByteCodeField(op.getIsNegated()));
778   ResultRange results = op.getResults();
779   writer.append(ByteCodeField(results.size()));
780   for (Value result : results) {
781     // We record the expected kind of the result, so that we can provide extra
782     // verification of the native rewrite function and handle the failure case
783     // of constraints accordingly.
784     writer.appendPDLValueKind(result);
785 
786     // Range results also need to append the range storage index.
787     if (isa<pdl::RangeType>(result.getType()))
788       writer.append(getRangeStorageIndex(result));
789     writer.append(result);
790   }
791   writer.append(op.getSuccessors());
792 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)793 void Generator::generate(pdl_interp::ApplyRewriteOp op,
794                          ByteCodeWriter &writer) {
795   assert(externalRewriterToMemIndex.count(op.getName()) &&
796          "expected index for rewrite function");
797   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798   writer.appendPDLValueList(op.getArgs());
799 
800   ResultRange results = op.getResults();
801   writer.append(ByteCodeField(results.size()));
802   for (Value result : results) {
803     // We record the expected kind of the result, so that we
804     // can provide extra verification of the native rewrite function.
805     writer.appendPDLValueKind(result);
806 
807     // Range results also need to append the range storage index.
808     if (isa<pdl::RangeType>(result.getType()))
809       writer.append(getRangeStorageIndex(result));
810     writer.append(result);
811   }
812 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)813 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814   Value lhs = op.getLhs();
815   if (isa<pdl::RangeType>(lhs.getType())) {
816     writer.append(OpCode::AreRangesEqual);
817     writer.appendPDLValueKind(lhs);
818     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819     return;
820   }
821 
822   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)824 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)827 void Generator::generate(pdl_interp::CheckAttributeOp op,
828                          ByteCodeWriter &writer) {
829   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830                 op.getSuccessors());
831 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)832 void Generator::generate(pdl_interp::CheckOperandCountOp op,
833                          ByteCodeWriter &writer) {
834   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
836                 op.getSuccessors());
837 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)838 void Generator::generate(pdl_interp::CheckOperationNameOp op,
839                          ByteCodeWriter &writer) {
840   writer.append(OpCode::CheckOperationName, op.getInputOp(),
841                 OperationName(op.getName(), ctx), op.getSuccessors());
842 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)843 void Generator::generate(pdl_interp::CheckResultCountOp op,
844                          ByteCodeWriter &writer) {
845   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
847                 op.getSuccessors());
848 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)849 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851                 op.getSuccessors());
852 }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)853 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855                 op.getSuccessors());
856 }
generate(pdl_interp::ContinueOp op,ByteCodeWriter & writer)857 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)861 void Generator::generate(pdl_interp::CreateAttributeOp op,
862                          ByteCodeWriter &writer) {
863   // Simply repoint the memory index of the result to the constant.
864   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)866 void Generator::generate(pdl_interp::CreateOperationOp op,
867                          ByteCodeWriter &writer) {
868   writer.append(OpCode::CreateOperation, op.getResultOp(),
869                 OperationName(op.getName(), ctx));
870   writer.appendPDLValueList(op.getInputOperands());
871 
872   // Add the attributes.
873   OperandRange attributes = op.getInputAttributes();
874   writer.append(static_cast<ByteCodeField>(attributes.size()));
875   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876     writer.append(std::get<0>(it), std::get<1>(it));
877 
878   // Add the result types. If the operation has inferred results, we use a
879   // marker "size" value. Otherwise, we add the list of explicit result types.
880   if (op.getInferredResultTypes())
881     writer.append(kInferTypesMarker);
882   else
883     writer.appendPDLValueList(op.getInputResultTypes());
884 }
generate(pdl_interp::CreateRangeOp op,ByteCodeWriter & writer)885 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886   // Append the correct opcode for the range type.
887   TypeSwitch<Type>(op.getType().getElementType())
888       .Case(
889           [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890       .Case([&](pdl::ValueType) {
891         writer.append(OpCode::CreateDynamicValueRange);
892       });
893 
894   writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895   writer.appendPDLValueList(op->getOperands());
896 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)897 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898   // Simply repoint the memory index of the result to the constant.
899   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900 }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)901 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902   writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903                 getRangeStorageIndex(op.getResult()), op.getValue());
904 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)905 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906   writer.append(OpCode::EraseOp, op.getInputOp());
907 }
generate(pdl_interp::ExtractOp op,ByteCodeWriter & writer)908 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909   OpCode opCode =
910       TypeSwitch<Type, OpCode>(op.getResult().getType())
911           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914           .Default([](Type) -> OpCode {
915             llvm_unreachable("unsupported element type");
916           });
917   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
918 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)919 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
920   writer.append(OpCode::Finalize);
921 }
generate(pdl_interp::ForEachOp op,ByteCodeWriter & writer)922 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
923   BlockArgument arg = op.getLoopVariable();
924   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
925   writer.appendPDLValueKind(arg.getType());
926   writer.append(curLoopLevel, op.getSuccessor());
927   ++curLoopLevel;
928   if (curLoopLevel > maxLoopLevel)
929     maxLoopLevel = curLoopLevel;
930   generate(&op.getRegion(), writer);
931   --curLoopLevel;
932 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)933 void Generator::generate(pdl_interp::GetAttributeOp op,
934                          ByteCodeWriter &writer) {
935   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
936                 op.getNameAttr());
937 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)938 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
939                          ByteCodeWriter &writer) {
940   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
941 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)942 void Generator::generate(pdl_interp::GetDefiningOpOp op,
943                          ByteCodeWriter &writer) {
944   writer.append(OpCode::GetDefiningOp, op.getInputOp());
945   writer.appendPDLValue(op.getValue());
946 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)947 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
948   uint32_t index = op.getIndex();
949   if (index < 4)
950     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
951   else
952     writer.append(OpCode::GetOperandN, index);
953   writer.append(op.getInputOp(), op.getValue());
954 }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)955 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
956   Value result = op.getValue();
957   std::optional<uint32_t> index = op.getIndex();
958   writer.append(OpCode::GetOperands,
959                 index.value_or(std::numeric_limits<uint32_t>::max()),
960                 op.getInputOp());
961   if (isa<pdl::RangeType>(result.getType()))
962     writer.append(getRangeStorageIndex(result));
963   else
964     writer.append(std::numeric_limits<ByteCodeField>::max());
965   writer.append(result);
966 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)967 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
968   uint32_t index = op.getIndex();
969   if (index < 4)
970     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
971   else
972     writer.append(OpCode::GetResultN, index);
973   writer.append(op.getInputOp(), op.getValue());
974 }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)975 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
976   Value result = op.getValue();
977   std::optional<uint32_t> index = op.getIndex();
978   writer.append(OpCode::GetResults,
979                 index.value_or(std::numeric_limits<uint32_t>::max()),
980                 op.getInputOp());
981   if (isa<pdl::RangeType>(result.getType()))
982     writer.append(getRangeStorageIndex(result));
983   else
984     writer.append(std::numeric_limits<ByteCodeField>::max());
985   writer.append(result);
986 }
generate(pdl_interp::GetUsersOp op,ByteCodeWriter & writer)987 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
988   Value operations = op.getOperations();
989   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
990   writer.append(OpCode::GetUsers, operations, rangeIndex);
991   writer.appendPDLValue(op.getValue());
992 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)993 void Generator::generate(pdl_interp::GetValueTypeOp op,
994                          ByteCodeWriter &writer) {
995   if (isa<pdl::RangeType>(op.getType())) {
996     Value result = op.getResult();
997     writer.append(OpCode::GetValueRangeTypes, result,
998                   getRangeStorageIndex(result), op.getValue());
999   } else {
1000     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001   }
1002 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)1003 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1004   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1005 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)1006 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1007   ByteCodeField patternIndex = patterns.size();
1008   patterns.emplace_back(PDLByteCodePattern::create(
1009       op, configMap.lookup(op),
1010       rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1011   writer.append(OpCode::RecordMatch, patternIndex,
1012                 SuccessorRange(op.getOperation()), op.getMatchedOps());
1013   writer.appendPDLValueList(op.getInputs());
1014 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)1015 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1016   writer.append(OpCode::ReplaceOp, op.getInputOp());
1017   writer.appendPDLValueList(op.getReplValues());
1018 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)1019 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1020                          ByteCodeWriter &writer) {
1021   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1022                 op.getCaseValuesAttr(), op.getSuccessors());
1023 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)1024 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1025                          ByteCodeWriter &writer) {
1026   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1027                 op.getCaseValuesAttr(), op.getSuccessors());
1028 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)1029 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1030                          ByteCodeWriter &writer) {
1031   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1032     return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1033   });
1034   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1035                 op.getSuccessors());
1036 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)1037 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1038                          ByteCodeWriter &writer) {
1039   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1040                 op.getCaseValuesAttr(), op.getSuccessors());
1041 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)1042 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1043   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1044                 op.getSuccessors());
1045 }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)1046 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1047   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1048                 op.getSuccessors());
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // PDLByteCode
1053 //===----------------------------------------------------------------------===//
1054 
PDLByteCode(ModuleOp module,SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,const DenseMap<Operation *,PDLPatternConfigSet * > & configMap,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)1055 PDLByteCode::PDLByteCode(
1056     ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1057     const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1058     llvm::StringMap<PDLConstraintFunction> constraintFns,
1059     llvm::StringMap<PDLRewriteFunction> rewriteFns)
1060     : configs(std::move(configs)) {
1061   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1062                       rewriterByteCode, patterns, maxValueMemoryIndex,
1063                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1064                       maxLoopLevel, constraintFns, rewriteFns, configMap);
1065   generator.generate(module);
1066 
1067   // Initialize the external functions.
1068   for (auto &it : constraintFns)
1069     constraintFunctions.push_back(std::move(it.second));
1070   for (auto &it : rewriteFns)
1071     rewriteFunctions.push_back(std::move(it.second));
1072 }
1073 
1074 /// Initialize the given state such that it can be used to execute the current
1075 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const1076 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1077   state.memory.resize(maxValueMemoryIndex, nullptr);
1078   state.opRangeMemory.resize(maxOpRangeCount);
1079   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1080   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1081   state.loopIndex.resize(maxLoopLevel, 0);
1082   state.currentPatternBenefits.reserve(patterns.size());
1083   for (const PDLByteCodePattern &pattern : patterns)
1084     state.currentPatternBenefits.push_back(pattern.getBenefit());
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // ByteCode Execution
1089 
1090 namespace {
1091 /// This class is an instantiation of the PDLResultList that provides access to
1092 /// the returned results. This API is not on `PDLResultList` to avoid
1093 /// overexposing access to information specific solely to the ByteCode.
1094 class ByteCodeRewriteResultList : public PDLResultList {
1095 public:
ByteCodeRewriteResultList(unsigned maxNumResults)1096   ByteCodeRewriteResultList(unsigned maxNumResults)
1097       : PDLResultList(maxNumResults) {}
1098 
1099   /// Return the list of PDL results.
getResults()1100   MutableArrayRef<PDLValue> getResults() { return results; }
1101 
1102   /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()1103   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1104     return allocatedTypeRanges;
1105   }
1106 
1107   /// Return the value ranges allocated by this list.
getAllocatedValueRanges()1108   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1109     return allocatedValueRanges;
1110   }
1111 };
1112 
1113 /// This class provides support for executing a bytecode stream.
1114 class ByteCodeExecutor {
1115 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<llvm::OwningArrayRef<Operation * >> opRangeMemory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,MutableArrayRef<unsigned> loopIndex,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)1116   ByteCodeExecutor(
1117       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1118       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1119       MutableArrayRef<TypeRange> typeRangeMemory,
1120       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1121       MutableArrayRef<ValueRange> valueRangeMemory,
1122       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1123       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1124       ArrayRef<ByteCodeField> code,
1125       ArrayRef<PatternBenefit> currentPatternBenefits,
1126       ArrayRef<PDLByteCodePattern> patterns,
1127       ArrayRef<PDLConstraintFunction> constraintFunctions,
1128       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130         typeRangeMemory(typeRangeMemory),
1131         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132         valueRangeMemory(valueRangeMemory),
1133         allocatedValueRangeMemory(allocatedValueRangeMemory),
1134         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1136         constraintFunctions(constraintFunctions),
1137         rewriteFunctions(rewriteFunctions) {}
1138 
1139   /// Start executing the code at the current bytecode index. `matches` is an
1140   /// optional field provided when this function is executed in a matching
1141   /// context.
1142   LogicalResult
1143   execute(PatternRewriter &rewriter,
1144           SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1145           std::optional<Location> mainRewriteLoc = {});
1146 
1147 private:
1148   /// Internal implementation of executing each of the bytecode commands.
1149   void executeApplyConstraint(PatternRewriter &rewriter);
1150   LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151   void executeAreEqual();
1152   void executeAreRangesEqual();
1153   void executeBranch();
1154   void executeCheckOperandCount();
1155   void executeCheckOperationName();
1156   void executeCheckResultCount();
1157   void executeCheckTypes();
1158   void executeContinue();
1159   void executeCreateConstantTypeRange();
1160   void executeCreateOperation(PatternRewriter &rewriter,
1161                               Location mainRewriteLoc);
1162   template <typename T>
1163   void executeDynamicCreateRange(StringRef type);
1164   void executeEraseOp(PatternRewriter &rewriter);
1165   template <typename T, typename Range, PDLValue::Kind kind>
1166   void executeExtract();
1167   void executeFinalize();
1168   void executeForEach();
1169   void executeGetAttribute();
1170   void executeGetAttributeType();
1171   void executeGetDefiningOp();
1172   void executeGetOperand(unsigned index);
1173   void executeGetOperands();
1174   void executeGetResult(unsigned index);
1175   void executeGetResults();
1176   void executeGetUsers();
1177   void executeGetValueType();
1178   void executeGetValueRangeTypes();
1179   void executeIsNotNull();
1180   void executeRecordMatch(PatternRewriter &rewriter,
1181                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182   void executeReplaceOp(PatternRewriter &rewriter);
1183   void executeSwitchAttribute();
1184   void executeSwitchOperandCount();
1185   void executeSwitchOperationName();
1186   void executeSwitchResultCount();
1187   void executeSwitchType();
1188   void executeSwitchTypes();
1189   void processNativeFunResults(ByteCodeRewriteResultList &results,
1190                                unsigned numResults,
1191                                LogicalResult &rewriteResult);
1192 
1193   /// Pushes a code iterator to the stack.
pushCodeIt(const ByteCodeField * it)1194   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1195 
1196   /// Pops a code iterator from the stack, returning true on success.
popCodeIt()1197   void popCodeIt() {
1198     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1199     curCodeIt = resumeCodeIt.back();
1200     resumeCodeIt.pop_back();
1201   }
1202 
1203   /// Return the bytecode iterator at the start of the current op code.
getPrevCodeIt() const1204   const ByteCodeField *getPrevCodeIt() const {
1205     LLVM_DEBUG({
1206       // Account for the op code and the Location stored inline.
1207       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1208     });
1209 
1210     // Account for the op code only.
1211     return curCodeIt - 1;
1212   }
1213 
1214   /// Read a value from the bytecode buffer, optionally skipping a certain
1215   /// number of prefix values. These methods always update the buffer to point
1216   /// to the next field after the read data.
1217   template <typename T = ByteCodeField>
read(size_t skipN=0)1218   T read(size_t skipN = 0) {
1219     curCodeIt += skipN;
1220     return readImpl<T>();
1221   }
read(size_t skipN=0)1222   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1223 
1224   /// Read a list of values from the bytecode buffer.
1225   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)1226   void readList(SmallVectorImpl<T> &list) {
1227     list.clear();
1228     for (unsigned i = 0, e = read(); i != e; ++i)
1229       list.push_back(read<ValueT>());
1230   }
1231 
1232   /// Read a list of values from the bytecode buffer. The values may be encoded
1233   /// either as a single element or a range of elements.
readList(SmallVectorImpl<Type> & list)1234   void readList(SmallVectorImpl<Type> &list) {
1235     for (unsigned i = 0, e = read(); i != e; ++i) {
1236       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1237         list.push_back(read<Type>());
1238       } else {
1239         TypeRange *values = read<TypeRange *>();
1240         list.append(values->begin(), values->end());
1241       }
1242     }
1243   }
readList(SmallVectorImpl<Value> & list)1244   void readList(SmallVectorImpl<Value> &list) {
1245     for (unsigned i = 0, e = read(); i != e; ++i) {
1246       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1247         list.push_back(read<Value>());
1248       } else {
1249         ValueRange *values = read<ValueRange *>();
1250         list.append(values->begin(), values->end());
1251       }
1252     }
1253   }
1254 
1255   /// Read a value stored inline as a pointer.
1256   template <typename T>
1257   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
readInline()1258   readInline() {
1259     const void *pointer;
1260     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1261     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1262     return T::getFromOpaquePointer(pointer);
1263   }
1264 
skip(size_t skipN)1265   void skip(size_t skipN) { curCodeIt += skipN; }
1266 
1267   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)1268   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1269   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)1270   void selectJump(size_t destIndex) {
1271     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1272   }
1273 
1274   /// Handle a switch operation with the provided value and cases.
1275   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})1276   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1277     LLVM_DEBUG({
1278       llvm::dbgs() << "  * Value: " << value << "\n"
1279                    << "  * Cases: ";
1280       llvm::interleaveComma(cases, llvm::dbgs());
1281       llvm::dbgs() << "\n";
1282     });
1283 
1284     // Check to see if the attribute value is within the case list. Jump to
1285     // the correct successor index based on the result.
1286     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1287       if (cmp(*it, value))
1288         return selectJump(size_t((it - cases.begin()) + 1));
1289     selectJump(size_t(0));
1290   }
1291 
1292   /// Store a pointer to memory.
storeToMemory(unsigned index,const void * value)1293   void storeToMemory(unsigned index, const void *value) {
1294     memory[index] = value;
1295   }
1296 
1297   /// Store a value to memory as an opaque pointer.
1298   template <typename T>
1299   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
storeToMemory(unsigned index,T value)1300   storeToMemory(unsigned index, T value) {
1301     memory[index] = value.getAsOpaquePointer();
1302   }
1303 
1304   /// Internal implementation of reading various data types from the bytecode
1305   /// stream.
1306   template <typename T>
readFromMemory()1307   const void *readFromMemory() {
1308     size_t index = *curCodeIt++;
1309 
1310     // If this type is an SSA value, it can only be stored in non-const memory.
1311     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1312                         Value>::value ||
1313         index < memory.size())
1314       return memory[index];
1315 
1316     // Otherwise, if this index is not inbounds it is uniqued.
1317     return uniquedMemory[index - memory.size()];
1318   }
1319   template <typename T>
readImpl()1320   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1321     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1322   }
1323   template <typename T>
1324   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1325                    T>
readImpl()1326   readImpl() {
1327     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1328   }
1329   template <typename T>
readImpl()1330   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1331     switch (read<PDLValue::Kind>()) {
1332     case PDLValue::Kind::Attribute:
1333       return read<Attribute>();
1334     case PDLValue::Kind::Operation:
1335       return read<Operation *>();
1336     case PDLValue::Kind::Type:
1337       return read<Type>();
1338     case PDLValue::Kind::Value:
1339       return read<Value>();
1340     case PDLValue::Kind::TypeRange:
1341       return read<TypeRange *>();
1342     case PDLValue::Kind::ValueRange:
1343       return read<ValueRange *>();
1344     }
1345     llvm_unreachable("unhandled PDLValue::Kind");
1346   }
1347   template <typename T>
readImpl()1348   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1349     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1350                   "unexpected ByteCode address size");
1351     ByteCodeAddr result;
1352     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1353     curCodeIt += 2;
1354     return result;
1355   }
1356   template <typename T>
readImpl()1357   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1358     return *curCodeIt++;
1359   }
1360   template <typename T>
readImpl()1361   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1362     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1363   }
1364 
1365   /// Assign the given range to the given memory index. This allocates a new
1366   /// range object if necessary.
1367   template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
assignRangeToMemory(RangeT && range,unsigned memIndex,unsigned rangeIndex)1368   void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1369                            unsigned rangeIndex) {
1370     // Utility functor used to type-erase the assignment.
1371     auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1372       // If the input range is empty, we don't need to allocate anything.
1373       if (range.empty()) {
1374         rangeMemory[rangeIndex] = {};
1375       } else {
1376         // Allocate a buffer for this type range.
1377         llvm::OwningArrayRef<T> storage(llvm::size(range));
1378         llvm::copy(range, storage.begin());
1379 
1380         // Assign this to the range slot and use the range as the value for the
1381         // memory index.
1382         allocatedRangeMemory.emplace_back(std::move(storage));
1383         rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1384       }
1385       memory[memIndex] = &rangeMemory[rangeIndex];
1386     };
1387 
1388     // Dispatch based on the concrete range type.
1389     if constexpr (std::is_same_v<T, Type>) {
1390       return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1391     } else if constexpr (std::is_same_v<T, Value>) {
1392       return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1393     } else {
1394       llvm_unreachable("unhandled range type");
1395     }
1396   }
1397 
1398   /// The underlying bytecode buffer.
1399   const ByteCodeField *curCodeIt;
1400 
1401   /// The stack of bytecode positions at which to resume operation.
1402   SmallVector<const ByteCodeField *> resumeCodeIt;
1403 
1404   /// The current execution memory.
1405   MutableArrayRef<const void *> memory;
1406   MutableArrayRef<OwningOpRange> opRangeMemory;
1407   MutableArrayRef<TypeRange> typeRangeMemory;
1408   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1409   MutableArrayRef<ValueRange> valueRangeMemory;
1410   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1411 
1412   /// The current loop indices.
1413   MutableArrayRef<unsigned> loopIndex;
1414 
1415   /// References to ByteCode data necessary for execution.
1416   ArrayRef<const void *> uniquedMemory;
1417   ArrayRef<ByteCodeField> code;
1418   ArrayRef<PatternBenefit> currentPatternBenefits;
1419   ArrayRef<PDLByteCodePattern> patterns;
1420   ArrayRef<PDLConstraintFunction> constraintFunctions;
1421   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1422 };
1423 } // namespace
1424 
executeApplyConstraint(PatternRewriter & rewriter)1425 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1426   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1427   ByteCodeField fun_idx = read();
1428   SmallVector<PDLValue, 16> args;
1429   readList<PDLValue>(args);
1430 
1431   LLVM_DEBUG({
1432     llvm::dbgs() << "  * Arguments: ";
1433     llvm::interleaveComma(args, llvm::dbgs());
1434     llvm::dbgs() << "\n";
1435   });
1436 
1437   ByteCodeField isNegated = read();
1438   LLVM_DEBUG({
1439     llvm::dbgs() << "  * isNegated: " << isNegated << "\n";
1440     llvm::interleaveComma(args, llvm::dbgs());
1441   });
1442 
1443   ByteCodeField numResults = read();
1444   const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1445   ByteCodeRewriteResultList results(numResults);
1446   LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1447   [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1448   LLVM_DEBUG({
1449     if (succeeded(rewriteResult)) {
1450       llvm::dbgs() << "  * Constraint succeeded\n";
1451       llvm::dbgs() << "  * Results: ";
1452       llvm::interleaveComma(constraintResults, llvm::dbgs());
1453       llvm::dbgs() << "\n";
1454     } else {
1455       llvm::dbgs() << "  * Constraint failed\n";
1456     }
1457   });
1458   assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1459          "native PDL rewrite function succeeded but returned "
1460          "unexpected number of results");
1461   processNativeFunResults(results, numResults, rewriteResult);
1462 
1463   // Depending on the constraint jump to the proper destination.
1464   selectJump(isNegated != succeeded(rewriteResult));
1465 }
1466 
executeApplyRewrite(PatternRewriter & rewriter)1467 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1468   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1469   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1470   SmallVector<PDLValue, 16> args;
1471   readList<PDLValue>(args);
1472 
1473   LLVM_DEBUG({
1474     llvm::dbgs() << "  * Arguments: ";
1475     llvm::interleaveComma(args, llvm::dbgs());
1476   });
1477 
1478   // Execute the rewrite function.
1479   ByteCodeField numResults = read();
1480   ByteCodeRewriteResultList results(numResults);
1481   LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1482 
1483   assert(results.getResults().size() == numResults &&
1484          "native PDL rewrite function returned unexpected number of results");
1485 
1486   processNativeFunResults(results, numResults, rewriteResult);
1487 
1488   if (failed(rewriteResult)) {
1489     LLVM_DEBUG(llvm::dbgs() << "  - Failed");
1490     return failure();
1491   }
1492   return success();
1493 }
1494 
processNativeFunResults(ByteCodeRewriteResultList & results,unsigned numResults,LogicalResult & rewriteResult)1495 void ByteCodeExecutor::processNativeFunResults(
1496     ByteCodeRewriteResultList &results, unsigned numResults,
1497     LogicalResult &rewriteResult) {
1498   // Store the results in the bytecode memory or handle missing results on
1499   // failure.
1500   for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1501     PDLValue::Kind resultKind = read<PDLValue::Kind>();
1502 
1503     // Skip the according number of values on the buffer on failure and exit
1504     // early as there are no results to process.
1505     if (failed(rewriteResult)) {
1506       if (resultKind == PDLValue::Kind::TypeRange ||
1507           resultKind == PDLValue::Kind::ValueRange) {
1508         skip(2);
1509       } else {
1510         skip(1);
1511       }
1512       return;
1513     }
1514     PDLValue result = results.getResults()[resultIdx];
1515     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1516     assert(result.getKind() == resultKind &&
1517            "native PDL rewrite function returned an unexpected type of "
1518            "result");
1519     // If the result is a range, we need to copy it over to the bytecodes
1520     // range memory.
1521     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1522       unsigned rangeIndex = read();
1523       typeRangeMemory[rangeIndex] = *typeRange;
1524       memory[read()] = &typeRangeMemory[rangeIndex];
1525     } else if (std::optional<ValueRange> valueRange =
1526                    result.dyn_cast<ValueRange>()) {
1527       unsigned rangeIndex = read();
1528       valueRangeMemory[rangeIndex] = *valueRange;
1529       memory[read()] = &valueRangeMemory[rangeIndex];
1530     } else {
1531       memory[read()] = result.getAsOpaquePointer();
1532     }
1533   }
1534 
1535   // Copy over any underlying storage allocated for result ranges.
1536   for (auto &it : results.getAllocatedTypeRanges())
1537     allocatedTypeRangeMemory.push_back(std::move(it));
1538   for (auto &it : results.getAllocatedValueRanges())
1539     allocatedValueRangeMemory.push_back(std::move(it));
1540 }
1541 
executeAreEqual()1542 void ByteCodeExecutor::executeAreEqual() {
1543   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1544   const void *lhs = read<const void *>();
1545   const void *rhs = read<const void *>();
1546 
1547   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1548   selectJump(lhs == rhs);
1549 }
1550 
executeAreRangesEqual()1551 void ByteCodeExecutor::executeAreRangesEqual() {
1552   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1553   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1554   const void *lhs = read<const void *>();
1555   const void *rhs = read<const void *>();
1556 
1557   switch (valueKind) {
1558   case PDLValue::Kind::TypeRange: {
1559     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1560     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1561     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1562     selectJump(*lhsRange == *rhsRange);
1563     break;
1564   }
1565   case PDLValue::Kind::ValueRange: {
1566     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1567     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1568     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1569     selectJump(*lhsRange == *rhsRange);
1570     break;
1571   }
1572   default:
1573     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1574   }
1575 }
1576 
executeBranch()1577 void ByteCodeExecutor::executeBranch() {
1578   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1579   curCodeIt = &code[read<ByteCodeAddr>()];
1580 }
1581 
executeCheckOperandCount()1582 void ByteCodeExecutor::executeCheckOperandCount() {
1583   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1584   Operation *op = read<Operation *>();
1585   uint32_t expectedCount = read<uint32_t>();
1586   bool compareAtLeast = read();
1587 
1588   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1589                           << "  * Expected: " << expectedCount << "\n"
1590                           << "  * Comparator: "
1591                           << (compareAtLeast ? ">=" : "==") << "\n");
1592   if (compareAtLeast)
1593     selectJump(op->getNumOperands() >= expectedCount);
1594   else
1595     selectJump(op->getNumOperands() == expectedCount);
1596 }
1597 
executeCheckOperationName()1598 void ByteCodeExecutor::executeCheckOperationName() {
1599   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1600   Operation *op = read<Operation *>();
1601   OperationName expectedName = read<OperationName>();
1602 
1603   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1604                           << "  * Expected: \"" << expectedName << "\"\n");
1605   selectJump(op->getName() == expectedName);
1606 }
1607 
executeCheckResultCount()1608 void ByteCodeExecutor::executeCheckResultCount() {
1609   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1610   Operation *op = read<Operation *>();
1611   uint32_t expectedCount = read<uint32_t>();
1612   bool compareAtLeast = read();
1613 
1614   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1615                           << "  * Expected: " << expectedCount << "\n"
1616                           << "  * Comparator: "
1617                           << (compareAtLeast ? ">=" : "==") << "\n");
1618   if (compareAtLeast)
1619     selectJump(op->getNumResults() >= expectedCount);
1620   else
1621     selectJump(op->getNumResults() == expectedCount);
1622 }
1623 
executeCheckTypes()1624 void ByteCodeExecutor::executeCheckTypes() {
1625   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1626   TypeRange *lhs = read<TypeRange *>();
1627   Attribute rhs = read<Attribute>();
1628   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1629 
1630   selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1631 }
1632 
executeContinue()1633 void ByteCodeExecutor::executeContinue() {
1634   ByteCodeField level = read();
1635   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1636                           << "  * Level: " << level << "\n");
1637   ++loopIndex[level];
1638   popCodeIt();
1639 }
1640 
executeCreateConstantTypeRange()1641 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642   LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1643   unsigned memIndex = read();
1644   unsigned rangeIndex = read();
1645   ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1646 
1647   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1648   assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1649                       rangeIndex);
1650 }
1651 
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1652 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1653                                               Location mainRewriteLoc) {
1654   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1655 
1656   unsigned memIndex = read();
1657   OperationState state(mainRewriteLoc, read<OperationName>());
1658   readList(state.operands);
1659   for (unsigned i = 0, e = read(); i != e; ++i) {
1660     StringAttr name = read<StringAttr>();
1661     if (Attribute attr = read<Attribute>())
1662       state.addAttribute(name, attr);
1663   }
1664 
1665   // Read in the result types. If the "size" is the sentinel value, this
1666   // indicates that the result types should be inferred.
1667   unsigned numResults = read();
1668   if (numResults == kInferTypesMarker) {
1669     InferTypeOpInterface::Concept *inferInterface =
1670         state.name.getInterface<InferTypeOpInterface>();
1671     assert(inferInterface &&
1672            "expected operation to provide InferTypeOpInterface");
1673 
1674     // TODO: Handle failure.
1675     if (failed(inferInterface->inferReturnTypes(
1676             state.getContext(), state.location, state.operands,
1677             state.attributes.getDictionary(state.getContext()),
1678             state.getRawProperties(), state.regions, state.types)))
1679       return;
1680   } else {
1681     // Otherwise, this is a fixed number of results.
1682     for (unsigned i = 0; i != numResults; ++i) {
1683       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1684         state.types.push_back(read<Type>());
1685       } else {
1686         TypeRange *resultTypes = read<TypeRange *>();
1687         state.types.append(resultTypes->begin(), resultTypes->end());
1688       }
1689     }
1690   }
1691 
1692   Operation *resultOp = rewriter.create(state);
1693   memory[memIndex] = resultOp;
1694 
1695   LLVM_DEBUG({
1696     llvm::dbgs() << "  * Attributes: "
1697                  << state.attributes.getDictionary(state.getContext())
1698                  << "\n  * Operands: ";
1699     llvm::interleaveComma(state.operands, llvm::dbgs());
1700     llvm::dbgs() << "\n  * Result Types: ";
1701     llvm::interleaveComma(state.types, llvm::dbgs());
1702     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1703   });
1704 }
1705 
1706 template <typename T>
executeDynamicCreateRange(StringRef type)1707 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1708   LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1709   unsigned memIndex = read();
1710   unsigned rangeIndex = read();
1711   SmallVector<T> values;
1712   readList(values);
1713 
1714   LLVM_DEBUG({
1715     llvm::dbgs() << "\n  * " << type << "s: ";
1716     llvm::interleaveComma(values, llvm::dbgs());
1717     llvm::dbgs() << "\n";
1718   });
1719 
1720   assignRangeToMemory(values, memIndex, rangeIndex);
1721 }
1722 
executeEraseOp(PatternRewriter & rewriter)1723 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1724   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1725   Operation *op = read<Operation *>();
1726 
1727   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1728   rewriter.eraseOp(op);
1729 }
1730 
1731 template <typename T, typename Range, PDLValue::Kind kind>
executeExtract()1732 void ByteCodeExecutor::executeExtract() {
1733   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1734   Range *range = read<Range *>();
1735   unsigned index = read<uint32_t>();
1736   unsigned memIndex = read();
1737 
1738   if (!range) {
1739     memory[memIndex] = nullptr;
1740     return;
1741   }
1742 
1743   T result = index < range->size() ? (*range)[index] : T();
1744   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1745                           << "  * Index: " << index << "\n"
1746                           << "  * Result: " << result << "\n");
1747   storeToMemory(memIndex, result);
1748 }
1749 
executeFinalize()1750 void ByteCodeExecutor::executeFinalize() {
1751   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1752 }
1753 
executeForEach()1754 void ByteCodeExecutor::executeForEach() {
1755   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1756   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1757   unsigned rangeIndex = read();
1758   unsigned memIndex = read();
1759   const void *value = nullptr;
1760 
1761   switch (read<PDLValue::Kind>()) {
1762   case PDLValue::Kind::Operation: {
1763     unsigned &index = loopIndex[read()];
1764     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1765     assert(index <= array.size() && "iterated past the end");
1766     if (index < array.size()) {
1767       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1768       value = array[index];
1769       break;
1770     }
1771 
1772     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1773     index = 0;
1774     selectJump(size_t(0));
1775     return;
1776   }
1777   default:
1778     llvm_unreachable("unexpected `ForEach` value kind");
1779   }
1780 
1781   // Store the iterate value and the stack address.
1782   memory[memIndex] = value;
1783   pushCodeIt(prevCodeIt);
1784 
1785   // Skip over the successor (we will enter the body of the loop).
1786   read<ByteCodeAddr>();
1787 }
1788 
executeGetAttribute()1789 void ByteCodeExecutor::executeGetAttribute() {
1790   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1791   unsigned memIndex = read();
1792   Operation *op = read<Operation *>();
1793   StringAttr attrName = read<StringAttr>();
1794   Attribute attr = op->getAttr(attrName);
1795 
1796   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1797                           << "  * Attribute: " << attrName << "\n"
1798                           << "  * Result: " << attr << "\n");
1799   memory[memIndex] = attr.getAsOpaquePointer();
1800 }
1801 
executeGetAttributeType()1802 void ByteCodeExecutor::executeGetAttributeType() {
1803   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1804   unsigned memIndex = read();
1805   Attribute attr = read<Attribute>();
1806   Type type;
1807   if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1808     type = typedAttr.getType();
1809 
1810   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1811                           << "  * Result: " << type << "\n");
1812   memory[memIndex] = type.getAsOpaquePointer();
1813 }
1814 
executeGetDefiningOp()1815 void ByteCodeExecutor::executeGetDefiningOp() {
1816   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1817   unsigned memIndex = read();
1818   Operation *op = nullptr;
1819   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1820     Value value = read<Value>();
1821     if (value)
1822       op = value.getDefiningOp();
1823     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1824   } else {
1825     ValueRange *values = read<ValueRange *>();
1826     if (values && !values->empty()) {
1827       op = values->front().getDefiningOp();
1828     }
1829     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1830   }
1831 
1832   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1833   memory[memIndex] = op;
1834 }
1835 
executeGetOperand(unsigned index)1836 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1837   Operation *op = read<Operation *>();
1838   unsigned memIndex = read();
1839   Value operand =
1840       index < op->getNumOperands() ? op->getOperand(index) : Value();
1841 
1842   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1843                           << "  * Index: " << index << "\n"
1844                           << "  * Result: " << operand << "\n");
1845   memory[memIndex] = operand.getAsOpaquePointer();
1846 }
1847 
1848 /// This function is the internal implementation of `GetResults` and
1849 /// `GetOperands` that provides support for extracting a value range from the
1850 /// given operation.
1851 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1852 static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> valueRangeMemory)1853 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1854                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1855                           MutableArrayRef<ValueRange> valueRangeMemory) {
1856   // Check for the sentinel index that signals that all values should be
1857   // returned.
1858   if (index == std::numeric_limits<uint32_t>::max()) {
1859     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1860     // `values` is already the full value range.
1861 
1862     // Otherwise, check to see if this operation uses AttrSizedSegments.
1863   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1864     LLVM_DEBUG(llvm::dbgs()
1865                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1866 
1867     auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1868     if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1869       return nullptr;
1870 
1871     ArrayRef<int32_t> segments = segmentAttr;
1872     unsigned startIndex =
1873         std::accumulate(segments.begin(), segments.begin() + index, 0);
1874     values = values.slice(startIndex, *std::next(segments.begin(), index));
1875 
1876     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1877                             << *std::next(segments.begin(), index) << "]\n");
1878 
1879     // Otherwise, assume this is the last operand group of the operation.
1880     // FIXME: We currently don't support operations with
1881     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1882     // have a way to detect it's presence.
1883   } else if (values.size() >= index) {
1884     LLVM_DEBUG(llvm::dbgs()
1885                << "  * Treating values as trailing variadic range\n");
1886     values = values.drop_front(index);
1887 
1888     // If we couldn't detect a way to compute the values, bail out.
1889   } else {
1890     return nullptr;
1891   }
1892 
1893   // If the range index is valid, we are returning a range.
1894   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1895     valueRangeMemory[rangeIndex] = values;
1896     return &valueRangeMemory[rangeIndex];
1897   }
1898 
1899   // If a range index wasn't provided, the range is required to be non-variadic.
1900   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1901 }
1902 
executeGetOperands()1903 void ByteCodeExecutor::executeGetOperands() {
1904   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1905   unsigned index = read<uint32_t>();
1906   Operation *op = read<Operation *>();
1907   ByteCodeField rangeIndex = read();
1908 
1909   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1910       op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1911       valueRangeMemory);
1912   if (!result)
1913     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1914   memory[read()] = result;
1915 }
1916 
executeGetResult(unsigned index)1917 void ByteCodeExecutor::executeGetResult(unsigned index) {
1918   Operation *op = read<Operation *>();
1919   unsigned memIndex = read();
1920   OpResult result =
1921       index < op->getNumResults() ? op->getResult(index) : OpResult();
1922 
1923   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1924                           << "  * Index: " << index << "\n"
1925                           << "  * Result: " << result << "\n");
1926   memory[memIndex] = result.getAsOpaquePointer();
1927 }
1928 
executeGetResults()1929 void ByteCodeExecutor::executeGetResults() {
1930   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1931   unsigned index = read<uint32_t>();
1932   Operation *op = read<Operation *>();
1933   ByteCodeField rangeIndex = read();
1934 
1935   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1936       op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1937       valueRangeMemory);
1938   if (!result)
1939     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1940   memory[read()] = result;
1941 }
1942 
executeGetUsers()1943 void ByteCodeExecutor::executeGetUsers() {
1944   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1945   unsigned memIndex = read();
1946   unsigned rangeIndex = read();
1947   OwningOpRange &range = opRangeMemory[rangeIndex];
1948   memory[memIndex] = &range;
1949 
1950   range = OwningOpRange();
1951   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1952     // Read the value.
1953     Value value = read<Value>();
1954     if (!value)
1955       return;
1956     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1957 
1958     // Extract the users of a single value.
1959     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1960     llvm::copy(value.getUsers(), range.begin());
1961   } else {
1962     // Read a range of values.
1963     ValueRange *values = read<ValueRange *>();
1964     if (!values)
1965       return;
1966     LLVM_DEBUG({
1967       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1968       llvm::interleaveComma(*values, llvm::dbgs());
1969       llvm::dbgs() << "\n";
1970     });
1971 
1972     // Extract all the users of a range of values.
1973     SmallVector<Operation *> users;
1974     for (Value value : *values)
1975       users.append(value.user_begin(), value.user_end());
1976     range = OwningOpRange(users.size());
1977     llvm::copy(users, range.begin());
1978   }
1979 
1980   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1981 }
1982 
executeGetValueType()1983 void ByteCodeExecutor::executeGetValueType() {
1984   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1985   unsigned memIndex = read();
1986   Value value = read<Value>();
1987   Type type = value ? value.getType() : Type();
1988 
1989   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1990                           << "  * Result: " << type << "\n");
1991   memory[memIndex] = type.getAsOpaquePointer();
1992 }
1993 
executeGetValueRangeTypes()1994 void ByteCodeExecutor::executeGetValueRangeTypes() {
1995   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1996   unsigned memIndex = read();
1997   unsigned rangeIndex = read();
1998   ValueRange *values = read<ValueRange *>();
1999   if (!values) {
2000     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
2001     memory[memIndex] = nullptr;
2002     return;
2003   }
2004 
2005   LLVM_DEBUG({
2006     llvm::dbgs() << "  * Values (" << values->size() << "): ";
2007     llvm::interleaveComma(*values, llvm::dbgs());
2008     llvm::dbgs() << "\n  * Result: ";
2009     llvm::interleaveComma(values->getType(), llvm::dbgs());
2010     llvm::dbgs() << "\n";
2011   });
2012   typeRangeMemory[rangeIndex] = values->getType();
2013   memory[memIndex] = &typeRangeMemory[rangeIndex];
2014 }
2015 
executeIsNotNull()2016 void ByteCodeExecutor::executeIsNotNull() {
2017   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2018   const void *value = read<const void *>();
2019 
2020   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
2021   selectJump(value != nullptr);
2022 }
2023 
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)2024 void ByteCodeExecutor::executeRecordMatch(
2025     PatternRewriter &rewriter,
2026     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
2027   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2028   unsigned patternIndex = read();
2029   PatternBenefit benefit = currentPatternBenefits[patternIndex];
2030   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2031 
2032   // If the benefit of the pattern is impossible, skip the processing of the
2033   // rest of the pattern.
2034   if (benefit.isImpossibleToMatch()) {
2035     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
2036     curCodeIt = dest;
2037     return;
2038   }
2039 
2040   // Create a fused location containing the locations of each of the
2041   // operations used in the match. This will be used as the location for
2042   // created operations during the rewrite that don't already have an
2043   // explicit location set.
2044   unsigned numMatchLocs = read();
2045   SmallVector<Location, 4> matchLocs;
2046   matchLocs.reserve(numMatchLocs);
2047   for (unsigned i = 0; i != numMatchLocs; ++i)
2048     matchLocs.push_back(read<Operation *>()->getLoc());
2049   Location matchLoc = rewriter.getFusedLoc(matchLocs);
2050 
2051   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
2052                           << "  * Location: " << matchLoc << "\n");
2053   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2054   PDLByteCode::MatchResult &match = matches.back();
2055 
2056   // Record all of the inputs to the match. If any of the inputs are ranges, we
2057   // will also need to remap the range pointer to memory stored in the match
2058   // state.
2059   unsigned numInputs = read();
2060   match.values.reserve(numInputs);
2061   match.typeRangeValues.reserve(numInputs);
2062   match.valueRangeValues.reserve(numInputs);
2063   for (unsigned i = 0; i < numInputs; ++i) {
2064     switch (read<PDLValue::Kind>()) {
2065     case PDLValue::Kind::TypeRange:
2066       match.typeRangeValues.push_back(*read<TypeRange *>());
2067       match.values.push_back(&match.typeRangeValues.back());
2068       break;
2069     case PDLValue::Kind::ValueRange:
2070       match.valueRangeValues.push_back(*read<ValueRange *>());
2071       match.values.push_back(&match.valueRangeValues.back());
2072       break;
2073     default:
2074       match.values.push_back(read<const void *>());
2075       break;
2076     }
2077   }
2078   curCodeIt = dest;
2079 }
2080 
executeReplaceOp(PatternRewriter & rewriter)2081 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2082   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2083   Operation *op = read<Operation *>();
2084   SmallVector<Value, 16> args;
2085   readList(args);
2086 
2087   LLVM_DEBUG({
2088     llvm::dbgs() << "  * Operation: " << *op << "\n"
2089                  << "  * Values: ";
2090     llvm::interleaveComma(args, llvm::dbgs());
2091     llvm::dbgs() << "\n";
2092   });
2093   rewriter.replaceOp(op, args);
2094 }
2095 
executeSwitchAttribute()2096 void ByteCodeExecutor::executeSwitchAttribute() {
2097   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2098   Attribute value = read<Attribute>();
2099   ArrayAttr cases = read<ArrayAttr>();
2100   handleSwitch(value, cases);
2101 }
2102 
executeSwitchOperandCount()2103 void ByteCodeExecutor::executeSwitchOperandCount() {
2104   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2105   Operation *op = read<Operation *>();
2106   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2107 
2108   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2109   handleSwitch(op->getNumOperands(), cases);
2110 }
2111 
executeSwitchOperationName()2112 void ByteCodeExecutor::executeSwitchOperationName() {
2113   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2114   OperationName value = read<Operation *>()->getName();
2115   size_t caseCount = read();
2116 
2117   // The operation names are stored in-line, so to print them out for
2118   // debugging purposes we need to read the array before executing the
2119   // switch so that we can display all of the possible values.
2120   LLVM_DEBUG({
2121     const ByteCodeField *prevCodeIt = curCodeIt;
2122     llvm::dbgs() << "  * Value: " << value << "\n"
2123                  << "  * Cases: ";
2124     llvm::interleaveComma(
2125         llvm::map_range(llvm::seq<size_t>(0, caseCount),
2126                         [&](size_t) { return read<OperationName>(); }),
2127         llvm::dbgs());
2128     llvm::dbgs() << "\n";
2129     curCodeIt = prevCodeIt;
2130   });
2131 
2132   // Try to find the switch value within any of the cases.
2133   for (size_t i = 0; i != caseCount; ++i) {
2134     if (read<OperationName>() == value) {
2135       curCodeIt += (caseCount - i - 1);
2136       return selectJump(i + 1);
2137     }
2138   }
2139   selectJump(size_t(0));
2140 }
2141 
executeSwitchResultCount()2142 void ByteCodeExecutor::executeSwitchResultCount() {
2143   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2144   Operation *op = read<Operation *>();
2145   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2146 
2147   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
2148   handleSwitch(op->getNumResults(), cases);
2149 }
2150 
executeSwitchType()2151 void ByteCodeExecutor::executeSwitchType() {
2152   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2153   Type value = read<Type>();
2154   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2155   handleSwitch(value, cases);
2156 }
2157 
executeSwitchTypes()2158 void ByteCodeExecutor::executeSwitchTypes() {
2159   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2160   TypeRange *value = read<TypeRange *>();
2161   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2162   if (!value) {
2163     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2164     return selectJump(size_t(0));
2165   }
2166   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2167     return value == caseValue.getAsValueRange<TypeAttr>();
2168   });
2169 }
2170 
2171 LogicalResult
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,std::optional<Location> mainRewriteLoc)2172 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2173                           SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2174                           std::optional<Location> mainRewriteLoc) {
2175   while (true) {
2176     // Print the location of the operation being executed.
2177     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2178 
2179     OpCode opCode = static_cast<OpCode>(read());
2180     switch (opCode) {
2181     case ApplyConstraint:
2182       executeApplyConstraint(rewriter);
2183       break;
2184     case ApplyRewrite:
2185       if (failed(executeApplyRewrite(rewriter)))
2186         return failure();
2187       break;
2188     case AreEqual:
2189       executeAreEqual();
2190       break;
2191     case AreRangesEqual:
2192       executeAreRangesEqual();
2193       break;
2194     case Branch:
2195       executeBranch();
2196       break;
2197     case CheckOperandCount:
2198       executeCheckOperandCount();
2199       break;
2200     case CheckOperationName:
2201       executeCheckOperationName();
2202       break;
2203     case CheckResultCount:
2204       executeCheckResultCount();
2205       break;
2206     case CheckTypes:
2207       executeCheckTypes();
2208       break;
2209     case Continue:
2210       executeContinue();
2211       break;
2212     case CreateConstantTypeRange:
2213       executeCreateConstantTypeRange();
2214       break;
2215     case CreateOperation:
2216       executeCreateOperation(rewriter, *mainRewriteLoc);
2217       break;
2218     case CreateDynamicTypeRange:
2219       executeDynamicCreateRange<Type>("Type");
2220       break;
2221     case CreateDynamicValueRange:
2222       executeDynamicCreateRange<Value>("Value");
2223       break;
2224     case EraseOp:
2225       executeEraseOp(rewriter);
2226       break;
2227     case ExtractOp:
2228       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2229       break;
2230     case ExtractType:
2231       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2232       break;
2233     case ExtractValue:
2234       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2235       break;
2236     case Finalize:
2237       executeFinalize();
2238       LLVM_DEBUG(llvm::dbgs() << "\n");
2239       return success();
2240     case ForEach:
2241       executeForEach();
2242       break;
2243     case GetAttribute:
2244       executeGetAttribute();
2245       break;
2246     case GetAttributeType:
2247       executeGetAttributeType();
2248       break;
2249     case GetDefiningOp:
2250       executeGetDefiningOp();
2251       break;
2252     case GetOperand0:
2253     case GetOperand1:
2254     case GetOperand2:
2255     case GetOperand3: {
2256       unsigned index = opCode - GetOperand0;
2257       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2258       executeGetOperand(index);
2259       break;
2260     }
2261     case GetOperandN:
2262       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2263       executeGetOperand(read<uint32_t>());
2264       break;
2265     case GetOperands:
2266       executeGetOperands();
2267       break;
2268     case GetResult0:
2269     case GetResult1:
2270     case GetResult2:
2271     case GetResult3: {
2272       unsigned index = opCode - GetResult0;
2273       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2274       executeGetResult(index);
2275       break;
2276     }
2277     case GetResultN:
2278       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2279       executeGetResult(read<uint32_t>());
2280       break;
2281     case GetResults:
2282       executeGetResults();
2283       break;
2284     case GetUsers:
2285       executeGetUsers();
2286       break;
2287     case GetValueType:
2288       executeGetValueType();
2289       break;
2290     case GetValueRangeTypes:
2291       executeGetValueRangeTypes();
2292       break;
2293     case IsNotNull:
2294       executeIsNotNull();
2295       break;
2296     case RecordMatch:
2297       assert(matches &&
2298              "expected matches to be provided when executing the matcher");
2299       executeRecordMatch(rewriter, *matches);
2300       break;
2301     case ReplaceOp:
2302       executeReplaceOp(rewriter);
2303       break;
2304     case SwitchAttribute:
2305       executeSwitchAttribute();
2306       break;
2307     case SwitchOperandCount:
2308       executeSwitchOperandCount();
2309       break;
2310     case SwitchOperationName:
2311       executeSwitchOperationName();
2312       break;
2313     case SwitchResultCount:
2314       executeSwitchResultCount();
2315       break;
2316     case SwitchType:
2317       executeSwitchType();
2318       break;
2319     case SwitchTypes:
2320       executeSwitchTypes();
2321       break;
2322     }
2323     LLVM_DEBUG(llvm::dbgs() << "\n");
2324   }
2325 }
2326 
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const2327 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2328                         SmallVectorImpl<MatchResult> &matches,
2329                         PDLByteCodeMutableState &state) const {
2330   // The first memory slot is always the root operation.
2331   state.memory[0] = op;
2332 
2333   // The matcher function always starts at code address 0.
2334   ByteCodeExecutor executor(
2335       matcherByteCode.data(), state.memory, state.opRangeMemory,
2336       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2337       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2338       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2339       constraintFunctions, rewriteFunctions);
2340   LogicalResult executeResult = executor.execute(rewriter, &matches);
2341   (void)executeResult;
2342   assert(succeeded(executeResult) && "unexpected matcher execution failure");
2343 
2344   // Order the found matches by benefit.
2345   std::stable_sort(matches.begin(), matches.end(),
2346                    [](const MatchResult &lhs, const MatchResult &rhs) {
2347                      return lhs.benefit > rhs.benefit;
2348                    });
2349 }
2350 
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const2351 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2352                                    const MatchResult &match,
2353                                    PDLByteCodeMutableState &state) const {
2354   auto *configSet = match.pattern->getConfigSet();
2355   if (configSet)
2356     configSet->notifyRewriteBegin(rewriter);
2357 
2358   // The arguments of the rewrite function are stored at the start of the
2359   // memory buffer.
2360   llvm::copy(match.values, state.memory.begin());
2361 
2362   ByteCodeExecutor executor(
2363       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2364       state.opRangeMemory, state.typeRangeMemory,
2365       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2366       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2367       rewriterByteCode, state.currentPatternBenefits, patterns,
2368       constraintFunctions, rewriteFunctions);
2369   LogicalResult result =
2370       executor.execute(rewriter, /*matches=*/nullptr, match.location);
2371 
2372   if (configSet)
2373     configSet->notifyRewriteEnd(rewriter);
2374 
2375   // If the rewrite failed, check if the pattern rewriter can recover. If it
2376   // can, we can signal to the pattern applicator to keep trying patterns. If it
2377   // doesn't, we need to bail. Bailing here should be fine, given that we have
2378   // no means to propagate such a failure to the user, and it also indicates a
2379   // bug in the user code (i.e. failable rewrites should not be used with
2380   // pattern rewriters that don't support it).
2381   if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2382     LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2383     llvm::report_fatal_error(
2384         "Native PDL Rewrite failed, but the pattern "
2385         "rewriter doesn't support recovery. Failable pattern rewrites should "
2386         "not be used with pattern rewriters that do not support them.");
2387   }
2388   return result;
2389 }
2390