1 //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v[1].0 with LLVM 4 // Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/MLIRContext.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::pdl_to_pdl_interp; 17 18 namespace { 19 20 //===----------------------------------------------------------------------===// 21 // Test Fixture 22 //===----------------------------------------------------------------------===// 23 24 /// The test fixture for constructing root ordering tests and verifying results. 25 /// This fixture constructs the test values v. The test populates the graph 26 /// with the desired costs and then calls check(), passing the expeted optimal 27 /// cost and the list of edges in the preorder traversal of the optimal 28 /// branching. 29 class RootOrderingTest : public ::testing::Test { 30 protected: 31 RootOrderingTest() { 32 context.loadDialect<StandardOpsDialect>(); 33 createValues(); 34 } 35 36 /// Creates the test values. 37 void createValues() { 38 OpBuilder builder(&context); 39 for (int i = 0; i < 4; ++i) 40 v[i] = builder.create<ConstantOp>(builder.getUnknownLoc(), 41 builder.getI32IntegerAttr(i)); 42 } 43 44 /// Checks that optimal branching on graph has the given cost and 45 /// its preorder traversal results in the specified edges. 46 void check(unsigned cost, OptimalBranching::EdgeList edges) { 47 OptimalBranching opt(graph, v[0]); 48 EXPECT_EQ(opt.solve(), cost); 49 EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges); 50 for (std::pair<Value, Value> edge : edges) 51 EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second); 52 } 53 54 protected: 55 /// The context for creating the values. 56 MLIRContext context; 57 58 /// Values used in the graph definition. We always use leading `n` values. 59 Value v[4]; 60 61 /// The graph being tested on. 62 RootOrderingGraph graph; 63 }; 64 65 //===----------------------------------------------------------------------===// 66 // Simple 3-node graphs 67 //===----------------------------------------------------------------------===// 68 69 TEST_F(RootOrderingTest, simpleA) { 70 graph[v[1]][v[0]].cost = {1, 10}; 71 graph[v[2]][v[0]].cost = {1, 11}; 72 graph[v[1]][v[2]].cost = {2, 12}; 73 graph[v[2]][v[1]].cost = {2, 13}; 74 check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}}); 75 } 76 77 TEST_F(RootOrderingTest, simpleB) { 78 graph[v[1]][v[0]].cost = {1, 10}; 79 graph[v[2]][v[0]].cost = {2, 11}; 80 graph[v[1]][v[2]].cost = {1, 12}; 81 graph[v[2]][v[1]].cost = {1, 13}; 82 check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); 83 } 84 85 TEST_F(RootOrderingTest, simpleC) { 86 graph[v[1]][v[0]].cost = {2, 10}; 87 graph[v[2]][v[0]].cost = {2, 11}; 88 graph[v[1]][v[2]].cost = {1, 12}; 89 graph[v[2]][v[1]].cost = {1, 13}; 90 check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Graph for testing contraction 95 //===----------------------------------------------------------------------===// 96 97 TEST_F(RootOrderingTest, contraction) { 98 graph[v[1]][v[0]].cost = {10, 0}; 99 graph[v[2]][v[0]].cost = {5, 0}; 100 graph[v[2]][v[1]].cost = {1, 0}; 101 graph[v[3]][v[2]].cost = {2, 0}; 102 graph[v[1]][v[3]].cost = {3, 0}; 103 check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}}); 104 } 105 106 } // end namespace 107