xref: /llvm-project/mlir/unittests/IR/SymbolTableTest.cpp (revision 479057887fbc8bfef17c86694f78496c54550f21)
1*47905788SIngo Müller //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2*47905788SIngo Müller //
3*47905788SIngo Müller // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*47905788SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
5*47905788SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*47905788SIngo Müller //
7*47905788SIngo Müller //===----------------------------------------------------------------------===//
8*47905788SIngo Müller #include "mlir/IR/SymbolTable.h"
9*47905788SIngo Müller #include "mlir/IR/BuiltinOps.h"
10*47905788SIngo Müller #include "mlir/IR/Verifier.h"
11*47905788SIngo Müller #include "mlir/Interfaces/CallInterfaces.h"
12*47905788SIngo Müller #include "mlir/Interfaces/FunctionInterfaces.h"
13*47905788SIngo Müller #include "mlir/Parser/Parser.h"
14*47905788SIngo Müller 
15*47905788SIngo Müller #include "gtest/gtest.h"
16*47905788SIngo Müller 
17*47905788SIngo Müller using namespace mlir;
18*47905788SIngo Müller 
19*47905788SIngo Müller namespace test {
20*47905788SIngo Müller void registerTestDialect(DialectRegistry &);
21*47905788SIngo Müller } // namespace test
22*47905788SIngo Müller 
23*47905788SIngo Müller class ReplaceAllSymbolUsesTest : public ::testing::Test {
24*47905788SIngo Müller protected:
25*47905788SIngo Müller   using ReplaceFnType = llvm::function_ref<LogicalResult(
26*47905788SIngo Müller       SymbolTable, ModuleOp, Operation *, Operation *)>;
27*47905788SIngo Müller 
SetUp()28*47905788SIngo Müller   void SetUp() override {
29*47905788SIngo Müller     ::test::registerTestDialect(registry);
30*47905788SIngo Müller     context = std::make_unique<MLIRContext>(registry);
31*47905788SIngo Müller   }
32*47905788SIngo Müller 
testReplaceAllSymbolUses(ReplaceFnType replaceFn)33*47905788SIngo Müller   void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34*47905788SIngo Müller     // Set up IR and find func ops.
35*47905788SIngo Müller     OwningOpRef<ModuleOp> module =
36*47905788SIngo Müller         parseSourceString<ModuleOp>(kInput, context.get());
37*47905788SIngo Müller     SymbolTable symbolTable(module.get());
38*47905788SIngo Müller     auto opIterator = module->getBody(0)->getOperations().begin();
39*47905788SIngo Müller     auto fooOp = cast<FunctionOpInterface>(opIterator++);
40*47905788SIngo Müller     auto barOp = cast<FunctionOpInterface>(opIterator++);
41*47905788SIngo Müller     ASSERT_EQ(fooOp.getNameAttr(), "foo");
42*47905788SIngo Müller     ASSERT_EQ(barOp.getNameAttr(), "bar");
43*47905788SIngo Müller 
44*47905788SIngo Müller     // Call test function that does symbol replacement.
45*47905788SIngo Müller     LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46*47905788SIngo Müller     ASSERT_TRUE(succeeded(res));
47*47905788SIngo Müller     ASSERT_TRUE(succeeded(verify(module.get())));
48*47905788SIngo Müller 
49*47905788SIngo Müller     // Check that it got renamed.
50*47905788SIngo Müller     bool calleeFound = false;
51*47905788SIngo Müller     fooOp->walk([&](CallOpInterface callOp) {
52*47905788SIngo Müller       StringAttr callee = callOp.getCallableForCallee()
53*47905788SIngo Müller                               .dyn_cast<SymbolRefAttr>()
54*47905788SIngo Müller                               .getLeafReference();
55*47905788SIngo Müller       EXPECT_EQ(callee, "baz");
56*47905788SIngo Müller       calleeFound = true;
57*47905788SIngo Müller     });
58*47905788SIngo Müller     EXPECT_TRUE(calleeFound);
59*47905788SIngo Müller   }
60*47905788SIngo Müller 
61*47905788SIngo Müller   std::unique_ptr<MLIRContext> context;
62*47905788SIngo Müller 
63*47905788SIngo Müller private:
64*47905788SIngo Müller   constexpr static llvm::StringLiteral kInput = R"MLIR(
65*47905788SIngo Müller       module {
66*47905788SIngo Müller         test.conversion_func_op private @foo() {
67*47905788SIngo Müller           "test.conversion_call_op"() { callee=@bar } : () -> ()
68*47905788SIngo Müller           "test.return"() : () -> ()
69*47905788SIngo Müller         }
70*47905788SIngo Müller         test.conversion_func_op private @bar()
71*47905788SIngo Müller       }
72*47905788SIngo Müller     )MLIR";
73*47905788SIngo Müller 
74*47905788SIngo Müller   DialectRegistry registry;
75*47905788SIngo Müller };
76*47905788SIngo Müller 
77*47905788SIngo Müller namespace {
78*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleOp)79*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
80*47905788SIngo Müller   // Symbol as `Operation *`, rename within module.
81*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
82*47905788SIngo Müller                                auto barOp) -> LogicalResult {
83*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
84*47905788SIngo Müller         barOp, StringAttr::get(context.get(), "baz"), module);
85*47905788SIngo Müller   });
86*47905788SIngo Müller }
87*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleOp)88*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
89*47905788SIngo Müller   // Symbol as `StringAttr`, rename within module.
90*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
91*47905788SIngo Müller                                auto barOp) -> LogicalResult {
92*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
93*47905788SIngo Müller         StringAttr::get(context.get(), "bar"),
94*47905788SIngo Müller         StringAttr::get(context.get(), "baz"), module);
95*47905788SIngo Müller   });
96*47905788SIngo Müller }
97*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleBody)98*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
99*47905788SIngo Müller   // Symbol as `Operation *`, rename within module body.
100*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101*47905788SIngo Müller                                auto barOp) -> LogicalResult {
102*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
103*47905788SIngo Müller         barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
104*47905788SIngo Müller   });
105*47905788SIngo Müller }
106*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleBody)107*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108*47905788SIngo Müller   // Symbol as `StringAttr`, rename within module body.
109*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110*47905788SIngo Müller                                auto barOp) -> LogicalResult {
111*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
112*47905788SIngo Müller         StringAttr::get(context.get(), "bar"),
113*47905788SIngo Müller         StringAttr::get(context.get(), "baz"), &module->getRegion(0));
114*47905788SIngo Müller   });
115*47905788SIngo Müller }
116*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,OperationInFuncOp)117*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
118*47905788SIngo Müller   // Symbol as `Operation *`, rename within function.
119*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120*47905788SIngo Müller                                auto barOp) -> LogicalResult {
121*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
122*47905788SIngo Müller         barOp, StringAttr::get(context.get(), "baz"), fooOp);
123*47905788SIngo Müller   });
124*47905788SIngo Müller }
125*47905788SIngo Müller 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInFuncOp)126*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
127*47905788SIngo Müller   // Symbol as `StringAttr`, rename within function.
128*47905788SIngo Müller   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129*47905788SIngo Müller                                auto barOp) -> LogicalResult {
130*47905788SIngo Müller     return symbolTable.replaceAllSymbolUses(
131*47905788SIngo Müller         StringAttr::get(context.get(), "bar"),
132*47905788SIngo Müller         StringAttr::get(context.get(), "baz"), fooOp);
133*47905788SIngo Müller   });
134*47905788SIngo Müller }
135*47905788SIngo Müller 
136*47905788SIngo Müller } // namespace
137