18a1ca2cdSRiver Riddle //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
28a1ca2cdSRiver Riddle //
38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a1ca2cdSRiver Riddle //
78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
88a1ca2cdSRiver Riddle
98a1ca2cdSRiver Riddle #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
1067d0d7acSMichele Scuttari
118a1ca2cdSRiver Riddle #include "PredicateTree.h"
128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
138a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
148a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
158a1ca2cdSRiver Riddle #include "mlir/Pass/Pass.h"
168a1ca2cdSRiver Riddle #include "llvm/ADT/MapVector.h"
178a1ca2cdSRiver Riddle #include "llvm/ADT/ScopedHashTable.h"
181d49e535SGuillaume Chatelet #include "llvm/ADT/Sequence.h"
198a1ca2cdSRiver Riddle #include "llvm/ADT/SetVector.h"
201d49e535SGuillaume Chatelet #include "llvm/ADT/SmallVector.h"
218a1ca2cdSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
228a1ca2cdSRiver Riddle
2367d0d7acSMichele Scuttari namespace mlir {
2467d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
2567d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2667d0d7acSMichele Scuttari } // namespace mlir
2767d0d7acSMichele Scuttari
288a1ca2cdSRiver Riddle using namespace mlir;
298a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp;
308a1ca2cdSRiver Riddle
318a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
328a1ca2cdSRiver Riddle // PatternLowering
338a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
348a1ca2cdSRiver Riddle
358a1ca2cdSRiver Riddle namespace {
368a1ca2cdSRiver Riddle /// This class generators operations within the PDL Interpreter dialect from a
378a1ca2cdSRiver Riddle /// given module containing PDL pattern operations.
388a1ca2cdSRiver Riddle struct PatternLowering {
398a1ca2cdSRiver Riddle public:
408c66344eSRiver Riddle PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
418c66344eSRiver Riddle DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
428a1ca2cdSRiver Riddle
438a1ca2cdSRiver Riddle /// Generate code for matching and rewriting based on the pattern operations
448a1ca2cdSRiver Riddle /// within the module.
458a1ca2cdSRiver Riddle void lower(ModuleOp module);
468a1ca2cdSRiver Riddle
478a1ca2cdSRiver Riddle private:
488a1ca2cdSRiver Riddle using ValueMap = llvm::ScopedHashTable<Position *, Value>;
498a1ca2cdSRiver Riddle using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
508a1ca2cdSRiver Riddle
518a1ca2cdSRiver Riddle /// Generate interpreter operations for the tree rooted at the given matcher
52a76ee58fSStanislav Funiak /// node, in the specified region.
53*8ec28af8SMatthias Gehre Block *generateMatcher(MatcherNode &node, Region ®ion,
54*8ec28af8SMatthias Gehre Block *block = nullptr);
558a1ca2cdSRiver Riddle
56a76ee58fSStanislav Funiak /// Get or create an access to the provided positional value in the current
57a76ee58fSStanislav Funiak /// block. This operation may mutate the provided block pointer if nested
58a76ee58fSStanislav Funiak /// regions (i.e., pdl_interp.iterate) are required.
59a76ee58fSStanislav Funiak Value getValueAt(Block *¤tBlock, Position *pos);
608a1ca2cdSRiver Riddle
61a76ee58fSStanislav Funiak /// Create the interpreter predicate operations. This operation may mutate the
62a76ee58fSStanislav Funiak /// provided current block pointer if nested regions (iterates) are required.
63a76ee58fSStanislav Funiak void generate(BoolNode *boolNode, Block *¤tBlock, Value val);
648a1ca2cdSRiver Riddle
65a76ee58fSStanislav Funiak /// Create the interpreter switch / predicate operations, with several case
66a76ee58fSStanislav Funiak /// destinations. This operation never mutates the provided current block
67a76ee58fSStanislav Funiak /// pointer, because the switch operation does not need Values beyond `val`.
68a76ee58fSStanislav Funiak void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
698a1ca2cdSRiver Riddle
70a76ee58fSStanislav Funiak /// Create the interpreter operations to record a successful pattern match
71a76ee58fSStanislav Funiak /// using the contained root operation. This operation may mutate the current
72a76ee58fSStanislav Funiak /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73a76ee58fSStanislav Funiak void generate(SuccessNode *successNode, Block *¤tBlock);
748a1ca2cdSRiver Riddle
758a1ca2cdSRiver Riddle /// Generate a rewriter function for the given pattern operation, and returns
768a1ca2cdSRiver Riddle /// a reference to that function.
778a1ca2cdSRiver Riddle SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
788a1ca2cdSRiver Riddle SmallVectorImpl<Position *> &usedMatchValues);
798a1ca2cdSRiver Riddle
808a1ca2cdSRiver Riddle /// Generate the rewriter code for the given operation.
8102c4c0d5SRiver Riddle void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
8202c4c0d5SRiver Riddle DenseMap<Value, Value> &rewriteValues,
8302c4c0d5SRiver Riddle function_ref<Value(Value)> mapRewriteValue);
848a1ca2cdSRiver Riddle void generateRewriter(pdl::AttributeOp attrOp,
858a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
868a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
878a1ca2cdSRiver Riddle void generateRewriter(pdl::EraseOp eraseOp,
888a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
898a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
908a1ca2cdSRiver Riddle void generateRewriter(pdl::OperationOp operationOp,
918a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
928a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
93ce57789dSRiver Riddle void generateRewriter(pdl::RangeOp rangeOp,
94ce57789dSRiver Riddle DenseMap<Value, Value> &rewriteValues,
95ce57789dSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
968a1ca2cdSRiver Riddle void generateRewriter(pdl::ReplaceOp replaceOp,
978a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
988a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
99242762c9SRiver Riddle void generateRewriter(pdl::ResultOp resultOp,
100242762c9SRiver Riddle DenseMap<Value, Value> &rewriteValues,
101242762c9SRiver Riddle function_ref<Value(Value)> mapRewriteValue);
1023a833a0eSRiver Riddle void generateRewriter(pdl::ResultsOp resultOp,
1033a833a0eSRiver Riddle DenseMap<Value, Value> &rewriteValues,
1043a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
1058a1ca2cdSRiver Riddle void generateRewriter(pdl::TypeOp typeOp,
1068a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
1078a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
1083a833a0eSRiver Riddle void generateRewriter(pdl::TypesOp typeOp,
1093a833a0eSRiver Riddle DenseMap<Value, Value> &rewriteValues,
1103a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
1118a1ca2cdSRiver Riddle
1128a1ca2cdSRiver Riddle /// Generate the values used for resolving the result types of an operation
1133c752289SRiver Riddle /// created within a dag rewriter region. If the result types of the operation
1143c752289SRiver Riddle /// should be inferred, `hasInferredResultTypes` is set to true.
1158a1ca2cdSRiver Riddle void generateOperationResultTypeRewriter(
1163c752289SRiver Riddle pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
1173c752289SRiver Riddle SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
1183c752289SRiver Riddle bool &hasInferredResultTypes);
1198a1ca2cdSRiver Riddle
1208a1ca2cdSRiver Riddle /// A builder to use when generating interpreter operations.
1218a1ca2cdSRiver Riddle OpBuilder builder;
1228a1ca2cdSRiver Riddle
1238a1ca2cdSRiver Riddle /// The matcher function used for all match related logic within PDL patterns.
124f96a8675SRiver Riddle pdl_interp::FuncOp matcherFunc;
1258a1ca2cdSRiver Riddle
1268a1ca2cdSRiver Riddle /// The rewriter module containing the all rewrite related logic within PDL
1278a1ca2cdSRiver Riddle /// patterns.
1288a1ca2cdSRiver Riddle ModuleOp rewriterModule;
1298a1ca2cdSRiver Riddle
1308a1ca2cdSRiver Riddle /// The symbol table of the rewriter module used for insertion.
1318a1ca2cdSRiver Riddle SymbolTable rewriterSymbolTable;
1328a1ca2cdSRiver Riddle
1338a1ca2cdSRiver Riddle /// A scoped map connecting a position with the corresponding interpreter
1348a1ca2cdSRiver Riddle /// value.
1358a1ca2cdSRiver Riddle ValueMap values;
1368a1ca2cdSRiver Riddle
1378a1ca2cdSRiver Riddle /// A stack of blocks used as the failure destination for matcher nodes that
1388a1ca2cdSRiver Riddle /// don't have an explicit failure path.
1398a1ca2cdSRiver Riddle SmallVector<Block *, 8> failureBlockStack;
1408a1ca2cdSRiver Riddle
1418a1ca2cdSRiver Riddle /// A mapping between values defined in a pattern match, and the corresponding
1428a1ca2cdSRiver Riddle /// positional value.
1438a1ca2cdSRiver Riddle DenseMap<Value, Position *> valueToPosition;
1448a1ca2cdSRiver Riddle
1457557530fSFangrui Song /// The set of operation values whose location will be used for newly
1468a1ca2cdSRiver Riddle /// generated operations.
1474efb7754SRiver Riddle SetVector<Value> locOps;
1488c66344eSRiver Riddle
1498c66344eSRiver Riddle /// A mapping between pattern operations and the corresponding configuration
1508c66344eSRiver Riddle /// set.
1518c66344eSRiver Riddle DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152*8ec28af8SMatthias Gehre
153*8ec28af8SMatthias Gehre /// A mapping from a constraint question to the ApplyConstraintOp
154*8ec28af8SMatthias Gehre /// that implements it.
155*8ec28af8SMatthias Gehre DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
1568a1ca2cdSRiver Riddle };
157be0a7e9fSMehdi Amini } // namespace
1588a1ca2cdSRiver Riddle
PatternLowering(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule,DenseMap<Operation *,PDLPatternConfigSet * > * configMap)1598c66344eSRiver Riddle PatternLowering::PatternLowering(
1608c66344eSRiver Riddle pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
1618c66344eSRiver Riddle DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
1628a1ca2cdSRiver Riddle : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
1638c66344eSRiver Riddle rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
1648c66344eSRiver Riddle configMap(configMap) {}
1658a1ca2cdSRiver Riddle
lower(ModuleOp module)1668a1ca2cdSRiver Riddle void PatternLowering::lower(ModuleOp module) {
1678a1ca2cdSRiver Riddle PredicateUniquer predicateUniquer;
1688a1ca2cdSRiver Riddle PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
1698a1ca2cdSRiver Riddle
1708a1ca2cdSRiver Riddle // Define top-level scope for the arguments to the matcher function.
1718a1ca2cdSRiver Riddle ValueMapScope topLevelValueScope(values);
1728a1ca2cdSRiver Riddle
1738a1ca2cdSRiver Riddle // Insert the root operation, i.e. argument to the matcher, at the root
1748a1ca2cdSRiver Riddle // position.
175f96a8675SRiver Riddle Block *matcherEntryBlock = &matcherFunc.front();
1768a1ca2cdSRiver Riddle values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
1778a1ca2cdSRiver Riddle
1788a1ca2cdSRiver Riddle // Generate a root matcher node from the provided PDL module.
1798a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
1808a1ca2cdSRiver Riddle module, predicateBuilder, valueToPosition);
181a76ee58fSStanislav Funiak Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
182a76ee58fSStanislav Funiak assert(failureBlockStack.empty() && "failed to empty the stack");
1838a1ca2cdSRiver Riddle
1848a1ca2cdSRiver Riddle // After generation, merged the first matched block into the entry.
1858a1ca2cdSRiver Riddle matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
1868a1ca2cdSRiver Riddle firstMatcherBlock->getOperations());
1878a1ca2cdSRiver Riddle firstMatcherBlock->erase();
1888a1ca2cdSRiver Riddle }
1898a1ca2cdSRiver Riddle
generateMatcher(MatcherNode & node,Region & region,Block * block)190*8ec28af8SMatthias Gehre Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
191*8ec28af8SMatthias Gehre Block *block) {
1928a1ca2cdSRiver Riddle // Push a new scope for the values used by this matcher.
193*8ec28af8SMatthias Gehre if (!block)
194*8ec28af8SMatthias Gehre block = ®ion.emplaceBlock();
1958a1ca2cdSRiver Riddle ValueMapScope scope(values);
1968a1ca2cdSRiver Riddle
1978a1ca2cdSRiver Riddle // If this is the return node, simply insert the corresponding interpreter
1988a1ca2cdSRiver Riddle // finalize.
1998a1ca2cdSRiver Riddle if (isa<ExitNode>(node)) {
2008a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(block);
2018a1ca2cdSRiver Riddle builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
2028a1ca2cdSRiver Riddle return block;
2038a1ca2cdSRiver Riddle }
2048a1ca2cdSRiver Riddle
2058a1ca2cdSRiver Riddle // Get the next block in the match sequence.
206a76ee58fSStanislav Funiak // This is intentionally executed first, before we get the value for the
207a76ee58fSStanislav Funiak // position associated with the node, so that we preserve an "there exist"
208a76ee58fSStanislav Funiak // semantics: if getting a value requires an upward traversal (going from a
209a76ee58fSStanislav Funiak // value to its consumers), we want to perform the check on all the consumers
210a76ee58fSStanislav Funiak // before we pass control to the failure node.
2118a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
212a76ee58fSStanislav Funiak Block *failureBlock;
2138a1ca2cdSRiver Riddle if (failureNode) {
214a76ee58fSStanislav Funiak failureBlock = generateMatcher(*failureNode, region);
215a76ee58fSStanislav Funiak failureBlockStack.push_back(failureBlock);
2168a1ca2cdSRiver Riddle } else {
2178a1ca2cdSRiver Riddle assert(!failureBlockStack.empty() && "expected valid failure block");
218a76ee58fSStanislav Funiak failureBlock = failureBlockStack.back();
2198a1ca2cdSRiver Riddle }
2208a1ca2cdSRiver Riddle
221a76ee58fSStanislav Funiak // If this node contains a position, get the corresponding value for this
222a76ee58fSStanislav Funiak // block.
223a76ee58fSStanislav Funiak Block *currentBlock = block;
224a76ee58fSStanislav Funiak Position *position = node.getPosition();
225a76ee58fSStanislav Funiak Value val = position ? getValueAt(currentBlock, position) : Value();
226a76ee58fSStanislav Funiak
2278a1ca2cdSRiver Riddle // If this value corresponds to an operation, record that we are going to use
2288a1ca2cdSRiver Riddle // its location as part of a fused location.
2295550c821STres Popp bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
2308a1ca2cdSRiver Riddle if (isOperationValue)
2318a1ca2cdSRiver Riddle locOps.insert(val);
2328a1ca2cdSRiver Riddle
233a76ee58fSStanislav Funiak // Dispatch to the correct method based on derived node type.
234a76ee58fSStanislav Funiak TypeSwitch<MatcherNode *>(&node)
235a19e1635SStanislav Funiak .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
236a19e1635SStanislav Funiak this->generate(derivedNode, currentBlock, val);
237a19e1635SStanislav Funiak })
238a76ee58fSStanislav Funiak .Case([&](SuccessNode *successNode) {
239a76ee58fSStanislav Funiak generate(successNode, currentBlock);
240a76ee58fSStanislav Funiak });
2418a1ca2cdSRiver Riddle
242a76ee58fSStanislav Funiak // Pop all the failure blocks that were inserted due to nesting of
243a76ee58fSStanislav Funiak // pdl_interp.iterate.
244a76ee58fSStanislav Funiak while (failureBlockStack.back() != failureBlock) {
245a76ee58fSStanislav Funiak failureBlockStack.pop_back();
246a76ee58fSStanislav Funiak assert(!failureBlockStack.empty() && "unable to locate failure block");
2478a1ca2cdSRiver Riddle }
2488a1ca2cdSRiver Riddle
249a76ee58fSStanislav Funiak // Pop the new failure block.
2508a1ca2cdSRiver Riddle if (failureNode)
2518a1ca2cdSRiver Riddle failureBlockStack.pop_back();
252a76ee58fSStanislav Funiak
2538a1ca2cdSRiver Riddle if (isOperationValue)
2548a1ca2cdSRiver Riddle locOps.remove(val);
255a76ee58fSStanislav Funiak
2568a1ca2cdSRiver Riddle return block;
2578a1ca2cdSRiver Riddle }
2588a1ca2cdSRiver Riddle
getValueAt(Block * & currentBlock,Position * pos)259a76ee58fSStanislav Funiak Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
2608a1ca2cdSRiver Riddle if (Value val = values.lookup(pos))
2618a1ca2cdSRiver Riddle return val;
2628a1ca2cdSRiver Riddle
2638a1ca2cdSRiver Riddle // Get the value for the parent position.
264233e9476SRiver Riddle Value parentVal;
265233e9476SRiver Riddle if (Position *parent = pos->getParent())
26680b3f08eSUday Bondhugula parentVal = getValueAt(currentBlock, parent);
2678a1ca2cdSRiver Riddle
2688a1ca2cdSRiver Riddle // TODO: Use a location from the position.
269233e9476SRiver Riddle Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
270a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(currentBlock);
2718a1ca2cdSRiver Riddle Value value;
2728a1ca2cdSRiver Riddle switch (pos->getKind()) {
273a76ee58fSStanislav Funiak case Predicates::OperationPos: {
274a76ee58fSStanislav Funiak auto *operationPos = cast<OperationPosition>(pos);
2752692eae5SStanislav Funiak if (operationPos->isOperandDefiningOp())
276a76ee58fSStanislav Funiak // Standard (downward) traversal which directly follows the defining op.
2778a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetDefiningOpOp>(
2788a1ca2cdSRiver Riddle loc, builder.getType<pdl::OperationType>(), parentVal);
2792692eae5SStanislav Funiak else
2802692eae5SStanislav Funiak // A passthrough operation position.
2812692eae5SStanislav Funiak value = parentVal;
2828a1ca2cdSRiver Riddle break;
283a76ee58fSStanislav Funiak }
2842692eae5SStanislav Funiak case Predicates::UsersPos: {
2852692eae5SStanislav Funiak auto *usersPos = cast<UsersPosition>(pos);
286a76ee58fSStanislav Funiak
287a76ee58fSStanislav Funiak // The first operation retrieves the representative value of a range.
2882692eae5SStanislav Funiak // This applies only when the parent is a range of values and we were
2892692eae5SStanislav Funiak // requested to use a representative value (e.g., upward traversal).
2905550c821STres Popp if (isa<pdl::RangeType>(parentVal.getType()) &&
2912692eae5SStanislav Funiak usersPos->useRepresentative())
292a76ee58fSStanislav Funiak value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
293a76ee58fSStanislav Funiak else
294a76ee58fSStanislav Funiak value = parentVal;
295a76ee58fSStanislav Funiak
296a76ee58fSStanislav Funiak // The second operation retrieves the users.
297a76ee58fSStanislav Funiak value = builder.create<pdl_interp::GetUsersOp>(loc, value);
2982692eae5SStanislav Funiak break;
2992692eae5SStanislav Funiak }
3002692eae5SStanislav Funiak case Predicates::ForEachPos: {
301a76ee58fSStanislav Funiak assert(!failureBlockStack.empty() && "expected valid failure block");
302a76ee58fSStanislav Funiak auto foreach = builder.create<pdl_interp::ForEachOp>(
3032692eae5SStanislav Funiak loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
304a76ee58fSStanislav Funiak value = foreach.getLoopVariable();
305a76ee58fSStanislav Funiak
3062692eae5SStanislav Funiak // Create the continuation block.
3073c405c3bSRiver Riddle Block *continueBlock = builder.createBlock(&foreach.getRegion());
308a76ee58fSStanislav Funiak builder.create<pdl_interp::ContinueOp>(loc);
309a76ee58fSStanislav Funiak failureBlockStack.push_back(continueBlock);
310a76ee58fSStanislav Funiak
3113c405c3bSRiver Riddle currentBlock = &foreach.getRegion().front();
312a76ee58fSStanislav Funiak break;
313a76ee58fSStanislav Funiak }
3148a1ca2cdSRiver Riddle case Predicates::OperandPos: {
3158a1ca2cdSRiver Riddle auto *operandPos = cast<OperandPosition>(pos);
3168a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetOperandOp>(
3178a1ca2cdSRiver Riddle loc, builder.getType<pdl::ValueType>(), parentVal,
3188a1ca2cdSRiver Riddle operandPos->getOperandNumber());
3198a1ca2cdSRiver Riddle break;
3208a1ca2cdSRiver Riddle }
3213a833a0eSRiver Riddle case Predicates::OperandGroupPos: {
3223a833a0eSRiver Riddle auto *operandPos = cast<OperandGroupPosition>(pos);
3233a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
3243a833a0eSRiver Riddle value = builder.create<pdl_interp::GetOperandsOp>(
3253a833a0eSRiver Riddle loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
3263a833a0eSRiver Riddle parentVal, operandPos->getOperandGroupNumber());
3273a833a0eSRiver Riddle break;
3283a833a0eSRiver Riddle }
3298a1ca2cdSRiver Riddle case Predicates::AttributePos: {
3308a1ca2cdSRiver Riddle auto *attrPos = cast<AttributePosition>(pos);
3318a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetAttributeOp>(
3328a1ca2cdSRiver Riddle loc, builder.getType<pdl::AttributeType>(), parentVal,
3338a1ca2cdSRiver Riddle attrPos->getName().strref());
3348a1ca2cdSRiver Riddle break;
3358a1ca2cdSRiver Riddle }
3368a1ca2cdSRiver Riddle case Predicates::TypePos: {
3375550c821STres Popp if (isa<pdl::AttributeType>(parentVal.getType()))
3388a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
3393a833a0eSRiver Riddle else
3403a833a0eSRiver Riddle value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
3418a1ca2cdSRiver Riddle break;
3428a1ca2cdSRiver Riddle }
3438a1ca2cdSRiver Riddle case Predicates::ResultPos: {
3448a1ca2cdSRiver Riddle auto *resPos = cast<ResultPosition>(pos);
3458a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetResultOp>(
3468a1ca2cdSRiver Riddle loc, builder.getType<pdl::ValueType>(), parentVal,
3478a1ca2cdSRiver Riddle resPos->getResultNumber());
3488a1ca2cdSRiver Riddle break;
3498a1ca2cdSRiver Riddle }
3503a833a0eSRiver Riddle case Predicates::ResultGroupPos: {
3513a833a0eSRiver Riddle auto *resPos = cast<ResultGroupPosition>(pos);
3523a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
3533a833a0eSRiver Riddle value = builder.create<pdl_interp::GetResultsOp>(
3543a833a0eSRiver Riddle loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
3553a833a0eSRiver Riddle parentVal, resPos->getResultGroupNumber());
3563a833a0eSRiver Riddle break;
3573a833a0eSRiver Riddle }
358233e9476SRiver Riddle case Predicates::AttributeLiteralPos: {
359233e9476SRiver Riddle auto *attrPos = cast<AttributeLiteralPosition>(pos);
360233e9476SRiver Riddle value =
361233e9476SRiver Riddle builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
362233e9476SRiver Riddle break;
363233e9476SRiver Riddle }
364233e9476SRiver Riddle case Predicates::TypeLiteralPos: {
365233e9476SRiver Riddle auto *typePos = cast<TypeLiteralPosition>(pos);
366233e9476SRiver Riddle Attribute rawTypeAttr = typePos->getValue();
3675550c821STres Popp if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368233e9476SRiver Riddle value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
369233e9476SRiver Riddle else
370233e9476SRiver Riddle value = builder.create<pdl_interp::CreateTypesOp>(
3715550c821STres Popp loc, cast<ArrayAttr>(rawTypeAttr));
372233e9476SRiver Riddle break;
373233e9476SRiver Riddle }
374*8ec28af8SMatthias Gehre case Predicates::ConstraintResultPos: {
375*8ec28af8SMatthias Gehre // Due to the order of traversal, the ApplyConstraintOp has already been
376*8ec28af8SMatthias Gehre // created and we can find it in constraintOpMap.
377*8ec28af8SMatthias Gehre auto *constrResPos = cast<ConstraintPosition>(pos);
378*8ec28af8SMatthias Gehre auto i = constraintOpMap.find(constrResPos->getQuestion());
379*8ec28af8SMatthias Gehre assert(i != constraintOpMap.end());
380*8ec28af8SMatthias Gehre value = i->second->getResult(constrResPos->getIndex());
381*8ec28af8SMatthias Gehre break;
382*8ec28af8SMatthias Gehre }
3838a1ca2cdSRiver Riddle default:
3848a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown Position getter");
3858a1ca2cdSRiver Riddle break;
3868a1ca2cdSRiver Riddle }
387a76ee58fSStanislav Funiak
3888a1ca2cdSRiver Riddle values.insert(pos, value);
3898a1ca2cdSRiver Riddle return value;
3908a1ca2cdSRiver Riddle }
3918a1ca2cdSRiver Riddle
generate(BoolNode * boolNode,Block * & currentBlock,Value val)392a76ee58fSStanislav Funiak void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
393a76ee58fSStanislav Funiak Value val) {
3948a1ca2cdSRiver Riddle Location loc = val.getLoc();
395a76ee58fSStanislav Funiak Qualifier *question = boolNode->getQuestion();
396a76ee58fSStanislav Funiak Qualifier *answer = boolNode->getAnswer();
397a76ee58fSStanislav Funiak Region *region = currentBlock->getParent();
398a76ee58fSStanislav Funiak
399a76ee58fSStanislav Funiak // Execute the getValue queries first, so that we create success
400a76ee58fSStanislav Funiak // matcher in the correct (possibly nested) region.
401a76ee58fSStanislav Funiak SmallVector<Value> args;
402a76ee58fSStanislav Funiak if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403a76ee58fSStanislav Funiak args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404a76ee58fSStanislav Funiak } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405233e9476SRiver Riddle for (Position *position : cstQuestion->getArgs())
406a76ee58fSStanislav Funiak args.push_back(getValueAt(currentBlock, position));
407a76ee58fSStanislav Funiak }
408a76ee58fSStanislav Funiak
409*8ec28af8SMatthias Gehre // Generate a new block as success successor and get the failure successor.
410*8ec28af8SMatthias Gehre Block *success = ®ion->emplaceBlock();
411a76ee58fSStanislav Funiak Block *failure = failureBlockStack.back();
412a76ee58fSStanislav Funiak
413*8ec28af8SMatthias Gehre // Create the predicate.
414a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(currentBlock);
4153a833a0eSRiver Riddle Predicates::Kind kind = question->getKind();
4163a833a0eSRiver Riddle switch (kind) {
4178a1ca2cdSRiver Riddle case Predicates::IsNotNullQuestion:
418a76ee58fSStanislav Funiak builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
4198a1ca2cdSRiver Riddle break;
4208a1ca2cdSRiver Riddle case Predicates::OperationNameQuestion: {
4218a1ca2cdSRiver Riddle auto *opNameAnswer = cast<OperationNameAnswer>(answer);
4228a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckOperationNameOp>(
423a76ee58fSStanislav Funiak loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
4248a1ca2cdSRiver Riddle break;
4258a1ca2cdSRiver Riddle }
4268a1ca2cdSRiver Riddle case Predicates::TypeQuestion: {
4278a1ca2cdSRiver Riddle auto *ans = cast<TypeAnswer>(answer);
4285550c821STres Popp if (isa<pdl::RangeType>(val.getType()))
4293a833a0eSRiver Riddle builder.create<pdl_interp::CheckTypesOp>(
43068f58812STres Popp loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
4313a833a0eSRiver Riddle else
4328a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckTypeOp>(
43368f58812STres Popp loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
4348a1ca2cdSRiver Riddle break;
4358a1ca2cdSRiver Riddle }
4368a1ca2cdSRiver Riddle case Predicates::AttributeQuestion: {
4378a1ca2cdSRiver Riddle auto *ans = cast<AttributeAnswer>(answer);
4388a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
439a76ee58fSStanislav Funiak success, failure);
4408a1ca2cdSRiver Riddle break;
4418a1ca2cdSRiver Riddle }
4423a833a0eSRiver Riddle case Predicates::OperandCountAtLeastQuestion:
4433a833a0eSRiver Riddle case Predicates::OperandCountQuestion:
4448a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckOperandCountOp>(
4453a833a0eSRiver Riddle loc, val, cast<UnsignedAnswer>(answer)->getValue(),
4463a833a0eSRiver Riddle /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
447a76ee58fSStanislav Funiak success, failure);
4488a1ca2cdSRiver Riddle break;
4493a833a0eSRiver Riddle case Predicates::ResultCountAtLeastQuestion:
4503a833a0eSRiver Riddle case Predicates::ResultCountQuestion:
4518a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckResultCountOp>(
4523a833a0eSRiver Riddle loc, val, cast<UnsignedAnswer>(answer)->getValue(),
4533a833a0eSRiver Riddle /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
454a76ee58fSStanislav Funiak success, failure);
4558a1ca2cdSRiver Riddle break;
4568a1ca2cdSRiver Riddle case Predicates::EqualToQuestion: {
457a76ee58fSStanislav Funiak bool trueAnswer = isa<TrueAnswer>(answer);
458a76ee58fSStanislav Funiak builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
459a76ee58fSStanislav Funiak trueAnswer ? success : failure,
460a76ee58fSStanislav Funiak trueAnswer ? failure : success);
4618a1ca2cdSRiver Riddle break;
4628a1ca2cdSRiver Riddle }
4638a1ca2cdSRiver Riddle case Predicates::ConstraintQuestion: {
464233e9476SRiver Riddle auto *cstQuestion = cast<ConstraintQuestion>(question);
465*8ec28af8SMatthias Gehre auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466*8ec28af8SMatthias Gehre loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467*8ec28af8SMatthias Gehre cstQuestion->getIsNegated(), success, failure);
468*8ec28af8SMatthias Gehre
469*8ec28af8SMatthias Gehre constraintOpMap.insert({cstQuestion, applyConstraintOp});
4708a1ca2cdSRiver Riddle break;
4718a1ca2cdSRiver Riddle }
4728a1ca2cdSRiver Riddle default:
4738a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown Predicate operation");
4748a1ca2cdSRiver Riddle }
475*8ec28af8SMatthias Gehre
476*8ec28af8SMatthias Gehre // Generate the matcher in the current (potentially nested) region.
477*8ec28af8SMatthias Gehre // This might use the results of the current predicate.
478*8ec28af8SMatthias Gehre generateMatcher(*boolNode->getSuccessNode(), *region, success);
4798a1ca2cdSRiver Riddle }
4808a1ca2cdSRiver Riddle
4818a1ca2cdSRiver Riddle template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)4828a1ca2cdSRiver Riddle static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
4833a833a0eSRiver Riddle llvm::MapVector<Qualifier *, Block *> &dests) {
4848a1ca2cdSRiver Riddle std::vector<ValT> values;
4858a1ca2cdSRiver Riddle std::vector<Block *> blocks;
4868a1ca2cdSRiver Riddle values.reserve(dests.size());
4878a1ca2cdSRiver Riddle blocks.reserve(dests.size());
4888a1ca2cdSRiver Riddle for (const auto &it : dests) {
4898a1ca2cdSRiver Riddle blocks.push_back(it.second);
4908a1ca2cdSRiver Riddle values.push_back(cast<PredT>(it.first)->getValue());
4918a1ca2cdSRiver Riddle }
4928a1ca2cdSRiver Riddle builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
4938a1ca2cdSRiver Riddle }
4948a1ca2cdSRiver Riddle
generate(SwitchNode * switchNode,Block * currentBlock,Value val)495a76ee58fSStanislav Funiak void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
496a76ee58fSStanislav Funiak Value val) {
497a76ee58fSStanislav Funiak Qualifier *question = switchNode->getQuestion();
498a76ee58fSStanislav Funiak Region *region = currentBlock->getParent();
499a76ee58fSStanislav Funiak Block *defaultDest = failureBlockStack.back();
500a76ee58fSStanislav Funiak
5013a833a0eSRiver Riddle // If the switch question is not an exact answer, i.e. for the `at_least`
5023a833a0eSRiver Riddle // cases, we generate a special block sequence.
5033a833a0eSRiver Riddle Predicates::Kind kind = question->getKind();
5043a833a0eSRiver Riddle if (kind == Predicates::OperandCountAtLeastQuestion ||
5053a833a0eSRiver Riddle kind == Predicates::ResultCountAtLeastQuestion) {
5063a833a0eSRiver Riddle // Order the children such that the cases are in reverse numerical order.
5071d49e535SGuillaume Chatelet SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
5081d49e535SGuillaume Chatelet llvm::seq<unsigned>(0, switchNode->getChildren().size()));
5093a833a0eSRiver Riddle llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
5103a833a0eSRiver Riddle return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
5113a833a0eSRiver Riddle cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
5123a833a0eSRiver Riddle });
5133a833a0eSRiver Riddle
5143a833a0eSRiver Riddle // Build the destination for each child using the next highest child as a
5153a833a0eSRiver Riddle // a failure destination. This essentially creates the following control
5163a833a0eSRiver Riddle // flow:
5173a833a0eSRiver Riddle //
5183a833a0eSRiver Riddle // if (operand_count < 1)
5193a833a0eSRiver Riddle // goto failure
5203a833a0eSRiver Riddle // if (child1.match())
5213a833a0eSRiver Riddle // ...
5223a833a0eSRiver Riddle //
5233a833a0eSRiver Riddle // if (operand_count < 2)
5243a833a0eSRiver Riddle // goto failure
5253a833a0eSRiver Riddle // if (child2.match())
5263a833a0eSRiver Riddle // ...
5273a833a0eSRiver Riddle //
5283a833a0eSRiver Riddle // failure:
5293a833a0eSRiver Riddle // ...
5303a833a0eSRiver Riddle //
5313a833a0eSRiver Riddle failureBlockStack.push_back(defaultDest);
532a76ee58fSStanislav Funiak Location loc = val.getLoc();
5333a833a0eSRiver Riddle for (unsigned idx : sortedChildren) {
5343a833a0eSRiver Riddle auto &child = switchNode->getChild(idx);
535a76ee58fSStanislav Funiak Block *childBlock = generateMatcher(*child.second, *region);
5363a833a0eSRiver Riddle Block *predicateBlock = builder.createBlock(childBlock);
537a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(predicateBlock);
538a76ee58fSStanislav Funiak unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
539a76ee58fSStanislav Funiak switch (kind) {
540a76ee58fSStanislav Funiak case Predicates::OperandCountAtLeastQuestion:
541a76ee58fSStanislav Funiak builder.create<pdl_interp::CheckOperandCountOp>(
542a76ee58fSStanislav Funiak loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
543a76ee58fSStanislav Funiak break;
544a76ee58fSStanislav Funiak case Predicates::ResultCountAtLeastQuestion:
545a76ee58fSStanislav Funiak builder.create<pdl_interp::CheckResultCountOp>(
546a76ee58fSStanislav Funiak loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
547a76ee58fSStanislav Funiak break;
548a76ee58fSStanislav Funiak default:
549a76ee58fSStanislav Funiak llvm_unreachable("Generating invalid AtLeast operation");
550a76ee58fSStanislav Funiak }
5513a833a0eSRiver Riddle failureBlockStack.back() = predicateBlock;
5523a833a0eSRiver Riddle }
5533a833a0eSRiver Riddle Block *firstPredicateBlock = failureBlockStack.pop_back_val();
5543a833a0eSRiver Riddle currentBlock->getOperations().splice(currentBlock->end(),
5553a833a0eSRiver Riddle firstPredicateBlock->getOperations());
5563a833a0eSRiver Riddle firstPredicateBlock->erase();
5573a833a0eSRiver Riddle return;
5583a833a0eSRiver Riddle }
5593a833a0eSRiver Riddle
5603a833a0eSRiver Riddle // Otherwise, generate each of the children and generate an interpreter
5613a833a0eSRiver Riddle // switch.
5623a833a0eSRiver Riddle llvm::MapVector<Qualifier *, Block *> children;
5633a833a0eSRiver Riddle for (auto &it : switchNode->getChildren())
564a76ee58fSStanislav Funiak children.insert({it.first, generateMatcher(*it.second, *region)});
5658a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(currentBlock);
5663a833a0eSRiver Riddle
5678a1ca2cdSRiver Riddle switch (question->getKind()) {
5688a1ca2cdSRiver Riddle case Predicates::OperandCountQuestion:
5698a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
5703a833a0eSRiver Riddle int32_t>(val, defaultDest, builder, children);
5718a1ca2cdSRiver Riddle case Predicates::ResultCountQuestion:
5728a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
5733a833a0eSRiver Riddle int32_t>(val, defaultDest, builder, children);
5748a1ca2cdSRiver Riddle case Predicates::OperationNameQuestion:
5758a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchOperationNameOp,
5768a1ca2cdSRiver Riddle OperationNameAnswer>(val, defaultDest, builder,
5773a833a0eSRiver Riddle children);
5788a1ca2cdSRiver Riddle case Predicates::TypeQuestion:
5795550c821STres Popp if (isa<pdl::RangeType>(val.getType())) {
5803a833a0eSRiver Riddle return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
5813a833a0eSRiver Riddle val, defaultDest, builder, children);
5823a833a0eSRiver Riddle }
5838a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
5843a833a0eSRiver Riddle val, defaultDest, builder, children);
5858a1ca2cdSRiver Riddle case Predicates::AttributeQuestion:
5868a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
5873a833a0eSRiver Riddle val, defaultDest, builder, children);
5888a1ca2cdSRiver Riddle default:
5898a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown switch predicate.");
5908a1ca2cdSRiver Riddle }
5918a1ca2cdSRiver Riddle }
5928a1ca2cdSRiver Riddle
generate(SuccessNode * successNode,Block * & currentBlock)593a76ee58fSStanislav Funiak void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
594a76ee58fSStanislav Funiak pdl::PatternOp pattern = successNode->getPattern();
595a76ee58fSStanislav Funiak Value root = successNode->getRoot();
596a76ee58fSStanislav Funiak
5978a1ca2cdSRiver Riddle // Generate a rewriter for the pattern this success node represents, and track
5988a1ca2cdSRiver Riddle // any values used from the match region.
5998a1ca2cdSRiver Riddle SmallVector<Position *, 8> usedMatchValues;
6008a1ca2cdSRiver Riddle SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
6018a1ca2cdSRiver Riddle
6028a1ca2cdSRiver Riddle // Process any values used in the rewrite that are defined in the match.
6038a1ca2cdSRiver Riddle std::vector<Value> mappedMatchValues;
6048a1ca2cdSRiver Riddle mappedMatchValues.reserve(usedMatchValues.size());
6058a1ca2cdSRiver Riddle for (Position *position : usedMatchValues)
6068a1ca2cdSRiver Riddle mappedMatchValues.push_back(getValueAt(currentBlock, position));
6078a1ca2cdSRiver Riddle
6088a1ca2cdSRiver Riddle // Collect the set of operations generated by the rewriter.
6098a1ca2cdSRiver Riddle SmallVector<StringRef, 4> generatedOps;
61072fddfb5SRiver Riddle for (auto op :
61172fddfb5SRiver Riddle pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
61272fddfb5SRiver Riddle generatedOps.push_back(*op.getOpName());
6138a1ca2cdSRiver Riddle ArrayAttr generatedOpsAttr;
6148a1ca2cdSRiver Riddle if (!generatedOps.empty())
6158a1ca2cdSRiver Riddle generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
6168a1ca2cdSRiver Riddle
6178a1ca2cdSRiver Riddle // Grab the root kind if present.
6188a1ca2cdSRiver Riddle StringAttr rootKindAttr;
619a76ee58fSStanislav Funiak if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
62022426110SRamkumar Ramachandra if (std::optional<StringRef> rootKind = rootOp.getOpName())
6218a1ca2cdSRiver Riddle rootKindAttr = builder.getStringAttr(*rootKind);
6228a1ca2cdSRiver Riddle
6238a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(currentBlock);
6248c66344eSRiver Riddle auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
6258a1ca2cdSRiver Riddle pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
626310c3ee4SRiver Riddle rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
627a76ee58fSStanislav Funiak failureBlockStack.back());
6288c66344eSRiver Riddle
6298c66344eSRiver Riddle // Set the config of the lowered match to the parent pattern.
6308c66344eSRiver Riddle if (configMap)
6318c66344eSRiver Riddle configMap->try_emplace(matchOp, configMap->lookup(pattern));
6328a1ca2cdSRiver Riddle }
6338a1ca2cdSRiver Riddle
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)6348a1ca2cdSRiver Riddle SymbolRefAttr PatternLowering::generateRewriter(
6358a1ca2cdSRiver Riddle pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
636f96a8675SRiver Riddle builder.setInsertionPointToEnd(rewriterModule.getBody());
637f96a8675SRiver Riddle auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
638f96a8675SRiver Riddle pattern.getLoc(), "pdl_generated_rewriter",
6391a36588eSKazu Hirata builder.getFunctionType(std::nullopt, std::nullopt));
6408a1ca2cdSRiver Riddle rewriterSymbolTable.insert(rewriterFunc);
6418a1ca2cdSRiver Riddle
6428a1ca2cdSRiver Riddle // Generate the rewriter function body.
643f96a8675SRiver Riddle builder.setInsertionPointToEnd(&rewriterFunc.front());
6448a1ca2cdSRiver Riddle
6458a1ca2cdSRiver Riddle // Map an input operand of the pattern to a generated interpreter value.
6468a1ca2cdSRiver Riddle DenseMap<Value, Value> rewriteValues;
6478a1ca2cdSRiver Riddle auto mapRewriteValue = [&](Value oldValue) {
6488a1ca2cdSRiver Riddle Value &newValue = rewriteValues[oldValue];
6498a1ca2cdSRiver Riddle if (newValue)
6508a1ca2cdSRiver Riddle return newValue;
6518a1ca2cdSRiver Riddle
6528a1ca2cdSRiver Riddle // Prefer materializing constants directly when possible.
6538a1ca2cdSRiver Riddle Operation *oldOp = oldValue.getDefiningOp();
6548a1ca2cdSRiver Riddle if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
655310c3ee4SRiver Riddle if (Attribute value = attrOp.getValueAttr()) {
6568a1ca2cdSRiver Riddle return newValue = builder.create<pdl_interp::CreateAttributeOp>(
6578a1ca2cdSRiver Riddle attrOp.getLoc(), value);
6588a1ca2cdSRiver Riddle }
6598a1ca2cdSRiver Riddle } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
66072fddfb5SRiver Riddle if (TypeAttr type = typeOp.getConstantTypeAttr()) {
6618a1ca2cdSRiver Riddle return newValue = builder.create<pdl_interp::CreateTypeOp>(
6628a1ca2cdSRiver Riddle typeOp.getLoc(), type);
6638a1ca2cdSRiver Riddle }
6643a833a0eSRiver Riddle } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
66572fddfb5SRiver Riddle if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
6663a833a0eSRiver Riddle return newValue = builder.create<pdl_interp::CreateTypesOp>(
6673a833a0eSRiver Riddle typeOp.getLoc(), typeOp.getType(), type);
6683a833a0eSRiver Riddle }
6698a1ca2cdSRiver Riddle }
6708a1ca2cdSRiver Riddle
6718a1ca2cdSRiver Riddle // Otherwise, add this as an input to the rewriter.
6728a1ca2cdSRiver Riddle Position *inputPos = valueToPosition.lookup(oldValue);
6738a1ca2cdSRiver Riddle assert(inputPos && "expected value to be a pattern input");
6748a1ca2cdSRiver Riddle usedMatchValues.push_back(inputPos);
675e084679fSRiver Riddle return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
676e084679fSRiver Riddle oldValue.getLoc());
6778a1ca2cdSRiver Riddle };
6788a1ca2cdSRiver Riddle
6798a1ca2cdSRiver Riddle // If this is a custom rewriter, simply dispatch to the registered rewrite
6808a1ca2cdSRiver Riddle // method.
6818a1ca2cdSRiver Riddle pdl::RewriteOp rewriter = pattern.getRewriter();
682310c3ee4SRiver Riddle if (StringAttr rewriteName = rewriter.getNameAttr()) {
683a76ee58fSStanislav Funiak SmallVector<Value> args;
684310c3ee4SRiver Riddle if (rewriter.getRoot())
685310c3ee4SRiver Riddle args.push_back(mapRewriteValue(rewriter.getRoot()));
686310c3ee4SRiver Riddle auto mappedArgs =
687310c3ee4SRiver Riddle llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
68802c4c0d5SRiver Riddle args.append(mappedArgs.begin(), mappedArgs.end());
6898a1ca2cdSRiver Riddle builder.create<pdl_interp::ApplyRewriteOp>(
6909595f356SRiver Riddle rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
6918a1ca2cdSRiver Riddle } else {
6928a1ca2cdSRiver Riddle // Otherwise this is a dag rewriter defined using PDL operations.
6938a1ca2cdSRiver Riddle for (Operation &rewriteOp : *rewriter.getBody()) {
6948a1ca2cdSRiver Riddle llvm::TypeSwitch<Operation *>(&rewriteOp)
69502c4c0d5SRiver Riddle .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
696ce57789dSRiver Riddle pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
697ce57789dSRiver Riddle pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
6988a1ca2cdSRiver Riddle this->generateRewriter(op, rewriteValues, mapRewriteValue);
6998a1ca2cdSRiver Riddle });
7008a1ca2cdSRiver Riddle }
7018a1ca2cdSRiver Riddle }
7028a1ca2cdSRiver Riddle
7038a1ca2cdSRiver Riddle // Update the signature of the rewrite function.
7048a1ca2cdSRiver Riddle rewriterFunc.setType(builder.getFunctionType(
7058a1ca2cdSRiver Riddle llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
7061a36588eSKazu Hirata /*results=*/std::nullopt));
7078a1ca2cdSRiver Riddle
7088a1ca2cdSRiver Riddle builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
709faf1c224SChris Lattner return SymbolRefAttr::get(
710faf1c224SChris Lattner builder.getContext(),
7118a1ca2cdSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712faf1c224SChris Lattner SymbolRefAttr::get(rewriterFunc));
7138a1ca2cdSRiver Riddle }
7148a1ca2cdSRiver Riddle
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7158a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
71602c4c0d5SRiver Riddle pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
71702c4c0d5SRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
71802c4c0d5SRiver Riddle SmallVector<Value, 2> arguments;
719310c3ee4SRiver Riddle for (Value argument : rewriteOp.getArgs())
72002c4c0d5SRiver Riddle arguments.push_back(mapRewriteValue(argument));
72102c4c0d5SRiver Riddle auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
722310c3ee4SRiver Riddle rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
7239595f356SRiver Riddle arguments);
7249595f356SRiver Riddle for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
72502c4c0d5SRiver Riddle rewriteValues[std::get<0>(it)] = std::get<1>(it);
72602c4c0d5SRiver Riddle }
72702c4c0d5SRiver Riddle
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)72802c4c0d5SRiver Riddle void PatternLowering::generateRewriter(
7298a1ca2cdSRiver Riddle pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
7308a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
7318a1ca2cdSRiver Riddle Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
732310c3ee4SRiver Riddle attrOp.getLoc(), attrOp.getValueAttr());
7338a1ca2cdSRiver Riddle rewriteValues[attrOp] = newAttr;
7348a1ca2cdSRiver Riddle }
7358a1ca2cdSRiver Riddle
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7368a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
7378a1ca2cdSRiver Riddle pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
7388a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
7398a1ca2cdSRiver Riddle builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
74072fddfb5SRiver Riddle mapRewriteValue(eraseOp.getOpValue()));
7418a1ca2cdSRiver Riddle }
7428a1ca2cdSRiver Riddle
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7438a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
7448a1ca2cdSRiver Riddle pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
7458a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
7468a1ca2cdSRiver Riddle SmallVector<Value, 4> operands;
74772fddfb5SRiver Riddle for (Value operand : operationOp.getOperandValues())
7488a1ca2cdSRiver Riddle operands.push_back(mapRewriteValue(operand));
7498a1ca2cdSRiver Riddle
7508a1ca2cdSRiver Riddle SmallVector<Value, 4> attributes;
75172fddfb5SRiver Riddle for (Value attr : operationOp.getAttributeValues())
7528a1ca2cdSRiver Riddle attributes.push_back(mapRewriteValue(attr));
7538a1ca2cdSRiver Riddle
7543c752289SRiver Riddle bool hasInferredResultTypes = false;
7558a1ca2cdSRiver Riddle SmallVector<Value, 2> types;
7563c752289SRiver Riddle generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
7573c752289SRiver Riddle rewriteValues, hasInferredResultTypes);
7588a1ca2cdSRiver Riddle
7598a1ca2cdSRiver Riddle // Create the new operation.
7608a1ca2cdSRiver Riddle Location loc = operationOp.getLoc();
7618a1ca2cdSRiver Riddle Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
76272fddfb5SRiver Riddle loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
76372fddfb5SRiver Riddle attributes, operationOp.getAttributeValueNames());
764310c3ee4SRiver Riddle rewriteValues[operationOp.getOp()] = createdOp;
7658a1ca2cdSRiver Riddle
766242762c9SRiver Riddle // Generate accesses for any results that have their types constrained.
7673a833a0eSRiver Riddle // Handle the case where there is a single range representing all of the
7683a833a0eSRiver Riddle // result types.
76972fddfb5SRiver Riddle OperandRange resultTys = operationOp.getTypeValues();
7705550c821STres Popp if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
7713a833a0eSRiver Riddle Value &type = rewriteValues[resultTys[0]];
7723a833a0eSRiver Riddle if (!type) {
7733a833a0eSRiver Riddle auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
7743a833a0eSRiver Riddle type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
7753a833a0eSRiver Riddle }
7763a833a0eSRiver Riddle return;
7773a833a0eSRiver Riddle }
7783a833a0eSRiver Riddle
7793a833a0eSRiver Riddle // Otherwise, populate the individual results.
7803a833a0eSRiver Riddle bool seenVariableLength = false;
7813a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
7823a833a0eSRiver Riddle Type valueRangeTy = pdl::RangeType::get(valueTy);
783e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(resultTys)) {
784242762c9SRiver Riddle Value &type = rewriteValues[it.value()];
785242762c9SRiver Riddle if (type)
786242762c9SRiver Riddle continue;
7875550c821STres Popp bool isVariadic = isa<pdl::RangeType>(it.value().getType());
7883a833a0eSRiver Riddle seenVariableLength |= isVariadic;
789242762c9SRiver Riddle
7903a833a0eSRiver Riddle // After a variable length result has been seen, we need to use result
7913a833a0eSRiver Riddle // groups because the exact index of the result is not statically known.
7923a833a0eSRiver Riddle Value resultVal;
7933a833a0eSRiver Riddle if (seenVariableLength)
7943a833a0eSRiver Riddle resultVal = builder.create<pdl_interp::GetResultsOp>(
7953a833a0eSRiver Riddle loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
7963a833a0eSRiver Riddle else
7973a833a0eSRiver Riddle resultVal = builder.create<pdl_interp::GetResultOp>(
7983a833a0eSRiver Riddle loc, valueTy, createdOp, it.index());
7993a833a0eSRiver Riddle type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
8008a1ca2cdSRiver Riddle }
8018a1ca2cdSRiver Riddle }
8028a1ca2cdSRiver Riddle
generateRewriter(pdl::RangeOp rangeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8038a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
804ce57789dSRiver Riddle pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
805ce57789dSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
806ce57789dSRiver Riddle SmallVector<Value, 4> replOperands;
807ce57789dSRiver Riddle for (Value operand : rangeOp.getArguments())
808ce57789dSRiver Riddle replOperands.push_back(mapRewriteValue(operand));
809ce57789dSRiver Riddle rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
810ce57789dSRiver Riddle rangeOp.getLoc(), rangeOp.getType(), replOperands);
811ce57789dSRiver Riddle }
812ce57789dSRiver Riddle
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)813ce57789dSRiver Riddle void PatternLowering::generateRewriter(
8148a1ca2cdSRiver Riddle pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
8158a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
816242762c9SRiver Riddle SmallVector<Value, 4> replOperands;
817242762c9SRiver Riddle
8188a1ca2cdSRiver Riddle // If the replacement was another operation, get its results. `pdl` allows
8198a1ca2cdSRiver Riddle // for using an operation for simplicitly, but the interpreter isn't as
8208a1ca2cdSRiver Riddle // user facing.
821310c3ee4SRiver Riddle if (Value replOp = replaceOp.getReplOperation()) {
8223a833a0eSRiver Riddle // Don't use replace if we know the replaced operation has no results.
82372fddfb5SRiver Riddle auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
82472fddfb5SRiver Riddle if (!opOp || !opOp.getTypeValues().empty()) {
8253a833a0eSRiver Riddle replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
8263a833a0eSRiver Riddle replOp.getLoc(), mapRewriteValue(replOp)));
8273a833a0eSRiver Riddle }
828242762c9SRiver Riddle } else {
829310c3ee4SRiver Riddle for (Value operand : replaceOp.getReplValues())
830242762c9SRiver Riddle replOperands.push_back(mapRewriteValue(operand));
831242762c9SRiver Riddle }
8328a1ca2cdSRiver Riddle
8338a1ca2cdSRiver Riddle // If there are no replacement values, just create an erase instead.
834242762c9SRiver Riddle if (replOperands.empty()) {
83572fddfb5SRiver Riddle builder.create<pdl_interp::EraseOp>(
83672fddfb5SRiver Riddle replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
8378a1ca2cdSRiver Riddle return;
8388a1ca2cdSRiver Riddle }
8398a1ca2cdSRiver Riddle
84072fddfb5SRiver Riddle builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
84172fddfb5SRiver Riddle mapRewriteValue(replaceOp.getOpValue()),
84272fddfb5SRiver Riddle replOperands);
8438a1ca2cdSRiver Riddle }
8448a1ca2cdSRiver Riddle
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8458a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
846242762c9SRiver Riddle pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
847242762c9SRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
848242762c9SRiver Riddle rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
849242762c9SRiver Riddle resultOp.getLoc(), builder.getType<pdl::ValueType>(),
850310c3ee4SRiver Riddle mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
851242762c9SRiver Riddle }
852242762c9SRiver Riddle
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)853242762c9SRiver Riddle void PatternLowering::generateRewriter(
8543a833a0eSRiver Riddle pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
8553a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8563a833a0eSRiver Riddle rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
857310c3ee4SRiver Riddle resultOp.getLoc(), resultOp.getType(),
858310c3ee4SRiver Riddle mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
8593a833a0eSRiver Riddle }
8603a833a0eSRiver Riddle
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8613a833a0eSRiver Riddle void PatternLowering::generateRewriter(
8628a1ca2cdSRiver Riddle pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
8638a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8648a1ca2cdSRiver Riddle // If the type isn't constant, the users (e.g. OperationOp) will resolve this
8658a1ca2cdSRiver Riddle // type.
86672fddfb5SRiver Riddle if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
8673a833a0eSRiver Riddle rewriteValues[typeOp] =
8688a1ca2cdSRiver Riddle builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
8693a833a0eSRiver Riddle }
8703a833a0eSRiver Riddle }
8713a833a0eSRiver Riddle
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8723a833a0eSRiver Riddle void PatternLowering::generateRewriter(
8733a833a0eSRiver Riddle pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
8743a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8753a833a0eSRiver Riddle // If the type isn't constant, the users (e.g. OperationOp) will resolve this
8763a833a0eSRiver Riddle // type.
87772fddfb5SRiver Riddle if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
8783a833a0eSRiver Riddle rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
8793a833a0eSRiver Riddle typeOp.getLoc(), typeOp.getType(), typeAttr);
8808a1ca2cdSRiver Riddle }
8818a1ca2cdSRiver Riddle }
8828a1ca2cdSRiver Riddle
generateOperationResultTypeRewriter(pdl::OperationOp op,function_ref<Value (Value)> mapRewriteValue,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,bool & hasInferredResultTypes)8838a1ca2cdSRiver Riddle void PatternLowering::generateOperationResultTypeRewriter(
8843c752289SRiver Riddle pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
8853c752289SRiver Riddle SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
8863c752289SRiver Riddle bool &hasInferredResultTypes) {
887c4a04059SChristian Sigg Block *rewriterBlock = op->getBlock();
8883a833a0eSRiver Riddle
8893c752289SRiver Riddle // Try to handle resolution for each of the result types individually. This is
8903c752289SRiver Riddle // preferred over type inferrence because it will allow for us to use existing
8913c752289SRiver Riddle // types directly, as opposed to trying to rebuild the type list.
89272fddfb5SRiver Riddle OperandRange resultTypeValues = op.getTypeValues();
8933c752289SRiver Riddle auto tryResolveResultTypes = [&] {
8948a1ca2cdSRiver Riddle types.reserve(resultTypeValues.size());
895e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(resultTypeValues)) {
896242762c9SRiver Riddle Value resultType = it.value();
8978a1ca2cdSRiver Riddle
8988a1ca2cdSRiver Riddle // Check for an already translated value.
8998a1ca2cdSRiver Riddle if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
9008a1ca2cdSRiver Riddle types.push_back(existingRewriteValue);
9018a1ca2cdSRiver Riddle continue;
9028a1ca2cdSRiver Riddle }
9038a1ca2cdSRiver Riddle
9048a1ca2cdSRiver Riddle // Check for an input from the matcher.
9058a1ca2cdSRiver Riddle if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
9068a1ca2cdSRiver Riddle types.push_back(mapRewriteValue(resultType));
9078a1ca2cdSRiver Riddle continue;
9088a1ca2cdSRiver Riddle }
9098a1ca2cdSRiver Riddle
9103c752289SRiver Riddle // Otherwise, we couldn't infer the result types. Bail out here to see if
9113c752289SRiver Riddle // we can infer the types for this operation from another way.
9123c752289SRiver Riddle types.clear();
9133c752289SRiver Riddle return failure();
9143c752289SRiver Riddle }
9153c752289SRiver Riddle return success();
9163c752289SRiver Riddle };
9173c752289SRiver Riddle if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
9183c752289SRiver Riddle return;
9193c752289SRiver Riddle
9203c752289SRiver Riddle // Otherwise, check if the operation has type inference support itself.
9213c752289SRiver Riddle if (op.hasTypeInference()) {
9223c752289SRiver Riddle hasInferredResultTypes = true;
9233c752289SRiver Riddle return;
9243c752289SRiver Riddle }
9253c752289SRiver Riddle
926b69f10f5SRiver Riddle // Look for an operation that was replaced by `op`. The result types will be
927b69f10f5SRiver Riddle // inferred from the results that were replaced.
928310c3ee4SRiver Riddle for (OpOperand &use : op.getOp().getUses()) {
929b69f10f5SRiver Riddle // Check that the use corresponds to a ReplaceOp and that it is the
930b69f10f5SRiver Riddle // replacement value, not the operation being replaced.
931b69f10f5SRiver Riddle pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
932b69f10f5SRiver Riddle if (!replOpUser || use.getOperandNumber() == 0)
933b69f10f5SRiver Riddle continue;
934b69f10f5SRiver Riddle // Make sure the replaced operation was defined before this one. PDL
935b69f10f5SRiver Riddle // rewrites only have single block regions, so if the op isn't in the
936b69f10f5SRiver Riddle // rewriter block (i.e. the current block of the operation) we already know
937b69f10f5SRiver Riddle // it dominates (i.e. it's in the matcher).
93872fddfb5SRiver Riddle Value replOpVal = replOpUser.getOpValue();
939b69f10f5SRiver Riddle Operation *replacedOp = replOpVal.getDefiningOp();
940b69f10f5SRiver Riddle if (replacedOp->getBlock() == rewriterBlock &&
941b69f10f5SRiver Riddle !replacedOp->isBeforeInBlock(op))
942b69f10f5SRiver Riddle continue;
943b69f10f5SRiver Riddle
944b69f10f5SRiver Riddle Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
945b69f10f5SRiver Riddle replacedOp->getLoc(), mapRewriteValue(replOpVal));
946b69f10f5SRiver Riddle types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
947b69f10f5SRiver Riddle replacedOp->getLoc(), replacedOpResults));
948b69f10f5SRiver Riddle return;
949b69f10f5SRiver Riddle }
950b69f10f5SRiver Riddle
9513c752289SRiver Riddle // If the types could not be inferred from any context and there weren't any
9523c752289SRiver Riddle // explicit result types, assume the user actually meant for the operation to
9533c752289SRiver Riddle // have no results.
9543c752289SRiver Riddle if (resultTypeValues.empty())
9553c752289SRiver Riddle return;
9563c752289SRiver Riddle
957310c3ee4SRiver Riddle // The verifier asserts that the result types of each pdl.getOperation can be
9583a833a0eSRiver Riddle // inferred. If we reach here, there is a bug either in the logic above or
959310c3ee4SRiver Riddle // in the verifier for pdl.getOperation.
9603a833a0eSRiver Riddle op->emitOpError() << "unable to infer result type for operation";
9613a833a0eSRiver Riddle llvm_unreachable("unable to infer result type for operation");
9628a1ca2cdSRiver Riddle }
9638a1ca2cdSRiver Riddle
9648a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9658a1ca2cdSRiver Riddle // Conversion Pass
9668a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9678a1ca2cdSRiver Riddle
9688a1ca2cdSRiver Riddle namespace {
969039b969bSMichele Scuttari struct PDLToPDLInterpPass
97067d0d7acSMichele Scuttari : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
9718c66344eSRiver Riddle PDLToPDLInterpPass() = default;
9728c66344eSRiver Riddle PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
PDLToPDLInterpPass__anon4598d56a0811::PDLToPDLInterpPass9738c66344eSRiver Riddle PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
9748c66344eSRiver Riddle : configMap(&configMap) {}
9758a1ca2cdSRiver Riddle void runOnOperation() final;
9768c66344eSRiver Riddle
9778c66344eSRiver Riddle /// A map containing the configuration for each pattern.
9788c66344eSRiver Riddle DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
9798a1ca2cdSRiver Riddle };
9808a1ca2cdSRiver Riddle } // namespace
9818a1ca2cdSRiver Riddle
9828a1ca2cdSRiver Riddle /// Convert the given module containing PDL pattern operations into a PDL
9838a1ca2cdSRiver Riddle /// Interpreter operations.
runOnOperation()984039b969bSMichele Scuttari void PDLToPDLInterpPass::runOnOperation() {
9858a1ca2cdSRiver Riddle ModuleOp module = getOperation();
9868a1ca2cdSRiver Riddle
9878a1ca2cdSRiver Riddle // Create the main matcher function This function contains all of the match
9888a1ca2cdSRiver Riddle // related functionality from patterns in the module.
9898a1ca2cdSRiver Riddle OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
990f96a8675SRiver Riddle auto matcherFunc = builder.create<pdl_interp::FuncOp>(
9918a1ca2cdSRiver Riddle module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
9928a1ca2cdSRiver Riddle builder.getFunctionType(builder.getType<pdl::OperationType>(),
9931a36588eSKazu Hirata /*results=*/std::nullopt),
9941a36588eSKazu Hirata /*attrs=*/std::nullopt);
9958a1ca2cdSRiver Riddle
9968a1ca2cdSRiver Riddle // Create a nested module to hold the functions invoked for rewriting the IR
9978a1ca2cdSRiver Riddle // after a successful match.
9988a1ca2cdSRiver Riddle ModuleOp rewriterModule = builder.create<ModuleOp>(
9998a1ca2cdSRiver Riddle module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
10008a1ca2cdSRiver Riddle
10018a1ca2cdSRiver Riddle // Generate the code for the patterns within the module.
10028c66344eSRiver Riddle PatternLowering generator(matcherFunc, rewriterModule, configMap);
10038a1ca2cdSRiver Riddle generator.lower(module);
10048a1ca2cdSRiver Riddle
10058a1ca2cdSRiver Riddle // After generation, delete all of the pattern operations.
10068a1ca2cdSRiver Riddle for (pdl::PatternOp pattern :
10078c66344eSRiver Riddle llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
10088c66344eSRiver Riddle // Drop the now dead config mappings.
10098c66344eSRiver Riddle if (configMap)
10108c66344eSRiver Riddle configMap->erase(pattern);
10118c66344eSRiver Riddle
10128a1ca2cdSRiver Riddle pattern.erase();
10138a1ca2cdSRiver Riddle }
10148c66344eSRiver Riddle }
10158a1ca2cdSRiver Riddle
createPDLToPDLInterpPass()10168a1ca2cdSRiver Riddle std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
1017039b969bSMichele Scuttari return std::make_unique<PDLToPDLInterpPass>();
10188a1ca2cdSRiver Riddle }
createPDLToPDLInterpPass(DenseMap<Operation *,PDLPatternConfigSet * > & configMap)10198c66344eSRiver Riddle std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
10208c66344eSRiver Riddle DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
10218c66344eSRiver Riddle return std::make_unique<PDLToPDLInterpPass>(configMap);
10228c66344eSRiver Riddle }
1023