1 //=== CompilerGeneratedNames.cpp - convert special symbols in global names ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIRDialect.h" 10 #include "flang/Optimizer/Dialect/FIROps.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Support/InternalNames.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/Pass/Pass.h" 18 19 namespace fir { 20 #define GEN_PASS_DEF_COMPILERGENERATEDNAMESCONVERSION 21 #include "flang/Optimizer/Transforms/Passes.h.inc" 22 } // namespace fir 23 24 using namespace mlir; 25 26 namespace { 27 28 class CompilerGeneratedNamesConversionPass 29 : public fir::impl::CompilerGeneratedNamesConversionBase< 30 CompilerGeneratedNamesConversionPass> { 31 public: 32 using CompilerGeneratedNamesConversionBase< 33 CompilerGeneratedNamesConversionPass>:: 34 CompilerGeneratedNamesConversionBase; 35 36 mlir::ModuleOp getModule() { return getOperation(); } 37 void runOnOperation() override; 38 }; 39 } // namespace 40 41 void CompilerGeneratedNamesConversionPass::runOnOperation() { 42 auto op = getOperation(); 43 auto *context = &getContext(); 44 45 llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings; 46 47 auto processOp = [&](mlir::Operation &op) { 48 auto symName = op.getAttrOfType<mlir::StringAttr>( 49 mlir::SymbolTable::getSymbolAttrName()); 50 auto deconstructedName = fir::NameUniquer::deconstruct(symName); 51 if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED && 52 !fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { 53 std::string newName = 54 fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str()); 55 if (newName != symName) { 56 auto newAttr = mlir::StringAttr::get(context, newName); 57 mlir::SymbolTable::setSymbolName(&op, newAttr); 58 auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr); 59 remappings.try_emplace(symName, newSymRef); 60 } 61 } 62 }; 63 for (auto &op : op->getRegion(0).front()) { 64 if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op)) 65 processOp(op); 66 else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(&op)) 67 for (auto &op : gpuMod->getRegion(0).front()) 68 if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op) || 69 llvm::isa<mlir::gpu::GPUFuncOp>(op)) 70 processOp(op); 71 } 72 73 if (remappings.empty()) 74 return; 75 76 // Update all uses of the functions and globals that have been renamed. 77 op.walk([&remappings](mlir::Operation *nestedOp) { 78 llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates; 79 for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary()) 80 if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) 81 if (auto remap = remappings.find(symRef.getRootReference()); 82 remap != remappings.end()) 83 updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{ 84 attr.getName(), mlir::SymbolRefAttr(remap->second)}); 85 for (auto update : updates) 86 nestedOp->setAttr(update.first, update.second); 87 }); 88 } 89