1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 PDLPatternConfigSet *configSet, 38 ByteCodeAddr rewriterAddr) { 39 PatternBenefit benefit = matchOp.getBenefit(); 40 MLIRContext *ctx = matchOp.getContext(); 41 42 // Collect the set of generated operations. 43 SmallVector<StringRef, 8> generatedOps; 44 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 45 generatedOps = 46 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 47 48 // Check to see if this is pattern matches a specific operation type. 49 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 50 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, 51 generatedOps); 52 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), 53 benefit, ctx, generatedOps); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // PDLByteCodeMutableState 58 //===----------------------------------------------------------------------===// 59 60 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 61 /// to the position of the pattern within the range returned by 62 /// `PDLByteCode::getPatterns`. 63 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 64 PatternBenefit benefit) { 65 currentPatternBenefits[patternIndex] = benefit; 66 } 67 68 /// Cleanup any allocated state after a full match/rewrite has been completed. 69 /// This method should be called irregardless of whether the match+rewrite was a 70 /// success or not. 71 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 72 allocatedTypeRangeMemory.clear(); 73 allocatedValueRangeMemory.clear(); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Bytecode OpCodes 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 enum OpCode : ByteCodeField { 82 /// Apply an externally registered constraint. 83 ApplyConstraint, 84 /// Apply an externally registered rewrite. 85 ApplyRewrite, 86 /// Check if two generic values are equal. 87 AreEqual, 88 /// Check if two ranges are equal. 89 AreRangesEqual, 90 /// Unconditional branch. 91 Branch, 92 /// Compare the operand count of an operation with a constant. 93 CheckOperandCount, 94 /// Compare the name of an operation with a constant. 95 CheckOperationName, 96 /// Compare the result count of an operation with a constant. 97 CheckResultCount, 98 /// Compare a range of types to a constant range of types. 99 CheckTypes, 100 /// Continue to the next iteration of a loop. 101 Continue, 102 /// Create an operation. 103 CreateOperation, 104 /// Create a range of types. 105 CreateTypes, 106 /// Erase an operation. 107 EraseOp, 108 /// Extract the op from a range at the specified index. 109 ExtractOp, 110 /// Extract the type from a range at the specified index. 111 ExtractType, 112 /// Extract the value from a range at the specified index. 113 ExtractValue, 114 /// Terminate a matcher or rewrite sequence. 115 Finalize, 116 /// Iterate over a range of values. 117 ForEach, 118 /// Get a specific attribute of an operation. 119 GetAttribute, 120 /// Get the type of an attribute. 121 GetAttributeType, 122 /// Get the defining operation of a value. 123 GetDefiningOp, 124 /// Get a specific operand of an operation. 125 GetOperand0, 126 GetOperand1, 127 GetOperand2, 128 GetOperand3, 129 GetOperandN, 130 /// Get a specific operand group of an operation. 131 GetOperands, 132 /// Get a specific result of an operation. 133 GetResult0, 134 GetResult1, 135 GetResult2, 136 GetResult3, 137 GetResultN, 138 /// Get a specific result group of an operation. 139 GetResults, 140 /// Get the users of a value or a range of values. 141 GetUsers, 142 /// Get the type of a value. 143 GetValueType, 144 /// Get the types of a value range. 145 GetValueRangeTypes, 146 /// Check if a generic value is not null. 147 IsNotNull, 148 /// Record a successful pattern match. 149 RecordMatch, 150 /// Replace an operation. 151 ReplaceOp, 152 /// Compare an attribute with a set of constants. 153 SwitchAttribute, 154 /// Compare the operand count of an operation with a set of constants. 155 SwitchOperandCount, 156 /// Compare the name of an operation with a set of constants. 157 SwitchOperationName, 158 /// Compare the result count of an operation with a set of constants. 159 SwitchResultCount, 160 /// Compare a type with a set of constants. 161 SwitchType, 162 /// Compare a range of types with a set of constants. 163 SwitchTypes, 164 }; 165 } // namespace 166 167 /// A marker used to indicate if an operation should infer types. 168 static constexpr ByteCodeField kInferTypesMarker = 169 std::numeric_limits<ByteCodeField>::max(); 170 171 //===----------------------------------------------------------------------===// 172 // ByteCode Generation 173 //===----------------------------------------------------------------------===// 174 175 //===----------------------------------------------------------------------===// 176 // Generator 177 178 namespace { 179 struct ByteCodeLiveRange; 180 struct ByteCodeWriter; 181 182 /// Check if the given class `T` can be converted to an opaque pointer. 183 template <typename T, typename... Args> 184 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 185 186 /// This class represents the main generator for the pattern bytecode. 187 class Generator { 188 public: 189 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 190 SmallVectorImpl<ByteCodeField> &matcherByteCode, 191 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 192 SmallVectorImpl<PDLByteCodePattern> &patterns, 193 ByteCodeField &maxValueMemoryIndex, 194 ByteCodeField &maxOpRangeMemoryIndex, 195 ByteCodeField &maxTypeRangeMemoryIndex, 196 ByteCodeField &maxValueRangeMemoryIndex, 197 ByteCodeField &maxLoopLevel, 198 llvm::StringMap<PDLConstraintFunction> &constraintFns, 199 llvm::StringMap<PDLRewriteFunction> &rewriteFns, 200 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 201 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 202 rewriterByteCode(rewriterByteCode), patterns(patterns), 203 maxValueMemoryIndex(maxValueMemoryIndex), 204 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 205 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 206 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 207 maxLoopLevel(maxLoopLevel), configMap(configMap) { 208 for (const auto &it : llvm::enumerate(constraintFns)) 209 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 210 for (const auto &it : llvm::enumerate(rewriteFns)) 211 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 212 } 213 214 /// Generate the bytecode for the given PDL interpreter module. 215 void generate(ModuleOp module); 216 217 /// Return the memory index to use for the given value. 218 ByteCodeField &getMemIndex(Value value) { 219 assert(valueToMemIndex.count(value) && 220 "expected memory index to be assigned"); 221 return valueToMemIndex[value]; 222 } 223 224 /// Return the range memory index used to store the given range value. 225 ByteCodeField &getRangeStorageIndex(Value value) { 226 assert(valueToRangeIndex.count(value) && 227 "expected range index to be assigned"); 228 return valueToRangeIndex[value]; 229 } 230 231 /// Return an index to use when referring to the given data that is uniqued in 232 /// the MLIR context. 233 template <typename T> 234 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 235 getMemIndex(T val) { 236 const void *opaqueVal = val.getAsOpaquePointer(); 237 238 // Get or insert a reference to this value. 239 auto it = uniquedDataToMemIndex.try_emplace( 240 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 241 if (it.second) 242 uniquedData.push_back(opaqueVal); 243 return it.first->second; 244 } 245 246 private: 247 /// Allocate memory indices for the results of operations within the matcher 248 /// and rewriters. 249 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 250 ModuleOp rewriterModule); 251 252 /// Generate the bytecode for the given operation. 253 void generate(Region *region, ByteCodeWriter &writer); 254 void generate(Operation *op, ByteCodeWriter &writer); 255 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 286 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 287 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 288 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 289 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 290 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 291 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 292 293 /// Mapping from value to its corresponding memory index. 294 DenseMap<Value, ByteCodeField> valueToMemIndex; 295 296 /// Mapping from a range value to its corresponding range storage index. 297 DenseMap<Value, ByteCodeField> valueToRangeIndex; 298 299 /// Mapping from the name of an externally registered rewrite to its index in 300 /// the bytecode registry. 301 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 302 303 /// Mapping from the name of an externally registered constraint to its index 304 /// in the bytecode registry. 305 llvm::StringMap<ByteCodeField> constraintToMemIndex; 306 307 /// Mapping from rewriter function name to the bytecode address of the 308 /// rewriter function in byte. 309 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 310 311 /// Mapping from a uniqued storage object to its memory index within 312 /// `uniquedData`. 313 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 314 315 /// The current level of the foreach loop. 316 ByteCodeField curLoopLevel = 0; 317 318 /// The current MLIR context. 319 MLIRContext *ctx; 320 321 /// Mapping from block to its address. 322 DenseMap<Block *, ByteCodeAddr> blockToAddr; 323 324 /// Data of the ByteCode class to be populated. 325 std::vector<const void *> &uniquedData; 326 SmallVectorImpl<ByteCodeField> &matcherByteCode; 327 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 328 SmallVectorImpl<PDLByteCodePattern> &patterns; 329 ByteCodeField &maxValueMemoryIndex; 330 ByteCodeField &maxOpRangeMemoryIndex; 331 ByteCodeField &maxTypeRangeMemoryIndex; 332 ByteCodeField &maxValueRangeMemoryIndex; 333 ByteCodeField &maxLoopLevel; 334 335 /// A map of pattern configurations. 336 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap; 337 }; 338 339 /// This class provides utilities for writing a bytecode stream. 340 struct ByteCodeWriter { 341 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 342 : bytecode(bytecode), generator(generator) {} 343 344 /// Append a field to the bytecode. 345 void append(ByteCodeField field) { bytecode.push_back(field); } 346 void append(OpCode opCode) { bytecode.push_back(opCode); } 347 348 /// Append an address to the bytecode. 349 void append(ByteCodeAddr field) { 350 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 351 "unexpected ByteCode address size"); 352 353 ByteCodeField fieldParts[2]; 354 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 355 bytecode.append({fieldParts[0], fieldParts[1]}); 356 } 357 358 /// Append a single successor to the bytecode, the exact address will need to 359 /// be resolved later. 360 void append(Block *successor) { 361 // Add back a reference to the successor so that the address can be resolved 362 // later. 363 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 364 append(ByteCodeAddr(0)); 365 } 366 367 /// Append a successor range to the bytecode, the exact address will need to 368 /// be resolved later. 369 void append(SuccessorRange successors) { 370 for (Block *successor : successors) 371 append(successor); 372 } 373 374 /// Append a range of values that will be read as generic PDLValues. 375 void appendPDLValueList(OperandRange values) { 376 bytecode.push_back(values.size()); 377 for (Value value : values) 378 appendPDLValue(value); 379 } 380 381 /// Append a value as a PDLValue. 382 void appendPDLValue(Value value) { 383 appendPDLValueKind(value); 384 append(value); 385 } 386 387 /// Append the PDLValue::Kind of the given value. 388 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 389 390 /// Append the PDLValue::Kind of the given type. 391 void appendPDLValueKind(Type type) { 392 PDLValue::Kind kind = 393 TypeSwitch<Type, PDLValue::Kind>(type) 394 .Case<pdl::AttributeType>( 395 [](Type) { return PDLValue::Kind::Attribute; }) 396 .Case<pdl::OperationType>( 397 [](Type) { return PDLValue::Kind::Operation; }) 398 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 399 if (rangeTy.getElementType().isa<pdl::TypeType>()) 400 return PDLValue::Kind::TypeRange; 401 return PDLValue::Kind::ValueRange; 402 }) 403 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 404 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 405 bytecode.push_back(static_cast<ByteCodeField>(kind)); 406 } 407 408 /// Append a value that will be stored in a memory slot and not inline within 409 /// the bytecode. 410 template <typename T> 411 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 412 std::is_pointer<T>::value> 413 append(T value) { 414 bytecode.push_back(generator.getMemIndex(value)); 415 } 416 417 /// Append a range of values. 418 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 419 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 420 append(T range) { 421 bytecode.push_back(llvm::size(range)); 422 for (auto it : range) 423 append(it); 424 } 425 426 /// Append a variadic number of fields to the bytecode. 427 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 428 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 429 append(field); 430 append(field2, fields...); 431 } 432 433 /// Appends a value as a pointer, stored inline within the bytecode. 434 template <typename T> 435 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 436 appendInline(T value) { 437 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 438 const void *pointer = value.getAsOpaquePointer(); 439 ByteCodeField fieldParts[numParts]; 440 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 441 bytecode.append(fieldParts, fieldParts + numParts); 442 } 443 444 /// Successor references in the bytecode that have yet to be resolved. 445 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 446 447 /// The underlying bytecode buffer. 448 SmallVectorImpl<ByteCodeField> &bytecode; 449 450 /// The main generator producing PDL. 451 Generator &generator; 452 }; 453 454 /// This class represents a live range of PDL Interpreter values, containing 455 /// information about when values are live within a match/rewrite. 456 struct ByteCodeLiveRange { 457 using Set = llvm::IntervalMap<uint64_t, char, 16>; 458 using Allocator = Set::Allocator; 459 460 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 461 462 /// Union this live range with the one provided. 463 void unionWith(const ByteCodeLiveRange &rhs) { 464 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 465 ++it) 466 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 467 } 468 469 /// Returns true if this range overlaps with the one provided. 470 bool overlaps(const ByteCodeLiveRange &rhs) const { 471 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 472 .valid(); 473 } 474 475 /// A map representing the ranges of the match/rewrite that a value is live in 476 /// the interpreter. 477 /// 478 /// We use std::unique_ptr here, because IntervalMap does not provide a 479 /// correct copy or move constructor. We can eliminate the pointer once 480 /// https://reviews.llvm.org/D113240 lands. 481 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 482 483 /// The operation range storage index for this range. 484 Optional<unsigned> opRangeIndex; 485 486 /// The type range storage index for this range. 487 Optional<unsigned> typeRangeIndex; 488 489 /// The value range storage index for this range. 490 Optional<unsigned> valueRangeIndex; 491 }; 492 } // namespace 493 494 void Generator::generate(ModuleOp module) { 495 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 496 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 497 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 498 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 499 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 500 501 // Allocate memory indices for the results of operations within the matcher 502 // and rewriters. 503 allocateMemoryIndices(matcherFunc, rewriterModule); 504 505 // Generate code for the rewriter functions. 506 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 507 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 508 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 509 for (Operation &op : rewriterFunc.getOps()) 510 generate(&op, rewriterByteCodeWriter); 511 } 512 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 513 "unexpected branches in rewriter function"); 514 515 // Generate code for the matcher function. 516 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 517 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 518 519 // Resolve successor references in the matcher. 520 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 521 ByteCodeAddr addr = blockToAddr[it.first]; 522 for (unsigned offsetToFix : it.second) 523 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 524 } 525 } 526 527 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 528 ModuleOp rewriterModule) { 529 // Rewriters use simplistic allocation scheme that simply assigns an index to 530 // each result. 531 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 532 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 533 auto processRewriterValue = [&](Value val) { 534 valueToMemIndex.try_emplace(val, index++); 535 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 536 Type elementTy = rangeType.getElementType(); 537 if (elementTy.isa<pdl::TypeType>()) 538 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 539 else if (elementTy.isa<pdl::ValueType>()) 540 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 541 } 542 }; 543 544 for (BlockArgument arg : rewriterFunc.getArguments()) 545 processRewriterValue(arg); 546 rewriterFunc.getBody().walk([&](Operation *op) { 547 for (Value result : op->getResults()) 548 processRewriterValue(result); 549 }); 550 if (index > maxValueMemoryIndex) 551 maxValueMemoryIndex = index; 552 if (typeRangeIndex > maxTypeRangeMemoryIndex) 553 maxTypeRangeMemoryIndex = typeRangeIndex; 554 if (valueRangeIndex > maxValueRangeMemoryIndex) 555 maxValueRangeMemoryIndex = valueRangeIndex; 556 } 557 558 // The matcher function uses a more sophisticated numbering that tries to 559 // minimize the number of memory indices assigned. This is done by determining 560 // a live range of the values within the matcher, then the allocation is just 561 // finding the minimal number of overlapping live ranges. This is essentially 562 // a simplified form of register allocation where we don't necessarily have a 563 // limited number of registers, but we still want to minimize the number used. 564 DenseMap<Operation *, unsigned> opToFirstIndex; 565 DenseMap<Operation *, unsigned> opToLastIndex; 566 567 // A custom walk that marks the first and the last index of each operation. 568 // The entry marks the beginning of the liveness range for this operation, 569 // followed by nested operations, followed by the end of the liveness range. 570 unsigned index = 0; 571 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 572 opToFirstIndex.try_emplace(op, index++); 573 for (Region ®ion : op->getRegions()) 574 for (Block &block : region.getBlocks()) 575 for (Operation &nested : block) 576 walk(&nested); 577 opToLastIndex.try_emplace(op, index++); 578 }; 579 walk(matcherFunc); 580 581 // Liveness info for each of the defs within the matcher. 582 ByteCodeLiveRange::Allocator allocator; 583 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 584 585 // Assign the root operation being matched to slot 0. 586 BlockArgument rootOpArg = matcherFunc.getArgument(0); 587 valueToMemIndex[rootOpArg] = 0; 588 589 // Walk each of the blocks, computing the def interval that the value is used. 590 Liveness matcherLiveness(matcherFunc); 591 matcherFunc->walk([&](Block *block) { 592 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 593 assert(info && "expected liveness info for block"); 594 auto processValue = [&](Value value, Operation *firstUseOrDef) { 595 // We don't need to process the root op argument, this value is always 596 // assigned to the first memory slot. 597 if (value == rootOpArg) 598 return; 599 600 // Set indices for the range of this block that the value is used. 601 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 602 defRangeIt->second.liveness->insert( 603 opToFirstIndex[firstUseOrDef], 604 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 605 /*dummyValue*/ 0); 606 607 // Check to see if this value is a range type. 608 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 609 Type eleType = rangeTy.getElementType(); 610 if (eleType.isa<pdl::OperationType>()) 611 defRangeIt->second.opRangeIndex = 0; 612 else if (eleType.isa<pdl::TypeType>()) 613 defRangeIt->second.typeRangeIndex = 0; 614 else if (eleType.isa<pdl::ValueType>()) 615 defRangeIt->second.valueRangeIndex = 0; 616 } 617 }; 618 619 // Process the live-ins of this block. 620 for (Value liveIn : info->in()) { 621 // Only process the value if it has been defined in the current region. 622 // Other values that span across pdl_interp.foreach will be added higher 623 // up. This ensures that the we keep them alive for the entire duration 624 // of the loop. 625 if (liveIn.getParentRegion() == block->getParent()) 626 processValue(liveIn, &block->front()); 627 } 628 629 // Process the block arguments for the entry block (those are not live-in). 630 if (block->isEntryBlock()) { 631 for (Value argument : block->getArguments()) 632 processValue(argument, &block->front()); 633 } 634 635 // Process any new defs within this block. 636 for (Operation &op : *block) 637 for (Value result : op.getResults()) 638 processValue(result, &op); 639 }); 640 641 // Greedily allocate memory slots using the computed def live ranges. 642 std::vector<ByteCodeLiveRange> allocatedIndices; 643 644 // The number of memory indices currently allocated (and its next value). 645 // Recall that the root gets allocated memory index 0. 646 ByteCodeField numIndices = 1; 647 648 // The number of memory ranges of various types (and their next values). 649 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 650 651 for (auto &defIt : valueDefRanges) { 652 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 653 ByteCodeLiveRange &defRange = defIt.second; 654 655 // Try to allocate to an existing index. 656 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 657 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 658 if (!defRange.overlaps(existingRange)) { 659 existingRange.unionWith(defRange); 660 memIndex = existingIndexIt.index() + 1; 661 662 if (defRange.opRangeIndex) { 663 if (!existingRange.opRangeIndex) 664 existingRange.opRangeIndex = numOpRanges++; 665 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 666 } else if (defRange.typeRangeIndex) { 667 if (!existingRange.typeRangeIndex) 668 existingRange.typeRangeIndex = numTypeRanges++; 669 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 670 } else if (defRange.valueRangeIndex) { 671 if (!existingRange.valueRangeIndex) 672 existingRange.valueRangeIndex = numValueRanges++; 673 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 674 } 675 break; 676 } 677 } 678 679 // If no existing index could be used, add a new one. 680 if (memIndex == 0) { 681 allocatedIndices.emplace_back(allocator); 682 ByteCodeLiveRange &newRange = allocatedIndices.back(); 683 newRange.unionWith(defRange); 684 685 // Allocate an index for op/type/value ranges. 686 if (defRange.opRangeIndex) { 687 newRange.opRangeIndex = numOpRanges; 688 valueToRangeIndex[defIt.first] = numOpRanges++; 689 } else if (defRange.typeRangeIndex) { 690 newRange.typeRangeIndex = numTypeRanges; 691 valueToRangeIndex[defIt.first] = numTypeRanges++; 692 } else if (defRange.valueRangeIndex) { 693 newRange.valueRangeIndex = numValueRanges; 694 valueToRangeIndex[defIt.first] = numValueRanges++; 695 } 696 697 memIndex = allocatedIndices.size(); 698 ++numIndices; 699 } 700 } 701 702 // Print the index usage and ensure that we did not run out of index space. 703 LLVM_DEBUG({ 704 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 705 << "(down from initial " << valueDefRanges.size() << ").\n"; 706 }); 707 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 708 "Ran out of memory for allocated indices"); 709 710 // Update the max number of indices. 711 if (numIndices > maxValueMemoryIndex) 712 maxValueMemoryIndex = numIndices; 713 if (numOpRanges > maxOpRangeMemoryIndex) 714 maxOpRangeMemoryIndex = numOpRanges; 715 if (numTypeRanges > maxTypeRangeMemoryIndex) 716 maxTypeRangeMemoryIndex = numTypeRanges; 717 if (numValueRanges > maxValueRangeMemoryIndex) 718 maxValueRangeMemoryIndex = numValueRanges; 719 } 720 721 void Generator::generate(Region *region, ByteCodeWriter &writer) { 722 llvm::ReversePostOrderTraversal<Region *> rpot(region); 723 for (Block *block : rpot) { 724 // Keep track of where this block begins within the matcher function. 725 blockToAddr.try_emplace(block, matcherByteCode.size()); 726 for (Operation &op : *block) 727 generate(&op, writer); 728 } 729 } 730 731 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 732 LLVM_DEBUG({ 733 // The following list must contain all the operations that do not 734 // produce any bytecode. 735 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op)) 736 writer.appendInline(op->getLoc()); 737 }); 738 TypeSwitch<Operation *>(op) 739 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 740 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 741 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 742 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 743 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 744 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 745 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 746 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 747 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 748 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 749 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 750 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 751 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 752 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 753 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 754 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 755 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, 756 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 757 pdl_interp::SwitchResultCountOp>( 758 [&](auto interpOp) { this->generate(interpOp, writer); }) 759 .Default([](Operation *) { 760 llvm_unreachable("unknown `pdl_interp` operation"); 761 }); 762 } 763 764 void Generator::generate(pdl_interp::ApplyConstraintOp op, 765 ByteCodeWriter &writer) { 766 assert(constraintToMemIndex.count(op.getName()) && 767 "expected index for constraint function"); 768 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 769 writer.appendPDLValueList(op.getArgs()); 770 writer.append(op.getSuccessors()); 771 } 772 void Generator::generate(pdl_interp::ApplyRewriteOp op, 773 ByteCodeWriter &writer) { 774 assert(externalRewriterToMemIndex.count(op.getName()) && 775 "expected index for rewrite function"); 776 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 777 writer.appendPDLValueList(op.getArgs()); 778 779 ResultRange results = op.getResults(); 780 writer.append(ByteCodeField(results.size())); 781 for (Value result : results) { 782 // In debug mode we also record the expected kind of the result, so that we 783 // can provide extra verification of the native rewrite function. 784 #ifndef NDEBUG 785 writer.appendPDLValueKind(result); 786 #endif 787 788 // Range results also need to append the range storage index. 789 if (result.getType().isa<pdl::RangeType>()) 790 writer.append(getRangeStorageIndex(result)); 791 writer.append(result); 792 } 793 } 794 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 795 Value lhs = op.getLhs(); 796 if (lhs.getType().isa<pdl::RangeType>()) { 797 writer.append(OpCode::AreRangesEqual); 798 writer.appendPDLValueKind(lhs); 799 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 800 return; 801 } 802 803 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 804 } 805 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 806 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 807 } 808 void Generator::generate(pdl_interp::CheckAttributeOp op, 809 ByteCodeWriter &writer) { 810 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 811 op.getSuccessors()); 812 } 813 void Generator::generate(pdl_interp::CheckOperandCountOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 816 static_cast<ByteCodeField>(op.getCompareAtLeast()), 817 op.getSuccessors()); 818 } 819 void Generator::generate(pdl_interp::CheckOperationNameOp op, 820 ByteCodeWriter &writer) { 821 writer.append(OpCode::CheckOperationName, op.getInputOp(), 822 OperationName(op.getName(), ctx), op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckResultCountOp op, 825 ByteCodeWriter &writer) { 826 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 827 static_cast<ByteCodeField>(op.getCompareAtLeast()), 828 op.getSuccessors()); 829 } 830 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 831 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 832 op.getSuccessors()); 833 } 834 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 835 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 836 op.getSuccessors()); 837 } 838 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 839 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 840 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 841 } 842 void Generator::generate(pdl_interp::CreateAttributeOp op, 843 ByteCodeWriter &writer) { 844 // Simply repoint the memory index of the result to the constant. 845 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 846 } 847 void Generator::generate(pdl_interp::CreateOperationOp op, 848 ByteCodeWriter &writer) { 849 writer.append(OpCode::CreateOperation, op.getResultOp(), 850 OperationName(op.getName(), ctx)); 851 writer.appendPDLValueList(op.getInputOperands()); 852 853 // Add the attributes. 854 OperandRange attributes = op.getInputAttributes(); 855 writer.append(static_cast<ByteCodeField>(attributes.size())); 856 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 857 writer.append(std::get<0>(it), std::get<1>(it)); 858 859 // Add the result types. If the operation has inferred results, we use a 860 // marker "size" value. Otherwise, we add the list of explicit result types. 861 if (op.getInferredResultTypes()) 862 writer.append(kInferTypesMarker); 863 else 864 writer.appendPDLValueList(op.getInputResultTypes()); 865 } 866 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 867 // Simply repoint the memory index of the result to the constant. 868 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 869 } 870 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 871 writer.append(OpCode::CreateTypes, op.getResult(), 872 getRangeStorageIndex(op.getResult()), op.getValue()); 873 } 874 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 875 writer.append(OpCode::EraseOp, op.getInputOp()); 876 } 877 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 878 OpCode opCode = 879 TypeSwitch<Type, OpCode>(op.getResult().getType()) 880 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 881 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 882 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 883 .Default([](Type) -> OpCode { 884 llvm_unreachable("unsupported element type"); 885 }); 886 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 887 } 888 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 889 writer.append(OpCode::Finalize); 890 } 891 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 892 BlockArgument arg = op.getLoopVariable(); 893 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 894 writer.appendPDLValueKind(arg.getType()); 895 writer.append(curLoopLevel, op.getSuccessor()); 896 ++curLoopLevel; 897 if (curLoopLevel > maxLoopLevel) 898 maxLoopLevel = curLoopLevel; 899 generate(&op.getRegion(), writer); 900 --curLoopLevel; 901 } 902 void Generator::generate(pdl_interp::GetAttributeOp op, 903 ByteCodeWriter &writer) { 904 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 905 op.getNameAttr()); 906 } 907 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 908 ByteCodeWriter &writer) { 909 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 910 } 911 void Generator::generate(pdl_interp::GetDefiningOpOp op, 912 ByteCodeWriter &writer) { 913 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 914 writer.appendPDLValue(op.getValue()); 915 } 916 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 917 uint32_t index = op.getIndex(); 918 if (index < 4) 919 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 920 else 921 writer.append(OpCode::GetOperandN, index); 922 writer.append(op.getInputOp(), op.getValue()); 923 } 924 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 925 Value result = op.getValue(); 926 Optional<uint32_t> index = op.getIndex(); 927 writer.append(OpCode::GetOperands, 928 index.value_or(std::numeric_limits<uint32_t>::max()), 929 op.getInputOp()); 930 if (result.getType().isa<pdl::RangeType>()) 931 writer.append(getRangeStorageIndex(result)); 932 else 933 writer.append(std::numeric_limits<ByteCodeField>::max()); 934 writer.append(result); 935 } 936 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 937 uint32_t index = op.getIndex(); 938 if (index < 4) 939 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 940 else 941 writer.append(OpCode::GetResultN, index); 942 writer.append(op.getInputOp(), op.getValue()); 943 } 944 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 945 Value result = op.getValue(); 946 Optional<uint32_t> index = op.getIndex(); 947 writer.append(OpCode::GetResults, 948 index.value_or(std::numeric_limits<uint32_t>::max()), 949 op.getInputOp()); 950 if (result.getType().isa<pdl::RangeType>()) 951 writer.append(getRangeStorageIndex(result)); 952 else 953 writer.append(std::numeric_limits<ByteCodeField>::max()); 954 writer.append(result); 955 } 956 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 957 Value operations = op.getOperations(); 958 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 959 writer.append(OpCode::GetUsers, operations, rangeIndex); 960 writer.appendPDLValue(op.getValue()); 961 } 962 void Generator::generate(pdl_interp::GetValueTypeOp op, 963 ByteCodeWriter &writer) { 964 if (op.getType().isa<pdl::RangeType>()) { 965 Value result = op.getResult(); 966 writer.append(OpCode::GetValueRangeTypes, result, 967 getRangeStorageIndex(result), op.getValue()); 968 } else { 969 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 970 } 971 } 972 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 973 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 974 } 975 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 976 ByteCodeField patternIndex = patterns.size(); 977 patterns.emplace_back(PDLByteCodePattern::create( 978 op, configMap.lookup(op), 979 rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 980 writer.append(OpCode::RecordMatch, patternIndex, 981 SuccessorRange(op.getOperation()), op.getMatchedOps()); 982 writer.appendPDLValueList(op.getInputs()); 983 } 984 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 985 writer.append(OpCode::ReplaceOp, op.getInputOp()); 986 writer.appendPDLValueList(op.getReplValues()); 987 } 988 void Generator::generate(pdl_interp::SwitchAttributeOp op, 989 ByteCodeWriter &writer) { 990 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 991 op.getCaseValuesAttr(), op.getSuccessors()); 992 } 993 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 994 ByteCodeWriter &writer) { 995 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 996 op.getCaseValuesAttr(), op.getSuccessors()); 997 } 998 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 999 ByteCodeWriter &writer) { 1000 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 1001 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 1002 }); 1003 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 1004 op.getSuccessors()); 1005 } 1006 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1007 ByteCodeWriter &writer) { 1008 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1009 op.getCaseValuesAttr(), op.getSuccessors()); 1010 } 1011 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1012 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1013 op.getSuccessors()); 1014 } 1015 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1016 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1017 op.getSuccessors()); 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // PDLByteCode 1022 //===----------------------------------------------------------------------===// 1023 1024 PDLByteCode::PDLByteCode( 1025 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, 1026 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, 1027 llvm::StringMap<PDLConstraintFunction> constraintFns, 1028 llvm::StringMap<PDLRewriteFunction> rewriteFns) 1029 : configs(std::move(configs)) { 1030 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1031 rewriterByteCode, patterns, maxValueMemoryIndex, 1032 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1033 maxLoopLevel, constraintFns, rewriteFns, configMap); 1034 generator.generate(module); 1035 1036 // Initialize the external functions. 1037 for (auto &it : constraintFns) 1038 constraintFunctions.push_back(std::move(it.second)); 1039 for (auto &it : rewriteFns) 1040 rewriteFunctions.push_back(std::move(it.second)); 1041 } 1042 1043 /// Initialize the given state such that it can be used to execute the current 1044 /// bytecode. 1045 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1046 state.memory.resize(maxValueMemoryIndex, nullptr); 1047 state.opRangeMemory.resize(maxOpRangeCount); 1048 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1049 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1050 state.loopIndex.resize(maxLoopLevel, 0); 1051 state.currentPatternBenefits.reserve(patterns.size()); 1052 for (const PDLByteCodePattern &pattern : patterns) 1053 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1054 } 1055 1056 //===----------------------------------------------------------------------===// 1057 // ByteCode Execution 1058 1059 namespace { 1060 /// This class provides support for executing a bytecode stream. 1061 class ByteCodeExecutor { 1062 public: 1063 ByteCodeExecutor( 1064 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1065 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1066 MutableArrayRef<TypeRange> typeRangeMemory, 1067 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1068 MutableArrayRef<ValueRange> valueRangeMemory, 1069 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1070 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1071 ArrayRef<ByteCodeField> code, 1072 ArrayRef<PatternBenefit> currentPatternBenefits, 1073 ArrayRef<PDLByteCodePattern> patterns, 1074 ArrayRef<PDLConstraintFunction> constraintFunctions, 1075 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1076 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1077 typeRangeMemory(typeRangeMemory), 1078 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1079 valueRangeMemory(valueRangeMemory), 1080 allocatedValueRangeMemory(allocatedValueRangeMemory), 1081 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1082 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1083 constraintFunctions(constraintFunctions), 1084 rewriteFunctions(rewriteFunctions) {} 1085 1086 /// Start executing the code at the current bytecode index. `matches` is an 1087 /// optional field provided when this function is executed in a matching 1088 /// context. 1089 LogicalResult 1090 execute(PatternRewriter &rewriter, 1091 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1092 Optional<Location> mainRewriteLoc = {}); 1093 1094 private: 1095 /// Internal implementation of executing each of the bytecode commands. 1096 void executeApplyConstraint(PatternRewriter &rewriter); 1097 LogicalResult executeApplyRewrite(PatternRewriter &rewriter); 1098 void executeAreEqual(); 1099 void executeAreRangesEqual(); 1100 void executeBranch(); 1101 void executeCheckOperandCount(); 1102 void executeCheckOperationName(); 1103 void executeCheckResultCount(); 1104 void executeCheckTypes(); 1105 void executeContinue(); 1106 void executeCreateOperation(PatternRewriter &rewriter, 1107 Location mainRewriteLoc); 1108 void executeCreateTypes(); 1109 void executeEraseOp(PatternRewriter &rewriter); 1110 template <typename T, typename Range, PDLValue::Kind kind> 1111 void executeExtract(); 1112 void executeFinalize(); 1113 void executeForEach(); 1114 void executeGetAttribute(); 1115 void executeGetAttributeType(); 1116 void executeGetDefiningOp(); 1117 void executeGetOperand(unsigned index); 1118 void executeGetOperands(); 1119 void executeGetResult(unsigned index); 1120 void executeGetResults(); 1121 void executeGetUsers(); 1122 void executeGetValueType(); 1123 void executeGetValueRangeTypes(); 1124 void executeIsNotNull(); 1125 void executeRecordMatch(PatternRewriter &rewriter, 1126 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1127 void executeReplaceOp(PatternRewriter &rewriter); 1128 void executeSwitchAttribute(); 1129 void executeSwitchOperandCount(); 1130 void executeSwitchOperationName(); 1131 void executeSwitchResultCount(); 1132 void executeSwitchType(); 1133 void executeSwitchTypes(); 1134 1135 /// Pushes a code iterator to the stack. 1136 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1137 1138 /// Pops a code iterator from the stack, returning true on success. 1139 void popCodeIt() { 1140 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1141 curCodeIt = resumeCodeIt.back(); 1142 resumeCodeIt.pop_back(); 1143 } 1144 1145 /// Return the bytecode iterator at the start of the current op code. 1146 const ByteCodeField *getPrevCodeIt() const { 1147 LLVM_DEBUG({ 1148 // Account for the op code and the Location stored inline. 1149 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1150 }); 1151 1152 // Account for the op code only. 1153 return curCodeIt - 1; 1154 } 1155 1156 /// Read a value from the bytecode buffer, optionally skipping a certain 1157 /// number of prefix values. These methods always update the buffer to point 1158 /// to the next field after the read data. 1159 template <typename T = ByteCodeField> 1160 T read(size_t skipN = 0) { 1161 curCodeIt += skipN; 1162 return readImpl<T>(); 1163 } 1164 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1165 1166 /// Read a list of values from the bytecode buffer. 1167 template <typename ValueT, typename T> 1168 void readList(SmallVectorImpl<T> &list) { 1169 list.clear(); 1170 for (unsigned i = 0, e = read(); i != e; ++i) 1171 list.push_back(read<ValueT>()); 1172 } 1173 1174 /// Read a list of values from the bytecode buffer. The values may be encoded 1175 /// as either Value or ValueRange elements. 1176 void readValueList(SmallVectorImpl<Value> &list) { 1177 for (unsigned i = 0, e = read(); i != e; ++i) { 1178 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1179 list.push_back(read<Value>()); 1180 } else { 1181 ValueRange *values = read<ValueRange *>(); 1182 list.append(values->begin(), values->end()); 1183 } 1184 } 1185 } 1186 1187 /// Read a value stored inline as a pointer. 1188 template <typename T> 1189 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1190 readInline() { 1191 const void *pointer; 1192 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1193 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1194 return T::getFromOpaquePointer(pointer); 1195 } 1196 1197 /// Jump to a specific successor based on a predicate value. 1198 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1199 /// Jump to a specific successor based on a destination index. 1200 void selectJump(size_t destIndex) { 1201 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1202 } 1203 1204 /// Handle a switch operation with the provided value and cases. 1205 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1206 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1207 LLVM_DEBUG({ 1208 llvm::dbgs() << " * Value: " << value << "\n" 1209 << " * Cases: "; 1210 llvm::interleaveComma(cases, llvm::dbgs()); 1211 llvm::dbgs() << "\n"; 1212 }); 1213 1214 // Check to see if the attribute value is within the case list. Jump to 1215 // the correct successor index based on the result. 1216 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1217 if (cmp(*it, value)) 1218 return selectJump(size_t((it - cases.begin()) + 1)); 1219 selectJump(size_t(0)); 1220 } 1221 1222 /// Store a pointer to memory. 1223 void storeToMemory(unsigned index, const void *value) { 1224 memory[index] = value; 1225 } 1226 1227 /// Store a value to memory as an opaque pointer. 1228 template <typename T> 1229 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1230 storeToMemory(unsigned index, T value) { 1231 memory[index] = value.getAsOpaquePointer(); 1232 } 1233 1234 /// Internal implementation of reading various data types from the bytecode 1235 /// stream. 1236 template <typename T> 1237 const void *readFromMemory() { 1238 size_t index = *curCodeIt++; 1239 1240 // If this type is an SSA value, it can only be stored in non-const memory. 1241 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1242 Value>::value || 1243 index < memory.size()) 1244 return memory[index]; 1245 1246 // Otherwise, if this index is not inbounds it is uniqued. 1247 return uniquedMemory[index - memory.size()]; 1248 } 1249 template <typename T> 1250 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1251 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1252 } 1253 template <typename T> 1254 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1255 T> 1256 readImpl() { 1257 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1258 } 1259 template <typename T> 1260 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1261 switch (read<PDLValue::Kind>()) { 1262 case PDLValue::Kind::Attribute: 1263 return read<Attribute>(); 1264 case PDLValue::Kind::Operation: 1265 return read<Operation *>(); 1266 case PDLValue::Kind::Type: 1267 return read<Type>(); 1268 case PDLValue::Kind::Value: 1269 return read<Value>(); 1270 case PDLValue::Kind::TypeRange: 1271 return read<TypeRange *>(); 1272 case PDLValue::Kind::ValueRange: 1273 return read<ValueRange *>(); 1274 } 1275 llvm_unreachable("unhandled PDLValue::Kind"); 1276 } 1277 template <typename T> 1278 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1279 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1280 "unexpected ByteCode address size"); 1281 ByteCodeAddr result; 1282 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1283 curCodeIt += 2; 1284 return result; 1285 } 1286 template <typename T> 1287 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1288 return *curCodeIt++; 1289 } 1290 template <typename T> 1291 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1292 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1293 } 1294 1295 /// The underlying bytecode buffer. 1296 const ByteCodeField *curCodeIt; 1297 1298 /// The stack of bytecode positions at which to resume operation. 1299 SmallVector<const ByteCodeField *> resumeCodeIt; 1300 1301 /// The current execution memory. 1302 MutableArrayRef<const void *> memory; 1303 MutableArrayRef<OwningOpRange> opRangeMemory; 1304 MutableArrayRef<TypeRange> typeRangeMemory; 1305 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1306 MutableArrayRef<ValueRange> valueRangeMemory; 1307 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1308 1309 /// The current loop indices. 1310 MutableArrayRef<unsigned> loopIndex; 1311 1312 /// References to ByteCode data necessary for execution. 1313 ArrayRef<const void *> uniquedMemory; 1314 ArrayRef<ByteCodeField> code; 1315 ArrayRef<PatternBenefit> currentPatternBenefits; 1316 ArrayRef<PDLByteCodePattern> patterns; 1317 ArrayRef<PDLConstraintFunction> constraintFunctions; 1318 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1319 }; 1320 1321 /// This class is an instantiation of the PDLResultList that provides access to 1322 /// the returned results. This API is not on `PDLResultList` to avoid 1323 /// overexposing access to information specific solely to the ByteCode. 1324 class ByteCodeRewriteResultList : public PDLResultList { 1325 public: 1326 ByteCodeRewriteResultList(unsigned maxNumResults) 1327 : PDLResultList(maxNumResults) {} 1328 1329 /// Return the list of PDL results. 1330 MutableArrayRef<PDLValue> getResults() { return results; } 1331 1332 /// Return the type ranges allocated by this list. 1333 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1334 return allocatedTypeRanges; 1335 } 1336 1337 /// Return the value ranges allocated by this list. 1338 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1339 return allocatedValueRanges; 1340 } 1341 }; 1342 } // namespace 1343 1344 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1345 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1346 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1347 SmallVector<PDLValue, 16> args; 1348 readList<PDLValue>(args); 1349 1350 LLVM_DEBUG({ 1351 llvm::dbgs() << " * Arguments: "; 1352 llvm::interleaveComma(args, llvm::dbgs()); 1353 }); 1354 1355 // Invoke the constraint and jump to the proper destination. 1356 selectJump(succeeded(constraintFn(rewriter, args))); 1357 } 1358 1359 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1360 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1361 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1362 SmallVector<PDLValue, 16> args; 1363 readList<PDLValue>(args); 1364 1365 LLVM_DEBUG({ 1366 llvm::dbgs() << " * Arguments: "; 1367 llvm::interleaveComma(args, llvm::dbgs()); 1368 }); 1369 1370 // Execute the rewrite function. 1371 ByteCodeField numResults = read(); 1372 ByteCodeRewriteResultList results(numResults); 1373 LogicalResult rewriteResult = rewriteFn(rewriter, results, args); 1374 1375 assert(results.getResults().size() == numResults && 1376 "native PDL rewrite function returned unexpected number of results"); 1377 1378 // Store the results in the bytecode memory. 1379 for (PDLValue &result : results.getResults()) { 1380 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1381 1382 // In debug mode we also verify the expected kind of the result. 1383 #ifndef NDEBUG 1384 assert(result.getKind() == read<PDLValue::Kind>() && 1385 "native PDL rewrite function returned an unexpected type of result"); 1386 #endif 1387 1388 // If the result is a range, we need to copy it over to the bytecodes 1389 // range memory. 1390 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1391 unsigned rangeIndex = read(); 1392 typeRangeMemory[rangeIndex] = *typeRange; 1393 memory[read()] = &typeRangeMemory[rangeIndex]; 1394 } else if (Optional<ValueRange> valueRange = 1395 result.dyn_cast<ValueRange>()) { 1396 unsigned rangeIndex = read(); 1397 valueRangeMemory[rangeIndex] = *valueRange; 1398 memory[read()] = &valueRangeMemory[rangeIndex]; 1399 } else { 1400 memory[read()] = result.getAsOpaquePointer(); 1401 } 1402 } 1403 1404 // Copy over any underlying storage allocated for result ranges. 1405 for (auto &it : results.getAllocatedTypeRanges()) 1406 allocatedTypeRangeMemory.push_back(std::move(it)); 1407 for (auto &it : results.getAllocatedValueRanges()) 1408 allocatedValueRangeMemory.push_back(std::move(it)); 1409 1410 // Process the result of the rewrite. 1411 if (failed(rewriteResult)) { 1412 LLVM_DEBUG(llvm::dbgs() << " - Failed"); 1413 return failure(); 1414 } 1415 return success(); 1416 } 1417 1418 void ByteCodeExecutor::executeAreEqual() { 1419 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1420 const void *lhs = read<const void *>(); 1421 const void *rhs = read<const void *>(); 1422 1423 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1424 selectJump(lhs == rhs); 1425 } 1426 1427 void ByteCodeExecutor::executeAreRangesEqual() { 1428 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1429 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1430 const void *lhs = read<const void *>(); 1431 const void *rhs = read<const void *>(); 1432 1433 switch (valueKind) { 1434 case PDLValue::Kind::TypeRange: { 1435 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1436 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1437 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1438 selectJump(*lhsRange == *rhsRange); 1439 break; 1440 } 1441 case PDLValue::Kind::ValueRange: { 1442 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1443 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1444 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1445 selectJump(*lhsRange == *rhsRange); 1446 break; 1447 } 1448 default: 1449 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1450 } 1451 } 1452 1453 void ByteCodeExecutor::executeBranch() { 1454 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1455 curCodeIt = &code[read<ByteCodeAddr>()]; 1456 } 1457 1458 void ByteCodeExecutor::executeCheckOperandCount() { 1459 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1460 Operation *op = read<Operation *>(); 1461 uint32_t expectedCount = read<uint32_t>(); 1462 bool compareAtLeast = read(); 1463 1464 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1465 << " * Expected: " << expectedCount << "\n" 1466 << " * Comparator: " 1467 << (compareAtLeast ? ">=" : "==") << "\n"); 1468 if (compareAtLeast) 1469 selectJump(op->getNumOperands() >= expectedCount); 1470 else 1471 selectJump(op->getNumOperands() == expectedCount); 1472 } 1473 1474 void ByteCodeExecutor::executeCheckOperationName() { 1475 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1476 Operation *op = read<Operation *>(); 1477 OperationName expectedName = read<OperationName>(); 1478 1479 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1480 << " * Expected: \"" << expectedName << "\"\n"); 1481 selectJump(op->getName() == expectedName); 1482 } 1483 1484 void ByteCodeExecutor::executeCheckResultCount() { 1485 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1486 Operation *op = read<Operation *>(); 1487 uint32_t expectedCount = read<uint32_t>(); 1488 bool compareAtLeast = read(); 1489 1490 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1491 << " * Expected: " << expectedCount << "\n" 1492 << " * Comparator: " 1493 << (compareAtLeast ? ">=" : "==") << "\n"); 1494 if (compareAtLeast) 1495 selectJump(op->getNumResults() >= expectedCount); 1496 else 1497 selectJump(op->getNumResults() == expectedCount); 1498 } 1499 1500 void ByteCodeExecutor::executeCheckTypes() { 1501 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1502 TypeRange *lhs = read<TypeRange *>(); 1503 Attribute rhs = read<Attribute>(); 1504 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1505 1506 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1507 } 1508 1509 void ByteCodeExecutor::executeContinue() { 1510 ByteCodeField level = read(); 1511 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1512 << " * Level: " << level << "\n"); 1513 ++loopIndex[level]; 1514 popCodeIt(); 1515 } 1516 1517 void ByteCodeExecutor::executeCreateTypes() { 1518 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1519 unsigned memIndex = read(); 1520 unsigned rangeIndex = read(); 1521 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1522 1523 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1524 1525 // Allocate a buffer for this type range. 1526 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1527 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1528 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1529 1530 // Assign this to the range slot and use the range as the value for the 1531 // memory index. 1532 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1533 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1534 } 1535 1536 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1537 Location mainRewriteLoc) { 1538 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1539 1540 unsigned memIndex = read(); 1541 OperationState state(mainRewriteLoc, read<OperationName>()); 1542 readValueList(state.operands); 1543 for (unsigned i = 0, e = read(); i != e; ++i) { 1544 StringAttr name = read<StringAttr>(); 1545 if (Attribute attr = read<Attribute>()) 1546 state.addAttribute(name, attr); 1547 } 1548 1549 // Read in the result types. If the "size" is the sentinel value, this 1550 // indicates that the result types should be inferred. 1551 unsigned numResults = read(); 1552 if (numResults == kInferTypesMarker) { 1553 InferTypeOpInterface::Concept *inferInterface = 1554 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1555 assert(inferInterface && 1556 "expected operation to provide InferTypeOpInterface"); 1557 1558 // TODO: Handle failure. 1559 if (failed(inferInterface->inferReturnTypes( 1560 state.getContext(), state.location, state.operands, 1561 state.attributes.getDictionary(state.getContext()), state.regions, 1562 state.types))) 1563 return; 1564 } else { 1565 // Otherwise, this is a fixed number of results. 1566 for (unsigned i = 0; i != numResults; ++i) { 1567 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1568 state.types.push_back(read<Type>()); 1569 } else { 1570 TypeRange *resultTypes = read<TypeRange *>(); 1571 state.types.append(resultTypes->begin(), resultTypes->end()); 1572 } 1573 } 1574 } 1575 1576 Operation *resultOp = rewriter.create(state); 1577 memory[memIndex] = resultOp; 1578 1579 LLVM_DEBUG({ 1580 llvm::dbgs() << " * Attributes: " 1581 << state.attributes.getDictionary(state.getContext()) 1582 << "\n * Operands: "; 1583 llvm::interleaveComma(state.operands, llvm::dbgs()); 1584 llvm::dbgs() << "\n * Result Types: "; 1585 llvm::interleaveComma(state.types, llvm::dbgs()); 1586 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1587 }); 1588 } 1589 1590 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1591 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1592 Operation *op = read<Operation *>(); 1593 1594 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1595 rewriter.eraseOp(op); 1596 } 1597 1598 template <typename T, typename Range, PDLValue::Kind kind> 1599 void ByteCodeExecutor::executeExtract() { 1600 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1601 Range *range = read<Range *>(); 1602 unsigned index = read<uint32_t>(); 1603 unsigned memIndex = read(); 1604 1605 if (!range) { 1606 memory[memIndex] = nullptr; 1607 return; 1608 } 1609 1610 T result = index < range->size() ? (*range)[index] : T(); 1611 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1612 << " * Index: " << index << "\n" 1613 << " * Result: " << result << "\n"); 1614 storeToMemory(memIndex, result); 1615 } 1616 1617 void ByteCodeExecutor::executeFinalize() { 1618 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1619 } 1620 1621 void ByteCodeExecutor::executeForEach() { 1622 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1623 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1624 unsigned rangeIndex = read(); 1625 unsigned memIndex = read(); 1626 const void *value = nullptr; 1627 1628 switch (read<PDLValue::Kind>()) { 1629 case PDLValue::Kind::Operation: { 1630 unsigned &index = loopIndex[read()]; 1631 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1632 assert(index <= array.size() && "iterated past the end"); 1633 if (index < array.size()) { 1634 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1635 value = array[index]; 1636 break; 1637 } 1638 1639 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1640 index = 0; 1641 selectJump(size_t(0)); 1642 return; 1643 } 1644 default: 1645 llvm_unreachable("unexpected `ForEach` value kind"); 1646 } 1647 1648 // Store the iterate value and the stack address. 1649 memory[memIndex] = value; 1650 pushCodeIt(prevCodeIt); 1651 1652 // Skip over the successor (we will enter the body of the loop). 1653 read<ByteCodeAddr>(); 1654 } 1655 1656 void ByteCodeExecutor::executeGetAttribute() { 1657 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1658 unsigned memIndex = read(); 1659 Operation *op = read<Operation *>(); 1660 StringAttr attrName = read<StringAttr>(); 1661 Attribute attr = op->getAttr(attrName); 1662 1663 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1664 << " * Attribute: " << attrName << "\n" 1665 << " * Result: " << attr << "\n"); 1666 memory[memIndex] = attr.getAsOpaquePointer(); 1667 } 1668 1669 void ByteCodeExecutor::executeGetAttributeType() { 1670 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1671 unsigned memIndex = read(); 1672 Attribute attr = read<Attribute>(); 1673 Type type; 1674 if (auto typedAttr = attr.dyn_cast<TypedAttr>()) 1675 type = typedAttr.getType(); 1676 1677 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1678 << " * Result: " << type << "\n"); 1679 memory[memIndex] = type.getAsOpaquePointer(); 1680 } 1681 1682 void ByteCodeExecutor::executeGetDefiningOp() { 1683 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1684 unsigned memIndex = read(); 1685 Operation *op = nullptr; 1686 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1687 Value value = read<Value>(); 1688 if (value) 1689 op = value.getDefiningOp(); 1690 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1691 } else { 1692 ValueRange *values = read<ValueRange *>(); 1693 if (values && !values->empty()) { 1694 op = values->front().getDefiningOp(); 1695 } 1696 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1697 } 1698 1699 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1700 memory[memIndex] = op; 1701 } 1702 1703 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1704 Operation *op = read<Operation *>(); 1705 unsigned memIndex = read(); 1706 Value operand = 1707 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1708 1709 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1710 << " * Index: " << index << "\n" 1711 << " * Result: " << operand << "\n"); 1712 memory[memIndex] = operand.getAsOpaquePointer(); 1713 } 1714 1715 /// This function is the internal implementation of `GetResults` and 1716 /// `GetOperands` that provides support for extracting a value range from the 1717 /// given operation. 1718 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1719 static void * 1720 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1721 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1722 MutableArrayRef<ValueRange> valueRangeMemory) { 1723 // Check for the sentinel index that signals that all values should be 1724 // returned. 1725 if (index == std::numeric_limits<uint32_t>::max()) { 1726 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1727 // `values` is already the full value range. 1728 1729 // Otherwise, check to see if this operation uses AttrSizedSegments. 1730 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1731 LLVM_DEBUG(llvm::dbgs() 1732 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1733 1734 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments); 1735 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index) 1736 return nullptr; 1737 1738 ArrayRef<int32_t> segments = segmentAttr; 1739 unsigned startIndex = 1740 std::accumulate(segments.begin(), segments.begin() + index, 0); 1741 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1742 1743 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1744 << *std::next(segments.begin(), index) << "]\n"); 1745 1746 // Otherwise, assume this is the last operand group of the operation. 1747 // FIXME: We currently don't support operations with 1748 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1749 // have a way to detect it's presence. 1750 } else if (values.size() >= index) { 1751 LLVM_DEBUG(llvm::dbgs() 1752 << " * Treating values as trailing variadic range\n"); 1753 values = values.drop_front(index); 1754 1755 // If we couldn't detect a way to compute the values, bail out. 1756 } else { 1757 return nullptr; 1758 } 1759 1760 // If the range index is valid, we are returning a range. 1761 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1762 valueRangeMemory[rangeIndex] = values; 1763 return &valueRangeMemory[rangeIndex]; 1764 } 1765 1766 // If a range index wasn't provided, the range is required to be non-variadic. 1767 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1768 } 1769 1770 void ByteCodeExecutor::executeGetOperands() { 1771 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1772 unsigned index = read<uint32_t>(); 1773 Operation *op = read<Operation *>(); 1774 ByteCodeField rangeIndex = read(); 1775 1776 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1777 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1778 valueRangeMemory); 1779 if (!result) 1780 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1781 memory[read()] = result; 1782 } 1783 1784 void ByteCodeExecutor::executeGetResult(unsigned index) { 1785 Operation *op = read<Operation *>(); 1786 unsigned memIndex = read(); 1787 OpResult result = 1788 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1789 1790 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1791 << " * Index: " << index << "\n" 1792 << " * Result: " << result << "\n"); 1793 memory[memIndex] = result.getAsOpaquePointer(); 1794 } 1795 1796 void ByteCodeExecutor::executeGetResults() { 1797 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1798 unsigned index = read<uint32_t>(); 1799 Operation *op = read<Operation *>(); 1800 ByteCodeField rangeIndex = read(); 1801 1802 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1803 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1804 valueRangeMemory); 1805 if (!result) 1806 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1807 memory[read()] = result; 1808 } 1809 1810 void ByteCodeExecutor::executeGetUsers() { 1811 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1812 unsigned memIndex = read(); 1813 unsigned rangeIndex = read(); 1814 OwningOpRange &range = opRangeMemory[rangeIndex]; 1815 memory[memIndex] = ⦥ 1816 1817 range = OwningOpRange(); 1818 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1819 // Read the value. 1820 Value value = read<Value>(); 1821 if (!value) 1822 return; 1823 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1824 1825 // Extract the users of a single value. 1826 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1827 llvm::copy(value.getUsers(), range.begin()); 1828 } else { 1829 // Read a range of values. 1830 ValueRange *values = read<ValueRange *>(); 1831 if (!values) 1832 return; 1833 LLVM_DEBUG({ 1834 llvm::dbgs() << " * Values (" << values->size() << "): "; 1835 llvm::interleaveComma(*values, llvm::dbgs()); 1836 llvm::dbgs() << "\n"; 1837 }); 1838 1839 // Extract all the users of a range of values. 1840 SmallVector<Operation *> users; 1841 for (Value value : *values) 1842 users.append(value.user_begin(), value.user_end()); 1843 range = OwningOpRange(users.size()); 1844 llvm::copy(users, range.begin()); 1845 } 1846 1847 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1848 } 1849 1850 void ByteCodeExecutor::executeGetValueType() { 1851 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1852 unsigned memIndex = read(); 1853 Value value = read<Value>(); 1854 Type type = value ? value.getType() : Type(); 1855 1856 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1857 << " * Result: " << type << "\n"); 1858 memory[memIndex] = type.getAsOpaquePointer(); 1859 } 1860 1861 void ByteCodeExecutor::executeGetValueRangeTypes() { 1862 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1863 unsigned memIndex = read(); 1864 unsigned rangeIndex = read(); 1865 ValueRange *values = read<ValueRange *>(); 1866 if (!values) { 1867 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1868 memory[memIndex] = nullptr; 1869 return; 1870 } 1871 1872 LLVM_DEBUG({ 1873 llvm::dbgs() << " * Values (" << values->size() << "): "; 1874 llvm::interleaveComma(*values, llvm::dbgs()); 1875 llvm::dbgs() << "\n * Result: "; 1876 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1877 llvm::dbgs() << "\n"; 1878 }); 1879 typeRangeMemory[rangeIndex] = values->getType(); 1880 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1881 } 1882 1883 void ByteCodeExecutor::executeIsNotNull() { 1884 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1885 const void *value = read<const void *>(); 1886 1887 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1888 selectJump(value != nullptr); 1889 } 1890 1891 void ByteCodeExecutor::executeRecordMatch( 1892 PatternRewriter &rewriter, 1893 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1894 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1895 unsigned patternIndex = read(); 1896 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1897 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1898 1899 // If the benefit of the pattern is impossible, skip the processing of the 1900 // rest of the pattern. 1901 if (benefit.isImpossibleToMatch()) { 1902 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1903 curCodeIt = dest; 1904 return; 1905 } 1906 1907 // Create a fused location containing the locations of each of the 1908 // operations used in the match. This will be used as the location for 1909 // created operations during the rewrite that don't already have an 1910 // explicit location set. 1911 unsigned numMatchLocs = read(); 1912 SmallVector<Location, 4> matchLocs; 1913 matchLocs.reserve(numMatchLocs); 1914 for (unsigned i = 0; i != numMatchLocs; ++i) 1915 matchLocs.push_back(read<Operation *>()->getLoc()); 1916 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1917 1918 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1919 << " * Location: " << matchLoc << "\n"); 1920 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1921 PDLByteCode::MatchResult &match = matches.back(); 1922 1923 // Record all of the inputs to the match. If any of the inputs are ranges, we 1924 // will also need to remap the range pointer to memory stored in the match 1925 // state. 1926 unsigned numInputs = read(); 1927 match.values.reserve(numInputs); 1928 match.typeRangeValues.reserve(numInputs); 1929 match.valueRangeValues.reserve(numInputs); 1930 for (unsigned i = 0; i < numInputs; ++i) { 1931 switch (read<PDLValue::Kind>()) { 1932 case PDLValue::Kind::TypeRange: 1933 match.typeRangeValues.push_back(*read<TypeRange *>()); 1934 match.values.push_back(&match.typeRangeValues.back()); 1935 break; 1936 case PDLValue::Kind::ValueRange: 1937 match.valueRangeValues.push_back(*read<ValueRange *>()); 1938 match.values.push_back(&match.valueRangeValues.back()); 1939 break; 1940 default: 1941 match.values.push_back(read<const void *>()); 1942 break; 1943 } 1944 } 1945 curCodeIt = dest; 1946 } 1947 1948 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1949 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1950 Operation *op = read<Operation *>(); 1951 SmallVector<Value, 16> args; 1952 readValueList(args); 1953 1954 LLVM_DEBUG({ 1955 llvm::dbgs() << " * Operation: " << *op << "\n" 1956 << " * Values: "; 1957 llvm::interleaveComma(args, llvm::dbgs()); 1958 llvm::dbgs() << "\n"; 1959 }); 1960 rewriter.replaceOp(op, args); 1961 } 1962 1963 void ByteCodeExecutor::executeSwitchAttribute() { 1964 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1965 Attribute value = read<Attribute>(); 1966 ArrayAttr cases = read<ArrayAttr>(); 1967 handleSwitch(value, cases); 1968 } 1969 1970 void ByteCodeExecutor::executeSwitchOperandCount() { 1971 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1972 Operation *op = read<Operation *>(); 1973 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1974 1975 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1976 handleSwitch(op->getNumOperands(), cases); 1977 } 1978 1979 void ByteCodeExecutor::executeSwitchOperationName() { 1980 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1981 OperationName value = read<Operation *>()->getName(); 1982 size_t caseCount = read(); 1983 1984 // The operation names are stored in-line, so to print them out for 1985 // debugging purposes we need to read the array before executing the 1986 // switch so that we can display all of the possible values. 1987 LLVM_DEBUG({ 1988 const ByteCodeField *prevCodeIt = curCodeIt; 1989 llvm::dbgs() << " * Value: " << value << "\n" 1990 << " * Cases: "; 1991 llvm::interleaveComma( 1992 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1993 [&](size_t) { return read<OperationName>(); }), 1994 llvm::dbgs()); 1995 llvm::dbgs() << "\n"; 1996 curCodeIt = prevCodeIt; 1997 }); 1998 1999 // Try to find the switch value within any of the cases. 2000 for (size_t i = 0; i != caseCount; ++i) { 2001 if (read<OperationName>() == value) { 2002 curCodeIt += (caseCount - i - 1); 2003 return selectJump(i + 1); 2004 } 2005 } 2006 selectJump(size_t(0)); 2007 } 2008 2009 void ByteCodeExecutor::executeSwitchResultCount() { 2010 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 2011 Operation *op = read<Operation *>(); 2012 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 2013 2014 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 2015 handleSwitch(op->getNumResults(), cases); 2016 } 2017 2018 void ByteCodeExecutor::executeSwitchType() { 2019 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2020 Type value = read<Type>(); 2021 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2022 handleSwitch(value, cases); 2023 } 2024 2025 void ByteCodeExecutor::executeSwitchTypes() { 2026 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2027 TypeRange *value = read<TypeRange *>(); 2028 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2029 if (!value) { 2030 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2031 return selectJump(size_t(0)); 2032 } 2033 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2034 return value == caseValue.getAsValueRange<TypeAttr>(); 2035 }); 2036 } 2037 2038 LogicalResult 2039 ByteCodeExecutor::execute(PatternRewriter &rewriter, 2040 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2041 Optional<Location> mainRewriteLoc) { 2042 while (true) { 2043 // Print the location of the operation being executed. 2044 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2045 2046 OpCode opCode = static_cast<OpCode>(read()); 2047 switch (opCode) { 2048 case ApplyConstraint: 2049 executeApplyConstraint(rewriter); 2050 break; 2051 case ApplyRewrite: 2052 if (failed(executeApplyRewrite(rewriter))) 2053 return failure(); 2054 break; 2055 case AreEqual: 2056 executeAreEqual(); 2057 break; 2058 case AreRangesEqual: 2059 executeAreRangesEqual(); 2060 break; 2061 case Branch: 2062 executeBranch(); 2063 break; 2064 case CheckOperandCount: 2065 executeCheckOperandCount(); 2066 break; 2067 case CheckOperationName: 2068 executeCheckOperationName(); 2069 break; 2070 case CheckResultCount: 2071 executeCheckResultCount(); 2072 break; 2073 case CheckTypes: 2074 executeCheckTypes(); 2075 break; 2076 case Continue: 2077 executeContinue(); 2078 break; 2079 case CreateOperation: 2080 executeCreateOperation(rewriter, *mainRewriteLoc); 2081 break; 2082 case CreateTypes: 2083 executeCreateTypes(); 2084 break; 2085 case EraseOp: 2086 executeEraseOp(rewriter); 2087 break; 2088 case ExtractOp: 2089 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2090 break; 2091 case ExtractType: 2092 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2093 break; 2094 case ExtractValue: 2095 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2096 break; 2097 case Finalize: 2098 executeFinalize(); 2099 LLVM_DEBUG(llvm::dbgs() << "\n"); 2100 return success(); 2101 case ForEach: 2102 executeForEach(); 2103 break; 2104 case GetAttribute: 2105 executeGetAttribute(); 2106 break; 2107 case GetAttributeType: 2108 executeGetAttributeType(); 2109 break; 2110 case GetDefiningOp: 2111 executeGetDefiningOp(); 2112 break; 2113 case GetOperand0: 2114 case GetOperand1: 2115 case GetOperand2: 2116 case GetOperand3: { 2117 unsigned index = opCode - GetOperand0; 2118 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2119 executeGetOperand(index); 2120 break; 2121 } 2122 case GetOperandN: 2123 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2124 executeGetOperand(read<uint32_t>()); 2125 break; 2126 case GetOperands: 2127 executeGetOperands(); 2128 break; 2129 case GetResult0: 2130 case GetResult1: 2131 case GetResult2: 2132 case GetResult3: { 2133 unsigned index = opCode - GetResult0; 2134 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2135 executeGetResult(index); 2136 break; 2137 } 2138 case GetResultN: 2139 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2140 executeGetResult(read<uint32_t>()); 2141 break; 2142 case GetResults: 2143 executeGetResults(); 2144 break; 2145 case GetUsers: 2146 executeGetUsers(); 2147 break; 2148 case GetValueType: 2149 executeGetValueType(); 2150 break; 2151 case GetValueRangeTypes: 2152 executeGetValueRangeTypes(); 2153 break; 2154 case IsNotNull: 2155 executeIsNotNull(); 2156 break; 2157 case RecordMatch: 2158 assert(matches && 2159 "expected matches to be provided when executing the matcher"); 2160 executeRecordMatch(rewriter, *matches); 2161 break; 2162 case ReplaceOp: 2163 executeReplaceOp(rewriter); 2164 break; 2165 case SwitchAttribute: 2166 executeSwitchAttribute(); 2167 break; 2168 case SwitchOperandCount: 2169 executeSwitchOperandCount(); 2170 break; 2171 case SwitchOperationName: 2172 executeSwitchOperationName(); 2173 break; 2174 case SwitchResultCount: 2175 executeSwitchResultCount(); 2176 break; 2177 case SwitchType: 2178 executeSwitchType(); 2179 break; 2180 case SwitchTypes: 2181 executeSwitchTypes(); 2182 break; 2183 } 2184 LLVM_DEBUG(llvm::dbgs() << "\n"); 2185 } 2186 } 2187 2188 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2189 SmallVectorImpl<MatchResult> &matches, 2190 PDLByteCodeMutableState &state) const { 2191 // The first memory slot is always the root operation. 2192 state.memory[0] = op; 2193 2194 // The matcher function always starts at code address 0. 2195 ByteCodeExecutor executor( 2196 matcherByteCode.data(), state.memory, state.opRangeMemory, 2197 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2198 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2199 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2200 constraintFunctions, rewriteFunctions); 2201 LogicalResult executeResult = executor.execute(rewriter, &matches); 2202 assert(succeeded(executeResult) && "unexpected matcher execution failure"); 2203 2204 // Order the found matches by benefit. 2205 std::stable_sort(matches.begin(), matches.end(), 2206 [](const MatchResult &lhs, const MatchResult &rhs) { 2207 return lhs.benefit > rhs.benefit; 2208 }); 2209 } 2210 2211 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, 2212 const MatchResult &match, 2213 PDLByteCodeMutableState &state) const { 2214 auto *configSet = match.pattern->getConfigSet(); 2215 if (configSet) 2216 configSet->notifyRewriteBegin(rewriter); 2217 2218 // The arguments of the rewrite function are stored at the start of the 2219 // memory buffer. 2220 llvm::copy(match.values, state.memory.begin()); 2221 2222 ByteCodeExecutor executor( 2223 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2224 state.opRangeMemory, state.typeRangeMemory, 2225 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2226 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2227 rewriterByteCode, state.currentPatternBenefits, patterns, 2228 constraintFunctions, rewriteFunctions); 2229 LogicalResult result = 2230 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2231 2232 if (configSet) 2233 configSet->notifyRewriteEnd(rewriter); 2234 2235 // If the rewrite failed, check if the pattern rewriter can recover. If it 2236 // can, we can signal to the pattern applicator to keep trying patterns. If it 2237 // doesn't, we need to bail. Bailing here should be fine, given that we have 2238 // no means to propagate such a failure to the user, and it also indicates a 2239 // bug in the user code (i.e. failable rewrites should not be used with 2240 // pattern rewriters that don't support it). 2241 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { 2242 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); 2243 llvm::report_fatal_error( 2244 "Native PDL Rewrite failed, but the pattern " 2245 "rewriter doesn't support recovery. Failable pattern rewrites should " 2246 "not be used with pattern rewriters that do not support them."); 2247 } 2248 return result; 2249 } 2250