xref: /llvm-project/mlir/include/mlir/IR/RegionGraphTraits.h (revision 1ef51e0452a473f404edc635412685fce6f61004)
1 //===- RegionGraphTraits.h - llvm::GraphTraits for CFGs ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements specializations of llvm::GraphTraits for various MLIR
10 // CFG data types.  This allows the generic LLVM graph algorithms to be applied
11 // to CFGs.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_REGIONGRAPHTRAITS_H
16 #define MLIR_IR_REGIONGRAPHTRAITS_H
17 
18 #include "mlir/IR/Region.h"
19 #include "llvm/ADT/GraphTraits.h"
20 
21 namespace llvm {
22 template <>
23 struct GraphTraits<mlir::Block *> {
24   using ChildIteratorType = mlir::Block::succ_iterator;
25   using Node = mlir::Block;
26   using NodeRef = Node *;
27 
28   static NodeRef getEntryNode(NodeRef bb) { return bb; }
29 
30   static ChildIteratorType child_begin(NodeRef node) {
31     return node->succ_begin();
32   }
33   static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
34 };
35 
36 template <>
37 struct GraphTraits<Inverse<mlir::Block *>> {
38   using ChildIteratorType = mlir::Block::pred_iterator;
39   using Node = mlir::Block;
40   using NodeRef = Node *;
41   static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
42     return inverseGraph.Graph;
43   }
44   static inline ChildIteratorType child_begin(NodeRef node) {
45     return node->pred_begin();
46   }
47   static inline ChildIteratorType child_end(NodeRef node) {
48     return node->pred_end();
49   }
50 };
51 
52 template <>
53 struct GraphTraits<const mlir::Block *> {
54   using ChildIteratorType = mlir::Block::succ_iterator;
55   using Node = const mlir::Block;
56   using NodeRef = Node *;
57 
58   static NodeRef getEntryNode(NodeRef node) { return node; }
59 
60   static ChildIteratorType child_begin(NodeRef node) {
61     return const_cast<mlir::Block *>(node)->succ_begin();
62   }
63   static ChildIteratorType child_end(NodeRef node) {
64     return const_cast<mlir::Block *>(node)->succ_end();
65   }
66 };
67 
68 template <>
69 struct GraphTraits<Inverse<const mlir::Block *>> {
70   using ChildIteratorType = mlir::Block::pred_iterator;
71   using Node = const mlir::Block;
72   using NodeRef = Node *;
73 
74   static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
75     return inverseGraph.Graph;
76   }
77 
78   static ChildIteratorType child_begin(NodeRef node) {
79     return const_cast<mlir::Block *>(node)->pred_begin();
80   }
81   static ChildIteratorType child_end(NodeRef node) {
82     return const_cast<mlir::Block *>(node)->pred_end();
83   }
84 };
85 
86 template <>
87 struct GraphTraits<mlir::Region *> : public GraphTraits<mlir::Block *> {
88   using GraphType = mlir::Region *;
89   using NodeRef = mlir::Block *;
90 
91   static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
92 
93   using nodes_iterator = pointer_iterator<mlir::Region::iterator>;
94   static nodes_iterator nodes_begin(GraphType fn) {
95     return nodes_iterator(fn->begin());
96   }
97   static nodes_iterator nodes_end(GraphType fn) {
98     return nodes_iterator(fn->end());
99   }
100 };
101 
102 template <>
103 struct GraphTraits<Inverse<mlir::Region *>>
104     : public GraphTraits<Inverse<mlir::Block *>> {
105   using GraphType = Inverse<mlir::Region *>;
106   using NodeRef = NodeRef;
107 
108   static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
109 
110   using nodes_iterator = pointer_iterator<mlir::Region::iterator>;
111   static nodes_iterator nodes_begin(GraphType fn) {
112     return nodes_iterator(fn.Graph->begin());
113   }
114   static nodes_iterator nodes_end(GraphType fn) {
115     return nodes_iterator(fn.Graph->end());
116   }
117 };
118 
119 } // namespace llvm
120 
121 #endif
122