xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
18a1ca2cdSRiver Riddle //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===//
28a1ca2cdSRiver Riddle //
38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a1ca2cdSRiver Riddle //
78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
88a1ca2cdSRiver Riddle //
98a1ca2cdSRiver Riddle // This file contains definitions for nodes of a tree structure for representing
108a1ca2cdSRiver Riddle // the general control flow within a pattern match.
118a1ca2cdSRiver Riddle //
128a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
138a1ca2cdSRiver Riddle 
148a1ca2cdSRiver Riddle #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
158a1ca2cdSRiver Riddle #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
168a1ca2cdSRiver Riddle 
178a1ca2cdSRiver Riddle #include "Predicate.h"
18e45840f4SRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h"
198a1ca2cdSRiver Riddle #include "llvm/ADT/MapVector.h"
208a1ca2cdSRiver Riddle 
218a1ca2cdSRiver Riddle namespace mlir {
22672cc75cSRiver Riddle class ModuleOp;
23672cc75cSRiver Riddle 
248a1ca2cdSRiver Riddle namespace pdl_to_pdl_interp {
258a1ca2cdSRiver Riddle 
268a1ca2cdSRiver Riddle class MatcherNode;
278a1ca2cdSRiver Riddle 
288a1ca2cdSRiver Riddle /// A PositionalPredicate is a predicate that is associated with a specific
298a1ca2cdSRiver Riddle /// positional value.
308a1ca2cdSRiver Riddle struct PositionalPredicate {
PositionalPredicatePositionalPredicate318a1ca2cdSRiver Riddle   PositionalPredicate(Position *pos,
328a1ca2cdSRiver Riddle                       const PredicateBuilder::Predicate &predicate)
338a1ca2cdSRiver Riddle       : position(pos), question(predicate.first), answer(predicate.second) {}
348a1ca2cdSRiver Riddle 
358a1ca2cdSRiver Riddle   /// The position the predicate is applied to.
368a1ca2cdSRiver Riddle   Position *position;
378a1ca2cdSRiver Riddle 
388a1ca2cdSRiver Riddle   /// The question that the predicate applies.
398a1ca2cdSRiver Riddle   Qualifier *question;
408a1ca2cdSRiver Riddle 
418a1ca2cdSRiver Riddle   /// The expected answer of the predicate.
428a1ca2cdSRiver Riddle   Qualifier *answer;
438a1ca2cdSRiver Riddle };
448a1ca2cdSRiver Riddle 
458a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
468a1ca2cdSRiver Riddle // MatcherNode
478a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
488a1ca2cdSRiver Riddle 
498a1ca2cdSRiver Riddle /// This class represents the base of a predicate matcher node.
508a1ca2cdSRiver Riddle class MatcherNode {
518a1ca2cdSRiver Riddle public:
528a1ca2cdSRiver Riddle   virtual ~MatcherNode() = default;
538a1ca2cdSRiver Riddle 
548a1ca2cdSRiver Riddle   /// Given a module containing PDL pattern operations, generate a matcher tree
558a1ca2cdSRiver Riddle   /// using the patterns within the given module and return the root matcher
568a1ca2cdSRiver Riddle   /// node. `valueToPosition` is a map that is populated with the original
578a1ca2cdSRiver Riddle   /// pdl values and their corresponding positions in the matcher tree.
588a1ca2cdSRiver Riddle   static std::unique_ptr<MatcherNode>
598a1ca2cdSRiver Riddle   generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
608a1ca2cdSRiver Riddle                       DenseMap<Value, Position *> &valueToPosition);
618a1ca2cdSRiver Riddle 
628a1ca2cdSRiver Riddle   /// Returns the position on which the question predicate should be checked.
getPosition()638a1ca2cdSRiver Riddle   Position *getPosition() const { return position; }
648a1ca2cdSRiver Riddle 
658a1ca2cdSRiver Riddle   /// Returns the predicate checked on this node.
getQuestion()668a1ca2cdSRiver Riddle   Qualifier *getQuestion() const { return question; }
678a1ca2cdSRiver Riddle 
688a1ca2cdSRiver Riddle   /// Returns the node that should be visited if this, or a subsequent node
698a1ca2cdSRiver Riddle   /// fails.
getFailureNode()708a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; }
718a1ca2cdSRiver Riddle 
728a1ca2cdSRiver Riddle   /// Sets the node that should be visited if this, or a subsequent node fails.
setFailureNode(std::unique_ptr<MatcherNode> node)738a1ca2cdSRiver Riddle   void setFailureNode(std::unique_ptr<MatcherNode> node) {
748a1ca2cdSRiver Riddle     failureNode = std::move(node);
758a1ca2cdSRiver Riddle   }
768a1ca2cdSRiver Riddle 
778a1ca2cdSRiver Riddle   /// Returns the unique type ID of this matcher instance. This should not be
788a1ca2cdSRiver Riddle   /// used directly, and is provided to support type casting.
getMatcherTypeID()798a1ca2cdSRiver Riddle   TypeID getMatcherTypeID() const { return matcherTypeID; }
808a1ca2cdSRiver Riddle 
818a1ca2cdSRiver Riddle protected:
828a1ca2cdSRiver Riddle   MatcherNode(TypeID matcherTypeID, Position *position = nullptr,
838a1ca2cdSRiver Riddle               Qualifier *question = nullptr,
848a1ca2cdSRiver Riddle               std::unique_ptr<MatcherNode> failureNode = nullptr);
858a1ca2cdSRiver Riddle 
868a1ca2cdSRiver Riddle private:
878a1ca2cdSRiver Riddle   /// The position on which the predicate should be checked.
888a1ca2cdSRiver Riddle   Position *position;
898a1ca2cdSRiver Riddle 
908a1ca2cdSRiver Riddle   /// The predicate that is checked on the given position.
918a1ca2cdSRiver Riddle   Qualifier *question;
928a1ca2cdSRiver Riddle 
938a1ca2cdSRiver Riddle   /// The node to visit if this node fails.
948a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> failureNode;
958a1ca2cdSRiver Riddle 
968a1ca2cdSRiver Riddle   /// An owning store for the failure node if it is owned by this node.
978a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> failureNodeStorage;
988a1ca2cdSRiver Riddle 
998a1ca2cdSRiver Riddle   /// A unique identifier for the derived matcher node, used for type casting.
1008a1ca2cdSRiver Riddle   TypeID matcherTypeID;
1018a1ca2cdSRiver Riddle };
1028a1ca2cdSRiver Riddle 
1038a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
1048a1ca2cdSRiver Riddle // BoolNode
1058a1ca2cdSRiver Riddle 
1068a1ca2cdSRiver Riddle /// A BoolNode denotes a question with a boolean-like result. These nodes branch
1078a1ca2cdSRiver Riddle /// to a single node on a successful result, otherwise defaulting to the failure
1088a1ca2cdSRiver Riddle /// node.
1098a1ca2cdSRiver Riddle struct BoolNode : public MatcherNode {
1108a1ca2cdSRiver Riddle   BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1118a1ca2cdSRiver Riddle            std::unique_ptr<MatcherNode> successNode,
1128a1ca2cdSRiver Riddle            std::unique_ptr<MatcherNode> failureNode = nullptr);
1138a1ca2cdSRiver Riddle 
1148a1ca2cdSRiver Riddle   /// Returns if the given matcher node is an instance of this class, used to
1158a1ca2cdSRiver Riddle   /// support type casting.
classofBoolNode1168a1ca2cdSRiver Riddle   static bool classof(const MatcherNode *node) {
1178a1ca2cdSRiver Riddle     return node->getMatcherTypeID() == TypeID::get<BoolNode>();
1188a1ca2cdSRiver Riddle   }
1198a1ca2cdSRiver Riddle 
1208a1ca2cdSRiver Riddle   /// Returns the expected answer of this boolean node.
getAnswerBoolNode1218a1ca2cdSRiver Riddle   Qualifier *getAnswer() const { return answer; }
1228a1ca2cdSRiver Riddle 
1238a1ca2cdSRiver Riddle   /// Returns the node that should be visited on success.
getSuccessNodeBoolNode1248a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; }
1258a1ca2cdSRiver Riddle 
1268a1ca2cdSRiver Riddle private:
1278a1ca2cdSRiver Riddle   /// The expected answer of this boolean node.
1288a1ca2cdSRiver Riddle   Qualifier *answer;
1298a1ca2cdSRiver Riddle 
1308a1ca2cdSRiver Riddle   /// The next node if this node succeeds. Otherwise, go to the failure node.
1318a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> successNode;
1328a1ca2cdSRiver Riddle };
1338a1ca2cdSRiver Riddle 
1348a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
1358a1ca2cdSRiver Riddle // ExitNode
1368a1ca2cdSRiver Riddle 
1378a1ca2cdSRiver Riddle /// An ExitNode is a special sentinel node that denotes the end of matcher.
1388a1ca2cdSRiver Riddle struct ExitNode : public MatcherNode {
ExitNodeExitNode1398a1ca2cdSRiver Riddle   ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {}
1408a1ca2cdSRiver Riddle 
1418a1ca2cdSRiver Riddle   /// Returns if the given matcher node is an instance of this class, used to
1428a1ca2cdSRiver Riddle   /// support type casting.
classofExitNode1438a1ca2cdSRiver Riddle   static bool classof(const MatcherNode *node) {
1448a1ca2cdSRiver Riddle     return node->getMatcherTypeID() == TypeID::get<ExitNode>();
1458a1ca2cdSRiver Riddle   }
1468a1ca2cdSRiver Riddle };
1478a1ca2cdSRiver Riddle 
1488a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
1498a1ca2cdSRiver Riddle // SuccessNode
1508a1ca2cdSRiver Riddle 
1518a1ca2cdSRiver Riddle /// A SuccessNode denotes that a given high level pattern has successfully been
1528a1ca2cdSRiver Riddle /// matched. This does not terminate the matcher, as there may be multiple
1538a1ca2cdSRiver Riddle /// successful matches.
1548a1ca2cdSRiver Riddle struct SuccessNode : public MatcherNode {
155a76ee58fSStanislav Funiak   explicit SuccessNode(pdl::PatternOp pattern, Value root,
1568a1ca2cdSRiver Riddle                        std::unique_ptr<MatcherNode> failureNode);
1578a1ca2cdSRiver Riddle 
1588a1ca2cdSRiver Riddle   /// Returns if the given matcher node is an instance of this class, used to
1598a1ca2cdSRiver Riddle   /// support type casting.
classofSuccessNode1608a1ca2cdSRiver Riddle   static bool classof(const MatcherNode *node) {
1618a1ca2cdSRiver Riddle     return node->getMatcherTypeID() == TypeID::get<SuccessNode>();
1628a1ca2cdSRiver Riddle   }
1638a1ca2cdSRiver Riddle 
1648a1ca2cdSRiver Riddle   /// Return the high level pattern operation that is matched with this node.
getPatternSuccessNode1658a1ca2cdSRiver Riddle   pdl::PatternOp getPattern() const { return pattern; }
1668a1ca2cdSRiver Riddle 
167a76ee58fSStanislav Funiak   /// Return the chosen root of the pattern.
getRootSuccessNode168a76ee58fSStanislav Funiak   Value getRoot() const { return root; }
169a76ee58fSStanislav Funiak 
1708a1ca2cdSRiver Riddle private:
1718a1ca2cdSRiver Riddle   /// The high level pattern operation that was successfully matched with this
1728a1ca2cdSRiver Riddle   /// node.
1738a1ca2cdSRiver Riddle   pdl::PatternOp pattern;
174a76ee58fSStanislav Funiak 
175a76ee58fSStanislav Funiak   /// The chosen root of the pattern.
176a76ee58fSStanislav Funiak   Value root;
1778a1ca2cdSRiver Riddle };
1788a1ca2cdSRiver Riddle 
1798a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
1808a1ca2cdSRiver Riddle // SwitchNode
1818a1ca2cdSRiver Riddle 
1828a1ca2cdSRiver Riddle /// A SwitchNode denotes a question with multiple potential results. These nodes
1838a1ca2cdSRiver Riddle /// branch to a specific node based on the result of the question.
1848a1ca2cdSRiver Riddle struct SwitchNode : public MatcherNode {
1858a1ca2cdSRiver Riddle   SwitchNode(Position *position, Qualifier *question);
1868a1ca2cdSRiver Riddle 
1878a1ca2cdSRiver Riddle   /// Returns if the given matcher node is an instance of this class, used to
1888a1ca2cdSRiver Riddle   /// support type casting.
classofSwitchNode1898a1ca2cdSRiver Riddle   static bool classof(const MatcherNode *node) {
1908a1ca2cdSRiver Riddle     return node->getMatcherTypeID() == TypeID::get<SwitchNode>();
1918a1ca2cdSRiver Riddle   }
1928a1ca2cdSRiver Riddle 
1938a1ca2cdSRiver Riddle   /// Returns the children of this switch node. The children are contained
1948a1ca2cdSRiver Riddle   /// within a mapping between the various case answers to destination matcher
1958a1ca2cdSRiver Riddle   /// nodes.
1968a1ca2cdSRiver Riddle   using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
getChildrenSwitchNode1978a1ca2cdSRiver Riddle   ChildMapT &getChildren() { return children; }
1988a1ca2cdSRiver Riddle 
1993a833a0eSRiver Riddle   /// Returns the child at the given index.
getChildSwitchNode2003a833a0eSRiver Riddle   std::pair<Qualifier *, std::unique_ptr<MatcherNode>> &getChild(unsigned i) {
2013a833a0eSRiver Riddle     assert(i < children.size() && "invalid child index");
2023a833a0eSRiver Riddle     return *std::next(children.begin(), i);
2033a833a0eSRiver Riddle   }
2043a833a0eSRiver Riddle 
2058a1ca2cdSRiver Riddle private:
2068a1ca2cdSRiver Riddle   /// Switch predicate "answers" select the child. Answers that are not found
2078a1ca2cdSRiver Riddle   /// default to the failure node.
2088a1ca2cdSRiver Riddle   ChildMapT children;
2098a1ca2cdSRiver Riddle };
2108a1ca2cdSRiver Riddle 
211*be0a7e9fSMehdi Amini } // namespace pdl_to_pdl_interp
212*be0a7e9fSMehdi Amini } // namespace mlir
2138a1ca2cdSRiver Riddle 
2148a1ca2cdSRiver Riddle #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
215