1*47905788SIngo Müller //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2*47905788SIngo Müller //
3*47905788SIngo Müller // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*47905788SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
5*47905788SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*47905788SIngo Müller //
7*47905788SIngo Müller //===----------------------------------------------------------------------===//
8*47905788SIngo Müller #include "mlir/IR/SymbolTable.h"
9*47905788SIngo Müller #include "mlir/IR/BuiltinOps.h"
10*47905788SIngo Müller #include "mlir/IR/Verifier.h"
11*47905788SIngo Müller #include "mlir/Interfaces/CallInterfaces.h"
12*47905788SIngo Müller #include "mlir/Interfaces/FunctionInterfaces.h"
13*47905788SIngo Müller #include "mlir/Parser/Parser.h"
14*47905788SIngo Müller
15*47905788SIngo Müller #include "gtest/gtest.h"
16*47905788SIngo Müller
17*47905788SIngo Müller using namespace mlir;
18*47905788SIngo Müller
19*47905788SIngo Müller namespace test {
20*47905788SIngo Müller void registerTestDialect(DialectRegistry &);
21*47905788SIngo Müller } // namespace test
22*47905788SIngo Müller
23*47905788SIngo Müller class ReplaceAllSymbolUsesTest : public ::testing::Test {
24*47905788SIngo Müller protected:
25*47905788SIngo Müller using ReplaceFnType = llvm::function_ref<LogicalResult(
26*47905788SIngo Müller SymbolTable, ModuleOp, Operation *, Operation *)>;
27*47905788SIngo Müller
SetUp()28*47905788SIngo Müller void SetUp() override {
29*47905788SIngo Müller ::test::registerTestDialect(registry);
30*47905788SIngo Müller context = std::make_unique<MLIRContext>(registry);
31*47905788SIngo Müller }
32*47905788SIngo Müller
testReplaceAllSymbolUses(ReplaceFnType replaceFn)33*47905788SIngo Müller void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34*47905788SIngo Müller // Set up IR and find func ops.
35*47905788SIngo Müller OwningOpRef<ModuleOp> module =
36*47905788SIngo Müller parseSourceString<ModuleOp>(kInput, context.get());
37*47905788SIngo Müller SymbolTable symbolTable(module.get());
38*47905788SIngo Müller auto opIterator = module->getBody(0)->getOperations().begin();
39*47905788SIngo Müller auto fooOp = cast<FunctionOpInterface>(opIterator++);
40*47905788SIngo Müller auto barOp = cast<FunctionOpInterface>(opIterator++);
41*47905788SIngo Müller ASSERT_EQ(fooOp.getNameAttr(), "foo");
42*47905788SIngo Müller ASSERT_EQ(barOp.getNameAttr(), "bar");
43*47905788SIngo Müller
44*47905788SIngo Müller // Call test function that does symbol replacement.
45*47905788SIngo Müller LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46*47905788SIngo Müller ASSERT_TRUE(succeeded(res));
47*47905788SIngo Müller ASSERT_TRUE(succeeded(verify(module.get())));
48*47905788SIngo Müller
49*47905788SIngo Müller // Check that it got renamed.
50*47905788SIngo Müller bool calleeFound = false;
51*47905788SIngo Müller fooOp->walk([&](CallOpInterface callOp) {
52*47905788SIngo Müller StringAttr callee = callOp.getCallableForCallee()
53*47905788SIngo Müller .dyn_cast<SymbolRefAttr>()
54*47905788SIngo Müller .getLeafReference();
55*47905788SIngo Müller EXPECT_EQ(callee, "baz");
56*47905788SIngo Müller calleeFound = true;
57*47905788SIngo Müller });
58*47905788SIngo Müller EXPECT_TRUE(calleeFound);
59*47905788SIngo Müller }
60*47905788SIngo Müller
61*47905788SIngo Müller std::unique_ptr<MLIRContext> context;
62*47905788SIngo Müller
63*47905788SIngo Müller private:
64*47905788SIngo Müller constexpr static llvm::StringLiteral kInput = R"MLIR(
65*47905788SIngo Müller module {
66*47905788SIngo Müller test.conversion_func_op private @foo() {
67*47905788SIngo Müller "test.conversion_call_op"() { callee=@bar } : () -> ()
68*47905788SIngo Müller "test.return"() : () -> ()
69*47905788SIngo Müller }
70*47905788SIngo Müller test.conversion_func_op private @bar()
71*47905788SIngo Müller }
72*47905788SIngo Müller )MLIR";
73*47905788SIngo Müller
74*47905788SIngo Müller DialectRegistry registry;
75*47905788SIngo Müller };
76*47905788SIngo Müller
77*47905788SIngo Müller namespace {
78*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleOp)79*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
80*47905788SIngo Müller // Symbol as `Operation *`, rename within module.
81*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
82*47905788SIngo Müller auto barOp) -> LogicalResult {
83*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
84*47905788SIngo Müller barOp, StringAttr::get(context.get(), "baz"), module);
85*47905788SIngo Müller });
86*47905788SIngo Müller }
87*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleOp)88*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
89*47905788SIngo Müller // Symbol as `StringAttr`, rename within module.
90*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
91*47905788SIngo Müller auto barOp) -> LogicalResult {
92*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
93*47905788SIngo Müller StringAttr::get(context.get(), "bar"),
94*47905788SIngo Müller StringAttr::get(context.get(), "baz"), module);
95*47905788SIngo Müller });
96*47905788SIngo Müller }
97*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,OperationInModuleBody)98*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
99*47905788SIngo Müller // Symbol as `Operation *`, rename within module body.
100*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101*47905788SIngo Müller auto barOp) -> LogicalResult {
102*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
103*47905788SIngo Müller barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
104*47905788SIngo Müller });
105*47905788SIngo Müller }
106*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInModuleBody)107*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108*47905788SIngo Müller // Symbol as `StringAttr`, rename within module body.
109*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110*47905788SIngo Müller auto barOp) -> LogicalResult {
111*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
112*47905788SIngo Müller StringAttr::get(context.get(), "bar"),
113*47905788SIngo Müller StringAttr::get(context.get(), "baz"), &module->getRegion(0));
114*47905788SIngo Müller });
115*47905788SIngo Müller }
116*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,OperationInFuncOp)117*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
118*47905788SIngo Müller // Symbol as `Operation *`, rename within function.
119*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120*47905788SIngo Müller auto barOp) -> LogicalResult {
121*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
122*47905788SIngo Müller barOp, StringAttr::get(context.get(), "baz"), fooOp);
123*47905788SIngo Müller });
124*47905788SIngo Müller }
125*47905788SIngo Müller
TEST_F(ReplaceAllSymbolUsesTest,StringAttrInFuncOp)126*47905788SIngo Müller TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
127*47905788SIngo Müller // Symbol as `StringAttr`, rename within function.
128*47905788SIngo Müller testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129*47905788SIngo Müller auto barOp) -> LogicalResult {
130*47905788SIngo Müller return symbolTable.replaceAllSymbolUses(
131*47905788SIngo Müller StringAttr::get(context.get(), "bar"),
132*47905788SIngo Müller StringAttr::get(context.get(), "baz"), fooOp);
133*47905788SIngo Müller });
134*47905788SIngo Müller }
135*47905788SIngo Müller
136*47905788SIngo Müller } // namespace
137