1 //===- RegionGraphTraits.h - llvm::GraphTraits for CFGs ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements specializations of llvm::GraphTraits for various MLIR 10 // CFG data types. This allows the generic LLVM graph algorithms to be applied 11 // to CFGs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_IR_REGIONGRAPHTRAITS_H 16 #define MLIR_IR_REGIONGRAPHTRAITS_H 17 18 #include "mlir/IR/Region.h" 19 #include "llvm/ADT/GraphTraits.h" 20 21 namespace llvm { 22 template <> 23 struct GraphTraits<mlir::Block *> { 24 using ChildIteratorType = mlir::Block::succ_iterator; 25 using Node = mlir::Block; 26 using NodeRef = Node *; 27 28 static NodeRef getEntryNode(NodeRef bb) { return bb; } 29 30 static ChildIteratorType child_begin(NodeRef node) { 31 return node->succ_begin(); 32 } 33 static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } 34 }; 35 36 template <> 37 struct GraphTraits<Inverse<mlir::Block *>> { 38 using ChildIteratorType = mlir::Block::pred_iterator; 39 using Node = mlir::Block; 40 using NodeRef = Node *; 41 static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) { 42 return inverseGraph.Graph; 43 } 44 static inline ChildIteratorType child_begin(NodeRef node) { 45 return node->pred_begin(); 46 } 47 static inline ChildIteratorType child_end(NodeRef node) { 48 return node->pred_end(); 49 } 50 }; 51 52 template <> 53 struct GraphTraits<const mlir::Block *> { 54 using ChildIteratorType = mlir::Block::succ_iterator; 55 using Node = const mlir::Block; 56 using NodeRef = Node *; 57 58 static NodeRef getEntryNode(NodeRef node) { return node; } 59 60 static ChildIteratorType child_begin(NodeRef node) { 61 return const_cast<mlir::Block *>(node)->succ_begin(); 62 } 63 static ChildIteratorType child_end(NodeRef node) { 64 return const_cast<mlir::Block *>(node)->succ_end(); 65 } 66 }; 67 68 template <> 69 struct GraphTraits<Inverse<const mlir::Block *>> { 70 using ChildIteratorType = mlir::Block::pred_iterator; 71 using Node = const mlir::Block; 72 using NodeRef = Node *; 73 74 static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) { 75 return inverseGraph.Graph; 76 } 77 78 static ChildIteratorType child_begin(NodeRef node) { 79 return const_cast<mlir::Block *>(node)->pred_begin(); 80 } 81 static ChildIteratorType child_end(NodeRef node) { 82 return const_cast<mlir::Block *>(node)->pred_end(); 83 } 84 }; 85 86 template <> 87 struct GraphTraits<mlir::Region *> : public GraphTraits<mlir::Block *> { 88 using GraphType = mlir::Region *; 89 using NodeRef = mlir::Block *; 90 91 static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } 92 93 using nodes_iterator = pointer_iterator<mlir::Region::iterator>; 94 static nodes_iterator nodes_begin(GraphType fn) { 95 return nodes_iterator(fn->begin()); 96 } 97 static nodes_iterator nodes_end(GraphType fn) { 98 return nodes_iterator(fn->end()); 99 } 100 }; 101 102 template <> 103 struct GraphTraits<Inverse<mlir::Region *>> 104 : public GraphTraits<Inverse<mlir::Block *>> { 105 using GraphType = Inverse<mlir::Region *>; 106 using NodeRef = NodeRef; 107 108 static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } 109 110 using nodes_iterator = pointer_iterator<mlir::Region::iterator>; 111 static nodes_iterator nodes_begin(GraphType fn) { 112 return nodes_iterator(fn.Graph->begin()); 113 } 114 static nodes_iterator nodes_end(GraphType fn) { 115 return nodes_iterator(fn.Graph->end()); 116 } 117 }; 118 119 } // namespace llvm 120 121 #endif 122