xref: /llvm-project/mlir/unittests/IR/SymbolTableTest.cpp (revision 479057887fbc8bfef17c86694f78496c54550f21)
1 //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "mlir/IR/SymbolTable.h"
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/Verifier.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
13 #include "mlir/Parser/Parser.h"
14 
15 #include "gtest/gtest.h"
16 
17 using namespace mlir;
18 
19 namespace test {
20 void registerTestDialect(DialectRegistry &);
21 } // namespace test
22 
23 class ReplaceAllSymbolUsesTest : public ::testing::Test {
24 protected:
25   using ReplaceFnType = llvm::function_ref<LogicalResult(
26       SymbolTable, ModuleOp, Operation *, Operation *)>;
27 
SetUp()28   void SetUp() override {
29     ::test::registerTestDialect(registry);
30     context = std::make_unique<MLIRContext>(registry);
31   }
32 
testReplaceAllSymbolUses(ReplaceFnType replaceFn)33   void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34     // Set up IR and find func ops.
35     OwningOpRef<ModuleOp> module =
36         parseSourceString<ModuleOp>(kInput, context.get());
37     SymbolTable symbolTable(module.get());
38     auto opIterator = module->getBody(0)->getOperations().begin();
39     auto fooOp = cast<FunctionOpInterface>(opIterator++);
40     auto barOp = cast<FunctionOpInterface>(opIterator++);
41     ASSERT_EQ(fooOp.getNameAttr(), "foo");
42     ASSERT_EQ(barOp.getNameAttr(), "bar");
43 
44     // Call test function that does symbol replacement.
45     LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46     ASSERT_TRUE(succeeded(res));
47     ASSERT_TRUE(succeeded(verify(module.get())));
48 
49     // Check that it got renamed.
50     bool calleeFound = false;
51     fooOp->walk([&](CallOpInterface callOp) {
52       StringAttr callee = callOp.getCallableForCallee()
53                               .dyn_cast<SymbolRefAttr>()
54                               .getLeafReference();
55       EXPECT_EQ(callee, "baz");
56       calleeFound = true;
57     });
58     EXPECT_TRUE(calleeFound);
59   }
60 
61   std::unique_ptr<MLIRContext> context;
62 
63 private:
64   constexpr static llvm::StringLiteral kInput = R"MLIR(
65       module {
66         test.conversion_func_op private @foo() {
67           "test.conversion_call_op"() { callee=@bar } : () -> ()
68           "test.return"() : () -> ()
69         }
70         test.conversion_func_op private @bar()
71       }
72     )MLIR";
73 
74   DialectRegistry registry;
75 };
76 
77 namespace {
78 
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleOp)79 TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
80   // Symbol as `Operation *`, rename within module.
81   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
82                                auto barOp) -> LogicalResult {
83     return symbolTable.replaceAllSymbolUses(
84         barOp, StringAttr::get(context.get(), "baz"), module);
85   });
86 }
87 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleOp)88 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
89   // Symbol as `StringAttr`, rename within module.
90   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
91                                auto barOp) -> LogicalResult {
92     return symbolTable.replaceAllSymbolUses(
93         StringAttr::get(context.get(), "bar"),
94         StringAttr::get(context.get(), "baz"), module);
95   });
96 }
97 
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleBody)98 TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
99   // Symbol as `Operation *`, rename within module body.
100   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101                                auto barOp) -> LogicalResult {
102     return symbolTable.replaceAllSymbolUses(
103         barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
104   });
105 }
106 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleBody)107 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108   // Symbol as `StringAttr`, rename within module body.
109   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110                                auto barOp) -> LogicalResult {
111     return symbolTable.replaceAllSymbolUses(
112         StringAttr::get(context.get(), "bar"),
113         StringAttr::get(context.get(), "baz"), &module->getRegion(0));
114   });
115 }
116 
TEST_F(ReplaceAllSymbolUsesTest,OperationInFuncOp)117 TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
118   // Symbol as `Operation *`, rename within function.
119   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120                                auto barOp) -> LogicalResult {
121     return symbolTable.replaceAllSymbolUses(
122         barOp, StringAttr::get(context.get(), "baz"), fooOp);
123   });
124 }
125 
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInFuncOp)126 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
127   // Symbol as `StringAttr`, rename within function.
128   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129                                auto barOp) -> LogicalResult {
130     return symbolTable.replaceAllSymbolUses(
131         StringAttr::get(context.get(), "bar"),
132         StringAttr::get(context.get(), "baz"), fooOp);
133   });
134 }
135 
136 } // namespace
137