1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 PDLPatternConfigSet *configSet, 38 ByteCodeAddr rewriterAddr) { 39 PatternBenefit benefit = matchOp.getBenefit(); 40 MLIRContext *ctx = matchOp.getContext(); 41 42 // Collect the set of generated operations. 43 SmallVector<StringRef, 8> generatedOps; 44 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 45 generatedOps = 46 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 47 48 // Check to see if this is pattern matches a specific operation type. 49 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 50 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, 51 generatedOps); 52 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), 53 benefit, ctx, generatedOps); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // PDLByteCodeMutableState 58 //===----------------------------------------------------------------------===// 59 60 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 61 /// to the position of the pattern within the range returned by 62 /// `PDLByteCode::getPatterns`. 63 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 64 PatternBenefit benefit) { 65 currentPatternBenefits[patternIndex] = benefit; 66 } 67 68 /// Cleanup any allocated state after a full match/rewrite has been completed. 69 /// This method should be called irregardless of whether the match+rewrite was a 70 /// success or not. 71 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 72 allocatedTypeRangeMemory.clear(); 73 allocatedValueRangeMemory.clear(); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Bytecode OpCodes 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 enum OpCode : ByteCodeField { 82 /// Apply an externally registered constraint. 83 ApplyConstraint, 84 /// Apply an externally registered rewrite. 85 ApplyRewrite, 86 /// Check if two generic values are equal. 87 AreEqual, 88 /// Check if two ranges are equal. 89 AreRangesEqual, 90 /// Unconditional branch. 91 Branch, 92 /// Compare the operand count of an operation with a constant. 93 CheckOperandCount, 94 /// Compare the name of an operation with a constant. 95 CheckOperationName, 96 /// Compare the result count of an operation with a constant. 97 CheckResultCount, 98 /// Compare a range of types to a constant range of types. 99 CheckTypes, 100 /// Continue to the next iteration of a loop. 101 Continue, 102 /// Create a type range from a list of constant types. 103 CreateConstantTypeRange, 104 /// Create an operation. 105 CreateOperation, 106 /// Create a type range from a list of dynamic types. 107 CreateDynamicTypeRange, 108 /// Create a value range. 109 CreateDynamicValueRange, 110 /// Erase an operation. 111 EraseOp, 112 /// Extract the op from a range at the specified index. 113 ExtractOp, 114 /// Extract the type from a range at the specified index. 115 ExtractType, 116 /// Extract the value from a range at the specified index. 117 ExtractValue, 118 /// Terminate a matcher or rewrite sequence. 119 Finalize, 120 /// Iterate over a range of values. 121 ForEach, 122 /// Get a specific attribute of an operation. 123 GetAttribute, 124 /// Get the type of an attribute. 125 GetAttributeType, 126 /// Get the defining operation of a value. 127 GetDefiningOp, 128 /// Get a specific operand of an operation. 129 GetOperand0, 130 GetOperand1, 131 GetOperand2, 132 GetOperand3, 133 GetOperandN, 134 /// Get a specific operand group of an operation. 135 GetOperands, 136 /// Get a specific result of an operation. 137 GetResult0, 138 GetResult1, 139 GetResult2, 140 GetResult3, 141 GetResultN, 142 /// Get a specific result group of an operation. 143 GetResults, 144 /// Get the users of a value or a range of values. 145 GetUsers, 146 /// Get the type of a value. 147 GetValueType, 148 /// Get the types of a value range. 149 GetValueRangeTypes, 150 /// Check if a generic value is not null. 151 IsNotNull, 152 /// Record a successful pattern match. 153 RecordMatch, 154 /// Replace an operation. 155 ReplaceOp, 156 /// Compare an attribute with a set of constants. 157 SwitchAttribute, 158 /// Compare the operand count of an operation with a set of constants. 159 SwitchOperandCount, 160 /// Compare the name of an operation with a set of constants. 161 SwitchOperationName, 162 /// Compare the result count of an operation with a set of constants. 163 SwitchResultCount, 164 /// Compare a type with a set of constants. 165 SwitchType, 166 /// Compare a range of types with a set of constants. 167 SwitchTypes, 168 }; 169 } // namespace 170 171 /// A marker used to indicate if an operation should infer types. 172 static constexpr ByteCodeField kInferTypesMarker = 173 std::numeric_limits<ByteCodeField>::max(); 174 175 //===----------------------------------------------------------------------===// 176 // ByteCode Generation 177 //===----------------------------------------------------------------------===// 178 179 //===----------------------------------------------------------------------===// 180 // Generator 181 182 namespace { 183 struct ByteCodeLiveRange; 184 struct ByteCodeWriter; 185 186 /// Check if the given class `T` can be converted to an opaque pointer. 187 template <typename T, typename... Args> 188 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 189 190 /// This class represents the main generator for the pattern bytecode. 191 class Generator { 192 public: 193 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 194 SmallVectorImpl<ByteCodeField> &matcherByteCode, 195 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 196 SmallVectorImpl<PDLByteCodePattern> &patterns, 197 ByteCodeField &maxValueMemoryIndex, 198 ByteCodeField &maxOpRangeMemoryIndex, 199 ByteCodeField &maxTypeRangeMemoryIndex, 200 ByteCodeField &maxValueRangeMemoryIndex, 201 ByteCodeField &maxLoopLevel, 202 llvm::StringMap<PDLConstraintFunction> &constraintFns, 203 llvm::StringMap<PDLRewriteFunction> &rewriteFns, 204 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap) 205 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 206 rewriterByteCode(rewriterByteCode), patterns(patterns), 207 maxValueMemoryIndex(maxValueMemoryIndex), 208 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 209 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 210 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 211 maxLoopLevel(maxLoopLevel), configMap(configMap) { 212 for (const auto &it : llvm::enumerate(constraintFns)) 213 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 214 for (const auto &it : llvm::enumerate(rewriteFns)) 215 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 216 } 217 218 /// Generate the bytecode for the given PDL interpreter module. 219 void generate(ModuleOp module); 220 221 /// Return the memory index to use for the given value. 222 ByteCodeField &getMemIndex(Value value) { 223 assert(valueToMemIndex.count(value) && 224 "expected memory index to be assigned"); 225 return valueToMemIndex[value]; 226 } 227 228 /// Return the range memory index used to store the given range value. 229 ByteCodeField &getRangeStorageIndex(Value value) { 230 assert(valueToRangeIndex.count(value) && 231 "expected range index to be assigned"); 232 return valueToRangeIndex[value]; 233 } 234 235 /// Return an index to use when referring to the given data that is uniqued in 236 /// the MLIR context. 237 template <typename T> 238 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 239 getMemIndex(T val) { 240 const void *opaqueVal = val.getAsOpaquePointer(); 241 242 // Get or insert a reference to this value. 243 auto it = uniquedDataToMemIndex.try_emplace( 244 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 245 if (it.second) 246 uniquedData.push_back(opaqueVal); 247 return it.first->second; 248 } 249 250 private: 251 /// Allocate memory indices for the results of operations within the matcher 252 /// and rewriters. 253 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 254 ModuleOp rewriterModule); 255 256 /// Generate the bytecode for the given operation. 257 void generate(Region *region, ByteCodeWriter &writer); 258 void generate(Operation *op, ByteCodeWriter &writer); 259 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 286 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 287 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 288 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 289 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 290 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 291 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 292 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 293 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 294 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 295 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 296 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 297 298 /// Mapping from value to its corresponding memory index. 299 DenseMap<Value, ByteCodeField> valueToMemIndex; 300 301 /// Mapping from a range value to its corresponding range storage index. 302 DenseMap<Value, ByteCodeField> valueToRangeIndex; 303 304 /// Mapping from the name of an externally registered rewrite to its index in 305 /// the bytecode registry. 306 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 307 308 /// Mapping from the name of an externally registered constraint to its index 309 /// in the bytecode registry. 310 llvm::StringMap<ByteCodeField> constraintToMemIndex; 311 312 /// Mapping from rewriter function name to the bytecode address of the 313 /// rewriter function in byte. 314 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 315 316 /// Mapping from a uniqued storage object to its memory index within 317 /// `uniquedData`. 318 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 319 320 /// The current level of the foreach loop. 321 ByteCodeField curLoopLevel = 0; 322 323 /// The current MLIR context. 324 MLIRContext *ctx; 325 326 /// Mapping from block to its address. 327 DenseMap<Block *, ByteCodeAddr> blockToAddr; 328 329 /// Data of the ByteCode class to be populated. 330 std::vector<const void *> &uniquedData; 331 SmallVectorImpl<ByteCodeField> &matcherByteCode; 332 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 333 SmallVectorImpl<PDLByteCodePattern> &patterns; 334 ByteCodeField &maxValueMemoryIndex; 335 ByteCodeField &maxOpRangeMemoryIndex; 336 ByteCodeField &maxTypeRangeMemoryIndex; 337 ByteCodeField &maxValueRangeMemoryIndex; 338 ByteCodeField &maxLoopLevel; 339 340 /// A map of pattern configurations. 341 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap; 342 }; 343 344 /// This class provides utilities for writing a bytecode stream. 345 struct ByteCodeWriter { 346 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 347 : bytecode(bytecode), generator(generator) {} 348 349 /// Append a field to the bytecode. 350 void append(ByteCodeField field) { bytecode.push_back(field); } 351 void append(OpCode opCode) { bytecode.push_back(opCode); } 352 353 /// Append an address to the bytecode. 354 void append(ByteCodeAddr field) { 355 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 356 "unexpected ByteCode address size"); 357 358 ByteCodeField fieldParts[2]; 359 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 360 bytecode.append({fieldParts[0], fieldParts[1]}); 361 } 362 363 /// Append a single successor to the bytecode, the exact address will need to 364 /// be resolved later. 365 void append(Block *successor) { 366 // Add back a reference to the successor so that the address can be resolved 367 // later. 368 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 369 append(ByteCodeAddr(0)); 370 } 371 372 /// Append a successor range to the bytecode, the exact address will need to 373 /// be resolved later. 374 void append(SuccessorRange successors) { 375 for (Block *successor : successors) 376 append(successor); 377 } 378 379 /// Append a range of values that will be read as generic PDLValues. 380 void appendPDLValueList(OperandRange values) { 381 bytecode.push_back(values.size()); 382 for (Value value : values) 383 appendPDLValue(value); 384 } 385 386 /// Append a value as a PDLValue. 387 void appendPDLValue(Value value) { 388 appendPDLValueKind(value); 389 append(value); 390 } 391 392 /// Append the PDLValue::Kind of the given value. 393 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 394 395 /// Append the PDLValue::Kind of the given type. 396 void appendPDLValueKind(Type type) { 397 PDLValue::Kind kind = 398 TypeSwitch<Type, PDLValue::Kind>(type) 399 .Case<pdl::AttributeType>( 400 [](Type) { return PDLValue::Kind::Attribute; }) 401 .Case<pdl::OperationType>( 402 [](Type) { return PDLValue::Kind::Operation; }) 403 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 404 if (rangeTy.getElementType().isa<pdl::TypeType>()) 405 return PDLValue::Kind::TypeRange; 406 return PDLValue::Kind::ValueRange; 407 }) 408 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 409 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 410 bytecode.push_back(static_cast<ByteCodeField>(kind)); 411 } 412 413 /// Append a value that will be stored in a memory slot and not inline within 414 /// the bytecode. 415 template <typename T> 416 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 417 std::is_pointer<T>::value> 418 append(T value) { 419 bytecode.push_back(generator.getMemIndex(value)); 420 } 421 422 /// Append a range of values. 423 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 424 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 425 append(T range) { 426 bytecode.push_back(llvm::size(range)); 427 for (auto it : range) 428 append(it); 429 } 430 431 /// Append a variadic number of fields to the bytecode. 432 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 433 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 434 append(field); 435 append(field2, fields...); 436 } 437 438 /// Appends a value as a pointer, stored inline within the bytecode. 439 template <typename T> 440 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 441 appendInline(T value) { 442 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 443 const void *pointer = value.getAsOpaquePointer(); 444 ByteCodeField fieldParts[numParts]; 445 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 446 bytecode.append(fieldParts, fieldParts + numParts); 447 } 448 449 /// Successor references in the bytecode that have yet to be resolved. 450 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 451 452 /// The underlying bytecode buffer. 453 SmallVectorImpl<ByteCodeField> &bytecode; 454 455 /// The main generator producing PDL. 456 Generator &generator; 457 }; 458 459 /// This class represents a live range of PDL Interpreter values, containing 460 /// information about when values are live within a match/rewrite. 461 struct ByteCodeLiveRange { 462 using Set = llvm::IntervalMap<uint64_t, char, 16>; 463 using Allocator = Set::Allocator; 464 465 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 466 467 /// Union this live range with the one provided. 468 void unionWith(const ByteCodeLiveRange &rhs) { 469 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 470 ++it) 471 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 472 } 473 474 /// Returns true if this range overlaps with the one provided. 475 bool overlaps(const ByteCodeLiveRange &rhs) const { 476 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 477 .valid(); 478 } 479 480 /// A map representing the ranges of the match/rewrite that a value is live in 481 /// the interpreter. 482 /// 483 /// We use std::unique_ptr here, because IntervalMap does not provide a 484 /// correct copy or move constructor. We can eliminate the pointer once 485 /// https://reviews.llvm.org/D113240 lands. 486 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 487 488 /// The operation range storage index for this range. 489 Optional<unsigned> opRangeIndex; 490 491 /// The type range storage index for this range. 492 Optional<unsigned> typeRangeIndex; 493 494 /// The value range storage index for this range. 495 Optional<unsigned> valueRangeIndex; 496 }; 497 } // namespace 498 499 void Generator::generate(ModuleOp module) { 500 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 501 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 502 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 503 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 504 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 505 506 // Allocate memory indices for the results of operations within the matcher 507 // and rewriters. 508 allocateMemoryIndices(matcherFunc, rewriterModule); 509 510 // Generate code for the rewriter functions. 511 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 512 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 513 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 514 for (Operation &op : rewriterFunc.getOps()) 515 generate(&op, rewriterByteCodeWriter); 516 } 517 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 518 "unexpected branches in rewriter function"); 519 520 // Generate code for the matcher function. 521 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 522 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 523 524 // Resolve successor references in the matcher. 525 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 526 ByteCodeAddr addr = blockToAddr[it.first]; 527 for (unsigned offsetToFix : it.second) 528 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 529 } 530 } 531 532 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 533 ModuleOp rewriterModule) { 534 // Rewriters use simplistic allocation scheme that simply assigns an index to 535 // each result. 536 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 537 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 538 auto processRewriterValue = [&](Value val) { 539 valueToMemIndex.try_emplace(val, index++); 540 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 541 Type elementTy = rangeType.getElementType(); 542 if (elementTy.isa<pdl::TypeType>()) 543 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 544 else if (elementTy.isa<pdl::ValueType>()) 545 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 546 } 547 }; 548 549 for (BlockArgument arg : rewriterFunc.getArguments()) 550 processRewriterValue(arg); 551 rewriterFunc.getBody().walk([&](Operation *op) { 552 for (Value result : op->getResults()) 553 processRewriterValue(result); 554 }); 555 if (index > maxValueMemoryIndex) 556 maxValueMemoryIndex = index; 557 if (typeRangeIndex > maxTypeRangeMemoryIndex) 558 maxTypeRangeMemoryIndex = typeRangeIndex; 559 if (valueRangeIndex > maxValueRangeMemoryIndex) 560 maxValueRangeMemoryIndex = valueRangeIndex; 561 } 562 563 // The matcher function uses a more sophisticated numbering that tries to 564 // minimize the number of memory indices assigned. This is done by determining 565 // a live range of the values within the matcher, then the allocation is just 566 // finding the minimal number of overlapping live ranges. This is essentially 567 // a simplified form of register allocation where we don't necessarily have a 568 // limited number of registers, but we still want to minimize the number used. 569 DenseMap<Operation *, unsigned> opToFirstIndex; 570 DenseMap<Operation *, unsigned> opToLastIndex; 571 572 // A custom walk that marks the first and the last index of each operation. 573 // The entry marks the beginning of the liveness range for this operation, 574 // followed by nested operations, followed by the end of the liveness range. 575 unsigned index = 0; 576 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 577 opToFirstIndex.try_emplace(op, index++); 578 for (Region ®ion : op->getRegions()) 579 for (Block &block : region.getBlocks()) 580 for (Operation &nested : block) 581 walk(&nested); 582 opToLastIndex.try_emplace(op, index++); 583 }; 584 walk(matcherFunc); 585 586 // Liveness info for each of the defs within the matcher. 587 ByteCodeLiveRange::Allocator allocator; 588 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 589 590 // Assign the root operation being matched to slot 0. 591 BlockArgument rootOpArg = matcherFunc.getArgument(0); 592 valueToMemIndex[rootOpArg] = 0; 593 594 // Walk each of the blocks, computing the def interval that the value is used. 595 Liveness matcherLiveness(matcherFunc); 596 matcherFunc->walk([&](Block *block) { 597 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 598 assert(info && "expected liveness info for block"); 599 auto processValue = [&](Value value, Operation *firstUseOrDef) { 600 // We don't need to process the root op argument, this value is always 601 // assigned to the first memory slot. 602 if (value == rootOpArg) 603 return; 604 605 // Set indices for the range of this block that the value is used. 606 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 607 defRangeIt->second.liveness->insert( 608 opToFirstIndex[firstUseOrDef], 609 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 610 /*dummyValue*/ 0); 611 612 // Check to see if this value is a range type. 613 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 614 Type eleType = rangeTy.getElementType(); 615 if (eleType.isa<pdl::OperationType>()) 616 defRangeIt->second.opRangeIndex = 0; 617 else if (eleType.isa<pdl::TypeType>()) 618 defRangeIt->second.typeRangeIndex = 0; 619 else if (eleType.isa<pdl::ValueType>()) 620 defRangeIt->second.valueRangeIndex = 0; 621 } 622 }; 623 624 // Process the live-ins of this block. 625 for (Value liveIn : info->in()) { 626 // Only process the value if it has been defined in the current region. 627 // Other values that span across pdl_interp.foreach will be added higher 628 // up. This ensures that the we keep them alive for the entire duration 629 // of the loop. 630 if (liveIn.getParentRegion() == block->getParent()) 631 processValue(liveIn, &block->front()); 632 } 633 634 // Process the block arguments for the entry block (those are not live-in). 635 if (block->isEntryBlock()) { 636 for (Value argument : block->getArguments()) 637 processValue(argument, &block->front()); 638 } 639 640 // Process any new defs within this block. 641 for (Operation &op : *block) 642 for (Value result : op.getResults()) 643 processValue(result, &op); 644 }); 645 646 // Greedily allocate memory slots using the computed def live ranges. 647 std::vector<ByteCodeLiveRange> allocatedIndices; 648 649 // The number of memory indices currently allocated (and its next value). 650 // Recall that the root gets allocated memory index 0. 651 ByteCodeField numIndices = 1; 652 653 // The number of memory ranges of various types (and their next values). 654 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 655 656 for (auto &defIt : valueDefRanges) { 657 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 658 ByteCodeLiveRange &defRange = defIt.second; 659 660 // Try to allocate to an existing index. 661 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 662 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 663 if (!defRange.overlaps(existingRange)) { 664 existingRange.unionWith(defRange); 665 memIndex = existingIndexIt.index() + 1; 666 667 if (defRange.opRangeIndex) { 668 if (!existingRange.opRangeIndex) 669 existingRange.opRangeIndex = numOpRanges++; 670 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 671 } else if (defRange.typeRangeIndex) { 672 if (!existingRange.typeRangeIndex) 673 existingRange.typeRangeIndex = numTypeRanges++; 674 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 675 } else if (defRange.valueRangeIndex) { 676 if (!existingRange.valueRangeIndex) 677 existingRange.valueRangeIndex = numValueRanges++; 678 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 679 } 680 break; 681 } 682 } 683 684 // If no existing index could be used, add a new one. 685 if (memIndex == 0) { 686 allocatedIndices.emplace_back(allocator); 687 ByteCodeLiveRange &newRange = allocatedIndices.back(); 688 newRange.unionWith(defRange); 689 690 // Allocate an index for op/type/value ranges. 691 if (defRange.opRangeIndex) { 692 newRange.opRangeIndex = numOpRanges; 693 valueToRangeIndex[defIt.first] = numOpRanges++; 694 } else if (defRange.typeRangeIndex) { 695 newRange.typeRangeIndex = numTypeRanges; 696 valueToRangeIndex[defIt.first] = numTypeRanges++; 697 } else if (defRange.valueRangeIndex) { 698 newRange.valueRangeIndex = numValueRanges; 699 valueToRangeIndex[defIt.first] = numValueRanges++; 700 } 701 702 memIndex = allocatedIndices.size(); 703 ++numIndices; 704 } 705 } 706 707 // Print the index usage and ensure that we did not run out of index space. 708 LLVM_DEBUG({ 709 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 710 << "(down from initial " << valueDefRanges.size() << ").\n"; 711 }); 712 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 713 "Ran out of memory for allocated indices"); 714 715 // Update the max number of indices. 716 if (numIndices > maxValueMemoryIndex) 717 maxValueMemoryIndex = numIndices; 718 if (numOpRanges > maxOpRangeMemoryIndex) 719 maxOpRangeMemoryIndex = numOpRanges; 720 if (numTypeRanges > maxTypeRangeMemoryIndex) 721 maxTypeRangeMemoryIndex = numTypeRanges; 722 if (numValueRanges > maxValueRangeMemoryIndex) 723 maxValueRangeMemoryIndex = numValueRanges; 724 } 725 726 void Generator::generate(Region *region, ByteCodeWriter &writer) { 727 llvm::ReversePostOrderTraversal<Region *> rpot(region); 728 for (Block *block : rpot) { 729 // Keep track of where this block begins within the matcher function. 730 blockToAddr.try_emplace(block, matcherByteCode.size()); 731 for (Operation &op : *block) 732 generate(&op, writer); 733 } 734 } 735 736 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 737 LLVM_DEBUG({ 738 // The following list must contain all the operations that do not 739 // produce any bytecode. 740 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op)) 741 writer.appendInline(op->getLoc()); 742 }); 743 TypeSwitch<Operation *>(op) 744 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 745 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 746 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 747 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 748 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 749 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 750 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp, 751 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 752 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 753 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 754 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 755 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 756 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 757 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 758 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 759 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 760 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, 761 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 762 pdl_interp::SwitchResultCountOp>( 763 [&](auto interpOp) { this->generate(interpOp, writer); }) 764 .Default([](Operation *) { 765 llvm_unreachable("unknown `pdl_interp` operation"); 766 }); 767 } 768 769 void Generator::generate(pdl_interp::ApplyConstraintOp op, 770 ByteCodeWriter &writer) { 771 assert(constraintToMemIndex.count(op.getName()) && 772 "expected index for constraint function"); 773 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 774 writer.appendPDLValueList(op.getArgs()); 775 writer.append(op.getSuccessors()); 776 } 777 void Generator::generate(pdl_interp::ApplyRewriteOp op, 778 ByteCodeWriter &writer) { 779 assert(externalRewriterToMemIndex.count(op.getName()) && 780 "expected index for rewrite function"); 781 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 782 writer.appendPDLValueList(op.getArgs()); 783 784 ResultRange results = op.getResults(); 785 writer.append(ByteCodeField(results.size())); 786 for (Value result : results) { 787 // In debug mode we also record the expected kind of the result, so that we 788 // can provide extra verification of the native rewrite function. 789 #ifndef NDEBUG 790 writer.appendPDLValueKind(result); 791 #endif 792 793 // Range results also need to append the range storage index. 794 if (result.getType().isa<pdl::RangeType>()) 795 writer.append(getRangeStorageIndex(result)); 796 writer.append(result); 797 } 798 } 799 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 800 Value lhs = op.getLhs(); 801 if (lhs.getType().isa<pdl::RangeType>()) { 802 writer.append(OpCode::AreRangesEqual); 803 writer.appendPDLValueKind(lhs); 804 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 805 return; 806 } 807 808 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 809 } 810 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 811 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 812 } 813 void Generator::generate(pdl_interp::CheckAttributeOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 816 op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::CheckOperandCountOp op, 819 ByteCodeWriter &writer) { 820 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 821 static_cast<ByteCodeField>(op.getCompareAtLeast()), 822 op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckOperationNameOp op, 825 ByteCodeWriter &writer) { 826 writer.append(OpCode::CheckOperationName, op.getInputOp(), 827 OperationName(op.getName(), ctx), op.getSuccessors()); 828 } 829 void Generator::generate(pdl_interp::CheckResultCountOp op, 830 ByteCodeWriter &writer) { 831 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 832 static_cast<ByteCodeField>(op.getCompareAtLeast()), 833 op.getSuccessors()); 834 } 835 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 836 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 837 op.getSuccessors()); 838 } 839 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 840 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 841 op.getSuccessors()); 842 } 843 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 844 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 845 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 846 } 847 void Generator::generate(pdl_interp::CreateAttributeOp op, 848 ByteCodeWriter &writer) { 849 // Simply repoint the memory index of the result to the constant. 850 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 851 } 852 void Generator::generate(pdl_interp::CreateOperationOp op, 853 ByteCodeWriter &writer) { 854 writer.append(OpCode::CreateOperation, op.getResultOp(), 855 OperationName(op.getName(), ctx)); 856 writer.appendPDLValueList(op.getInputOperands()); 857 858 // Add the attributes. 859 OperandRange attributes = op.getInputAttributes(); 860 writer.append(static_cast<ByteCodeField>(attributes.size())); 861 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 862 writer.append(std::get<0>(it), std::get<1>(it)); 863 864 // Add the result types. If the operation has inferred results, we use a 865 // marker "size" value. Otherwise, we add the list of explicit result types. 866 if (op.getInferredResultTypes()) 867 writer.append(kInferTypesMarker); 868 else 869 writer.appendPDLValueList(op.getInputResultTypes()); 870 } 871 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { 872 // Append the correct opcode for the range type. 873 TypeSwitch<Type>(op.getType().getElementType()) 874 .Case( 875 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); }) 876 .Case([&](pdl::ValueType) { 877 writer.append(OpCode::CreateDynamicValueRange); 878 }); 879 880 writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); 881 writer.appendPDLValueList(op->getOperands()); 882 } 883 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 884 // Simply repoint the memory index of the result to the constant. 885 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 886 } 887 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 888 writer.append(OpCode::CreateConstantTypeRange, op.getResult(), 889 getRangeStorageIndex(op.getResult()), op.getValue()); 890 } 891 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 892 writer.append(OpCode::EraseOp, op.getInputOp()); 893 } 894 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 895 OpCode opCode = 896 TypeSwitch<Type, OpCode>(op.getResult().getType()) 897 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 898 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 899 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 900 .Default([](Type) -> OpCode { 901 llvm_unreachable("unsupported element type"); 902 }); 903 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 904 } 905 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 906 writer.append(OpCode::Finalize); 907 } 908 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 909 BlockArgument arg = op.getLoopVariable(); 910 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 911 writer.appendPDLValueKind(arg.getType()); 912 writer.append(curLoopLevel, op.getSuccessor()); 913 ++curLoopLevel; 914 if (curLoopLevel > maxLoopLevel) 915 maxLoopLevel = curLoopLevel; 916 generate(&op.getRegion(), writer); 917 --curLoopLevel; 918 } 919 void Generator::generate(pdl_interp::GetAttributeOp op, 920 ByteCodeWriter &writer) { 921 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 922 op.getNameAttr()); 923 } 924 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 925 ByteCodeWriter &writer) { 926 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 927 } 928 void Generator::generate(pdl_interp::GetDefiningOpOp op, 929 ByteCodeWriter &writer) { 930 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 931 writer.appendPDLValue(op.getValue()); 932 } 933 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 934 uint32_t index = op.getIndex(); 935 if (index < 4) 936 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 937 else 938 writer.append(OpCode::GetOperandN, index); 939 writer.append(op.getInputOp(), op.getValue()); 940 } 941 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 942 Value result = op.getValue(); 943 Optional<uint32_t> index = op.getIndex(); 944 writer.append(OpCode::GetOperands, 945 index.value_or(std::numeric_limits<uint32_t>::max()), 946 op.getInputOp()); 947 if (result.getType().isa<pdl::RangeType>()) 948 writer.append(getRangeStorageIndex(result)); 949 else 950 writer.append(std::numeric_limits<ByteCodeField>::max()); 951 writer.append(result); 952 } 953 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 954 uint32_t index = op.getIndex(); 955 if (index < 4) 956 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 957 else 958 writer.append(OpCode::GetResultN, index); 959 writer.append(op.getInputOp(), op.getValue()); 960 } 961 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 962 Value result = op.getValue(); 963 Optional<uint32_t> index = op.getIndex(); 964 writer.append(OpCode::GetResults, 965 index.value_or(std::numeric_limits<uint32_t>::max()), 966 op.getInputOp()); 967 if (result.getType().isa<pdl::RangeType>()) 968 writer.append(getRangeStorageIndex(result)); 969 else 970 writer.append(std::numeric_limits<ByteCodeField>::max()); 971 writer.append(result); 972 } 973 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 974 Value operations = op.getOperations(); 975 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 976 writer.append(OpCode::GetUsers, operations, rangeIndex); 977 writer.appendPDLValue(op.getValue()); 978 } 979 void Generator::generate(pdl_interp::GetValueTypeOp op, 980 ByteCodeWriter &writer) { 981 if (op.getType().isa<pdl::RangeType>()) { 982 Value result = op.getResult(); 983 writer.append(OpCode::GetValueRangeTypes, result, 984 getRangeStorageIndex(result), op.getValue()); 985 } else { 986 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 987 } 988 } 989 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 990 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 991 } 992 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 993 ByteCodeField patternIndex = patterns.size(); 994 patterns.emplace_back(PDLByteCodePattern::create( 995 op, configMap.lookup(op), 996 rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 997 writer.append(OpCode::RecordMatch, patternIndex, 998 SuccessorRange(op.getOperation()), op.getMatchedOps()); 999 writer.appendPDLValueList(op.getInputs()); 1000 } 1001 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 1002 writer.append(OpCode::ReplaceOp, op.getInputOp()); 1003 writer.appendPDLValueList(op.getReplValues()); 1004 } 1005 void Generator::generate(pdl_interp::SwitchAttributeOp op, 1006 ByteCodeWriter &writer) { 1007 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 1008 op.getCaseValuesAttr(), op.getSuccessors()); 1009 } 1010 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 1011 ByteCodeWriter &writer) { 1012 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 1013 op.getCaseValuesAttr(), op.getSuccessors()); 1014 } 1015 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 1016 ByteCodeWriter &writer) { 1017 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 1018 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 1019 }); 1020 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 1021 op.getSuccessors()); 1022 } 1023 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1024 ByteCodeWriter &writer) { 1025 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1026 op.getCaseValuesAttr(), op.getSuccessors()); 1027 } 1028 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1029 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1030 op.getSuccessors()); 1031 } 1032 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1033 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1034 op.getSuccessors()); 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // PDLByteCode 1039 //===----------------------------------------------------------------------===// 1040 1041 PDLByteCode::PDLByteCode( 1042 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, 1043 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, 1044 llvm::StringMap<PDLConstraintFunction> constraintFns, 1045 llvm::StringMap<PDLRewriteFunction> rewriteFns) 1046 : configs(std::move(configs)) { 1047 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1048 rewriterByteCode, patterns, maxValueMemoryIndex, 1049 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1050 maxLoopLevel, constraintFns, rewriteFns, configMap); 1051 generator.generate(module); 1052 1053 // Initialize the external functions. 1054 for (auto &it : constraintFns) 1055 constraintFunctions.push_back(std::move(it.second)); 1056 for (auto &it : rewriteFns) 1057 rewriteFunctions.push_back(std::move(it.second)); 1058 } 1059 1060 /// Initialize the given state such that it can be used to execute the current 1061 /// bytecode. 1062 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1063 state.memory.resize(maxValueMemoryIndex, nullptr); 1064 state.opRangeMemory.resize(maxOpRangeCount); 1065 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1066 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1067 state.loopIndex.resize(maxLoopLevel, 0); 1068 state.currentPatternBenefits.reserve(patterns.size()); 1069 for (const PDLByteCodePattern &pattern : patterns) 1070 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1071 } 1072 1073 //===----------------------------------------------------------------------===// 1074 // ByteCode Execution 1075 1076 namespace { 1077 /// This class provides support for executing a bytecode stream. 1078 class ByteCodeExecutor { 1079 public: 1080 ByteCodeExecutor( 1081 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1082 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1083 MutableArrayRef<TypeRange> typeRangeMemory, 1084 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1085 MutableArrayRef<ValueRange> valueRangeMemory, 1086 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1087 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1088 ArrayRef<ByteCodeField> code, 1089 ArrayRef<PatternBenefit> currentPatternBenefits, 1090 ArrayRef<PDLByteCodePattern> patterns, 1091 ArrayRef<PDLConstraintFunction> constraintFunctions, 1092 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1093 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1094 typeRangeMemory(typeRangeMemory), 1095 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1096 valueRangeMemory(valueRangeMemory), 1097 allocatedValueRangeMemory(allocatedValueRangeMemory), 1098 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1099 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1100 constraintFunctions(constraintFunctions), 1101 rewriteFunctions(rewriteFunctions) {} 1102 1103 /// Start executing the code at the current bytecode index. `matches` is an 1104 /// optional field provided when this function is executed in a matching 1105 /// context. 1106 LogicalResult 1107 execute(PatternRewriter &rewriter, 1108 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1109 Optional<Location> mainRewriteLoc = {}); 1110 1111 private: 1112 /// Internal implementation of executing each of the bytecode commands. 1113 void executeApplyConstraint(PatternRewriter &rewriter); 1114 LogicalResult executeApplyRewrite(PatternRewriter &rewriter); 1115 void executeAreEqual(); 1116 void executeAreRangesEqual(); 1117 void executeBranch(); 1118 void executeCheckOperandCount(); 1119 void executeCheckOperationName(); 1120 void executeCheckResultCount(); 1121 void executeCheckTypes(); 1122 void executeContinue(); 1123 void executeCreateConstantTypeRange(); 1124 void executeCreateOperation(PatternRewriter &rewriter, 1125 Location mainRewriteLoc); 1126 template <typename T> 1127 void executeDynamicCreateRange(StringRef type); 1128 void executeEraseOp(PatternRewriter &rewriter); 1129 template <typename T, typename Range, PDLValue::Kind kind> 1130 void executeExtract(); 1131 void executeFinalize(); 1132 void executeForEach(); 1133 void executeGetAttribute(); 1134 void executeGetAttributeType(); 1135 void executeGetDefiningOp(); 1136 void executeGetOperand(unsigned index); 1137 void executeGetOperands(); 1138 void executeGetResult(unsigned index); 1139 void executeGetResults(); 1140 void executeGetUsers(); 1141 void executeGetValueType(); 1142 void executeGetValueRangeTypes(); 1143 void executeIsNotNull(); 1144 void executeRecordMatch(PatternRewriter &rewriter, 1145 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1146 void executeReplaceOp(PatternRewriter &rewriter); 1147 void executeSwitchAttribute(); 1148 void executeSwitchOperandCount(); 1149 void executeSwitchOperationName(); 1150 void executeSwitchResultCount(); 1151 void executeSwitchType(); 1152 void executeSwitchTypes(); 1153 1154 /// Pushes a code iterator to the stack. 1155 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1156 1157 /// Pops a code iterator from the stack, returning true on success. 1158 void popCodeIt() { 1159 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1160 curCodeIt = resumeCodeIt.back(); 1161 resumeCodeIt.pop_back(); 1162 } 1163 1164 /// Return the bytecode iterator at the start of the current op code. 1165 const ByteCodeField *getPrevCodeIt() const { 1166 LLVM_DEBUG({ 1167 // Account for the op code and the Location stored inline. 1168 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1169 }); 1170 1171 // Account for the op code only. 1172 return curCodeIt - 1; 1173 } 1174 1175 /// Read a value from the bytecode buffer, optionally skipping a certain 1176 /// number of prefix values. These methods always update the buffer to point 1177 /// to the next field after the read data. 1178 template <typename T = ByteCodeField> 1179 T read(size_t skipN = 0) { 1180 curCodeIt += skipN; 1181 return readImpl<T>(); 1182 } 1183 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1184 1185 /// Read a list of values from the bytecode buffer. 1186 template <typename ValueT, typename T> 1187 void readList(SmallVectorImpl<T> &list) { 1188 list.clear(); 1189 for (unsigned i = 0, e = read(); i != e; ++i) 1190 list.push_back(read<ValueT>()); 1191 } 1192 1193 /// Read a list of values from the bytecode buffer. The values may be encoded 1194 /// either as a single element or a range of elements. 1195 void readList(SmallVectorImpl<Type> &list) { 1196 for (unsigned i = 0, e = read(); i != e; ++i) { 1197 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1198 list.push_back(read<Type>()); 1199 } else { 1200 TypeRange *values = read<TypeRange *>(); 1201 list.append(values->begin(), values->end()); 1202 } 1203 } 1204 } 1205 void readList(SmallVectorImpl<Value> &list) { 1206 for (unsigned i = 0, e = read(); i != e; ++i) { 1207 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1208 list.push_back(read<Value>()); 1209 } else { 1210 ValueRange *values = read<ValueRange *>(); 1211 list.append(values->begin(), values->end()); 1212 } 1213 } 1214 } 1215 1216 /// Read a value stored inline as a pointer. 1217 template <typename T> 1218 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1219 readInline() { 1220 const void *pointer; 1221 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1222 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1223 return T::getFromOpaquePointer(pointer); 1224 } 1225 1226 /// Jump to a specific successor based on a predicate value. 1227 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1228 /// Jump to a specific successor based on a destination index. 1229 void selectJump(size_t destIndex) { 1230 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1231 } 1232 1233 /// Handle a switch operation with the provided value and cases. 1234 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1235 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1236 LLVM_DEBUG({ 1237 llvm::dbgs() << " * Value: " << value << "\n" 1238 << " * Cases: "; 1239 llvm::interleaveComma(cases, llvm::dbgs()); 1240 llvm::dbgs() << "\n"; 1241 }); 1242 1243 // Check to see if the attribute value is within the case list. Jump to 1244 // the correct successor index based on the result. 1245 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1246 if (cmp(*it, value)) 1247 return selectJump(size_t((it - cases.begin()) + 1)); 1248 selectJump(size_t(0)); 1249 } 1250 1251 /// Store a pointer to memory. 1252 void storeToMemory(unsigned index, const void *value) { 1253 memory[index] = value; 1254 } 1255 1256 /// Store a value to memory as an opaque pointer. 1257 template <typename T> 1258 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1259 storeToMemory(unsigned index, T value) { 1260 memory[index] = value.getAsOpaquePointer(); 1261 } 1262 1263 /// Internal implementation of reading various data types from the bytecode 1264 /// stream. 1265 template <typename T> 1266 const void *readFromMemory() { 1267 size_t index = *curCodeIt++; 1268 1269 // If this type is an SSA value, it can only be stored in non-const memory. 1270 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1271 Value>::value || 1272 index < memory.size()) 1273 return memory[index]; 1274 1275 // Otherwise, if this index is not inbounds it is uniqued. 1276 return uniquedMemory[index - memory.size()]; 1277 } 1278 template <typename T> 1279 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1280 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1281 } 1282 template <typename T> 1283 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1284 T> 1285 readImpl() { 1286 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1287 } 1288 template <typename T> 1289 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1290 switch (read<PDLValue::Kind>()) { 1291 case PDLValue::Kind::Attribute: 1292 return read<Attribute>(); 1293 case PDLValue::Kind::Operation: 1294 return read<Operation *>(); 1295 case PDLValue::Kind::Type: 1296 return read<Type>(); 1297 case PDLValue::Kind::Value: 1298 return read<Value>(); 1299 case PDLValue::Kind::TypeRange: 1300 return read<TypeRange *>(); 1301 case PDLValue::Kind::ValueRange: 1302 return read<ValueRange *>(); 1303 } 1304 llvm_unreachable("unhandled PDLValue::Kind"); 1305 } 1306 template <typename T> 1307 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1308 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1309 "unexpected ByteCode address size"); 1310 ByteCodeAddr result; 1311 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1312 curCodeIt += 2; 1313 return result; 1314 } 1315 template <typename T> 1316 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1317 return *curCodeIt++; 1318 } 1319 template <typename T> 1320 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1321 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1322 } 1323 1324 /// Assign the given range to the given memory index. This allocates a new 1325 /// range object if necessary. 1326 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>> 1327 void assignRangeToMemory(RangeT &&range, unsigned memIndex, 1328 unsigned rangeIndex) { 1329 // Utility functor used to type-erase the assignment. 1330 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { 1331 // If the input range is empty, we don't need to allocate anything. 1332 if (range.empty()) { 1333 rangeMemory[rangeIndex] = {}; 1334 } else { 1335 // Allocate a buffer for this type range. 1336 llvm::OwningArrayRef<T> storage(llvm::size(range)); 1337 llvm::copy(range, storage.begin()); 1338 1339 // Assign this to the range slot and use the range as the value for the 1340 // memory index. 1341 allocatedRangeMemory.emplace_back(std::move(storage)); 1342 rangeMemory[rangeIndex] = allocatedRangeMemory.back(); 1343 } 1344 memory[memIndex] = &rangeMemory[rangeIndex]; 1345 }; 1346 1347 // Dispatch based on the concrete range type. 1348 if constexpr (std::is_same_v<T, Type>) { 1349 return assignRange(allocatedTypeRangeMemory, typeRangeMemory); 1350 } else if constexpr (std::is_same_v<T, Value>) { 1351 return assignRange(allocatedValueRangeMemory, valueRangeMemory); 1352 } else { 1353 llvm_unreachable("unhandled range type"); 1354 } 1355 } 1356 1357 /// The underlying bytecode buffer. 1358 const ByteCodeField *curCodeIt; 1359 1360 /// The stack of bytecode positions at which to resume operation. 1361 SmallVector<const ByteCodeField *> resumeCodeIt; 1362 1363 /// The current execution memory. 1364 MutableArrayRef<const void *> memory; 1365 MutableArrayRef<OwningOpRange> opRangeMemory; 1366 MutableArrayRef<TypeRange> typeRangeMemory; 1367 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1368 MutableArrayRef<ValueRange> valueRangeMemory; 1369 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1370 1371 /// The current loop indices. 1372 MutableArrayRef<unsigned> loopIndex; 1373 1374 /// References to ByteCode data necessary for execution. 1375 ArrayRef<const void *> uniquedMemory; 1376 ArrayRef<ByteCodeField> code; 1377 ArrayRef<PatternBenefit> currentPatternBenefits; 1378 ArrayRef<PDLByteCodePattern> patterns; 1379 ArrayRef<PDLConstraintFunction> constraintFunctions; 1380 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1381 }; 1382 1383 /// This class is an instantiation of the PDLResultList that provides access to 1384 /// the returned results. This API is not on `PDLResultList` to avoid 1385 /// overexposing access to information specific solely to the ByteCode. 1386 class ByteCodeRewriteResultList : public PDLResultList { 1387 public: 1388 ByteCodeRewriteResultList(unsigned maxNumResults) 1389 : PDLResultList(maxNumResults) {} 1390 1391 /// Return the list of PDL results. 1392 MutableArrayRef<PDLValue> getResults() { return results; } 1393 1394 /// Return the type ranges allocated by this list. 1395 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1396 return allocatedTypeRanges; 1397 } 1398 1399 /// Return the value ranges allocated by this list. 1400 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1401 return allocatedValueRanges; 1402 } 1403 }; 1404 } // namespace 1405 1406 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1407 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1408 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1409 SmallVector<PDLValue, 16> args; 1410 readList<PDLValue>(args); 1411 1412 LLVM_DEBUG({ 1413 llvm::dbgs() << " * Arguments: "; 1414 llvm::interleaveComma(args, llvm::dbgs()); 1415 }); 1416 1417 // Invoke the constraint and jump to the proper destination. 1418 selectJump(succeeded(constraintFn(rewriter, args))); 1419 } 1420 1421 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1422 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1423 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1424 SmallVector<PDLValue, 16> args; 1425 readList<PDLValue>(args); 1426 1427 LLVM_DEBUG({ 1428 llvm::dbgs() << " * Arguments: "; 1429 llvm::interleaveComma(args, llvm::dbgs()); 1430 }); 1431 1432 // Execute the rewrite function. 1433 ByteCodeField numResults = read(); 1434 ByteCodeRewriteResultList results(numResults); 1435 LogicalResult rewriteResult = rewriteFn(rewriter, results, args); 1436 1437 assert(results.getResults().size() == numResults && 1438 "native PDL rewrite function returned unexpected number of results"); 1439 1440 // Store the results in the bytecode memory. 1441 for (PDLValue &result : results.getResults()) { 1442 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1443 1444 // In debug mode we also verify the expected kind of the result. 1445 #ifndef NDEBUG 1446 assert(result.getKind() == read<PDLValue::Kind>() && 1447 "native PDL rewrite function returned an unexpected type of result"); 1448 #endif 1449 1450 // If the result is a range, we need to copy it over to the bytecodes 1451 // range memory. 1452 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1453 unsigned rangeIndex = read(); 1454 typeRangeMemory[rangeIndex] = *typeRange; 1455 memory[read()] = &typeRangeMemory[rangeIndex]; 1456 } else if (Optional<ValueRange> valueRange = 1457 result.dyn_cast<ValueRange>()) { 1458 unsigned rangeIndex = read(); 1459 valueRangeMemory[rangeIndex] = *valueRange; 1460 memory[read()] = &valueRangeMemory[rangeIndex]; 1461 } else { 1462 memory[read()] = result.getAsOpaquePointer(); 1463 } 1464 } 1465 1466 // Copy over any underlying storage allocated for result ranges. 1467 for (auto &it : results.getAllocatedTypeRanges()) 1468 allocatedTypeRangeMemory.push_back(std::move(it)); 1469 for (auto &it : results.getAllocatedValueRanges()) 1470 allocatedValueRangeMemory.push_back(std::move(it)); 1471 1472 // Process the result of the rewrite. 1473 if (failed(rewriteResult)) { 1474 LLVM_DEBUG(llvm::dbgs() << " - Failed"); 1475 return failure(); 1476 } 1477 return success(); 1478 } 1479 1480 void ByteCodeExecutor::executeAreEqual() { 1481 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1482 const void *lhs = read<const void *>(); 1483 const void *rhs = read<const void *>(); 1484 1485 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1486 selectJump(lhs == rhs); 1487 } 1488 1489 void ByteCodeExecutor::executeAreRangesEqual() { 1490 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1491 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1492 const void *lhs = read<const void *>(); 1493 const void *rhs = read<const void *>(); 1494 1495 switch (valueKind) { 1496 case PDLValue::Kind::TypeRange: { 1497 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1498 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1499 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1500 selectJump(*lhsRange == *rhsRange); 1501 break; 1502 } 1503 case PDLValue::Kind::ValueRange: { 1504 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1505 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1506 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1507 selectJump(*lhsRange == *rhsRange); 1508 break; 1509 } 1510 default: 1511 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1512 } 1513 } 1514 1515 void ByteCodeExecutor::executeBranch() { 1516 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1517 curCodeIt = &code[read<ByteCodeAddr>()]; 1518 } 1519 1520 void ByteCodeExecutor::executeCheckOperandCount() { 1521 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1522 Operation *op = read<Operation *>(); 1523 uint32_t expectedCount = read<uint32_t>(); 1524 bool compareAtLeast = read(); 1525 1526 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1527 << " * Expected: " << expectedCount << "\n" 1528 << " * Comparator: " 1529 << (compareAtLeast ? ">=" : "==") << "\n"); 1530 if (compareAtLeast) 1531 selectJump(op->getNumOperands() >= expectedCount); 1532 else 1533 selectJump(op->getNumOperands() == expectedCount); 1534 } 1535 1536 void ByteCodeExecutor::executeCheckOperationName() { 1537 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1538 Operation *op = read<Operation *>(); 1539 OperationName expectedName = read<OperationName>(); 1540 1541 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1542 << " * Expected: \"" << expectedName << "\"\n"); 1543 selectJump(op->getName() == expectedName); 1544 } 1545 1546 void ByteCodeExecutor::executeCheckResultCount() { 1547 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1548 Operation *op = read<Operation *>(); 1549 uint32_t expectedCount = read<uint32_t>(); 1550 bool compareAtLeast = read(); 1551 1552 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1553 << " * Expected: " << expectedCount << "\n" 1554 << " * Comparator: " 1555 << (compareAtLeast ? ">=" : "==") << "\n"); 1556 if (compareAtLeast) 1557 selectJump(op->getNumResults() >= expectedCount); 1558 else 1559 selectJump(op->getNumResults() == expectedCount); 1560 } 1561 1562 void ByteCodeExecutor::executeCheckTypes() { 1563 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1564 TypeRange *lhs = read<TypeRange *>(); 1565 Attribute rhs = read<Attribute>(); 1566 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1567 1568 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1569 } 1570 1571 void ByteCodeExecutor::executeContinue() { 1572 ByteCodeField level = read(); 1573 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1574 << " * Level: " << level << "\n"); 1575 ++loopIndex[level]; 1576 popCodeIt(); 1577 } 1578 1579 void ByteCodeExecutor::executeCreateConstantTypeRange() { 1580 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); 1581 unsigned memIndex = read(); 1582 unsigned rangeIndex = read(); 1583 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1584 1585 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1586 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex, 1587 rangeIndex); 1588 } 1589 1590 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1591 Location mainRewriteLoc) { 1592 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1593 1594 unsigned memIndex = read(); 1595 OperationState state(mainRewriteLoc, read<OperationName>()); 1596 readList(state.operands); 1597 for (unsigned i = 0, e = read(); i != e; ++i) { 1598 StringAttr name = read<StringAttr>(); 1599 if (Attribute attr = read<Attribute>()) 1600 state.addAttribute(name, attr); 1601 } 1602 1603 // Read in the result types. If the "size" is the sentinel value, this 1604 // indicates that the result types should be inferred. 1605 unsigned numResults = read(); 1606 if (numResults == kInferTypesMarker) { 1607 InferTypeOpInterface::Concept *inferInterface = 1608 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1609 assert(inferInterface && 1610 "expected operation to provide InferTypeOpInterface"); 1611 1612 // TODO: Handle failure. 1613 if (failed(inferInterface->inferReturnTypes( 1614 state.getContext(), state.location, state.operands, 1615 state.attributes.getDictionary(state.getContext()), state.regions, 1616 state.types))) 1617 return; 1618 } else { 1619 // Otherwise, this is a fixed number of results. 1620 for (unsigned i = 0; i != numResults; ++i) { 1621 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1622 state.types.push_back(read<Type>()); 1623 } else { 1624 TypeRange *resultTypes = read<TypeRange *>(); 1625 state.types.append(resultTypes->begin(), resultTypes->end()); 1626 } 1627 } 1628 } 1629 1630 Operation *resultOp = rewriter.create(state); 1631 memory[memIndex] = resultOp; 1632 1633 LLVM_DEBUG({ 1634 llvm::dbgs() << " * Attributes: " 1635 << state.attributes.getDictionary(state.getContext()) 1636 << "\n * Operands: "; 1637 llvm::interleaveComma(state.operands, llvm::dbgs()); 1638 llvm::dbgs() << "\n * Result Types: "; 1639 llvm::interleaveComma(state.types, llvm::dbgs()); 1640 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1641 }); 1642 } 1643 1644 template <typename T> 1645 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { 1646 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); 1647 unsigned memIndex = read(); 1648 unsigned rangeIndex = read(); 1649 SmallVector<T> values; 1650 readList(values); 1651 1652 LLVM_DEBUG({ 1653 llvm::dbgs() << "\n * " << type << "s: "; 1654 llvm::interleaveComma(values, llvm::dbgs()); 1655 llvm::dbgs() << "\n"; 1656 }); 1657 1658 assignRangeToMemory(values, memIndex, rangeIndex); 1659 } 1660 1661 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1662 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1663 Operation *op = read<Operation *>(); 1664 1665 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1666 rewriter.eraseOp(op); 1667 } 1668 1669 template <typename T, typename Range, PDLValue::Kind kind> 1670 void ByteCodeExecutor::executeExtract() { 1671 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1672 Range *range = read<Range *>(); 1673 unsigned index = read<uint32_t>(); 1674 unsigned memIndex = read(); 1675 1676 if (!range) { 1677 memory[memIndex] = nullptr; 1678 return; 1679 } 1680 1681 T result = index < range->size() ? (*range)[index] : T(); 1682 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1683 << " * Index: " << index << "\n" 1684 << " * Result: " << result << "\n"); 1685 storeToMemory(memIndex, result); 1686 } 1687 1688 void ByteCodeExecutor::executeFinalize() { 1689 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1690 } 1691 1692 void ByteCodeExecutor::executeForEach() { 1693 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1694 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1695 unsigned rangeIndex = read(); 1696 unsigned memIndex = read(); 1697 const void *value = nullptr; 1698 1699 switch (read<PDLValue::Kind>()) { 1700 case PDLValue::Kind::Operation: { 1701 unsigned &index = loopIndex[read()]; 1702 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1703 assert(index <= array.size() && "iterated past the end"); 1704 if (index < array.size()) { 1705 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1706 value = array[index]; 1707 break; 1708 } 1709 1710 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1711 index = 0; 1712 selectJump(size_t(0)); 1713 return; 1714 } 1715 default: 1716 llvm_unreachable("unexpected `ForEach` value kind"); 1717 } 1718 1719 // Store the iterate value and the stack address. 1720 memory[memIndex] = value; 1721 pushCodeIt(prevCodeIt); 1722 1723 // Skip over the successor (we will enter the body of the loop). 1724 read<ByteCodeAddr>(); 1725 } 1726 1727 void ByteCodeExecutor::executeGetAttribute() { 1728 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1729 unsigned memIndex = read(); 1730 Operation *op = read<Operation *>(); 1731 StringAttr attrName = read<StringAttr>(); 1732 Attribute attr = op->getAttr(attrName); 1733 1734 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1735 << " * Attribute: " << attrName << "\n" 1736 << " * Result: " << attr << "\n"); 1737 memory[memIndex] = attr.getAsOpaquePointer(); 1738 } 1739 1740 void ByteCodeExecutor::executeGetAttributeType() { 1741 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1742 unsigned memIndex = read(); 1743 Attribute attr = read<Attribute>(); 1744 Type type; 1745 if (auto typedAttr = attr.dyn_cast<TypedAttr>()) 1746 type = typedAttr.getType(); 1747 1748 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1749 << " * Result: " << type << "\n"); 1750 memory[memIndex] = type.getAsOpaquePointer(); 1751 } 1752 1753 void ByteCodeExecutor::executeGetDefiningOp() { 1754 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1755 unsigned memIndex = read(); 1756 Operation *op = nullptr; 1757 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1758 Value value = read<Value>(); 1759 if (value) 1760 op = value.getDefiningOp(); 1761 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1762 } else { 1763 ValueRange *values = read<ValueRange *>(); 1764 if (values && !values->empty()) { 1765 op = values->front().getDefiningOp(); 1766 } 1767 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1768 } 1769 1770 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1771 memory[memIndex] = op; 1772 } 1773 1774 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1775 Operation *op = read<Operation *>(); 1776 unsigned memIndex = read(); 1777 Value operand = 1778 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1779 1780 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1781 << " * Index: " << index << "\n" 1782 << " * Result: " << operand << "\n"); 1783 memory[memIndex] = operand.getAsOpaquePointer(); 1784 } 1785 1786 /// This function is the internal implementation of `GetResults` and 1787 /// `GetOperands` that provides support for extracting a value range from the 1788 /// given operation. 1789 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1790 static void * 1791 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1792 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1793 MutableArrayRef<ValueRange> valueRangeMemory) { 1794 // Check for the sentinel index that signals that all values should be 1795 // returned. 1796 if (index == std::numeric_limits<uint32_t>::max()) { 1797 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1798 // `values` is already the full value range. 1799 1800 // Otherwise, check to see if this operation uses AttrSizedSegments. 1801 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1802 LLVM_DEBUG(llvm::dbgs() 1803 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1804 1805 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments); 1806 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index) 1807 return nullptr; 1808 1809 ArrayRef<int32_t> segments = segmentAttr; 1810 unsigned startIndex = 1811 std::accumulate(segments.begin(), segments.begin() + index, 0); 1812 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1813 1814 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1815 << *std::next(segments.begin(), index) << "]\n"); 1816 1817 // Otherwise, assume this is the last operand group of the operation. 1818 // FIXME: We currently don't support operations with 1819 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1820 // have a way to detect it's presence. 1821 } else if (values.size() >= index) { 1822 LLVM_DEBUG(llvm::dbgs() 1823 << " * Treating values as trailing variadic range\n"); 1824 values = values.drop_front(index); 1825 1826 // If we couldn't detect a way to compute the values, bail out. 1827 } else { 1828 return nullptr; 1829 } 1830 1831 // If the range index is valid, we are returning a range. 1832 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1833 valueRangeMemory[rangeIndex] = values; 1834 return &valueRangeMemory[rangeIndex]; 1835 } 1836 1837 // If a range index wasn't provided, the range is required to be non-variadic. 1838 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1839 } 1840 1841 void ByteCodeExecutor::executeGetOperands() { 1842 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1843 unsigned index = read<uint32_t>(); 1844 Operation *op = read<Operation *>(); 1845 ByteCodeField rangeIndex = read(); 1846 1847 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1848 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1849 valueRangeMemory); 1850 if (!result) 1851 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1852 memory[read()] = result; 1853 } 1854 1855 void ByteCodeExecutor::executeGetResult(unsigned index) { 1856 Operation *op = read<Operation *>(); 1857 unsigned memIndex = read(); 1858 OpResult result = 1859 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1860 1861 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1862 << " * Index: " << index << "\n" 1863 << " * Result: " << result << "\n"); 1864 memory[memIndex] = result.getAsOpaquePointer(); 1865 } 1866 1867 void ByteCodeExecutor::executeGetResults() { 1868 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1869 unsigned index = read<uint32_t>(); 1870 Operation *op = read<Operation *>(); 1871 ByteCodeField rangeIndex = read(); 1872 1873 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1874 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1875 valueRangeMemory); 1876 if (!result) 1877 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1878 memory[read()] = result; 1879 } 1880 1881 void ByteCodeExecutor::executeGetUsers() { 1882 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1883 unsigned memIndex = read(); 1884 unsigned rangeIndex = read(); 1885 OwningOpRange &range = opRangeMemory[rangeIndex]; 1886 memory[memIndex] = ⦥ 1887 1888 range = OwningOpRange(); 1889 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1890 // Read the value. 1891 Value value = read<Value>(); 1892 if (!value) 1893 return; 1894 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1895 1896 // Extract the users of a single value. 1897 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1898 llvm::copy(value.getUsers(), range.begin()); 1899 } else { 1900 // Read a range of values. 1901 ValueRange *values = read<ValueRange *>(); 1902 if (!values) 1903 return; 1904 LLVM_DEBUG({ 1905 llvm::dbgs() << " * Values (" << values->size() << "): "; 1906 llvm::interleaveComma(*values, llvm::dbgs()); 1907 llvm::dbgs() << "\n"; 1908 }); 1909 1910 // Extract all the users of a range of values. 1911 SmallVector<Operation *> users; 1912 for (Value value : *values) 1913 users.append(value.user_begin(), value.user_end()); 1914 range = OwningOpRange(users.size()); 1915 llvm::copy(users, range.begin()); 1916 } 1917 1918 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1919 } 1920 1921 void ByteCodeExecutor::executeGetValueType() { 1922 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1923 unsigned memIndex = read(); 1924 Value value = read<Value>(); 1925 Type type = value ? value.getType() : Type(); 1926 1927 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1928 << " * Result: " << type << "\n"); 1929 memory[memIndex] = type.getAsOpaquePointer(); 1930 } 1931 1932 void ByteCodeExecutor::executeGetValueRangeTypes() { 1933 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1934 unsigned memIndex = read(); 1935 unsigned rangeIndex = read(); 1936 ValueRange *values = read<ValueRange *>(); 1937 if (!values) { 1938 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1939 memory[memIndex] = nullptr; 1940 return; 1941 } 1942 1943 LLVM_DEBUG({ 1944 llvm::dbgs() << " * Values (" << values->size() << "): "; 1945 llvm::interleaveComma(*values, llvm::dbgs()); 1946 llvm::dbgs() << "\n * Result: "; 1947 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1948 llvm::dbgs() << "\n"; 1949 }); 1950 typeRangeMemory[rangeIndex] = values->getType(); 1951 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1952 } 1953 1954 void ByteCodeExecutor::executeIsNotNull() { 1955 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1956 const void *value = read<const void *>(); 1957 1958 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1959 selectJump(value != nullptr); 1960 } 1961 1962 void ByteCodeExecutor::executeRecordMatch( 1963 PatternRewriter &rewriter, 1964 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1965 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1966 unsigned patternIndex = read(); 1967 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1968 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1969 1970 // If the benefit of the pattern is impossible, skip the processing of the 1971 // rest of the pattern. 1972 if (benefit.isImpossibleToMatch()) { 1973 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1974 curCodeIt = dest; 1975 return; 1976 } 1977 1978 // Create a fused location containing the locations of each of the 1979 // operations used in the match. This will be used as the location for 1980 // created operations during the rewrite that don't already have an 1981 // explicit location set. 1982 unsigned numMatchLocs = read(); 1983 SmallVector<Location, 4> matchLocs; 1984 matchLocs.reserve(numMatchLocs); 1985 for (unsigned i = 0; i != numMatchLocs; ++i) 1986 matchLocs.push_back(read<Operation *>()->getLoc()); 1987 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1988 1989 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1990 << " * Location: " << matchLoc << "\n"); 1991 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1992 PDLByteCode::MatchResult &match = matches.back(); 1993 1994 // Record all of the inputs to the match. If any of the inputs are ranges, we 1995 // will also need to remap the range pointer to memory stored in the match 1996 // state. 1997 unsigned numInputs = read(); 1998 match.values.reserve(numInputs); 1999 match.typeRangeValues.reserve(numInputs); 2000 match.valueRangeValues.reserve(numInputs); 2001 for (unsigned i = 0; i < numInputs; ++i) { 2002 switch (read<PDLValue::Kind>()) { 2003 case PDLValue::Kind::TypeRange: 2004 match.typeRangeValues.push_back(*read<TypeRange *>()); 2005 match.values.push_back(&match.typeRangeValues.back()); 2006 break; 2007 case PDLValue::Kind::ValueRange: 2008 match.valueRangeValues.push_back(*read<ValueRange *>()); 2009 match.values.push_back(&match.valueRangeValues.back()); 2010 break; 2011 default: 2012 match.values.push_back(read<const void *>()); 2013 break; 2014 } 2015 } 2016 curCodeIt = dest; 2017 } 2018 2019 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 2020 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 2021 Operation *op = read<Operation *>(); 2022 SmallVector<Value, 16> args; 2023 readList(args); 2024 2025 LLVM_DEBUG({ 2026 llvm::dbgs() << " * Operation: " << *op << "\n" 2027 << " * Values: "; 2028 llvm::interleaveComma(args, llvm::dbgs()); 2029 llvm::dbgs() << "\n"; 2030 }); 2031 rewriter.replaceOp(op, args); 2032 } 2033 2034 void ByteCodeExecutor::executeSwitchAttribute() { 2035 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 2036 Attribute value = read<Attribute>(); 2037 ArrayAttr cases = read<ArrayAttr>(); 2038 handleSwitch(value, cases); 2039 } 2040 2041 void ByteCodeExecutor::executeSwitchOperandCount() { 2042 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 2043 Operation *op = read<Operation *>(); 2044 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 2045 2046 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 2047 handleSwitch(op->getNumOperands(), cases); 2048 } 2049 2050 void ByteCodeExecutor::executeSwitchOperationName() { 2051 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 2052 OperationName value = read<Operation *>()->getName(); 2053 size_t caseCount = read(); 2054 2055 // The operation names are stored in-line, so to print them out for 2056 // debugging purposes we need to read the array before executing the 2057 // switch so that we can display all of the possible values. 2058 LLVM_DEBUG({ 2059 const ByteCodeField *prevCodeIt = curCodeIt; 2060 llvm::dbgs() << " * Value: " << value << "\n" 2061 << " * Cases: "; 2062 llvm::interleaveComma( 2063 llvm::map_range(llvm::seq<size_t>(0, caseCount), 2064 [&](size_t) { return read<OperationName>(); }), 2065 llvm::dbgs()); 2066 llvm::dbgs() << "\n"; 2067 curCodeIt = prevCodeIt; 2068 }); 2069 2070 // Try to find the switch value within any of the cases. 2071 for (size_t i = 0; i != caseCount; ++i) { 2072 if (read<OperationName>() == value) { 2073 curCodeIt += (caseCount - i - 1); 2074 return selectJump(i + 1); 2075 } 2076 } 2077 selectJump(size_t(0)); 2078 } 2079 2080 void ByteCodeExecutor::executeSwitchResultCount() { 2081 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 2082 Operation *op = read<Operation *>(); 2083 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 2084 2085 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 2086 handleSwitch(op->getNumResults(), cases); 2087 } 2088 2089 void ByteCodeExecutor::executeSwitchType() { 2090 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2091 Type value = read<Type>(); 2092 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2093 handleSwitch(value, cases); 2094 } 2095 2096 void ByteCodeExecutor::executeSwitchTypes() { 2097 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2098 TypeRange *value = read<TypeRange *>(); 2099 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2100 if (!value) { 2101 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2102 return selectJump(size_t(0)); 2103 } 2104 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2105 return value == caseValue.getAsValueRange<TypeAttr>(); 2106 }); 2107 } 2108 2109 LogicalResult 2110 ByteCodeExecutor::execute(PatternRewriter &rewriter, 2111 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2112 Optional<Location> mainRewriteLoc) { 2113 while (true) { 2114 // Print the location of the operation being executed. 2115 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2116 2117 OpCode opCode = static_cast<OpCode>(read()); 2118 switch (opCode) { 2119 case ApplyConstraint: 2120 executeApplyConstraint(rewriter); 2121 break; 2122 case ApplyRewrite: 2123 if (failed(executeApplyRewrite(rewriter))) 2124 return failure(); 2125 break; 2126 case AreEqual: 2127 executeAreEqual(); 2128 break; 2129 case AreRangesEqual: 2130 executeAreRangesEqual(); 2131 break; 2132 case Branch: 2133 executeBranch(); 2134 break; 2135 case CheckOperandCount: 2136 executeCheckOperandCount(); 2137 break; 2138 case CheckOperationName: 2139 executeCheckOperationName(); 2140 break; 2141 case CheckResultCount: 2142 executeCheckResultCount(); 2143 break; 2144 case CheckTypes: 2145 executeCheckTypes(); 2146 break; 2147 case Continue: 2148 executeContinue(); 2149 break; 2150 case CreateConstantTypeRange: 2151 executeCreateConstantTypeRange(); 2152 break; 2153 case CreateOperation: 2154 executeCreateOperation(rewriter, *mainRewriteLoc); 2155 break; 2156 case CreateDynamicTypeRange: 2157 executeDynamicCreateRange<Type>("Type"); 2158 break; 2159 case CreateDynamicValueRange: 2160 executeDynamicCreateRange<Value>("Value"); 2161 break; 2162 case EraseOp: 2163 executeEraseOp(rewriter); 2164 break; 2165 case ExtractOp: 2166 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2167 break; 2168 case ExtractType: 2169 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2170 break; 2171 case ExtractValue: 2172 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2173 break; 2174 case Finalize: 2175 executeFinalize(); 2176 LLVM_DEBUG(llvm::dbgs() << "\n"); 2177 return success(); 2178 case ForEach: 2179 executeForEach(); 2180 break; 2181 case GetAttribute: 2182 executeGetAttribute(); 2183 break; 2184 case GetAttributeType: 2185 executeGetAttributeType(); 2186 break; 2187 case GetDefiningOp: 2188 executeGetDefiningOp(); 2189 break; 2190 case GetOperand0: 2191 case GetOperand1: 2192 case GetOperand2: 2193 case GetOperand3: { 2194 unsigned index = opCode - GetOperand0; 2195 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2196 executeGetOperand(index); 2197 break; 2198 } 2199 case GetOperandN: 2200 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2201 executeGetOperand(read<uint32_t>()); 2202 break; 2203 case GetOperands: 2204 executeGetOperands(); 2205 break; 2206 case GetResult0: 2207 case GetResult1: 2208 case GetResult2: 2209 case GetResult3: { 2210 unsigned index = opCode - GetResult0; 2211 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2212 executeGetResult(index); 2213 break; 2214 } 2215 case GetResultN: 2216 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2217 executeGetResult(read<uint32_t>()); 2218 break; 2219 case GetResults: 2220 executeGetResults(); 2221 break; 2222 case GetUsers: 2223 executeGetUsers(); 2224 break; 2225 case GetValueType: 2226 executeGetValueType(); 2227 break; 2228 case GetValueRangeTypes: 2229 executeGetValueRangeTypes(); 2230 break; 2231 case IsNotNull: 2232 executeIsNotNull(); 2233 break; 2234 case RecordMatch: 2235 assert(matches && 2236 "expected matches to be provided when executing the matcher"); 2237 executeRecordMatch(rewriter, *matches); 2238 break; 2239 case ReplaceOp: 2240 executeReplaceOp(rewriter); 2241 break; 2242 case SwitchAttribute: 2243 executeSwitchAttribute(); 2244 break; 2245 case SwitchOperandCount: 2246 executeSwitchOperandCount(); 2247 break; 2248 case SwitchOperationName: 2249 executeSwitchOperationName(); 2250 break; 2251 case SwitchResultCount: 2252 executeSwitchResultCount(); 2253 break; 2254 case SwitchType: 2255 executeSwitchType(); 2256 break; 2257 case SwitchTypes: 2258 executeSwitchTypes(); 2259 break; 2260 } 2261 LLVM_DEBUG(llvm::dbgs() << "\n"); 2262 } 2263 } 2264 2265 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2266 SmallVectorImpl<MatchResult> &matches, 2267 PDLByteCodeMutableState &state) const { 2268 // The first memory slot is always the root operation. 2269 state.memory[0] = op; 2270 2271 // The matcher function always starts at code address 0. 2272 ByteCodeExecutor executor( 2273 matcherByteCode.data(), state.memory, state.opRangeMemory, 2274 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2275 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2276 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2277 constraintFunctions, rewriteFunctions); 2278 LogicalResult executeResult = executor.execute(rewriter, &matches); 2279 assert(succeeded(executeResult) && "unexpected matcher execution failure"); 2280 2281 // Order the found matches by benefit. 2282 std::stable_sort(matches.begin(), matches.end(), 2283 [](const MatchResult &lhs, const MatchResult &rhs) { 2284 return lhs.benefit > rhs.benefit; 2285 }); 2286 } 2287 2288 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, 2289 const MatchResult &match, 2290 PDLByteCodeMutableState &state) const { 2291 auto *configSet = match.pattern->getConfigSet(); 2292 if (configSet) 2293 configSet->notifyRewriteBegin(rewriter); 2294 2295 // The arguments of the rewrite function are stored at the start of the 2296 // memory buffer. 2297 llvm::copy(match.values, state.memory.begin()); 2298 2299 ByteCodeExecutor executor( 2300 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2301 state.opRangeMemory, state.typeRangeMemory, 2302 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2303 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2304 rewriterByteCode, state.currentPatternBenefits, patterns, 2305 constraintFunctions, rewriteFunctions); 2306 LogicalResult result = 2307 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2308 2309 if (configSet) 2310 configSet->notifyRewriteEnd(rewriter); 2311 2312 // If the rewrite failed, check if the pattern rewriter can recover. If it 2313 // can, we can signal to the pattern applicator to keep trying patterns. If it 2314 // doesn't, we need to bail. Bailing here should be fine, given that we have 2315 // no means to propagate such a failure to the user, and it also indicates a 2316 // bug in the user code (i.e. failable rewrites should not be used with 2317 // pattern rewriters that don't support it). 2318 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { 2319 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); 2320 llvm::report_fatal_error( 2321 "Native PDL Rewrite failed, but the pattern " 2322 "rewriter doesn't support recovery. Failable pattern rewrites should " 2323 "not be used with pattern rewriters that do not support them."); 2324 } 2325 return result; 2326 } 2327