//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements MLIR to byte-code generation and the interpreter. // //===----------------------------------------------------------------------===// #include "ByteCode.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/RegionGraphTraits.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" #include #include #define DEBUG_TYPE "pdl-bytecode" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // PDLByteCodePattern //===----------------------------------------------------------------------===// PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr) { PatternBenefit benefit = matchOp.getBenefit(); MLIRContext *ctx = matchOp.getContext(); // Collect the set of generated operations. SmallVector generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange()); // Check to see if this is pattern matches a specific operation type. if (std::optional rootKind = matchOp.getRootKind()) return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, generatedOps); return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), benefit, ctx, generatedOps); } //===----------------------------------------------------------------------===// // PDLByteCodeMutableState //===----------------------------------------------------------------------===// /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds /// to the position of the pattern within the range returned by /// `PDLByteCode::getPatterns`. void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) { currentPatternBenefits[patternIndex] = benefit; } /// Cleanup any allocated state after a full match/rewrite has been completed. /// This method should be called irregardless of whether the match+rewrite was a /// success or not. void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { allocatedTypeRangeMemory.clear(); allocatedValueRangeMemory.clear(); } //===----------------------------------------------------------------------===// // Bytecode OpCodes //===----------------------------------------------------------------------===// namespace { enum OpCode : ByteCodeField { /// Apply an externally registered constraint. ApplyConstraint, /// Apply an externally registered rewrite. ApplyRewrite, /// Check if two generic values are equal. AreEqual, /// Check if two ranges are equal. AreRangesEqual, /// Unconditional branch. Branch, /// Compare the operand count of an operation with a constant. CheckOperandCount, /// Compare the name of an operation with a constant. CheckOperationName, /// Compare the result count of an operation with a constant. CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, /// Continue to the next iteration of a loop. Continue, /// Create a type range from a list of constant types. CreateConstantTypeRange, /// Create an operation. CreateOperation, /// Create a type range from a list of dynamic types. CreateDynamicTypeRange, /// Create a value range. CreateDynamicValueRange, /// Erase an operation. EraseOp, /// Extract the op from a range at the specified index. ExtractOp, /// Extract the type from a range at the specified index. ExtractType, /// Extract the value from a range at the specified index. ExtractValue, /// Terminate a matcher or rewrite sequence. Finalize, /// Iterate over a range of values. ForEach, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. GetAttributeType, /// Get the defining operation of a value. GetDefiningOp, /// Get a specific operand of an operation. GetOperand0, GetOperand1, GetOperand2, GetOperand3, GetOperandN, /// Get a specific operand group of an operation. GetOperands, /// Get a specific result of an operation. GetResult0, GetResult1, GetResult2, GetResult3, GetResultN, /// Get a specific result group of an operation. GetResults, /// Get the users of a value or a range of values. GetUsers, /// Get the type of a value. GetValueType, /// Get the types of a value range. GetValueRangeTypes, /// Check if a generic value is not null. IsNotNull, /// Record a successful pattern match. RecordMatch, /// Replace an operation. ReplaceOp, /// Compare an attribute with a set of constants. SwitchAttribute, /// Compare the operand count of an operation with a set of constants. SwitchOperandCount, /// Compare the name of an operation with a set of constants. SwitchOperationName, /// Compare the result count of an operation with a set of constants. SwitchResultCount, /// Compare a type with a set of constants. SwitchType, /// Compare a range of types with a set of constants. SwitchTypes, }; } // namespace /// A marker used to indicate if an operation should infer types. static constexpr ByteCodeField kInferTypesMarker = std::numeric_limits::max(); //===----------------------------------------------------------------------===// // ByteCode Generation //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Generator namespace { struct ByteCodeLiveRange; struct ByteCodeWriter; /// Check if the given class `T` can be converted to an opaque pointer. template using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); /// This class represents the main generator for the pattern bytecode. class Generator { public: Generator(MLIRContext *ctx, std::vector &uniquedData, SmallVectorImpl &matcherByteCode, SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxLoopLevel, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns, const DenseMap &configMap) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), maxLoopLevel(maxLoopLevel), configMap(configMap) { for (const auto &it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (const auto &it : llvm::enumerate(rewriteFns)) externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); } /// Generate the bytecode for the given PDL interpreter module. void generate(ModuleOp module); /// Return the memory index to use for the given value. ByteCodeField &getMemIndex(Value value) { assert(valueToMemIndex.count(value) && "expected memory index to be assigned"); return valueToMemIndex[value]; } /// Return the range memory index used to store the given range value. ByteCodeField &getRangeStorageIndex(Value value) { assert(valueToRangeIndex.count(value) && "expected range index to be assigned"); return valueToRangeIndex[value]; } /// Return an index to use when referring to the given data that is uniqued in /// the MLIR context. template std::enable_if_t::value, ByteCodeField &> getMemIndex(T val) { const void *opaqueVal = val.getAsOpaquePointer(); // Get or insert a reference to this value. auto it = uniquedDataToMemIndex.try_emplace( opaqueVal, maxValueMemoryIndex + uniquedData.size()); if (it.second) uniquedData.push_back(opaqueVal); return it.first->second; } private: /// Allocate memory indices for the results of operations within the matcher /// and rewriters. void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); /// Mapping from value to its corresponding memory index. DenseMap valueToMemIndex; /// Mapping from a range value to its corresponding range storage index. DenseMap valueToRangeIndex; /// Mapping from the name of an externally registered rewrite to its index in /// the bytecode registry. llvm::StringMap externalRewriterToMemIndex; /// Mapping from the name of an externally registered constraint to its index /// in the bytecode registry. llvm::StringMap constraintToMemIndex; /// Mapping from rewriter function name to the bytecode address of the /// rewriter function in byte. llvm::StringMap rewriterToAddr; /// Mapping from a uniqued storage object to its memory index within /// `uniquedData`. DenseMap uniquedDataToMemIndex; /// The current level of the foreach loop. ByteCodeField curLoopLevel = 0; /// The current MLIR context. MLIRContext *ctx; /// Mapping from block to its address. DenseMap blockToAddr; /// Data of the ByteCode class to be populated. std::vector &uniquedData; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxLoopLevel; /// A map of pattern configurations. const DenseMap &configMap; }; /// This class provides utilities for writing a bytecode stream. struct ByteCodeWriter { ByteCodeWriter(SmallVectorImpl &bytecode, Generator &generator) : bytecode(bytecode), generator(generator) {} /// Append a field to the bytecode. void append(ByteCodeField field) { bytecode.push_back(field); } void append(OpCode opCode) { bytecode.push_back(opCode); } /// Append an address to the bytecode. void append(ByteCodeAddr field) { static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, "unexpected ByteCode address size"); ByteCodeField fieldParts[2]; std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); bytecode.append({fieldParts[0], fieldParts[1]}); } /// Append a single successor to the bytecode, the exact address will need to /// be resolved later. void append(Block *successor) { // Add back a reference to the successor so that the address can be resolved // later. unresolvedSuccessorRefs[successor].push_back(bytecode.size()); append(ByteCodeAddr(0)); } /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { for (Block *successor : successors) append(successor); } /// Append a range of values that will be read as generic PDLValues. void appendPDLValueList(OperandRange values) { bytecode.push_back(values.size()); for (Value value : values) appendPDLValue(value); } /// Append a value as a PDLValue. void appendPDLValue(Value value) { appendPDLValueKind(value); append(value); } /// Append the PDLValue::Kind of the given value. void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } /// Append the PDLValue::Kind of the given type. void appendPDLValueKind(Type type) { PDLValue::Kind kind = TypeSwitch(type) .Case( [](Type) { return PDLValue::Kind::Attribute; }) .Case( [](Type) { return PDLValue::Kind::Operation; }) .Case([](pdl::RangeType rangeTy) { if (isa(rangeTy.getElementType())) return PDLValue::Kind::TypeRange; return PDLValue::Kind::ValueRange; }) .Case([](Type) { return PDLValue::Kind::Type; }) .Case([](Type) { return PDLValue::Kind::Value; }); bytecode.push_back(static_cast(kind)); } /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template std::enable_if_t::value || std::is_pointer::value> append(T value) { bytecode.push_back(generator.getMemIndex(value)); } /// Append a range of values. template > std::enable_if_t::value> append(T range) { bytecode.push_back(llvm::size(range)); for (auto it : range) append(it); } /// Append a variadic number of fields to the bytecode. template void append(FieldTy field, Field2Ty field2, FieldTys... fields) { append(field); append(field2, fields...); } /// Appends a value as a pointer, stored inline within the bytecode. template std::enable_if_t::value> appendInline(T value) { constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); const void *pointer = value.getAsOpaquePointer(); ByteCodeField fieldParts[numParts]; std::memcpy(fieldParts, &pointer, sizeof(const void *)); bytecode.append(fieldParts, fieldParts + numParts); } /// Successor references in the bytecode that have yet to be resolved. DenseMap> unresolvedSuccessorRefs; /// The underlying bytecode buffer. SmallVectorImpl &bytecode; /// The main generator producing PDL. Generator &generator; }; /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { using Set = llvm::IntervalMap; using Allocator = Set::Allocator; ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; ++it) liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. /// /// We use std::unique_ptr here, because IntervalMap does not provide a /// correct copy or move constructor. We can eliminate the pointer once /// https://reviews.llvm.org/D113240 lands. std::unique_ptr> liveness; /// The operation range storage index for this range. std::optional opRangeIndex; /// The type range storage index for this range. std::optional typeRangeIndex; /// The value range storage index for this range. std::optional valueRangeIndex; }; } // namespace void Generator::generate(ModuleOp module) { auto matcherFunc = module.lookupSymbol( pdl_interp::PDLInterpDialect::getMatcherFunctionName()); ModuleOp rewriterModule = module.lookupSymbol( pdl_interp::PDLInterpDialect::getRewriterModuleName()); assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); // Allocate memory indices for the results of operations within the matcher // and rewriters. allocateMemoryIndices(matcherFunc, rewriterModule); // Generate code for the rewriter functions. ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); for (auto rewriterFunc : rewriterModule.getOps()) { rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); for (Operation &op : rewriterFunc.getOps()) generate(&op, rewriterByteCodeWriter); } assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && "unexpected branches in rewriter function"); // Generate code for the matcher function. ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); generate(&matcherFunc.getBody(), matcherByteCodeWriter); // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { ByteCodeAddr addr = blockToAddr[it.first]; for (unsigned offsetToFix : it.second) std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); } } void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule) { // Rewriters use simplistic allocation scheme that simply assigns an index to // each result. for (auto rewriterFunc : rewriterModule.getOps()) { ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; auto processRewriterValue = [&](Value val) { valueToMemIndex.try_emplace(val, index++); if (pdl::RangeType rangeType = dyn_cast(val.getType())) { Type elementTy = rangeType.getElementType(); if (isa(elementTy)) valueToRangeIndex.try_emplace(val, typeRangeIndex++); else if (isa(elementTy)) valueToRangeIndex.try_emplace(val, valueRangeIndex++); } }; for (BlockArgument arg : rewriterFunc.getArguments()) processRewriterValue(arg); rewriterFunc.getBody().walk([&](Operation *op) { for (Value result : op->getResults()) processRewriterValue(result); }); if (index > maxValueMemoryIndex) maxValueMemoryIndex = index; if (typeRangeIndex > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = typeRangeIndex; if (valueRangeIndex > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = valueRangeIndex; } // The matcher function uses a more sophisticated numbering that tries to // minimize the number of memory indices assigned. This is done by determining // a live range of the values within the matcher, then the allocation is just // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. DenseMap opToFirstIndex; DenseMap opToLastIndex; // A custom walk that marks the first and the last index of each operation. // The entry marks the beginning of the liveness range for this operation, // followed by nested operations, followed by the end of the liveness range. unsigned index = 0; llvm::unique_function walk = [&](Operation *op) { opToFirstIndex.try_emplace(op, index++); for (Region ®ion : op->getRegions()) for (Block &block : region.getBlocks()) for (Operation &nested : block) walk(&nested); opToLastIndex.try_emplace(op, index++); }; walk(matcherFunc); // Liveness info for each of the defs within the matcher. ByteCodeLiveRange::Allocator allocator; DenseMap valueDefRanges; // Assign the root operation being matched to slot 0. BlockArgument rootOpArg = matcherFunc.getArgument(0); valueToMemIndex[rootOpArg] = 0; // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); matcherFunc->walk([&](Block *block) { const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always // assigned to the first memory slot. if (value == rootOpArg) return; // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; defRangeIt->second.liveness->insert( opToFirstIndex[firstUseOrDef], opToLastIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); // Check to see if this value is a range type. if (auto rangeTy = dyn_cast(value.getType())) { Type eleType = rangeTy.getElementType(); if (isa(eleType)) defRangeIt->second.opRangeIndex = 0; else if (isa(eleType)) defRangeIt->second.typeRangeIndex = 0; else if (isa(eleType)) defRangeIt->second.valueRangeIndex = 0; } }; // Process the live-ins of this block. for (Value liveIn : info->in()) { // Only process the value if it has been defined in the current region. // Other values that span across pdl_interp.foreach will be added higher // up. This ensures that the we keep them alive for the entire duration // of the loop. if (liveIn.getParentRegion() == block->getParent()) processValue(liveIn, &block->front()); } // Process the block arguments for the entry block (those are not live-in). if (block->isEntryBlock()) { for (Value argument : block->getArguments()) processValue(argument, &block->front()); } // Process any new defs within this block. for (Operation &op : *block) for (Value result : op.getResults()) processValue(result, &op); }); // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; // The number of memory indices currently allocated (and its next value). // Recall that the root gets allocated memory index 0. ByteCodeField numIndices = 1; // The number of memory ranges of various types (and their next values). ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; // Try to allocate to an existing index. for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { ByteCodeLiveRange &existingRange = existingIndexIt.value(); if (!defRange.overlaps(existingRange)) { existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; if (defRange.opRangeIndex) { if (!existingRange.opRangeIndex) existingRange.opRangeIndex = numOpRanges++; valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; } else if (defRange.valueRangeIndex) { if (!existingRange.valueRangeIndex) existingRange.valueRangeIndex = numValueRanges++; valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; } break; } } // If no existing index could be used, add a new one. if (memIndex == 0) { allocatedIndices.emplace_back(allocator); ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); // Allocate an index for op/type/value ranges. if (defRange.opRangeIndex) { newRange.opRangeIndex = numOpRanges; valueToRangeIndex[defIt.first] = numOpRanges++; } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { newRange.valueRangeIndex = numValueRanges; valueToRangeIndex[defIt.first] = numValueRanges++; } memIndex = allocatedIndices.size(); ++numIndices; } } // Print the index usage and ensure that we did not run out of index space. LLVM_DEBUG({ llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " << "(down from initial " << valueDefRanges.size() << ").\n"; }); assert(allocatedIndices.size() <= std::numeric_limits::max() && "Ran out of memory for allocated indices"); // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; if (numOpRanges > maxOpRangeMemoryIndex) maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } void Generator::generate(Region *region, ByteCodeWriter &writer) { llvm::ReversePostOrderTraversal rpot(region); for (Block *block : rpot) { // Keep track of where this block begins within the matcher function. blockToAddr.try_emplace(block, matcherByteCode.size()); for (Operation &op : *block) generate(&op, writer); } } void Generator::generate(Operation *op, ByteCodeWriter &writer) { LLVM_DEBUG({ // The following list must contain all the operations that do not // produce any bytecode. if (!isa(op)) writer.appendInline(op->getLoc()); }); TypeSwitch(op) .Case( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); }); } void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { // Constraints that should return a value have to be registered as rewrites. // If a constraint and a rewrite of similar name are registered the // constraint takes precedence writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); writer.append(ByteCodeField(op.getIsNegated())); ResultRange results = op.getResults(); writer.append(ByteCodeField(results.size())); for (Value result : results) { // We record the expected kind of the result, so that we can provide extra // verification of the native rewrite function and handle the failure case // of constraints accordingly. writer.appendPDLValueKind(result); // Range results also need to append the range storage index. if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); writer.append(result); } writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer) { assert(externalRewriterToMemIndex.count(op.getName()) && "expected index for rewrite function"); writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); ResultRange results = op.getResults(); writer.append(ByteCodeField(results.size())); for (Value result : results) { // We record the expected kind of the result, so that we // can provide extra verification of the native rewrite function. writer.appendPDLValueKind(result); // Range results also need to append the range storage index. if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); writer.append(result); } } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { Value lhs = op.getLhs(); if (isa(lhs.getType())) { writer.append(OpCode::AreRangesEqual); writer.appendPDLValueKind(lhs); writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); return; } writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); } void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); } void Generator::generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), static_cast(op.getCompareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperationName, op.getInputOp(), OperationName(op.getName(), ctx), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), static_cast(op.getCompareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.getValue(), op.getType(), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), op.getSuccessors()); } void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); } void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); } void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateOperation, op.getResultOp(), OperationName(op.getName(), ctx)); writer.appendPDLValueList(op.getInputOperands()); // Add the attributes. OperandRange attributes = op.getInputAttributes(); writer.append(static_cast(attributes.size())); for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) writer.append(std::get<0>(it), std::get<1>(it)); // Add the result types. If the operation has inferred results, we use a // marker "size" value. Otherwise, we add the list of explicit result types. if (op.getInferredResultTypes()) writer.append(kInferTypesMarker); else writer.appendPDLValueList(op.getInputResultTypes()); } void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { // Append the correct opcode for the range type. TypeSwitch(op.getType().getElementType()) .Case( [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); }) .Case([&](pdl::ValueType) { writer.append(OpCode::CreateDynamicValueRange); }); writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); writer.appendPDLValueList(op->getOperands()); } void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.getResult()) = getMemIndex(op.getValue()); } void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateConstantTypeRange, op.getResult(), getRangeStorageIndex(op.getResult()), op.getValue()); } void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.getInputOp()); } void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { OpCode opCode = TypeSwitch(op.getResult().getType()) .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) .Case([](pdl::TypeType) { return OpCode::ExtractType; }) .Default([](Type) -> OpCode { llvm_unreachable("unsupported element type"); }); writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); } void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { writer.append(OpCode::Finalize); } void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { BlockArgument arg = op.getLoopVariable(); writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); writer.appendPDLValueKind(arg.getType()); writer.append(curLoopLevel, op.getSuccessor()); ++curLoopLevel; if (curLoopLevel > maxLoopLevel) maxLoopLevel = curLoopLevel; generate(&op.getRegion(), writer); --curLoopLevel; } void Generator::generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), op.getNameAttr()); } void Generator::generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); } void Generator::generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetDefiningOp, op.getInputOp()); writer.appendPDLValue(op.getValue()); } void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { uint32_t index = op.getIndex(); if (index < 4) writer.append(static_cast(OpCode::GetOperand0 + index)); else writer.append(OpCode::GetOperandN, index); writer.append(op.getInputOp(), op.getValue()); } void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { Value result = op.getValue(); std::optional index = op.getIndex(); writer.append(OpCode::GetOperands, index.value_or(std::numeric_limits::max()), op.getInputOp()); if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); writer.append(result); } void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { uint32_t index = op.getIndex(); if (index < 4) writer.append(static_cast(OpCode::GetResult0 + index)); else writer.append(OpCode::GetResultN, index); writer.append(op.getInputOp(), op.getValue()); } void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { Value result = op.getValue(); std::optional index = op.getIndex(); writer.append(OpCode::GetResults, index.value_or(std::numeric_limits::max()), op.getInputOp()); if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); writer.append(result); } void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { Value operations = op.getOperations(); ByteCodeField rangeIndex = getRangeStorageIndex(operations); writer.append(OpCode::GetUsers, operations, rangeIndex); writer.appendPDLValue(op.getValue()); } void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { if (isa(op.getType())) { Value result = op.getResult(); writer.append(OpCode::GetValueRangeTypes, result, getRangeStorageIndex(result), op.getValue()); } else { writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); } } void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); } void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( op, configMap.lookup(op), rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.getMatchedOps()); writer.appendPDLValueList(op.getInputs()); } void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { writer.append(OpCode::ReplaceOp, op.getInputOp()); writer.appendPDLValueList(op.getReplValues()); } void Generator::generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchAttribute, op.getAttribute(), op.getCaseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchOperandCount, op.getInputOp(), op.getCaseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer) { auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { return OperationName(cast(attr).getValue(), ctx); }); writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchResultCount, op.getInputOp(), op.getCaseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), op.getSuccessors()); } //===----------------------------------------------------------------------===// // PDLByteCode //===----------------------------------------------------------------------===// PDLByteCode::PDLByteCode( ModuleOp module, SmallVector> configs, const DenseMap &configMap, llvm::StringMap constraintFns, llvm::StringMap rewriteFns) : configs(std::move(configs)) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, maxLoopLevel, constraintFns, rewriteFns, configMap); generator.generate(module); // Initialize the external functions. for (auto &it : constraintFns) constraintFunctions.push_back(std::move(it.second)); for (auto &it : rewriteFns) rewriteFunctions.push_back(std::move(it.second)); } /// Initialize the given state such that it can be used to execute the current /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); state.opRangeMemory.resize(maxOpRangeCount); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.loopIndex.resize(maxLoopLevel, 0); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); } //===----------------------------------------------------------------------===// // ByteCode Execution namespace { /// This class is an instantiation of the PDLResultList that provides access to /// the returned results. This API is not on `PDLResultList` to avoid /// overexposing access to information specific solely to the ByteCode. class ByteCodeRewriteResultList : public PDLResultList { public: ByteCodeRewriteResultList(unsigned maxNumResults) : PDLResultList(maxNumResults) {} /// Return the list of PDL results. MutableArrayRef getResults() { return results; } /// Return the type ranges allocated by this list. MutableArrayRef> getAllocatedTypeRanges() { return allocatedTypeRanges; } /// Return the value ranges allocated by this list. MutableArrayRef> getAllocatedValueRanges() { return allocatedValueRanges; } }; /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, MutableArrayRef> opRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, std::vector> &allocatedValueRangeMemory, MutableArrayRef loopIndex, ArrayRef uniquedMemory, ArrayRef code, ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, ArrayRef rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching /// context. LogicalResult execute(PatternRewriter &rewriter, SmallVectorImpl *matches = nullptr, std::optional mainRewriteLoc = {}); private: /// Internal implementation of executing each of the bytecode commands. void executeApplyConstraint(PatternRewriter &rewriter); LogicalResult executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); void executeAreRangesEqual(); void executeBranch(); void executeCheckOperandCount(); void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); void executeContinue(); void executeCreateConstantTypeRange(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); template void executeDynamicCreateRange(StringRef type); void executeEraseOp(PatternRewriter &rewriter); template void executeExtract(); void executeFinalize(); void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); void executeGetOperand(unsigned index); void executeGetOperands(); void executeGetResult(unsigned index); void executeGetResults(); void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); void executeIsNotNull(); void executeRecordMatch(PatternRewriter &rewriter, SmallVectorImpl &matches); void executeReplaceOp(PatternRewriter &rewriter); void executeSwitchAttribute(); void executeSwitchOperandCount(); void executeSwitchOperationName(); void executeSwitchResultCount(); void executeSwitchType(); void executeSwitchTypes(); void processNativeFunResults(ByteCodeRewriteResultList &results, unsigned numResults, LogicalResult &rewriteResult); /// Pushes a code iterator to the stack. void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } /// Pops a code iterator from the stack, returning true on success. void popCodeIt() { assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); curCodeIt = resumeCodeIt.back(); resumeCodeIt.pop_back(); } /// Return the bytecode iterator at the start of the current op code. const ByteCodeField *getPrevCodeIt() const { LLVM_DEBUG({ // Account for the op code and the Location stored inline. return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); }); // Account for the op code only. return curCodeIt - 1; } /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. template T read(size_t skipN = 0) { curCodeIt += skipN; return readImpl(); } ByteCodeField read(size_t skipN = 0) { return read(skipN); } /// Read a list of values from the bytecode buffer. template void readList(SmallVectorImpl &list) { list.clear(); for (unsigned i = 0, e = read(); i != e; ++i) list.push_back(read()); } /// Read a list of values from the bytecode buffer. The values may be encoded /// either as a single element or a range of elements. void readList(SmallVectorImpl &list) { for (unsigned i = 0, e = read(); i != e; ++i) { if (read() == PDLValue::Kind::Type) { list.push_back(read()); } else { TypeRange *values = read(); list.append(values->begin(), values->end()); } } } void readList(SmallVectorImpl &list) { for (unsigned i = 0, e = read(); i != e; ++i) { if (read() == PDLValue::Kind::Value) { list.push_back(read()); } else { ValueRange *values = read(); list.append(values->begin(), values->end()); } } } /// Read a value stored inline as a pointer. template std::enable_if_t::value, T> readInline() { const void *pointer; std::memcpy(&pointer, curCodeIt, sizeof(const void *)); curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); return T::getFromOpaquePointer(pointer); } void skip(size_t skipN) { curCodeIt += skipN; } /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. void selectJump(size_t destIndex) { curCodeIt = &code[read(destIndex * 2)]; } /// Handle a switch operation with the provided value and cases. template > void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { LLVM_DEBUG({ llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; llvm::interleaveComma(cases, llvm::dbgs()); llvm::dbgs() << "\n"; }); // Check to see if the attribute value is within the case list. Jump to // the correct successor index based on the result. for (auto it = cases.begin(), e = cases.end(); it != e; ++it) if (cmp(*it, value)) return selectJump(size_t((it - cases.begin()) + 1)); selectJump(size_t(0)); } /// Store a pointer to memory. void storeToMemory(unsigned index, const void *value) { memory[index] = value; } /// Store a value to memory as an opaque pointer. template std::enable_if_t::value> storeToMemory(unsigned index, T value) { memory[index] = value.getAsOpaquePointer(); } /// Internal implementation of reading various data types from the bytecode /// stream. template const void *readFromMemory() { size_t index = *curCodeIt++; // If this type is an SSA value, it can only be stored in non-const memory. if (llvm::is_one_of::value || index < memory.size()) return memory[index]; // Otherwise, if this index is not inbounds it is uniqued. return uniquedMemory[index - memory.size()]; } template std::enable_if_t::value, T> readImpl() { return reinterpret_cast(const_cast(readFromMemory())); } template std::enable_if_t::value && !std::is_same::value, T> readImpl() { return T(T::getFromOpaquePointer(readFromMemory())); } template std::enable_if_t::value, T> readImpl() { switch (read()) { case PDLValue::Kind::Attribute: return read(); case PDLValue::Kind::Operation: return read(); case PDLValue::Kind::Type: return read(); case PDLValue::Kind::Value: return read(); case PDLValue::Kind::TypeRange: return read(); case PDLValue::Kind::ValueRange: return read(); } llvm_unreachable("unhandled PDLValue::Kind"); } template std::enable_if_t::value, T> readImpl() { static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, "unexpected ByteCode address size"); ByteCodeAddr result; std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); curCodeIt += 2; return result; } template std::enable_if_t::value, T> readImpl() { return *curCodeIt++; } template std::enable_if_t::value, T> readImpl() { return static_cast(readImpl()); } /// Assign the given range to the given memory index. This allocates a new /// range object if necessary. template > void assignRangeToMemory(RangeT &&range, unsigned memIndex, unsigned rangeIndex) { // Utility functor used to type-erase the assignment. auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { // If the input range is empty, we don't need to allocate anything. if (range.empty()) { rangeMemory[rangeIndex] = {}; } else { // Allocate a buffer for this type range. llvm::OwningArrayRef storage(llvm::size(range)); llvm::copy(range, storage.begin()); // Assign this to the range slot and use the range as the value for the // memory index. allocatedRangeMemory.emplace_back(std::move(storage)); rangeMemory[rangeIndex] = allocatedRangeMemory.back(); } memory[memIndex] = &rangeMemory[rangeIndex]; }; // Dispatch based on the concrete range type. if constexpr (std::is_same_v) { return assignRange(allocatedTypeRangeMemory, typeRangeMemory); } else if constexpr (std::is_same_v) { return assignRange(allocatedValueRangeMemory, valueRangeMemory); } else { llvm_unreachable("unhandled range type"); } } /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; /// The stack of bytecode positions at which to resume operation. SmallVector resumeCodeIt; /// The current execution memory. MutableArrayRef memory; MutableArrayRef opRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; std::vector> &allocatedValueRangeMemory; /// The current loop indices. MutableArrayRef loopIndex; /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; ArrayRef code; ArrayRef currentPatternBenefits; ArrayRef patterns; ArrayRef constraintFunctions; ArrayRef rewriteFunctions; }; } // namespace void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); ByteCodeField fun_idx = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n"; }); ByteCodeField isNegated = read(); LLVM_DEBUG({ llvm::dbgs() << " * isNegated: " << isNegated << "\n"; llvm::interleaveComma(args, llvm::dbgs()); }); ByteCodeField numResults = read(); const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx]; ByteCodeRewriteResultList results(numResults); LogicalResult rewriteResult = constraintFn(rewriter, results, args); [[maybe_unused]] ArrayRef constraintResults = results.getResults(); LLVM_DEBUG({ if (succeeded(rewriteResult)) { llvm::dbgs() << " * Constraint succeeded\n"; llvm::dbgs() << " * Results: "; llvm::interleaveComma(constraintResults, llvm::dbgs()); llvm::dbgs() << "\n"; } else { llvm::dbgs() << " * Constraint failed\n"; } }); assert((failed(rewriteResult) || constraintResults.size() == numResults) && "native PDL rewrite function succeeded but returned " "unexpected number of results"); processNativeFunResults(results, numResults, rewriteResult); // Depending on the constraint jump to the proper destination. selectJump(isNegated != succeeded(rewriteResult)); } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); }); // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); LogicalResult rewriteResult = rewriteFn(rewriter, results, args); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); processNativeFunResults(results, numResults, rewriteResult); if (failed(rewriteResult)) { LLVM_DEBUG(llvm::dbgs() << " - Failed"); return failure(); } return success(); } void ByteCodeExecutor::processNativeFunResults( ByteCodeRewriteResultList &results, unsigned numResults, LogicalResult &rewriteResult) { // Store the results in the bytecode memory or handle missing results on // failure. for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) { PDLValue::Kind resultKind = read(); // Skip the according number of values on the buffer on failure and exit // early as there are no results to process. if (failed(rewriteResult)) { if (resultKind == PDLValue::Kind::TypeRange || resultKind == PDLValue::Kind::ValueRange) { skip(2); } else { skip(1); } return; } PDLValue result = results.getResults()[resultIdx]; LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); assert(result.getKind() == resultKind && "native PDL rewrite function returned an unexpected type of " "result"); // If the result is a range, we need to copy it over to the bytecodes // range memory. if (std::optional typeRange = result.dyn_cast()) { unsigned rangeIndex = read(); typeRangeMemory[rangeIndex] = *typeRange; memory[read()] = &typeRangeMemory[rangeIndex]; } else if (std::optional valueRange = result.dyn_cast()) { unsigned rangeIndex = read(); valueRangeMemory[rangeIndex] = *valueRange; memory[read()] = &valueRangeMemory[rangeIndex]; } else { memory[read()] = result.getAsOpaquePointer(); } } // Copy over any underlying storage allocated for result ranges. for (auto &it : results.getAllocatedTypeRanges()) allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); } void ByteCodeExecutor::executeAreEqual() { LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); const void *lhs = read(); const void *rhs = read(); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); selectJump(lhs == rhs); } void ByteCodeExecutor::executeAreRangesEqual() { LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); PDLValue::Kind valueKind = read(); const void *lhs = read(); const void *rhs = read(); switch (valueKind) { case PDLValue::Kind::TypeRange: { const TypeRange *lhsRange = reinterpret_cast(lhs); const TypeRange *rhsRange = reinterpret_cast(rhs); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); selectJump(*lhsRange == *rhsRange); break; } case PDLValue::Kind::ValueRange: { const auto *lhsRange = reinterpret_cast(lhs); const auto *rhsRange = reinterpret_cast(rhs); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); selectJump(*lhsRange == *rhsRange); break; } default: llvm_unreachable("unexpected `AreRangesEqual` value kind"); } } void ByteCodeExecutor::executeBranch() { LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); curCodeIt = &code[read()]; } void ByteCodeExecutor::executeCheckOperandCount() { LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" << " * Expected: " << expectedCount << "\n" << " * Comparator: " << (compareAtLeast ? ">=" : "==") << "\n"); if (compareAtLeast) selectJump(op->getNumOperands() >= expectedCount); else selectJump(op->getNumOperands() == expectedCount); } void ByteCodeExecutor::executeCheckOperationName() { LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); Operation *op = read(); OperationName expectedName = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" << " * Expected: \"" << expectedName << "\"\n"); selectJump(op->getName() == expectedName); } void ByteCodeExecutor::executeCheckResultCount() { LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" << " * Expected: " << expectedCount << "\n" << " * Comparator: " << (compareAtLeast ? ">=" : "==") << "\n"); if (compareAtLeast) selectJump(op->getNumResults() >= expectedCount); else selectJump(op->getNumResults() == expectedCount); } void ByteCodeExecutor::executeCheckTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); TypeRange *lhs = read(); Attribute rhs = read(); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); selectJump(*lhs == cast(rhs).getAsValueRange()); } void ByteCodeExecutor::executeContinue() { ByteCodeField level = read(); LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" << " * Level: " << level << "\n"); ++loopIndex[level]; popCodeIt(); } void ByteCodeExecutor::executeCreateConstantTypeRange() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); ArrayAttr typesAttr = cast(read()); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); assignRangeToMemory(typesAttr.getAsValueRange(), memIndex, rangeIndex); } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc) { LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); unsigned memIndex = read(); OperationState state(mainRewriteLoc, read()); readList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { StringAttr name = read(); if (Attribute attr = read()) state.addAttribute(name, attr); } // Read in the result types. If the "size" is the sentinel value, this // indicates that the result types should be inferred. unsigned numResults = read(); if (numResults == kInferTypesMarker) { InferTypeOpInterface::Concept *inferInterface = state.name.getInterface(); assert(inferInterface && "expected operation to provide InferTypeOpInterface"); // TODO: Handle failure. if (failed(inferInterface->inferReturnTypes( state.getContext(), state.location, state.operands, state.attributes.getDictionary(state.getContext()), state.getRawProperties(), state.regions, state.types))) return; } else { // Otherwise, this is a fixed number of results. for (unsigned i = 0; i != numResults; ++i) { if (read() == PDLValue::Kind::Type) { state.types.push_back(read()); } else { TypeRange *resultTypes = read(); state.types.append(resultTypes->begin(), resultTypes->end()); } } } Operation *resultOp = rewriter.create(state); memory[memIndex] = resultOp; LLVM_DEBUG({ llvm::dbgs() << " * Attributes: " << state.attributes.getDictionary(state.getContext()) << "\n * Operands: "; llvm::interleaveComma(state.operands, llvm::dbgs()); llvm::dbgs() << "\n * Result Types: "; llvm::interleaveComma(state.types, llvm::dbgs()); llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; }); } template void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); SmallVector values; readList(values); LLVM_DEBUG({ llvm::dbgs() << "\n * " << type << "s: "; llvm::interleaveComma(values, llvm::dbgs()); llvm::dbgs() << "\n"; }); assignRangeToMemory(values, memIndex, rangeIndex); } void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); Operation *op = read(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); rewriter.eraseOp(op); } template void ByteCodeExecutor::executeExtract() { LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); Range *range = read(); unsigned index = read(); unsigned memIndex = read(); if (!range) { memory[memIndex] = nullptr; return; } T result = index < range->size() ? (*range)[index] : T(); LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" << " * Index: " << index << "\n" << " * Result: " << result << "\n"); storeToMemory(memIndex, result); } void ByteCodeExecutor::executeFinalize() { LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); } void ByteCodeExecutor::executeForEach() { LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); const ByteCodeField *prevCodeIt = getPrevCodeIt(); unsigned rangeIndex = read(); unsigned memIndex = read(); const void *value = nullptr; switch (read()) { case PDLValue::Kind::Operation: { unsigned &index = loopIndex[read()]; ArrayRef array = opRangeMemory[rangeIndex]; assert(index <= array.size() && "iterated past the end"); if (index < array.size()) { LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); value = array[index]; break; } LLVM_DEBUG(llvm::dbgs() << " * Done\n"); index = 0; selectJump(size_t(0)); return; } default: llvm_unreachable("unexpected `ForEach` value kind"); } // Store the iterate value and the stack address. memory[memIndex] = value; pushCodeIt(prevCodeIt); // Skip over the successor (we will enter the body of the loop). read(); } void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); Operation *op = read(); StringAttr attrName = read(); Attribute attr = op->getAttr(attrName); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" << " * Attribute: " << attrName << "\n" << " * Result: " << attr << "\n"); memory[memIndex] = attr.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetAttributeType() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); unsigned memIndex = read(); Attribute attr = read(); Type type; if (auto typedAttr = dyn_cast(attr)) type = typedAttr.getType(); LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" << " * Result: " << type << "\n"); memory[memIndex] = type.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetDefiningOp() { LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); unsigned memIndex = read(); Operation *op = nullptr; if (read() == PDLValue::Kind::Value) { Value value = read(); if (value) op = value.getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); } else { ValueRange *values = read(); if (values && !values->empty()) { op = values->front().getDefiningOp(); } LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); } LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); memory[memIndex] = op; } void ByteCodeExecutor::executeGetOperand(unsigned index) { Operation *op = read(); unsigned memIndex = read(); Value operand = index < op->getNumOperands() ? op->getOperand(index) : Value(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" << " * Index: " << index << "\n" << " * Result: " << operand << "\n"); memory[memIndex] = operand.getAsOpaquePointer(); } /// This function is the internal implementation of `GetResults` and /// `GetOperands` that provides support for extracting a value range from the /// given operation. template