1 //===-- BoxedProcedure.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/CodeGen/CodeGen.h" 10 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 17 #include "flang/Optimizer/Support/FatalError.h" 18 #include "flang/Optimizer/Support/InternalNames.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/MapVector.h" 23 24 namespace fir { 25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS 26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-procedure-pointer" 30 31 using namespace fir; 32 33 namespace { 34 /// Options to the procedure pointer pass. 35 struct BoxedProcedureOptions { 36 // Lower the boxproc abstraction to function pointers and thunks where 37 // required. 38 bool useThunks = true; 39 }; 40 41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. 42 class BoxprocTypeRewriter : public mlir::TypeConverter { 43 public: 44 using mlir::TypeConverter::convertType; 45 46 /// Does the type \p ty need to be converted? 47 /// Any type that is a `!fir.boxproc` in whole or in part will need to be 48 /// converted to a function type to lower the IR to function pointer form in 49 /// the default implementation performed in this pass. Other implementations 50 /// are possible, so those may convert `!fir.boxproc` to some other type or 51 /// not at all depending on the implementation target's characteristics and 52 /// preference. 53 bool needsConversion(mlir::Type ty) { 54 if (ty.isa<BoxProcType>()) 55 return true; 56 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) { 57 for (auto t : funcTy.getInputs()) 58 if (needsConversion(t)) 59 return true; 60 for (auto t : funcTy.getResults()) 61 if (needsConversion(t)) 62 return true; 63 return false; 64 } 65 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) { 66 for (auto t : tupleTy.getTypes()) 67 if (needsConversion(t)) 68 return true; 69 return false; 70 } 71 if (auto recTy = ty.dyn_cast<RecordType>()) { 72 if (llvm::is_contained(visitedTypes, recTy)) 73 return false; 74 bool result = false; 75 visitedTypes.push_back(recTy); 76 for (auto t : recTy.getTypeList()) { 77 if (needsConversion(t.second)) { 78 result = true; 79 break; 80 } 81 } 82 visitedTypes.pop_back(); 83 return result; 84 } 85 if (auto boxTy = ty.dyn_cast<BaseBoxType>()) 86 return needsConversion(boxTy.getEleTy()); 87 if (isa_ref_type(ty)) 88 return needsConversion(unwrapRefType(ty)); 89 if (auto t = ty.dyn_cast<SequenceType>()) 90 return needsConversion(unwrapSequenceType(ty)); 91 return false; 92 } 93 94 BoxprocTypeRewriter(mlir::Location location) : loc{location} { 95 addConversion([](mlir::Type ty) { return ty; }); 96 addConversion( 97 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); 98 addConversion([&](mlir::TupleType tupTy) { 99 llvm::SmallVector<mlir::Type> memTys; 100 for (auto ty : tupTy.getTypes()) 101 memTys.push_back(convertType(ty)); 102 return mlir::TupleType::get(tupTy.getContext(), memTys); 103 }); 104 addConversion([&](mlir::FunctionType funcTy) { 105 llvm::SmallVector<mlir::Type> inTys; 106 llvm::SmallVector<mlir::Type> resTys; 107 for (auto ty : funcTy.getInputs()) 108 inTys.push_back(convertType(ty)); 109 for (auto ty : funcTy.getResults()) 110 resTys.push_back(convertType(ty)); 111 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); 112 }); 113 addConversion([&](ReferenceType ty) { 114 return ReferenceType::get(convertType(ty.getEleTy())); 115 }); 116 addConversion([&](PointerType ty) { 117 return PointerType::get(convertType(ty.getEleTy())); 118 }); 119 addConversion( 120 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); 121 addConversion([&](fir::LLVMPointerType ty) { 122 return fir::LLVMPointerType::get(convertType(ty.getEleTy())); 123 }); 124 addConversion( 125 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); 126 addConversion([&](ClassType ty) { 127 return ClassType::get(convertType(ty.getEleTy())); 128 }); 129 addConversion([&](SequenceType ty) { 130 // TODO: add ty.getLayoutMap() as needed. 131 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); 132 }); 133 addConversion([&](RecordType ty) -> mlir::Type { 134 if (!needsConversion(ty)) 135 return ty; 136 if (auto converted = typeInConversion.lookup(ty)) 137 return converted; 138 auto rec = RecordType::get(ty.getContext(), 139 ty.getName().str() + boxprocSuffix.str()); 140 if (rec.isFinalized()) 141 return rec; 142 auto it = typeInConversion.try_emplace(ty, rec); 143 std::vector<RecordType::TypePair> ps = ty.getLenParamList(); 144 std::vector<RecordType::TypePair> cs; 145 for (auto t : ty.getTypeList()) { 146 if (needsConversion(t.second)) 147 cs.emplace_back(t.first, convertType(t.second)); 148 else 149 cs.emplace_back(t.first, t.second); 150 } 151 rec.finalize(ps, cs); 152 typeInConversion.erase(it.first); 153 return rec; 154 }); 155 addArgumentMaterialization(materializeProcedure); 156 addSourceMaterialization(materializeProcedure); 157 addTargetMaterialization(materializeProcedure); 158 } 159 160 static mlir::Value materializeProcedure(mlir::OpBuilder &builder, 161 BoxProcType type, 162 mlir::ValueRange inputs, 163 mlir::Location loc) { 164 assert(inputs.size() == 1); 165 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), 166 inputs[0]); 167 } 168 169 void setLocation(mlir::Location location) { loc = location; } 170 171 private: 172 llvm::SmallVector<mlir::Type> visitedTypes; 173 llvm::SmallMapVector<mlir::Type, mlir::Type, 8> typeInConversion; 174 mlir::Location loc; 175 }; 176 177 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, 178 /// Fortran procedures can be referenced directly through a function pointer. 179 /// However, Fortran has one-level dynamic scoping between a host procedure and 180 /// its internal procedures. This allows internal procedures to directly access 181 /// and modify the state of the host procedure's variables. 182 /// 183 /// There are any number of possible implementations possible. 184 /// 185 /// The implementation used here is to convert `boxproc` values to function 186 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the 187 /// host procedure's data, then a thunk will be created at runtime to capture 188 /// the frame pointer during execution. In LLVM IR, the frame pointer is 189 /// designated with the `nest` attribute. The thunk's address will then be used 190 /// as the call target instead of the original function's address directly. 191 class BoxedProcedurePass 192 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { 193 public: 194 BoxedProcedurePass() { options = {true}; } 195 BoxedProcedurePass(bool useThunks) { options = {useThunks}; } 196 197 inline mlir::ModuleOp getModule() { return getOperation(); } 198 199 void runOnOperation() override final { 200 if (options.useThunks) { 201 auto *context = &getContext(); 202 mlir::IRRewriter rewriter(context); 203 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); 204 mlir::Dialect *firDialect = context->getLoadedDialect("fir"); 205 getModule().walk([&](mlir::Operation *op) { 206 typeConverter.setLocation(op->getLoc()); 207 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { 208 mlir::Type ty = addr.getVal().getType(); 209 mlir::Type resTy = addr.getResult().getType(); 210 if (typeConverter.needsConversion(ty) || 211 ty.isa<mlir::FunctionType>()) { 212 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` 213 // or function type to be `fir.convert` ops. 214 rewriter.setInsertionPoint(addr); 215 rewriter.replaceOpWithNewOp<ConvertOp>( 216 addr, typeConverter.convertType(addr.getType()), addr.getVal()); 217 } else if (typeConverter.needsConversion(resTy)) { 218 rewriter.startRootUpdate(op); 219 op->getResult(0).setType(typeConverter.convertType(resTy)); 220 rewriter.finalizeRootUpdate(op); 221 } 222 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { 223 mlir::FunctionType ty = func.getFunctionType(); 224 if (typeConverter.needsConversion(ty)) { 225 rewriter.startRootUpdate(func); 226 auto toTy = 227 typeConverter.convertType(ty).cast<mlir::FunctionType>(); 228 if (!func.empty()) 229 for (auto e : llvm::enumerate(toTy.getInputs())) { 230 unsigned i = e.index(); 231 auto &block = func.front(); 232 block.insertArgument(i, e.value(), func.getLoc()); 233 block.getArgument(i + 1).replaceAllUsesWith( 234 block.getArgument(i)); 235 block.eraseArgument(i + 1); 236 } 237 func.setType(toTy); 238 rewriter.finalizeRootUpdate(func); 239 } 240 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { 241 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk 242 // as required. 243 mlir::Type toTy = typeConverter.convertType( 244 embox.getType().cast<BoxProcType>().getEleTy()); 245 rewriter.setInsertionPoint(embox); 246 if (embox.getHost()) { 247 // Create the thunk. 248 auto module = embox->getParentOfType<mlir::ModuleOp>(); 249 FirOpBuilder builder(rewriter, module); 250 auto loc = embox.getLoc(); 251 mlir::Type i8Ty = builder.getI8Type(); 252 mlir::Type i8Ptr = builder.getRefType(i8Ty); 253 mlir::Type buffTy = SequenceType::get({32}, i8Ty); 254 auto buffer = builder.create<AllocaOp>(loc, buffTy); 255 mlir::Value closure = 256 builder.createConvert(loc, i8Ptr, embox.getHost()); 257 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); 258 mlir::Value func = 259 builder.createConvert(loc, i8Ptr, embox.getFunc()); 260 builder.create<fir::CallOp>( 261 loc, factory::getLlvmInitTrampoline(builder), 262 llvm::ArrayRef<mlir::Value>{tramp, func, closure}); 263 auto adjustCall = builder.create<fir::CallOp>( 264 loc, factory::getLlvmAdjustTrampoline(builder), 265 llvm::ArrayRef<mlir::Value>{tramp}); 266 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 267 adjustCall.getResult(0)); 268 } else { 269 // Just forward the function as a pointer. 270 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 271 embox.getFunc()); 272 } 273 } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { 274 auto ty = global.getType(); 275 if (typeConverter.needsConversion(ty)) { 276 rewriter.startRootUpdate(global); 277 auto toTy = typeConverter.convertType(ty); 278 global.setType(toTy); 279 rewriter.finalizeRootUpdate(global); 280 } 281 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { 282 auto ty = mem.getType(); 283 if (typeConverter.needsConversion(ty)) { 284 rewriter.setInsertionPoint(mem); 285 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 286 bool isPinned = mem.getPinned(); 287 llvm::StringRef uniqName = 288 mem.getUniqName().value_or(llvm::StringRef()); 289 llvm::StringRef bindcName = 290 mem.getBindcName().value_or(llvm::StringRef()); 291 rewriter.replaceOpWithNewOp<AllocaOp>( 292 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), 293 mem.getShape()); 294 } 295 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { 296 auto ty = mem.getType(); 297 if (typeConverter.needsConversion(ty)) { 298 rewriter.setInsertionPoint(mem); 299 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 300 llvm::StringRef uniqName = 301 mem.getUniqName().value_or(llvm::StringRef()); 302 llvm::StringRef bindcName = 303 mem.getBindcName().value_or(llvm::StringRef()); 304 rewriter.replaceOpWithNewOp<AllocMemOp>( 305 mem, toTy, uniqName, bindcName, mem.getTypeparams(), 306 mem.getShape()); 307 } 308 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { 309 auto ty = coor.getType(); 310 mlir::Type baseTy = coor.getBaseType(); 311 if (typeConverter.needsConversion(ty) || 312 typeConverter.needsConversion(baseTy)) { 313 rewriter.setInsertionPoint(coor); 314 auto toTy = typeConverter.convertType(ty); 315 auto toBaseTy = typeConverter.convertType(baseTy); 316 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), 317 coor.getCoor(), toBaseTy); 318 } 319 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { 320 auto ty = index.getType(); 321 mlir::Type onTy = index.getOnType(); 322 if (typeConverter.needsConversion(ty) || 323 typeConverter.needsConversion(onTy)) { 324 rewriter.setInsertionPoint(index); 325 auto toTy = typeConverter.convertType(ty); 326 auto toOnTy = typeConverter.convertType(onTy); 327 rewriter.replaceOpWithNewOp<FieldIndexOp>( 328 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 329 } 330 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { 331 auto ty = index.getType(); 332 mlir::Type onTy = index.getOnType(); 333 if (typeConverter.needsConversion(ty) || 334 typeConverter.needsConversion(onTy)) { 335 rewriter.setInsertionPoint(index); 336 auto toTy = typeConverter.convertType(ty); 337 auto toOnTy = typeConverter.convertType(onTy); 338 rewriter.replaceOpWithNewOp<LenParamIndexOp>( 339 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 340 } 341 } else if (op->getDialect() == firDialect) { 342 rewriter.startRootUpdate(op); 343 for (auto i : llvm::enumerate(op->getResultTypes())) 344 if (typeConverter.needsConversion(i.value())) { 345 auto toTy = typeConverter.convertType(i.value()); 346 op->getResult(i.index()).setType(toTy); 347 } 348 rewriter.finalizeRootUpdate(op); 349 } 350 }); 351 } 352 } 353 354 private: 355 BoxedProcedureOptions options; 356 }; 357 } // namespace 358 359 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { 360 return std::make_unique<BoxedProcedurePass>(); 361 } 362 363 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { 364 return std::make_unique<BoxedProcedurePass>(useThunks); 365 } 366