1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 #include <optional> 27 28 #define DEBUG_TYPE "pdl-bytecode" 29 30 using namespace mlir; 31 using namespace mlir::detail; 32 33 //===----------------------------------------------------------------------===// 34 // PDLByteCodePattern 35 //===----------------------------------------------------------------------===// 36 37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 38 PDLPatternConfigSet *configSet, 39 ByteCodeAddr rewriterAddr) { 40 PatternBenefit benefit = matchOp.getBenefit(); 41 MLIRContext *ctx = matchOp.getContext(); 42 43 // Collect the set of generated operations. 44 SmallVector<StringRef, 8> generatedOps; 45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 46 generatedOps = 47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 48 49 // Check to see if this is pattern matches a specific operation type. 50 if (std::optional<StringRef> rootKind = matchOp.getRootKind()) 51 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, 52 generatedOps); 53 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), 54 benefit, ctx, generatedOps); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // PDLByteCodeMutableState 59 //===----------------------------------------------------------------------===// 60 61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 62 /// to the position of the pattern within the range returned by 63 /// `PDLByteCode::getPatterns`. 64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 65 PatternBenefit benefit) { 66 currentPatternBenefits[patternIndex] = benefit; 67 } 68 69 /// Cleanup any allocated state after a full match/rewrite has been completed. 70 /// This method should be called irregardless of whether the match+rewrite was a 71 /// success or not. 72 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 73 allocatedTypeRangeMemory.clear(); 74 allocatedValueRangeMemory.clear(); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Bytecode OpCodes 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 enum OpCode : ByteCodeField { 83 /// Apply an externally registered constraint. 84 ApplyConstraint, 85 /// Apply an externally registered rewrite. 86 ApplyRewrite, 87 /// Check if two generic values are equal. 88 AreEqual, 89 /// Check if two ranges are equal. 90 AreRangesEqual, 91 /// Unconditional branch. 92 Branch, 93 /// Compare the operand count of an operation with a constant. 94 CheckOperandCount, 95 /// Compare the name of an operation with a constant. 96 CheckOperationName, 97 /// Compare the result count of an operation with a constant. 98 CheckResultCount, 99 /// Compare a range of types to a constant range of types. 100 CheckTypes, 101 /// Continue to the next iteration of a loop. 102 Continue, 103 /// Create a type range from a list of constant types. 104 CreateConstantTypeRange, 105 /// Create an operation. 106 CreateOperation, 107 /// Create a type range from a list of dynamic types. 108 CreateDynamicTypeRange, 109 /// Create a value range. 110 CreateDynamicValueRange, 111 /// Erase an operation. 112 EraseOp, 113 /// Extract the op from a range at the specified index. 114 ExtractOp, 115 /// Extract the type from a range at the specified index. 116 ExtractType, 117 /// Extract the value from a range at the specified index. 118 ExtractValue, 119 /// Terminate a matcher or rewrite sequence. 120 Finalize, 121 /// Iterate over a range of values. 122 ForEach, 123 /// Get a specific attribute of an operation. 124 GetAttribute, 125 /// Get the type of an attribute. 126 GetAttributeType, 127 /// Get the defining operation of a value. 128 GetDefiningOp, 129 /// Get a specific operand of an operation. 130 GetOperand0, 131 GetOperand1, 132 GetOperand2, 133 GetOperand3, 134 GetOperandN, 135 /// Get a specific operand group of an operation. 136 GetOperands, 137 /// Get a specific result of an operation. 138 GetResult0, 139 GetResult1, 140 GetResult2, 141 GetResult3, 142 GetResultN, 143 /// Get a specific result group of an operation. 144 GetResults, 145 /// Get the users of a value or a range of values. 146 GetUsers, 147 /// Get the type of a value. 148 GetValueType, 149 /// Get the types of a value range. 150 GetValueRangeTypes, 151 /// Check if a generic value is not null. 152 IsNotNull, 153 /// Record a successful pattern match. 154 RecordMatch, 155 /// Replace an operation. 156 ReplaceOp, 157 /// Compare an attribute with a set of constants. 158 SwitchAttribute, 159 /// Compare the operand count of an operation with a set of constants. 160 SwitchOperandCount, 161 /// Compare the name of an operation with a set of constants. 162 SwitchOperationName, 163 /// Compare the result count of an operation with a set of constants. 164 SwitchResultCount, 165 /// Compare a type with a set of constants. 166 SwitchType, 167 /// Compare a range of types with a set of constants. 168 SwitchTypes, 169 }; 170 } // namespace 171 172 /// A marker used to indicate if an operation should infer types. 173 static constexpr ByteCodeField kInferTypesMarker = 174 std::numeric_limits<ByteCodeField>::max(); 175 176 //===----------------------------------------------------------------------===// 177 // ByteCode Generation 178 //===----------------------------------------------------------------------===// 179 180 //===----------------------------------------------------------------------===// 181 // Generator 182 183 namespace { 184 struct ByteCodeLiveRange; 185 struct ByteCodeWriter; 186 187 /// Check if the given class `T` can be converted to an opaque pointer. 188 template <typename T, typename... Args> 189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 190 191 /// This class represents the main generator for the pattern bytecode. 192 class Generator { 193 public: 194 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 195 SmallVectorImpl<ByteCodeField> &matcherByteCode, 196 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 197 SmallVectorImpl<PDLByteCodePattern> &patterns, 198 ByteCodeField &maxValueMemoryIndex, 199 ByteCodeField &maxOpRangeMemoryIndex, 200 ByteCodeField &maxTypeRangeMemoryIndex, 201 ByteCodeField &maxValueRangeMemoryIndex, 202 ByteCodeField &maxLoopLevel, 203 llvm::StringMap<PDLConstraintFunction> &constraintFns, 204 llvm::StringMap<PDLRewriteFunction> &rewriteFns, 205 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 207 rewriterByteCode(rewriterByteCode), patterns(patterns), 208 maxValueMemoryIndex(maxValueMemoryIndex), 209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 212 maxLoopLevel(maxLoopLevel), configMap(configMap) { 213 for (const auto &it : llvm::enumerate(constraintFns)) 214 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 215 for (const auto &it : llvm::enumerate(rewriteFns)) 216 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 217 } 218 219 /// Generate the bytecode for the given PDL interpreter module. 220 void generate(ModuleOp module); 221 222 /// Return the memory index to use for the given value. 223 ByteCodeField &getMemIndex(Value value) { 224 assert(valueToMemIndex.count(value) && 225 "expected memory index to be assigned"); 226 return valueToMemIndex[value]; 227 } 228 229 /// Return the range memory index used to store the given range value. 230 ByteCodeField &getRangeStorageIndex(Value value) { 231 assert(valueToRangeIndex.count(value) && 232 "expected range index to be assigned"); 233 return valueToRangeIndex[value]; 234 } 235 236 /// Return an index to use when referring to the given data that is uniqued in 237 /// the MLIR context. 238 template <typename T> 239 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 240 getMemIndex(T val) { 241 const void *opaqueVal = val.getAsOpaquePointer(); 242 243 // Get or insert a reference to this value. 244 auto it = uniquedDataToMemIndex.try_emplace( 245 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 246 if (it.second) 247 uniquedData.push_back(opaqueVal); 248 return it.first->second; 249 } 250 251 private: 252 /// Allocate memory indices for the results of operations within the matcher 253 /// and rewriters. 254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 255 ModuleOp rewriterModule); 256 257 /// Generate the bytecode for the given operation. 258 void generate(Region *region, ByteCodeWriter &writer); 259 void generate(Operation *op, ByteCodeWriter &writer); 260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 298 299 /// Mapping from value to its corresponding memory index. 300 DenseMap<Value, ByteCodeField> valueToMemIndex; 301 302 /// Mapping from a range value to its corresponding range storage index. 303 DenseMap<Value, ByteCodeField> valueToRangeIndex; 304 305 /// Mapping from the name of an externally registered rewrite to its index in 306 /// the bytecode registry. 307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 308 309 /// Mapping from the name of an externally registered constraint to its index 310 /// in the bytecode registry. 311 llvm::StringMap<ByteCodeField> constraintToMemIndex; 312 313 /// Mapping from rewriter function name to the bytecode address of the 314 /// rewriter function in byte. 315 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 316 317 /// Mapping from a uniqued storage object to its memory index within 318 /// `uniquedData`. 319 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 320 321 /// The current level of the foreach loop. 322 ByteCodeField curLoopLevel = 0; 323 324 /// The current MLIR context. 325 MLIRContext *ctx; 326 327 /// Mapping from block to its address. 328 DenseMap<Block *, ByteCodeAddr> blockToAddr; 329 330 /// Data of the ByteCode class to be populated. 331 std::vector<const void *> &uniquedData; 332 SmallVectorImpl<ByteCodeField> &matcherByteCode; 333 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 334 SmallVectorImpl<PDLByteCodePattern> &patterns; 335 ByteCodeField &maxValueMemoryIndex; 336 ByteCodeField &maxOpRangeMemoryIndex; 337 ByteCodeField &maxTypeRangeMemoryIndex; 338 ByteCodeField &maxValueRangeMemoryIndex; 339 ByteCodeField &maxLoopLevel; 340 341 /// A map of pattern configurations. 342 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap; 343 }; 344 345 /// This class provides utilities for writing a bytecode stream. 346 struct ByteCodeWriter { 347 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 348 : bytecode(bytecode), generator(generator) {} 349 350 /// Append a field to the bytecode. 351 void append(ByteCodeField field) { bytecode.push_back(field); } 352 void append(OpCode opCode) { bytecode.push_back(opCode); } 353 354 /// Append an address to the bytecode. 355 void append(ByteCodeAddr field) { 356 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 357 "unexpected ByteCode address size"); 358 359 ByteCodeField fieldParts[2]; 360 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 361 bytecode.append({fieldParts[0], fieldParts[1]}); 362 } 363 364 /// Append a single successor to the bytecode, the exact address will need to 365 /// be resolved later. 366 void append(Block *successor) { 367 // Add back a reference to the successor so that the address can be resolved 368 // later. 369 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 370 append(ByteCodeAddr(0)); 371 } 372 373 /// Append a successor range to the bytecode, the exact address will need to 374 /// be resolved later. 375 void append(SuccessorRange successors) { 376 for (Block *successor : successors) 377 append(successor); 378 } 379 380 /// Append a range of values that will be read as generic PDLValues. 381 void appendPDLValueList(OperandRange values) { 382 bytecode.push_back(values.size()); 383 for (Value value : values) 384 appendPDLValue(value); 385 } 386 387 /// Append a value as a PDLValue. 388 void appendPDLValue(Value value) { 389 appendPDLValueKind(value); 390 append(value); 391 } 392 393 /// Append the PDLValue::Kind of the given value. 394 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 395 396 /// Append the PDLValue::Kind of the given type. 397 void appendPDLValueKind(Type type) { 398 PDLValue::Kind kind = 399 TypeSwitch<Type, PDLValue::Kind>(type) 400 .Case<pdl::AttributeType>( 401 [](Type) { return PDLValue::Kind::Attribute; }) 402 .Case<pdl::OperationType>( 403 [](Type) { return PDLValue::Kind::Operation; }) 404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 405 if (rangeTy.getElementType().isa<pdl::TypeType>()) 406 return PDLValue::Kind::TypeRange; 407 return PDLValue::Kind::ValueRange; 408 }) 409 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 410 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 411 bytecode.push_back(static_cast<ByteCodeField>(kind)); 412 } 413 414 /// Append a value that will be stored in a memory slot and not inline within 415 /// the bytecode. 416 template <typename T> 417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 418 std::is_pointer<T>::value> 419 append(T value) { 420 bytecode.push_back(generator.getMemIndex(value)); 421 } 422 423 /// Append a range of values. 424 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 426 append(T range) { 427 bytecode.push_back(llvm::size(range)); 428 for (auto it : range) 429 append(it); 430 } 431 432 /// Append a variadic number of fields to the bytecode. 433 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 435 append(field); 436 append(field2, fields...); 437 } 438 439 /// Appends a value as a pointer, stored inline within the bytecode. 440 template <typename T> 441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 442 appendInline(T value) { 443 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 444 const void *pointer = value.getAsOpaquePointer(); 445 ByteCodeField fieldParts[numParts]; 446 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 447 bytecode.append(fieldParts, fieldParts + numParts); 448 } 449 450 /// Successor references in the bytecode that have yet to be resolved. 451 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 452 453 /// The underlying bytecode buffer. 454 SmallVectorImpl<ByteCodeField> &bytecode; 455 456 /// The main generator producing PDL. 457 Generator &generator; 458 }; 459 460 /// This class represents a live range of PDL Interpreter values, containing 461 /// information about when values are live within a match/rewrite. 462 struct ByteCodeLiveRange { 463 using Set = llvm::IntervalMap<uint64_t, char, 16>; 464 using Allocator = Set::Allocator; 465 466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 467 468 /// Union this live range with the one provided. 469 void unionWith(const ByteCodeLiveRange &rhs) { 470 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 471 ++it) 472 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 473 } 474 475 /// Returns true if this range overlaps with the one provided. 476 bool overlaps(const ByteCodeLiveRange &rhs) const { 477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 478 .valid(); 479 } 480 481 /// A map representing the ranges of the match/rewrite that a value is live in 482 /// the interpreter. 483 /// 484 /// We use std::unique_ptr here, because IntervalMap does not provide a 485 /// correct copy or move constructor. We can eliminate the pointer once 486 /// https://reviews.llvm.org/D113240 lands. 487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 488 489 /// The operation range storage index for this range. 490 std::optional<unsigned> opRangeIndex; 491 492 /// The type range storage index for this range. 493 std::optional<unsigned> typeRangeIndex; 494 495 /// The value range storage index for this range. 496 std::optional<unsigned> valueRangeIndex; 497 }; 498 } // namespace 499 500 void Generator::generate(ModuleOp module) { 501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 502 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 504 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 505 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 506 507 // Allocate memory indices for the results of operations within the matcher 508 // and rewriters. 509 allocateMemoryIndices(matcherFunc, rewriterModule); 510 511 // Generate code for the rewriter functions. 512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 513 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 515 for (Operation &op : rewriterFunc.getOps()) 516 generate(&op, rewriterByteCodeWriter); 517 } 518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 519 "unexpected branches in rewriter function"); 520 521 // Generate code for the matcher function. 522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 523 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 524 525 // Resolve successor references in the matcher. 526 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 527 ByteCodeAddr addr = blockToAddr[it.first]; 528 for (unsigned offsetToFix : it.second) 529 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 530 } 531 } 532 533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 534 ModuleOp rewriterModule) { 535 // Rewriters use simplistic allocation scheme that simply assigns an index to 536 // each result. 537 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 539 auto processRewriterValue = [&](Value val) { 540 valueToMemIndex.try_emplace(val, index++); 541 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 542 Type elementTy = rangeType.getElementType(); 543 if (elementTy.isa<pdl::TypeType>()) 544 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 545 else if (elementTy.isa<pdl::ValueType>()) 546 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 547 } 548 }; 549 550 for (BlockArgument arg : rewriterFunc.getArguments()) 551 processRewriterValue(arg); 552 rewriterFunc.getBody().walk([&](Operation *op) { 553 for (Value result : op->getResults()) 554 processRewriterValue(result); 555 }); 556 if (index > maxValueMemoryIndex) 557 maxValueMemoryIndex = index; 558 if (typeRangeIndex > maxTypeRangeMemoryIndex) 559 maxTypeRangeMemoryIndex = typeRangeIndex; 560 if (valueRangeIndex > maxValueRangeMemoryIndex) 561 maxValueRangeMemoryIndex = valueRangeIndex; 562 } 563 564 // The matcher function uses a more sophisticated numbering that tries to 565 // minimize the number of memory indices assigned. This is done by determining 566 // a live range of the values within the matcher, then the allocation is just 567 // finding the minimal number of overlapping live ranges. This is essentially 568 // a simplified form of register allocation where we don't necessarily have a 569 // limited number of registers, but we still want to minimize the number used. 570 DenseMap<Operation *, unsigned> opToFirstIndex; 571 DenseMap<Operation *, unsigned> opToLastIndex; 572 573 // A custom walk that marks the first and the last index of each operation. 574 // The entry marks the beginning of the liveness range for this operation, 575 // followed by nested operations, followed by the end of the liveness range. 576 unsigned index = 0; 577 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 578 opToFirstIndex.try_emplace(op, index++); 579 for (Region ®ion : op->getRegions()) 580 for (Block &block : region.getBlocks()) 581 for (Operation &nested : block) 582 walk(&nested); 583 opToLastIndex.try_emplace(op, index++); 584 }; 585 walk(matcherFunc); 586 587 // Liveness info for each of the defs within the matcher. 588 ByteCodeLiveRange::Allocator allocator; 589 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 590 591 // Assign the root operation being matched to slot 0. 592 BlockArgument rootOpArg = matcherFunc.getArgument(0); 593 valueToMemIndex[rootOpArg] = 0; 594 595 // Walk each of the blocks, computing the def interval that the value is used. 596 Liveness matcherLiveness(matcherFunc); 597 matcherFunc->walk([&](Block *block) { 598 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 599 assert(info && "expected liveness info for block"); 600 auto processValue = [&](Value value, Operation *firstUseOrDef) { 601 // We don't need to process the root op argument, this value is always 602 // assigned to the first memory slot. 603 if (value == rootOpArg) 604 return; 605 606 // Set indices for the range of this block that the value is used. 607 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 608 defRangeIt->second.liveness->insert( 609 opToFirstIndex[firstUseOrDef], 610 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 611 /*dummyValue*/ 0); 612 613 // Check to see if this value is a range type. 614 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 615 Type eleType = rangeTy.getElementType(); 616 if (eleType.isa<pdl::OperationType>()) 617 defRangeIt->second.opRangeIndex = 0; 618 else if (eleType.isa<pdl::TypeType>()) 619 defRangeIt->second.typeRangeIndex = 0; 620 else if (eleType.isa<pdl::ValueType>()) 621 defRangeIt->second.valueRangeIndex = 0; 622 } 623 }; 624 625 // Process the live-ins of this block. 626 for (Value liveIn : info->in()) { 627 // Only process the value if it has been defined in the current region. 628 // Other values that span across pdl_interp.foreach will be added higher 629 // up. This ensures that the we keep them alive for the entire duration 630 // of the loop. 631 if (liveIn.getParentRegion() == block->getParent()) 632 processValue(liveIn, &block->front()); 633 } 634 635 // Process the block arguments for the entry block (those are not live-in). 636 if (block->isEntryBlock()) { 637 for (Value argument : block->getArguments()) 638 processValue(argument, &block->front()); 639 } 640 641 // Process any new defs within this block. 642 for (Operation &op : *block) 643 for (Value result : op.getResults()) 644 processValue(result, &op); 645 }); 646 647 // Greedily allocate memory slots using the computed def live ranges. 648 std::vector<ByteCodeLiveRange> allocatedIndices; 649 650 // The number of memory indices currently allocated (and its next value). 651 // Recall that the root gets allocated memory index 0. 652 ByteCodeField numIndices = 1; 653 654 // The number of memory ranges of various types (and their next values). 655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 656 657 for (auto &defIt : valueDefRanges) { 658 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 659 ByteCodeLiveRange &defRange = defIt.second; 660 661 // Try to allocate to an existing index. 662 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 663 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 664 if (!defRange.overlaps(existingRange)) { 665 existingRange.unionWith(defRange); 666 memIndex = existingIndexIt.index() + 1; 667 668 if (defRange.opRangeIndex) { 669 if (!existingRange.opRangeIndex) 670 existingRange.opRangeIndex = numOpRanges++; 671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 672 } else if (defRange.typeRangeIndex) { 673 if (!existingRange.typeRangeIndex) 674 existingRange.typeRangeIndex = numTypeRanges++; 675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 676 } else if (defRange.valueRangeIndex) { 677 if (!existingRange.valueRangeIndex) 678 existingRange.valueRangeIndex = numValueRanges++; 679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 680 } 681 break; 682 } 683 } 684 685 // If no existing index could be used, add a new one. 686 if (memIndex == 0) { 687 allocatedIndices.emplace_back(allocator); 688 ByteCodeLiveRange &newRange = allocatedIndices.back(); 689 newRange.unionWith(defRange); 690 691 // Allocate an index for op/type/value ranges. 692 if (defRange.opRangeIndex) { 693 newRange.opRangeIndex = numOpRanges; 694 valueToRangeIndex[defIt.first] = numOpRanges++; 695 } else if (defRange.typeRangeIndex) { 696 newRange.typeRangeIndex = numTypeRanges; 697 valueToRangeIndex[defIt.first] = numTypeRanges++; 698 } else if (defRange.valueRangeIndex) { 699 newRange.valueRangeIndex = numValueRanges; 700 valueToRangeIndex[defIt.first] = numValueRanges++; 701 } 702 703 memIndex = allocatedIndices.size(); 704 ++numIndices; 705 } 706 } 707 708 // Print the index usage and ensure that we did not run out of index space. 709 LLVM_DEBUG({ 710 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 711 << "(down from initial " << valueDefRanges.size() << ").\n"; 712 }); 713 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 714 "Ran out of memory for allocated indices"); 715 716 // Update the max number of indices. 717 if (numIndices > maxValueMemoryIndex) 718 maxValueMemoryIndex = numIndices; 719 if (numOpRanges > maxOpRangeMemoryIndex) 720 maxOpRangeMemoryIndex = numOpRanges; 721 if (numTypeRanges > maxTypeRangeMemoryIndex) 722 maxTypeRangeMemoryIndex = numTypeRanges; 723 if (numValueRanges > maxValueRangeMemoryIndex) 724 maxValueRangeMemoryIndex = numValueRanges; 725 } 726 727 void Generator::generate(Region *region, ByteCodeWriter &writer) { 728 llvm::ReversePostOrderTraversal<Region *> rpot(region); 729 for (Block *block : rpot) { 730 // Keep track of where this block begins within the matcher function. 731 blockToAddr.try_emplace(block, matcherByteCode.size()); 732 for (Operation &op : *block) 733 generate(&op, writer); 734 } 735 } 736 737 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 738 LLVM_DEBUG({ 739 // The following list must contain all the operations that do not 740 // produce any bytecode. 741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op)) 742 writer.appendInline(op->getLoc()); 743 }); 744 TypeSwitch<Operation *>(op) 745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 746 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp, 752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, 762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 763 pdl_interp::SwitchResultCountOp>( 764 [&](auto interpOp) { this->generate(interpOp, writer); }) 765 .Default([](Operation *) { 766 llvm_unreachable("unknown `pdl_interp` operation"); 767 }); 768 } 769 770 void Generator::generate(pdl_interp::ApplyConstraintOp op, 771 ByteCodeWriter &writer) { 772 assert(constraintToMemIndex.count(op.getName()) && 773 "expected index for constraint function"); 774 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 775 writer.appendPDLValueList(op.getArgs()); 776 writer.append(op.getSuccessors()); 777 } 778 void Generator::generate(pdl_interp::ApplyRewriteOp op, 779 ByteCodeWriter &writer) { 780 assert(externalRewriterToMemIndex.count(op.getName()) && 781 "expected index for rewrite function"); 782 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 783 writer.appendPDLValueList(op.getArgs()); 784 785 ResultRange results = op.getResults(); 786 writer.append(ByteCodeField(results.size())); 787 for (Value result : results) { 788 // In debug mode we also record the expected kind of the result, so that we 789 // can provide extra verification of the native rewrite function. 790 #ifndef NDEBUG 791 writer.appendPDLValueKind(result); 792 #endif 793 794 // Range results also need to append the range storage index. 795 if (result.getType().isa<pdl::RangeType>()) 796 writer.append(getRangeStorageIndex(result)); 797 writer.append(result); 798 } 799 } 800 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 801 Value lhs = op.getLhs(); 802 if (lhs.getType().isa<pdl::RangeType>()) { 803 writer.append(OpCode::AreRangesEqual); 804 writer.appendPDLValueKind(lhs); 805 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 806 return; 807 } 808 809 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 810 } 811 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 812 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 813 } 814 void Generator::generate(pdl_interp::CheckAttributeOp op, 815 ByteCodeWriter &writer) { 816 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 817 op.getSuccessors()); 818 } 819 void Generator::generate(pdl_interp::CheckOperandCountOp op, 820 ByteCodeWriter &writer) { 821 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 822 static_cast<ByteCodeField>(op.getCompareAtLeast()), 823 op.getSuccessors()); 824 } 825 void Generator::generate(pdl_interp::CheckOperationNameOp op, 826 ByteCodeWriter &writer) { 827 writer.append(OpCode::CheckOperationName, op.getInputOp(), 828 OperationName(op.getName(), ctx), op.getSuccessors()); 829 } 830 void Generator::generate(pdl_interp::CheckResultCountOp op, 831 ByteCodeWriter &writer) { 832 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 833 static_cast<ByteCodeField>(op.getCompareAtLeast()), 834 op.getSuccessors()); 835 } 836 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 837 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 838 op.getSuccessors()); 839 } 840 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 841 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 842 op.getSuccessors()); 843 } 844 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 845 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 846 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 847 } 848 void Generator::generate(pdl_interp::CreateAttributeOp op, 849 ByteCodeWriter &writer) { 850 // Simply repoint the memory index of the result to the constant. 851 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 852 } 853 void Generator::generate(pdl_interp::CreateOperationOp op, 854 ByteCodeWriter &writer) { 855 writer.append(OpCode::CreateOperation, op.getResultOp(), 856 OperationName(op.getName(), ctx)); 857 writer.appendPDLValueList(op.getInputOperands()); 858 859 // Add the attributes. 860 OperandRange attributes = op.getInputAttributes(); 861 writer.append(static_cast<ByteCodeField>(attributes.size())); 862 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 863 writer.append(std::get<0>(it), std::get<1>(it)); 864 865 // Add the result types. If the operation has inferred results, we use a 866 // marker "size" value. Otherwise, we add the list of explicit result types. 867 if (op.getInferredResultTypes()) 868 writer.append(kInferTypesMarker); 869 else 870 writer.appendPDLValueList(op.getInputResultTypes()); 871 } 872 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { 873 // Append the correct opcode for the range type. 874 TypeSwitch<Type>(op.getType().getElementType()) 875 .Case( 876 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); }) 877 .Case([&](pdl::ValueType) { 878 writer.append(OpCode::CreateDynamicValueRange); 879 }); 880 881 writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); 882 writer.appendPDLValueList(op->getOperands()); 883 } 884 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 885 // Simply repoint the memory index of the result to the constant. 886 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 887 } 888 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 889 writer.append(OpCode::CreateConstantTypeRange, op.getResult(), 890 getRangeStorageIndex(op.getResult()), op.getValue()); 891 } 892 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 893 writer.append(OpCode::EraseOp, op.getInputOp()); 894 } 895 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 896 OpCode opCode = 897 TypeSwitch<Type, OpCode>(op.getResult().getType()) 898 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 899 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 900 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 901 .Default([](Type) -> OpCode { 902 llvm_unreachable("unsupported element type"); 903 }); 904 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 905 } 906 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 907 writer.append(OpCode::Finalize); 908 } 909 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 910 BlockArgument arg = op.getLoopVariable(); 911 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 912 writer.appendPDLValueKind(arg.getType()); 913 writer.append(curLoopLevel, op.getSuccessor()); 914 ++curLoopLevel; 915 if (curLoopLevel > maxLoopLevel) 916 maxLoopLevel = curLoopLevel; 917 generate(&op.getRegion(), writer); 918 --curLoopLevel; 919 } 920 void Generator::generate(pdl_interp::GetAttributeOp op, 921 ByteCodeWriter &writer) { 922 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 923 op.getNameAttr()); 924 } 925 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 926 ByteCodeWriter &writer) { 927 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 928 } 929 void Generator::generate(pdl_interp::GetDefiningOpOp op, 930 ByteCodeWriter &writer) { 931 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 932 writer.appendPDLValue(op.getValue()); 933 } 934 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 935 uint32_t index = op.getIndex(); 936 if (index < 4) 937 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 938 else 939 writer.append(OpCode::GetOperandN, index); 940 writer.append(op.getInputOp(), op.getValue()); 941 } 942 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 943 Value result = op.getValue(); 944 std::optional<uint32_t> index = op.getIndex(); 945 writer.append(OpCode::GetOperands, 946 index.value_or(std::numeric_limits<uint32_t>::max()), 947 op.getInputOp()); 948 if (result.getType().isa<pdl::RangeType>()) 949 writer.append(getRangeStorageIndex(result)); 950 else 951 writer.append(std::numeric_limits<ByteCodeField>::max()); 952 writer.append(result); 953 } 954 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 955 uint32_t index = op.getIndex(); 956 if (index < 4) 957 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 958 else 959 writer.append(OpCode::GetResultN, index); 960 writer.append(op.getInputOp(), op.getValue()); 961 } 962 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 963 Value result = op.getValue(); 964 std::optional<uint32_t> index = op.getIndex(); 965 writer.append(OpCode::GetResults, 966 index.value_or(std::numeric_limits<uint32_t>::max()), 967 op.getInputOp()); 968 if (result.getType().isa<pdl::RangeType>()) 969 writer.append(getRangeStorageIndex(result)); 970 else 971 writer.append(std::numeric_limits<ByteCodeField>::max()); 972 writer.append(result); 973 } 974 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 975 Value operations = op.getOperations(); 976 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 977 writer.append(OpCode::GetUsers, operations, rangeIndex); 978 writer.appendPDLValue(op.getValue()); 979 } 980 void Generator::generate(pdl_interp::GetValueTypeOp op, 981 ByteCodeWriter &writer) { 982 if (op.getType().isa<pdl::RangeType>()) { 983 Value result = op.getResult(); 984 writer.append(OpCode::GetValueRangeTypes, result, 985 getRangeStorageIndex(result), op.getValue()); 986 } else { 987 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 988 } 989 } 990 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 991 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 992 } 993 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 994 ByteCodeField patternIndex = patterns.size(); 995 patterns.emplace_back(PDLByteCodePattern::create( 996 op, configMap.lookup(op), 997 rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 998 writer.append(OpCode::RecordMatch, patternIndex, 999 SuccessorRange(op.getOperation()), op.getMatchedOps()); 1000 writer.appendPDLValueList(op.getInputs()); 1001 } 1002 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 1003 writer.append(OpCode::ReplaceOp, op.getInputOp()); 1004 writer.appendPDLValueList(op.getReplValues()); 1005 } 1006 void Generator::generate(pdl_interp::SwitchAttributeOp op, 1007 ByteCodeWriter &writer) { 1008 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 1009 op.getCaseValuesAttr(), op.getSuccessors()); 1010 } 1011 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 1012 ByteCodeWriter &writer) { 1013 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 1014 op.getCaseValuesAttr(), op.getSuccessors()); 1015 } 1016 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 1017 ByteCodeWriter &writer) { 1018 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 1019 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 1020 }); 1021 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 1022 op.getSuccessors()); 1023 } 1024 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1025 ByteCodeWriter &writer) { 1026 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1027 op.getCaseValuesAttr(), op.getSuccessors()); 1028 } 1029 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1030 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1031 op.getSuccessors()); 1032 } 1033 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1034 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1035 op.getSuccessors()); 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // PDLByteCode 1040 //===----------------------------------------------------------------------===// 1041 1042 PDLByteCode::PDLByteCode( 1043 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, 1044 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, 1045 llvm::StringMap<PDLConstraintFunction> constraintFns, 1046 llvm::StringMap<PDLRewriteFunction> rewriteFns) 1047 : configs(std::move(configs)) { 1048 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1049 rewriterByteCode, patterns, maxValueMemoryIndex, 1050 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1051 maxLoopLevel, constraintFns, rewriteFns, configMap); 1052 generator.generate(module); 1053 1054 // Initialize the external functions. 1055 for (auto &it : constraintFns) 1056 constraintFunctions.push_back(std::move(it.second)); 1057 for (auto &it : rewriteFns) 1058 rewriteFunctions.push_back(std::move(it.second)); 1059 } 1060 1061 /// Initialize the given state such that it can be used to execute the current 1062 /// bytecode. 1063 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1064 state.memory.resize(maxValueMemoryIndex, nullptr); 1065 state.opRangeMemory.resize(maxOpRangeCount); 1066 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1067 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1068 state.loopIndex.resize(maxLoopLevel, 0); 1069 state.currentPatternBenefits.reserve(patterns.size()); 1070 for (const PDLByteCodePattern &pattern : patterns) 1071 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1072 } 1073 1074 //===----------------------------------------------------------------------===// 1075 // ByteCode Execution 1076 1077 namespace { 1078 /// This class provides support for executing a bytecode stream. 1079 class ByteCodeExecutor { 1080 public: 1081 ByteCodeExecutor( 1082 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1083 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1084 MutableArrayRef<TypeRange> typeRangeMemory, 1085 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1086 MutableArrayRef<ValueRange> valueRangeMemory, 1087 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1088 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1089 ArrayRef<ByteCodeField> code, 1090 ArrayRef<PatternBenefit> currentPatternBenefits, 1091 ArrayRef<PDLByteCodePattern> patterns, 1092 ArrayRef<PDLConstraintFunction> constraintFunctions, 1093 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1094 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1095 typeRangeMemory(typeRangeMemory), 1096 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1097 valueRangeMemory(valueRangeMemory), 1098 allocatedValueRangeMemory(allocatedValueRangeMemory), 1099 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1100 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1101 constraintFunctions(constraintFunctions), 1102 rewriteFunctions(rewriteFunctions) {} 1103 1104 /// Start executing the code at the current bytecode index. `matches` is an 1105 /// optional field provided when this function is executed in a matching 1106 /// context. 1107 LogicalResult 1108 execute(PatternRewriter &rewriter, 1109 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1110 std::optional<Location> mainRewriteLoc = {}); 1111 1112 private: 1113 /// Internal implementation of executing each of the bytecode commands. 1114 void executeApplyConstraint(PatternRewriter &rewriter); 1115 LogicalResult executeApplyRewrite(PatternRewriter &rewriter); 1116 void executeAreEqual(); 1117 void executeAreRangesEqual(); 1118 void executeBranch(); 1119 void executeCheckOperandCount(); 1120 void executeCheckOperationName(); 1121 void executeCheckResultCount(); 1122 void executeCheckTypes(); 1123 void executeContinue(); 1124 void executeCreateConstantTypeRange(); 1125 void executeCreateOperation(PatternRewriter &rewriter, 1126 Location mainRewriteLoc); 1127 template <typename T> 1128 void executeDynamicCreateRange(StringRef type); 1129 void executeEraseOp(PatternRewriter &rewriter); 1130 template <typename T, typename Range, PDLValue::Kind kind> 1131 void executeExtract(); 1132 void executeFinalize(); 1133 void executeForEach(); 1134 void executeGetAttribute(); 1135 void executeGetAttributeType(); 1136 void executeGetDefiningOp(); 1137 void executeGetOperand(unsigned index); 1138 void executeGetOperands(); 1139 void executeGetResult(unsigned index); 1140 void executeGetResults(); 1141 void executeGetUsers(); 1142 void executeGetValueType(); 1143 void executeGetValueRangeTypes(); 1144 void executeIsNotNull(); 1145 void executeRecordMatch(PatternRewriter &rewriter, 1146 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1147 void executeReplaceOp(PatternRewriter &rewriter); 1148 void executeSwitchAttribute(); 1149 void executeSwitchOperandCount(); 1150 void executeSwitchOperationName(); 1151 void executeSwitchResultCount(); 1152 void executeSwitchType(); 1153 void executeSwitchTypes(); 1154 1155 /// Pushes a code iterator to the stack. 1156 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1157 1158 /// Pops a code iterator from the stack, returning true on success. 1159 void popCodeIt() { 1160 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1161 curCodeIt = resumeCodeIt.back(); 1162 resumeCodeIt.pop_back(); 1163 } 1164 1165 /// Return the bytecode iterator at the start of the current op code. 1166 const ByteCodeField *getPrevCodeIt() const { 1167 LLVM_DEBUG({ 1168 // Account for the op code and the Location stored inline. 1169 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1170 }); 1171 1172 // Account for the op code only. 1173 return curCodeIt - 1; 1174 } 1175 1176 /// Read a value from the bytecode buffer, optionally skipping a certain 1177 /// number of prefix values. These methods always update the buffer to point 1178 /// to the next field after the read data. 1179 template <typename T = ByteCodeField> 1180 T read(size_t skipN = 0) { 1181 curCodeIt += skipN; 1182 return readImpl<T>(); 1183 } 1184 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1185 1186 /// Read a list of values from the bytecode buffer. 1187 template <typename ValueT, typename T> 1188 void readList(SmallVectorImpl<T> &list) { 1189 list.clear(); 1190 for (unsigned i = 0, e = read(); i != e; ++i) 1191 list.push_back(read<ValueT>()); 1192 } 1193 1194 /// Read a list of values from the bytecode buffer. The values may be encoded 1195 /// either as a single element or a range of elements. 1196 void readList(SmallVectorImpl<Type> &list) { 1197 for (unsigned i = 0, e = read(); i != e; ++i) { 1198 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1199 list.push_back(read<Type>()); 1200 } else { 1201 TypeRange *values = read<TypeRange *>(); 1202 list.append(values->begin(), values->end()); 1203 } 1204 } 1205 } 1206 void readList(SmallVectorImpl<Value> &list) { 1207 for (unsigned i = 0, e = read(); i != e; ++i) { 1208 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1209 list.push_back(read<Value>()); 1210 } else { 1211 ValueRange *values = read<ValueRange *>(); 1212 list.append(values->begin(), values->end()); 1213 } 1214 } 1215 } 1216 1217 /// Read a value stored inline as a pointer. 1218 template <typename T> 1219 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1220 readInline() { 1221 const void *pointer; 1222 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1223 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1224 return T::getFromOpaquePointer(pointer); 1225 } 1226 1227 /// Jump to a specific successor based on a predicate value. 1228 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1229 /// Jump to a specific successor based on a destination index. 1230 void selectJump(size_t destIndex) { 1231 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1232 } 1233 1234 /// Handle a switch operation with the provided value and cases. 1235 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1236 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1237 LLVM_DEBUG({ 1238 llvm::dbgs() << " * Value: " << value << "\n" 1239 << " * Cases: "; 1240 llvm::interleaveComma(cases, llvm::dbgs()); 1241 llvm::dbgs() << "\n"; 1242 }); 1243 1244 // Check to see if the attribute value is within the case list. Jump to 1245 // the correct successor index based on the result. 1246 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1247 if (cmp(*it, value)) 1248 return selectJump(size_t((it - cases.begin()) + 1)); 1249 selectJump(size_t(0)); 1250 } 1251 1252 /// Store a pointer to memory. 1253 void storeToMemory(unsigned index, const void *value) { 1254 memory[index] = value; 1255 } 1256 1257 /// Store a value to memory as an opaque pointer. 1258 template <typename T> 1259 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1260 storeToMemory(unsigned index, T value) { 1261 memory[index] = value.getAsOpaquePointer(); 1262 } 1263 1264 /// Internal implementation of reading various data types from the bytecode 1265 /// stream. 1266 template <typename T> 1267 const void *readFromMemory() { 1268 size_t index = *curCodeIt++; 1269 1270 // If this type is an SSA value, it can only be stored in non-const memory. 1271 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1272 Value>::value || 1273 index < memory.size()) 1274 return memory[index]; 1275 1276 // Otherwise, if this index is not inbounds it is uniqued. 1277 return uniquedMemory[index - memory.size()]; 1278 } 1279 template <typename T> 1280 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1281 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1282 } 1283 template <typename T> 1284 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1285 T> 1286 readImpl() { 1287 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1288 } 1289 template <typename T> 1290 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1291 switch (read<PDLValue::Kind>()) { 1292 case PDLValue::Kind::Attribute: 1293 return read<Attribute>(); 1294 case PDLValue::Kind::Operation: 1295 return read<Operation *>(); 1296 case PDLValue::Kind::Type: 1297 return read<Type>(); 1298 case PDLValue::Kind::Value: 1299 return read<Value>(); 1300 case PDLValue::Kind::TypeRange: 1301 return read<TypeRange *>(); 1302 case PDLValue::Kind::ValueRange: 1303 return read<ValueRange *>(); 1304 } 1305 llvm_unreachable("unhandled PDLValue::Kind"); 1306 } 1307 template <typename T> 1308 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1309 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1310 "unexpected ByteCode address size"); 1311 ByteCodeAddr result; 1312 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1313 curCodeIt += 2; 1314 return result; 1315 } 1316 template <typename T> 1317 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1318 return *curCodeIt++; 1319 } 1320 template <typename T> 1321 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1322 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1323 } 1324 1325 /// Assign the given range to the given memory index. This allocates a new 1326 /// range object if necessary. 1327 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>> 1328 void assignRangeToMemory(RangeT &&range, unsigned memIndex, 1329 unsigned rangeIndex) { 1330 // Utility functor used to type-erase the assignment. 1331 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { 1332 // If the input range is empty, we don't need to allocate anything. 1333 if (range.empty()) { 1334 rangeMemory[rangeIndex] = {}; 1335 } else { 1336 // Allocate a buffer for this type range. 1337 llvm::OwningArrayRef<T> storage(llvm::size(range)); 1338 llvm::copy(range, storage.begin()); 1339 1340 // Assign this to the range slot and use the range as the value for the 1341 // memory index. 1342 allocatedRangeMemory.emplace_back(std::move(storage)); 1343 rangeMemory[rangeIndex] = allocatedRangeMemory.back(); 1344 } 1345 memory[memIndex] = &rangeMemory[rangeIndex]; 1346 }; 1347 1348 // Dispatch based on the concrete range type. 1349 if constexpr (std::is_same_v<T, Type>) { 1350 return assignRange(allocatedTypeRangeMemory, typeRangeMemory); 1351 } else if constexpr (std::is_same_v<T, Value>) { 1352 return assignRange(allocatedValueRangeMemory, valueRangeMemory); 1353 } else { 1354 llvm_unreachable("unhandled range type"); 1355 } 1356 } 1357 1358 /// The underlying bytecode buffer. 1359 const ByteCodeField *curCodeIt; 1360 1361 /// The stack of bytecode positions at which to resume operation. 1362 SmallVector<const ByteCodeField *> resumeCodeIt; 1363 1364 /// The current execution memory. 1365 MutableArrayRef<const void *> memory; 1366 MutableArrayRef<OwningOpRange> opRangeMemory; 1367 MutableArrayRef<TypeRange> typeRangeMemory; 1368 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1369 MutableArrayRef<ValueRange> valueRangeMemory; 1370 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1371 1372 /// The current loop indices. 1373 MutableArrayRef<unsigned> loopIndex; 1374 1375 /// References to ByteCode data necessary for execution. 1376 ArrayRef<const void *> uniquedMemory; 1377 ArrayRef<ByteCodeField> code; 1378 ArrayRef<PatternBenefit> currentPatternBenefits; 1379 ArrayRef<PDLByteCodePattern> patterns; 1380 ArrayRef<PDLConstraintFunction> constraintFunctions; 1381 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1382 }; 1383 1384 /// This class is an instantiation of the PDLResultList that provides access to 1385 /// the returned results. This API is not on `PDLResultList` to avoid 1386 /// overexposing access to information specific solely to the ByteCode. 1387 class ByteCodeRewriteResultList : public PDLResultList { 1388 public: 1389 ByteCodeRewriteResultList(unsigned maxNumResults) 1390 : PDLResultList(maxNumResults) {} 1391 1392 /// Return the list of PDL results. 1393 MutableArrayRef<PDLValue> getResults() { return results; } 1394 1395 /// Return the type ranges allocated by this list. 1396 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1397 return allocatedTypeRanges; 1398 } 1399 1400 /// Return the value ranges allocated by this list. 1401 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1402 return allocatedValueRanges; 1403 } 1404 }; 1405 } // namespace 1406 1407 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1408 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1409 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1410 SmallVector<PDLValue, 16> args; 1411 readList<PDLValue>(args); 1412 1413 LLVM_DEBUG({ 1414 llvm::dbgs() << " * Arguments: "; 1415 llvm::interleaveComma(args, llvm::dbgs()); 1416 }); 1417 1418 // Invoke the constraint and jump to the proper destination. 1419 selectJump(succeeded(constraintFn(rewriter, args))); 1420 } 1421 1422 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1423 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1424 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1425 SmallVector<PDLValue, 16> args; 1426 readList<PDLValue>(args); 1427 1428 LLVM_DEBUG({ 1429 llvm::dbgs() << " * Arguments: "; 1430 llvm::interleaveComma(args, llvm::dbgs()); 1431 }); 1432 1433 // Execute the rewrite function. 1434 ByteCodeField numResults = read(); 1435 ByteCodeRewriteResultList results(numResults); 1436 LogicalResult rewriteResult = rewriteFn(rewriter, results, args); 1437 1438 assert(results.getResults().size() == numResults && 1439 "native PDL rewrite function returned unexpected number of results"); 1440 1441 // Store the results in the bytecode memory. 1442 for (PDLValue &result : results.getResults()) { 1443 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1444 1445 // In debug mode we also verify the expected kind of the result. 1446 #ifndef NDEBUG 1447 assert(result.getKind() == read<PDLValue::Kind>() && 1448 "native PDL rewrite function returned an unexpected type of result"); 1449 #endif 1450 1451 // If the result is a range, we need to copy it over to the bytecodes 1452 // range memory. 1453 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1454 unsigned rangeIndex = read(); 1455 typeRangeMemory[rangeIndex] = *typeRange; 1456 memory[read()] = &typeRangeMemory[rangeIndex]; 1457 } else if (std::optional<ValueRange> valueRange = 1458 result.dyn_cast<ValueRange>()) { 1459 unsigned rangeIndex = read(); 1460 valueRangeMemory[rangeIndex] = *valueRange; 1461 memory[read()] = &valueRangeMemory[rangeIndex]; 1462 } else { 1463 memory[read()] = result.getAsOpaquePointer(); 1464 } 1465 } 1466 1467 // Copy over any underlying storage allocated for result ranges. 1468 for (auto &it : results.getAllocatedTypeRanges()) 1469 allocatedTypeRangeMemory.push_back(std::move(it)); 1470 for (auto &it : results.getAllocatedValueRanges()) 1471 allocatedValueRangeMemory.push_back(std::move(it)); 1472 1473 // Process the result of the rewrite. 1474 if (failed(rewriteResult)) { 1475 LLVM_DEBUG(llvm::dbgs() << " - Failed"); 1476 return failure(); 1477 } 1478 return success(); 1479 } 1480 1481 void ByteCodeExecutor::executeAreEqual() { 1482 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1483 const void *lhs = read<const void *>(); 1484 const void *rhs = read<const void *>(); 1485 1486 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1487 selectJump(lhs == rhs); 1488 } 1489 1490 void ByteCodeExecutor::executeAreRangesEqual() { 1491 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1492 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1493 const void *lhs = read<const void *>(); 1494 const void *rhs = read<const void *>(); 1495 1496 switch (valueKind) { 1497 case PDLValue::Kind::TypeRange: { 1498 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1499 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1500 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1501 selectJump(*lhsRange == *rhsRange); 1502 break; 1503 } 1504 case PDLValue::Kind::ValueRange: { 1505 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1506 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1507 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1508 selectJump(*lhsRange == *rhsRange); 1509 break; 1510 } 1511 default: 1512 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1513 } 1514 } 1515 1516 void ByteCodeExecutor::executeBranch() { 1517 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1518 curCodeIt = &code[read<ByteCodeAddr>()]; 1519 } 1520 1521 void ByteCodeExecutor::executeCheckOperandCount() { 1522 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1523 Operation *op = read<Operation *>(); 1524 uint32_t expectedCount = read<uint32_t>(); 1525 bool compareAtLeast = read(); 1526 1527 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1528 << " * Expected: " << expectedCount << "\n" 1529 << " * Comparator: " 1530 << (compareAtLeast ? ">=" : "==") << "\n"); 1531 if (compareAtLeast) 1532 selectJump(op->getNumOperands() >= expectedCount); 1533 else 1534 selectJump(op->getNumOperands() == expectedCount); 1535 } 1536 1537 void ByteCodeExecutor::executeCheckOperationName() { 1538 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1539 Operation *op = read<Operation *>(); 1540 OperationName expectedName = read<OperationName>(); 1541 1542 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1543 << " * Expected: \"" << expectedName << "\"\n"); 1544 selectJump(op->getName() == expectedName); 1545 } 1546 1547 void ByteCodeExecutor::executeCheckResultCount() { 1548 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1549 Operation *op = read<Operation *>(); 1550 uint32_t expectedCount = read<uint32_t>(); 1551 bool compareAtLeast = read(); 1552 1553 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1554 << " * Expected: " << expectedCount << "\n" 1555 << " * Comparator: " 1556 << (compareAtLeast ? ">=" : "==") << "\n"); 1557 if (compareAtLeast) 1558 selectJump(op->getNumResults() >= expectedCount); 1559 else 1560 selectJump(op->getNumResults() == expectedCount); 1561 } 1562 1563 void ByteCodeExecutor::executeCheckTypes() { 1564 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1565 TypeRange *lhs = read<TypeRange *>(); 1566 Attribute rhs = read<Attribute>(); 1567 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1568 1569 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1570 } 1571 1572 void ByteCodeExecutor::executeContinue() { 1573 ByteCodeField level = read(); 1574 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1575 << " * Level: " << level << "\n"); 1576 ++loopIndex[level]; 1577 popCodeIt(); 1578 } 1579 1580 void ByteCodeExecutor::executeCreateConstantTypeRange() { 1581 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); 1582 unsigned memIndex = read(); 1583 unsigned rangeIndex = read(); 1584 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1585 1586 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1587 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex, 1588 rangeIndex); 1589 } 1590 1591 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1592 Location mainRewriteLoc) { 1593 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1594 1595 unsigned memIndex = read(); 1596 OperationState state(mainRewriteLoc, read<OperationName>()); 1597 readList(state.operands); 1598 for (unsigned i = 0, e = read(); i != e; ++i) { 1599 StringAttr name = read<StringAttr>(); 1600 if (Attribute attr = read<Attribute>()) 1601 state.addAttribute(name, attr); 1602 } 1603 1604 // Read in the result types. If the "size" is the sentinel value, this 1605 // indicates that the result types should be inferred. 1606 unsigned numResults = read(); 1607 if (numResults == kInferTypesMarker) { 1608 InferTypeOpInterface::Concept *inferInterface = 1609 state.name.getInterface<InferTypeOpInterface>(); 1610 assert(inferInterface && 1611 "expected operation to provide InferTypeOpInterface"); 1612 1613 // TODO: Handle failure. 1614 if (failed(inferInterface->inferReturnTypes( 1615 state.getContext(), state.location, state.operands, 1616 state.attributes.getDictionary(state.getContext()), state.regions, 1617 state.types))) 1618 return; 1619 } else { 1620 // Otherwise, this is a fixed number of results. 1621 for (unsigned i = 0; i != numResults; ++i) { 1622 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1623 state.types.push_back(read<Type>()); 1624 } else { 1625 TypeRange *resultTypes = read<TypeRange *>(); 1626 state.types.append(resultTypes->begin(), resultTypes->end()); 1627 } 1628 } 1629 } 1630 1631 Operation *resultOp = rewriter.create(state); 1632 memory[memIndex] = resultOp; 1633 1634 LLVM_DEBUG({ 1635 llvm::dbgs() << " * Attributes: " 1636 << state.attributes.getDictionary(state.getContext()) 1637 << "\n * Operands: "; 1638 llvm::interleaveComma(state.operands, llvm::dbgs()); 1639 llvm::dbgs() << "\n * Result Types: "; 1640 llvm::interleaveComma(state.types, llvm::dbgs()); 1641 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1642 }); 1643 } 1644 1645 template <typename T> 1646 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { 1647 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); 1648 unsigned memIndex = read(); 1649 unsigned rangeIndex = read(); 1650 SmallVector<T> values; 1651 readList(values); 1652 1653 LLVM_DEBUG({ 1654 llvm::dbgs() << "\n * " << type << "s: "; 1655 llvm::interleaveComma(values, llvm::dbgs()); 1656 llvm::dbgs() << "\n"; 1657 }); 1658 1659 assignRangeToMemory(values, memIndex, rangeIndex); 1660 } 1661 1662 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1663 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1664 Operation *op = read<Operation *>(); 1665 1666 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1667 rewriter.eraseOp(op); 1668 } 1669 1670 template <typename T, typename Range, PDLValue::Kind kind> 1671 void ByteCodeExecutor::executeExtract() { 1672 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1673 Range *range = read<Range *>(); 1674 unsigned index = read<uint32_t>(); 1675 unsigned memIndex = read(); 1676 1677 if (!range) { 1678 memory[memIndex] = nullptr; 1679 return; 1680 } 1681 1682 T result = index < range->size() ? (*range)[index] : T(); 1683 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1684 << " * Index: " << index << "\n" 1685 << " * Result: " << result << "\n"); 1686 storeToMemory(memIndex, result); 1687 } 1688 1689 void ByteCodeExecutor::executeFinalize() { 1690 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1691 } 1692 1693 void ByteCodeExecutor::executeForEach() { 1694 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1695 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1696 unsigned rangeIndex = read(); 1697 unsigned memIndex = read(); 1698 const void *value = nullptr; 1699 1700 switch (read<PDLValue::Kind>()) { 1701 case PDLValue::Kind::Operation: { 1702 unsigned &index = loopIndex[read()]; 1703 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1704 assert(index <= array.size() && "iterated past the end"); 1705 if (index < array.size()) { 1706 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1707 value = array[index]; 1708 break; 1709 } 1710 1711 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1712 index = 0; 1713 selectJump(size_t(0)); 1714 return; 1715 } 1716 default: 1717 llvm_unreachable("unexpected `ForEach` value kind"); 1718 } 1719 1720 // Store the iterate value and the stack address. 1721 memory[memIndex] = value; 1722 pushCodeIt(prevCodeIt); 1723 1724 // Skip over the successor (we will enter the body of the loop). 1725 read<ByteCodeAddr>(); 1726 } 1727 1728 void ByteCodeExecutor::executeGetAttribute() { 1729 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1730 unsigned memIndex = read(); 1731 Operation *op = read<Operation *>(); 1732 StringAttr attrName = read<StringAttr>(); 1733 Attribute attr = op->getAttr(attrName); 1734 1735 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1736 << " * Attribute: " << attrName << "\n" 1737 << " * Result: " << attr << "\n"); 1738 memory[memIndex] = attr.getAsOpaquePointer(); 1739 } 1740 1741 void ByteCodeExecutor::executeGetAttributeType() { 1742 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1743 unsigned memIndex = read(); 1744 Attribute attr = read<Attribute>(); 1745 Type type; 1746 if (auto typedAttr = attr.dyn_cast<TypedAttr>()) 1747 type = typedAttr.getType(); 1748 1749 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1750 << " * Result: " << type << "\n"); 1751 memory[memIndex] = type.getAsOpaquePointer(); 1752 } 1753 1754 void ByteCodeExecutor::executeGetDefiningOp() { 1755 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1756 unsigned memIndex = read(); 1757 Operation *op = nullptr; 1758 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1759 Value value = read<Value>(); 1760 if (value) 1761 op = value.getDefiningOp(); 1762 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1763 } else { 1764 ValueRange *values = read<ValueRange *>(); 1765 if (values && !values->empty()) { 1766 op = values->front().getDefiningOp(); 1767 } 1768 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1769 } 1770 1771 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1772 memory[memIndex] = op; 1773 } 1774 1775 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1776 Operation *op = read<Operation *>(); 1777 unsigned memIndex = read(); 1778 Value operand = 1779 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1780 1781 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1782 << " * Index: " << index << "\n" 1783 << " * Result: " << operand << "\n"); 1784 memory[memIndex] = operand.getAsOpaquePointer(); 1785 } 1786 1787 /// This function is the internal implementation of `GetResults` and 1788 /// `GetOperands` that provides support for extracting a value range from the 1789 /// given operation. 1790 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1791 static void * 1792 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1793 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1794 MutableArrayRef<ValueRange> valueRangeMemory) { 1795 // Check for the sentinel index that signals that all values should be 1796 // returned. 1797 if (index == std::numeric_limits<uint32_t>::max()) { 1798 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1799 // `values` is already the full value range. 1800 1801 // Otherwise, check to see if this operation uses AttrSizedSegments. 1802 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1803 LLVM_DEBUG(llvm::dbgs() 1804 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1805 1806 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments); 1807 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index) 1808 return nullptr; 1809 1810 ArrayRef<int32_t> segments = segmentAttr; 1811 unsigned startIndex = 1812 std::accumulate(segments.begin(), segments.begin() + index, 0); 1813 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1814 1815 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1816 << *std::next(segments.begin(), index) << "]\n"); 1817 1818 // Otherwise, assume this is the last operand group of the operation. 1819 // FIXME: We currently don't support operations with 1820 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1821 // have a way to detect it's presence. 1822 } else if (values.size() >= index) { 1823 LLVM_DEBUG(llvm::dbgs() 1824 << " * Treating values as trailing variadic range\n"); 1825 values = values.drop_front(index); 1826 1827 // If we couldn't detect a way to compute the values, bail out. 1828 } else { 1829 return nullptr; 1830 } 1831 1832 // If the range index is valid, we are returning a range. 1833 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1834 valueRangeMemory[rangeIndex] = values; 1835 return &valueRangeMemory[rangeIndex]; 1836 } 1837 1838 // If a range index wasn't provided, the range is required to be non-variadic. 1839 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1840 } 1841 1842 void ByteCodeExecutor::executeGetOperands() { 1843 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1844 unsigned index = read<uint32_t>(); 1845 Operation *op = read<Operation *>(); 1846 ByteCodeField rangeIndex = read(); 1847 1848 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1849 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1850 valueRangeMemory); 1851 if (!result) 1852 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1853 memory[read()] = result; 1854 } 1855 1856 void ByteCodeExecutor::executeGetResult(unsigned index) { 1857 Operation *op = read<Operation *>(); 1858 unsigned memIndex = read(); 1859 OpResult result = 1860 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1861 1862 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1863 << " * Index: " << index << "\n" 1864 << " * Result: " << result << "\n"); 1865 memory[memIndex] = result.getAsOpaquePointer(); 1866 } 1867 1868 void ByteCodeExecutor::executeGetResults() { 1869 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1870 unsigned index = read<uint32_t>(); 1871 Operation *op = read<Operation *>(); 1872 ByteCodeField rangeIndex = read(); 1873 1874 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1875 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1876 valueRangeMemory); 1877 if (!result) 1878 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1879 memory[read()] = result; 1880 } 1881 1882 void ByteCodeExecutor::executeGetUsers() { 1883 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1884 unsigned memIndex = read(); 1885 unsigned rangeIndex = read(); 1886 OwningOpRange &range = opRangeMemory[rangeIndex]; 1887 memory[memIndex] = ⦥ 1888 1889 range = OwningOpRange(); 1890 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1891 // Read the value. 1892 Value value = read<Value>(); 1893 if (!value) 1894 return; 1895 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1896 1897 // Extract the users of a single value. 1898 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1899 llvm::copy(value.getUsers(), range.begin()); 1900 } else { 1901 // Read a range of values. 1902 ValueRange *values = read<ValueRange *>(); 1903 if (!values) 1904 return; 1905 LLVM_DEBUG({ 1906 llvm::dbgs() << " * Values (" << values->size() << "): "; 1907 llvm::interleaveComma(*values, llvm::dbgs()); 1908 llvm::dbgs() << "\n"; 1909 }); 1910 1911 // Extract all the users of a range of values. 1912 SmallVector<Operation *> users; 1913 for (Value value : *values) 1914 users.append(value.user_begin(), value.user_end()); 1915 range = OwningOpRange(users.size()); 1916 llvm::copy(users, range.begin()); 1917 } 1918 1919 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1920 } 1921 1922 void ByteCodeExecutor::executeGetValueType() { 1923 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1924 unsigned memIndex = read(); 1925 Value value = read<Value>(); 1926 Type type = value ? value.getType() : Type(); 1927 1928 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1929 << " * Result: " << type << "\n"); 1930 memory[memIndex] = type.getAsOpaquePointer(); 1931 } 1932 1933 void ByteCodeExecutor::executeGetValueRangeTypes() { 1934 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1935 unsigned memIndex = read(); 1936 unsigned rangeIndex = read(); 1937 ValueRange *values = read<ValueRange *>(); 1938 if (!values) { 1939 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1940 memory[memIndex] = nullptr; 1941 return; 1942 } 1943 1944 LLVM_DEBUG({ 1945 llvm::dbgs() << " * Values (" << values->size() << "): "; 1946 llvm::interleaveComma(*values, llvm::dbgs()); 1947 llvm::dbgs() << "\n * Result: "; 1948 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1949 llvm::dbgs() << "\n"; 1950 }); 1951 typeRangeMemory[rangeIndex] = values->getType(); 1952 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1953 } 1954 1955 void ByteCodeExecutor::executeIsNotNull() { 1956 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1957 const void *value = read<const void *>(); 1958 1959 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1960 selectJump(value != nullptr); 1961 } 1962 1963 void ByteCodeExecutor::executeRecordMatch( 1964 PatternRewriter &rewriter, 1965 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1966 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1967 unsigned patternIndex = read(); 1968 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1969 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1970 1971 // If the benefit of the pattern is impossible, skip the processing of the 1972 // rest of the pattern. 1973 if (benefit.isImpossibleToMatch()) { 1974 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1975 curCodeIt = dest; 1976 return; 1977 } 1978 1979 // Create a fused location containing the locations of each of the 1980 // operations used in the match. This will be used as the location for 1981 // created operations during the rewrite that don't already have an 1982 // explicit location set. 1983 unsigned numMatchLocs = read(); 1984 SmallVector<Location, 4> matchLocs; 1985 matchLocs.reserve(numMatchLocs); 1986 for (unsigned i = 0; i != numMatchLocs; ++i) 1987 matchLocs.push_back(read<Operation *>()->getLoc()); 1988 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1989 1990 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1991 << " * Location: " << matchLoc << "\n"); 1992 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1993 PDLByteCode::MatchResult &match = matches.back(); 1994 1995 // Record all of the inputs to the match. If any of the inputs are ranges, we 1996 // will also need to remap the range pointer to memory stored in the match 1997 // state. 1998 unsigned numInputs = read(); 1999 match.values.reserve(numInputs); 2000 match.typeRangeValues.reserve(numInputs); 2001 match.valueRangeValues.reserve(numInputs); 2002 for (unsigned i = 0; i < numInputs; ++i) { 2003 switch (read<PDLValue::Kind>()) { 2004 case PDLValue::Kind::TypeRange: 2005 match.typeRangeValues.push_back(*read<TypeRange *>()); 2006 match.values.push_back(&match.typeRangeValues.back()); 2007 break; 2008 case PDLValue::Kind::ValueRange: 2009 match.valueRangeValues.push_back(*read<ValueRange *>()); 2010 match.values.push_back(&match.valueRangeValues.back()); 2011 break; 2012 default: 2013 match.values.push_back(read<const void *>()); 2014 break; 2015 } 2016 } 2017 curCodeIt = dest; 2018 } 2019 2020 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 2021 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 2022 Operation *op = read<Operation *>(); 2023 SmallVector<Value, 16> args; 2024 readList(args); 2025 2026 LLVM_DEBUG({ 2027 llvm::dbgs() << " * Operation: " << *op << "\n" 2028 << " * Values: "; 2029 llvm::interleaveComma(args, llvm::dbgs()); 2030 llvm::dbgs() << "\n"; 2031 }); 2032 rewriter.replaceOp(op, args); 2033 } 2034 2035 void ByteCodeExecutor::executeSwitchAttribute() { 2036 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 2037 Attribute value = read<Attribute>(); 2038 ArrayAttr cases = read<ArrayAttr>(); 2039 handleSwitch(value, cases); 2040 } 2041 2042 void ByteCodeExecutor::executeSwitchOperandCount() { 2043 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 2044 Operation *op = read<Operation *>(); 2045 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 2046 2047 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 2048 handleSwitch(op->getNumOperands(), cases); 2049 } 2050 2051 void ByteCodeExecutor::executeSwitchOperationName() { 2052 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 2053 OperationName value = read<Operation *>()->getName(); 2054 size_t caseCount = read(); 2055 2056 // The operation names are stored in-line, so to print them out for 2057 // debugging purposes we need to read the array before executing the 2058 // switch so that we can display all of the possible values. 2059 LLVM_DEBUG({ 2060 const ByteCodeField *prevCodeIt = curCodeIt; 2061 llvm::dbgs() << " * Value: " << value << "\n" 2062 << " * Cases: "; 2063 llvm::interleaveComma( 2064 llvm::map_range(llvm::seq<size_t>(0, caseCount), 2065 [&](size_t) { return read<OperationName>(); }), 2066 llvm::dbgs()); 2067 llvm::dbgs() << "\n"; 2068 curCodeIt = prevCodeIt; 2069 }); 2070 2071 // Try to find the switch value within any of the cases. 2072 for (size_t i = 0; i != caseCount; ++i) { 2073 if (read<OperationName>() == value) { 2074 curCodeIt += (caseCount - i - 1); 2075 return selectJump(i + 1); 2076 } 2077 } 2078 selectJump(size_t(0)); 2079 } 2080 2081 void ByteCodeExecutor::executeSwitchResultCount() { 2082 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 2083 Operation *op = read<Operation *>(); 2084 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 2085 2086 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 2087 handleSwitch(op->getNumResults(), cases); 2088 } 2089 2090 void ByteCodeExecutor::executeSwitchType() { 2091 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2092 Type value = read<Type>(); 2093 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2094 handleSwitch(value, cases); 2095 } 2096 2097 void ByteCodeExecutor::executeSwitchTypes() { 2098 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2099 TypeRange *value = read<TypeRange *>(); 2100 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2101 if (!value) { 2102 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2103 return selectJump(size_t(0)); 2104 } 2105 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2106 return value == caseValue.getAsValueRange<TypeAttr>(); 2107 }); 2108 } 2109 2110 LogicalResult 2111 ByteCodeExecutor::execute(PatternRewriter &rewriter, 2112 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2113 std::optional<Location> mainRewriteLoc) { 2114 while (true) { 2115 // Print the location of the operation being executed. 2116 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2117 2118 OpCode opCode = static_cast<OpCode>(read()); 2119 switch (opCode) { 2120 case ApplyConstraint: 2121 executeApplyConstraint(rewriter); 2122 break; 2123 case ApplyRewrite: 2124 if (failed(executeApplyRewrite(rewriter))) 2125 return failure(); 2126 break; 2127 case AreEqual: 2128 executeAreEqual(); 2129 break; 2130 case AreRangesEqual: 2131 executeAreRangesEqual(); 2132 break; 2133 case Branch: 2134 executeBranch(); 2135 break; 2136 case CheckOperandCount: 2137 executeCheckOperandCount(); 2138 break; 2139 case CheckOperationName: 2140 executeCheckOperationName(); 2141 break; 2142 case CheckResultCount: 2143 executeCheckResultCount(); 2144 break; 2145 case CheckTypes: 2146 executeCheckTypes(); 2147 break; 2148 case Continue: 2149 executeContinue(); 2150 break; 2151 case CreateConstantTypeRange: 2152 executeCreateConstantTypeRange(); 2153 break; 2154 case CreateOperation: 2155 executeCreateOperation(rewriter, *mainRewriteLoc); 2156 break; 2157 case CreateDynamicTypeRange: 2158 executeDynamicCreateRange<Type>("Type"); 2159 break; 2160 case CreateDynamicValueRange: 2161 executeDynamicCreateRange<Value>("Value"); 2162 break; 2163 case EraseOp: 2164 executeEraseOp(rewriter); 2165 break; 2166 case ExtractOp: 2167 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2168 break; 2169 case ExtractType: 2170 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2171 break; 2172 case ExtractValue: 2173 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2174 break; 2175 case Finalize: 2176 executeFinalize(); 2177 LLVM_DEBUG(llvm::dbgs() << "\n"); 2178 return success(); 2179 case ForEach: 2180 executeForEach(); 2181 break; 2182 case GetAttribute: 2183 executeGetAttribute(); 2184 break; 2185 case GetAttributeType: 2186 executeGetAttributeType(); 2187 break; 2188 case GetDefiningOp: 2189 executeGetDefiningOp(); 2190 break; 2191 case GetOperand0: 2192 case GetOperand1: 2193 case GetOperand2: 2194 case GetOperand3: { 2195 unsigned index = opCode - GetOperand0; 2196 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2197 executeGetOperand(index); 2198 break; 2199 } 2200 case GetOperandN: 2201 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2202 executeGetOperand(read<uint32_t>()); 2203 break; 2204 case GetOperands: 2205 executeGetOperands(); 2206 break; 2207 case GetResult0: 2208 case GetResult1: 2209 case GetResult2: 2210 case GetResult3: { 2211 unsigned index = opCode - GetResult0; 2212 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2213 executeGetResult(index); 2214 break; 2215 } 2216 case GetResultN: 2217 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2218 executeGetResult(read<uint32_t>()); 2219 break; 2220 case GetResults: 2221 executeGetResults(); 2222 break; 2223 case GetUsers: 2224 executeGetUsers(); 2225 break; 2226 case GetValueType: 2227 executeGetValueType(); 2228 break; 2229 case GetValueRangeTypes: 2230 executeGetValueRangeTypes(); 2231 break; 2232 case IsNotNull: 2233 executeIsNotNull(); 2234 break; 2235 case RecordMatch: 2236 assert(matches && 2237 "expected matches to be provided when executing the matcher"); 2238 executeRecordMatch(rewriter, *matches); 2239 break; 2240 case ReplaceOp: 2241 executeReplaceOp(rewriter); 2242 break; 2243 case SwitchAttribute: 2244 executeSwitchAttribute(); 2245 break; 2246 case SwitchOperandCount: 2247 executeSwitchOperandCount(); 2248 break; 2249 case SwitchOperationName: 2250 executeSwitchOperationName(); 2251 break; 2252 case SwitchResultCount: 2253 executeSwitchResultCount(); 2254 break; 2255 case SwitchType: 2256 executeSwitchType(); 2257 break; 2258 case SwitchTypes: 2259 executeSwitchTypes(); 2260 break; 2261 } 2262 LLVM_DEBUG(llvm::dbgs() << "\n"); 2263 } 2264 } 2265 2266 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2267 SmallVectorImpl<MatchResult> &matches, 2268 PDLByteCodeMutableState &state) const { 2269 // The first memory slot is always the root operation. 2270 state.memory[0] = op; 2271 2272 // The matcher function always starts at code address 0. 2273 ByteCodeExecutor executor( 2274 matcherByteCode.data(), state.memory, state.opRangeMemory, 2275 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2276 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2277 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2278 constraintFunctions, rewriteFunctions); 2279 LogicalResult executeResult = executor.execute(rewriter, &matches); 2280 (void)executeResult; 2281 assert(succeeded(executeResult) && "unexpected matcher execution failure"); 2282 2283 // Order the found matches by benefit. 2284 std::stable_sort(matches.begin(), matches.end(), 2285 [](const MatchResult &lhs, const MatchResult &rhs) { 2286 return lhs.benefit > rhs.benefit; 2287 }); 2288 } 2289 2290 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, 2291 const MatchResult &match, 2292 PDLByteCodeMutableState &state) const { 2293 auto *configSet = match.pattern->getConfigSet(); 2294 if (configSet) 2295 configSet->notifyRewriteBegin(rewriter); 2296 2297 // The arguments of the rewrite function are stored at the start of the 2298 // memory buffer. 2299 llvm::copy(match.values, state.memory.begin()); 2300 2301 ByteCodeExecutor executor( 2302 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2303 state.opRangeMemory, state.typeRangeMemory, 2304 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2305 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2306 rewriterByteCode, state.currentPatternBenefits, patterns, 2307 constraintFunctions, rewriteFunctions); 2308 LogicalResult result = 2309 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2310 2311 if (configSet) 2312 configSet->notifyRewriteEnd(rewriter); 2313 2314 // If the rewrite failed, check if the pattern rewriter can recover. If it 2315 // can, we can signal to the pattern applicator to keep trying patterns. If it 2316 // doesn't, we need to bail. Bailing here should be fine, given that we have 2317 // no means to propagate such a failure to the user, and it also indicates a 2318 // bug in the user code (i.e. failable rewrites should not be used with 2319 // pattern rewriters that don't support it). 2320 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { 2321 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); 2322 llvm::report_fatal_error( 2323 "Native PDL Rewrite failed, but the pattern " 2324 "rewriter doesn't support recovery. Failable pattern rewrites should " 2325 "not be used with pattern rewriters that do not support them."); 2326 } 2327 return result; 2328 } 2329