xref: /llvm-project/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 namespace fir {
21 #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
22 #include "flang/Optimizer/Transforms/Passes.h.inc"
23 } // namespace fir
24 
25 #define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
26 
27 namespace {
28 unsigned uniqueLitId = 1;
29 
30 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
31 protected:
32   const mlir::DominanceInfo &di;
33 
34 public:
35   using OpRewritePattern::OpRewritePattern;
36 
37   CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
38       : OpRewritePattern(ctx), di(_di) {}
39 
40   llvm::LogicalResult
41   matchAndRewrite(fir::CallOp callOp,
42                   mlir::PatternRewriter &rewriter) const override {
43     LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
44     auto module = callOp->getParentOfType<mlir::ModuleOp>();
45     bool needUpdate = false;
46     fir::FirOpBuilder builder(rewriter, module);
47     llvm::SmallVector<mlir::Value> newOperands;
48     llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
49     for (const mlir::Value &a : callOp.getArgs()) {
50       auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
51       // We can convert arguments that are alloca, and that has
52       // the value by reference attribute. All else is just added
53       // to the argument list.
54       if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
55         newOperands.push_back(a);
56         continue;
57       }
58 
59       mlir::Type varTy = alloca.getInType();
60       assert(!fir::hasDynamicSize(varTy) &&
61              "only expect statically sized scalars to be by value");
62 
63       // Find immediate store with const argument
64       mlir::Operation *store = nullptr;
65       for (mlir::Operation *s : alloca->getUsers()) {
66         if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
67           // We can only deal with ONE store - if already found one,
68           // set to nullptr and exit the loop.
69           if (store) {
70             store = nullptr;
71             break;
72           }
73           store = s;
74         }
75       }
76 
77       // If we didn't find any store, or multiple stores, add argument as is
78       // and move on.
79       if (!store) {
80         newOperands.push_back(a);
81         continue;
82       }
83 
84       LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
85 
86       mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
87       // If not a constant, add to operands and move on.
88       if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
89         // Unable to remove alloca arg
90         newOperands.push_back(a);
91         continue;
92       }
93 
94       LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
95 
96       std::string globalName =
97           "_global_const_." + std::to_string(uniqueLitId++);
98       assert(!builder.getNamedGlobal(globalName) &&
99              "We should have a unique name here");
100 
101       if (llvm::none_of(allocas,
102                         [alloca](auto x) { return x.first == alloca; })) {
103         allocas.push_back(std::make_pair(alloca, store));
104       }
105 
106       auto loc = callOp.getLoc();
107       fir::GlobalOp global = builder.createGlobalConstant(
108           loc, varTy, globalName,
109           [&](fir::FirOpBuilder &builder) {
110             mlir::Operation *cln = definingOp->clone();
111             builder.insert(cln);
112             mlir::Value val =
113                 builder.createConvert(loc, varTy, cln->getResult(0));
114             builder.create<fir::HasValueOp>(loc, val);
115           },
116           builder.createInternalLinkage());
117       mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
118                                                        global.getSymbol());
119       newOperands.push_back(addr);
120       needUpdate = true;
121     }
122 
123     if (needUpdate) {
124       auto loc = callOp.getLoc();
125       llvm::SmallVector<mlir::Type> newResultTypes;
126       newResultTypes.append(callOp.getResultTypes().begin(),
127                             callOp.getResultTypes().end());
128       fir::CallOp newOp = builder.create<fir::CallOp>(
129           loc,
130           callOp.getCallee().has_value() ? callOp.getCallee().value()
131                                          : mlir::SymbolRefAttr{},
132           newResultTypes, newOperands);
133       // Copy all the attributes from the old to new op.
134       newOp->setAttrs(callOp->getAttrs());
135       rewriter.replaceOp(callOp, newOp);
136 
137       for (auto a : allocas) {
138         if (a.first->hasOneUse()) {
139           // If the alloca is only used for a store and the call operand, the
140           // store is no longer required.
141           rewriter.eraseOp(a.second);
142           rewriter.eraseOp(a.first);
143         }
144       }
145       LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
146                               << newOp << '\n');
147       return mlir::success();
148     }
149 
150     // Failure here just means "we couldn't do the conversion", which is
151     // perfectly acceptable to the upper layers of this function.
152     return mlir::failure();
153   }
154 };
155 
156 // this pass attempts to convert immediate scalar literals in function calls
157 // to global constants to allow transformations such as Dead Argument
158 // Elimination
159 class ConstantArgumentGlobalisationOpt
160     : public fir::impl::ConstantArgumentGlobalisationOptBase<
161           ConstantArgumentGlobalisationOpt> {
162 public:
163   ConstantArgumentGlobalisationOpt() = default;
164 
165   void runOnOperation() override {
166     mlir::ModuleOp mod = getOperation();
167     mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
168     auto *context = &getContext();
169     mlir::RewritePatternSet patterns(context);
170     mlir::GreedyRewriteConfig config;
171     config.enableRegionSimplification =
172         mlir::GreedySimplifyRegionLevel::Disabled;
173     config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
174 
175     patterns.insert<CallOpRewriter>(context, *di);
176     if (mlir::failed(
177             mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
178       mlir::emitError(mod.getLoc(),
179                       "error in constant globalisation optimization\n");
180       signalPassFailure();
181     }
182   }
183 };
184 } // namespace
185