1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR. 10 // The byte-code is constructed from the PDL Interpreter dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_BYTECODE_H_ 15 #define MLIR_REWRITE_BYTECODE_H_ 16 17 #include "mlir/IR/PatternMatch.h" 18 19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 20 21 namespace mlir { 22 namespace pdl_interp { 23 class RecordMatchOp; 24 } // namespace pdl_interp 25 26 namespace detail { 27 class PDLByteCode; 28 29 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode 30 /// entries. ByteCodeAddr refers to size of indices into the bytecode. 31 using ByteCodeField = uint16_t; 32 using ByteCodeAddr = uint32_t; 33 using OwningOpRange = llvm::OwningArrayRef<Operation *>; 34 35 //===----------------------------------------------------------------------===// 36 // PDLByteCodePattern 37 //===----------------------------------------------------------------------===// 38 39 /// All of the data pertaining to a specific pattern within the bytecode. 40 class PDLByteCodePattern : public Pattern { 41 public: 42 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, 43 PDLPatternConfigSet *configSet, 44 ByteCodeAddr rewriterAddr); 45 46 /// Return the bytecode address of the rewriter for this pattern. getRewriterAddr()47 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } 48 49 /// Return the configuration set for this pattern, or null if there is none. getConfigSet()50 PDLPatternConfigSet *getConfigSet() const { return configSet; } 51 52 private: 53 template <typename... Args> PDLByteCodePattern(ByteCodeAddr rewriterAddr,PDLPatternConfigSet * configSet,Args &&...patternArgs)54 PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet, 55 Args &&...patternArgs) 56 : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr), 57 configSet(configSet) {} 58 59 /// The address of the rewriter for this pattern. 60 ByteCodeAddr rewriterAddr; 61 62 /// The optional config set for this pattern. 63 PDLPatternConfigSet *configSet; 64 }; 65 66 //===----------------------------------------------------------------------===// 67 // PDLByteCodeMutableState 68 //===----------------------------------------------------------------------===// 69 70 /// This class contains the mutable state of a bytecode instance. This allows 71 /// for a bytecode instance to be cached and reused across various different 72 /// threads/drivers. 73 class PDLByteCodeMutableState { 74 public: 75 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 76 /// to the position of the pattern within the range returned by 77 /// `PDLByteCode::getPatterns`. 78 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); 79 80 /// Cleanup any allocated state after a match/rewrite has been completed. This 81 /// method should be called irregardless of whether the match+rewrite was a 82 /// success or not. 83 void cleanupAfterMatchAndRewrite(); 84 85 private: 86 /// Allow access to data fields. 87 friend class PDLByteCode; 88 89 /// The mutable block of memory used during the matching and rewriting phases 90 /// of the bytecode. 91 std::vector<const void *> memory; 92 93 /// A mutable block of memory used during the matching and rewriting phase of 94 /// the bytecode to store ranges of operations. These are always stored by 95 /// owning references, because at no point in the execution of the byte code 96 /// we get an indexed range (view) of operations. 97 std::vector<OwningOpRange> opRangeMemory; 98 99 /// A mutable block of memory used during the matching and rewriting phase of 100 /// the bytecode to store ranges of types. 101 std::vector<TypeRange> typeRangeMemory; 102 /// A set of type ranges that have been allocated by the byte code interpreter 103 /// to provide a guaranteed lifetime. 104 std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; 105 106 /// A mutable block of memory used during the matching and rewriting phase of 107 /// the bytecode to store ranges of values. 108 std::vector<ValueRange> valueRangeMemory; 109 /// A set of value ranges that have been allocated by the byte code 110 /// interpreter to provide a guaranteed lifetime. 111 std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; 112 113 /// The current index of ranges being iterated over for each level of nesting. 114 /// These are always maintained at 0 for the loops that are not active, so we 115 /// do not need to have a separate initialization phase for each loop. 116 std::vector<unsigned> loopIndex; 117 118 /// The up-to-date benefits of the patterns held by the bytecode. The order 119 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. 120 std::vector<PatternBenefit> currentPatternBenefits; 121 }; 122 123 //===----------------------------------------------------------------------===// 124 // PDLByteCode 125 //===----------------------------------------------------------------------===// 126 127 /// The bytecode class is also the interpreter. Contains the bytecode itself, 128 /// the static info, addresses of the rewriter functions, the interpreter 129 /// memory buffer, and the execution context. 130 class PDLByteCode { 131 public: 132 /// Each successful match returns a MatchResult, which contains information 133 /// necessary to execute the rewriter and indicates the originating pattern. 134 struct MatchResult { MatchResultMatchResult135 MatchResult(Location loc, const PDLByteCodePattern &pattern, 136 PatternBenefit benefit) 137 : location(loc), pattern(&pattern), benefit(benefit) {} 138 MatchResult(const MatchResult &) = delete; 139 MatchResult &operator=(const MatchResult &) = delete; 140 MatchResult(MatchResult &&other) = default; 141 MatchResult &operator=(MatchResult &&) = default; 142 143 /// The location of operations to be replaced. 144 Location location; 145 /// Memory values defined in the matcher that are passed to the rewriter. 146 SmallVector<const void *> values; 147 /// Memory used for the range input values. 148 SmallVector<TypeRange, 0> typeRangeValues; 149 SmallVector<ValueRange, 0> valueRangeValues; 150 151 /// The originating pattern that was matched. This is always non-null, but 152 /// represented with a pointer to allow for assignment. 153 const PDLByteCodePattern *pattern; 154 /// The current benefit of the pattern that was matched. 155 PatternBenefit benefit; 156 }; 157 158 /// Create a ByteCode instance from the given module containing operations in 159 /// the PDL interpreter dialect. 160 PDLByteCode(ModuleOp module, 161 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, 162 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, 163 llvm::StringMap<PDLConstraintFunction> constraintFns, 164 llvm::StringMap<PDLRewriteFunction> rewriteFns); 165 166 /// Return the patterns held by the bytecode. getPatterns()167 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } 168 169 /// Initialize the given state such that it can be used to execute the current 170 /// bytecode. 171 void initializeMutableState(PDLByteCodeMutableState &state) const; 172 173 /// Run the pattern matcher on the given root operation, collecting the 174 /// matched patterns in `matches`. 175 void match(Operation *op, PatternRewriter &rewriter, 176 SmallVectorImpl<MatchResult> &matches, 177 PDLByteCodeMutableState &state) const; 178 179 /// Run the rewriter of the given pattern that was previously matched in 180 /// `match`. Returns if a failure was encountered during the rewrite. 181 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, 182 PDLByteCodeMutableState &state) const; 183 184 private: 185 /// Execute the given byte code starting at the provided instruction `inst`. 186 /// `matches` is an optional field provided when this function is executed in 187 /// a matching context. 188 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, 189 PDLByteCodeMutableState &state, 190 SmallVectorImpl<MatchResult> *matches) const; 191 192 /// The set of pattern configs referenced within the bytecode. 193 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs; 194 195 /// A vector containing pointers to uniqued data. The storage is intentionally 196 /// opaque such that we can store a wide range of data types. The types of 197 /// data stored here include: 198 /// * Attribute, OperationName, Type 199 std::vector<const void *> uniquedData; 200 201 /// A vector containing the generated bytecode for the matcher. 202 SmallVector<ByteCodeField, 64> matcherByteCode; 203 204 /// A vector containing the generated bytecode for all of the rewriters. 205 SmallVector<ByteCodeField, 64> rewriterByteCode; 206 207 /// The set of patterns contained within the bytecode. 208 SmallVector<PDLByteCodePattern, 32> patterns; 209 210 /// A set of user defined functions invoked via PDL. 211 std::vector<PDLConstraintFunction> constraintFunctions; 212 std::vector<PDLRewriteFunction> rewriteFunctions; 213 214 /// The maximum memory index used by a value. 215 ByteCodeField maxValueMemoryIndex = 0; 216 217 /// The maximum number of different types of ranges. 218 ByteCodeField maxOpRangeCount = 0; 219 ByteCodeField maxTypeRangeCount = 0; 220 ByteCodeField maxValueRangeCount = 0; 221 222 /// The maximum number of nested loops. 223 ByteCodeField maxLoopLevel = 0; 224 }; 225 226 } // namespace detail 227 } // namespace mlir 228 229 #else 230 231 namespace mlir::detail { 232 233 class PDLByteCodeMutableState { 234 public: cleanupAfterMatchAndRewrite()235 void cleanupAfterMatchAndRewrite() {} updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)236 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {} 237 }; 238 239 class PDLByteCodePattern : public Pattern {}; 240 241 class PDLByteCode { 242 public: 243 struct MatchResult { 244 const PDLByteCodePattern *pattern = nullptr; 245 PatternBenefit benefit; 246 }; 247 initializeMutableState(PDLByteCodeMutableState & state)248 void initializeMutableState(PDLByteCodeMutableState &state) const {} match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state)249 void match(Operation *op, PatternRewriter &rewriter, 250 SmallVectorImpl<MatchResult> &matches, 251 PDLByteCodeMutableState &state) const {} rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state)252 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, 253 PDLByteCodeMutableState &state) const { 254 return failure(); 255 } getPatterns()256 ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; } 257 }; 258 259 } // namespace mlir::detail 260 261 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 262 263 #endif // MLIR_REWRITE_BYTECODE_H_ 264