xref: /llvm-project/mlir/unittests/Support/CyclicReplacerCacheTest.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1 //===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/CyclicReplacerCache.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/SetVector.h"
12 #include "gmock/gmock.h"
13 #include <map>
14 #include <set>
15 
16 using namespace mlir;
17 
18 TEST(CachedCyclicReplacerTest, testNoRecursion) {
19   CachedCyclicReplacer<int, bool> replacer(
20       /*replacer=*/[](int n) { return static_cast<bool>(n); },
21       /*cycleBreaker=*/[](int n) { return std::nullopt; });
22 
23   EXPECT_EQ(replacer(3), true);
24   EXPECT_EQ(replacer(0), false);
25 }
26 
27 TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
28   // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
29   std::optional<CachedCyclicReplacer<int, int>> replacer;
30   replacer.emplace(
31       /*replacer=*/[&](int n) { return (*replacer)((n + 1) % 3); },
32       /*cycleBreaker=*/[&](int n) { return -1; });
33 
34   // Starting at 0.
35   EXPECT_EQ((*replacer)(0), -1);
36   // Starting at 2.
37   EXPECT_EQ((*replacer)(2), -1);
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // CachedCyclicReplacer: ChainRecursion
42 //===----------------------------------------------------------------------===//
43 
44 /// This set of tests uses a replacer function that maps ints into vectors of
45 /// ints.
46 ///
47 /// The replacement result for input `n` is the replacement result of `(n+1)%3`
48 /// appended with an element `42`. Theoretically, this will produce an
49 /// infinitely long vector. The cycle-breaker function prunes this infinite
50 /// recursion in the replacer logic by returning an empty vector upon the first
51 /// re-occurrence of an input value.
52 namespace {
53 class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
54 public:
55   // N ==> (N+1) % 3
56   // This will create a chain of infinite length without recursion pruning.
57   CachedCyclicReplacerChainRecursionPruningTest()
58       : replacer(
59             [&](int n) {
60               ++invokeCount;
61               std::vector<int> result = replacer((n + 1) % 3);
62               result.push_back(42);
63               return result;
64             },
65             [&](int n) -> std::optional<std::vector<int>> {
66               return baseCase.value_or(n) == n
67                          ? std::make_optional(std::vector<int>{})
68                          : std::nullopt;
69             }) {}
70 
71   std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
72 
73   CachedCyclicReplacer<int, std::vector<int>> replacer;
74   int invokeCount = 0;
75   std::optional<int> baseCase = std::nullopt;
76 };
77 } // namespace
78 
79 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
80   // Starting at 0. Cycle length is 3.
81   EXPECT_EQ(replacer(0), getChain(3));
82   EXPECT_EQ(invokeCount, 3);
83 
84   // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
85   invokeCount = 0;
86   EXPECT_EQ(replacer(1), getChain(5));
87   EXPECT_EQ(invokeCount, 2);
88 
89   // Starting at 2. Cycle length is 4. Entire result is cached.
90   invokeCount = 0;
91   EXPECT_EQ(replacer(2), getChain(4));
92   EXPECT_EQ(invokeCount, 0);
93 }
94 
95 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
96   // Starting at 1. Cycle length is 3.
97   EXPECT_EQ(replacer(1), getChain(3));
98   EXPECT_EQ(invokeCount, 3);
99 }
100 
101 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
102   baseCase = 0;
103 
104   // Starting at 0. Cycle length is 3.
105   EXPECT_EQ(replacer(0), getChain(3));
106   EXPECT_EQ(invokeCount, 3);
107 }
108 
109 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
110   baseCase = 0;
111 
112   // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
113   EXPECT_EQ(replacer(1), getChain(5));
114   EXPECT_EQ(invokeCount, 5);
115 
116   // Starting at 0. Cycle length is 3. Entire result is cached.
117   invokeCount = 0;
118   EXPECT_EQ(replacer(0), getChain(3));
119   EXPECT_EQ(invokeCount, 0);
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // CachedCyclicReplacer: GraphReplacement
124 //===----------------------------------------------------------------------===//
125 
126 /// This set of tests uses a replacer function that maps from cyclic graphs to
127 /// trees, pruning out cycles in the process.
128 ///
129 /// It consists of two helper classes:
130 /// - Graph
131 ///   - A directed graph where nodes are non-negative integers.
132 /// - PrunedGraph
133 ///   - A Graph where edges that used to cause cycles are now represented with
134 ///     an indirection (a recursionId).
135 namespace {
136 class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
137 public:
138   /// A directed graph where nodes are non-negative integers.
139   struct Graph {
140     using Node = int64_t;
141 
142     /// Use ordered containers for deterministic output.
143     /// Nodes without outgoing edges are considered nonexistent.
144     std::map<Node, std::set<Node>> edges;
145 
146     void addEdge(Node src, Node sink) { edges[src].insert(sink); }
147 
148     bool isCyclic() const {
149       DenseSet<Node> visited;
150       for (Node root : llvm::make_first_range(edges)) {
151         if (visited.contains(root))
152           continue;
153 
154         SetVector<Node> path;
155         SmallVector<Node> workstack;
156         workstack.push_back(root);
157         while (!workstack.empty()) {
158           Node curr = workstack.back();
159           workstack.pop_back();
160 
161           if (curr < 0) {
162             // A negative node signals the end of processing all of this node's
163             // children. Remove self from path.
164             assert(path.back() == -curr && "internal inconsistency");
165             path.pop_back();
166             continue;
167           }
168 
169           if (path.contains(curr))
170             return true;
171 
172           visited.insert(curr);
173           auto edgesIter = edges.find(curr);
174           if (edgesIter == edges.end() || edgesIter->second.empty())
175             continue;
176 
177           path.insert(curr);
178           // Push negative node to signify recursion return.
179           workstack.push_back(-curr);
180           workstack.insert(workstack.end(), edgesIter->second.begin(),
181                            edgesIter->second.end());
182         }
183       }
184       return false;
185     }
186 
187     /// Deterministic output for testing.
188     std::string serialize() const {
189       std::ostringstream oss;
190       for (const auto &[src, neighbors] : edges) {
191         oss << src << ":";
192         for (Graph::Node neighbor : neighbors)
193           oss << " " << neighbor;
194         oss << "\n";
195       }
196       return oss.str();
197     }
198   };
199 
200   /// A Graph where edges that used to cause cycles (back-edges) are now
201   /// represented with an indirection (a recursionId).
202   ///
203   /// In addition to each node having an integer ID, each node also tracks the
204   /// original integer ID it had in the original graph. This way for every
205   /// back-edge, we can represent it as pointing to a new instance of the
206   /// original node. Then we mark the original node and the new instance with
207   /// a new unique recursionId to indicate that they're supposed to be the same
208   /// node.
209   struct PrunedGraph {
210     using Node = Graph::Node;
211     struct NodeInfo {
212       Graph::Node originalId;
213       /// A negative recursive index means not recursive. Otherwise nodes with
214       /// the same originalId & recursionId are the same node in the original
215       /// graph.
216       int64_t recursionId;
217     };
218 
219     /// Add a regular non-recursive-self node.
220     Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
221       Node id = nextConnectionId++;
222       info[id] = {originalId, recursionIndex};
223       return id;
224     }
225     /// Add a recursive-self-node, i.e. a duplicate of the original node that is
226     /// meant to represent an indirection to it.
227     std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
228       auto node = addNode(originalId, nextRecursionId);
229       return {node, nextRecursionId++};
230     }
231     void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
232 
233     /// Deterministic output for testing.
234     std::string serialize() const {
235       std::ostringstream oss;
236       oss << "nodes\n";
237       for (const auto &[nodeId, nodeInfo] : info) {
238         oss << nodeId << ": n" << nodeInfo.originalId;
239         if (nodeInfo.recursionId >= 0)
240           oss << '<' << nodeInfo.recursionId << '>';
241         oss << "\n";
242       }
243       oss << "edges\n";
244       oss << connections.serialize();
245       return oss.str();
246     }
247 
248     bool isCyclic() const { return connections.isCyclic(); }
249 
250   private:
251     Graph connections;
252     int64_t nextRecursionId = 0;
253     int64_t nextConnectionId = 0;
254     /// Use ordered map for deterministic output.
255     std::map<Graph::Node, NodeInfo> info;
256   };
257 
258   PrunedGraph breakCycles(const Graph &input) {
259     assert(input.isCyclic() && "input graph is not cyclic");
260 
261     PrunedGraph output;
262 
263     DenseMap<Graph::Node, int64_t> recMap;
264     auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
265       auto [node, recId] = output.addRecursiveSelfNode(inNode);
266       recMap[inNode] = recId;
267       return node;
268     };
269 
270     CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
271 
272     std::function<Graph::Node(Graph::Node)> replaceNode =
273         [&](Graph::Node inNode) {
274           auto cacheEntry = cache.lookupOrInit(inNode);
275           if (std::optional<Graph::Node> result = cacheEntry.get())
276             return *result;
277 
278           // Recursively replace its neighbors.
279           SmallVector<Graph::Node> neighbors;
280           if (auto it = input.edges.find(inNode); it != input.edges.end())
281             neighbors = SmallVector<Graph::Node>(
282                 llvm::map_range(it->second, replaceNode));
283 
284           // Create a new node in the output graph.
285           int64_t recursionIndex =
286               cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
287           Graph::Node result = output.addNode(inNode, recursionIndex);
288 
289           for (Graph::Node neighbor : neighbors)
290             output.addEdge(result, neighbor);
291 
292           cacheEntry.resolve(result);
293           return result;
294         };
295 
296     /// Translate starting from each node.
297     for (Graph::Node root : llvm::make_first_range(input.edges))
298       replaceNode(root);
299 
300     return output;
301   }
302 
303   /// Helper for serialization tests that allow putting comments in the
304   /// serialized format. Every line that begins with a `;` is considered a
305   /// comment. The entire line, incl. the terminating `\n` is removed.
306   std::string trimComments(StringRef input) {
307     std::ostringstream oss;
308     bool isNewLine = false;
309     bool isComment = false;
310     for (char c : input) {
311       // Lines beginning with ';' are comments.
312       if (isNewLine && c == ';')
313         isComment = true;
314 
315       if (!isComment)
316         oss << c;
317 
318       if (c == '\n') {
319         isNewLine = true;
320         isComment = false;
321       }
322     }
323     return oss.str();
324   }
325 };
326 } // namespace
327 
328 TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
329   // 0 -> 1 -> 2
330   // ^         |
331   // +---------+
332   Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
333   PrunedGraph output = breakCycles(input);
334   ASSERT_FALSE(output.isCyclic()) << output.serialize();
335   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
336 ; root 0
337 0: n0<0>
338 1: n2
339 2: n1
340 3: n0<0>
341 ; root 1
342 4: n2
343 ; root 2
344 5: n1
345 edges
346 1: 0
347 2: 1
348 3: 2
349 4: 3
350 5: 4
351 )"));
352 }
353 
354 TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
355   // +----> 1 -----+
356   // |             v
357   // 0 <---------- 3
358   // |             ^
359   // +----> 2 -----+
360   //
361   // Two loops:
362   // 0 -> 1 -> 3 -> 0
363   // 0 -> 2 -> 3 -> 0
364   Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
365   PrunedGraph output = breakCycles(input);
366   ASSERT_FALSE(output.isCyclic()) << output.serialize();
367   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
368 ; root 0
369 0: n0<0>
370 1: n3
371 2: n1
372 3: n2
373 4: n0<0>
374 ; root 1
375 5: n3
376 6: n1
377 ; root 2
378 7: n2
379 edges
380 1: 0
381 2: 1
382 3: 1
383 4: 2 3
384 5: 4
385 6: 5
386 7: 5
387 )"));
388 }
389 
390 TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
391   // +----> 1 -----+
392   // |      ^      v
393   // 0 <----+----- 2
394   //
395   // Two nested loops:
396   // 0 -> 1 -> 2 -> 0
397   //      1 -> 2 -> 1
398   Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
399   PrunedGraph output = breakCycles(input);
400   ASSERT_FALSE(output.isCyclic()) << output.serialize();
401   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
402 ; root 0
403 0: n0<0>
404 1: n1<1>
405 2: n2
406 3: n1<1>
407 4: n0<0>
408 ; root 1
409 5: n1<2>
410 6: n2
411 7: n1<2>
412 ; root 2
413 8: n2
414 edges
415 2: 0 1
416 3: 2
417 4: 3
418 6: 4 5
419 7: 6
420 8: 4 7
421 )"));
422 }
423 
424 TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
425   // +----> 1 -----+
426   // |      ^      v
427   // 0 <----+----- 3
428   // |      v      ^
429   // +----> 2 -----+
430   //
431   // Two sets of nested loops:
432   // 0 -> 1 -> 3 -> 0
433   //      1 -> 3 -> 1
434   // 0 -> 2 -> 3 -> 0
435   //      2 -> 3 -> 2
436   Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
437   PrunedGraph output = breakCycles(input);
438   ASSERT_FALSE(output.isCyclic()) << output.serialize();
439   EXPECT_EQ(output.serialize(), trimComments(R"(nodes
440 ; root 0
441 0: n0<0>
442 1: n1<1>
443 2: n3<2>
444 3: n2
445 4: n3<2>
446 5: n1<1>
447 6: n2<3>
448 7: n3
449 8: n2<3>
450 9: n0<0>
451 ; root 1
452 10: n1<4>
453 11: n3<5>
454 12: n2
455 13: n3<5>
456 14: n1<4>
457 ; root 2
458 15: n2<6>
459 16: n3
460 17: n2<6>
461 ; root 3
462 18: n3
463 edges
464 ; root 0
465 3: 2
466 4: 0 1 3
467 5: 4
468 7: 0 5 6
469 8: 7
470 9: 5 8
471 ; root 1
472 12: 11
473 13: 9 10 12
474 14: 13
475 ; root 2
476 16: 9 14 15
477 17: 16
478 ; root 3
479 18: 9 14 17
480 )"));
481 }
482