1 //===-- BoxedProcedure.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/CodeGen/CodeGen.h" 10 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 17 #include "flang/Optimizer/Support/FatalError.h" 18 #include "flang/Optimizer/Support/InternalNames.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/DenseMap.h" 23 24 namespace fir { 25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS 26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-procedure-pointer" 30 31 using namespace fir; 32 33 namespace { 34 /// Options to the procedure pointer pass. 35 struct BoxedProcedureOptions { 36 // Lower the boxproc abstraction to function pointers and thunks where 37 // required. 38 bool useThunks = true; 39 }; 40 41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. 42 class BoxprocTypeRewriter : public mlir::TypeConverter { 43 public: 44 using mlir::TypeConverter::convertType; 45 46 /// Does the type \p ty need to be converted? 47 /// Any type that is a `!fir.boxproc` in whole or in part will need to be 48 /// converted to a function type to lower the IR to function pointer form in 49 /// the default implementation performed in this pass. Other implementations 50 /// are possible, so those may convert `!fir.boxproc` to some other type or 51 /// not at all depending on the implementation target's characteristics and 52 /// preference. 53 bool needsConversion(mlir::Type ty) { 54 if (mlir::isa<BoxProcType>(ty)) 55 return true; 56 if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) { 57 for (auto t : funcTy.getInputs()) 58 if (needsConversion(t)) 59 return true; 60 for (auto t : funcTy.getResults()) 61 if (needsConversion(t)) 62 return true; 63 return false; 64 } 65 if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) { 66 for (auto t : tupleTy.getTypes()) 67 if (needsConversion(t)) 68 return true; 69 return false; 70 } 71 if (auto recTy = mlir::dyn_cast<RecordType>(ty)) { 72 auto [visited, inserted] = visitedTypes.try_emplace(ty, false); 73 if (!inserted) 74 return visited->second; 75 bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType; 76 needConversionIsVisitingRecordType = true; 77 bool result = false; 78 for (auto t : recTy.getTypeList()) { 79 if (needsConversion(t.second)) { 80 result = true; 81 break; 82 } 83 } 84 // Only keep the result cached if the fir.type visited was a "top-level 85 // type". Nested types with a recursive reference to the "top-level type" 86 // may incorrectly have been resolved as not needed conversions because it 87 // had not been determined yet if the "top-level type" needed conversion. 88 // This is not an issue to determine the "top-level type" need of 89 // conversion, but the result should not be kept and later used in other 90 // contexts. 91 needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType; 92 if (needConversionIsVisitingRecordType) 93 visitedTypes.erase(ty); 94 else 95 visitedTypes.find(ty)->second = result; 96 return result; 97 } 98 if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty)) 99 return needsConversion(boxTy.getEleTy()); 100 if (isa_ref_type(ty)) 101 return needsConversion(unwrapRefType(ty)); 102 if (auto t = mlir::dyn_cast<SequenceType>(ty)) 103 return needsConversion(unwrapSequenceType(ty)); 104 if (auto t = mlir::dyn_cast<TypeDescType>(ty)) 105 return needsConversion(t.getOfTy()); 106 return false; 107 } 108 109 BoxprocTypeRewriter(mlir::Location location) : loc{location} { 110 addConversion([](mlir::Type ty) { return ty; }); 111 addConversion( 112 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); 113 addConversion([&](mlir::TupleType tupTy) { 114 llvm::SmallVector<mlir::Type> memTys; 115 for (auto ty : tupTy.getTypes()) 116 memTys.push_back(convertType(ty)); 117 return mlir::TupleType::get(tupTy.getContext(), memTys); 118 }); 119 addConversion([&](mlir::FunctionType funcTy) { 120 llvm::SmallVector<mlir::Type> inTys; 121 llvm::SmallVector<mlir::Type> resTys; 122 for (auto ty : funcTy.getInputs()) 123 inTys.push_back(convertType(ty)); 124 for (auto ty : funcTy.getResults()) 125 resTys.push_back(convertType(ty)); 126 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); 127 }); 128 addConversion([&](ReferenceType ty) { 129 return ReferenceType::get(convertType(ty.getEleTy())); 130 }); 131 addConversion([&](PointerType ty) { 132 return PointerType::get(convertType(ty.getEleTy())); 133 }); 134 addConversion( 135 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); 136 addConversion([&](fir::LLVMPointerType ty) { 137 return fir::LLVMPointerType::get(convertType(ty.getEleTy())); 138 }); 139 addConversion( 140 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); 141 addConversion([&](ClassType ty) { 142 return ClassType::get(convertType(ty.getEleTy())); 143 }); 144 addConversion([&](SequenceType ty) { 145 // TODO: add ty.getLayoutMap() as needed. 146 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); 147 }); 148 addConversion([&](RecordType ty) -> mlir::Type { 149 if (!needsConversion(ty)) 150 return ty; 151 if (auto converted = convertedTypes.lookup(ty)) 152 return converted; 153 auto rec = RecordType::get(ty.getContext(), 154 ty.getName().str() + boxprocSuffix.str()); 155 if (rec.isFinalized()) 156 return rec; 157 [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec); 158 assert(it.second && "expected ty to not be in the map"); 159 std::vector<RecordType::TypePair> ps = ty.getLenParamList(); 160 std::vector<RecordType::TypePair> cs; 161 for (auto t : ty.getTypeList()) { 162 if (needsConversion(t.second)) 163 cs.emplace_back(t.first, convertType(t.second)); 164 else 165 cs.emplace_back(t.first, t.second); 166 } 167 rec.finalize(ps, cs); 168 rec.pack(ty.isPacked()); 169 return rec; 170 }); 171 addConversion([&](TypeDescType ty) { 172 return TypeDescType::get(convertType(ty.getOfTy())); 173 }); 174 addSourceMaterialization(materializeProcedure); 175 addTargetMaterialization(materializeProcedure); 176 } 177 178 static mlir::Value materializeProcedure(mlir::OpBuilder &builder, 179 BoxProcType type, 180 mlir::ValueRange inputs, 181 mlir::Location loc) { 182 assert(inputs.size() == 1); 183 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), 184 inputs[0]); 185 } 186 187 void setLocation(mlir::Location location) { loc = location; } 188 189 private: 190 // Maps to deal with recursive derived types (avoid infinite loops). 191 // Caching is also beneficial for apps with big types (dozens of 192 // components and or parent types), so the lifetime of the cache 193 // is the whole pass. 194 llvm::DenseMap<mlir::Type, bool> visitedTypes; 195 bool needConversionIsVisitingRecordType = false; 196 llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes; 197 mlir::Location loc; 198 }; 199 200 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, 201 /// Fortran procedures can be referenced directly through a function pointer. 202 /// However, Fortran has one-level dynamic scoping between a host procedure and 203 /// its internal procedures. This allows internal procedures to directly access 204 /// and modify the state of the host procedure's variables. 205 /// 206 /// There are any number of possible implementations possible. 207 /// 208 /// The implementation used here is to convert `boxproc` values to function 209 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the 210 /// host procedure's data, then a thunk will be created at runtime to capture 211 /// the frame pointer during execution. In LLVM IR, the frame pointer is 212 /// designated with the `nest` attribute. The thunk's address will then be used 213 /// as the call target instead of the original function's address directly. 214 class BoxedProcedurePass 215 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { 216 public: 217 using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase; 218 219 inline mlir::ModuleOp getModule() { return getOperation(); } 220 221 void runOnOperation() override final { 222 if (options.useThunks) { 223 auto *context = &getContext(); 224 mlir::IRRewriter rewriter(context); 225 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); 226 getModule().walk([&](mlir::Operation *op) { 227 bool opIsValid = true; 228 typeConverter.setLocation(op->getLoc()); 229 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { 230 mlir::Type ty = addr.getVal().getType(); 231 mlir::Type resTy = addr.getResult().getType(); 232 if (llvm::isa<mlir::FunctionType>(ty) || 233 llvm::isa<fir::BoxProcType>(ty)) { 234 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` 235 // or function type to be `fir.convert` ops. 236 rewriter.setInsertionPoint(addr); 237 rewriter.replaceOpWithNewOp<ConvertOp>( 238 addr, typeConverter.convertType(addr.getType()), addr.getVal()); 239 opIsValid = false; 240 } else if (typeConverter.needsConversion(resTy)) { 241 rewriter.startOpModification(op); 242 op->getResult(0).setType(typeConverter.convertType(resTy)); 243 rewriter.finalizeOpModification(op); 244 } 245 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { 246 mlir::FunctionType ty = func.getFunctionType(); 247 if (typeConverter.needsConversion(ty)) { 248 rewriter.startOpModification(func); 249 auto toTy = 250 mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty)); 251 if (!func.empty()) 252 for (auto e : llvm::enumerate(toTy.getInputs())) { 253 unsigned i = e.index(); 254 auto &block = func.front(); 255 block.insertArgument(i, e.value(), func.getLoc()); 256 block.getArgument(i + 1).replaceAllUsesWith( 257 block.getArgument(i)); 258 block.eraseArgument(i + 1); 259 } 260 func.setType(toTy); 261 rewriter.finalizeOpModification(func); 262 } 263 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { 264 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk 265 // as required. 266 mlir::Type toTy = typeConverter.convertType( 267 mlir::cast<BoxProcType>(embox.getType()).getEleTy()); 268 rewriter.setInsertionPoint(embox); 269 if (embox.getHost()) { 270 // Create the thunk. 271 auto module = embox->getParentOfType<mlir::ModuleOp>(); 272 FirOpBuilder builder(rewriter, module); 273 const auto triple{fir::getTargetTriple(module)}; 274 auto loc = embox.getLoc(); 275 mlir::Type i8Ty = builder.getI8Type(); 276 mlir::Type i8Ptr = builder.getRefType(i8Ty); 277 // For AArch64, PPC32 and PPC64, the thunk is populated by a call to 278 // __trampoline_setup, which is defined in 279 // compiler-rt/lib/builtins/trampoline_setup.c and requires the 280 // thunk size greater than 32 bytes. For RISCV and x86_64, the 281 // thunk setup doesn't go through __trampoline_setup and fits in 32 282 // bytes. 283 fir::SequenceType::Extent thunkSize = triple.getTrampolineSize(); 284 mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty); 285 auto buffer = builder.create<AllocaOp>(loc, buffTy); 286 mlir::Value closure = 287 builder.createConvert(loc, i8Ptr, embox.getHost()); 288 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); 289 mlir::Value func = 290 builder.createConvert(loc, i8Ptr, embox.getFunc()); 291 builder.create<fir::CallOp>( 292 loc, factory::getLlvmInitTrampoline(builder), 293 llvm::ArrayRef<mlir::Value>{tramp, func, closure}); 294 auto adjustCall = builder.create<fir::CallOp>( 295 loc, factory::getLlvmAdjustTrampoline(builder), 296 llvm::ArrayRef<mlir::Value>{tramp}); 297 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 298 adjustCall.getResult(0)); 299 opIsValid = false; 300 } else { 301 // Just forward the function as a pointer. 302 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 303 embox.getFunc()); 304 opIsValid = false; 305 } 306 } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { 307 auto ty = global.getType(); 308 if (typeConverter.needsConversion(ty)) { 309 rewriter.startOpModification(global); 310 auto toTy = typeConverter.convertType(ty); 311 global.setType(toTy); 312 rewriter.finalizeOpModification(global); 313 } 314 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { 315 auto ty = mem.getType(); 316 if (typeConverter.needsConversion(ty)) { 317 rewriter.setInsertionPoint(mem); 318 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 319 bool isPinned = mem.getPinned(); 320 llvm::StringRef uniqName = 321 mem.getUniqName().value_or(llvm::StringRef()); 322 llvm::StringRef bindcName = 323 mem.getBindcName().value_or(llvm::StringRef()); 324 rewriter.replaceOpWithNewOp<AllocaOp>( 325 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), 326 mem.getShape()); 327 opIsValid = false; 328 } 329 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { 330 auto ty = mem.getType(); 331 if (typeConverter.needsConversion(ty)) { 332 rewriter.setInsertionPoint(mem); 333 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 334 llvm::StringRef uniqName = 335 mem.getUniqName().value_or(llvm::StringRef()); 336 llvm::StringRef bindcName = 337 mem.getBindcName().value_or(llvm::StringRef()); 338 rewriter.replaceOpWithNewOp<AllocMemOp>( 339 mem, toTy, uniqName, bindcName, mem.getTypeparams(), 340 mem.getShape()); 341 opIsValid = false; 342 } 343 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { 344 auto ty = coor.getType(); 345 mlir::Type baseTy = coor.getBaseType(); 346 if (typeConverter.needsConversion(ty) || 347 typeConverter.needsConversion(baseTy)) { 348 rewriter.setInsertionPoint(coor); 349 auto toTy = typeConverter.convertType(ty); 350 auto toBaseTy = typeConverter.convertType(baseTy); 351 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), 352 coor.getCoor(), toBaseTy); 353 opIsValid = false; 354 } 355 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { 356 auto ty = index.getType(); 357 mlir::Type onTy = index.getOnType(); 358 if (typeConverter.needsConversion(ty) || 359 typeConverter.needsConversion(onTy)) { 360 rewriter.setInsertionPoint(index); 361 auto toTy = typeConverter.convertType(ty); 362 auto toOnTy = typeConverter.convertType(onTy); 363 rewriter.replaceOpWithNewOp<FieldIndexOp>( 364 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 365 opIsValid = false; 366 } 367 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { 368 auto ty = index.getType(); 369 mlir::Type onTy = index.getOnType(); 370 if (typeConverter.needsConversion(ty) || 371 typeConverter.needsConversion(onTy)) { 372 rewriter.setInsertionPoint(index); 373 auto toTy = typeConverter.convertType(ty); 374 auto toOnTy = typeConverter.convertType(onTy); 375 rewriter.replaceOpWithNewOp<LenParamIndexOp>( 376 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 377 opIsValid = false; 378 } 379 } else { 380 rewriter.startOpModification(op); 381 // Convert the operands if needed 382 for (auto i : llvm::enumerate(op->getResultTypes())) 383 if (typeConverter.needsConversion(i.value())) { 384 auto toTy = typeConverter.convertType(i.value()); 385 op->getResult(i.index()).setType(toTy); 386 } 387 388 // Convert the type attributes if needed 389 for (const mlir::NamedAttribute &attr : op->getAttrDictionary()) 390 if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue())) 391 if (typeConverter.needsConversion(tyAttr.getValue())) { 392 auto toTy = typeConverter.convertType(tyAttr.getValue()); 393 op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy)); 394 } 395 rewriter.finalizeOpModification(op); 396 } 397 // Ensure block arguments are updated if needed. 398 if (opIsValid && op->getNumRegions() != 0) { 399 rewriter.startOpModification(op); 400 for (mlir::Region ®ion : op->getRegions()) 401 for (mlir::Block &block : region.getBlocks()) 402 for (mlir::BlockArgument blockArg : block.getArguments()) 403 if (typeConverter.needsConversion(blockArg.getType())) { 404 mlir::Type toTy = 405 typeConverter.convertType(blockArg.getType()); 406 blockArg.setType(toTy); 407 } 408 rewriter.finalizeOpModification(op); 409 } 410 }); 411 } 412 } 413 414 private: 415 BoxedProcedureOptions options; 416 }; 417 } // namespace 418