xref: /llvm-project/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp (revision 6df7cc7f47d280d550f41fc167bdd75fea726a06)
1 //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v[1].0 with LLVM
4 // Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_to_pdl_interp;
17 
18 namespace {
19 
20 //===----------------------------------------------------------------------===//
21 // Test Fixture
22 //===----------------------------------------------------------------------===//
23 
24 /// The test fixture for constructing root ordering tests and verifying results.
25 /// This fixture constructs the test values v. The test populates the graph
26 /// with the desired costs and then calls check(), passing the expeted optimal
27 /// cost and the list of edges in the preorder traversal of the optimal
28 /// branching.
29 class RootOrderingTest : public ::testing::Test {
30 protected:
31   RootOrderingTest() {
32     context.loadDialect<StandardOpsDialect>();
33     createValues();
34   }
35 
36   /// Creates the test values.
37   void createValues() {
38     OpBuilder builder(&context);
39     for (int i = 0; i < 4; ++i)
40       v[i] = builder.create<ConstantOp>(builder.getUnknownLoc(),
41                                         builder.getI32IntegerAttr(i));
42   }
43 
44   /// Checks that optimal branching on graph has the given cost and
45   /// its preorder traversal results in the specified edges.
46   void check(unsigned cost, OptimalBranching::EdgeList edges) {
47     OptimalBranching opt(graph, v[0]);
48     EXPECT_EQ(opt.solve(), cost);
49     EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges);
50     for (std::pair<Value, Value> edge : edges)
51       EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second);
52   }
53 
54 protected:
55   /// The context for creating the values.
56   MLIRContext context;
57 
58   /// Values used in the graph definition. We always use leading `n` values.
59   Value v[4];
60 
61   /// The graph being tested on.
62   RootOrderingGraph graph;
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Simple 3-node graphs
67 //===----------------------------------------------------------------------===//
68 
69 TEST_F(RootOrderingTest, simpleA) {
70   graph[v[1]][v[0]].cost = {1, 10};
71   graph[v[2]][v[0]].cost = {1, 11};
72   graph[v[1]][v[2]].cost = {2, 12};
73   graph[v[2]][v[1]].cost = {2, 13};
74   check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}});
75 }
76 
77 TEST_F(RootOrderingTest, simpleB) {
78   graph[v[1]][v[0]].cost = {1, 10};
79   graph[v[2]][v[0]].cost = {2, 11};
80   graph[v[1]][v[2]].cost = {1, 12};
81   graph[v[2]][v[1]].cost = {1, 13};
82   check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
83 }
84 
85 TEST_F(RootOrderingTest, simpleC) {
86   graph[v[1]][v[0]].cost = {2, 10};
87   graph[v[2]][v[0]].cost = {2, 11};
88   graph[v[1]][v[2]].cost = {1, 12};
89   graph[v[2]][v[1]].cost = {1, 13};
90   check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Graph for testing contraction
95 //===----------------------------------------------------------------------===//
96 
97 TEST_F(RootOrderingTest, contraction) {
98   graph[v[1]][v[0]].cost = {10, 0};
99   graph[v[2]][v[0]].cost = {5, 0};
100   graph[v[2]][v[1]].cost = {1, 0};
101   graph[v[3]][v[2]].cost = {2, 0};
102   graph[v[1]][v[3]].cost = {3, 0};
103   check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}});
104 }
105 
106 } // end namespace
107