18a1ca2cdSRiver Riddle //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===// 28a1ca2cdSRiver Riddle // 38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a1ca2cdSRiver Riddle // 78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 88a1ca2cdSRiver Riddle // 98a1ca2cdSRiver Riddle // This file contains definitions for nodes of a tree structure for representing 108a1ca2cdSRiver Riddle // the general control flow within a pattern match. 118a1ca2cdSRiver Riddle // 128a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 138a1ca2cdSRiver Riddle 148a1ca2cdSRiver Riddle #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 158a1ca2cdSRiver Riddle #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 168a1ca2cdSRiver Riddle 178a1ca2cdSRiver Riddle #include "Predicate.h" 18e45840f4SRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h" 198a1ca2cdSRiver Riddle #include "llvm/ADT/MapVector.h" 208a1ca2cdSRiver Riddle 218a1ca2cdSRiver Riddle namespace mlir { 22672cc75cSRiver Riddle class ModuleOp; 23672cc75cSRiver Riddle 248a1ca2cdSRiver Riddle namespace pdl_to_pdl_interp { 258a1ca2cdSRiver Riddle 268a1ca2cdSRiver Riddle class MatcherNode; 278a1ca2cdSRiver Riddle 288a1ca2cdSRiver Riddle /// A PositionalPredicate is a predicate that is associated with a specific 298a1ca2cdSRiver Riddle /// positional value. 308a1ca2cdSRiver Riddle struct PositionalPredicate { PositionalPredicatePositionalPredicate318a1ca2cdSRiver Riddle PositionalPredicate(Position *pos, 328a1ca2cdSRiver Riddle const PredicateBuilder::Predicate &predicate) 338a1ca2cdSRiver Riddle : position(pos), question(predicate.first), answer(predicate.second) {} 348a1ca2cdSRiver Riddle 358a1ca2cdSRiver Riddle /// The position the predicate is applied to. 368a1ca2cdSRiver Riddle Position *position; 378a1ca2cdSRiver Riddle 388a1ca2cdSRiver Riddle /// The question that the predicate applies. 398a1ca2cdSRiver Riddle Qualifier *question; 408a1ca2cdSRiver Riddle 418a1ca2cdSRiver Riddle /// The expected answer of the predicate. 428a1ca2cdSRiver Riddle Qualifier *answer; 438a1ca2cdSRiver Riddle }; 448a1ca2cdSRiver Riddle 458a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 468a1ca2cdSRiver Riddle // MatcherNode 478a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 488a1ca2cdSRiver Riddle 498a1ca2cdSRiver Riddle /// This class represents the base of a predicate matcher node. 508a1ca2cdSRiver Riddle class MatcherNode { 518a1ca2cdSRiver Riddle public: 528a1ca2cdSRiver Riddle virtual ~MatcherNode() = default; 538a1ca2cdSRiver Riddle 548a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree 558a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher 568a1ca2cdSRiver Riddle /// node. `valueToPosition` is a map that is populated with the original 578a1ca2cdSRiver Riddle /// pdl values and their corresponding positions in the matcher tree. 588a1ca2cdSRiver Riddle static std::unique_ptr<MatcherNode> 598a1ca2cdSRiver Riddle generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 608a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition); 618a1ca2cdSRiver Riddle 628a1ca2cdSRiver Riddle /// Returns the position on which the question predicate should be checked. getPosition()638a1ca2cdSRiver Riddle Position *getPosition() const { return position; } 648a1ca2cdSRiver Riddle 658a1ca2cdSRiver Riddle /// Returns the predicate checked on this node. getQuestion()668a1ca2cdSRiver Riddle Qualifier *getQuestion() const { return question; } 678a1ca2cdSRiver Riddle 688a1ca2cdSRiver Riddle /// Returns the node that should be visited if this, or a subsequent node 698a1ca2cdSRiver Riddle /// fails. getFailureNode()708a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; } 718a1ca2cdSRiver Riddle 728a1ca2cdSRiver Riddle /// Sets the node that should be visited if this, or a subsequent node fails. setFailureNode(std::unique_ptr<MatcherNode> node)738a1ca2cdSRiver Riddle void setFailureNode(std::unique_ptr<MatcherNode> node) { 748a1ca2cdSRiver Riddle failureNode = std::move(node); 758a1ca2cdSRiver Riddle } 768a1ca2cdSRiver Riddle 778a1ca2cdSRiver Riddle /// Returns the unique type ID of this matcher instance. This should not be 788a1ca2cdSRiver Riddle /// used directly, and is provided to support type casting. getMatcherTypeID()798a1ca2cdSRiver Riddle TypeID getMatcherTypeID() const { return matcherTypeID; } 808a1ca2cdSRiver Riddle 818a1ca2cdSRiver Riddle protected: 828a1ca2cdSRiver Riddle MatcherNode(TypeID matcherTypeID, Position *position = nullptr, 838a1ca2cdSRiver Riddle Qualifier *question = nullptr, 848a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode = nullptr); 858a1ca2cdSRiver Riddle 868a1ca2cdSRiver Riddle private: 878a1ca2cdSRiver Riddle /// The position on which the predicate should be checked. 888a1ca2cdSRiver Riddle Position *position; 898a1ca2cdSRiver Riddle 908a1ca2cdSRiver Riddle /// The predicate that is checked on the given position. 918a1ca2cdSRiver Riddle Qualifier *question; 928a1ca2cdSRiver Riddle 938a1ca2cdSRiver Riddle /// The node to visit if this node fails. 948a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode; 958a1ca2cdSRiver Riddle 968a1ca2cdSRiver Riddle /// An owning store for the failure node if it is owned by this node. 978a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNodeStorage; 988a1ca2cdSRiver Riddle 998a1ca2cdSRiver Riddle /// A unique identifier for the derived matcher node, used for type casting. 1008a1ca2cdSRiver Riddle TypeID matcherTypeID; 1018a1ca2cdSRiver Riddle }; 1028a1ca2cdSRiver Riddle 1038a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1048a1ca2cdSRiver Riddle // BoolNode 1058a1ca2cdSRiver Riddle 1068a1ca2cdSRiver Riddle /// A BoolNode denotes a question with a boolean-like result. These nodes branch 1078a1ca2cdSRiver Riddle /// to a single node on a successful result, otherwise defaulting to the failure 1088a1ca2cdSRiver Riddle /// node. 1098a1ca2cdSRiver Riddle struct BoolNode : public MatcherNode { 1108a1ca2cdSRiver Riddle BoolNode(Position *position, Qualifier *question, Qualifier *answer, 1118a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> successNode, 1128a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode = nullptr); 1138a1ca2cdSRiver Riddle 1148a1ca2cdSRiver Riddle /// Returns if the given matcher node is an instance of this class, used to 1158a1ca2cdSRiver Riddle /// support type casting. classofBoolNode1168a1ca2cdSRiver Riddle static bool classof(const MatcherNode *node) { 1178a1ca2cdSRiver Riddle return node->getMatcherTypeID() == TypeID::get<BoolNode>(); 1188a1ca2cdSRiver Riddle } 1198a1ca2cdSRiver Riddle 1208a1ca2cdSRiver Riddle /// Returns the expected answer of this boolean node. getAnswerBoolNode1218a1ca2cdSRiver Riddle Qualifier *getAnswer() const { return answer; } 1228a1ca2cdSRiver Riddle 1238a1ca2cdSRiver Riddle /// Returns the node that should be visited on success. getSuccessNodeBoolNode1248a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; } 1258a1ca2cdSRiver Riddle 1268a1ca2cdSRiver Riddle private: 1278a1ca2cdSRiver Riddle /// The expected answer of this boolean node. 1288a1ca2cdSRiver Riddle Qualifier *answer; 1298a1ca2cdSRiver Riddle 1308a1ca2cdSRiver Riddle /// The next node if this node succeeds. Otherwise, go to the failure node. 1318a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> successNode; 1328a1ca2cdSRiver Riddle }; 1338a1ca2cdSRiver Riddle 1348a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1358a1ca2cdSRiver Riddle // ExitNode 1368a1ca2cdSRiver Riddle 1378a1ca2cdSRiver Riddle /// An ExitNode is a special sentinel node that denotes the end of matcher. 1388a1ca2cdSRiver Riddle struct ExitNode : public MatcherNode { ExitNodeExitNode1398a1ca2cdSRiver Riddle ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {} 1408a1ca2cdSRiver Riddle 1418a1ca2cdSRiver Riddle /// Returns if the given matcher node is an instance of this class, used to 1428a1ca2cdSRiver Riddle /// support type casting. classofExitNode1438a1ca2cdSRiver Riddle static bool classof(const MatcherNode *node) { 1448a1ca2cdSRiver Riddle return node->getMatcherTypeID() == TypeID::get<ExitNode>(); 1458a1ca2cdSRiver Riddle } 1468a1ca2cdSRiver Riddle }; 1478a1ca2cdSRiver Riddle 1488a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1498a1ca2cdSRiver Riddle // SuccessNode 1508a1ca2cdSRiver Riddle 1518a1ca2cdSRiver Riddle /// A SuccessNode denotes that a given high level pattern has successfully been 1528a1ca2cdSRiver Riddle /// matched. This does not terminate the matcher, as there may be multiple 1538a1ca2cdSRiver Riddle /// successful matches. 1548a1ca2cdSRiver Riddle struct SuccessNode : public MatcherNode { 155a76ee58fSStanislav Funiak explicit SuccessNode(pdl::PatternOp pattern, Value root, 1568a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode); 1578a1ca2cdSRiver Riddle 1588a1ca2cdSRiver Riddle /// Returns if the given matcher node is an instance of this class, used to 1598a1ca2cdSRiver Riddle /// support type casting. classofSuccessNode1608a1ca2cdSRiver Riddle static bool classof(const MatcherNode *node) { 1618a1ca2cdSRiver Riddle return node->getMatcherTypeID() == TypeID::get<SuccessNode>(); 1628a1ca2cdSRiver Riddle } 1638a1ca2cdSRiver Riddle 1648a1ca2cdSRiver Riddle /// Return the high level pattern operation that is matched with this node. getPatternSuccessNode1658a1ca2cdSRiver Riddle pdl::PatternOp getPattern() const { return pattern; } 1668a1ca2cdSRiver Riddle 167a76ee58fSStanislav Funiak /// Return the chosen root of the pattern. getRootSuccessNode168a76ee58fSStanislav Funiak Value getRoot() const { return root; } 169a76ee58fSStanislav Funiak 1708a1ca2cdSRiver Riddle private: 1718a1ca2cdSRiver Riddle /// The high level pattern operation that was successfully matched with this 1728a1ca2cdSRiver Riddle /// node. 1738a1ca2cdSRiver Riddle pdl::PatternOp pattern; 174a76ee58fSStanislav Funiak 175a76ee58fSStanislav Funiak /// The chosen root of the pattern. 176a76ee58fSStanislav Funiak Value root; 1778a1ca2cdSRiver Riddle }; 1788a1ca2cdSRiver Riddle 1798a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1808a1ca2cdSRiver Riddle // SwitchNode 1818a1ca2cdSRiver Riddle 1828a1ca2cdSRiver Riddle /// A SwitchNode denotes a question with multiple potential results. These nodes 1838a1ca2cdSRiver Riddle /// branch to a specific node based on the result of the question. 1848a1ca2cdSRiver Riddle struct SwitchNode : public MatcherNode { 1858a1ca2cdSRiver Riddle SwitchNode(Position *position, Qualifier *question); 1868a1ca2cdSRiver Riddle 1878a1ca2cdSRiver Riddle /// Returns if the given matcher node is an instance of this class, used to 1888a1ca2cdSRiver Riddle /// support type casting. classofSwitchNode1898a1ca2cdSRiver Riddle static bool classof(const MatcherNode *node) { 1908a1ca2cdSRiver Riddle return node->getMatcherTypeID() == TypeID::get<SwitchNode>(); 1918a1ca2cdSRiver Riddle } 1928a1ca2cdSRiver Riddle 1938a1ca2cdSRiver Riddle /// Returns the children of this switch node. The children are contained 1948a1ca2cdSRiver Riddle /// within a mapping between the various case answers to destination matcher 1958a1ca2cdSRiver Riddle /// nodes. 1968a1ca2cdSRiver Riddle using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>; getChildrenSwitchNode1978a1ca2cdSRiver Riddle ChildMapT &getChildren() { return children; } 1988a1ca2cdSRiver Riddle 1993a833a0eSRiver Riddle /// Returns the child at the given index. getChildSwitchNode2003a833a0eSRiver Riddle std::pair<Qualifier *, std::unique_ptr<MatcherNode>> &getChild(unsigned i) { 2013a833a0eSRiver Riddle assert(i < children.size() && "invalid child index"); 2023a833a0eSRiver Riddle return *std::next(children.begin(), i); 2033a833a0eSRiver Riddle } 2043a833a0eSRiver Riddle 2058a1ca2cdSRiver Riddle private: 2068a1ca2cdSRiver Riddle /// Switch predicate "answers" select the child. Answers that are not found 2078a1ca2cdSRiver Riddle /// default to the failure node. 2088a1ca2cdSRiver Riddle ChildMapT children; 2098a1ca2cdSRiver Riddle }; 2108a1ca2cdSRiver Riddle 211*be0a7e9fSMehdi Amini } // namespace pdl_to_pdl_interp 212*be0a7e9fSMehdi Amini } // namespace mlir 2138a1ca2cdSRiver Riddle 2148a1ca2cdSRiver Riddle #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 215