1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.getBenefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 /// A marker used to indicate if an operation should infer types. 166 static constexpr ByteCodeField kInferTypesMarker = 167 std::numeric_limits<ByteCodeField>::max(); 168 169 //===----------------------------------------------------------------------===// 170 // ByteCode Generation 171 //===----------------------------------------------------------------------===// 172 173 //===----------------------------------------------------------------------===// 174 // Generator 175 176 namespace { 177 struct ByteCodeLiveRange; 178 struct ByteCodeWriter; 179 180 /// Check if the given class `T` can be converted to an opaque pointer. 181 template <typename T, typename... Args> 182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 183 184 /// This class represents the main generator for the pattern bytecode. 185 class Generator { 186 public: 187 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 188 SmallVectorImpl<ByteCodeField> &matcherByteCode, 189 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 190 SmallVectorImpl<PDLByteCodePattern> &patterns, 191 ByteCodeField &maxValueMemoryIndex, 192 ByteCodeField &maxOpRangeMemoryIndex, 193 ByteCodeField &maxTypeRangeMemoryIndex, 194 ByteCodeField &maxValueRangeMemoryIndex, 195 ByteCodeField &maxLoopLevel, 196 llvm::StringMap<PDLConstraintFunction> &constraintFns, 197 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 198 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 199 rewriterByteCode(rewriterByteCode), patterns(patterns), 200 maxValueMemoryIndex(maxValueMemoryIndex), 201 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 202 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 203 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 204 maxLoopLevel(maxLoopLevel) { 205 for (const auto &it : llvm::enumerate(constraintFns)) 206 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 207 for (const auto &it : llvm::enumerate(rewriteFns)) 208 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 209 } 210 211 /// Generate the bytecode for the given PDL interpreter module. 212 void generate(ModuleOp module); 213 214 /// Return the memory index to use for the given value. 215 ByteCodeField &getMemIndex(Value value) { 216 assert(valueToMemIndex.count(value) && 217 "expected memory index to be assigned"); 218 return valueToMemIndex[value]; 219 } 220 221 /// Return the range memory index used to store the given range value. 222 ByteCodeField &getRangeStorageIndex(Value value) { 223 assert(valueToRangeIndex.count(value) && 224 "expected range index to be assigned"); 225 return valueToRangeIndex[value]; 226 } 227 228 /// Return an index to use when referring to the given data that is uniqued in 229 /// the MLIR context. 230 template <typename T> 231 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 232 getMemIndex(T val) { 233 const void *opaqueVal = val.getAsOpaquePointer(); 234 235 // Get or insert a reference to this value. 236 auto it = uniquedDataToMemIndex.try_emplace( 237 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 238 if (it.second) 239 uniquedData.push_back(opaqueVal); 240 return it.first->second; 241 } 242 243 private: 244 /// Allocate memory indices for the results of operations within the matcher 245 /// and rewriters. 246 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 247 ModuleOp rewriterModule); 248 249 /// Generate the bytecode for the given operation. 250 void generate(Region *region, ByteCodeWriter &writer); 251 void generate(Operation *op, ByteCodeWriter &writer); 252 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 286 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 287 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 288 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 289 290 /// Mapping from value to its corresponding memory index. 291 DenseMap<Value, ByteCodeField> valueToMemIndex; 292 293 /// Mapping from a range value to its corresponding range storage index. 294 DenseMap<Value, ByteCodeField> valueToRangeIndex; 295 296 /// Mapping from the name of an externally registered rewrite to its index in 297 /// the bytecode registry. 298 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 299 300 /// Mapping from the name of an externally registered constraint to its index 301 /// in the bytecode registry. 302 llvm::StringMap<ByteCodeField> constraintToMemIndex; 303 304 /// Mapping from rewriter function name to the bytecode address of the 305 /// rewriter function in byte. 306 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 307 308 /// Mapping from a uniqued storage object to its memory index within 309 /// `uniquedData`. 310 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 311 312 /// The current level of the foreach loop. 313 ByteCodeField curLoopLevel = 0; 314 315 /// The current MLIR context. 316 MLIRContext *ctx; 317 318 /// Mapping from block to its address. 319 DenseMap<Block *, ByteCodeAddr> blockToAddr; 320 321 /// Data of the ByteCode class to be populated. 322 std::vector<const void *> &uniquedData; 323 SmallVectorImpl<ByteCodeField> &matcherByteCode; 324 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 325 SmallVectorImpl<PDLByteCodePattern> &patterns; 326 ByteCodeField &maxValueMemoryIndex; 327 ByteCodeField &maxOpRangeMemoryIndex; 328 ByteCodeField &maxTypeRangeMemoryIndex; 329 ByteCodeField &maxValueRangeMemoryIndex; 330 ByteCodeField &maxLoopLevel; 331 }; 332 333 /// This class provides utilities for writing a bytecode stream. 334 struct ByteCodeWriter { 335 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 336 : bytecode(bytecode), generator(generator) {} 337 338 /// Append a field to the bytecode. 339 void append(ByteCodeField field) { bytecode.push_back(field); } 340 void append(OpCode opCode) { bytecode.push_back(opCode); } 341 342 /// Append an address to the bytecode. 343 void append(ByteCodeAddr field) { 344 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 345 "unexpected ByteCode address size"); 346 347 ByteCodeField fieldParts[2]; 348 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 349 bytecode.append({fieldParts[0], fieldParts[1]}); 350 } 351 352 /// Append a single successor to the bytecode, the exact address will need to 353 /// be resolved later. 354 void append(Block *successor) { 355 // Add back a reference to the successor so that the address can be resolved 356 // later. 357 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 358 append(ByteCodeAddr(0)); 359 } 360 361 /// Append a successor range to the bytecode, the exact address will need to 362 /// be resolved later. 363 void append(SuccessorRange successors) { 364 for (Block *successor : successors) 365 append(successor); 366 } 367 368 /// Append a range of values that will be read as generic PDLValues. 369 void appendPDLValueList(OperandRange values) { 370 bytecode.push_back(values.size()); 371 for (Value value : values) 372 appendPDLValue(value); 373 } 374 375 /// Append a value as a PDLValue. 376 void appendPDLValue(Value value) { 377 appendPDLValueKind(value); 378 append(value); 379 } 380 381 /// Append the PDLValue::Kind of the given value. 382 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 383 384 /// Append the PDLValue::Kind of the given type. 385 void appendPDLValueKind(Type type) { 386 PDLValue::Kind kind = 387 TypeSwitch<Type, PDLValue::Kind>(type) 388 .Case<pdl::AttributeType>( 389 [](Type) { return PDLValue::Kind::Attribute; }) 390 .Case<pdl::OperationType>( 391 [](Type) { return PDLValue::Kind::Operation; }) 392 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 393 if (rangeTy.getElementType().isa<pdl::TypeType>()) 394 return PDLValue::Kind::TypeRange; 395 return PDLValue::Kind::ValueRange; 396 }) 397 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 398 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 399 bytecode.push_back(static_cast<ByteCodeField>(kind)); 400 } 401 402 /// Append a value that will be stored in a memory slot and not inline within 403 /// the bytecode. 404 template <typename T> 405 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 406 std::is_pointer<T>::value> 407 append(T value) { 408 bytecode.push_back(generator.getMemIndex(value)); 409 } 410 411 /// Append a range of values. 412 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 413 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 414 append(T range) { 415 bytecode.push_back(llvm::size(range)); 416 for (auto it : range) 417 append(it); 418 } 419 420 /// Append a variadic number of fields to the bytecode. 421 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 422 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 423 append(field); 424 append(field2, fields...); 425 } 426 427 /// Appends a value as a pointer, stored inline within the bytecode. 428 template <typename T> 429 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 430 appendInline(T value) { 431 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 432 const void *pointer = value.getAsOpaquePointer(); 433 ByteCodeField fieldParts[numParts]; 434 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 435 bytecode.append(fieldParts, fieldParts + numParts); 436 } 437 438 /// Successor references in the bytecode that have yet to be resolved. 439 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 440 441 /// The underlying bytecode buffer. 442 SmallVectorImpl<ByteCodeField> &bytecode; 443 444 /// The main generator producing PDL. 445 Generator &generator; 446 }; 447 448 /// This class represents a live range of PDL Interpreter values, containing 449 /// information about when values are live within a match/rewrite. 450 struct ByteCodeLiveRange { 451 using Set = llvm::IntervalMap<uint64_t, char, 16>; 452 using Allocator = Set::Allocator; 453 454 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 455 456 /// Union this live range with the one provided. 457 void unionWith(const ByteCodeLiveRange &rhs) { 458 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 459 ++it) 460 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 461 } 462 463 /// Returns true if this range overlaps with the one provided. 464 bool overlaps(const ByteCodeLiveRange &rhs) const { 465 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 466 .valid(); 467 } 468 469 /// A map representing the ranges of the match/rewrite that a value is live in 470 /// the interpreter. 471 /// 472 /// We use std::unique_ptr here, because IntervalMap does not provide a 473 /// correct copy or move constructor. We can eliminate the pointer once 474 /// https://reviews.llvm.org/D113240 lands. 475 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 476 477 /// The operation range storage index for this range. 478 Optional<unsigned> opRangeIndex; 479 480 /// The type range storage index for this range. 481 Optional<unsigned> typeRangeIndex; 482 483 /// The value range storage index for this range. 484 Optional<unsigned> valueRangeIndex; 485 }; 486 } // namespace 487 488 void Generator::generate(ModuleOp module) { 489 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 490 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 491 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 492 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 493 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 494 495 // Allocate memory indices for the results of operations within the matcher 496 // and rewriters. 497 allocateMemoryIndices(matcherFunc, rewriterModule); 498 499 // Generate code for the rewriter functions. 500 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 501 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 502 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 503 for (Operation &op : rewriterFunc.getOps()) 504 generate(&op, rewriterByteCodeWriter); 505 } 506 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 507 "unexpected branches in rewriter function"); 508 509 // Generate code for the matcher function. 510 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 511 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 512 513 // Resolve successor references in the matcher. 514 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 515 ByteCodeAddr addr = blockToAddr[it.first]; 516 for (unsigned offsetToFix : it.second) 517 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 518 } 519 } 520 521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 522 ModuleOp rewriterModule) { 523 // Rewriters use simplistic allocation scheme that simply assigns an index to 524 // each result. 525 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 526 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 527 auto processRewriterValue = [&](Value val) { 528 valueToMemIndex.try_emplace(val, index++); 529 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 530 Type elementTy = rangeType.getElementType(); 531 if (elementTy.isa<pdl::TypeType>()) 532 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 533 else if (elementTy.isa<pdl::ValueType>()) 534 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 535 } 536 }; 537 538 for (BlockArgument arg : rewriterFunc.getArguments()) 539 processRewriterValue(arg); 540 rewriterFunc.getBody().walk([&](Operation *op) { 541 for (Value result : op->getResults()) 542 processRewriterValue(result); 543 }); 544 if (index > maxValueMemoryIndex) 545 maxValueMemoryIndex = index; 546 if (typeRangeIndex > maxTypeRangeMemoryIndex) 547 maxTypeRangeMemoryIndex = typeRangeIndex; 548 if (valueRangeIndex > maxValueRangeMemoryIndex) 549 maxValueRangeMemoryIndex = valueRangeIndex; 550 } 551 552 // The matcher function uses a more sophisticated numbering that tries to 553 // minimize the number of memory indices assigned. This is done by determining 554 // a live range of the values within the matcher, then the allocation is just 555 // finding the minimal number of overlapping live ranges. This is essentially 556 // a simplified form of register allocation where we don't necessarily have a 557 // limited number of registers, but we still want to minimize the number used. 558 DenseMap<Operation *, unsigned> opToFirstIndex; 559 DenseMap<Operation *, unsigned> opToLastIndex; 560 561 // A custom walk that marks the first and the last index of each operation. 562 // The entry marks the beginning of the liveness range for this operation, 563 // followed by nested operations, followed by the end of the liveness range. 564 unsigned index = 0; 565 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 566 opToFirstIndex.try_emplace(op, index++); 567 for (Region ®ion : op->getRegions()) 568 for (Block &block : region.getBlocks()) 569 for (Operation &nested : block) 570 walk(&nested); 571 opToLastIndex.try_emplace(op, index++); 572 }; 573 walk(matcherFunc); 574 575 // Liveness info for each of the defs within the matcher. 576 ByteCodeLiveRange::Allocator allocator; 577 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 578 579 // Assign the root operation being matched to slot 0. 580 BlockArgument rootOpArg = matcherFunc.getArgument(0); 581 valueToMemIndex[rootOpArg] = 0; 582 583 // Walk each of the blocks, computing the def interval that the value is used. 584 Liveness matcherLiveness(matcherFunc); 585 matcherFunc->walk([&](Block *block) { 586 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 587 assert(info && "expected liveness info for block"); 588 auto processValue = [&](Value value, Operation *firstUseOrDef) { 589 // We don't need to process the root op argument, this value is always 590 // assigned to the first memory slot. 591 if (value == rootOpArg) 592 return; 593 594 // Set indices for the range of this block that the value is used. 595 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 596 defRangeIt->second.liveness->insert( 597 opToFirstIndex[firstUseOrDef], 598 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 599 /*dummyValue*/ 0); 600 601 // Check to see if this value is a range type. 602 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 603 Type eleType = rangeTy.getElementType(); 604 if (eleType.isa<pdl::OperationType>()) 605 defRangeIt->second.opRangeIndex = 0; 606 else if (eleType.isa<pdl::TypeType>()) 607 defRangeIt->second.typeRangeIndex = 0; 608 else if (eleType.isa<pdl::ValueType>()) 609 defRangeIt->second.valueRangeIndex = 0; 610 } 611 }; 612 613 // Process the live-ins of this block. 614 for (Value liveIn : info->in()) { 615 // Only process the value if it has been defined in the current region. 616 // Other values that span across pdl_interp.foreach will be added higher 617 // up. This ensures that the we keep them alive for the entire duration 618 // of the loop. 619 if (liveIn.getParentRegion() == block->getParent()) 620 processValue(liveIn, &block->front()); 621 } 622 623 // Process the block arguments for the entry block (those are not live-in). 624 if (block->isEntryBlock()) { 625 for (Value argument : block->getArguments()) 626 processValue(argument, &block->front()); 627 } 628 629 // Process any new defs within this block. 630 for (Operation &op : *block) 631 for (Value result : op.getResults()) 632 processValue(result, &op); 633 }); 634 635 // Greedily allocate memory slots using the computed def live ranges. 636 std::vector<ByteCodeLiveRange> allocatedIndices; 637 638 // The number of memory indices currently allocated (and its next value). 639 // Recall that the root gets allocated memory index 0. 640 ByteCodeField numIndices = 1; 641 642 // The number of memory ranges of various types (and their next values). 643 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 644 645 for (auto &defIt : valueDefRanges) { 646 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 647 ByteCodeLiveRange &defRange = defIt.second; 648 649 // Try to allocate to an existing index. 650 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 651 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 652 if (!defRange.overlaps(existingRange)) { 653 existingRange.unionWith(defRange); 654 memIndex = existingIndexIt.index() + 1; 655 656 if (defRange.opRangeIndex) { 657 if (!existingRange.opRangeIndex) 658 existingRange.opRangeIndex = numOpRanges++; 659 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 660 } else if (defRange.typeRangeIndex) { 661 if (!existingRange.typeRangeIndex) 662 existingRange.typeRangeIndex = numTypeRanges++; 663 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 664 } else if (defRange.valueRangeIndex) { 665 if (!existingRange.valueRangeIndex) 666 existingRange.valueRangeIndex = numValueRanges++; 667 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 668 } 669 break; 670 } 671 } 672 673 // If no existing index could be used, add a new one. 674 if (memIndex == 0) { 675 allocatedIndices.emplace_back(allocator); 676 ByteCodeLiveRange &newRange = allocatedIndices.back(); 677 newRange.unionWith(defRange); 678 679 // Allocate an index for op/type/value ranges. 680 if (defRange.opRangeIndex) { 681 newRange.opRangeIndex = numOpRanges; 682 valueToRangeIndex[defIt.first] = numOpRanges++; 683 } else if (defRange.typeRangeIndex) { 684 newRange.typeRangeIndex = numTypeRanges; 685 valueToRangeIndex[defIt.first] = numTypeRanges++; 686 } else if (defRange.valueRangeIndex) { 687 newRange.valueRangeIndex = numValueRanges; 688 valueToRangeIndex[defIt.first] = numValueRanges++; 689 } 690 691 memIndex = allocatedIndices.size(); 692 ++numIndices; 693 } 694 } 695 696 // Print the index usage and ensure that we did not run out of index space. 697 LLVM_DEBUG({ 698 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 699 << "(down from initial " << valueDefRanges.size() << ").\n"; 700 }); 701 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 702 "Ran out of memory for allocated indices"); 703 704 // Update the max number of indices. 705 if (numIndices > maxValueMemoryIndex) 706 maxValueMemoryIndex = numIndices; 707 if (numOpRanges > maxOpRangeMemoryIndex) 708 maxOpRangeMemoryIndex = numOpRanges; 709 if (numTypeRanges > maxTypeRangeMemoryIndex) 710 maxTypeRangeMemoryIndex = numTypeRanges; 711 if (numValueRanges > maxValueRangeMemoryIndex) 712 maxValueRangeMemoryIndex = numValueRanges; 713 } 714 715 void Generator::generate(Region *region, ByteCodeWriter &writer) { 716 llvm::ReversePostOrderTraversal<Region *> rpot(region); 717 for (Block *block : rpot) { 718 // Keep track of where this block begins within the matcher function. 719 blockToAddr.try_emplace(block, matcherByteCode.size()); 720 for (Operation &op : *block) 721 generate(&op, writer); 722 } 723 } 724 725 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 726 LLVM_DEBUG({ 727 // The following list must contain all the operations that do not 728 // produce any bytecode. 729 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op)) 730 writer.appendInline(op->getLoc()); 731 }); 732 TypeSwitch<Operation *>(op) 733 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 734 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 735 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 736 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 737 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 738 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 739 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 740 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 741 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 742 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 743 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 744 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 745 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 746 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 747 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 748 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 749 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, 750 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 751 pdl_interp::SwitchResultCountOp>( 752 [&](auto interpOp) { this->generate(interpOp, writer); }) 753 .Default([](Operation *) { 754 llvm_unreachable("unknown `pdl_interp` operation"); 755 }); 756 } 757 758 void Generator::generate(pdl_interp::ApplyConstraintOp op, 759 ByteCodeWriter &writer) { 760 assert(constraintToMemIndex.count(op.getName()) && 761 "expected index for constraint function"); 762 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 763 writer.appendPDLValueList(op.getArgs()); 764 writer.append(op.getSuccessors()); 765 } 766 void Generator::generate(pdl_interp::ApplyRewriteOp op, 767 ByteCodeWriter &writer) { 768 assert(externalRewriterToMemIndex.count(op.getName()) && 769 "expected index for rewrite function"); 770 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 771 writer.appendPDLValueList(op.getArgs()); 772 773 ResultRange results = op.getResults(); 774 writer.append(ByteCodeField(results.size())); 775 for (Value result : results) { 776 // In debug mode we also record the expected kind of the result, so that we 777 // can provide extra verification of the native rewrite function. 778 #ifndef NDEBUG 779 writer.appendPDLValueKind(result); 780 #endif 781 782 // Range results also need to append the range storage index. 783 if (result.getType().isa<pdl::RangeType>()) 784 writer.append(getRangeStorageIndex(result)); 785 writer.append(result); 786 } 787 } 788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 789 Value lhs = op.getLhs(); 790 if (lhs.getType().isa<pdl::RangeType>()) { 791 writer.append(OpCode::AreRangesEqual); 792 writer.appendPDLValueKind(lhs); 793 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 794 return; 795 } 796 797 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 798 } 799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 800 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 801 } 802 void Generator::generate(pdl_interp::CheckAttributeOp op, 803 ByteCodeWriter &writer) { 804 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 805 op.getSuccessors()); 806 } 807 void Generator::generate(pdl_interp::CheckOperandCountOp op, 808 ByteCodeWriter &writer) { 809 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 810 static_cast<ByteCodeField>(op.getCompareAtLeast()), 811 op.getSuccessors()); 812 } 813 void Generator::generate(pdl_interp::CheckOperationNameOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckOperationName, op.getInputOp(), 816 OperationName(op.getName(), ctx), op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::CheckResultCountOp op, 819 ByteCodeWriter &writer) { 820 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 821 static_cast<ByteCodeField>(op.getCompareAtLeast()), 822 op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 825 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 826 op.getSuccessors()); 827 } 828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 829 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 830 op.getSuccessors()); 831 } 832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 833 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 834 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 835 } 836 void Generator::generate(pdl_interp::CreateAttributeOp op, 837 ByteCodeWriter &writer) { 838 // Simply repoint the memory index of the result to the constant. 839 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 840 } 841 void Generator::generate(pdl_interp::CreateOperationOp op, 842 ByteCodeWriter &writer) { 843 writer.append(OpCode::CreateOperation, op.getResultOp(), 844 OperationName(op.getName(), ctx)); 845 writer.appendPDLValueList(op.getInputOperands()); 846 847 // Add the attributes. 848 OperandRange attributes = op.getInputAttributes(); 849 writer.append(static_cast<ByteCodeField>(attributes.size())); 850 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 851 writer.append(std::get<0>(it), std::get<1>(it)); 852 853 // Add the result types. If the operation has inferred results, we use a 854 // marker "size" value. Otherwise, we add the list of explicit result types. 855 if (op.getInferredResultTypes()) 856 writer.append(kInferTypesMarker); 857 else 858 writer.appendPDLValueList(op.getInputResultTypes()); 859 } 860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 861 // Simply repoint the memory index of the result to the constant. 862 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 863 } 864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 865 writer.append(OpCode::CreateTypes, op.getResult(), 866 getRangeStorageIndex(op.getResult()), op.getValue()); 867 } 868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 869 writer.append(OpCode::EraseOp, op.getInputOp()); 870 } 871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 872 OpCode opCode = 873 TypeSwitch<Type, OpCode>(op.getResult().getType()) 874 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 875 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 876 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 877 .Default([](Type) -> OpCode { 878 llvm_unreachable("unsupported element type"); 879 }); 880 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 881 } 882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 883 writer.append(OpCode::Finalize); 884 } 885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 886 BlockArgument arg = op.getLoopVariable(); 887 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 888 writer.appendPDLValueKind(arg.getType()); 889 writer.append(curLoopLevel, op.getSuccessor()); 890 ++curLoopLevel; 891 if (curLoopLevel > maxLoopLevel) 892 maxLoopLevel = curLoopLevel; 893 generate(&op.getRegion(), writer); 894 --curLoopLevel; 895 } 896 void Generator::generate(pdl_interp::GetAttributeOp op, 897 ByteCodeWriter &writer) { 898 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 899 op.getNameAttr()); 900 } 901 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 902 ByteCodeWriter &writer) { 903 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 904 } 905 void Generator::generate(pdl_interp::GetDefiningOpOp op, 906 ByteCodeWriter &writer) { 907 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 908 writer.appendPDLValue(op.getValue()); 909 } 910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 911 uint32_t index = op.getIndex(); 912 if (index < 4) 913 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 914 else 915 writer.append(OpCode::GetOperandN, index); 916 writer.append(op.getInputOp(), op.getValue()); 917 } 918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 919 Value result = op.getValue(); 920 Optional<uint32_t> index = op.getIndex(); 921 writer.append(OpCode::GetOperands, 922 index.value_or(std::numeric_limits<uint32_t>::max()), 923 op.getInputOp()); 924 if (result.getType().isa<pdl::RangeType>()) 925 writer.append(getRangeStorageIndex(result)); 926 else 927 writer.append(std::numeric_limits<ByteCodeField>::max()); 928 writer.append(result); 929 } 930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 931 uint32_t index = op.getIndex(); 932 if (index < 4) 933 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 934 else 935 writer.append(OpCode::GetResultN, index); 936 writer.append(op.getInputOp(), op.getValue()); 937 } 938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 939 Value result = op.getValue(); 940 Optional<uint32_t> index = op.getIndex(); 941 writer.append(OpCode::GetResults, 942 index.value_or(std::numeric_limits<uint32_t>::max()), 943 op.getInputOp()); 944 if (result.getType().isa<pdl::RangeType>()) 945 writer.append(getRangeStorageIndex(result)); 946 else 947 writer.append(std::numeric_limits<ByteCodeField>::max()); 948 writer.append(result); 949 } 950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 951 Value operations = op.getOperations(); 952 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 953 writer.append(OpCode::GetUsers, operations, rangeIndex); 954 writer.appendPDLValue(op.getValue()); 955 } 956 void Generator::generate(pdl_interp::GetValueTypeOp op, 957 ByteCodeWriter &writer) { 958 if (op.getType().isa<pdl::RangeType>()) { 959 Value result = op.getResult(); 960 writer.append(OpCode::GetValueRangeTypes, result, 961 getRangeStorageIndex(result), op.getValue()); 962 } else { 963 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 964 } 965 } 966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 967 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 968 } 969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 970 ByteCodeField patternIndex = patterns.size(); 971 patterns.emplace_back(PDLByteCodePattern::create( 972 op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 973 writer.append(OpCode::RecordMatch, patternIndex, 974 SuccessorRange(op.getOperation()), op.getMatchedOps()); 975 writer.appendPDLValueList(op.getInputs()); 976 } 977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 978 writer.append(OpCode::ReplaceOp, op.getInputOp()); 979 writer.appendPDLValueList(op.getReplValues()); 980 } 981 void Generator::generate(pdl_interp::SwitchAttributeOp op, 982 ByteCodeWriter &writer) { 983 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 984 op.getCaseValuesAttr(), op.getSuccessors()); 985 } 986 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 987 ByteCodeWriter &writer) { 988 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 989 op.getCaseValuesAttr(), op.getSuccessors()); 990 } 991 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 992 ByteCodeWriter &writer) { 993 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 994 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 995 }); 996 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 997 op.getSuccessors()); 998 } 999 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1000 ByteCodeWriter &writer) { 1001 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1002 op.getCaseValuesAttr(), op.getSuccessors()); 1003 } 1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1005 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1006 op.getSuccessors()); 1007 } 1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1009 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1010 op.getSuccessors()); 1011 } 1012 1013 //===----------------------------------------------------------------------===// 1014 // PDLByteCode 1015 //===----------------------------------------------------------------------===// 1016 1017 PDLByteCode::PDLByteCode(ModuleOp module, 1018 llvm::StringMap<PDLConstraintFunction> constraintFns, 1019 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1020 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1021 rewriterByteCode, patterns, maxValueMemoryIndex, 1022 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1023 maxLoopLevel, constraintFns, rewriteFns); 1024 generator.generate(module); 1025 1026 // Initialize the external functions. 1027 for (auto &it : constraintFns) 1028 constraintFunctions.push_back(std::move(it.second)); 1029 for (auto &it : rewriteFns) 1030 rewriteFunctions.push_back(std::move(it.second)); 1031 } 1032 1033 /// Initialize the given state such that it can be used to execute the current 1034 /// bytecode. 1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1036 state.memory.resize(maxValueMemoryIndex, nullptr); 1037 state.opRangeMemory.resize(maxOpRangeCount); 1038 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1039 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1040 state.loopIndex.resize(maxLoopLevel, 0); 1041 state.currentPatternBenefits.reserve(patterns.size()); 1042 for (const PDLByteCodePattern &pattern : patterns) 1043 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // ByteCode Execution 1048 1049 namespace { 1050 /// This class provides support for executing a bytecode stream. 1051 class ByteCodeExecutor { 1052 public: 1053 ByteCodeExecutor( 1054 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1055 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1056 MutableArrayRef<TypeRange> typeRangeMemory, 1057 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1058 MutableArrayRef<ValueRange> valueRangeMemory, 1059 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1060 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1061 ArrayRef<ByteCodeField> code, 1062 ArrayRef<PatternBenefit> currentPatternBenefits, 1063 ArrayRef<PDLByteCodePattern> patterns, 1064 ArrayRef<PDLConstraintFunction> constraintFunctions, 1065 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1066 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1067 typeRangeMemory(typeRangeMemory), 1068 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1069 valueRangeMemory(valueRangeMemory), 1070 allocatedValueRangeMemory(allocatedValueRangeMemory), 1071 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1072 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1073 constraintFunctions(constraintFunctions), 1074 rewriteFunctions(rewriteFunctions) {} 1075 1076 /// Start executing the code at the current bytecode index. `matches` is an 1077 /// optional field provided when this function is executed in a matching 1078 /// context. 1079 void execute(PatternRewriter &rewriter, 1080 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1081 Optional<Location> mainRewriteLoc = {}); 1082 1083 private: 1084 /// Internal implementation of executing each of the bytecode commands. 1085 void executeApplyConstraint(PatternRewriter &rewriter); 1086 void executeApplyRewrite(PatternRewriter &rewriter); 1087 void executeAreEqual(); 1088 void executeAreRangesEqual(); 1089 void executeBranch(); 1090 void executeCheckOperandCount(); 1091 void executeCheckOperationName(); 1092 void executeCheckResultCount(); 1093 void executeCheckTypes(); 1094 void executeContinue(); 1095 void executeCreateOperation(PatternRewriter &rewriter, 1096 Location mainRewriteLoc); 1097 void executeCreateTypes(); 1098 void executeEraseOp(PatternRewriter &rewriter); 1099 template <typename T, typename Range, PDLValue::Kind kind> 1100 void executeExtract(); 1101 void executeFinalize(); 1102 void executeForEach(); 1103 void executeGetAttribute(); 1104 void executeGetAttributeType(); 1105 void executeGetDefiningOp(); 1106 void executeGetOperand(unsigned index); 1107 void executeGetOperands(); 1108 void executeGetResult(unsigned index); 1109 void executeGetResults(); 1110 void executeGetUsers(); 1111 void executeGetValueType(); 1112 void executeGetValueRangeTypes(); 1113 void executeIsNotNull(); 1114 void executeRecordMatch(PatternRewriter &rewriter, 1115 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1116 void executeReplaceOp(PatternRewriter &rewriter); 1117 void executeSwitchAttribute(); 1118 void executeSwitchOperandCount(); 1119 void executeSwitchOperationName(); 1120 void executeSwitchResultCount(); 1121 void executeSwitchType(); 1122 void executeSwitchTypes(); 1123 1124 /// Pushes a code iterator to the stack. 1125 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1126 1127 /// Pops a code iterator from the stack, returning true on success. 1128 void popCodeIt() { 1129 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1130 curCodeIt = resumeCodeIt.back(); 1131 resumeCodeIt.pop_back(); 1132 } 1133 1134 /// Return the bytecode iterator at the start of the current op code. 1135 const ByteCodeField *getPrevCodeIt() const { 1136 LLVM_DEBUG({ 1137 // Account for the op code and the Location stored inline. 1138 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1139 }); 1140 1141 // Account for the op code only. 1142 return curCodeIt - 1; 1143 } 1144 1145 /// Read a value from the bytecode buffer, optionally skipping a certain 1146 /// number of prefix values. These methods always update the buffer to point 1147 /// to the next field after the read data. 1148 template <typename T = ByteCodeField> 1149 T read(size_t skipN = 0) { 1150 curCodeIt += skipN; 1151 return readImpl<T>(); 1152 } 1153 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1154 1155 /// Read a list of values from the bytecode buffer. 1156 template <typename ValueT, typename T> 1157 void readList(SmallVectorImpl<T> &list) { 1158 list.clear(); 1159 for (unsigned i = 0, e = read(); i != e; ++i) 1160 list.push_back(read<ValueT>()); 1161 } 1162 1163 /// Read a list of values from the bytecode buffer. The values may be encoded 1164 /// as either Value or ValueRange elements. 1165 void readValueList(SmallVectorImpl<Value> &list) { 1166 for (unsigned i = 0, e = read(); i != e; ++i) { 1167 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1168 list.push_back(read<Value>()); 1169 } else { 1170 ValueRange *values = read<ValueRange *>(); 1171 list.append(values->begin(), values->end()); 1172 } 1173 } 1174 } 1175 1176 /// Read a value stored inline as a pointer. 1177 template <typename T> 1178 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1179 readInline() { 1180 const void *pointer; 1181 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1182 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1183 return T::getFromOpaquePointer(pointer); 1184 } 1185 1186 /// Jump to a specific successor based on a predicate value. 1187 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1188 /// Jump to a specific successor based on a destination index. 1189 void selectJump(size_t destIndex) { 1190 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1191 } 1192 1193 /// Handle a switch operation with the provided value and cases. 1194 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1195 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1196 LLVM_DEBUG({ 1197 llvm::dbgs() << " * Value: " << value << "\n" 1198 << " * Cases: "; 1199 llvm::interleaveComma(cases, llvm::dbgs()); 1200 llvm::dbgs() << "\n"; 1201 }); 1202 1203 // Check to see if the attribute value is within the case list. Jump to 1204 // the correct successor index based on the result. 1205 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1206 if (cmp(*it, value)) 1207 return selectJump(size_t((it - cases.begin()) + 1)); 1208 selectJump(size_t(0)); 1209 } 1210 1211 /// Store a pointer to memory. 1212 void storeToMemory(unsigned index, const void *value) { 1213 memory[index] = value; 1214 } 1215 1216 /// Store a value to memory as an opaque pointer. 1217 template <typename T> 1218 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1219 storeToMemory(unsigned index, T value) { 1220 memory[index] = value.getAsOpaquePointer(); 1221 } 1222 1223 /// Internal implementation of reading various data types from the bytecode 1224 /// stream. 1225 template <typename T> 1226 const void *readFromMemory() { 1227 size_t index = *curCodeIt++; 1228 1229 // If this type is an SSA value, it can only be stored in non-const memory. 1230 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1231 Value>::value || 1232 index < memory.size()) 1233 return memory[index]; 1234 1235 // Otherwise, if this index is not inbounds it is uniqued. 1236 return uniquedMemory[index - memory.size()]; 1237 } 1238 template <typename T> 1239 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1240 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1241 } 1242 template <typename T> 1243 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1244 T> 1245 readImpl() { 1246 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1247 } 1248 template <typename T> 1249 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1250 switch (read<PDLValue::Kind>()) { 1251 case PDLValue::Kind::Attribute: 1252 return read<Attribute>(); 1253 case PDLValue::Kind::Operation: 1254 return read<Operation *>(); 1255 case PDLValue::Kind::Type: 1256 return read<Type>(); 1257 case PDLValue::Kind::Value: 1258 return read<Value>(); 1259 case PDLValue::Kind::TypeRange: 1260 return read<TypeRange *>(); 1261 case PDLValue::Kind::ValueRange: 1262 return read<ValueRange *>(); 1263 } 1264 llvm_unreachable("unhandled PDLValue::Kind"); 1265 } 1266 template <typename T> 1267 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1268 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1269 "unexpected ByteCode address size"); 1270 ByteCodeAddr result; 1271 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1272 curCodeIt += 2; 1273 return result; 1274 } 1275 template <typename T> 1276 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1277 return *curCodeIt++; 1278 } 1279 template <typename T> 1280 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1281 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1282 } 1283 1284 /// The underlying bytecode buffer. 1285 const ByteCodeField *curCodeIt; 1286 1287 /// The stack of bytecode positions at which to resume operation. 1288 SmallVector<const ByteCodeField *> resumeCodeIt; 1289 1290 /// The current execution memory. 1291 MutableArrayRef<const void *> memory; 1292 MutableArrayRef<OwningOpRange> opRangeMemory; 1293 MutableArrayRef<TypeRange> typeRangeMemory; 1294 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1295 MutableArrayRef<ValueRange> valueRangeMemory; 1296 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1297 1298 /// The current loop indices. 1299 MutableArrayRef<unsigned> loopIndex; 1300 1301 /// References to ByteCode data necessary for execution. 1302 ArrayRef<const void *> uniquedMemory; 1303 ArrayRef<ByteCodeField> code; 1304 ArrayRef<PatternBenefit> currentPatternBenefits; 1305 ArrayRef<PDLByteCodePattern> patterns; 1306 ArrayRef<PDLConstraintFunction> constraintFunctions; 1307 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1308 }; 1309 1310 /// This class is an instantiation of the PDLResultList that provides access to 1311 /// the returned results. This API is not on `PDLResultList` to avoid 1312 /// overexposing access to information specific solely to the ByteCode. 1313 class ByteCodeRewriteResultList : public PDLResultList { 1314 public: 1315 ByteCodeRewriteResultList(unsigned maxNumResults) 1316 : PDLResultList(maxNumResults) {} 1317 1318 /// Return the list of PDL results. 1319 MutableArrayRef<PDLValue> getResults() { return results; } 1320 1321 /// Return the type ranges allocated by this list. 1322 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1323 return allocatedTypeRanges; 1324 } 1325 1326 /// Return the value ranges allocated by this list. 1327 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1328 return allocatedValueRanges; 1329 } 1330 }; 1331 } // namespace 1332 1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1334 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1335 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1336 SmallVector<PDLValue, 16> args; 1337 readList<PDLValue>(args); 1338 1339 LLVM_DEBUG({ 1340 llvm::dbgs() << " * Arguments: "; 1341 llvm::interleaveComma(args, llvm::dbgs()); 1342 }); 1343 1344 // Invoke the constraint and jump to the proper destination. 1345 selectJump(succeeded(constraintFn(rewriter, args))); 1346 } 1347 1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1349 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1350 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1351 SmallVector<PDLValue, 16> args; 1352 readList<PDLValue>(args); 1353 1354 LLVM_DEBUG({ 1355 llvm::dbgs() << " * Arguments: "; 1356 llvm::interleaveComma(args, llvm::dbgs()); 1357 }); 1358 1359 // Execute the rewrite function. 1360 ByteCodeField numResults = read(); 1361 ByteCodeRewriteResultList results(numResults); 1362 rewriteFn(rewriter, results, args); 1363 1364 assert(results.getResults().size() == numResults && 1365 "native PDL rewrite function returned unexpected number of results"); 1366 1367 // Store the results in the bytecode memory. 1368 for (PDLValue &result : results.getResults()) { 1369 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1370 1371 // In debug mode we also verify the expected kind of the result. 1372 #ifndef NDEBUG 1373 assert(result.getKind() == read<PDLValue::Kind>() && 1374 "native PDL rewrite function returned an unexpected type of result"); 1375 #endif 1376 1377 // If the result is a range, we need to copy it over to the bytecodes 1378 // range memory. 1379 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1380 unsigned rangeIndex = read(); 1381 typeRangeMemory[rangeIndex] = *typeRange; 1382 memory[read()] = &typeRangeMemory[rangeIndex]; 1383 } else if (Optional<ValueRange> valueRange = 1384 result.dyn_cast<ValueRange>()) { 1385 unsigned rangeIndex = read(); 1386 valueRangeMemory[rangeIndex] = *valueRange; 1387 memory[read()] = &valueRangeMemory[rangeIndex]; 1388 } else { 1389 memory[read()] = result.getAsOpaquePointer(); 1390 } 1391 } 1392 1393 // Copy over any underlying storage allocated for result ranges. 1394 for (auto &it : results.getAllocatedTypeRanges()) 1395 allocatedTypeRangeMemory.push_back(std::move(it)); 1396 for (auto &it : results.getAllocatedValueRanges()) 1397 allocatedValueRangeMemory.push_back(std::move(it)); 1398 } 1399 1400 void ByteCodeExecutor::executeAreEqual() { 1401 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1402 const void *lhs = read<const void *>(); 1403 const void *rhs = read<const void *>(); 1404 1405 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1406 selectJump(lhs == rhs); 1407 } 1408 1409 void ByteCodeExecutor::executeAreRangesEqual() { 1410 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1411 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1412 const void *lhs = read<const void *>(); 1413 const void *rhs = read<const void *>(); 1414 1415 switch (valueKind) { 1416 case PDLValue::Kind::TypeRange: { 1417 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1418 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1419 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1420 selectJump(*lhsRange == *rhsRange); 1421 break; 1422 } 1423 case PDLValue::Kind::ValueRange: { 1424 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1425 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1426 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1427 selectJump(*lhsRange == *rhsRange); 1428 break; 1429 } 1430 default: 1431 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1432 } 1433 } 1434 1435 void ByteCodeExecutor::executeBranch() { 1436 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1437 curCodeIt = &code[read<ByteCodeAddr>()]; 1438 } 1439 1440 void ByteCodeExecutor::executeCheckOperandCount() { 1441 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1442 Operation *op = read<Operation *>(); 1443 uint32_t expectedCount = read<uint32_t>(); 1444 bool compareAtLeast = read(); 1445 1446 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1447 << " * Expected: " << expectedCount << "\n" 1448 << " * Comparator: " 1449 << (compareAtLeast ? ">=" : "==") << "\n"); 1450 if (compareAtLeast) 1451 selectJump(op->getNumOperands() >= expectedCount); 1452 else 1453 selectJump(op->getNumOperands() == expectedCount); 1454 } 1455 1456 void ByteCodeExecutor::executeCheckOperationName() { 1457 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1458 Operation *op = read<Operation *>(); 1459 OperationName expectedName = read<OperationName>(); 1460 1461 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1462 << " * Expected: \"" << expectedName << "\"\n"); 1463 selectJump(op->getName() == expectedName); 1464 } 1465 1466 void ByteCodeExecutor::executeCheckResultCount() { 1467 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1468 Operation *op = read<Operation *>(); 1469 uint32_t expectedCount = read<uint32_t>(); 1470 bool compareAtLeast = read(); 1471 1472 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1473 << " * Expected: " << expectedCount << "\n" 1474 << " * Comparator: " 1475 << (compareAtLeast ? ">=" : "==") << "\n"); 1476 if (compareAtLeast) 1477 selectJump(op->getNumResults() >= expectedCount); 1478 else 1479 selectJump(op->getNumResults() == expectedCount); 1480 } 1481 1482 void ByteCodeExecutor::executeCheckTypes() { 1483 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1484 TypeRange *lhs = read<TypeRange *>(); 1485 Attribute rhs = read<Attribute>(); 1486 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1487 1488 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1489 } 1490 1491 void ByteCodeExecutor::executeContinue() { 1492 ByteCodeField level = read(); 1493 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1494 << " * Level: " << level << "\n"); 1495 ++loopIndex[level]; 1496 popCodeIt(); 1497 } 1498 1499 void ByteCodeExecutor::executeCreateTypes() { 1500 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1501 unsigned memIndex = read(); 1502 unsigned rangeIndex = read(); 1503 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1504 1505 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1506 1507 // Allocate a buffer for this type range. 1508 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1509 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1510 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1511 1512 // Assign this to the range slot and use the range as the value for the 1513 // memory index. 1514 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1515 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1516 } 1517 1518 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1519 Location mainRewriteLoc) { 1520 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1521 1522 unsigned memIndex = read(); 1523 OperationState state(mainRewriteLoc, read<OperationName>()); 1524 readValueList(state.operands); 1525 for (unsigned i = 0, e = read(); i != e; ++i) { 1526 StringAttr name = read<StringAttr>(); 1527 if (Attribute attr = read<Attribute>()) 1528 state.addAttribute(name, attr); 1529 } 1530 1531 // Read in the result types. If the "size" is the sentinel value, this 1532 // indicates that the result types should be inferred. 1533 unsigned numResults = read(); 1534 if (numResults == kInferTypesMarker) { 1535 InferTypeOpInterface::Concept *inferInterface = 1536 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1537 assert(inferInterface && 1538 "expected operation to provide InferTypeOpInterface"); 1539 1540 // TODO: Handle failure. 1541 if (failed(inferInterface->inferReturnTypes( 1542 state.getContext(), state.location, state.operands, 1543 state.attributes.getDictionary(state.getContext()), state.regions, 1544 state.types))) 1545 return; 1546 } else { 1547 // Otherwise, this is a fixed number of results. 1548 for (unsigned i = 0; i != numResults; ++i) { 1549 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1550 state.types.push_back(read<Type>()); 1551 } else { 1552 TypeRange *resultTypes = read<TypeRange *>(); 1553 state.types.append(resultTypes->begin(), resultTypes->end()); 1554 } 1555 } 1556 } 1557 1558 Operation *resultOp = rewriter.create(state); 1559 memory[memIndex] = resultOp; 1560 1561 LLVM_DEBUG({ 1562 llvm::dbgs() << " * Attributes: " 1563 << state.attributes.getDictionary(state.getContext()) 1564 << "\n * Operands: "; 1565 llvm::interleaveComma(state.operands, llvm::dbgs()); 1566 llvm::dbgs() << "\n * Result Types: "; 1567 llvm::interleaveComma(state.types, llvm::dbgs()); 1568 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1569 }); 1570 } 1571 1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1573 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1574 Operation *op = read<Operation *>(); 1575 1576 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1577 rewriter.eraseOp(op); 1578 } 1579 1580 template <typename T, typename Range, PDLValue::Kind kind> 1581 void ByteCodeExecutor::executeExtract() { 1582 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1583 Range *range = read<Range *>(); 1584 unsigned index = read<uint32_t>(); 1585 unsigned memIndex = read(); 1586 1587 if (!range) { 1588 memory[memIndex] = nullptr; 1589 return; 1590 } 1591 1592 T result = index < range->size() ? (*range)[index] : T(); 1593 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1594 << " * Index: " << index << "\n" 1595 << " * Result: " << result << "\n"); 1596 storeToMemory(memIndex, result); 1597 } 1598 1599 void ByteCodeExecutor::executeFinalize() { 1600 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1601 } 1602 1603 void ByteCodeExecutor::executeForEach() { 1604 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1605 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1606 unsigned rangeIndex = read(); 1607 unsigned memIndex = read(); 1608 const void *value = nullptr; 1609 1610 switch (read<PDLValue::Kind>()) { 1611 case PDLValue::Kind::Operation: { 1612 unsigned &index = loopIndex[read()]; 1613 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1614 assert(index <= array.size() && "iterated past the end"); 1615 if (index < array.size()) { 1616 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1617 value = array[index]; 1618 break; 1619 } 1620 1621 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1622 index = 0; 1623 selectJump(size_t(0)); 1624 return; 1625 } 1626 default: 1627 llvm_unreachable("unexpected `ForEach` value kind"); 1628 } 1629 1630 // Store the iterate value and the stack address. 1631 memory[memIndex] = value; 1632 pushCodeIt(prevCodeIt); 1633 1634 // Skip over the successor (we will enter the body of the loop). 1635 read<ByteCodeAddr>(); 1636 } 1637 1638 void ByteCodeExecutor::executeGetAttribute() { 1639 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1640 unsigned memIndex = read(); 1641 Operation *op = read<Operation *>(); 1642 StringAttr attrName = read<StringAttr>(); 1643 Attribute attr = op->getAttr(attrName); 1644 1645 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1646 << " * Attribute: " << attrName << "\n" 1647 << " * Result: " << attr << "\n"); 1648 memory[memIndex] = attr.getAsOpaquePointer(); 1649 } 1650 1651 void ByteCodeExecutor::executeGetAttributeType() { 1652 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1653 unsigned memIndex = read(); 1654 Attribute attr = read<Attribute>(); 1655 Type type; 1656 if (auto typedAttr = attr.dyn_cast<TypedAttr>()) 1657 type = typedAttr.getType(); 1658 1659 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1660 << " * Result: " << type << "\n"); 1661 memory[memIndex] = type.getAsOpaquePointer(); 1662 } 1663 1664 void ByteCodeExecutor::executeGetDefiningOp() { 1665 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1666 unsigned memIndex = read(); 1667 Operation *op = nullptr; 1668 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1669 Value value = read<Value>(); 1670 if (value) 1671 op = value.getDefiningOp(); 1672 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1673 } else { 1674 ValueRange *values = read<ValueRange *>(); 1675 if (values && !values->empty()) { 1676 op = values->front().getDefiningOp(); 1677 } 1678 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1679 } 1680 1681 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1682 memory[memIndex] = op; 1683 } 1684 1685 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1686 Operation *op = read<Operation *>(); 1687 unsigned memIndex = read(); 1688 Value operand = 1689 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1690 1691 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1692 << " * Index: " << index << "\n" 1693 << " * Result: " << operand << "\n"); 1694 memory[memIndex] = operand.getAsOpaquePointer(); 1695 } 1696 1697 /// This function is the internal implementation of `GetResults` and 1698 /// `GetOperands` that provides support for extracting a value range from the 1699 /// given operation. 1700 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1701 static void * 1702 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1703 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1704 MutableArrayRef<ValueRange> valueRangeMemory) { 1705 // Check for the sentinel index that signals that all values should be 1706 // returned. 1707 if (index == std::numeric_limits<uint32_t>::max()) { 1708 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1709 // `values` is already the full value range. 1710 1711 // Otherwise, check to see if this operation uses AttrSizedSegments. 1712 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1713 LLVM_DEBUG(llvm::dbgs() 1714 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1715 1716 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1717 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1718 return nullptr; 1719 1720 auto segments = segmentAttr.getValues<int32_t>(); 1721 unsigned startIndex = 1722 std::accumulate(segments.begin(), segments.begin() + index, 0); 1723 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1724 1725 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1726 << *std::next(segments.begin(), index) << "]\n"); 1727 1728 // Otherwise, assume this is the last operand group of the operation. 1729 // FIXME: We currently don't support operations with 1730 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1731 // have a way to detect it's presence. 1732 } else if (values.size() >= index) { 1733 LLVM_DEBUG(llvm::dbgs() 1734 << " * Treating values as trailing variadic range\n"); 1735 values = values.drop_front(index); 1736 1737 // If we couldn't detect a way to compute the values, bail out. 1738 } else { 1739 return nullptr; 1740 } 1741 1742 // If the range index is valid, we are returning a range. 1743 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1744 valueRangeMemory[rangeIndex] = values; 1745 return &valueRangeMemory[rangeIndex]; 1746 } 1747 1748 // If a range index wasn't provided, the range is required to be non-variadic. 1749 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1750 } 1751 1752 void ByteCodeExecutor::executeGetOperands() { 1753 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1754 unsigned index = read<uint32_t>(); 1755 Operation *op = read<Operation *>(); 1756 ByteCodeField rangeIndex = read(); 1757 1758 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1759 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1760 valueRangeMemory); 1761 if (!result) 1762 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1763 memory[read()] = result; 1764 } 1765 1766 void ByteCodeExecutor::executeGetResult(unsigned index) { 1767 Operation *op = read<Operation *>(); 1768 unsigned memIndex = read(); 1769 OpResult result = 1770 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1771 1772 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1773 << " * Index: " << index << "\n" 1774 << " * Result: " << result << "\n"); 1775 memory[memIndex] = result.getAsOpaquePointer(); 1776 } 1777 1778 void ByteCodeExecutor::executeGetResults() { 1779 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1780 unsigned index = read<uint32_t>(); 1781 Operation *op = read<Operation *>(); 1782 ByteCodeField rangeIndex = read(); 1783 1784 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1785 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1786 valueRangeMemory); 1787 if (!result) 1788 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1789 memory[read()] = result; 1790 } 1791 1792 void ByteCodeExecutor::executeGetUsers() { 1793 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1794 unsigned memIndex = read(); 1795 unsigned rangeIndex = read(); 1796 OwningOpRange &range = opRangeMemory[rangeIndex]; 1797 memory[memIndex] = ⦥ 1798 1799 range = OwningOpRange(); 1800 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1801 // Read the value. 1802 Value value = read<Value>(); 1803 if (!value) 1804 return; 1805 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1806 1807 // Extract the users of a single value. 1808 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1809 llvm::copy(value.getUsers(), range.begin()); 1810 } else { 1811 // Read a range of values. 1812 ValueRange *values = read<ValueRange *>(); 1813 if (!values) 1814 return; 1815 LLVM_DEBUG({ 1816 llvm::dbgs() << " * Values (" << values->size() << "): "; 1817 llvm::interleaveComma(*values, llvm::dbgs()); 1818 llvm::dbgs() << "\n"; 1819 }); 1820 1821 // Extract all the users of a range of values. 1822 SmallVector<Operation *> users; 1823 for (Value value : *values) 1824 users.append(value.user_begin(), value.user_end()); 1825 range = OwningOpRange(users.size()); 1826 llvm::copy(users, range.begin()); 1827 } 1828 1829 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1830 } 1831 1832 void ByteCodeExecutor::executeGetValueType() { 1833 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1834 unsigned memIndex = read(); 1835 Value value = read<Value>(); 1836 Type type = value ? value.getType() : Type(); 1837 1838 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1839 << " * Result: " << type << "\n"); 1840 memory[memIndex] = type.getAsOpaquePointer(); 1841 } 1842 1843 void ByteCodeExecutor::executeGetValueRangeTypes() { 1844 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1845 unsigned memIndex = read(); 1846 unsigned rangeIndex = read(); 1847 ValueRange *values = read<ValueRange *>(); 1848 if (!values) { 1849 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1850 memory[memIndex] = nullptr; 1851 return; 1852 } 1853 1854 LLVM_DEBUG({ 1855 llvm::dbgs() << " * Values (" << values->size() << "): "; 1856 llvm::interleaveComma(*values, llvm::dbgs()); 1857 llvm::dbgs() << "\n * Result: "; 1858 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1859 llvm::dbgs() << "\n"; 1860 }); 1861 typeRangeMemory[rangeIndex] = values->getType(); 1862 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1863 } 1864 1865 void ByteCodeExecutor::executeIsNotNull() { 1866 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1867 const void *value = read<const void *>(); 1868 1869 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1870 selectJump(value != nullptr); 1871 } 1872 1873 void ByteCodeExecutor::executeRecordMatch( 1874 PatternRewriter &rewriter, 1875 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1876 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1877 unsigned patternIndex = read(); 1878 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1879 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1880 1881 // If the benefit of the pattern is impossible, skip the processing of the 1882 // rest of the pattern. 1883 if (benefit.isImpossibleToMatch()) { 1884 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1885 curCodeIt = dest; 1886 return; 1887 } 1888 1889 // Create a fused location containing the locations of each of the 1890 // operations used in the match. This will be used as the location for 1891 // created operations during the rewrite that don't already have an 1892 // explicit location set. 1893 unsigned numMatchLocs = read(); 1894 SmallVector<Location, 4> matchLocs; 1895 matchLocs.reserve(numMatchLocs); 1896 for (unsigned i = 0; i != numMatchLocs; ++i) 1897 matchLocs.push_back(read<Operation *>()->getLoc()); 1898 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1899 1900 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1901 << " * Location: " << matchLoc << "\n"); 1902 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1903 PDLByteCode::MatchResult &match = matches.back(); 1904 1905 // Record all of the inputs to the match. If any of the inputs are ranges, we 1906 // will also need to remap the range pointer to memory stored in the match 1907 // state. 1908 unsigned numInputs = read(); 1909 match.values.reserve(numInputs); 1910 match.typeRangeValues.reserve(numInputs); 1911 match.valueRangeValues.reserve(numInputs); 1912 for (unsigned i = 0; i < numInputs; ++i) { 1913 switch (read<PDLValue::Kind>()) { 1914 case PDLValue::Kind::TypeRange: 1915 match.typeRangeValues.push_back(*read<TypeRange *>()); 1916 match.values.push_back(&match.typeRangeValues.back()); 1917 break; 1918 case PDLValue::Kind::ValueRange: 1919 match.valueRangeValues.push_back(*read<ValueRange *>()); 1920 match.values.push_back(&match.valueRangeValues.back()); 1921 break; 1922 default: 1923 match.values.push_back(read<const void *>()); 1924 break; 1925 } 1926 } 1927 curCodeIt = dest; 1928 } 1929 1930 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1931 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1932 Operation *op = read<Operation *>(); 1933 SmallVector<Value, 16> args; 1934 readValueList(args); 1935 1936 LLVM_DEBUG({ 1937 llvm::dbgs() << " * Operation: " << *op << "\n" 1938 << " * Values: "; 1939 llvm::interleaveComma(args, llvm::dbgs()); 1940 llvm::dbgs() << "\n"; 1941 }); 1942 rewriter.replaceOp(op, args); 1943 } 1944 1945 void ByteCodeExecutor::executeSwitchAttribute() { 1946 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1947 Attribute value = read<Attribute>(); 1948 ArrayAttr cases = read<ArrayAttr>(); 1949 handleSwitch(value, cases); 1950 } 1951 1952 void ByteCodeExecutor::executeSwitchOperandCount() { 1953 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1954 Operation *op = read<Operation *>(); 1955 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1956 1957 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1958 handleSwitch(op->getNumOperands(), cases); 1959 } 1960 1961 void ByteCodeExecutor::executeSwitchOperationName() { 1962 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1963 OperationName value = read<Operation *>()->getName(); 1964 size_t caseCount = read(); 1965 1966 // The operation names are stored in-line, so to print them out for 1967 // debugging purposes we need to read the array before executing the 1968 // switch so that we can display all of the possible values. 1969 LLVM_DEBUG({ 1970 const ByteCodeField *prevCodeIt = curCodeIt; 1971 llvm::dbgs() << " * Value: " << value << "\n" 1972 << " * Cases: "; 1973 llvm::interleaveComma( 1974 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1975 [&](size_t) { return read<OperationName>(); }), 1976 llvm::dbgs()); 1977 llvm::dbgs() << "\n"; 1978 curCodeIt = prevCodeIt; 1979 }); 1980 1981 // Try to find the switch value within any of the cases. 1982 for (size_t i = 0; i != caseCount; ++i) { 1983 if (read<OperationName>() == value) { 1984 curCodeIt += (caseCount - i - 1); 1985 return selectJump(i + 1); 1986 } 1987 } 1988 selectJump(size_t(0)); 1989 } 1990 1991 void ByteCodeExecutor::executeSwitchResultCount() { 1992 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1993 Operation *op = read<Operation *>(); 1994 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1995 1996 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1997 handleSwitch(op->getNumResults(), cases); 1998 } 1999 2000 void ByteCodeExecutor::executeSwitchType() { 2001 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2002 Type value = read<Type>(); 2003 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2004 handleSwitch(value, cases); 2005 } 2006 2007 void ByteCodeExecutor::executeSwitchTypes() { 2008 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2009 TypeRange *value = read<TypeRange *>(); 2010 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2011 if (!value) { 2012 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2013 return selectJump(size_t(0)); 2014 } 2015 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2016 return value == caseValue.getAsValueRange<TypeAttr>(); 2017 }); 2018 } 2019 2020 void ByteCodeExecutor::execute( 2021 PatternRewriter &rewriter, 2022 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2023 Optional<Location> mainRewriteLoc) { 2024 while (true) { 2025 // Print the location of the operation being executed. 2026 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2027 2028 OpCode opCode = static_cast<OpCode>(read()); 2029 switch (opCode) { 2030 case ApplyConstraint: 2031 executeApplyConstraint(rewriter); 2032 break; 2033 case ApplyRewrite: 2034 executeApplyRewrite(rewriter); 2035 break; 2036 case AreEqual: 2037 executeAreEqual(); 2038 break; 2039 case AreRangesEqual: 2040 executeAreRangesEqual(); 2041 break; 2042 case Branch: 2043 executeBranch(); 2044 break; 2045 case CheckOperandCount: 2046 executeCheckOperandCount(); 2047 break; 2048 case CheckOperationName: 2049 executeCheckOperationName(); 2050 break; 2051 case CheckResultCount: 2052 executeCheckResultCount(); 2053 break; 2054 case CheckTypes: 2055 executeCheckTypes(); 2056 break; 2057 case Continue: 2058 executeContinue(); 2059 break; 2060 case CreateOperation: 2061 executeCreateOperation(rewriter, *mainRewriteLoc); 2062 break; 2063 case CreateTypes: 2064 executeCreateTypes(); 2065 break; 2066 case EraseOp: 2067 executeEraseOp(rewriter); 2068 break; 2069 case ExtractOp: 2070 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2071 break; 2072 case ExtractType: 2073 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2074 break; 2075 case ExtractValue: 2076 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2077 break; 2078 case Finalize: 2079 executeFinalize(); 2080 LLVM_DEBUG(llvm::dbgs() << "\n"); 2081 return; 2082 case ForEach: 2083 executeForEach(); 2084 break; 2085 case GetAttribute: 2086 executeGetAttribute(); 2087 break; 2088 case GetAttributeType: 2089 executeGetAttributeType(); 2090 break; 2091 case GetDefiningOp: 2092 executeGetDefiningOp(); 2093 break; 2094 case GetOperand0: 2095 case GetOperand1: 2096 case GetOperand2: 2097 case GetOperand3: { 2098 unsigned index = opCode - GetOperand0; 2099 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2100 executeGetOperand(index); 2101 break; 2102 } 2103 case GetOperandN: 2104 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2105 executeGetOperand(read<uint32_t>()); 2106 break; 2107 case GetOperands: 2108 executeGetOperands(); 2109 break; 2110 case GetResult0: 2111 case GetResult1: 2112 case GetResult2: 2113 case GetResult3: { 2114 unsigned index = opCode - GetResult0; 2115 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2116 executeGetResult(index); 2117 break; 2118 } 2119 case GetResultN: 2120 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2121 executeGetResult(read<uint32_t>()); 2122 break; 2123 case GetResults: 2124 executeGetResults(); 2125 break; 2126 case GetUsers: 2127 executeGetUsers(); 2128 break; 2129 case GetValueType: 2130 executeGetValueType(); 2131 break; 2132 case GetValueRangeTypes: 2133 executeGetValueRangeTypes(); 2134 break; 2135 case IsNotNull: 2136 executeIsNotNull(); 2137 break; 2138 case RecordMatch: 2139 assert(matches && 2140 "expected matches to be provided when executing the matcher"); 2141 executeRecordMatch(rewriter, *matches); 2142 break; 2143 case ReplaceOp: 2144 executeReplaceOp(rewriter); 2145 break; 2146 case SwitchAttribute: 2147 executeSwitchAttribute(); 2148 break; 2149 case SwitchOperandCount: 2150 executeSwitchOperandCount(); 2151 break; 2152 case SwitchOperationName: 2153 executeSwitchOperationName(); 2154 break; 2155 case SwitchResultCount: 2156 executeSwitchResultCount(); 2157 break; 2158 case SwitchType: 2159 executeSwitchType(); 2160 break; 2161 case SwitchTypes: 2162 executeSwitchTypes(); 2163 break; 2164 } 2165 LLVM_DEBUG(llvm::dbgs() << "\n"); 2166 } 2167 } 2168 2169 /// Run the pattern matcher on the given root operation, collecting the matched 2170 /// patterns in `matches`. 2171 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2172 SmallVectorImpl<MatchResult> &matches, 2173 PDLByteCodeMutableState &state) const { 2174 // The first memory slot is always the root operation. 2175 state.memory[0] = op; 2176 2177 // The matcher function always starts at code address 0. 2178 ByteCodeExecutor executor( 2179 matcherByteCode.data(), state.memory, state.opRangeMemory, 2180 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2181 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2182 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2183 constraintFunctions, rewriteFunctions); 2184 executor.execute(rewriter, &matches); 2185 2186 // Order the found matches by benefit. 2187 std::stable_sort(matches.begin(), matches.end(), 2188 [](const MatchResult &lhs, const MatchResult &rhs) { 2189 return lhs.benefit > rhs.benefit; 2190 }); 2191 } 2192 2193 /// Run the rewriter of the given pattern on the root operation `op`. 2194 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2195 PDLByteCodeMutableState &state) const { 2196 // The arguments of the rewrite function are stored at the start of the 2197 // memory buffer. 2198 llvm::copy(match.values, state.memory.begin()); 2199 2200 ByteCodeExecutor executor( 2201 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2202 state.opRangeMemory, state.typeRangeMemory, 2203 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2204 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2205 rewriterByteCode, state.currentPatternBenefits, patterns, 2206 constraintFunctions, rewriteFunctions); 2207 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2208 } 2209