1 //===- CyclicReplacerCacheTest.cpp ----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/CyclicReplacerCache.h" 10 #include "mlir/Support/LLVM.h" 11 #include "llvm/ADT/SetVector.h" 12 #include "gmock/gmock.h" 13 #include <map> 14 #include <set> 15 16 using namespace mlir; 17 18 TEST(CachedCyclicReplacerTest, testNoRecursion) { 19 CachedCyclicReplacer<int, bool> replacer( 20 /*replacer=*/[](int n) { return static_cast<bool>(n); }, 21 /*cycleBreaker=*/[](int n) { return std::nullopt; }); 22 23 EXPECT_EQ(replacer(3), true); 24 EXPECT_EQ(replacer(0), false); 25 } 26 27 TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) { 28 // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ... 29 std::optional<CachedCyclicReplacer<int, int>> replacer; 30 replacer.emplace( 31 /*replacer=*/[&](int n) { return (*replacer)((n + 1) % 3); }, 32 /*cycleBreaker=*/[&](int n) { return -1; }); 33 34 // Starting at 0. 35 EXPECT_EQ((*replacer)(0), -1); 36 // Starting at 2. 37 EXPECT_EQ((*replacer)(2), -1); 38 } 39 40 //===----------------------------------------------------------------------===// 41 // CachedCyclicReplacer: ChainRecursion 42 //===----------------------------------------------------------------------===// 43 44 /// This set of tests uses a replacer function that maps ints into vectors of 45 /// ints. 46 /// 47 /// The replacement result for input `n` is the replacement result of `(n+1)%3` 48 /// appended with an element `42`. Theoretically, this will produce an 49 /// infinitely long vector. The cycle-breaker function prunes this infinite 50 /// recursion in the replacer logic by returning an empty vector upon the first 51 /// re-occurrence of an input value. 52 namespace { 53 class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test { 54 public: 55 // N ==> (N+1) % 3 56 // This will create a chain of infinite length without recursion pruning. 57 CachedCyclicReplacerChainRecursionPruningTest() 58 : replacer( 59 [&](int n) { 60 ++invokeCount; 61 std::vector<int> result = replacer((n + 1) % 3); 62 result.push_back(42); 63 return result; 64 }, 65 [&](int n) -> std::optional<std::vector<int>> { 66 return baseCase.value_or(n) == n 67 ? std::make_optional(std::vector<int>{}) 68 : std::nullopt; 69 }) {} 70 71 std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); }; 72 73 CachedCyclicReplacer<int, std::vector<int>> replacer; 74 int invokeCount = 0; 75 std::optional<int> baseCase = std::nullopt; 76 }; 77 } // namespace 78 79 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) { 80 // Starting at 0. Cycle length is 3. 81 EXPECT_EQ(replacer(0), getChain(3)); 82 EXPECT_EQ(invokeCount, 3); 83 84 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. 85 invokeCount = 0; 86 EXPECT_EQ(replacer(1), getChain(5)); 87 EXPECT_EQ(invokeCount, 2); 88 89 // Starting at 2. Cycle length is 4. Entire result is cached. 90 invokeCount = 0; 91 EXPECT_EQ(replacer(2), getChain(4)); 92 EXPECT_EQ(invokeCount, 0); 93 } 94 95 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) { 96 // Starting at 1. Cycle length is 3. 97 EXPECT_EQ(replacer(1), getChain(3)); 98 EXPECT_EQ(invokeCount, 3); 99 } 100 101 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) { 102 baseCase = 0; 103 104 // Starting at 0. Cycle length is 3. 105 EXPECT_EQ(replacer(0), getChain(3)); 106 EXPECT_EQ(invokeCount, 3); 107 } 108 109 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) { 110 baseCase = 0; 111 112 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). 113 EXPECT_EQ(replacer(1), getChain(5)); 114 EXPECT_EQ(invokeCount, 5); 115 116 // Starting at 0. Cycle length is 3. Entire result is cached. 117 invokeCount = 0; 118 EXPECT_EQ(replacer(0), getChain(3)); 119 EXPECT_EQ(invokeCount, 0); 120 } 121 122 //===----------------------------------------------------------------------===// 123 // CachedCyclicReplacer: GraphReplacement 124 //===----------------------------------------------------------------------===// 125 126 /// This set of tests uses a replacer function that maps from cyclic graphs to 127 /// trees, pruning out cycles in the process. 128 /// 129 /// It consists of two helper classes: 130 /// - Graph 131 /// - A directed graph where nodes are non-negative integers. 132 /// - PrunedGraph 133 /// - A Graph where edges that used to cause cycles are now represented with 134 /// an indirection (a recursionId). 135 namespace { 136 class CachedCyclicReplacerGraphReplacement : public ::testing::Test { 137 public: 138 /// A directed graph where nodes are non-negative integers. 139 struct Graph { 140 using Node = int64_t; 141 142 /// Use ordered containers for deterministic output. 143 /// Nodes without outgoing edges are considered nonexistent. 144 std::map<Node, std::set<Node>> edges; 145 146 void addEdge(Node src, Node sink) { edges[src].insert(sink); } 147 148 bool isCyclic() const { 149 DenseSet<Node> visited; 150 for (Node root : llvm::make_first_range(edges)) { 151 if (visited.contains(root)) 152 continue; 153 154 SetVector<Node> path; 155 SmallVector<Node> workstack; 156 workstack.push_back(root); 157 while (!workstack.empty()) { 158 Node curr = workstack.back(); 159 workstack.pop_back(); 160 161 if (curr < 0) { 162 // A negative node signals the end of processing all of this node's 163 // children. Remove self from path. 164 assert(path.back() == -curr && "internal inconsistency"); 165 path.pop_back(); 166 continue; 167 } 168 169 if (path.contains(curr)) 170 return true; 171 172 visited.insert(curr); 173 auto edgesIter = edges.find(curr); 174 if (edgesIter == edges.end() || edgesIter->second.empty()) 175 continue; 176 177 path.insert(curr); 178 // Push negative node to signify recursion return. 179 workstack.push_back(-curr); 180 workstack.insert(workstack.end(), edgesIter->second.begin(), 181 edgesIter->second.end()); 182 } 183 } 184 return false; 185 } 186 187 /// Deterministic output for testing. 188 std::string serialize() const { 189 std::ostringstream oss; 190 for (const auto &[src, neighbors] : edges) { 191 oss << src << ":"; 192 for (Graph::Node neighbor : neighbors) 193 oss << " " << neighbor; 194 oss << "\n"; 195 } 196 return oss.str(); 197 } 198 }; 199 200 /// A Graph where edges that used to cause cycles (back-edges) are now 201 /// represented with an indirection (a recursionId). 202 /// 203 /// In addition to each node having an integer ID, each node also tracks the 204 /// original integer ID it had in the original graph. This way for every 205 /// back-edge, we can represent it as pointing to a new instance of the 206 /// original node. Then we mark the original node and the new instance with 207 /// a new unique recursionId to indicate that they're supposed to be the same 208 /// node. 209 struct PrunedGraph { 210 using Node = Graph::Node; 211 struct NodeInfo { 212 Graph::Node originalId; 213 /// A negative recursive index means not recursive. Otherwise nodes with 214 /// the same originalId & recursionId are the same node in the original 215 /// graph. 216 int64_t recursionId; 217 }; 218 219 /// Add a regular non-recursive-self node. 220 Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) { 221 Node id = nextConnectionId++; 222 info[id] = {originalId, recursionIndex}; 223 return id; 224 } 225 /// Add a recursive-self-node, i.e. a duplicate of the original node that is 226 /// meant to represent an indirection to it. 227 std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) { 228 auto node = addNode(originalId, nextRecursionId); 229 return {node, nextRecursionId++}; 230 } 231 void addEdge(Node src, Node sink) { connections.addEdge(src, sink); } 232 233 /// Deterministic output for testing. 234 std::string serialize() const { 235 std::ostringstream oss; 236 oss << "nodes\n"; 237 for (const auto &[nodeId, nodeInfo] : info) { 238 oss << nodeId << ": n" << nodeInfo.originalId; 239 if (nodeInfo.recursionId >= 0) 240 oss << '<' << nodeInfo.recursionId << '>'; 241 oss << "\n"; 242 } 243 oss << "edges\n"; 244 oss << connections.serialize(); 245 return oss.str(); 246 } 247 248 bool isCyclic() const { return connections.isCyclic(); } 249 250 private: 251 Graph connections; 252 int64_t nextRecursionId = 0; 253 int64_t nextConnectionId = 0; 254 /// Use ordered map for deterministic output. 255 std::map<Graph::Node, NodeInfo> info; 256 }; 257 258 PrunedGraph breakCycles(const Graph &input) { 259 assert(input.isCyclic() && "input graph is not cyclic"); 260 261 PrunedGraph output; 262 263 DenseMap<Graph::Node, int64_t> recMap; 264 auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> { 265 auto [node, recId] = output.addRecursiveSelfNode(inNode); 266 recMap[inNode] = recId; 267 return node; 268 }; 269 270 CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker); 271 272 std::function<Graph::Node(Graph::Node)> replaceNode = 273 [&](Graph::Node inNode) { 274 auto cacheEntry = cache.lookupOrInit(inNode); 275 if (std::optional<Graph::Node> result = cacheEntry.get()) 276 return *result; 277 278 // Recursively replace its neighbors. 279 SmallVector<Graph::Node> neighbors; 280 if (auto it = input.edges.find(inNode); it != input.edges.end()) 281 neighbors = SmallVector<Graph::Node>( 282 llvm::map_range(it->second, replaceNode)); 283 284 // Create a new node in the output graph. 285 int64_t recursionIndex = 286 cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1; 287 Graph::Node result = output.addNode(inNode, recursionIndex); 288 289 for (Graph::Node neighbor : neighbors) 290 output.addEdge(result, neighbor); 291 292 cacheEntry.resolve(result); 293 return result; 294 }; 295 296 /// Translate starting from each node. 297 for (Graph::Node root : llvm::make_first_range(input.edges)) 298 replaceNode(root); 299 300 return output; 301 } 302 303 /// Helper for serialization tests that allow putting comments in the 304 /// serialized format. Every line that begins with a `;` is considered a 305 /// comment. The entire line, incl. the terminating `\n` is removed. 306 std::string trimComments(StringRef input) { 307 std::ostringstream oss; 308 bool isNewLine = false; 309 bool isComment = false; 310 for (char c : input) { 311 // Lines beginning with ';' are comments. 312 if (isNewLine && c == ';') 313 isComment = true; 314 315 if (!isComment) 316 oss << c; 317 318 if (c == '\n') { 319 isNewLine = true; 320 isComment = false; 321 } 322 } 323 return oss.str(); 324 } 325 }; 326 } // namespace 327 328 TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) { 329 // 0 -> 1 -> 2 330 // ^ | 331 // +---------+ 332 Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}}; 333 PrunedGraph output = breakCycles(input); 334 ASSERT_FALSE(output.isCyclic()) << output.serialize(); 335 EXPECT_EQ(output.serialize(), trimComments(R"(nodes 336 ; root 0 337 0: n0<0> 338 1: n2 339 2: n1 340 3: n0<0> 341 ; root 1 342 4: n2 343 ; root 2 344 5: n1 345 edges 346 1: 0 347 2: 1 348 3: 2 349 4: 3 350 5: 4 351 )")); 352 } 353 354 TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) { 355 // +----> 1 -----+ 356 // | v 357 // 0 <---------- 3 358 // | ^ 359 // +----> 2 -----+ 360 // 361 // Two loops: 362 // 0 -> 1 -> 3 -> 0 363 // 0 -> 2 -> 3 -> 0 364 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}}; 365 PrunedGraph output = breakCycles(input); 366 ASSERT_FALSE(output.isCyclic()) << output.serialize(); 367 EXPECT_EQ(output.serialize(), trimComments(R"(nodes 368 ; root 0 369 0: n0<0> 370 1: n3 371 2: n1 372 3: n2 373 4: n0<0> 374 ; root 1 375 5: n3 376 6: n1 377 ; root 2 378 7: n2 379 edges 380 1: 0 381 2: 1 382 3: 1 383 4: 2 3 384 5: 4 385 6: 5 386 7: 5 387 )")); 388 } 389 390 TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) { 391 // +----> 1 -----+ 392 // | ^ v 393 // 0 <----+----- 2 394 // 395 // Two nested loops: 396 // 0 -> 1 -> 2 -> 0 397 // 1 -> 2 -> 1 398 Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}}; 399 PrunedGraph output = breakCycles(input); 400 ASSERT_FALSE(output.isCyclic()) << output.serialize(); 401 EXPECT_EQ(output.serialize(), trimComments(R"(nodes 402 ; root 0 403 0: n0<0> 404 1: n1<1> 405 2: n2 406 3: n1<1> 407 4: n0<0> 408 ; root 1 409 5: n1<2> 410 6: n2 411 7: n1<2> 412 ; root 2 413 8: n2 414 edges 415 2: 0 1 416 3: 2 417 4: 3 418 6: 4 5 419 7: 6 420 8: 4 7 421 )")); 422 } 423 424 TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) { 425 // +----> 1 -----+ 426 // | ^ v 427 // 0 <----+----- 3 428 // | v ^ 429 // +----> 2 -----+ 430 // 431 // Two sets of nested loops: 432 // 0 -> 1 -> 3 -> 0 433 // 1 -> 3 -> 1 434 // 0 -> 2 -> 3 -> 0 435 // 2 -> 3 -> 2 436 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}}; 437 PrunedGraph output = breakCycles(input); 438 ASSERT_FALSE(output.isCyclic()) << output.serialize(); 439 EXPECT_EQ(output.serialize(), trimComments(R"(nodes 440 ; root 0 441 0: n0<0> 442 1: n1<1> 443 2: n3<2> 444 3: n2 445 4: n3<2> 446 5: n1<1> 447 6: n2<3> 448 7: n3 449 8: n2<3> 450 9: n0<0> 451 ; root 1 452 10: n1<4> 453 11: n3<5> 454 12: n2 455 13: n3<5> 456 14: n1<4> 457 ; root 2 458 15: n2<6> 459 16: n3 460 17: n2<6> 461 ; root 3 462 18: n3 463 edges 464 ; root 0 465 3: 2 466 4: 0 1 3 467 5: 4 468 7: 0 5 6 469 8: 7 470 9: 5 8 471 ; root 1 472 12: 11 473 13: 9 10 12 474 14: 13 475 ; root 2 476 16: 9 14 15 477 17: 16 478 ; root 3 479 18: 9 14 17 480 )")); 481 } 482