1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
2385ab413bSRiver Riddle #include "llvm/Support/Format.h"
2485ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h"
2585ab413bSRiver Riddle #include <numeric>
26a1fe1f5fSKazu Hirata #include <optional>
27abfd1a8bSRiver Riddle
28abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
29abfd1a8bSRiver Riddle
30abfd1a8bSRiver Riddle using namespace mlir;
31abfd1a8bSRiver Riddle using namespace mlir::detail;
32abfd1a8bSRiver Riddle
33abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
34abfd1a8bSRiver Riddle // PDLByteCodePattern
35abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
36abfd1a8bSRiver Riddle
create(pdl_interp::RecordMatchOp matchOp,PDLPatternConfigSet * configSet,ByteCodeAddr rewriterAddr)37abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
388c66344eSRiver Riddle PDLPatternConfigSet *configSet,
39abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) {
408c66344eSRiver Riddle PatternBenefit benefit = matchOp.getBenefit();
418c66344eSRiver Riddle MLIRContext *ctx = matchOp.getContext();
428c66344eSRiver Riddle
438c66344eSRiver Riddle // Collect the set of generated operations.
44abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps;
453c405c3bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46abfd1a8bSRiver Riddle generatedOps =
47abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48abfd1a8bSRiver Riddle
49abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type.
5022426110SRamkumar Ramachandra if (std::optional<StringRef> rootKind = matchOp.getRootKind())
518c66344eSRiver Riddle return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
5276f3c2f3SRiver Riddle generatedOps);
538c66344eSRiver Riddle return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
548c66344eSRiver Riddle benefit, ctx, generatedOps);
55abfd1a8bSRiver Riddle }
56abfd1a8bSRiver Riddle
57abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
58abfd1a8bSRiver Riddle // PDLByteCodeMutableState
59abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
60abfd1a8bSRiver Riddle
61abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
63abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)64abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65abfd1a8bSRiver Riddle PatternBenefit benefit) {
66abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit;
67abfd1a8bSRiver Riddle }
68abfd1a8bSRiver Riddle
6985ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed.
7085ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a
7185ab413bSRiver Riddle /// success or not.
cleanupAfterMatchAndRewrite()7285ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
7385ab413bSRiver Riddle allocatedTypeRangeMemory.clear();
7485ab413bSRiver Riddle allocatedValueRangeMemory.clear();
7585ab413bSRiver Riddle }
7685ab413bSRiver Riddle
77abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
78abfd1a8bSRiver Riddle // Bytecode OpCodes
79abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
80abfd1a8bSRiver Riddle
81abfd1a8bSRiver Riddle namespace {
82abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
83abfd1a8bSRiver Riddle /// Apply an externally registered constraint.
84abfd1a8bSRiver Riddle ApplyConstraint,
85abfd1a8bSRiver Riddle /// Apply an externally registered rewrite.
86abfd1a8bSRiver Riddle ApplyRewrite,
87abfd1a8bSRiver Riddle /// Check if two generic values are equal.
88abfd1a8bSRiver Riddle AreEqual,
8985ab413bSRiver Riddle /// Check if two ranges are equal.
9085ab413bSRiver Riddle AreRangesEqual,
91abfd1a8bSRiver Riddle /// Unconditional branch.
92abfd1a8bSRiver Riddle Branch,
93abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant.
94abfd1a8bSRiver Riddle CheckOperandCount,
95abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant.
96abfd1a8bSRiver Riddle CheckOperationName,
97abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant.
98abfd1a8bSRiver Riddle CheckResultCount,
9985ab413bSRiver Riddle /// Compare a range of types to a constant range of types.
10085ab413bSRiver Riddle CheckTypes,
1013eb1647aSStanislav Funiak /// Continue to the next iteration of a loop.
1023eb1647aSStanislav Funiak Continue,
103ce57789dSRiver Riddle /// Create a type range from a list of constant types.
104ce57789dSRiver Riddle CreateConstantTypeRange,
105abfd1a8bSRiver Riddle /// Create an operation.
106abfd1a8bSRiver Riddle CreateOperation,
107ce57789dSRiver Riddle /// Create a type range from a list of dynamic types.
108ce57789dSRiver Riddle CreateDynamicTypeRange,
109ce57789dSRiver Riddle /// Create a value range.
110ce57789dSRiver Riddle CreateDynamicValueRange,
111abfd1a8bSRiver Riddle /// Erase an operation.
112abfd1a8bSRiver Riddle EraseOp,
1133eb1647aSStanislav Funiak /// Extract the op from a range at the specified index.
1143eb1647aSStanislav Funiak ExtractOp,
1153eb1647aSStanislav Funiak /// Extract the type from a range at the specified index.
1163eb1647aSStanislav Funiak ExtractType,
1173eb1647aSStanislav Funiak /// Extract the value from a range at the specified index.
1183eb1647aSStanislav Funiak ExtractValue,
119abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence.
120abfd1a8bSRiver Riddle Finalize,
1213eb1647aSStanislav Funiak /// Iterate over a range of values.
1223eb1647aSStanislav Funiak ForEach,
123abfd1a8bSRiver Riddle /// Get a specific attribute of an operation.
124abfd1a8bSRiver Riddle GetAttribute,
125abfd1a8bSRiver Riddle /// Get the type of an attribute.
126abfd1a8bSRiver Riddle GetAttributeType,
127abfd1a8bSRiver Riddle /// Get the defining operation of a value.
128abfd1a8bSRiver Riddle GetDefiningOp,
129abfd1a8bSRiver Riddle /// Get a specific operand of an operation.
130abfd1a8bSRiver Riddle GetOperand0,
131abfd1a8bSRiver Riddle GetOperand1,
132abfd1a8bSRiver Riddle GetOperand2,
133abfd1a8bSRiver Riddle GetOperand3,
134abfd1a8bSRiver Riddle GetOperandN,
13585ab413bSRiver Riddle /// Get a specific operand group of an operation.
13685ab413bSRiver Riddle GetOperands,
137abfd1a8bSRiver Riddle /// Get a specific result of an operation.
138abfd1a8bSRiver Riddle GetResult0,
139abfd1a8bSRiver Riddle GetResult1,
140abfd1a8bSRiver Riddle GetResult2,
141abfd1a8bSRiver Riddle GetResult3,
142abfd1a8bSRiver Riddle GetResultN,
14385ab413bSRiver Riddle /// Get a specific result group of an operation.
14485ab413bSRiver Riddle GetResults,
1453eb1647aSStanislav Funiak /// Get the users of a value or a range of values.
1463eb1647aSStanislav Funiak GetUsers,
147abfd1a8bSRiver Riddle /// Get the type of a value.
148abfd1a8bSRiver Riddle GetValueType,
14985ab413bSRiver Riddle /// Get the types of a value range.
15085ab413bSRiver Riddle GetValueRangeTypes,
151abfd1a8bSRiver Riddle /// Check if a generic value is not null.
152abfd1a8bSRiver Riddle IsNotNull,
153abfd1a8bSRiver Riddle /// Record a successful pattern match.
154abfd1a8bSRiver Riddle RecordMatch,
155abfd1a8bSRiver Riddle /// Replace an operation.
156abfd1a8bSRiver Riddle ReplaceOp,
157abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants.
158abfd1a8bSRiver Riddle SwitchAttribute,
159abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants.
160abfd1a8bSRiver Riddle SwitchOperandCount,
161abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants.
162abfd1a8bSRiver Riddle SwitchOperationName,
163abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants.
164abfd1a8bSRiver Riddle SwitchResultCount,
165abfd1a8bSRiver Riddle /// Compare a type with a set of constants.
166abfd1a8bSRiver Riddle SwitchType,
16785ab413bSRiver Riddle /// Compare a range of types with a set of constants.
16885ab413bSRiver Riddle SwitchTypes,
169abfd1a8bSRiver Riddle };
170be0a7e9fSMehdi Amini } // namespace
171abfd1a8bSRiver Riddle
1723c752289SRiver Riddle /// A marker used to indicate if an operation should infer types.
1733c752289SRiver Riddle static constexpr ByteCodeField kInferTypesMarker =
1743c752289SRiver Riddle std::numeric_limits<ByteCodeField>::max();
1753c752289SRiver Riddle
176abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
177abfd1a8bSRiver Riddle // ByteCode Generation
178abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
179abfd1a8bSRiver Riddle
180abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
181abfd1a8bSRiver Riddle // Generator
182abfd1a8bSRiver Riddle
183abfd1a8bSRiver Riddle namespace {
1843eb1647aSStanislav Funiak struct ByteCodeLiveRange;
185abfd1a8bSRiver Riddle struct ByteCodeWriter;
186abfd1a8bSRiver Riddle
1873eb1647aSStanislav Funiak /// Check if the given class `T` can be converted to an opaque pointer.
1883eb1647aSStanislav Funiak template <typename T, typename... Args>
1893eb1647aSStanislav Funiak using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
1903eb1647aSStanislav Funiak
191abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
192abfd1a8bSRiver Riddle class Generator {
193abfd1a8bSRiver Riddle public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxOpRangeMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,ByteCodeField & maxLoopLevel,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns,const DenseMap<Operation *,PDLPatternConfigSet * > & configMap)194abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode,
196abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns,
198abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex,
1993eb1647aSStanislav Funiak ByteCodeField &maxOpRangeMemoryIndex,
20085ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex,
20185ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex,
2023eb1647aSStanislav Funiak ByteCodeField &maxLoopLevel,
203abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns,
2048c66344eSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns,
2058c66344eSRiver Riddle const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns),
20885ab413bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex),
2093eb1647aSStanislav Funiak maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
21085ab413bSRiver Riddle maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
2113eb1647aSStanislav Funiak maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
2128c66344eSRiver Riddle maxLoopLevel(maxLoopLevel), configMap(configMap) {
213e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(constraintFns))
214abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index());
215e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(rewriteFns))
216abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217abfd1a8bSRiver Riddle }
218abfd1a8bSRiver Riddle
219abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module.
220abfd1a8bSRiver Riddle void generate(ModuleOp module);
221abfd1a8bSRiver Riddle
222abfd1a8bSRiver Riddle /// Return the memory index to use for the given value.
getMemIndex(Value value)223abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) {
224abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) &&
225abfd1a8bSRiver Riddle "expected memory index to be assigned");
226abfd1a8bSRiver Riddle return valueToMemIndex[value];
227abfd1a8bSRiver Riddle }
228abfd1a8bSRiver Riddle
22985ab413bSRiver Riddle /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)23085ab413bSRiver Riddle ByteCodeField &getRangeStorageIndex(Value value) {
23185ab413bSRiver Riddle assert(valueToRangeIndex.count(value) &&
23285ab413bSRiver Riddle "expected range index to be assigned");
23385ab413bSRiver Riddle return valueToRangeIndex[value];
23485ab413bSRiver Riddle }
23585ab413bSRiver Riddle
236abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in
237abfd1a8bSRiver Riddle /// the MLIR context.
238abfd1a8bSRiver Riddle template <typename T>
239abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)240abfd1a8bSRiver Riddle getMemIndex(T val) {
241abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer();
242abfd1a8bSRiver Riddle
243abfd1a8bSRiver Riddle // Get or insert a reference to this value.
244abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace(
245abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size());
246abfd1a8bSRiver Riddle if (it.second)
247abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal);
248abfd1a8bSRiver Riddle return it.first->second;
249abfd1a8bSRiver Riddle }
250abfd1a8bSRiver Riddle
251abfd1a8bSRiver Riddle private:
252abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher
253abfd1a8bSRiver Riddle /// and rewriters.
254f96a8675SRiver Riddle void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255f96a8675SRiver Riddle ModuleOp rewriterModule);
256abfd1a8bSRiver Riddle
257abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation.
2583eb1647aSStanislav Funiak void generate(Region *region, ByteCodeWriter &writer);
259abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer);
260abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
26985ab413bSRiver Riddle void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
2703eb1647aSStanislav Funiak void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273ce57789dSRiver Riddle void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
27585ab413bSRiver Riddle void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
2773eb1647aSStanislav Funiak void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
2793eb1647aSStanislav Funiak void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
28485ab413bSRiver Riddle void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
28685ab413bSRiver Riddle void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
2873eb1647aSStanislav Funiak void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
29485ab413bSRiver Riddle void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298abfd1a8bSRiver Riddle
299abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index.
300abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex;
301abfd1a8bSRiver Riddle
30285ab413bSRiver Riddle /// Mapping from a range value to its corresponding range storage index.
30385ab413bSRiver Riddle DenseMap<Value, ByteCodeField> valueToRangeIndex;
30485ab413bSRiver Riddle
305abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in
306abfd1a8bSRiver Riddle /// the bytecode registry.
307abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308abfd1a8bSRiver Riddle
309abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index
310abfd1a8bSRiver Riddle /// in the bytecode registry.
311abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex;
312abfd1a8bSRiver Riddle
313abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the
314abfd1a8bSRiver Riddle /// rewriter function in byte.
315abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316abfd1a8bSRiver Riddle
317abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within
318abfd1a8bSRiver Riddle /// `uniquedData`.
319abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320abfd1a8bSRiver Riddle
3213eb1647aSStanislav Funiak /// The current level of the foreach loop.
3223eb1647aSStanislav Funiak ByteCodeField curLoopLevel = 0;
3233eb1647aSStanislav Funiak
324abfd1a8bSRiver Riddle /// The current MLIR context.
325abfd1a8bSRiver Riddle MLIRContext *ctx;
326abfd1a8bSRiver Riddle
3273eb1647aSStanislav Funiak /// Mapping from block to its address.
3283eb1647aSStanislav Funiak DenseMap<Block *, ByteCodeAddr> blockToAddr;
3293eb1647aSStanislav Funiak
330abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated.
331abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData;
332abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode;
333abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns;
335abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex;
3363eb1647aSStanislav Funiak ByteCodeField &maxOpRangeMemoryIndex;
33785ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex;
33885ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex;
3393eb1647aSStanislav Funiak ByteCodeField &maxLoopLevel;
3408c66344eSRiver Riddle
3418c66344eSRiver Riddle /// A map of pattern configurations.
3428c66344eSRiver Riddle const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343abfd1a8bSRiver Riddle };
344abfd1a8bSRiver Riddle
345abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
346abfd1a8bSRiver Riddle struct ByteCodeWriter {
ByteCodeWriter__anon22ebf8dc0211::ByteCodeWriter347abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {}
349abfd1a8bSRiver Riddle
350abfd1a8bSRiver Riddle /// Append a field to the bytecode.
append__anon22ebf8dc0211::ByteCodeWriter351abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon22ebf8dc0211::ByteCodeWriter352fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); }
353abfd1a8bSRiver Riddle
354abfd1a8bSRiver Riddle /// Append an address to the bytecode.
append__anon22ebf8dc0211::ByteCodeWriter355abfd1a8bSRiver Riddle void append(ByteCodeAddr field) {
356abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357abfd1a8bSRiver Riddle "unexpected ByteCode address size");
358abfd1a8bSRiver Riddle
359abfd1a8bSRiver Riddle ByteCodeField fieldParts[2];
360abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]});
362abfd1a8bSRiver Riddle }
363abfd1a8bSRiver Riddle
3643eb1647aSStanislav Funiak /// Append a single successor to the bytecode, the exact address will need to
365abfd1a8bSRiver Riddle /// be resolved later.
append__anon22ebf8dc0211::ByteCodeWriter3663eb1647aSStanislav Funiak void append(Block *successor) {
3673eb1647aSStanislav Funiak // Add back a reference to the successor so that the address can be resolved
3683eb1647aSStanislav Funiak // later.
369abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370abfd1a8bSRiver Riddle append(ByteCodeAddr(0));
371abfd1a8bSRiver Riddle }
3723eb1647aSStanislav Funiak
3733eb1647aSStanislav Funiak /// Append a successor range to the bytecode, the exact address will need to
3743eb1647aSStanislav Funiak /// be resolved later.
append__anon22ebf8dc0211::ByteCodeWriter3753eb1647aSStanislav Funiak void append(SuccessorRange successors) {
3763eb1647aSStanislav Funiak for (Block *successor : successors)
3773eb1647aSStanislav Funiak append(successor);
378abfd1a8bSRiver Riddle }
379abfd1a8bSRiver Riddle
380abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon22ebf8dc0211::ByteCodeWriter381abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) {
382abfd1a8bSRiver Riddle bytecode.push_back(values.size());
38385ab413bSRiver Riddle for (Value value : values)
38485ab413bSRiver Riddle appendPDLValue(value);
38585ab413bSRiver Riddle }
38685ab413bSRiver Riddle
38785ab413bSRiver Riddle /// Append a value as a PDLValue.
appendPDLValue__anon22ebf8dc0211::ByteCodeWriter38885ab413bSRiver Riddle void appendPDLValue(Value value) {
38985ab413bSRiver Riddle appendPDLValueKind(value);
390abfd1a8bSRiver Riddle append(value);
391abfd1a8bSRiver Riddle }
39285ab413bSRiver Riddle
39385ab413bSRiver Riddle /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anon22ebf8dc0211::ByteCodeWriter3943eb1647aSStanislav Funiak void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
3953eb1647aSStanislav Funiak
3963eb1647aSStanislav Funiak /// Append the PDLValue::Kind of the given type.
appendPDLValueKind__anon22ebf8dc0211::ByteCodeWriter3973eb1647aSStanislav Funiak void appendPDLValueKind(Type type) {
39885ab413bSRiver Riddle PDLValue::Kind kind =
3993eb1647aSStanislav Funiak TypeSwitch<Type, PDLValue::Kind>(type)
40085ab413bSRiver Riddle .Case<pdl::AttributeType>(
40185ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Attribute; })
40285ab413bSRiver Riddle .Case<pdl::OperationType>(
40385ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Operation; })
40485ab413bSRiver Riddle .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
4055550c821STres Popp if (isa<pdl::TypeType>(rangeTy.getElementType()))
40685ab413bSRiver Riddle return PDLValue::Kind::TypeRange;
40785ab413bSRiver Riddle return PDLValue::Kind::ValueRange;
40885ab413bSRiver Riddle })
40985ab413bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
41085ab413bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
41185ab413bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind));
412abfd1a8bSRiver Riddle }
413abfd1a8bSRiver Riddle
414abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within
415abfd1a8bSRiver Riddle /// the bytecode.
416abfd1a8bSRiver Riddle template <typename T>
417abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418abfd1a8bSRiver Riddle std::is_pointer<T>::value>
append__anon22ebf8dc0211::ByteCodeWriter419abfd1a8bSRiver Riddle append(T value) {
420abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value));
421abfd1a8bSRiver Riddle }
422abfd1a8bSRiver Riddle
423abfd1a8bSRiver Riddle /// Append a range of values.
424abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon22ebf8dc0211::ByteCodeWriter426abfd1a8bSRiver Riddle append(T range) {
427abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range));
428abfd1a8bSRiver Riddle for (auto it : range)
429abfd1a8bSRiver Riddle append(it);
430abfd1a8bSRiver Riddle }
431abfd1a8bSRiver Riddle
432abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode.
433abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon22ebf8dc0211::ByteCodeWriter434abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435abfd1a8bSRiver Riddle append(field);
436abfd1a8bSRiver Riddle append(field2, fields...);
437abfd1a8bSRiver Riddle }
438abfd1a8bSRiver Riddle
439d35f1190SStanislav Funiak /// Appends a value as a pointer, stored inline within the bytecode.
440d35f1190SStanislav Funiak template <typename T>
441d35f1190SStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
appendInline__anon22ebf8dc0211::ByteCodeWriter442d35f1190SStanislav Funiak appendInline(T value) {
443d35f1190SStanislav Funiak constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444d35f1190SStanislav Funiak const void *pointer = value.getAsOpaquePointer();
445d35f1190SStanislav Funiak ByteCodeField fieldParts[numParts];
446d35f1190SStanislav Funiak std::memcpy(fieldParts, &pointer, sizeof(const void *));
447d35f1190SStanislav Funiak bytecode.append(fieldParts, fieldParts + numParts);
448d35f1190SStanislav Funiak }
449d35f1190SStanislav Funiak
450abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved.
451abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452abfd1a8bSRiver Riddle
453abfd1a8bSRiver Riddle /// The underlying bytecode buffer.
454abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode;
455abfd1a8bSRiver Riddle
456abfd1a8bSRiver Riddle /// The main generator producing PDL.
457abfd1a8bSRiver Riddle Generator &generator;
458abfd1a8bSRiver Riddle };
45985ab413bSRiver Riddle
46085ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing
46185ab413bSRiver Riddle /// information about when values are live within a match/rewrite.
46285ab413bSRiver Riddle struct ByteCodeLiveRange {
4633eb1647aSStanislav Funiak using Set = llvm::IntervalMap<uint64_t, char, 16>;
46485ab413bSRiver Riddle using Allocator = Set::Allocator;
46585ab413bSRiver Riddle
ByteCodeLiveRange__anon22ebf8dc0211::ByteCodeLiveRange4663eb1647aSStanislav Funiak ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
46785ab413bSRiver Riddle
46885ab413bSRiver Riddle /// Union this live range with the one provided.
unionWith__anon22ebf8dc0211::ByteCodeLiveRange46985ab413bSRiver Riddle void unionWith(const ByteCodeLiveRange &rhs) {
4703eb1647aSStanislav Funiak for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
4713eb1647aSStanislav Funiak ++it)
4723eb1647aSStanislav Funiak liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
47385ab413bSRiver Riddle }
47485ab413bSRiver Riddle
47585ab413bSRiver Riddle /// Returns true if this range overlaps with the one provided.
overlaps__anon22ebf8dc0211::ByteCodeLiveRange47685ab413bSRiver Riddle bool overlaps(const ByteCodeLiveRange &rhs) const {
4773eb1647aSStanislav Funiak return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
4783eb1647aSStanislav Funiak .valid();
47985ab413bSRiver Riddle }
48085ab413bSRiver Riddle
48185ab413bSRiver Riddle /// A map representing the ranges of the match/rewrite that a value is live in
48285ab413bSRiver Riddle /// the interpreter.
4833eb1647aSStanislav Funiak ///
4843eb1647aSStanislav Funiak /// We use std::unique_ptr here, because IntervalMap does not provide a
4853eb1647aSStanislav Funiak /// correct copy or move constructor. We can eliminate the pointer once
4863eb1647aSStanislav Funiak /// https://reviews.llvm.org/D113240 lands.
4873eb1647aSStanislav Funiak std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
4883eb1647aSStanislav Funiak
4893eb1647aSStanislav Funiak /// The operation range storage index for this range.
4900a81ace0SKazu Hirata std::optional<unsigned> opRangeIndex;
49185ab413bSRiver Riddle
49285ab413bSRiver Riddle /// The type range storage index for this range.
4930a81ace0SKazu Hirata std::optional<unsigned> typeRangeIndex;
49485ab413bSRiver Riddle
49585ab413bSRiver Riddle /// The value range storage index for this range.
4960a81ace0SKazu Hirata std::optional<unsigned> valueRangeIndex;
49785ab413bSRiver Riddle };
498be0a7e9fSMehdi Amini } // namespace
499abfd1a8bSRiver Riddle
generate(ModuleOp module)500abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
501f96a8675SRiver Riddle auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName());
505abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506abfd1a8bSRiver Riddle
507abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher
508abfd1a8bSRiver Riddle // and rewriters.
509abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule);
510abfd1a8bSRiver Riddle
511abfd1a8bSRiver Riddle // Generate code for the rewriter functions.
512abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513f96a8675SRiver Riddle for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps())
516abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter);
517abfd1a8bSRiver Riddle }
518abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519abfd1a8bSRiver Riddle "unexpected branches in rewriter function");
520abfd1a8bSRiver Riddle
521abfd1a8bSRiver Riddle // Generate code for the matcher function.
522abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
5233eb1647aSStanislav Funiak generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524abfd1a8bSRiver Riddle
525abfd1a8bSRiver Riddle // Resolve successor references in the matcher.
526abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first];
528abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second)
529abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530abfd1a8bSRiver Riddle }
531abfd1a8bSRiver Riddle }
532abfd1a8bSRiver Riddle
allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)533f96a8675SRiver Riddle void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534abfd1a8bSRiver Riddle ModuleOp rewriterModule) {
535abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to
536abfd1a8bSRiver Riddle // each result.
537f96a8675SRiver Riddle for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
53885ab413bSRiver Riddle ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
53985ab413bSRiver Riddle auto processRewriterValue = [&](Value val) {
54085ab413bSRiver Riddle valueToMemIndex.try_emplace(val, index++);
5415550c821STres Popp if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
54285ab413bSRiver Riddle Type elementTy = rangeType.getElementType();
5435550c821STres Popp if (isa<pdl::TypeType>(elementTy))
54485ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, typeRangeIndex++);
5455550c821STres Popp else if (isa<pdl::ValueType>(elementTy))
54685ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, valueRangeIndex++);
54785ab413bSRiver Riddle }
54885ab413bSRiver Riddle };
54985ab413bSRiver Riddle
550abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments())
55185ab413bSRiver Riddle processRewriterValue(arg);
552abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) {
553abfd1a8bSRiver Riddle for (Value result : op->getResults())
55485ab413bSRiver Riddle processRewriterValue(result);
555abfd1a8bSRiver Riddle });
556abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex)
557abfd1a8bSRiver Riddle maxValueMemoryIndex = index;
55885ab413bSRiver Riddle if (typeRangeIndex > maxTypeRangeMemoryIndex)
55985ab413bSRiver Riddle maxTypeRangeMemoryIndex = typeRangeIndex;
56085ab413bSRiver Riddle if (valueRangeIndex > maxValueRangeMemoryIndex)
56185ab413bSRiver Riddle maxValueRangeMemoryIndex = valueRangeIndex;
562abfd1a8bSRiver Riddle }
563abfd1a8bSRiver Riddle
564abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to
565abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining
566abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just
567abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially
568abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a
569abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used.
570b4130e9eSStanislav Funiak DenseMap<Operation *, unsigned> opToFirstIndex;
571b4130e9eSStanislav Funiak DenseMap<Operation *, unsigned> opToLastIndex;
572b4130e9eSStanislav Funiak
573b4130e9eSStanislav Funiak // A custom walk that marks the first and the last index of each operation.
574b4130e9eSStanislav Funiak // The entry marks the beginning of the liveness range for this operation,
575b4130e9eSStanislav Funiak // followed by nested operations, followed by the end of the liveness range.
576b4130e9eSStanislav Funiak unsigned index = 0;
577b4130e9eSStanislav Funiak llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578b4130e9eSStanislav Funiak opToFirstIndex.try_emplace(op, index++);
579b4130e9eSStanislav Funiak for (Region ®ion : op->getRegions())
580b4130e9eSStanislav Funiak for (Block &block : region.getBlocks())
581b4130e9eSStanislav Funiak for (Operation &nested : block)
582b4130e9eSStanislav Funiak walk(&nested);
583b4130e9eSStanislav Funiak opToLastIndex.try_emplace(op, index++);
584b4130e9eSStanislav Funiak };
585b4130e9eSStanislav Funiak walk(matcherFunc);
586abfd1a8bSRiver Riddle
587abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher.
58885ab413bSRiver Riddle ByteCodeLiveRange::Allocator allocator;
58985ab413bSRiver Riddle DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590abfd1a8bSRiver Riddle
591abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0.
592abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0);
593abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0;
594abfd1a8bSRiver Riddle
595abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used.
596abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc);
5973eb1647aSStanislav Funiak matcherFunc->walk([&](Block *block) {
5983eb1647aSStanislav Funiak const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599abfd1a8bSRiver Riddle assert(info && "expected liveness info for block");
600abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) {
601abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always
602abfd1a8bSRiver Riddle // assigned to the first memory slot.
603abfd1a8bSRiver Riddle if (value == rootOpArg)
604abfd1a8bSRiver Riddle return;
605abfd1a8bSRiver Riddle
606abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used.
607abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
6083eb1647aSStanislav Funiak defRangeIt->second.liveness->insert(
609b4130e9eSStanislav Funiak opToFirstIndex[firstUseOrDef],
610b4130e9eSStanislav Funiak opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611abfd1a8bSRiver Riddle /*dummyValue*/ 0);
61285ab413bSRiver Riddle
61385ab413bSRiver Riddle // Check to see if this value is a range type.
6145550c821STres Popp if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
61585ab413bSRiver Riddle Type eleType = rangeTy.getElementType();
6165550c821STres Popp if (isa<pdl::OperationType>(eleType))
6173eb1647aSStanislav Funiak defRangeIt->second.opRangeIndex = 0;
6185550c821STres Popp else if (isa<pdl::TypeType>(eleType))
61985ab413bSRiver Riddle defRangeIt->second.typeRangeIndex = 0;
6205550c821STres Popp else if (isa<pdl::ValueType>(eleType))
62185ab413bSRiver Riddle defRangeIt->second.valueRangeIndex = 0;
62285ab413bSRiver Riddle }
623abfd1a8bSRiver Riddle };
624abfd1a8bSRiver Riddle
625abfd1a8bSRiver Riddle // Process the live-ins of this block.
6263eb1647aSStanislav Funiak for (Value liveIn : info->in()) {
6273eb1647aSStanislav Funiak // Only process the value if it has been defined in the current region.
6283eb1647aSStanislav Funiak // Other values that span across pdl_interp.foreach will be added higher
6293eb1647aSStanislav Funiak // up. This ensures that the we keep them alive for the entire duration
6303eb1647aSStanislav Funiak // of the loop.
6313eb1647aSStanislav Funiak if (liveIn.getParentRegion() == block->getParent())
6323eb1647aSStanislav Funiak processValue(liveIn, &block->front());
6333eb1647aSStanislav Funiak }
6343eb1647aSStanislav Funiak
6353eb1647aSStanislav Funiak // Process the block arguments for the entry block (those are not live-in).
6363eb1647aSStanislav Funiak if (block->isEntryBlock()) {
6373eb1647aSStanislav Funiak for (Value argument : block->getArguments())
6383eb1647aSStanislav Funiak processValue(argument, &block->front());
6393eb1647aSStanislav Funiak }
640abfd1a8bSRiver Riddle
641abfd1a8bSRiver Riddle // Process any new defs within this block.
6423eb1647aSStanislav Funiak for (Operation &op : *block)
643abfd1a8bSRiver Riddle for (Value result : op.getResults())
644abfd1a8bSRiver Riddle processValue(result, &op);
6453eb1647aSStanislav Funiak });
646abfd1a8bSRiver Riddle
647abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges.
64885ab413bSRiver Riddle std::vector<ByteCodeLiveRange> allocatedIndices;
6493eb1647aSStanislav Funiak
6503eb1647aSStanislav Funiak // The number of memory indices currently allocated (and its next value).
6513eb1647aSStanislav Funiak // Recall that the root gets allocated memory index 0.
6523eb1647aSStanislav Funiak ByteCodeField numIndices = 1;
6533eb1647aSStanislav Funiak
6543eb1647aSStanislav Funiak // The number of memory ranges of various types (and their next values).
6553eb1647aSStanislav Funiak ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
6563eb1647aSStanislav Funiak
657abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) {
658abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first];
65985ab413bSRiver Riddle ByteCodeLiveRange &defRange = defIt.second;
660abfd1a8bSRiver Riddle
661abfd1a8bSRiver Riddle // Try to allocate to an existing index.
662e4853be2SMehdi Amini for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
66385ab413bSRiver Riddle ByteCodeLiveRange &existingRange = existingIndexIt.value();
66485ab413bSRiver Riddle if (!defRange.overlaps(existingRange)) {
66585ab413bSRiver Riddle existingRange.unionWith(defRange);
666abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1;
66785ab413bSRiver Riddle
6683eb1647aSStanislav Funiak if (defRange.opRangeIndex) {
6693eb1647aSStanislav Funiak if (!existingRange.opRangeIndex)
6703eb1647aSStanislav Funiak existingRange.opRangeIndex = numOpRanges++;
6713eb1647aSStanislav Funiak valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
6723eb1647aSStanislav Funiak } else if (defRange.typeRangeIndex) {
67385ab413bSRiver Riddle if (!existingRange.typeRangeIndex)
67485ab413bSRiver Riddle existingRange.typeRangeIndex = numTypeRanges++;
67585ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
67685ab413bSRiver Riddle } else if (defRange.valueRangeIndex) {
67785ab413bSRiver Riddle if (!existingRange.valueRangeIndex)
67885ab413bSRiver Riddle existingRange.valueRangeIndex = numValueRanges++;
67985ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
68085ab413bSRiver Riddle }
68185ab413bSRiver Riddle break;
68285ab413bSRiver Riddle }
683abfd1a8bSRiver Riddle }
684abfd1a8bSRiver Riddle
685abfd1a8bSRiver Riddle // If no existing index could be used, add a new one.
686abfd1a8bSRiver Riddle if (memIndex == 0) {
687abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator);
68885ab413bSRiver Riddle ByteCodeLiveRange &newRange = allocatedIndices.back();
68985ab413bSRiver Riddle newRange.unionWith(defRange);
69085ab413bSRiver Riddle
6913eb1647aSStanislav Funiak // Allocate an index for op/type/value ranges.
6923eb1647aSStanislav Funiak if (defRange.opRangeIndex) {
6933eb1647aSStanislav Funiak newRange.opRangeIndex = numOpRanges;
6943eb1647aSStanislav Funiak valueToRangeIndex[defIt.first] = numOpRanges++;
6953eb1647aSStanislav Funiak } else if (defRange.typeRangeIndex) {
69685ab413bSRiver Riddle newRange.typeRangeIndex = numTypeRanges;
69785ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numTypeRanges++;
69885ab413bSRiver Riddle } else if (defRange.valueRangeIndex) {
69985ab413bSRiver Riddle newRange.valueRangeIndex = numValueRanges;
70085ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numValueRanges++;
70185ab413bSRiver Riddle }
70285ab413bSRiver Riddle
703abfd1a8bSRiver Riddle memIndex = allocatedIndices.size();
70485ab413bSRiver Riddle ++numIndices;
705abfd1a8bSRiver Riddle }
706abfd1a8bSRiver Riddle }
707abfd1a8bSRiver Riddle
7083eb1647aSStanislav Funiak // Print the index usage and ensure that we did not run out of index space.
7093eb1647aSStanislav Funiak LLVM_DEBUG({
7103eb1647aSStanislav Funiak llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
7113eb1647aSStanislav Funiak << "(down from initial " << valueDefRanges.size() << ").\n";
7123eb1647aSStanislav Funiak });
7133eb1647aSStanislav Funiak assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
7143eb1647aSStanislav Funiak "Ran out of memory for allocated indices");
7153eb1647aSStanislav Funiak
716abfd1a8bSRiver Riddle // Update the max number of indices.
71785ab413bSRiver Riddle if (numIndices > maxValueMemoryIndex)
71885ab413bSRiver Riddle maxValueMemoryIndex = numIndices;
7193eb1647aSStanislav Funiak if (numOpRanges > maxOpRangeMemoryIndex)
7203eb1647aSStanislav Funiak maxOpRangeMemoryIndex = numOpRanges;
72185ab413bSRiver Riddle if (numTypeRanges > maxTypeRangeMemoryIndex)
72285ab413bSRiver Riddle maxTypeRangeMemoryIndex = numTypeRanges;
72385ab413bSRiver Riddle if (numValueRanges > maxValueRangeMemoryIndex)
72485ab413bSRiver Riddle maxValueRangeMemoryIndex = numValueRanges;
725abfd1a8bSRiver Riddle }
726abfd1a8bSRiver Riddle
generate(Region * region,ByteCodeWriter & writer)7273eb1647aSStanislav Funiak void Generator::generate(Region *region, ByteCodeWriter &writer) {
7283eb1647aSStanislav Funiak llvm::ReversePostOrderTraversal<Region *> rpot(region);
7293eb1647aSStanislav Funiak for (Block *block : rpot) {
7303eb1647aSStanislav Funiak // Keep track of where this block begins within the matcher function.
7313eb1647aSStanislav Funiak blockToAddr.try_emplace(block, matcherByteCode.size());
7323eb1647aSStanislav Funiak for (Operation &op : *block)
7333eb1647aSStanislav Funiak generate(&op, writer);
7343eb1647aSStanislav Funiak }
7353eb1647aSStanislav Funiak }
7363eb1647aSStanislav Funiak
generate(Operation * op,ByteCodeWriter & writer)737abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738d35f1190SStanislav Funiak LLVM_DEBUG({
739d35f1190SStanislav Funiak // The following list must contain all the operations that do not
740d35f1190SStanislav Funiak // produce any bytecode.
7413c752289SRiver Riddle if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742d35f1190SStanislav Funiak writer.appendInline(op->getLoc());
743d35f1190SStanislav Funiak });
744abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op)
745abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
74985ab413bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
7503eb1647aSStanislav Funiak pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751ce57789dSRiver Riddle pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752ce57789dSRiver Riddle pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753ce57789dSRiver Riddle pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
7543eb1647aSStanislav Funiak pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
7553eb1647aSStanislav Funiak pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
7563eb1647aSStanislav Funiak pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
7573eb1647aSStanislav Funiak pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
7583eb1647aSStanislav Funiak pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
7593c752289SRiver Riddle pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
7603c752289SRiver Riddle pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
7613c752289SRiver Riddle pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
7623c752289SRiver Riddle pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
7633c752289SRiver Riddle pdl_interp::SwitchResultCountOp>(
764abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); })
765abfd1a8bSRiver Riddle .Default([](Operation *) {
766abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation");
767abfd1a8bSRiver Riddle });
768abfd1a8bSRiver Riddle }
769abfd1a8bSRiver Riddle
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)770abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
771abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
7728ec28af8SMatthias Gehre // Constraints that should return a value have to be registered as rewrites.
7738ec28af8SMatthias Gehre // If a constraint and a rewrite of similar name are registered the
7748ec28af8SMatthias Gehre // constraint takes precedence
7759595f356SRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
7763c405c3bSRiver Riddle writer.appendPDLValueList(op.getArgs());
7776d2b2b8eSMartin Lücke writer.append(ByteCodeField(op.getIsNegated()));
7788ec28af8SMatthias Gehre ResultRange results = op.getResults();
7798ec28af8SMatthias Gehre writer.append(ByteCodeField(results.size()));
7808ec28af8SMatthias Gehre for (Value result : results) {
7818ec28af8SMatthias Gehre // We record the expected kind of the result, so that we can provide extra
7828ec28af8SMatthias Gehre // verification of the native rewrite function and handle the failure case
7838ec28af8SMatthias Gehre // of constraints accordingly.
7848ec28af8SMatthias Gehre writer.appendPDLValueKind(result);
7858ec28af8SMatthias Gehre
7868ec28af8SMatthias Gehre // Range results also need to append the range storage index.
7878ec28af8SMatthias Gehre if (isa<pdl::RangeType>(result.getType()))
7888ec28af8SMatthias Gehre writer.append(getRangeStorageIndex(result));
7898ec28af8SMatthias Gehre writer.append(result);
7908ec28af8SMatthias Gehre }
791abfd1a8bSRiver Riddle writer.append(op.getSuccessors());
792abfd1a8bSRiver Riddle }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)793abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
794abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
7953c405c3bSRiver Riddle assert(externalRewriterToMemIndex.count(op.getName()) &&
796abfd1a8bSRiver Riddle "expected index for rewrite function");
7979595f356SRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
7983c405c3bSRiver Riddle writer.appendPDLValueList(op.getArgs());
79902c4c0d5SRiver Riddle
8003c405c3bSRiver Riddle ResultRange results = op.getResults();
80185ab413bSRiver Riddle writer.append(ByteCodeField(results.size()));
80285ab413bSRiver Riddle for (Value result : results) {
8038ec28af8SMatthias Gehre // We record the expected kind of the result, so that we
80485ab413bSRiver Riddle // can provide extra verification of the native rewrite function.
80585ab413bSRiver Riddle writer.appendPDLValueKind(result);
80685ab413bSRiver Riddle
80785ab413bSRiver Riddle // Range results also need to append the range storage index.
8085550c821STres Popp if (isa<pdl::RangeType>(result.getType()))
80985ab413bSRiver Riddle writer.append(getRangeStorageIndex(result));
81002c4c0d5SRiver Riddle writer.append(result);
811abfd1a8bSRiver Riddle }
81285ab413bSRiver Riddle }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)813abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
8143c405c3bSRiver Riddle Value lhs = op.getLhs();
8155550c821STres Popp if (isa<pdl::RangeType>(lhs.getType())) {
81685ab413bSRiver Riddle writer.append(OpCode::AreRangesEqual);
81785ab413bSRiver Riddle writer.appendPDLValueKind(lhs);
8183c405c3bSRiver Riddle writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
81985ab413bSRiver Riddle return;
82085ab413bSRiver Riddle }
82185ab413bSRiver Riddle
8223c405c3bSRiver Riddle writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823abfd1a8bSRiver Riddle }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)824abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
8258affe881SRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)827abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
828abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
8293c405c3bSRiver Riddle writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830abfd1a8bSRiver Riddle op.getSuccessors());
831abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)832abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
833abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
8343c405c3bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
8353c405c3bSRiver Riddle static_cast<ByteCodeField>(op.getCompareAtLeast()),
836abfd1a8bSRiver Riddle op.getSuccessors());
837abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)838abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
839abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
8403c405c3bSRiver Riddle writer.append(OpCode::CheckOperationName, op.getInputOp(),
8413c405c3bSRiver Riddle OperationName(op.getName(), ctx), op.getSuccessors());
842abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)843abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
844abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
8453c405c3bSRiver Riddle writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
8463c405c3bSRiver Riddle static_cast<ByteCodeField>(op.getCompareAtLeast()),
847abfd1a8bSRiver Riddle op.getSuccessors());
848abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)849abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
8503c405c3bSRiver Riddle writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
8513c405c3bSRiver Riddle op.getSuccessors());
852abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)85385ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
8543c405c3bSRiver Riddle writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
8553c405c3bSRiver Riddle op.getSuccessors());
85685ab413bSRiver Riddle }
generate(pdl_interp::ContinueOp op,ByteCodeWriter & writer)8573eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
8583eb1647aSStanislav Funiak assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
8593eb1647aSStanislav Funiak writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
8603eb1647aSStanislav Funiak }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)861abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
862abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
863abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant.
8643c405c3bSRiver Riddle getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)866abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
867abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
8683c405c3bSRiver Riddle writer.append(OpCode::CreateOperation, op.getResultOp(),
8693c405c3bSRiver Riddle OperationName(op.getName(), ctx));
8703c405c3bSRiver Riddle writer.appendPDLValueList(op.getInputOperands());
871abfd1a8bSRiver Riddle
872abfd1a8bSRiver Riddle // Add the attributes.
8733c405c3bSRiver Riddle OperandRange attributes = op.getInputAttributes();
874abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size()));
8753c405c3bSRiver Riddle for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876195730a6SRiver Riddle writer.append(std::get<0>(it), std::get<1>(it));
8773c752289SRiver Riddle
8783c752289SRiver Riddle // Add the result types. If the operation has inferred results, we use a
8793c752289SRiver Riddle // marker "size" value. Otherwise, we add the list of explicit result types.
8803c752289SRiver Riddle if (op.getInferredResultTypes())
8813c752289SRiver Riddle writer.append(kInferTypesMarker);
8823c752289SRiver Riddle else
8833c405c3bSRiver Riddle writer.appendPDLValueList(op.getInputResultTypes());
884abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateRangeOp op,ByteCodeWriter & writer)885ce57789dSRiver Riddle void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886ce57789dSRiver Riddle // Append the correct opcode for the range type.
887ce57789dSRiver Riddle TypeSwitch<Type>(op.getType().getElementType())
888ce57789dSRiver Riddle .Case(
889ce57789dSRiver Riddle [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890ce57789dSRiver Riddle .Case([&](pdl::ValueType) {
891ce57789dSRiver Riddle writer.append(OpCode::CreateDynamicValueRange);
892ce57789dSRiver Riddle });
893ce57789dSRiver Riddle
894ce57789dSRiver Riddle writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895ce57789dSRiver Riddle writer.appendPDLValueList(op->getOperands());
896ce57789dSRiver Riddle }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)897abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant.
8993c405c3bSRiver Riddle getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)90185ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902ce57789dSRiver Riddle writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
9033c405c3bSRiver Riddle getRangeStorageIndex(op.getResult()), op.getValue());
90485ab413bSRiver Riddle }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)905abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
9063c405c3bSRiver Riddle writer.append(OpCode::EraseOp, op.getInputOp());
907abfd1a8bSRiver Riddle }
generate(pdl_interp::ExtractOp op,ByteCodeWriter & writer)9083eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
9093eb1647aSStanislav Funiak OpCode opCode =
9103c405c3bSRiver Riddle TypeSwitch<Type, OpCode>(op.getResult().getType())
9113eb1647aSStanislav Funiak .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
9123eb1647aSStanislav Funiak .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
9133eb1647aSStanislav Funiak .Case([](pdl::TypeType) { return OpCode::ExtractType; })
9143eb1647aSStanislav Funiak .Default([](Type) -> OpCode {
9153eb1647aSStanislav Funiak llvm_unreachable("unsupported element type");
9163eb1647aSStanislav Funiak });
9173c405c3bSRiver Riddle writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
9183eb1647aSStanislav Funiak }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)919abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
920abfd1a8bSRiver Riddle writer.append(OpCode::Finalize);
921abfd1a8bSRiver Riddle }
generate(pdl_interp::ForEachOp op,ByteCodeWriter & writer)9223eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
9233eb1647aSStanislav Funiak BlockArgument arg = op.getLoopVariable();
9243c405c3bSRiver Riddle writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
9253eb1647aSStanislav Funiak writer.appendPDLValueKind(arg.getType());
9263c405c3bSRiver Riddle writer.append(curLoopLevel, op.getSuccessor());
9273eb1647aSStanislav Funiak ++curLoopLevel;
9283eb1647aSStanislav Funiak if (curLoopLevel > maxLoopLevel)
9293eb1647aSStanislav Funiak maxLoopLevel = curLoopLevel;
9303c405c3bSRiver Riddle generate(&op.getRegion(), writer);
9313eb1647aSStanislav Funiak --curLoopLevel;
9323eb1647aSStanislav Funiak }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)933abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
934abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
9353c405c3bSRiver Riddle writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
9363c405c3bSRiver Riddle op.getNameAttr());
937abfd1a8bSRiver Riddle }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)938abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
939abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
9403c405c3bSRiver Riddle writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
941abfd1a8bSRiver Riddle }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)942abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
943abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
9443c405c3bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.getInputOp());
9453c405c3bSRiver Riddle writer.appendPDLValue(op.getValue());
946abfd1a8bSRiver Riddle }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)947abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
9483c405c3bSRiver Riddle uint32_t index = op.getIndex();
949abfd1a8bSRiver Riddle if (index < 4)
950abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
951abfd1a8bSRiver Riddle else
952abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index);
9533c405c3bSRiver Riddle writer.append(op.getInputOp(), op.getValue());
954abfd1a8bSRiver Riddle }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)95585ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
9563c405c3bSRiver Riddle Value result = op.getValue();
95722426110SRamkumar Ramachandra std::optional<uint32_t> index = op.getIndex();
95885ab413bSRiver Riddle writer.append(OpCode::GetOperands,
95930c67587SKazu Hirata index.value_or(std::numeric_limits<uint32_t>::max()),
9603c405c3bSRiver Riddle op.getInputOp());
9615550c821STres Popp if (isa<pdl::RangeType>(result.getType()))
96285ab413bSRiver Riddle writer.append(getRangeStorageIndex(result));
96385ab413bSRiver Riddle else
96485ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max());
96585ab413bSRiver Riddle writer.append(result);
96685ab413bSRiver Riddle }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)967abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
9683c405c3bSRiver Riddle uint32_t index = op.getIndex();
969abfd1a8bSRiver Riddle if (index < 4)
970abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
971abfd1a8bSRiver Riddle else
972abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index);
9733c405c3bSRiver Riddle writer.append(op.getInputOp(), op.getValue());
974abfd1a8bSRiver Riddle }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)97585ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
9763c405c3bSRiver Riddle Value result = op.getValue();
97722426110SRamkumar Ramachandra std::optional<uint32_t> index = op.getIndex();
97885ab413bSRiver Riddle writer.append(OpCode::GetResults,
97930c67587SKazu Hirata index.value_or(std::numeric_limits<uint32_t>::max()),
9803c405c3bSRiver Riddle op.getInputOp());
9815550c821STres Popp if (isa<pdl::RangeType>(result.getType()))
98285ab413bSRiver Riddle writer.append(getRangeStorageIndex(result));
98385ab413bSRiver Riddle else
98485ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max());
98585ab413bSRiver Riddle writer.append(result);
98685ab413bSRiver Riddle }
generate(pdl_interp::GetUsersOp op,ByteCodeWriter & writer)9873eb1647aSStanislav Funiak void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
9883c405c3bSRiver Riddle Value operations = op.getOperations();
9893eb1647aSStanislav Funiak ByteCodeField rangeIndex = getRangeStorageIndex(operations);
9903eb1647aSStanislav Funiak writer.append(OpCode::GetUsers, operations, rangeIndex);
9913c405c3bSRiver Riddle writer.appendPDLValue(op.getValue());
9923eb1647aSStanislav Funiak }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)993abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
994abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
9955550c821STres Popp if (isa<pdl::RangeType>(op.getType())) {
9963c405c3bSRiver Riddle Value result = op.getResult();
99785ab413bSRiver Riddle writer.append(OpCode::GetValueRangeTypes, result,
9983c405c3bSRiver Riddle getRangeStorageIndex(result), op.getValue());
99985ab413bSRiver Riddle } else {
10003c405c3bSRiver Riddle writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001abfd1a8bSRiver Riddle }
100285ab413bSRiver Riddle }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)1003abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
10043c405c3bSRiver Riddle writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1005abfd1a8bSRiver Riddle }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)1006abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1007abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size();
1008abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create(
10098c66344eSRiver Riddle op, configMap.lookup(op),
10108c66344eSRiver Riddle rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
10118affe881SRiver Riddle writer.append(OpCode::RecordMatch, patternIndex,
10123c405c3bSRiver Riddle SuccessorRange(op.getOperation()), op.getMatchedOps());
10133c405c3bSRiver Riddle writer.appendPDLValueList(op.getInputs());
1014abfd1a8bSRiver Riddle }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)1015abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
10163c405c3bSRiver Riddle writer.append(OpCode::ReplaceOp, op.getInputOp());
10173c405c3bSRiver Riddle writer.appendPDLValueList(op.getReplValues());
1018abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)1019abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
1020abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
10213c405c3bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.getAttribute(),
10223c405c3bSRiver Riddle op.getCaseValuesAttr(), op.getSuccessors());
1023abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)1024abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1025abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
10263c405c3bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
10273c405c3bSRiver Riddle op.getCaseValuesAttr(), op.getSuccessors());
1028abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)1029abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1030abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
10313c405c3bSRiver Riddle auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
10325550c821STres Popp return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1033abfd1a8bSRiver Riddle });
10343c405c3bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1035abfd1a8bSRiver Riddle op.getSuccessors());
1036abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)1037abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
1038abfd1a8bSRiver Riddle ByteCodeWriter &writer) {
10393c405c3bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.getInputOp(),
10403c405c3bSRiver Riddle op.getCaseValuesAttr(), op.getSuccessors());
1041abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)1042abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
10433c405c3bSRiver Riddle writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1044abfd1a8bSRiver Riddle op.getSuccessors());
1045abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)104685ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
10473c405c3bSRiver Riddle writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
104885ab413bSRiver Riddle op.getSuccessors());
104985ab413bSRiver Riddle }
1050abfd1a8bSRiver Riddle
1051abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1052abfd1a8bSRiver Riddle // PDLByteCode
1053abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1054abfd1a8bSRiver Riddle
PDLByteCode(ModuleOp module,SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,const DenseMap<Operation *,PDLPatternConfigSet * > & configMap,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)10558c66344eSRiver Riddle PDLByteCode::PDLByteCode(
10568c66344eSRiver Riddle ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
10578c66344eSRiver Riddle const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1058abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns,
10598c66344eSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns)
10608c66344eSRiver Riddle : configs(std::move(configs)) {
1061abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode,
1062abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex,
10633eb1647aSStanislav Funiak maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
10648c66344eSRiver Riddle maxLoopLevel, constraintFns, rewriteFns, configMap);
1065abfd1a8bSRiver Riddle generator.generate(module);
1066abfd1a8bSRiver Riddle
1067abfd1a8bSRiver Riddle // Initialize the external functions.
1068abfd1a8bSRiver Riddle for (auto &it : constraintFns)
1069abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second));
1070abfd1a8bSRiver Riddle for (auto &it : rewriteFns)
1071abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second));
1072abfd1a8bSRiver Riddle }
1073abfd1a8bSRiver Riddle
1074abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
1075abfd1a8bSRiver Riddle /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const1076abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1077abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr);
10783eb1647aSStanislav Funiak state.opRangeMemory.resize(maxOpRangeCount);
107985ab413bSRiver Riddle state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
108085ab413bSRiver Riddle state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
10813eb1647aSStanislav Funiak state.loopIndex.resize(maxLoopLevel, 0);
1082abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size());
1083abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns)
1084abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit());
1085abfd1a8bSRiver Riddle }
1086abfd1a8bSRiver Riddle
1087abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1088abfd1a8bSRiver Riddle // ByteCode Execution
1089abfd1a8bSRiver Riddle
1090abfd1a8bSRiver Riddle namespace {
10918ec28af8SMatthias Gehre /// This class is an instantiation of the PDLResultList that provides access to
10928ec28af8SMatthias Gehre /// the returned results. This API is not on `PDLResultList` to avoid
10938ec28af8SMatthias Gehre /// overexposing access to information specific solely to the ByteCode.
10948ec28af8SMatthias Gehre class ByteCodeRewriteResultList : public PDLResultList {
10958ec28af8SMatthias Gehre public:
ByteCodeRewriteResultList(unsigned maxNumResults)10968ec28af8SMatthias Gehre ByteCodeRewriteResultList(unsigned maxNumResults)
10978ec28af8SMatthias Gehre : PDLResultList(maxNumResults) {}
10988ec28af8SMatthias Gehre
10998ec28af8SMatthias Gehre /// Return the list of PDL results.
getResults()11008ec28af8SMatthias Gehre MutableArrayRef<PDLValue> getResults() { return results; }
11018ec28af8SMatthias Gehre
11028ec28af8SMatthias Gehre /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()11038ec28af8SMatthias Gehre MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
11048ec28af8SMatthias Gehre return allocatedTypeRanges;
11058ec28af8SMatthias Gehre }
11068ec28af8SMatthias Gehre
11078ec28af8SMatthias Gehre /// Return the value ranges allocated by this list.
getAllocatedValueRanges()11088ec28af8SMatthias Gehre MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
11098ec28af8SMatthias Gehre return allocatedValueRanges;
11108ec28af8SMatthias Gehre }
11118ec28af8SMatthias Gehre };
11128ec28af8SMatthias Gehre
1113abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
1114abfd1a8bSRiver Riddle class ByteCodeExecutor {
1115abfd1a8bSRiver Riddle public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<llvm::OwningArrayRef<Operation * >> opRangeMemory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,MutableArrayRef<unsigned> loopIndex,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)111685ab413bSRiver Riddle ByteCodeExecutor(
111785ab413bSRiver Riddle const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
11183eb1647aSStanislav Funiak MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
111985ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory,
112085ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
112185ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory,
112285ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
11233eb1647aSStanislav Funiak MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
11243eb1647aSStanislav Funiak ArrayRef<ByteCodeField> code,
1125abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits,
1126abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns,
1127abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions,
1128abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions)
11293eb1647aSStanislav Funiak : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
11303eb1647aSStanislav Funiak typeRangeMemory(typeRangeMemory),
113185ab413bSRiver Riddle allocatedTypeRangeMemory(allocatedTypeRangeMemory),
113285ab413bSRiver Riddle valueRangeMemory(valueRangeMemory),
113385ab413bSRiver Riddle allocatedValueRangeMemory(allocatedValueRangeMemory),
11343eb1647aSStanislav Funiak loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
113585ab413bSRiver Riddle currentPatternBenefits(currentPatternBenefits), patterns(patterns),
113685ab413bSRiver Riddle constraintFunctions(constraintFunctions),
113702c4c0d5SRiver Riddle rewriteFunctions(rewriteFunctions) {}
1138abfd1a8bSRiver Riddle
1139abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an
1140abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching
1141abfd1a8bSRiver Riddle /// context.
11428c66344eSRiver Riddle LogicalResult
11438c66344eSRiver Riddle execute(PatternRewriter &rewriter,
1144abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
11450a81ace0SKazu Hirata std::optional<Location> mainRewriteLoc = {});
1146abfd1a8bSRiver Riddle
1147abfd1a8bSRiver Riddle private:
1148154cabe7SRiver Riddle /// Internal implementation of executing each of the bytecode commands.
1149154cabe7SRiver Riddle void executeApplyConstraint(PatternRewriter &rewriter);
11508c66344eSRiver Riddle LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151154cabe7SRiver Riddle void executeAreEqual();
115285ab413bSRiver Riddle void executeAreRangesEqual();
1153154cabe7SRiver Riddle void executeBranch();
1154154cabe7SRiver Riddle void executeCheckOperandCount();
1155154cabe7SRiver Riddle void executeCheckOperationName();
1156154cabe7SRiver Riddle void executeCheckResultCount();
115785ab413bSRiver Riddle void executeCheckTypes();
11583eb1647aSStanislav Funiak void executeContinue();
1159ce57789dSRiver Riddle void executeCreateConstantTypeRange();
1160154cabe7SRiver Riddle void executeCreateOperation(PatternRewriter &rewriter,
1161154cabe7SRiver Riddle Location mainRewriteLoc);
1162ce57789dSRiver Riddle template <typename T>
1163ce57789dSRiver Riddle void executeDynamicCreateRange(StringRef type);
1164154cabe7SRiver Riddle void executeEraseOp(PatternRewriter &rewriter);
11653eb1647aSStanislav Funiak template <typename T, typename Range, PDLValue::Kind kind>
11663eb1647aSStanislav Funiak void executeExtract();
11673eb1647aSStanislav Funiak void executeFinalize();
11683eb1647aSStanislav Funiak void executeForEach();
1169154cabe7SRiver Riddle void executeGetAttribute();
1170154cabe7SRiver Riddle void executeGetAttributeType();
1171154cabe7SRiver Riddle void executeGetDefiningOp();
1172154cabe7SRiver Riddle void executeGetOperand(unsigned index);
117385ab413bSRiver Riddle void executeGetOperands();
1174154cabe7SRiver Riddle void executeGetResult(unsigned index);
117585ab413bSRiver Riddle void executeGetResults();
11763eb1647aSStanislav Funiak void executeGetUsers();
1177154cabe7SRiver Riddle void executeGetValueType();
117885ab413bSRiver Riddle void executeGetValueRangeTypes();
1179154cabe7SRiver Riddle void executeIsNotNull();
1180154cabe7SRiver Riddle void executeRecordMatch(PatternRewriter &rewriter,
1181154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182154cabe7SRiver Riddle void executeReplaceOp(PatternRewriter &rewriter);
1183154cabe7SRiver Riddle void executeSwitchAttribute();
1184154cabe7SRiver Riddle void executeSwitchOperandCount();
1185154cabe7SRiver Riddle void executeSwitchOperationName();
1186154cabe7SRiver Riddle void executeSwitchResultCount();
1187154cabe7SRiver Riddle void executeSwitchType();
118885ab413bSRiver Riddle void executeSwitchTypes();
11898ec28af8SMatthias Gehre void processNativeFunResults(ByteCodeRewriteResultList &results,
11908ec28af8SMatthias Gehre unsigned numResults,
11918ec28af8SMatthias Gehre LogicalResult &rewriteResult);
1192154cabe7SRiver Riddle
11933eb1647aSStanislav Funiak /// Pushes a code iterator to the stack.
pushCodeIt(const ByteCodeField * it)11943eb1647aSStanislav Funiak void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
11953eb1647aSStanislav Funiak
11963eb1647aSStanislav Funiak /// Pops a code iterator from the stack, returning true on success.
popCodeIt()11973eb1647aSStanislav Funiak void popCodeIt() {
11983eb1647aSStanislav Funiak assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
11993eb1647aSStanislav Funiak curCodeIt = resumeCodeIt.back();
12003eb1647aSStanislav Funiak resumeCodeIt.pop_back();
12013eb1647aSStanislav Funiak }
12023eb1647aSStanislav Funiak
1203d35f1190SStanislav Funiak /// Return the bytecode iterator at the start of the current op code.
getPrevCodeIt() const1204d35f1190SStanislav Funiak const ByteCodeField *getPrevCodeIt() const {
1205d35f1190SStanislav Funiak LLVM_DEBUG({
1206d35f1190SStanislav Funiak // Account for the op code and the Location stored inline.
1207d35f1190SStanislav Funiak return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1208d35f1190SStanislav Funiak });
1209d35f1190SStanislav Funiak
1210d35f1190SStanislav Funiak // Account for the op code only.
1211d35f1190SStanislav Funiak return curCodeIt - 1;
1212d35f1190SStanislav Funiak }
1213d35f1190SStanislav Funiak
1214abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain
1215abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point
1216abfd1a8bSRiver Riddle /// to the next field after the read data.
1217abfd1a8bSRiver Riddle template <typename T = ByteCodeField>
read(size_t skipN=0)1218abfd1a8bSRiver Riddle T read(size_t skipN = 0) {
1219abfd1a8bSRiver Riddle curCodeIt += skipN;
1220abfd1a8bSRiver Riddle return readImpl<T>();
1221abfd1a8bSRiver Riddle }
read(size_t skipN=0)1222abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1223abfd1a8bSRiver Riddle
1224abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer.
1225abfd1a8bSRiver Riddle template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)1226abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) {
1227abfd1a8bSRiver Riddle list.clear();
1228abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i)
1229abfd1a8bSRiver Riddle list.push_back(read<ValueT>());
1230abfd1a8bSRiver Riddle }
1231abfd1a8bSRiver Riddle
123285ab413bSRiver Riddle /// Read a list of values from the bytecode buffer. The values may be encoded
1233ce57789dSRiver Riddle /// either as a single element or a range of elements.
readList(SmallVectorImpl<Type> & list)1234ce57789dSRiver Riddle void readList(SmallVectorImpl<Type> &list) {
1235ce57789dSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) {
1236ce57789dSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1237ce57789dSRiver Riddle list.push_back(read<Type>());
1238ce57789dSRiver Riddle } else {
1239ce57789dSRiver Riddle TypeRange *values = read<TypeRange *>();
1240ce57789dSRiver Riddle list.append(values->begin(), values->end());
1241ce57789dSRiver Riddle }
1242ce57789dSRiver Riddle }
1243ce57789dSRiver Riddle }
readList(SmallVectorImpl<Value> & list)1244ce57789dSRiver Riddle void readList(SmallVectorImpl<Value> &list) {
124585ab413bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) {
124685ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
124785ab413bSRiver Riddle list.push_back(read<Value>());
124885ab413bSRiver Riddle } else {
124985ab413bSRiver Riddle ValueRange *values = read<ValueRange *>();
125085ab413bSRiver Riddle list.append(values->begin(), values->end());
125185ab413bSRiver Riddle }
125285ab413bSRiver Riddle }
125385ab413bSRiver Riddle }
125485ab413bSRiver Riddle
1255d35f1190SStanislav Funiak /// Read a value stored inline as a pointer.
1256d35f1190SStanislav Funiak template <typename T>
1257d35f1190SStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
readInline()1258d35f1190SStanislav Funiak readInline() {
1259d35f1190SStanislav Funiak const void *pointer;
1260d35f1190SStanislav Funiak std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1261d35f1190SStanislav Funiak curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1262d35f1190SStanislav Funiak return T::getFromOpaquePointer(pointer);
1263d35f1190SStanislav Funiak }
1264d35f1190SStanislav Funiak
skip(size_t skipN)12658ec28af8SMatthias Gehre void skip(size_t skipN) { curCodeIt += skipN; }
12668ec28af8SMatthias Gehre
1267abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)1268abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1269abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)1270abfd1a8bSRiver Riddle void selectJump(size_t destIndex) {
1271abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1272abfd1a8bSRiver Riddle }
1273abfd1a8bSRiver Riddle
1274abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases.
127585ab413bSRiver Riddle template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})127685ab413bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1277abfd1a8bSRiver Riddle LLVM_DEBUG({
1278abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n"
1279abfd1a8bSRiver Riddle << " * Cases: ";
1280abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs());
1281154cabe7SRiver Riddle llvm::dbgs() << "\n";
1282abfd1a8bSRiver Riddle });
1283abfd1a8bSRiver Riddle
1284abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to
1285abfd1a8bSRiver Riddle // the correct successor index based on the result.
1286f80b6304SRiver Riddle for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
128785ab413bSRiver Riddle if (cmp(*it, value))
1288f80b6304SRiver Riddle return selectJump(size_t((it - cases.begin()) + 1));
1289f80b6304SRiver Riddle selectJump(size_t(0));
1290abfd1a8bSRiver Riddle }
1291abfd1a8bSRiver Riddle
12923eb1647aSStanislav Funiak /// Store a pointer to memory.
storeToMemory(unsigned index,const void * value)12933eb1647aSStanislav Funiak void storeToMemory(unsigned index, const void *value) {
12943eb1647aSStanislav Funiak memory[index] = value;
12953eb1647aSStanislav Funiak }
12963eb1647aSStanislav Funiak
12973eb1647aSStanislav Funiak /// Store a value to memory as an opaque pointer.
12983eb1647aSStanislav Funiak template <typename T>
12993eb1647aSStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
storeToMemory(unsigned index,T value)13003eb1647aSStanislav Funiak storeToMemory(unsigned index, T value) {
13013eb1647aSStanislav Funiak memory[index] = value.getAsOpaquePointer();
13023eb1647aSStanislav Funiak }
13033eb1647aSStanislav Funiak
1304abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode
1305abfd1a8bSRiver Riddle /// stream.
1306abfd1a8bSRiver Riddle template <typename T>
readFromMemory()1307abfd1a8bSRiver Riddle const void *readFromMemory() {
1308abfd1a8bSRiver Riddle size_t index = *curCodeIt++;
1309abfd1a8bSRiver Riddle
1310abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory.
131185ab413bSRiver Riddle if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
131285ab413bSRiver Riddle Value>::value ||
131385ab413bSRiver Riddle index < memory.size())
1314abfd1a8bSRiver Riddle return memory[index];
1315abfd1a8bSRiver Riddle
1316abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued.
1317abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()];
1318abfd1a8bSRiver Riddle }
1319abfd1a8bSRiver Riddle template <typename T>
readImpl()1320abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1321abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1322abfd1a8bSRiver Riddle }
1323abfd1a8bSRiver Riddle template <typename T>
1324abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1325abfd1a8bSRiver Riddle T>
readImpl()1326abfd1a8bSRiver Riddle readImpl() {
1327abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>()));
1328abfd1a8bSRiver Riddle }
1329abfd1a8bSRiver Riddle template <typename T>
readImpl()1330abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
133185ab413bSRiver Riddle switch (read<PDLValue::Kind>()) {
133285ab413bSRiver Riddle case PDLValue::Kind::Attribute:
1333abfd1a8bSRiver Riddle return read<Attribute>();
133485ab413bSRiver Riddle case PDLValue::Kind::Operation:
1335abfd1a8bSRiver Riddle return read<Operation *>();
133685ab413bSRiver Riddle case PDLValue::Kind::Type:
1337abfd1a8bSRiver Riddle return read<Type>();
133885ab413bSRiver Riddle case PDLValue::Kind::Value:
1339abfd1a8bSRiver Riddle return read<Value>();
134085ab413bSRiver Riddle case PDLValue::Kind::TypeRange:
134185ab413bSRiver Riddle return read<TypeRange *>();
134285ab413bSRiver Riddle case PDLValue::Kind::ValueRange:
134385ab413bSRiver Riddle return read<ValueRange *>();
1344abfd1a8bSRiver Riddle }
134585ab413bSRiver Riddle llvm_unreachable("unhandled PDLValue::Kind");
1346abfd1a8bSRiver Riddle }
1347abfd1a8bSRiver Riddle template <typename T>
readImpl()1348abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1349abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1350abfd1a8bSRiver Riddle "unexpected ByteCode address size");
1351abfd1a8bSRiver Riddle ByteCodeAddr result;
1352abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1353abfd1a8bSRiver Riddle curCodeIt += 2;
1354abfd1a8bSRiver Riddle return result;
1355abfd1a8bSRiver Riddle }
1356abfd1a8bSRiver Riddle template <typename T>
readImpl()1357abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1358abfd1a8bSRiver Riddle return *curCodeIt++;
1359abfd1a8bSRiver Riddle }
136085ab413bSRiver Riddle template <typename T>
readImpl()136185ab413bSRiver Riddle std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
136285ab413bSRiver Riddle return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
136385ab413bSRiver Riddle }
1364abfd1a8bSRiver Riddle
1365ce57789dSRiver Riddle /// Assign the given range to the given memory index. This allocates a new
1366ce57789dSRiver Riddle /// range object if necessary.
1367ce57789dSRiver Riddle template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
assignRangeToMemory(RangeT && range,unsigned memIndex,unsigned rangeIndex)1368ce57789dSRiver Riddle void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1369ce57789dSRiver Riddle unsigned rangeIndex) {
1370ce57789dSRiver Riddle // Utility functor used to type-erase the assignment.
1371ce57789dSRiver Riddle auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1372ce57789dSRiver Riddle // If the input range is empty, we don't need to allocate anything.
1373ce57789dSRiver Riddle if (range.empty()) {
1374ce57789dSRiver Riddle rangeMemory[rangeIndex] = {};
1375ce57789dSRiver Riddle } else {
1376ce57789dSRiver Riddle // Allocate a buffer for this type range.
1377ce57789dSRiver Riddle llvm::OwningArrayRef<T> storage(llvm::size(range));
1378ce57789dSRiver Riddle llvm::copy(range, storage.begin());
1379ce57789dSRiver Riddle
1380ce57789dSRiver Riddle // Assign this to the range slot and use the range as the value for the
1381ce57789dSRiver Riddle // memory index.
1382ce57789dSRiver Riddle allocatedRangeMemory.emplace_back(std::move(storage));
1383ce57789dSRiver Riddle rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1384ce57789dSRiver Riddle }
1385ce57789dSRiver Riddle memory[memIndex] = &rangeMemory[rangeIndex];
1386ce57789dSRiver Riddle };
1387ce57789dSRiver Riddle
1388ce57789dSRiver Riddle // Dispatch based on the concrete range type.
1389ce57789dSRiver Riddle if constexpr (std::is_same_v<T, Type>) {
1390ce57789dSRiver Riddle return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1391ce57789dSRiver Riddle } else if constexpr (std::is_same_v<T, Value>) {
1392ce57789dSRiver Riddle return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1393ce57789dSRiver Riddle } else {
1394ce57789dSRiver Riddle llvm_unreachable("unhandled range type");
1395ce57789dSRiver Riddle }
1396ce57789dSRiver Riddle }
1397ce57789dSRiver Riddle
1398abfd1a8bSRiver Riddle /// The underlying bytecode buffer.
1399abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt;
1400abfd1a8bSRiver Riddle
14013eb1647aSStanislav Funiak /// The stack of bytecode positions at which to resume operation.
14023eb1647aSStanislav Funiak SmallVector<const ByteCodeField *> resumeCodeIt;
14033eb1647aSStanislav Funiak
1404abfd1a8bSRiver Riddle /// The current execution memory.
1405abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory;
14063eb1647aSStanislav Funiak MutableArrayRef<OwningOpRange> opRangeMemory;
140785ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory;
140885ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
140985ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory;
141085ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1411abfd1a8bSRiver Riddle
14123eb1647aSStanislav Funiak /// The current loop indices.
14133eb1647aSStanislav Funiak MutableArrayRef<unsigned> loopIndex;
14143eb1647aSStanislav Funiak
1415abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution.
1416abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory;
1417abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code;
1418abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits;
1419abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns;
1420abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions;
1421abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions;
1422abfd1a8bSRiver Riddle };
1423be0a7e9fSMehdi Amini } // namespace
1424abfd1a8bSRiver Riddle
executeApplyConstraint(PatternRewriter & rewriter)1425154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1426abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
14278ec28af8SMatthias Gehre ByteCodeField fun_idx = read();
1428abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args;
1429abfd1a8bSRiver Riddle readList<PDLValue>(args);
1430154cabe7SRiver Riddle
1431abfd1a8bSRiver Riddle LLVM_DEBUG({
1432abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: ";
1433abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs());
14346d2b2b8eSMartin Lücke llvm::dbgs() << "\n";
1435abfd1a8bSRiver Riddle });
1436abfd1a8bSRiver Riddle
14376d2b2b8eSMartin Lücke ByteCodeField isNegated = read();
14386d2b2b8eSMartin Lücke LLVM_DEBUG({
14396d2b2b8eSMartin Lücke llvm::dbgs() << " * isNegated: " << isNegated << "\n";
14406d2b2b8eSMartin Lücke llvm::interleaveComma(args, llvm::dbgs());
14416d2b2b8eSMartin Lücke });
14428ec28af8SMatthias Gehre
14438ec28af8SMatthias Gehre ByteCodeField numResults = read();
14448ec28af8SMatthias Gehre const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
14458ec28af8SMatthias Gehre ByteCodeRewriteResultList results(numResults);
14468ec28af8SMatthias Gehre LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1447*0ec318e5SMatthias Gehre [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
14488ec28af8SMatthias Gehre LLVM_DEBUG({
14498ec28af8SMatthias Gehre if (succeeded(rewriteResult)) {
14508ec28af8SMatthias Gehre llvm::dbgs() << " * Constraint succeeded\n";
14518ec28af8SMatthias Gehre llvm::dbgs() << " * Results: ";
14528ec28af8SMatthias Gehre llvm::interleaveComma(constraintResults, llvm::dbgs());
14538ec28af8SMatthias Gehre llvm::dbgs() << "\n";
14548ec28af8SMatthias Gehre } else {
14558ec28af8SMatthias Gehre llvm::dbgs() << " * Constraint failed\n";
14568ec28af8SMatthias Gehre }
14578ec28af8SMatthias Gehre });
14588ec28af8SMatthias Gehre assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
14598ec28af8SMatthias Gehre "native PDL rewrite function succeeded but returned "
14608ec28af8SMatthias Gehre "unexpected number of results");
14618ec28af8SMatthias Gehre processNativeFunResults(results, numResults, rewriteResult);
14628ec28af8SMatthias Gehre
14638ec28af8SMatthias Gehre // Depending on the constraint jump to the proper destination.
14648ec28af8SMatthias Gehre selectJump(isNegated != succeeded(rewriteResult));
1465abfd1a8bSRiver Riddle }
1466154cabe7SRiver Riddle
executeApplyRewrite(PatternRewriter & rewriter)14678c66344eSRiver Riddle LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1468abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1469abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1470abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args;
1471abfd1a8bSRiver Riddle readList<PDLValue>(args);
1472abfd1a8bSRiver Riddle
1473abfd1a8bSRiver Riddle LLVM_DEBUG({
147402c4c0d5SRiver Riddle llvm::dbgs() << " * Arguments: ";
1475abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs());
1476abfd1a8bSRiver Riddle });
147785ab413bSRiver Riddle
147885ab413bSRiver Riddle // Execute the rewrite function.
147985ab413bSRiver Riddle ByteCodeField numResults = read();
148085ab413bSRiver Riddle ByteCodeRewriteResultList results(numResults);
14818c66344eSRiver Riddle LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1482154cabe7SRiver Riddle
148385ab413bSRiver Riddle assert(results.getResults().size() == numResults &&
148402c4c0d5SRiver Riddle "native PDL rewrite function returned unexpected number of results");
148502c4c0d5SRiver Riddle
14868ec28af8SMatthias Gehre processNativeFunResults(results, numResults, rewriteResult);
14878ec28af8SMatthias Gehre
14888ec28af8SMatthias Gehre if (failed(rewriteResult)) {
14898ec28af8SMatthias Gehre LLVM_DEBUG(llvm::dbgs() << " - Failed");
14908ec28af8SMatthias Gehre return failure();
14918ec28af8SMatthias Gehre }
14928ec28af8SMatthias Gehre return success();
14938ec28af8SMatthias Gehre }
14948ec28af8SMatthias Gehre
processNativeFunResults(ByteCodeRewriteResultList & results,unsigned numResults,LogicalResult & rewriteResult)14958ec28af8SMatthias Gehre void ByteCodeExecutor::processNativeFunResults(
14968ec28af8SMatthias Gehre ByteCodeRewriteResultList &results, unsigned numResults,
14978ec28af8SMatthias Gehre LogicalResult &rewriteResult) {
14988ec28af8SMatthias Gehre // Store the results in the bytecode memory or handle missing results on
14998ec28af8SMatthias Gehre // failure.
15008ec28af8SMatthias Gehre for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
15018ec28af8SMatthias Gehre PDLValue::Kind resultKind = read<PDLValue::Kind>();
15028ec28af8SMatthias Gehre
15038ec28af8SMatthias Gehre // Skip the according number of values on the buffer on failure and exit
15048ec28af8SMatthias Gehre // early as there are no results to process.
15058ec28af8SMatthias Gehre if (failed(rewriteResult)) {
15068ec28af8SMatthias Gehre if (resultKind == PDLValue::Kind::TypeRange ||
15078ec28af8SMatthias Gehre resultKind == PDLValue::Kind::ValueRange) {
15088ec28af8SMatthias Gehre skip(2);
15098ec28af8SMatthias Gehre } else {
15108ec28af8SMatthias Gehre skip(1);
15118ec28af8SMatthias Gehre }
15128ec28af8SMatthias Gehre return;
15138ec28af8SMatthias Gehre }
15148ec28af8SMatthias Gehre PDLValue result = results.getResults()[resultIdx];
151502c4c0d5SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
15168ec28af8SMatthias Gehre assert(result.getKind() == resultKind &&
15178ec28af8SMatthias Gehre "native PDL rewrite function returned an unexpected type of "
15188ec28af8SMatthias Gehre "result");
151985ab413bSRiver Riddle // If the result is a range, we need to copy it over to the bytecodes
152085ab413bSRiver Riddle // range memory.
15210a81ace0SKazu Hirata if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
152285ab413bSRiver Riddle unsigned rangeIndex = read();
152385ab413bSRiver Riddle typeRangeMemory[rangeIndex] = *typeRange;
152485ab413bSRiver Riddle memory[read()] = &typeRangeMemory[rangeIndex];
15250a81ace0SKazu Hirata } else if (std::optional<ValueRange> valueRange =
152685ab413bSRiver Riddle result.dyn_cast<ValueRange>()) {
152785ab413bSRiver Riddle unsigned rangeIndex = read();
152885ab413bSRiver Riddle valueRangeMemory[rangeIndex] = *valueRange;
152985ab413bSRiver Riddle memory[read()] = &valueRangeMemory[rangeIndex];
153085ab413bSRiver Riddle } else {
153102c4c0d5SRiver Riddle memory[read()] = result.getAsOpaquePointer();
153202c4c0d5SRiver Riddle }
1533abfd1a8bSRiver Riddle }
1534154cabe7SRiver Riddle
153585ab413bSRiver Riddle // Copy over any underlying storage allocated for result ranges.
153685ab413bSRiver Riddle for (auto &it : results.getAllocatedTypeRanges())
153785ab413bSRiver Riddle allocatedTypeRangeMemory.push_back(std::move(it));
153885ab413bSRiver Riddle for (auto &it : results.getAllocatedValueRanges())
153985ab413bSRiver Riddle allocatedValueRangeMemory.push_back(std::move(it));
154085ab413bSRiver Riddle }
154185ab413bSRiver Riddle
executeAreEqual()1542154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() {
1543abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1544abfd1a8bSRiver Riddle const void *lhs = read<const void *>();
1545abfd1a8bSRiver Riddle const void *rhs = read<const void *>();
1546abfd1a8bSRiver Riddle
1547154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1548abfd1a8bSRiver Riddle selectJump(lhs == rhs);
1549abfd1a8bSRiver Riddle }
1550154cabe7SRiver Riddle
executeAreRangesEqual()155185ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() {
155285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
155385ab413bSRiver Riddle PDLValue::Kind valueKind = read<PDLValue::Kind>();
155485ab413bSRiver Riddle const void *lhs = read<const void *>();
155585ab413bSRiver Riddle const void *rhs = read<const void *>();
155685ab413bSRiver Riddle
155785ab413bSRiver Riddle switch (valueKind) {
155885ab413bSRiver Riddle case PDLValue::Kind::TypeRange: {
155985ab413bSRiver Riddle const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
156085ab413bSRiver Riddle const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
156185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
156285ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange);
156385ab413bSRiver Riddle break;
156485ab413bSRiver Riddle }
156585ab413bSRiver Riddle case PDLValue::Kind::ValueRange: {
156685ab413bSRiver Riddle const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
156785ab413bSRiver Riddle const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
156885ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
156985ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange);
157085ab413bSRiver Riddle break;
157185ab413bSRiver Riddle }
157285ab413bSRiver Riddle default:
157385ab413bSRiver Riddle llvm_unreachable("unexpected `AreRangesEqual` value kind");
157485ab413bSRiver Riddle }
157585ab413bSRiver Riddle }
157685ab413bSRiver Riddle
executeBranch()1577154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() {
1578154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1579abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()];
1580abfd1a8bSRiver Riddle }
1581154cabe7SRiver Riddle
executeCheckOperandCount()1582154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() {
1583abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1584abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1585abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>();
158685ab413bSRiver Riddle bool compareAtLeast = read();
1587abfd1a8bSRiver Riddle
1588abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
158985ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n"
159085ab413bSRiver Riddle << " * Comparator: "
159185ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n");
159285ab413bSRiver Riddle if (compareAtLeast)
159385ab413bSRiver Riddle selectJump(op->getNumOperands() >= expectedCount);
159485ab413bSRiver Riddle else
1595abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount);
1596abfd1a8bSRiver Riddle }
1597154cabe7SRiver Riddle
executeCheckOperationName()1598154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() {
1599abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1600abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1601abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>();
1602abfd1a8bSRiver Riddle
1603154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1604154cabe7SRiver Riddle << " * Expected: \"" << expectedName << "\"\n");
1605abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName);
1606abfd1a8bSRiver Riddle }
1607154cabe7SRiver Riddle
executeCheckResultCount()1608154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() {
1609abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1610abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1611abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>();
161285ab413bSRiver Riddle bool compareAtLeast = read();
1613abfd1a8bSRiver Riddle
1614abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
161585ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n"
161685ab413bSRiver Riddle << " * Comparator: "
161785ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n");
161885ab413bSRiver Riddle if (compareAtLeast)
161985ab413bSRiver Riddle selectJump(op->getNumResults() >= expectedCount);
162085ab413bSRiver Riddle else
1621abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount);
1622abfd1a8bSRiver Riddle }
1623154cabe7SRiver Riddle
executeCheckTypes()162485ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() {
162585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
162685ab413bSRiver Riddle TypeRange *lhs = read<TypeRange *>();
162785ab413bSRiver Riddle Attribute rhs = read<Attribute>();
162885ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
162985ab413bSRiver Riddle
16305550c821STres Popp selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
163185ab413bSRiver Riddle }
163285ab413bSRiver Riddle
executeContinue()16333eb1647aSStanislav Funiak void ByteCodeExecutor::executeContinue() {
16343eb1647aSStanislav Funiak ByteCodeField level = read();
16353eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
16363eb1647aSStanislav Funiak << " * Level: " << level << "\n");
16373eb1647aSStanislav Funiak ++loopIndex[level];
16383eb1647aSStanislav Funiak popCodeIt();
16393eb1647aSStanislav Funiak }
16403eb1647aSStanislav Funiak
executeCreateConstantTypeRange()1641ce57789dSRiver Riddle void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642ce57789dSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
164385ab413bSRiver Riddle unsigned memIndex = read();
164485ab413bSRiver Riddle unsigned rangeIndex = read();
16455550c821STres Popp ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
164685ab413bSRiver Riddle
164785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1648ce57789dSRiver Riddle assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1649ce57789dSRiver Riddle rangeIndex);
165085ab413bSRiver Riddle }
165185ab413bSRiver Riddle
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1652154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1653154cabe7SRiver Riddle Location mainRewriteLoc) {
1654abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1655abfd1a8bSRiver Riddle
1656abfd1a8bSRiver Riddle unsigned memIndex = read();
1657154cabe7SRiver Riddle OperationState state(mainRewriteLoc, read<OperationName>());
1658ce57789dSRiver Riddle readList(state.operands);
1659abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) {
1660195730a6SRiver Riddle StringAttr name = read<StringAttr>();
1661abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>())
1662abfd1a8bSRiver Riddle state.addAttribute(name, attr);
1663abfd1a8bSRiver Riddle }
1664abfd1a8bSRiver Riddle
16653c752289SRiver Riddle // Read in the result types. If the "size" is the sentinel value, this
16663c752289SRiver Riddle // indicates that the result types should be inferred.
16673c752289SRiver Riddle unsigned numResults = read();
16683c752289SRiver Riddle if (numResults == kInferTypesMarker) {
1669ea7be7e3SBenjamin Kramer InferTypeOpInterface::Concept *inferInterface =
16700441272cSMehdi Amini state.name.getInterface<InferTypeOpInterface>();
16713c752289SRiver Riddle assert(inferInterface &&
16723c752289SRiver Riddle "expected operation to provide InferTypeOpInterface");
1673abfd1a8bSRiver Riddle
1674abfd1a8bSRiver Riddle // TODO: Handle failure.
1675ea7be7e3SBenjamin Kramer if (failed(inferInterface->inferReturnTypes(
1676abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands,
16775e118f93SMehdi Amini state.attributes.getDictionary(state.getContext()),
16785e118f93SMehdi Amini state.getRawProperties(), state.regions, state.types)))
1679abfd1a8bSRiver Riddle return;
16803c752289SRiver Riddle } else {
16813c752289SRiver Riddle // Otherwise, this is a fixed number of results.
16823c752289SRiver Riddle for (unsigned i = 0; i != numResults; ++i) {
16833c752289SRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
16843c752289SRiver Riddle state.types.push_back(read<Type>());
16853c752289SRiver Riddle } else {
16863c752289SRiver Riddle TypeRange *resultTypes = read<TypeRange *>();
16873c752289SRiver Riddle state.types.append(resultTypes->begin(), resultTypes->end());
16883c752289SRiver Riddle }
16893c752289SRiver Riddle }
1690abfd1a8bSRiver Riddle }
169185ab413bSRiver Riddle
169214ecafd0SChia-hung Duan Operation *resultOp = rewriter.create(state);
1693abfd1a8bSRiver Riddle memory[memIndex] = resultOp;
1694abfd1a8bSRiver Riddle
1695abfd1a8bSRiver Riddle LLVM_DEBUG({
1696abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: "
1697abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext())
1698abfd1a8bSRiver Riddle << "\n * Operands: ";
1699abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs());
1700abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: ";
1701abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs());
1702154cabe7SRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1703abfd1a8bSRiver Riddle });
1704abfd1a8bSRiver Riddle }
1705154cabe7SRiver Riddle
1706ce57789dSRiver Riddle template <typename T>
executeDynamicCreateRange(StringRef type)1707ce57789dSRiver Riddle void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1708ce57789dSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1709ce57789dSRiver Riddle unsigned memIndex = read();
1710ce57789dSRiver Riddle unsigned rangeIndex = read();
1711ce57789dSRiver Riddle SmallVector<T> values;
1712ce57789dSRiver Riddle readList(values);
1713ce57789dSRiver Riddle
1714ce57789dSRiver Riddle LLVM_DEBUG({
1715ce57789dSRiver Riddle llvm::dbgs() << "\n * " << type << "s: ";
1716ce57789dSRiver Riddle llvm::interleaveComma(values, llvm::dbgs());
1717ce57789dSRiver Riddle llvm::dbgs() << "\n";
1718ce57789dSRiver Riddle });
1719ce57789dSRiver Riddle
1720ce57789dSRiver Riddle assignRangeToMemory(values, memIndex, rangeIndex);
1721ce57789dSRiver Riddle }
1722ce57789dSRiver Riddle
executeEraseOp(PatternRewriter & rewriter)1723154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1724abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1725abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1726abfd1a8bSRiver Riddle
1727154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1728abfd1a8bSRiver Riddle rewriter.eraseOp(op);
1729abfd1a8bSRiver Riddle }
1730154cabe7SRiver Riddle
17313eb1647aSStanislav Funiak template <typename T, typename Range, PDLValue::Kind kind>
executeExtract()17323eb1647aSStanislav Funiak void ByteCodeExecutor::executeExtract() {
17333eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
17343eb1647aSStanislav Funiak Range *range = read<Range *>();
17353eb1647aSStanislav Funiak unsigned index = read<uint32_t>();
17363eb1647aSStanislav Funiak unsigned memIndex = read();
17373eb1647aSStanislav Funiak
17383eb1647aSStanislav Funiak if (!range) {
17393eb1647aSStanislav Funiak memory[memIndex] = nullptr;
17403eb1647aSStanislav Funiak return;
17413eb1647aSStanislav Funiak }
17423eb1647aSStanislav Funiak
17433eb1647aSStanislav Funiak T result = index < range->size() ? (*range)[index] : T();
17443eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
17453eb1647aSStanislav Funiak << " * Index: " << index << "\n"
17463eb1647aSStanislav Funiak << " * Result: " << result << "\n");
17473eb1647aSStanislav Funiak storeToMemory(memIndex, result);
17483eb1647aSStanislav Funiak }
17493eb1647aSStanislav Funiak
executeFinalize()17503eb1647aSStanislav Funiak void ByteCodeExecutor::executeFinalize() {
17513eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
17523eb1647aSStanislav Funiak }
17533eb1647aSStanislav Funiak
executeForEach()17543eb1647aSStanislav Funiak void ByteCodeExecutor::executeForEach() {
17553eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1756d35f1190SStanislav Funiak const ByteCodeField *prevCodeIt = getPrevCodeIt();
17573eb1647aSStanislav Funiak unsigned rangeIndex = read();
17583eb1647aSStanislav Funiak unsigned memIndex = read();
17593eb1647aSStanislav Funiak const void *value = nullptr;
17603eb1647aSStanislav Funiak
17613eb1647aSStanislav Funiak switch (read<PDLValue::Kind>()) {
17623eb1647aSStanislav Funiak case PDLValue::Kind::Operation: {
17633eb1647aSStanislav Funiak unsigned &index = loopIndex[read()];
17643eb1647aSStanislav Funiak ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
17653eb1647aSStanislav Funiak assert(index <= array.size() && "iterated past the end");
17663eb1647aSStanislav Funiak if (index < array.size()) {
17673eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
17683eb1647aSStanislav Funiak value = array[index];
17693eb1647aSStanislav Funiak break;
17703eb1647aSStanislav Funiak }
17713eb1647aSStanislav Funiak
17723eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Done\n");
17733eb1647aSStanislav Funiak index = 0;
17743eb1647aSStanislav Funiak selectJump(size_t(0));
17753eb1647aSStanislav Funiak return;
17763eb1647aSStanislav Funiak }
17773eb1647aSStanislav Funiak default:
17783eb1647aSStanislav Funiak llvm_unreachable("unexpected `ForEach` value kind");
17793eb1647aSStanislav Funiak }
17803eb1647aSStanislav Funiak
17813eb1647aSStanislav Funiak // Store the iterate value and the stack address.
17823eb1647aSStanislav Funiak memory[memIndex] = value;
1783d35f1190SStanislav Funiak pushCodeIt(prevCodeIt);
17843eb1647aSStanislav Funiak
17853eb1647aSStanislav Funiak // Skip over the successor (we will enter the body of the loop).
17863eb1647aSStanislav Funiak read<ByteCodeAddr>();
17873eb1647aSStanislav Funiak }
17883eb1647aSStanislav Funiak
executeGetAttribute()1789154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() {
1790abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1791abfd1a8bSRiver Riddle unsigned memIndex = read();
1792abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1793195730a6SRiver Riddle StringAttr attrName = read<StringAttr>();
1794abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName);
1795abfd1a8bSRiver Riddle
1796abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1797abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n"
1798154cabe7SRiver Riddle << " * Result: " << attr << "\n");
1799abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer();
1800abfd1a8bSRiver Riddle }
1801154cabe7SRiver Riddle
executeGetAttributeType()1802154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() {
1803abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1804abfd1a8bSRiver Riddle unsigned memIndex = read();
1805abfd1a8bSRiver Riddle Attribute attr = read<Attribute>();
1806e1795322SJeff Niu Type type;
18075550c821STres Popp if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1808e1795322SJeff Niu type = typedAttr.getType();
1809abfd1a8bSRiver Riddle
1810abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1811154cabe7SRiver Riddle << " * Result: " << type << "\n");
1812154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer();
1813abfd1a8bSRiver Riddle }
1814154cabe7SRiver Riddle
executeGetDefiningOp()1815154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() {
1816abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1817abfd1a8bSRiver Riddle unsigned memIndex = read();
181885ab413bSRiver Riddle Operation *op = nullptr;
181985ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1820abfd1a8bSRiver Riddle Value value = read<Value>();
182185ab413bSRiver Riddle if (value)
182285ab413bSRiver Riddle op = value.getDefiningOp();
182385ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
182485ab413bSRiver Riddle } else {
182585ab413bSRiver Riddle ValueRange *values = read<ValueRange *>();
182685ab413bSRiver Riddle if (values && !values->empty()) {
182785ab413bSRiver Riddle op = values->front().getDefiningOp();
182885ab413bSRiver Riddle }
182985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
183085ab413bSRiver Riddle }
1831abfd1a8bSRiver Riddle
183285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1833abfd1a8bSRiver Riddle memory[memIndex] = op;
1834abfd1a8bSRiver Riddle }
1835154cabe7SRiver Riddle
executeGetOperand(unsigned index)1836154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) {
1837abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1838abfd1a8bSRiver Riddle unsigned memIndex = read();
1839abfd1a8bSRiver Riddle Value operand =
1840abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value();
1841abfd1a8bSRiver Riddle
1842abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1843abfd1a8bSRiver Riddle << " * Index: " << index << "\n"
1844154cabe7SRiver Riddle << " * Result: " << operand << "\n");
1845abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer();
1846abfd1a8bSRiver Riddle }
1847154cabe7SRiver Riddle
184885ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and
184985ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the
185085ab413bSRiver Riddle /// given operation.
185185ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT>
185285ab413bSRiver Riddle static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> valueRangeMemory)185385ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
185485ab413bSRiver Riddle ByteCodeField rangeIndex, StringRef attrSizedSegments,
18553eb1647aSStanislav Funiak MutableArrayRef<ValueRange> valueRangeMemory) {
185685ab413bSRiver Riddle // Check for the sentinel index that signals that all values should be
185785ab413bSRiver Riddle // returned.
185885ab413bSRiver Riddle if (index == std::numeric_limits<uint32_t>::max()) {
185985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
186085ab413bSRiver Riddle // `values` is already the full value range.
186185ab413bSRiver Riddle
186285ab413bSRiver Riddle // Otherwise, check to see if this operation uses AttrSizedSegments.
186385ab413bSRiver Riddle } else if (op->hasTrait<AttrSizedSegmentsT>()) {
186485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs()
186585ab413bSRiver Riddle << " * Extracting values from `" << attrSizedSegments << "`\n");
186685ab413bSRiver Riddle
186758a47508SJeff Niu auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
186858a47508SJeff Niu if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
186985ab413bSRiver Riddle return nullptr;
187085ab413bSRiver Riddle
187158a47508SJeff Niu ArrayRef<int32_t> segments = segmentAttr;
187285ab413bSRiver Riddle unsigned startIndex =
187385ab413bSRiver Riddle std::accumulate(segments.begin(), segments.begin() + index, 0);
187485ab413bSRiver Riddle values = values.slice(startIndex, *std::next(segments.begin(), index));
187585ab413bSRiver Riddle
187685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
187785ab413bSRiver Riddle << *std::next(segments.begin(), index) << "]\n");
187885ab413bSRiver Riddle
187985ab413bSRiver Riddle // Otherwise, assume this is the last operand group of the operation.
188085ab413bSRiver Riddle // FIXME: We currently don't support operations with
188185ab413bSRiver Riddle // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
188285ab413bSRiver Riddle // have a way to detect it's presence.
188385ab413bSRiver Riddle } else if (values.size() >= index) {
188485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs()
188585ab413bSRiver Riddle << " * Treating values as trailing variadic range\n");
188685ab413bSRiver Riddle values = values.drop_front(index);
188785ab413bSRiver Riddle
188885ab413bSRiver Riddle // If we couldn't detect a way to compute the values, bail out.
188985ab413bSRiver Riddle } else {
189085ab413bSRiver Riddle return nullptr;
189185ab413bSRiver Riddle }
189285ab413bSRiver Riddle
189385ab413bSRiver Riddle // If the range index is valid, we are returning a range.
189485ab413bSRiver Riddle if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
189585ab413bSRiver Riddle valueRangeMemory[rangeIndex] = values;
189685ab413bSRiver Riddle return &valueRangeMemory[rangeIndex];
189785ab413bSRiver Riddle }
189885ab413bSRiver Riddle
189985ab413bSRiver Riddle // If a range index wasn't provided, the range is required to be non-variadic.
190085ab413bSRiver Riddle return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
190185ab413bSRiver Riddle }
190285ab413bSRiver Riddle
executeGetOperands()190385ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() {
190485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
190585ab413bSRiver Riddle unsigned index = read<uint32_t>();
190685ab413bSRiver Riddle Operation *op = read<Operation *>();
190785ab413bSRiver Riddle ByteCodeField rangeIndex = read();
190885ab413bSRiver Riddle
190985ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1910363b6559SMehdi Amini op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
191185ab413bSRiver Riddle valueRangeMemory);
191285ab413bSRiver Riddle if (!result)
191385ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
191485ab413bSRiver Riddle memory[read()] = result;
191585ab413bSRiver Riddle }
191685ab413bSRiver Riddle
executeGetResult(unsigned index)1917154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) {
1918abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
1919abfd1a8bSRiver Riddle unsigned memIndex = read();
1920abfd1a8bSRiver Riddle OpResult result =
1921abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult();
1922abfd1a8bSRiver Riddle
1923abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1924abfd1a8bSRiver Riddle << " * Index: " << index << "\n"
1925154cabe7SRiver Riddle << " * Result: " << result << "\n");
1926abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer();
1927abfd1a8bSRiver Riddle }
1928154cabe7SRiver Riddle
executeGetResults()192985ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() {
193085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
193185ab413bSRiver Riddle unsigned index = read<uint32_t>();
193285ab413bSRiver Riddle Operation *op = read<Operation *>();
193385ab413bSRiver Riddle ByteCodeField rangeIndex = read();
193485ab413bSRiver Riddle
193585ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1936363b6559SMehdi Amini op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
193785ab413bSRiver Riddle valueRangeMemory);
193885ab413bSRiver Riddle if (!result)
193985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
194085ab413bSRiver Riddle memory[read()] = result;
194185ab413bSRiver Riddle }
194285ab413bSRiver Riddle
executeGetUsers()19433eb1647aSStanislav Funiak void ByteCodeExecutor::executeGetUsers() {
19443eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
19453eb1647aSStanislav Funiak unsigned memIndex = read();
19463eb1647aSStanislav Funiak unsigned rangeIndex = read();
19473eb1647aSStanislav Funiak OwningOpRange &range = opRangeMemory[rangeIndex];
19483eb1647aSStanislav Funiak memory[memIndex] = ⦥
19493eb1647aSStanislav Funiak
19503eb1647aSStanislav Funiak range = OwningOpRange();
19513eb1647aSStanislav Funiak if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
19523eb1647aSStanislav Funiak // Read the value.
19533eb1647aSStanislav Funiak Value value = read<Value>();
19543eb1647aSStanislav Funiak if (!value)
19553eb1647aSStanislav Funiak return;
19563eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
19573eb1647aSStanislav Funiak
19583eb1647aSStanislav Funiak // Extract the users of a single value.
19593eb1647aSStanislav Funiak range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
19603eb1647aSStanislav Funiak llvm::copy(value.getUsers(), range.begin());
19613eb1647aSStanislav Funiak } else {
19623eb1647aSStanislav Funiak // Read a range of values.
19633eb1647aSStanislav Funiak ValueRange *values = read<ValueRange *>();
19643eb1647aSStanislav Funiak if (!values)
19653eb1647aSStanislav Funiak return;
19663eb1647aSStanislav Funiak LLVM_DEBUG({
19673eb1647aSStanislav Funiak llvm::dbgs() << " * Values (" << values->size() << "): ";
19683eb1647aSStanislav Funiak llvm::interleaveComma(*values, llvm::dbgs());
19693eb1647aSStanislav Funiak llvm::dbgs() << "\n";
19703eb1647aSStanislav Funiak });
19713eb1647aSStanislav Funiak
19723eb1647aSStanislav Funiak // Extract all the users of a range of values.
19733eb1647aSStanislav Funiak SmallVector<Operation *> users;
19743eb1647aSStanislav Funiak for (Value value : *values)
19753eb1647aSStanislav Funiak users.append(value.user_begin(), value.user_end());
19763eb1647aSStanislav Funiak range = OwningOpRange(users.size());
19773eb1647aSStanislav Funiak llvm::copy(users, range.begin());
19783eb1647aSStanislav Funiak }
19793eb1647aSStanislav Funiak
19803eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
19813eb1647aSStanislav Funiak }
19823eb1647aSStanislav Funiak
executeGetValueType()1983154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() {
1984abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1985abfd1a8bSRiver Riddle unsigned memIndex = read();
1986abfd1a8bSRiver Riddle Value value = read<Value>();
1987154cabe7SRiver Riddle Type type = value ? value.getType() : Type();
1988abfd1a8bSRiver Riddle
1989abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1990154cabe7SRiver Riddle << " * Result: " << type << "\n");
1991154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer();
1992abfd1a8bSRiver Riddle }
1993154cabe7SRiver Riddle
executeGetValueRangeTypes()199485ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() {
199585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
199685ab413bSRiver Riddle unsigned memIndex = read();
199785ab413bSRiver Riddle unsigned rangeIndex = read();
199885ab413bSRiver Riddle ValueRange *values = read<ValueRange *>();
199985ab413bSRiver Riddle if (!values) {
200085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
200185ab413bSRiver Riddle memory[memIndex] = nullptr;
200285ab413bSRiver Riddle return;
200385ab413bSRiver Riddle }
200485ab413bSRiver Riddle
200585ab413bSRiver Riddle LLVM_DEBUG({
200685ab413bSRiver Riddle llvm::dbgs() << " * Values (" << values->size() << "): ";
200785ab413bSRiver Riddle llvm::interleaveComma(*values, llvm::dbgs());
200885ab413bSRiver Riddle llvm::dbgs() << "\n * Result: ";
200985ab413bSRiver Riddle llvm::interleaveComma(values->getType(), llvm::dbgs());
201085ab413bSRiver Riddle llvm::dbgs() << "\n";
201185ab413bSRiver Riddle });
201285ab413bSRiver Riddle typeRangeMemory[rangeIndex] = values->getType();
201385ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex];
201485ab413bSRiver Riddle }
201585ab413bSRiver Riddle
executeIsNotNull()2016154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() {
2017abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2018abfd1a8bSRiver Riddle const void *value = read<const void *>();
2019abfd1a8bSRiver Riddle
2020154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2021abfd1a8bSRiver Riddle selectJump(value != nullptr);
2022abfd1a8bSRiver Riddle }
2023154cabe7SRiver Riddle
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)2024154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch(
2025154cabe7SRiver Riddle PatternRewriter &rewriter,
2026154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
2027abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2028abfd1a8bSRiver Riddle unsigned patternIndex = read();
2029abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex];
2030abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2031abfd1a8bSRiver Riddle
2032abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the
2033abfd1a8bSRiver Riddle // rest of the pattern.
2034abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) {
2035154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2036abfd1a8bSRiver Riddle curCodeIt = dest;
2037154cabe7SRiver Riddle return;
2038abfd1a8bSRiver Riddle }
2039abfd1a8bSRiver Riddle
2040abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the
2041abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for
2042abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an
2043abfd1a8bSRiver Riddle // explicit location set.
2044abfd1a8bSRiver Riddle unsigned numMatchLocs = read();
2045abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs;
2046abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs);
2047abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i)
2048abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc());
2049abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs);
2050abfd1a8bSRiver Riddle
2051abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2052154cabe7SRiver Riddle << " * Location: " << matchLoc << "\n");
2053154cabe7SRiver Riddle matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
205485ab413bSRiver Riddle PDLByteCode::MatchResult &match = matches.back();
205585ab413bSRiver Riddle
205685ab413bSRiver Riddle // Record all of the inputs to the match. If any of the inputs are ranges, we
205785ab413bSRiver Riddle // will also need to remap the range pointer to memory stored in the match
205885ab413bSRiver Riddle // state.
205985ab413bSRiver Riddle unsigned numInputs = read();
206085ab413bSRiver Riddle match.values.reserve(numInputs);
206185ab413bSRiver Riddle match.typeRangeValues.reserve(numInputs);
206285ab413bSRiver Riddle match.valueRangeValues.reserve(numInputs);
206385ab413bSRiver Riddle for (unsigned i = 0; i < numInputs; ++i) {
206485ab413bSRiver Riddle switch (read<PDLValue::Kind>()) {
206585ab413bSRiver Riddle case PDLValue::Kind::TypeRange:
206685ab413bSRiver Riddle match.typeRangeValues.push_back(*read<TypeRange *>());
206785ab413bSRiver Riddle match.values.push_back(&match.typeRangeValues.back());
206885ab413bSRiver Riddle break;
206985ab413bSRiver Riddle case PDLValue::Kind::ValueRange:
207085ab413bSRiver Riddle match.valueRangeValues.push_back(*read<ValueRange *>());
207185ab413bSRiver Riddle match.values.push_back(&match.valueRangeValues.back());
207285ab413bSRiver Riddle break;
207385ab413bSRiver Riddle default:
207485ab413bSRiver Riddle match.values.push_back(read<const void *>());
207585ab413bSRiver Riddle break;
207685ab413bSRiver Riddle }
207785ab413bSRiver Riddle }
2078abfd1a8bSRiver Riddle curCodeIt = dest;
2079abfd1a8bSRiver Riddle }
2080154cabe7SRiver Riddle
executeReplaceOp(PatternRewriter & rewriter)2081154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2082abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2083abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
2084abfd1a8bSRiver Riddle SmallVector<Value, 16> args;
2085ce57789dSRiver Riddle readList(args);
2086abfd1a8bSRiver Riddle
2087abfd1a8bSRiver Riddle LLVM_DEBUG({
2088abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n"
2089abfd1a8bSRiver Riddle << " * Values: ";
2090abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs());
2091154cabe7SRiver Riddle llvm::dbgs() << "\n";
2092abfd1a8bSRiver Riddle });
2093abfd1a8bSRiver Riddle rewriter.replaceOp(op, args);
2094abfd1a8bSRiver Riddle }
2095154cabe7SRiver Riddle
executeSwitchAttribute()2096154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() {
2097abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2098abfd1a8bSRiver Riddle Attribute value = read<Attribute>();
2099abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>();
2100abfd1a8bSRiver Riddle handleSwitch(value, cases);
2101abfd1a8bSRiver Riddle }
2102154cabe7SRiver Riddle
executeSwitchOperandCount()2103154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() {
2104abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2105abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
2106abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2107abfd1a8bSRiver Riddle
2108abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2109abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases);
2110abfd1a8bSRiver Riddle }
2111154cabe7SRiver Riddle
executeSwitchOperationName()2112154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() {
2113abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2114abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName();
2115abfd1a8bSRiver Riddle size_t caseCount = read();
2116abfd1a8bSRiver Riddle
2117abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for
2118abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the
2119abfd1a8bSRiver Riddle // switch so that we can display all of the possible values.
2120abfd1a8bSRiver Riddle LLVM_DEBUG({
2121abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt;
2122abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n"
2123abfd1a8bSRiver Riddle << " * Cases: ";
2124abfd1a8bSRiver Riddle llvm::interleaveComma(
2125abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount),
2126154cabe7SRiver Riddle [&](size_t) { return read<OperationName>(); }),
2127abfd1a8bSRiver Riddle llvm::dbgs());
2128154cabe7SRiver Riddle llvm::dbgs() << "\n";
2129abfd1a8bSRiver Riddle curCodeIt = prevCodeIt;
2130abfd1a8bSRiver Riddle });
2131abfd1a8bSRiver Riddle
2132abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases.
2133abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) {
2134abfd1a8bSRiver Riddle if (read<OperationName>() == value) {
2135abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1);
2136154cabe7SRiver Riddle return selectJump(i + 1);
2137abfd1a8bSRiver Riddle }
2138abfd1a8bSRiver Riddle }
2139154cabe7SRiver Riddle selectJump(size_t(0));
2140abfd1a8bSRiver Riddle }
2141154cabe7SRiver Riddle
executeSwitchResultCount()2142154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() {
2143abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2144abfd1a8bSRiver Riddle Operation *op = read<Operation *>();
2145abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2146abfd1a8bSRiver Riddle
2147abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2148abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases);
2149abfd1a8bSRiver Riddle }
2150154cabe7SRiver Riddle
executeSwitchType()2151154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() {
2152abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2153abfd1a8bSRiver Riddle Type value = read<Type>();
2154abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2155abfd1a8bSRiver Riddle handleSwitch(value, cases);
2156154cabe7SRiver Riddle }
2157154cabe7SRiver Riddle
executeSwitchTypes()215885ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() {
215985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
216085ab413bSRiver Riddle TypeRange *value = read<TypeRange *>();
216185ab413bSRiver Riddle auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
216285ab413bSRiver Riddle if (!value) {
216385ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
216485ab413bSRiver Riddle return selectJump(size_t(0));
216585ab413bSRiver Riddle }
216685ab413bSRiver Riddle handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
216785ab413bSRiver Riddle return value == caseValue.getAsValueRange<TypeAttr>();
216885ab413bSRiver Riddle });
216985ab413bSRiver Riddle }
217085ab413bSRiver Riddle
21718c66344eSRiver Riddle LogicalResult
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,std::optional<Location> mainRewriteLoc)21728c66344eSRiver Riddle ByteCodeExecutor::execute(PatternRewriter &rewriter,
2173154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches,
21740a81ace0SKazu Hirata std::optional<Location> mainRewriteLoc) {
2175154cabe7SRiver Riddle while (true) {
2176d35f1190SStanislav Funiak // Print the location of the operation being executed.
2177d35f1190SStanislav Funiak LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2178d35f1190SStanislav Funiak
2179154cabe7SRiver Riddle OpCode opCode = static_cast<OpCode>(read());
2180154cabe7SRiver Riddle switch (opCode) {
2181154cabe7SRiver Riddle case ApplyConstraint:
2182154cabe7SRiver Riddle executeApplyConstraint(rewriter);
2183154cabe7SRiver Riddle break;
2184154cabe7SRiver Riddle case ApplyRewrite:
21858c66344eSRiver Riddle if (failed(executeApplyRewrite(rewriter)))
21868c66344eSRiver Riddle return failure();
2187154cabe7SRiver Riddle break;
2188154cabe7SRiver Riddle case AreEqual:
2189154cabe7SRiver Riddle executeAreEqual();
2190154cabe7SRiver Riddle break;
219185ab413bSRiver Riddle case AreRangesEqual:
219285ab413bSRiver Riddle executeAreRangesEqual();
219385ab413bSRiver Riddle break;
2194154cabe7SRiver Riddle case Branch:
2195154cabe7SRiver Riddle executeBranch();
2196154cabe7SRiver Riddle break;
2197154cabe7SRiver Riddle case CheckOperandCount:
2198154cabe7SRiver Riddle executeCheckOperandCount();
2199154cabe7SRiver Riddle break;
2200154cabe7SRiver Riddle case CheckOperationName:
2201154cabe7SRiver Riddle executeCheckOperationName();
2202154cabe7SRiver Riddle break;
2203154cabe7SRiver Riddle case CheckResultCount:
2204154cabe7SRiver Riddle executeCheckResultCount();
2205154cabe7SRiver Riddle break;
220685ab413bSRiver Riddle case CheckTypes:
220785ab413bSRiver Riddle executeCheckTypes();
220885ab413bSRiver Riddle break;
22093eb1647aSStanislav Funiak case Continue:
22103eb1647aSStanislav Funiak executeContinue();
22113eb1647aSStanislav Funiak break;
2212ce57789dSRiver Riddle case CreateConstantTypeRange:
2213ce57789dSRiver Riddle executeCreateConstantTypeRange();
2214ce57789dSRiver Riddle break;
2215154cabe7SRiver Riddle case CreateOperation:
2216154cabe7SRiver Riddle executeCreateOperation(rewriter, *mainRewriteLoc);
2217154cabe7SRiver Riddle break;
2218ce57789dSRiver Riddle case CreateDynamicTypeRange:
2219ce57789dSRiver Riddle executeDynamicCreateRange<Type>("Type");
2220ce57789dSRiver Riddle break;
2221ce57789dSRiver Riddle case CreateDynamicValueRange:
2222ce57789dSRiver Riddle executeDynamicCreateRange<Value>("Value");
222385ab413bSRiver Riddle break;
2224154cabe7SRiver Riddle case EraseOp:
2225154cabe7SRiver Riddle executeEraseOp(rewriter);
2226154cabe7SRiver Riddle break;
22273eb1647aSStanislav Funiak case ExtractOp:
22283eb1647aSStanislav Funiak executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
22293eb1647aSStanislav Funiak break;
22303eb1647aSStanislav Funiak case ExtractType:
22313eb1647aSStanislav Funiak executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
22323eb1647aSStanislav Funiak break;
22333eb1647aSStanislav Funiak case ExtractValue:
22343eb1647aSStanislav Funiak executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
22353eb1647aSStanislav Funiak break;
2236154cabe7SRiver Riddle case Finalize:
22373eb1647aSStanislav Funiak executeFinalize();
22383eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "\n");
22398c66344eSRiver Riddle return success();
22403eb1647aSStanislav Funiak case ForEach:
22413eb1647aSStanislav Funiak executeForEach();
22423eb1647aSStanislav Funiak break;
2243154cabe7SRiver Riddle case GetAttribute:
2244154cabe7SRiver Riddle executeGetAttribute();
2245154cabe7SRiver Riddle break;
2246154cabe7SRiver Riddle case GetAttributeType:
2247154cabe7SRiver Riddle executeGetAttributeType();
2248154cabe7SRiver Riddle break;
2249154cabe7SRiver Riddle case GetDefiningOp:
2250154cabe7SRiver Riddle executeGetDefiningOp();
2251154cabe7SRiver Riddle break;
2252154cabe7SRiver Riddle case GetOperand0:
2253154cabe7SRiver Riddle case GetOperand1:
2254154cabe7SRiver Riddle case GetOperand2:
2255154cabe7SRiver Riddle case GetOperand3: {
2256154cabe7SRiver Riddle unsigned index = opCode - GetOperand0;
2257154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
22581fff7c89SFrederik Gossen executeGetOperand(index);
2259abfd1a8bSRiver Riddle break;
2260abfd1a8bSRiver Riddle }
2261154cabe7SRiver Riddle case GetOperandN:
2262154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2263154cabe7SRiver Riddle executeGetOperand(read<uint32_t>());
2264154cabe7SRiver Riddle break;
226585ab413bSRiver Riddle case GetOperands:
226685ab413bSRiver Riddle executeGetOperands();
226785ab413bSRiver Riddle break;
2268154cabe7SRiver Riddle case GetResult0:
2269154cabe7SRiver Riddle case GetResult1:
2270154cabe7SRiver Riddle case GetResult2:
2271154cabe7SRiver Riddle case GetResult3: {
2272154cabe7SRiver Riddle unsigned index = opCode - GetResult0;
2273154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
22741fff7c89SFrederik Gossen executeGetResult(index);
2275154cabe7SRiver Riddle break;
2276abfd1a8bSRiver Riddle }
2277154cabe7SRiver Riddle case GetResultN:
2278154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2279154cabe7SRiver Riddle executeGetResult(read<uint32_t>());
2280154cabe7SRiver Riddle break;
228185ab413bSRiver Riddle case GetResults:
228285ab413bSRiver Riddle executeGetResults();
228385ab413bSRiver Riddle break;
22843eb1647aSStanislav Funiak case GetUsers:
22853eb1647aSStanislav Funiak executeGetUsers();
22863eb1647aSStanislav Funiak break;
2287154cabe7SRiver Riddle case GetValueType:
2288154cabe7SRiver Riddle executeGetValueType();
2289154cabe7SRiver Riddle break;
229085ab413bSRiver Riddle case GetValueRangeTypes:
229185ab413bSRiver Riddle executeGetValueRangeTypes();
229285ab413bSRiver Riddle break;
2293154cabe7SRiver Riddle case IsNotNull:
2294154cabe7SRiver Riddle executeIsNotNull();
2295154cabe7SRiver Riddle break;
2296154cabe7SRiver Riddle case RecordMatch:
2297154cabe7SRiver Riddle assert(matches &&
2298154cabe7SRiver Riddle "expected matches to be provided when executing the matcher");
2299154cabe7SRiver Riddle executeRecordMatch(rewriter, *matches);
2300154cabe7SRiver Riddle break;
2301154cabe7SRiver Riddle case ReplaceOp:
2302154cabe7SRiver Riddle executeReplaceOp(rewriter);
2303154cabe7SRiver Riddle break;
2304154cabe7SRiver Riddle case SwitchAttribute:
2305154cabe7SRiver Riddle executeSwitchAttribute();
2306154cabe7SRiver Riddle break;
2307154cabe7SRiver Riddle case SwitchOperandCount:
2308154cabe7SRiver Riddle executeSwitchOperandCount();
2309154cabe7SRiver Riddle break;
2310154cabe7SRiver Riddle case SwitchOperationName:
2311154cabe7SRiver Riddle executeSwitchOperationName();
2312154cabe7SRiver Riddle break;
2313154cabe7SRiver Riddle case SwitchResultCount:
2314154cabe7SRiver Riddle executeSwitchResultCount();
2315154cabe7SRiver Riddle break;
2316154cabe7SRiver Riddle case SwitchType:
2317154cabe7SRiver Riddle executeSwitchType();
2318154cabe7SRiver Riddle break;
231985ab413bSRiver Riddle case SwitchTypes:
232085ab413bSRiver Riddle executeSwitchTypes();
232185ab413bSRiver Riddle break;
2322154cabe7SRiver Riddle }
2323154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "\n");
2324abfd1a8bSRiver Riddle }
2325abfd1a8bSRiver Riddle }
2326abfd1a8bSRiver Riddle
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const2327abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2328abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches,
2329abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const {
2330abfd1a8bSRiver Riddle // The first memory slot is always the root operation.
2331abfd1a8bSRiver Riddle state.memory[0] = op;
2332abfd1a8bSRiver Riddle
2333abfd1a8bSRiver Riddle // The matcher function always starts at code address 0.
233485ab413bSRiver Riddle ByteCodeExecutor executor(
23353eb1647aSStanislav Funiak matcherByteCode.data(), state.memory, state.opRangeMemory,
23363eb1647aSStanislav Funiak state.typeRangeMemory, state.allocatedTypeRangeMemory,
23373eb1647aSStanislav Funiak state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
23383eb1647aSStanislav Funiak uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
23393eb1647aSStanislav Funiak constraintFunctions, rewriteFunctions);
23408c66344eSRiver Riddle LogicalResult executeResult = executor.execute(rewriter, &matches);
2341eee5c385SJohannes Reifferscheid (void)executeResult;
23428c66344eSRiver Riddle assert(succeeded(executeResult) && "unexpected matcher execution failure");
2343abfd1a8bSRiver Riddle
2344abfd1a8bSRiver Riddle // Order the found matches by benefit.
2345abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(),
2346abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) {
2347abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit;
2348abfd1a8bSRiver Riddle });
2349abfd1a8bSRiver Riddle }
2350abfd1a8bSRiver Riddle
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const23518c66344eSRiver Riddle LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
23528c66344eSRiver Riddle const MatchResult &match,
2353abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const {
23548c66344eSRiver Riddle auto *configSet = match.pattern->getConfigSet();
23558c66344eSRiver Riddle if (configSet)
23568c66344eSRiver Riddle configSet->notifyRewriteBegin(rewriter);
23578c66344eSRiver Riddle
2358abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the
2359abfd1a8bSRiver Riddle // memory buffer.
2360abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin());
2361abfd1a8bSRiver Riddle
236285ab413bSRiver Riddle ByteCodeExecutor executor(
236385ab413bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
23643eb1647aSStanislav Funiak state.opRangeMemory, state.typeRangeMemory,
23653eb1647aSStanislav Funiak state.allocatedTypeRangeMemory, state.valueRangeMemory,
23663eb1647aSStanislav Funiak state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
236785ab413bSRiver Riddle rewriterByteCode, state.currentPatternBenefits, patterns,
236802c4c0d5SRiver Riddle constraintFunctions, rewriteFunctions);
23698c66344eSRiver Riddle LogicalResult result =
2370abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location);
23718c66344eSRiver Riddle
23728c66344eSRiver Riddle if (configSet)
23738c66344eSRiver Riddle configSet->notifyRewriteEnd(rewriter);
23748c66344eSRiver Riddle
23758c66344eSRiver Riddle // If the rewrite failed, check if the pattern rewriter can recover. If it
23768c66344eSRiver Riddle // can, we can signal to the pattern applicator to keep trying patterns. If it
23778c66344eSRiver Riddle // doesn't, we need to bail. Bailing here should be fine, given that we have
23788c66344eSRiver Riddle // no means to propagate such a failure to the user, and it also indicates a
23798c66344eSRiver Riddle // bug in the user code (i.e. failable rewrites should not be used with
23808c66344eSRiver Riddle // pattern rewriters that don't support it).
23818c66344eSRiver Riddle if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
23828c66344eSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
23838c66344eSRiver Riddle llvm::report_fatal_error(
23848c66344eSRiver Riddle "Native PDL Rewrite failed, but the pattern "
23858c66344eSRiver Riddle "rewriter doesn't support recovery. Failable pattern rewrites should "
23868c66344eSRiver Riddle "not be used with pattern rewriters that do not support them.");
23878c66344eSRiver Riddle }
23888c66344eSRiver Riddle return result;
2389abfd1a8bSRiver Riddle }
2390