xref: /llvm-project/flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp (revision e93d226664d7012d1bb017f0cda24ad1b75f37fc)
1 //=== CompilerGeneratedNames.cpp - convert special symbols in global names ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Support/InternalNames.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/Pass/Pass.h"
18 
19 namespace fir {
20 #define GEN_PASS_DEF_COMPILERGENERATEDNAMESCONVERSION
21 #include "flang/Optimizer/Transforms/Passes.h.inc"
22 } // namespace fir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 class CompilerGeneratedNamesConversionPass
29     : public fir::impl::CompilerGeneratedNamesConversionBase<
30           CompilerGeneratedNamesConversionPass> {
31 public:
32   using CompilerGeneratedNamesConversionBase<
33       CompilerGeneratedNamesConversionPass>::
34       CompilerGeneratedNamesConversionBase;
35 
36   mlir::ModuleOp getModule() { return getOperation(); }
37   void runOnOperation() override;
38 };
39 } // namespace
40 
41 void CompilerGeneratedNamesConversionPass::runOnOperation() {
42   auto op = getOperation();
43   auto *context = &getContext();
44 
45   llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
46 
47   auto processOp = [&](mlir::Operation &op) {
48     auto symName = op.getAttrOfType<mlir::StringAttr>(
49         mlir::SymbolTable::getSymbolAttrName());
50     auto deconstructedName = fir::NameUniquer::deconstruct(symName);
51     if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED &&
52         !fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
53       std::string newName =
54           fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str());
55       if (newName != symName) {
56         auto newAttr = mlir::StringAttr::get(context, newName);
57         mlir::SymbolTable::setSymbolName(&op, newAttr);
58         auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
59         remappings.try_emplace(symName, newSymRef);
60       }
61     }
62   };
63   for (auto &op : op->getRegion(0).front()) {
64     if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op))
65       processOp(op);
66     else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(&op))
67       for (auto &op : gpuMod->getRegion(0).front())
68         if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op) ||
69             llvm::isa<mlir::gpu::GPUFuncOp>(op))
70           processOp(op);
71   }
72 
73   if (remappings.empty())
74     return;
75 
76   // Update all uses of the functions and globals that have been renamed.
77   op.walk([&remappings](mlir::Operation *nestedOp) {
78     llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
79     for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
80       if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue()))
81         if (auto remap = remappings.find(symRef.getRootReference());
82             remap != remappings.end())
83           updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
84               attr.getName(), mlir::SymbolRefAttr(remap->second)});
85     for (auto update : updates)
86       nestedOp->setAttr(update.first, update.second);
87   });
88 }
89