1 //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v[1].0 with LLVM 4 // Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/MLIRContext.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::arith; 17 using namespace mlir::pdl_to_pdl_interp; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Test Fixture 23 //===----------------------------------------------------------------------===// 24 25 /// The test fixture for constructing root ordering tests and verifying results. 26 /// This fixture constructs the test values v. The test populates the graph 27 /// with the desired costs and then calls check(), passing the expected optimal 28 /// cost and the list of edges in the preorder traversal of the optimal 29 /// branching. 30 class RootOrderingTest : public ::testing::Test { 31 protected: 32 RootOrderingTest() { 33 context.loadDialect<ArithmeticDialect>(); 34 createValues(); 35 } 36 37 /// Creates the test values. These values simply act as vertices / vertex IDs 38 /// in the cost graph, rather than being a part of an IR. 39 void createValues() { 40 OpBuilder builder(&context); 41 builder.setInsertionPointToStart(&block); 42 for (int i = 0; i < 4; ++i) 43 // Ops will be deleted when `block` is destroyed. 44 v[i] = builder.create<ConstantIntOp>(builder.getUnknownLoc(), i, 32); 45 } 46 47 /// Checks that optimal branching on graph has the given cost and 48 /// its preorder traversal results in the specified edges. 49 void check(unsigned cost, OptimalBranching::EdgeList edges) { 50 OptimalBranching opt(graph, v[0]); 51 EXPECT_EQ(opt.solve(), cost); 52 EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges); 53 for (std::pair<Value, Value> edge : edges) 54 EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second); 55 } 56 57 protected: 58 /// The context for creating the values. 59 MLIRContext context; 60 61 /// Block holding all the operations. 62 Block block; 63 64 /// Values used in the graph definition. We always use leading `n` values. 65 Value v[4]; 66 67 /// The graph being tested on. 68 RootOrderingGraph graph; 69 }; 70 71 //===----------------------------------------------------------------------===// 72 // Simple 3-node graphs 73 //===----------------------------------------------------------------------===// 74 75 TEST_F(RootOrderingTest, simpleA) { 76 graph[v[1]][v[0]].cost = {1, 10}; 77 graph[v[2]][v[0]].cost = {1, 11}; 78 graph[v[1]][v[2]].cost = {2, 12}; 79 graph[v[2]][v[1]].cost = {2, 13}; 80 check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}}); 81 } 82 83 TEST_F(RootOrderingTest, simpleB) { 84 graph[v[1]][v[0]].cost = {1, 10}; 85 graph[v[2]][v[0]].cost = {2, 11}; 86 graph[v[1]][v[2]].cost = {1, 12}; 87 graph[v[2]][v[1]].cost = {1, 13}; 88 check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); 89 } 90 91 TEST_F(RootOrderingTest, simpleC) { 92 graph[v[1]][v[0]].cost = {2, 10}; 93 graph[v[2]][v[0]].cost = {2, 11}; 94 graph[v[1]][v[2]].cost = {1, 12}; 95 graph[v[2]][v[1]].cost = {1, 13}; 96 check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // Graph for testing contraction 101 //===----------------------------------------------------------------------===// 102 103 TEST_F(RootOrderingTest, contraction) { 104 graph[v[1]][v[0]].cost = {10, 0}; 105 graph[v[2]][v[0]].cost = {5, 0}; 106 graph[v[2]][v[1]].cost = {1, 0}; 107 graph[v[3]][v[2]].cost = {2, 0}; 108 graph[v[1]][v[3]].cost = {3, 0}; 109 check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}}); 110 } 111 112 } // namespace 113