xref: /llvm-project/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp (revision abc362a1077b9cb4186e3e53a616589c7fed4387)
1 //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v[1].0 with LLVM
4 // Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::arith;
17 using namespace mlir::pdl_to_pdl_interp;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Test Fixture
23 //===----------------------------------------------------------------------===//
24 
25 /// The test fixture for constructing root ordering tests and verifying results.
26 /// This fixture constructs the test values v. The test populates the graph
27 /// with the desired costs and then calls check(), passing the expected optimal
28 /// cost and the list of edges in the preorder traversal of the optimal
29 /// branching.
30 class RootOrderingTest : public ::testing::Test {
31 protected:
RootOrderingTest()32   RootOrderingTest() {
33     context.loadDialect<ArithDialect>();
34     createValues();
35   }
36 
37   /// Creates the test values. These values simply act as vertices / vertex IDs
38   /// in the cost graph, rather than being a part of an IR.
createValues()39   void createValues() {
40     OpBuilder builder(&context);
41     builder.setInsertionPointToStart(&block);
42     for (int i = 0; i < 4; ++i)
43       // Ops will be deleted when `block` is destroyed.
44       v[i] = builder.create<ConstantIntOp>(builder.getUnknownLoc(), i, 32);
45   }
46 
47   /// Checks that optimal branching on graph has the given cost and
48   /// its preorder traversal results in the specified edges.
check(unsigned cost,const OptimalBranching::EdgeList & edges)49   void check(unsigned cost, const OptimalBranching::EdgeList &edges) {
50     OptimalBranching opt(graph, v[0]);
51     EXPECT_EQ(opt.solve(), cost);
52     EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges);
53     for (std::pair<Value, Value> edge : edges)
54       EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second);
55   }
56 
57 protected:
58   /// The context for creating the values.
59   MLIRContext context;
60 
61   /// Block holding all the operations.
62   Block block;
63 
64   /// Values used in the graph definition. We always use leading `n` values.
65   Value v[4];
66 
67   /// The graph being tested on.
68   RootOrderingGraph graph;
69 };
70 
71 //===----------------------------------------------------------------------===//
72 // Simple 3-node graphs
73 //===----------------------------------------------------------------------===//
74 
TEST_F(RootOrderingTest,simpleA)75 TEST_F(RootOrderingTest, simpleA) {
76   graph[v[1]][v[0]].cost = {1, 10};
77   graph[v[2]][v[0]].cost = {1, 11};
78   graph[v[1]][v[2]].cost = {2, 12};
79   graph[v[2]][v[1]].cost = {2, 13};
80   check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}});
81 }
82 
TEST_F(RootOrderingTest,simpleB)83 TEST_F(RootOrderingTest, simpleB) {
84   graph[v[1]][v[0]].cost = {1, 10};
85   graph[v[2]][v[0]].cost = {2, 11};
86   graph[v[1]][v[2]].cost = {1, 12};
87   graph[v[2]][v[1]].cost = {1, 13};
88   check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
89 }
90 
TEST_F(RootOrderingTest,simpleC)91 TEST_F(RootOrderingTest, simpleC) {
92   graph[v[1]][v[0]].cost = {2, 10};
93   graph[v[2]][v[0]].cost = {2, 11};
94   graph[v[1]][v[2]].cost = {1, 12};
95   graph[v[2]][v[1]].cost = {1, 13};
96   check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // Graph for testing contraction
101 //===----------------------------------------------------------------------===//
102 
TEST_F(RootOrderingTest,contraction)103 TEST_F(RootOrderingTest, contraction) {
104   graph[v[1]][v[0]].cost = {10, 0};
105   graph[v[2]][v[0]].cost = {5, 0};
106   graph[v[2]][v[1]].cost = {1, 0};
107   graph[v[3]][v[2]].cost = {2, 0};
108   graph[v[1]][v[3]].cost = {3, 0};
109   check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}});
110 }
111 
112 } // namespace
113