xref: /llvm-project/mlir/unittests/Support/CyclicReplacerCacheTest.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1ec50f582SBilly Zhu //===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
2ec50f582SBilly Zhu //
3ec50f582SBilly Zhu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ec50f582SBilly Zhu // See https://llvm.org/LICENSE.txt for license information.
5ec50f582SBilly Zhu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ec50f582SBilly Zhu //
7ec50f582SBilly Zhu //===----------------------------------------------------------------------===//
8ec50f582SBilly Zhu 
9ec50f582SBilly Zhu #include "mlir/Support/CyclicReplacerCache.h"
10d601ca3cSHaojian Wu #include "mlir/Support/LLVM.h"
11ec50f582SBilly Zhu #include "llvm/ADT/SetVector.h"
12ec50f582SBilly Zhu #include "gmock/gmock.h"
13ec50f582SBilly Zhu #include <map>
14ec50f582SBilly Zhu #include <set>
15ec50f582SBilly Zhu 
16ec50f582SBilly Zhu using namespace mlir;
17ec50f582SBilly Zhu 
18ec50f582SBilly Zhu TEST(CachedCyclicReplacerTest, testNoRecursion) {
19ec50f582SBilly Zhu   CachedCyclicReplacer<int, bool> replacer(
20ec50f582SBilly Zhu       /*replacer=*/[](int n) { return static_cast<bool>(n); },
21ec50f582SBilly Zhu       /*cycleBreaker=*/[](int n) { return std::nullopt; });
22ec50f582SBilly Zhu 
23ec50f582SBilly Zhu   EXPECT_EQ(replacer(3), true);
24ec50f582SBilly Zhu   EXPECT_EQ(replacer(0), false);
25ec50f582SBilly Zhu }
26ec50f582SBilly Zhu 
27ec50f582SBilly Zhu TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
28ec50f582SBilly Zhu   // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
2980ff3911SRyan Holt   std::optional<CachedCyclicReplacer<int, int>> replacer;
3080ff3911SRyan Holt   replacer.emplace(
3180ff3911SRyan Holt       /*replacer=*/[&](int n) { return (*replacer)((n + 1) % 3); },
32ec50f582SBilly Zhu       /*cycleBreaker=*/[&](int n) { return -1; });
33ec50f582SBilly Zhu 
34ec50f582SBilly Zhu   // Starting at 0.
3580ff3911SRyan Holt   EXPECT_EQ((*replacer)(0), -1);
36ec50f582SBilly Zhu   // Starting at 2.
3780ff3911SRyan Holt   EXPECT_EQ((*replacer)(2), -1);
38ec50f582SBilly Zhu }
39ec50f582SBilly Zhu 
40ec50f582SBilly Zhu //===----------------------------------------------------------------------===//
41ec50f582SBilly Zhu // CachedCyclicReplacer: ChainRecursion
42ec50f582SBilly Zhu //===----------------------------------------------------------------------===//
43ec50f582SBilly Zhu 
44ec50f582SBilly Zhu /// This set of tests uses a replacer function that maps ints into vectors of
45ec50f582SBilly Zhu /// ints.
46ec50f582SBilly Zhu ///
47ec50f582SBilly Zhu /// The replacement result for input `n` is the replacement result of `(n+1)%3`
48ec50f582SBilly Zhu /// appended with an element `42`. Theoretically, this will produce an
49ec50f582SBilly Zhu /// infinitely long vector. The cycle-breaker function prunes this infinite
50ec50f582SBilly Zhu /// recursion in the replacer logic by returning an empty vector upon the first
51ec50f582SBilly Zhu /// re-occurrence of an input value.
52ec50f582SBilly Zhu namespace {
53ec50f582SBilly Zhu class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
54ec50f582SBilly Zhu public:
55ec50f582SBilly Zhu   // N ==> (N+1) % 3
56ec50f582SBilly Zhu   // This will create a chain of infinite length without recursion pruning.
57ec50f582SBilly Zhu   CachedCyclicReplacerChainRecursionPruningTest()
58ec50f582SBilly Zhu       : replacer(
59ec50f582SBilly Zhu             [&](int n) {
60ec50f582SBilly Zhu               ++invokeCount;
61ec50f582SBilly Zhu               std::vector<int> result = replacer((n + 1) % 3);
62ec50f582SBilly Zhu               result.push_back(42);
63ec50f582SBilly Zhu               return result;
64ec50f582SBilly Zhu             },
65ec50f582SBilly Zhu             [&](int n) -> std::optional<std::vector<int>> {
66ec50f582SBilly Zhu               return baseCase.value_or(n) == n
67ec50f582SBilly Zhu                          ? std::make_optional(std::vector<int>{})
68ec50f582SBilly Zhu                          : std::nullopt;
69ec50f582SBilly Zhu             }) {}
70ec50f582SBilly Zhu 
71ec50f582SBilly Zhu   std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
72ec50f582SBilly Zhu 
73ec50f582SBilly Zhu   CachedCyclicReplacer<int, std::vector<int>> replacer;
74ec50f582SBilly Zhu   int invokeCount = 0;
75ec50f582SBilly Zhu   std::optional<int> baseCase = std::nullopt;
76ec50f582SBilly Zhu };
77ec50f582SBilly Zhu } // namespace
78ec50f582SBilly Zhu 
79ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
80ec50f582SBilly Zhu   // Starting at 0. Cycle length is 3.
81ec50f582SBilly Zhu   EXPECT_EQ(replacer(0), getChain(3));
82ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 3);
83ec50f582SBilly Zhu 
84ec50f582SBilly Zhu   // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
85ec50f582SBilly Zhu   invokeCount = 0;
86ec50f582SBilly Zhu   EXPECT_EQ(replacer(1), getChain(5));
87ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 2);
88ec50f582SBilly Zhu 
89ec50f582SBilly Zhu   // Starting at 2. Cycle length is 4. Entire result is cached.
90ec50f582SBilly Zhu   invokeCount = 0;
91ec50f582SBilly Zhu   EXPECT_EQ(replacer(2), getChain(4));
92ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 0);
93ec50f582SBilly Zhu }
94ec50f582SBilly Zhu 
95ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
96ec50f582SBilly Zhu   // Starting at 1. Cycle length is 3.
97ec50f582SBilly Zhu   EXPECT_EQ(replacer(1), getChain(3));
98ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 3);
99ec50f582SBilly Zhu }
100ec50f582SBilly Zhu 
101ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
102ec50f582SBilly Zhu   baseCase = 0;
103ec50f582SBilly Zhu 
104ec50f582SBilly Zhu   // Starting at 0. Cycle length is 3.
105ec50f582SBilly Zhu   EXPECT_EQ(replacer(0), getChain(3));
106ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 3);
107ec50f582SBilly Zhu }
108ec50f582SBilly Zhu 
109ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
110ec50f582SBilly Zhu   baseCase = 0;
111ec50f582SBilly Zhu 
112ec50f582SBilly Zhu   // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
113ec50f582SBilly Zhu   EXPECT_EQ(replacer(1), getChain(5));
114ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 5);
115ec50f582SBilly Zhu 
116ec50f582SBilly Zhu   // Starting at 0. Cycle length is 3. Entire result is cached.
117ec50f582SBilly Zhu   invokeCount = 0;
118ec50f582SBilly Zhu   EXPECT_EQ(replacer(0), getChain(3));
119ec50f582SBilly Zhu   EXPECT_EQ(invokeCount, 0);
120ec50f582SBilly Zhu }
121ec50f582SBilly Zhu 
122ec50f582SBilly Zhu //===----------------------------------------------------------------------===//
123ec50f582SBilly Zhu // CachedCyclicReplacer: GraphReplacement
124ec50f582SBilly Zhu //===----------------------------------------------------------------------===//
125ec50f582SBilly Zhu 
126ec50f582SBilly Zhu /// This set of tests uses a replacer function that maps from cyclic graphs to
127ec50f582SBilly Zhu /// trees, pruning out cycles in the process.
128ec50f582SBilly Zhu ///
129ec50f582SBilly Zhu /// It consists of two helper classes:
130ec50f582SBilly Zhu /// - Graph
131ec50f582SBilly Zhu ///   - A directed graph where nodes are non-negative integers.
132ec50f582SBilly Zhu /// - PrunedGraph
133ec50f582SBilly Zhu ///   - A Graph where edges that used to cause cycles are now represented with
134ec50f582SBilly Zhu ///     an indirection (a recursionId).
135ec50f582SBilly Zhu namespace {
136ec50f582SBilly Zhu class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
137ec50f582SBilly Zhu public:
138ec50f582SBilly Zhu   /// A directed graph where nodes are non-negative integers.
139ec50f582SBilly Zhu   struct Graph {
140ec50f582SBilly Zhu     using Node = int64_t;
141ec50f582SBilly Zhu 
142ec50f582SBilly Zhu     /// Use ordered containers for deterministic output.
143ec50f582SBilly Zhu     /// Nodes without outgoing edges are considered nonexistent.
144ec50f582SBilly Zhu     std::map<Node, std::set<Node>> edges;
145ec50f582SBilly Zhu 
146ec50f582SBilly Zhu     void addEdge(Node src, Node sink) { edges[src].insert(sink); }
147ec50f582SBilly Zhu 
148ec50f582SBilly Zhu     bool isCyclic() const {
149ec50f582SBilly Zhu       DenseSet<Node> visited;
150ec50f582SBilly Zhu       for (Node root : llvm::make_first_range(edges)) {
151ec50f582SBilly Zhu         if (visited.contains(root))
152ec50f582SBilly Zhu           continue;
153ec50f582SBilly Zhu 
154ec50f582SBilly Zhu         SetVector<Node> path;
155ec50f582SBilly Zhu         SmallVector<Node> workstack;
156ec50f582SBilly Zhu         workstack.push_back(root);
157ec50f582SBilly Zhu         while (!workstack.empty()) {
158ec50f582SBilly Zhu           Node curr = workstack.back();
159ec50f582SBilly Zhu           workstack.pop_back();
160ec50f582SBilly Zhu 
161ec50f582SBilly Zhu           if (curr < 0) {
162ec50f582SBilly Zhu             // A negative node signals the end of processing all of this node's
163ec50f582SBilly Zhu             // children. Remove self from path.
164ec50f582SBilly Zhu             assert(path.back() == -curr && "internal inconsistency");
165ec50f582SBilly Zhu             path.pop_back();
166ec50f582SBilly Zhu             continue;
167ec50f582SBilly Zhu           }
168ec50f582SBilly Zhu 
169ec50f582SBilly Zhu           if (path.contains(curr))
170ec50f582SBilly Zhu             return true;
171ec50f582SBilly Zhu 
172ec50f582SBilly Zhu           visited.insert(curr);
173ec50f582SBilly Zhu           auto edgesIter = edges.find(curr);
174ec50f582SBilly Zhu           if (edgesIter == edges.end() || edgesIter->second.empty())
175ec50f582SBilly Zhu             continue;
176ec50f582SBilly Zhu 
177ec50f582SBilly Zhu           path.insert(curr);
178ec50f582SBilly Zhu           // Push negative node to signify recursion return.
179ec50f582SBilly Zhu           workstack.push_back(-curr);
180ec50f582SBilly Zhu           workstack.insert(workstack.end(), edgesIter->second.begin(),
181ec50f582SBilly Zhu                            edgesIter->second.end());
182ec50f582SBilly Zhu         }
183ec50f582SBilly Zhu       }
184ec50f582SBilly Zhu       return false;
185ec50f582SBilly Zhu     }
186ec50f582SBilly Zhu 
187ec50f582SBilly Zhu     /// Deterministic output for testing.
188ec50f582SBilly Zhu     std::string serialize() const {
189ec50f582SBilly Zhu       std::ostringstream oss;
190ec50f582SBilly Zhu       for (const auto &[src, neighbors] : edges) {
191ec50f582SBilly Zhu         oss << src << ":";
192ec50f582SBilly Zhu         for (Graph::Node neighbor : neighbors)
193ec50f582SBilly Zhu           oss << " " << neighbor;
194ec50f582SBilly Zhu         oss << "\n";
195ec50f582SBilly Zhu       }
196ec50f582SBilly Zhu       return oss.str();
197ec50f582SBilly Zhu     }
198ec50f582SBilly Zhu   };
199ec50f582SBilly Zhu 
200ec50f582SBilly Zhu   /// A Graph where edges that used to cause cycles (back-edges) are now
201ec50f582SBilly Zhu   /// represented with an indirection (a recursionId).
202ec50f582SBilly Zhu   ///
203ec50f582SBilly Zhu   /// In addition to each node having an integer ID, each node also tracks the
204ec50f582SBilly Zhu   /// original integer ID it had in the original graph. This way for every
205ec50f582SBilly Zhu   /// back-edge, we can represent it as pointing to a new instance of the
206ec50f582SBilly Zhu   /// original node. Then we mark the original node and the new instance with
207ec50f582SBilly Zhu   /// a new unique recursionId to indicate that they're supposed to be the same
208ec50f582SBilly Zhu   /// node.
209ec50f582SBilly Zhu   struct PrunedGraph {
210ec50f582SBilly Zhu     using Node = Graph::Node;
211ec50f582SBilly Zhu     struct NodeInfo {
212ec50f582SBilly Zhu       Graph::Node originalId;
213ec50f582SBilly Zhu       /// A negative recursive index means not recursive. Otherwise nodes with
214ec50f582SBilly Zhu       /// the same originalId & recursionId are the same node in the original
215ec50f582SBilly Zhu       /// graph.
216ec50f582SBilly Zhu       int64_t recursionId;
217ec50f582SBilly Zhu     };
218ec50f582SBilly Zhu 
219ec50f582SBilly Zhu     /// Add a regular non-recursive-self node.
220ec50f582SBilly Zhu     Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
221ec50f582SBilly Zhu       Node id = nextConnectionId++;
222ec50f582SBilly Zhu       info[id] = {originalId, recursionIndex};
223ec50f582SBilly Zhu       return id;
224ec50f582SBilly Zhu     }
225ec50f582SBilly Zhu     /// Add a recursive-self-node, i.e. a duplicate of the original node that is
226ec50f582SBilly Zhu     /// meant to represent an indirection to it.
227ec50f582SBilly Zhu     std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
228*d5746d73SFrank Schlimbach       auto node = addNode(originalId, nextRecursionId);
229*d5746d73SFrank Schlimbach       return {node, nextRecursionId++};
230ec50f582SBilly Zhu     }
231ec50f582SBilly Zhu     void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
232ec50f582SBilly Zhu 
233ec50f582SBilly Zhu     /// Deterministic output for testing.
234ec50f582SBilly Zhu     std::string serialize() const {
235ec50f582SBilly Zhu       std::ostringstream oss;
236ec50f582SBilly Zhu       oss << "nodes\n";
237ec50f582SBilly Zhu       for (const auto &[nodeId, nodeInfo] : info) {
238ec50f582SBilly Zhu         oss << nodeId << ": n" << nodeInfo.originalId;
239ec50f582SBilly Zhu         if (nodeInfo.recursionId >= 0)
240ec50f582SBilly Zhu           oss << '<' << nodeInfo.recursionId << '>';
241ec50f582SBilly Zhu         oss << "\n";
242ec50f582SBilly Zhu       }
243ec50f582SBilly Zhu       oss << "edges\n";
244ec50f582SBilly Zhu       oss << connections.serialize();
245ec50f582SBilly Zhu       return oss.str();
246ec50f582SBilly Zhu     }
247ec50f582SBilly Zhu 
248ec50f582SBilly Zhu     bool isCyclic() const { return connections.isCyclic(); }
249ec50f582SBilly Zhu 
250ec50f582SBilly Zhu   private:
251ec50f582SBilly Zhu     Graph connections;
252ec50f582SBilly Zhu     int64_t nextRecursionId = 0;
253ec50f582SBilly Zhu     int64_t nextConnectionId = 0;
254ec50f582SBilly Zhu     /// Use ordered map for deterministic output.
255ec50f582SBilly Zhu     std::map<Graph::Node, NodeInfo> info;
256ec50f582SBilly Zhu   };
257ec50f582SBilly Zhu 
258ec50f582SBilly Zhu   PrunedGraph breakCycles(const Graph &input) {
259ec50f582SBilly Zhu     assert(input.isCyclic() && "input graph is not cyclic");
260ec50f582SBilly Zhu 
261ec50f582SBilly Zhu     PrunedGraph output;
262ec50f582SBilly Zhu 
263ec50f582SBilly Zhu     DenseMap<Graph::Node, int64_t> recMap;
264ec50f582SBilly Zhu     auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
265ec50f582SBilly Zhu       auto [node, recId] = output.addRecursiveSelfNode(inNode);
266ec50f582SBilly Zhu       recMap[inNode] = recId;
267ec50f582SBilly Zhu       return node;
268ec50f582SBilly Zhu     };
269ec50f582SBilly Zhu 
270ec50f582SBilly Zhu     CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
271ec50f582SBilly Zhu 
272ec50f582SBilly Zhu     std::function<Graph::Node(Graph::Node)> replaceNode =
273ec50f582SBilly Zhu         [&](Graph::Node inNode) {
274ec50f582SBilly Zhu           auto cacheEntry = cache.lookupOrInit(inNode);
275ec50f582SBilly Zhu           if (std::optional<Graph::Node> result = cacheEntry.get())
276ec50f582SBilly Zhu             return *result;
277ec50f582SBilly Zhu 
278ec50f582SBilly Zhu           // Recursively replace its neighbors.
279ec50f582SBilly Zhu           SmallVector<Graph::Node> neighbors;
280ec50f582SBilly Zhu           if (auto it = input.edges.find(inNode); it != input.edges.end())
281ec50f582SBilly Zhu             neighbors = SmallVector<Graph::Node>(
282ec50f582SBilly Zhu                 llvm::map_range(it->second, replaceNode));
283ec50f582SBilly Zhu 
284ec50f582SBilly Zhu           // Create a new node in the output graph.
285ec50f582SBilly Zhu           int64_t recursionIndex =
286ec50f582SBilly Zhu               cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
287ec50f582SBilly Zhu           Graph::Node result = output.addNode(inNode, recursionIndex);
288ec50f582SBilly Zhu 
289ec50f582SBilly Zhu           for (Graph::Node neighbor : neighbors)
290ec50f582SBilly Zhu             output.addEdge(result, neighbor);
291ec50f582SBilly Zhu 
292ec50f582SBilly Zhu           cacheEntry.resolve(result);
293ec50f582SBilly Zhu           return result;
294ec50f582SBilly Zhu         };
295ec50f582SBilly Zhu 
296ec50f582SBilly Zhu     /// Translate starting from each node.
297ec50f582SBilly Zhu     for (Graph::Node root : llvm::make_first_range(input.edges))
298ec50f582SBilly Zhu       replaceNode(root);
299ec50f582SBilly Zhu 
300ec50f582SBilly Zhu     return output;
301ec50f582SBilly Zhu   }
302ec50f582SBilly Zhu 
303ec50f582SBilly Zhu   /// Helper for serialization tests that allow putting comments in the
304ec50f582SBilly Zhu   /// serialized format. Every line that begins with a `;` is considered a
305ec50f582SBilly Zhu   /// comment. The entire line, incl. the terminating `\n` is removed.
306ec50f582SBilly Zhu   std::string trimComments(StringRef input) {
307ec50f582SBilly Zhu     std::ostringstream oss;
308ec50f582SBilly Zhu     bool isNewLine = false;
309ec50f582SBilly Zhu     bool isComment = false;
310ec50f582SBilly Zhu     for (char c : input) {
311ec50f582SBilly Zhu       // Lines beginning with ';' are comments.
312ec50f582SBilly Zhu       if (isNewLine && c == ';')
313ec50f582SBilly Zhu         isComment = true;
314ec50f582SBilly Zhu 
315ec50f582SBilly Zhu       if (!isComment)
316ec50f582SBilly Zhu         oss << c;
317ec50f582SBilly Zhu 
318ec50f582SBilly Zhu       if (c == '\n') {
319ec50f582SBilly Zhu         isNewLine = true;
320ec50f582SBilly Zhu         isComment = false;
321ec50f582SBilly Zhu       }
322ec50f582SBilly Zhu     }
323ec50f582SBilly Zhu     return oss.str();
324ec50f582SBilly Zhu   }
325ec50f582SBilly Zhu };
326ec50f582SBilly Zhu } // namespace
327ec50f582SBilly Zhu 
328ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
329ec50f582SBilly Zhu   // 0 -> 1 -> 2
330ec50f582SBilly Zhu   // ^         |
331ec50f582SBilly Zhu   // +---------+
332ec50f582SBilly Zhu   Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
333ec50f582SBilly Zhu   PrunedGraph output = breakCycles(input);
334ec50f582SBilly Zhu   ASSERT_FALSE(output.isCyclic()) << output.serialize();
335ec50f582SBilly Zhu   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
336ec50f582SBilly Zhu ; root 0
337ec50f582SBilly Zhu 0: n0<0>
338ec50f582SBilly Zhu 1: n2
339ec50f582SBilly Zhu 2: n1
340ec50f582SBilly Zhu 3: n0<0>
341ec50f582SBilly Zhu ; root 1
342ec50f582SBilly Zhu 4: n2
343ec50f582SBilly Zhu ; root 2
344ec50f582SBilly Zhu 5: n1
345ec50f582SBilly Zhu edges
346ec50f582SBilly Zhu 1: 0
347ec50f582SBilly Zhu 2: 1
348ec50f582SBilly Zhu 3: 2
349ec50f582SBilly Zhu 4: 3
350ec50f582SBilly Zhu 5: 4
351ec50f582SBilly Zhu )"));
352ec50f582SBilly Zhu }
353ec50f582SBilly Zhu 
354ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
355ec50f582SBilly Zhu   // +----> 1 -----+
356ec50f582SBilly Zhu   // |             v
357ec50f582SBilly Zhu   // 0 <---------- 3
358ec50f582SBilly Zhu   // |             ^
359ec50f582SBilly Zhu   // +----> 2 -----+
360ec50f582SBilly Zhu   //
361ec50f582SBilly Zhu   // Two loops:
362ec50f582SBilly Zhu   // 0 -> 1 -> 3 -> 0
363ec50f582SBilly Zhu   // 0 -> 2 -> 3 -> 0
364ec50f582SBilly Zhu   Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
365ec50f582SBilly Zhu   PrunedGraph output = breakCycles(input);
366ec50f582SBilly Zhu   ASSERT_FALSE(output.isCyclic()) << output.serialize();
367ec50f582SBilly Zhu   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
368ec50f582SBilly Zhu ; root 0
369ec50f582SBilly Zhu 0: n0<0>
370ec50f582SBilly Zhu 1: n3
371ec50f582SBilly Zhu 2: n1
372ec50f582SBilly Zhu 3: n2
373ec50f582SBilly Zhu 4: n0<0>
374ec50f582SBilly Zhu ; root 1
375ec50f582SBilly Zhu 5: n3
376ec50f582SBilly Zhu 6: n1
377ec50f582SBilly Zhu ; root 2
378ec50f582SBilly Zhu 7: n2
379ec50f582SBilly Zhu edges
380ec50f582SBilly Zhu 1: 0
381ec50f582SBilly Zhu 2: 1
382ec50f582SBilly Zhu 3: 1
383ec50f582SBilly Zhu 4: 2 3
384ec50f582SBilly Zhu 5: 4
385ec50f582SBilly Zhu 6: 5
386ec50f582SBilly Zhu 7: 5
387ec50f582SBilly Zhu )"));
388ec50f582SBilly Zhu }
389ec50f582SBilly Zhu 
390ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
391ec50f582SBilly Zhu   // +----> 1 -----+
392ec50f582SBilly Zhu   // |      ^      v
393ec50f582SBilly Zhu   // 0 <----+----- 2
394ec50f582SBilly Zhu   //
395ec50f582SBilly Zhu   // Two nested loops:
396ec50f582SBilly Zhu   // 0 -> 1 -> 2 -> 0
397ec50f582SBilly Zhu   //      1 -> 2 -> 1
398ec50f582SBilly Zhu   Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
399ec50f582SBilly Zhu   PrunedGraph output = breakCycles(input);
400ec50f582SBilly Zhu   ASSERT_FALSE(output.isCyclic()) << output.serialize();
401ec50f582SBilly Zhu   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
402ec50f582SBilly Zhu ; root 0
403ec50f582SBilly Zhu 0: n0<0>
404ec50f582SBilly Zhu 1: n1<1>
405ec50f582SBilly Zhu 2: n2
406ec50f582SBilly Zhu 3: n1<1>
407ec50f582SBilly Zhu 4: n0<0>
408ec50f582SBilly Zhu ; root 1
409ec50f582SBilly Zhu 5: n1<2>
410ec50f582SBilly Zhu 6: n2
411ec50f582SBilly Zhu 7: n1<2>
412ec50f582SBilly Zhu ; root 2
413ec50f582SBilly Zhu 8: n2
414ec50f582SBilly Zhu edges
415ec50f582SBilly Zhu 2: 0 1
416ec50f582SBilly Zhu 3: 2
417ec50f582SBilly Zhu 4: 3
418ec50f582SBilly Zhu 6: 4 5
419ec50f582SBilly Zhu 7: 6
420ec50f582SBilly Zhu 8: 4 7
421ec50f582SBilly Zhu )"));
422ec50f582SBilly Zhu }
423ec50f582SBilly Zhu 
424ec50f582SBilly Zhu TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
425ec50f582SBilly Zhu   // +----> 1 -----+
426ec50f582SBilly Zhu   // |      ^      v
427ec50f582SBilly Zhu   // 0 <----+----- 3
428ec50f582SBilly Zhu   // |      v      ^
429ec50f582SBilly Zhu   // +----> 2 -----+
430ec50f582SBilly Zhu   //
431ec50f582SBilly Zhu   // Two sets of nested loops:
432ec50f582SBilly Zhu   // 0 -> 1 -> 3 -> 0
433ec50f582SBilly Zhu   //      1 -> 3 -> 1
434ec50f582SBilly Zhu   // 0 -> 2 -> 3 -> 0
435ec50f582SBilly Zhu   //      2 -> 3 -> 2
436ec50f582SBilly Zhu   Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
437ec50f582SBilly Zhu   PrunedGraph output = breakCycles(input);
438ec50f582SBilly Zhu   ASSERT_FALSE(output.isCyclic()) << output.serialize();
439ec50f582SBilly Zhu   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
440ec50f582SBilly Zhu ; root 0
441ec50f582SBilly Zhu 0: n0<0>
442ec50f582SBilly Zhu 1: n1<1>
443ec50f582SBilly Zhu 2: n3<2>
444ec50f582SBilly Zhu 3: n2
445ec50f582SBilly Zhu 4: n3<2>
446ec50f582SBilly Zhu 5: n1<1>
447ec50f582SBilly Zhu 6: n2<3>
448ec50f582SBilly Zhu 7: n3
449ec50f582SBilly Zhu 8: n2<3>
450ec50f582SBilly Zhu 9: n0<0>
451ec50f582SBilly Zhu ; root 1
452ec50f582SBilly Zhu 10: n1<4>
453ec50f582SBilly Zhu 11: n3<5>
454ec50f582SBilly Zhu 12: n2
455ec50f582SBilly Zhu 13: n3<5>
456ec50f582SBilly Zhu 14: n1<4>
457ec50f582SBilly Zhu ; root 2
458ec50f582SBilly Zhu 15: n2<6>
459ec50f582SBilly Zhu 16: n3
460ec50f582SBilly Zhu 17: n2<6>
461ec50f582SBilly Zhu ; root 3
462ec50f582SBilly Zhu 18: n3
463ec50f582SBilly Zhu edges
464ec50f582SBilly Zhu ; root 0
465ec50f582SBilly Zhu 3: 2
466ec50f582SBilly Zhu 4: 0 1 3
467ec50f582SBilly Zhu 5: 4
468ec50f582SBilly Zhu 7: 0 5 6
469ec50f582SBilly Zhu 8: 7
470ec50f582SBilly Zhu 9: 5 8
471ec50f582SBilly Zhu ; root 1
472ec50f582SBilly Zhu 12: 11
473ec50f582SBilly Zhu 13: 9 10 12
474ec50f582SBilly Zhu 14: 13
475ec50f582SBilly Zhu ; root 2
476ec50f582SBilly Zhu 16: 9 14 15
477ec50f582SBilly Zhu 17: 16
478ec50f582SBilly Zhu ; root 3
479ec50f582SBilly Zhu 18: 9 14 17
480ec50f582SBilly Zhu )"));
481ec50f582SBilly Zhu }
482