1 //===- ConstantArgumentGlobalisation.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dominance.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 namespace fir { 21 #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT 22 #include "flang/Optimizer/Transforms/Passes.h.inc" 23 } // namespace fir 24 25 #define DEBUG_TYPE "flang-constant-argument-globalisation-opt" 26 27 namespace { 28 unsigned uniqueLitId = 1; 29 30 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> { 31 protected: 32 const mlir::DominanceInfo &di; 33 34 public: 35 using OpRewritePattern::OpRewritePattern; 36 37 CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di) 38 : OpRewritePattern(ctx), di(_di) {} 39 40 llvm::LogicalResult 41 matchAndRewrite(fir::CallOp callOp, 42 mlir::PatternRewriter &rewriter) const override { 43 LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n"); 44 auto module = callOp->getParentOfType<mlir::ModuleOp>(); 45 bool needUpdate = false; 46 fir::FirOpBuilder builder(rewriter, module); 47 llvm::SmallVector<mlir::Value> newOperands; 48 llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas; 49 for (const mlir::Value &a : callOp.getArgs()) { 50 auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp()); 51 // We can convert arguments that are alloca, and that has 52 // the value by reference attribute. All else is just added 53 // to the argument list. 54 if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) { 55 newOperands.push_back(a); 56 continue; 57 } 58 59 mlir::Type varTy = alloca.getInType(); 60 assert(!fir::hasDynamicSize(varTy) && 61 "only expect statically sized scalars to be by value"); 62 63 // Find immediate store with const argument 64 mlir::Operation *store = nullptr; 65 for (mlir::Operation *s : alloca->getUsers()) { 66 if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) { 67 // We can only deal with ONE store - if already found one, 68 // set to nullptr and exit the loop. 69 if (store) { 70 store = nullptr; 71 break; 72 } 73 store = s; 74 } 75 } 76 77 // If we didn't find any store, or multiple stores, add argument as is 78 // and move on. 79 if (!store) { 80 newOperands.push_back(a); 81 continue; 82 } 83 84 LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n"); 85 86 mlir::Operation *definingOp = store->getOperand(0).getDefiningOp(); 87 // If not a constant, add to operands and move on. 88 if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) { 89 // Unable to remove alloca arg 90 newOperands.push_back(a); 91 continue; 92 } 93 94 LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n"); 95 96 std::string globalName = 97 "_global_const_." + std::to_string(uniqueLitId++); 98 assert(!builder.getNamedGlobal(globalName) && 99 "We should have a unique name here"); 100 101 if (llvm::none_of(allocas, 102 [alloca](auto x) { return x.first == alloca; })) { 103 allocas.push_back(std::make_pair(alloca, store)); 104 } 105 106 auto loc = callOp.getLoc(); 107 fir::GlobalOp global = builder.createGlobalConstant( 108 loc, varTy, globalName, 109 [&](fir::FirOpBuilder &builder) { 110 mlir::Operation *cln = definingOp->clone(); 111 builder.insert(cln); 112 mlir::Value val = 113 builder.createConvert(loc, varTy, cln->getResult(0)); 114 builder.create<fir::HasValueOp>(loc, val); 115 }, 116 builder.createInternalLinkage()); 117 mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(), 118 global.getSymbol()); 119 newOperands.push_back(addr); 120 needUpdate = true; 121 } 122 123 if (needUpdate) { 124 auto loc = callOp.getLoc(); 125 llvm::SmallVector<mlir::Type> newResultTypes; 126 newResultTypes.append(callOp.getResultTypes().begin(), 127 callOp.getResultTypes().end()); 128 fir::CallOp newOp = builder.create<fir::CallOp>( 129 loc, 130 callOp.getCallee().has_value() ? callOp.getCallee().value() 131 : mlir::SymbolRefAttr{}, 132 newResultTypes, newOperands); 133 // Copy all the attributes from the old to new op. 134 newOp->setAttrs(callOp->getAttrs()); 135 rewriter.replaceOp(callOp, newOp); 136 137 for (auto a : allocas) { 138 if (a.first->hasOneUse()) { 139 // If the alloca is only used for a store and the call operand, the 140 // store is no longer required. 141 rewriter.eraseOp(a.second); 142 rewriter.eraseOp(a.first); 143 } 144 } 145 LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as " 146 << newOp << '\n'); 147 return mlir::success(); 148 } 149 150 // Failure here just means "we couldn't do the conversion", which is 151 // perfectly acceptable to the upper layers of this function. 152 return mlir::failure(); 153 } 154 }; 155 156 // this pass attempts to convert immediate scalar literals in function calls 157 // to global constants to allow transformations such as Dead Argument 158 // Elimination 159 class ConstantArgumentGlobalisationOpt 160 : public fir::impl::ConstantArgumentGlobalisationOptBase< 161 ConstantArgumentGlobalisationOpt> { 162 public: 163 ConstantArgumentGlobalisationOpt() = default; 164 165 void runOnOperation() override { 166 mlir::ModuleOp mod = getOperation(); 167 mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>(); 168 auto *context = &getContext(); 169 mlir::RewritePatternSet patterns(context); 170 mlir::GreedyRewriteConfig config; 171 config.enableRegionSimplification = 172 mlir::GreedySimplifyRegionLevel::Disabled; 173 config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; 174 175 patterns.insert<CallOpRewriter>(context, *di); 176 if (mlir::failed( 177 mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { 178 mlir::emitError(mod.getLoc(), 179 "error in constant globalisation optimization\n"); 180 signalPassFailure(); 181 } 182 } 183 }; 184 } // namespace 185