//===- CyclicReplacerCacheTest.cpp ----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Support/CyclicReplacerCache.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" #include "gmock/gmock.h" #include #include using namespace mlir; TEST(CachedCyclicReplacerTest, testNoRecursion) { CachedCyclicReplacer replacer( /*replacer=*/[](int n) { return static_cast(n); }, /*cycleBreaker=*/[](int n) { return std::nullopt; }); EXPECT_EQ(replacer(3), true); EXPECT_EQ(replacer(0), false); } TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) { // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ... std::optional> replacer; replacer.emplace( /*replacer=*/[&](int n) { return (*replacer)((n + 1) % 3); }, /*cycleBreaker=*/[&](int n) { return -1; }); // Starting at 0. EXPECT_EQ((*replacer)(0), -1); // Starting at 2. EXPECT_EQ((*replacer)(2), -1); } //===----------------------------------------------------------------------===// // CachedCyclicReplacer: ChainRecursion //===----------------------------------------------------------------------===// /// This set of tests uses a replacer function that maps ints into vectors of /// ints. /// /// The replacement result for input `n` is the replacement result of `(n+1)%3` /// appended with an element `42`. Theoretically, this will produce an /// infinitely long vector. The cycle-breaker function prunes this infinite /// recursion in the replacer logic by returning an empty vector upon the first /// re-occurrence of an input value. namespace { class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test { public: // N ==> (N+1) % 3 // This will create a chain of infinite length without recursion pruning. CachedCyclicReplacerChainRecursionPruningTest() : replacer( [&](int n) { ++invokeCount; std::vector result = replacer((n + 1) % 3); result.push_back(42); return result; }, [&](int n) -> std::optional> { return baseCase.value_or(n) == n ? std::make_optional(std::vector{}) : std::nullopt; }) {} std::vector getChain(unsigned N) { return std::vector(N, 42); }; CachedCyclicReplacer> replacer; int invokeCount = 0; std::optional baseCase = std::nullopt; }; } // namespace TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) { // Starting at 0. Cycle length is 3. EXPECT_EQ(replacer(0), getChain(3)); EXPECT_EQ(invokeCount, 3); // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. invokeCount = 0; EXPECT_EQ(replacer(1), getChain(5)); EXPECT_EQ(invokeCount, 2); // Starting at 2. Cycle length is 4. Entire result is cached. invokeCount = 0; EXPECT_EQ(replacer(2), getChain(4)); EXPECT_EQ(invokeCount, 0); } TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) { // Starting at 1. Cycle length is 3. EXPECT_EQ(replacer(1), getChain(3)); EXPECT_EQ(invokeCount, 3); } TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) { baseCase = 0; // Starting at 0. Cycle length is 3. EXPECT_EQ(replacer(0), getChain(3)); EXPECT_EQ(invokeCount, 3); } TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) { baseCase = 0; // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). EXPECT_EQ(replacer(1), getChain(5)); EXPECT_EQ(invokeCount, 5); // Starting at 0. Cycle length is 3. Entire result is cached. invokeCount = 0; EXPECT_EQ(replacer(0), getChain(3)); EXPECT_EQ(invokeCount, 0); } //===----------------------------------------------------------------------===// // CachedCyclicReplacer: GraphReplacement //===----------------------------------------------------------------------===// /// This set of tests uses a replacer function that maps from cyclic graphs to /// trees, pruning out cycles in the process. /// /// It consists of two helper classes: /// - Graph /// - A directed graph where nodes are non-negative integers. /// - PrunedGraph /// - A Graph where edges that used to cause cycles are now represented with /// an indirection (a recursionId). namespace { class CachedCyclicReplacerGraphReplacement : public ::testing::Test { public: /// A directed graph where nodes are non-negative integers. struct Graph { using Node = int64_t; /// Use ordered containers for deterministic output. /// Nodes without outgoing edges are considered nonexistent. std::map> edges; void addEdge(Node src, Node sink) { edges[src].insert(sink); } bool isCyclic() const { DenseSet visited; for (Node root : llvm::make_first_range(edges)) { if (visited.contains(root)) continue; SetVector path; SmallVector workstack; workstack.push_back(root); while (!workstack.empty()) { Node curr = workstack.back(); workstack.pop_back(); if (curr < 0) { // A negative node signals the end of processing all of this node's // children. Remove self from path. assert(path.back() == -curr && "internal inconsistency"); path.pop_back(); continue; } if (path.contains(curr)) return true; visited.insert(curr); auto edgesIter = edges.find(curr); if (edgesIter == edges.end() || edgesIter->second.empty()) continue; path.insert(curr); // Push negative node to signify recursion return. workstack.push_back(-curr); workstack.insert(workstack.end(), edgesIter->second.begin(), edgesIter->second.end()); } } return false; } /// Deterministic output for testing. std::string serialize() const { std::ostringstream oss; for (const auto &[src, neighbors] : edges) { oss << src << ":"; for (Graph::Node neighbor : neighbors) oss << " " << neighbor; oss << "\n"; } return oss.str(); } }; /// A Graph where edges that used to cause cycles (back-edges) are now /// represented with an indirection (a recursionId). /// /// In addition to each node having an integer ID, each node also tracks the /// original integer ID it had in the original graph. This way for every /// back-edge, we can represent it as pointing to a new instance of the /// original node. Then we mark the original node and the new instance with /// a new unique recursionId to indicate that they're supposed to be the same /// node. struct PrunedGraph { using Node = Graph::Node; struct NodeInfo { Graph::Node originalId; /// A negative recursive index means not recursive. Otherwise nodes with /// the same originalId & recursionId are the same node in the original /// graph. int64_t recursionId; }; /// Add a regular non-recursive-self node. Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) { Node id = nextConnectionId++; info[id] = {originalId, recursionIndex}; return id; } /// Add a recursive-self-node, i.e. a duplicate of the original node that is /// meant to represent an indirection to it. std::pair addRecursiveSelfNode(Graph::Node originalId) { auto node = addNode(originalId, nextRecursionId); return {node, nextRecursionId++}; } void addEdge(Node src, Node sink) { connections.addEdge(src, sink); } /// Deterministic output for testing. std::string serialize() const { std::ostringstream oss; oss << "nodes\n"; for (const auto &[nodeId, nodeInfo] : info) { oss << nodeId << ": n" << nodeInfo.originalId; if (nodeInfo.recursionId >= 0) oss << '<' << nodeInfo.recursionId << '>'; oss << "\n"; } oss << "edges\n"; oss << connections.serialize(); return oss.str(); } bool isCyclic() const { return connections.isCyclic(); } private: Graph connections; int64_t nextRecursionId = 0; int64_t nextConnectionId = 0; /// Use ordered map for deterministic output. std::map info; }; PrunedGraph breakCycles(const Graph &input) { assert(input.isCyclic() && "input graph is not cyclic"); PrunedGraph output; DenseMap recMap; auto cycleBreaker = [&](Graph::Node inNode) -> std::optional { auto [node, recId] = output.addRecursiveSelfNode(inNode); recMap[inNode] = recId; return node; }; CyclicReplacerCache cache(cycleBreaker); std::function replaceNode = [&](Graph::Node inNode) { auto cacheEntry = cache.lookupOrInit(inNode); if (std::optional result = cacheEntry.get()) return *result; // Recursively replace its neighbors. SmallVector neighbors; if (auto it = input.edges.find(inNode); it != input.edges.end()) neighbors = SmallVector( llvm::map_range(it->second, replaceNode)); // Create a new node in the output graph. int64_t recursionIndex = cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1; Graph::Node result = output.addNode(inNode, recursionIndex); for (Graph::Node neighbor : neighbors) output.addEdge(result, neighbor); cacheEntry.resolve(result); return result; }; /// Translate starting from each node. for (Graph::Node root : llvm::make_first_range(input.edges)) replaceNode(root); return output; } /// Helper for serialization tests that allow putting comments in the /// serialized format. Every line that begins with a `;` is considered a /// comment. The entire line, incl. the terminating `\n` is removed. std::string trimComments(StringRef input) { std::ostringstream oss; bool isNewLine = false; bool isComment = false; for (char c : input) { // Lines beginning with ';' are comments. if (isNewLine && c == ';') isComment = true; if (!isComment) oss << c; if (c == '\n') { isNewLine = true; isComment = false; } } return oss.str(); } }; } // namespace TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) { // 0 -> 1 -> 2 // ^ | // +---------+ Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}}; PrunedGraph output = breakCycles(input); ASSERT_FALSE(output.isCyclic()) << output.serialize(); EXPECT_EQ(output.serialize(), trimComments(R"(nodes ; root 0 0: n0<0> 1: n2 2: n1 3: n0<0> ; root 1 4: n2 ; root 2 5: n1 edges 1: 0 2: 1 3: 2 4: 3 5: 4 )")); } TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) { // +----> 1 -----+ // | v // 0 <---------- 3 // | ^ // +----> 2 -----+ // // Two loops: // 0 -> 1 -> 3 -> 0 // 0 -> 2 -> 3 -> 0 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}}; PrunedGraph output = breakCycles(input); ASSERT_FALSE(output.isCyclic()) << output.serialize(); EXPECT_EQ(output.serialize(), trimComments(R"(nodes ; root 0 0: n0<0> 1: n3 2: n1 3: n2 4: n0<0> ; root 1 5: n3 6: n1 ; root 2 7: n2 edges 1: 0 2: 1 3: 1 4: 2 3 5: 4 6: 5 7: 5 )")); } TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) { // +----> 1 -----+ // | ^ v // 0 <----+----- 2 // // Two nested loops: // 0 -> 1 -> 2 -> 0 // 1 -> 2 -> 1 Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}}; PrunedGraph output = breakCycles(input); ASSERT_FALSE(output.isCyclic()) << output.serialize(); EXPECT_EQ(output.serialize(), trimComments(R"(nodes ; root 0 0: n0<0> 1: n1<1> 2: n2 3: n1<1> 4: n0<0> ; root 1 5: n1<2> 6: n2 7: n1<2> ; root 2 8: n2 edges 2: 0 1 3: 2 4: 3 6: 4 5 7: 6 8: 4 7 )")); } TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) { // +----> 1 -----+ // | ^ v // 0 <----+----- 3 // | v ^ // +----> 2 -----+ // // Two sets of nested loops: // 0 -> 1 -> 3 -> 0 // 1 -> 3 -> 1 // 0 -> 2 -> 3 -> 0 // 2 -> 3 -> 2 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}}; PrunedGraph output = breakCycles(input); ASSERT_FALSE(output.isCyclic()) << output.serialize(); EXPECT_EQ(output.serialize(), trimComments(R"(nodes ; root 0 0: n0<0> 1: n1<1> 2: n3<2> 3: n2 4: n3<2> 5: n1<1> 6: n2<3> 7: n3 8: n2<3> 9: n0<0> ; root 1 10: n1<4> 11: n3<5> 12: n2 13: n3<5> 14: n1<4> ; root 2 15: n2<6> 16: n3 17: n2<6> ; root 3 18: n3 edges ; root 0 3: 2 4: 0 1 3 5: 4 7: 0 5 6 8: 7 9: 5 8 ; root 1 12: 11 13: 9 10 12 14: 13 ; root 2 16: 9 14 15 17: 16 ; root 3 18: 9 14 17 )")); }