1 //===-- BoxedProcedure.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/CodeGen/CodeGen.h" 10 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Support/FIRContext.h" 17 #include "flang/Optimizer/Support/FatalError.h" 18 #include "flang/Optimizer/Support/InternalNames.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS 25 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 26 } // namespace fir 27 28 #define DEBUG_TYPE "flang-procedure-pointer" 29 30 using namespace fir; 31 32 namespace { 33 /// Options to the procedure pointer pass. 34 struct BoxedProcedureOptions { 35 // Lower the boxproc abstraction to function pointers and thunks where 36 // required. 37 bool useThunks = true; 38 }; 39 40 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. 41 class BoxprocTypeRewriter : public mlir::TypeConverter { 42 public: 43 using mlir::TypeConverter::convertType; 44 45 /// Does the type \p ty need to be converted? 46 /// Any type that is a `!fir.boxproc` in whole or in part will need to be 47 /// converted to a function type to lower the IR to function pointer form in 48 /// the default implementation performed in this pass. Other implementations 49 /// are possible, so those may convert `!fir.boxproc` to some other type or 50 /// not at all depending on the implementation target's characteristics and 51 /// preference. 52 bool needsConversion(mlir::Type ty) { 53 if (ty.isa<BoxProcType>()) 54 return true; 55 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) { 56 for (auto t : funcTy.getInputs()) 57 if (needsConversion(t)) 58 return true; 59 for (auto t : funcTy.getResults()) 60 if (needsConversion(t)) 61 return true; 62 return false; 63 } 64 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) { 65 for (auto t : tupleTy.getTypes()) 66 if (needsConversion(t)) 67 return true; 68 return false; 69 } 70 if (auto recTy = ty.dyn_cast<RecordType>()) { 71 if (llvm::is_contained(visitedTypes, recTy)) 72 return false; 73 bool result = false; 74 visitedTypes.push_back(recTy); 75 for (auto t : recTy.getTypeList()) { 76 if (needsConversion(t.second)) { 77 result = true; 78 break; 79 } 80 } 81 visitedTypes.pop_back(); 82 return result; 83 } 84 if (auto boxTy = ty.dyn_cast<BoxType>()) 85 return needsConversion(boxTy.getEleTy()); 86 if (isa_ref_type(ty)) 87 return needsConversion(unwrapRefType(ty)); 88 if (auto t = ty.dyn_cast<SequenceType>()) 89 return needsConversion(unwrapSequenceType(ty)); 90 return false; 91 } 92 93 BoxprocTypeRewriter(mlir::Location location) : loc{location} { 94 addConversion([](mlir::Type ty) { return ty; }); 95 addConversion( 96 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); 97 addConversion([&](mlir::TupleType tupTy) { 98 llvm::SmallVector<mlir::Type> memTys; 99 for (auto ty : tupTy.getTypes()) 100 memTys.push_back(convertType(ty)); 101 return mlir::TupleType::get(tupTy.getContext(), memTys); 102 }); 103 addConversion([&](mlir::FunctionType funcTy) { 104 llvm::SmallVector<mlir::Type> inTys; 105 llvm::SmallVector<mlir::Type> resTys; 106 for (auto ty : funcTy.getInputs()) 107 inTys.push_back(convertType(ty)); 108 for (auto ty : funcTy.getResults()) 109 resTys.push_back(convertType(ty)); 110 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); 111 }); 112 addConversion([&](ReferenceType ty) { 113 return ReferenceType::get(convertType(ty.getEleTy())); 114 }); 115 addConversion([&](PointerType ty) { 116 return PointerType::get(convertType(ty.getEleTy())); 117 }); 118 addConversion( 119 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); 120 addConversion( 121 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); 122 addConversion([&](SequenceType ty) { 123 // TODO: add ty.getLayoutMap() as needed. 124 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); 125 }); 126 addConversion([&](RecordType ty) -> mlir::Type { 127 if (!needsConversion(ty)) 128 return ty; 129 auto rec = RecordType::get(ty.getContext(), 130 ty.getName().str() + boxprocSuffix.str()); 131 if (rec.isFinalized()) 132 return rec; 133 std::vector<RecordType::TypePair> ps = ty.getLenParamList(); 134 std::vector<RecordType::TypePair> cs; 135 for (auto t : ty.getTypeList()) { 136 if (needsConversion(t.second)) 137 cs.emplace_back(t.first, convertType(t.second)); 138 else 139 cs.emplace_back(t.first, t.second); 140 } 141 rec.finalize(ps, cs); 142 return rec; 143 }); 144 addArgumentMaterialization(materializeProcedure); 145 addSourceMaterialization(materializeProcedure); 146 addTargetMaterialization(materializeProcedure); 147 } 148 149 static mlir::Value materializeProcedure(mlir::OpBuilder &builder, 150 BoxProcType type, 151 mlir::ValueRange inputs, 152 mlir::Location loc) { 153 assert(inputs.size() == 1); 154 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), 155 inputs[0]); 156 } 157 158 void setLocation(mlir::Location location) { loc = location; } 159 160 private: 161 llvm::SmallVector<mlir::Type> visitedTypes; 162 mlir::Location loc; 163 }; 164 165 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, 166 /// Fortran procedures can be referenced directly through a function pointer. 167 /// However, Fortran has one-level dynamic scoping between a host procedure and 168 /// its internal procedures. This allows internal procedures to directly access 169 /// and modify the state of the host procedure's variables. 170 /// 171 /// There are any number of possible implementations possible. 172 /// 173 /// The implementation used here is to convert `boxproc` values to function 174 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the 175 /// host procedure's data, then a thunk will be created at runtime to capture 176 /// the frame pointer during execution. In LLVM IR, the frame pointer is 177 /// designated with the `nest` attribute. The thunk's address will then be used 178 /// as the call target instead of the original function's address directly. 179 class BoxedProcedurePass 180 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { 181 public: 182 BoxedProcedurePass() { options = {true}; } 183 BoxedProcedurePass(bool useThunks) { options = {useThunks}; } 184 185 inline mlir::ModuleOp getModule() { return getOperation(); } 186 187 void runOnOperation() override final { 188 if (options.useThunks) { 189 auto *context = &getContext(); 190 mlir::IRRewriter rewriter(context); 191 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); 192 mlir::Dialect *firDialect = context->getLoadedDialect("fir"); 193 getModule().walk([&](mlir::Operation *op) { 194 typeConverter.setLocation(op->getLoc()); 195 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { 196 auto ty = addr.getVal().getType(); 197 if (typeConverter.needsConversion(ty) || 198 ty.isa<mlir::FunctionType>()) { 199 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` 200 // or function type to be `fir.convert` ops. 201 rewriter.setInsertionPoint(addr); 202 rewriter.replaceOpWithNewOp<ConvertOp>( 203 addr, typeConverter.convertType(addr.getType()), addr.getVal()); 204 } 205 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { 206 mlir::FunctionType ty = func.getFunctionType(); 207 if (typeConverter.needsConversion(ty)) { 208 rewriter.startRootUpdate(func); 209 auto toTy = 210 typeConverter.convertType(ty).cast<mlir::FunctionType>(); 211 if (!func.empty()) 212 for (auto e : llvm::enumerate(toTy.getInputs())) { 213 unsigned i = e.index(); 214 auto &block = func.front(); 215 block.insertArgument(i, e.value(), func.getLoc()); 216 block.getArgument(i + 1).replaceAllUsesWith( 217 block.getArgument(i)); 218 block.eraseArgument(i + 1); 219 } 220 func.setType(toTy); 221 rewriter.finalizeRootUpdate(func); 222 } 223 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { 224 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk 225 // as required. 226 mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy(); 227 rewriter.setInsertionPoint(embox); 228 if (embox.getHost()) { 229 // Create the thunk. 230 auto module = embox->getParentOfType<mlir::ModuleOp>(); 231 fir::KindMapping kindMap = getKindMapping(module); 232 FirOpBuilder builder(rewriter, kindMap); 233 auto loc = embox.getLoc(); 234 mlir::Type i8Ty = builder.getI8Type(); 235 mlir::Type i8Ptr = builder.getRefType(i8Ty); 236 mlir::Type buffTy = SequenceType::get({32}, i8Ty); 237 auto buffer = builder.create<AllocaOp>(loc, buffTy); 238 mlir::Value closure = 239 builder.createConvert(loc, i8Ptr, embox.getHost()); 240 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); 241 mlir::Value func = 242 builder.createConvert(loc, i8Ptr, embox.getFunc()); 243 builder.create<fir::CallOp>( 244 loc, factory::getLlvmInitTrampoline(builder), 245 llvm::ArrayRef<mlir::Value>{tramp, func, closure}); 246 auto adjustCall = builder.create<fir::CallOp>( 247 loc, factory::getLlvmAdjustTrampoline(builder), 248 llvm::ArrayRef<mlir::Value>{tramp}); 249 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 250 adjustCall.getResult(0)); 251 } else { 252 // Just forward the function as a pointer. 253 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 254 embox.getFunc()); 255 } 256 } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { 257 auto ty = global.getType(); 258 if (typeConverter.needsConversion(ty)) { 259 rewriter.startRootUpdate(global); 260 auto toTy = typeConverter.convertType(ty); 261 global.setType(toTy); 262 rewriter.finalizeRootUpdate(global); 263 } 264 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { 265 auto ty = mem.getType(); 266 if (typeConverter.needsConversion(ty)) { 267 rewriter.setInsertionPoint(mem); 268 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 269 bool isPinned = mem.getPinned(); 270 llvm::StringRef uniqName = 271 mem.getUniqName().value_or(llvm::StringRef()); 272 llvm::StringRef bindcName = 273 mem.getBindcName().value_or(llvm::StringRef()); 274 rewriter.replaceOpWithNewOp<AllocaOp>( 275 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), 276 mem.getShape()); 277 } 278 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { 279 auto ty = mem.getType(); 280 if (typeConverter.needsConversion(ty)) { 281 rewriter.setInsertionPoint(mem); 282 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 283 llvm::StringRef uniqName = 284 mem.getUniqName().value_or(llvm::StringRef()); 285 llvm::StringRef bindcName = 286 mem.getBindcName().value_or(llvm::StringRef()); 287 rewriter.replaceOpWithNewOp<AllocMemOp>( 288 mem, toTy, uniqName, bindcName, mem.getTypeparams(), 289 mem.getShape()); 290 } 291 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { 292 auto ty = coor.getType(); 293 mlir::Type baseTy = coor.getBaseType(); 294 if (typeConverter.needsConversion(ty) || 295 typeConverter.needsConversion(baseTy)) { 296 rewriter.setInsertionPoint(coor); 297 auto toTy = typeConverter.convertType(ty); 298 auto toBaseTy = typeConverter.convertType(baseTy); 299 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), 300 coor.getCoor(), toBaseTy); 301 } 302 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { 303 auto ty = index.getType(); 304 mlir::Type onTy = index.getOnType(); 305 if (typeConverter.needsConversion(ty) || 306 typeConverter.needsConversion(onTy)) { 307 rewriter.setInsertionPoint(index); 308 auto toTy = typeConverter.convertType(ty); 309 auto toOnTy = typeConverter.convertType(onTy); 310 rewriter.replaceOpWithNewOp<FieldIndexOp>( 311 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 312 } 313 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { 314 auto ty = index.getType(); 315 mlir::Type onTy = index.getOnType(); 316 if (typeConverter.needsConversion(ty) || 317 typeConverter.needsConversion(onTy)) { 318 rewriter.setInsertionPoint(index); 319 auto toTy = typeConverter.convertType(ty); 320 auto toOnTy = typeConverter.convertType(onTy); 321 rewriter.replaceOpWithNewOp<LenParamIndexOp>( 322 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 323 } 324 } else if (op->getDialect() == firDialect) { 325 rewriter.startRootUpdate(op); 326 for (auto i : llvm::enumerate(op->getResultTypes())) 327 if (typeConverter.needsConversion(i.value())) { 328 auto toTy = typeConverter.convertType(i.value()); 329 op->getResult(i.index()).setType(toTy); 330 } 331 rewriter.finalizeRootUpdate(op); 332 } 333 }); 334 } 335 } 336 337 private: 338 BoxedProcedureOptions options; 339 }; 340 } // namespace 341 342 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { 343 return std::make_unique<BoxedProcedurePass>(); 344 } 345 346 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { 347 return std::make_unique<BoxedProcedurePass>(useThunks); 348 } 349