1 //===-- BoxedProcedure.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/CodeGen/CodeGen.h" 10 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Support/FIRContext.h" 17 #include "flang/Optimizer/Support/FatalError.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 namespace fir { 23 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS 24 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 25 } // namespace fir 26 27 #define DEBUG_TYPE "flang-procedure-pointer" 28 29 using namespace fir; 30 31 namespace { 32 /// Options to the procedure pointer pass. 33 struct BoxedProcedureOptions { 34 // Lower the boxproc abstraction to function pointers and thunks where 35 // required. 36 bool useThunks = true; 37 }; 38 39 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. 40 class BoxprocTypeRewriter : public mlir::TypeConverter { 41 public: 42 using mlir::TypeConverter::convertType; 43 44 /// Does the type \p ty need to be converted? 45 /// Any type that is a `!fir.boxproc` in whole or in part will need to be 46 /// converted to a function type to lower the IR to function pointer form in 47 /// the default implementation performed in this pass. Other implementations 48 /// are possible, so those may convert `!fir.boxproc` to some other type or 49 /// not at all depending on the implementation target's characteristics and 50 /// preference. 51 bool needsConversion(mlir::Type ty) { 52 if (ty.isa<BoxProcType>()) 53 return true; 54 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) { 55 for (auto t : funcTy.getInputs()) 56 if (needsConversion(t)) 57 return true; 58 for (auto t : funcTy.getResults()) 59 if (needsConversion(t)) 60 return true; 61 return false; 62 } 63 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) { 64 for (auto t : tupleTy.getTypes()) 65 if (needsConversion(t)) 66 return true; 67 return false; 68 } 69 if (auto recTy = ty.dyn_cast<RecordType>()) { 70 if (llvm::is_contained(visitedTypes, recTy)) 71 return false; 72 bool result = false; 73 visitedTypes.push_back(recTy); 74 for (auto t : recTy.getTypeList()) { 75 if (needsConversion(t.second)) { 76 result = true; 77 break; 78 } 79 } 80 visitedTypes.pop_back(); 81 return result; 82 } 83 if (auto boxTy = ty.dyn_cast<BoxType>()) 84 return needsConversion(boxTy.getEleTy()); 85 if (isa_ref_type(ty)) 86 return needsConversion(unwrapRefType(ty)); 87 if (auto t = ty.dyn_cast<SequenceType>()) 88 return needsConversion(unwrapSequenceType(ty)); 89 return false; 90 } 91 92 BoxprocTypeRewriter(mlir::Location location) : loc{location} { 93 addConversion([](mlir::Type ty) { return ty; }); 94 addConversion( 95 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); 96 addConversion([&](mlir::TupleType tupTy) { 97 llvm::SmallVector<mlir::Type> memTys; 98 for (auto ty : tupTy.getTypes()) 99 memTys.push_back(convertType(ty)); 100 return mlir::TupleType::get(tupTy.getContext(), memTys); 101 }); 102 addConversion([&](mlir::FunctionType funcTy) { 103 llvm::SmallVector<mlir::Type> inTys; 104 llvm::SmallVector<mlir::Type> resTys; 105 for (auto ty : funcTy.getInputs()) 106 inTys.push_back(convertType(ty)); 107 for (auto ty : funcTy.getResults()) 108 resTys.push_back(convertType(ty)); 109 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); 110 }); 111 addConversion([&](ReferenceType ty) { 112 return ReferenceType::get(convertType(ty.getEleTy())); 113 }); 114 addConversion([&](PointerType ty) { 115 return PointerType::get(convertType(ty.getEleTy())); 116 }); 117 addConversion( 118 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); 119 addConversion( 120 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); 121 addConversion([&](SequenceType ty) { 122 // TODO: add ty.getLayoutMap() as needed. 123 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); 124 }); 125 addConversion([&](RecordType ty) -> mlir::Type { 126 if (!needsConversion(ty)) 127 return ty; 128 // FIR record types can have recursive references, so conversion is a bit 129 // more complex than the other types. This conversion is not needed 130 // presently, so just emit a TODO message. Need to consider the uniqued 131 // name of the record, etc. Also, fir::RecordType::get returns the 132 // existing type being translated. So finalize() will not change it, and 133 // the translation would not do anything. So the type needs to be mutated, 134 // and this might require special care to comply with MLIR infrastructure. 135 136 // TODO: this will be needed to support derived type containing procedure 137 // pointer components. 138 fir::emitFatalError( 139 loc, "not yet implemented: record type with a boxproc type"); 140 return RecordType::get(ty.getContext(), "*fixme*"); 141 }); 142 addArgumentMaterialization(materializeProcedure); 143 addSourceMaterialization(materializeProcedure); 144 addTargetMaterialization(materializeProcedure); 145 } 146 147 static mlir::Value materializeProcedure(mlir::OpBuilder &builder, 148 BoxProcType type, 149 mlir::ValueRange inputs, 150 mlir::Location loc) { 151 assert(inputs.size() == 1); 152 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), 153 inputs[0]); 154 } 155 156 void setLocation(mlir::Location location) { loc = location; } 157 158 private: 159 llvm::SmallVector<mlir::Type> visitedTypes; 160 mlir::Location loc; 161 }; 162 163 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, 164 /// Fortran procedures can be referenced directly through a function pointer. 165 /// However, Fortran has one-level dynamic scoping between a host procedure and 166 /// its internal procedures. This allows internal procedures to directly access 167 /// and modify the state of the host procedure's variables. 168 /// 169 /// There are any number of possible implementations possible. 170 /// 171 /// The implementation used here is to convert `boxproc` values to function 172 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the 173 /// host procedure's data, then a thunk will be created at runtime to capture 174 /// the frame pointer during execution. In LLVM IR, the frame pointer is 175 /// designated with the `nest` attribute. The thunk's address will then be used 176 /// as the call target instead of the original function's address directly. 177 class BoxedProcedurePass 178 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { 179 public: 180 BoxedProcedurePass() { options = {true}; } 181 BoxedProcedurePass(bool useThunks) { options = {useThunks}; } 182 183 inline mlir::ModuleOp getModule() { return getOperation(); } 184 185 void runOnOperation() override final { 186 if (options.useThunks) { 187 auto *context = &getContext(); 188 mlir::IRRewriter rewriter(context); 189 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); 190 mlir::Dialect *firDialect = context->getLoadedDialect("fir"); 191 getModule().walk([&](mlir::Operation *op) { 192 typeConverter.setLocation(op->getLoc()); 193 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { 194 auto ty = addr.getVal().getType(); 195 if (typeConverter.needsConversion(ty) || 196 ty.isa<mlir::FunctionType>()) { 197 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` 198 // or function type to be `fir.convert` ops. 199 rewriter.setInsertionPoint(addr); 200 rewriter.replaceOpWithNewOp<ConvertOp>( 201 addr, typeConverter.convertType(addr.getType()), addr.getVal()); 202 } 203 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { 204 mlir::FunctionType ty = func.getFunctionType(); 205 if (typeConverter.needsConversion(ty)) { 206 rewriter.startRootUpdate(func); 207 auto toTy = 208 typeConverter.convertType(ty).cast<mlir::FunctionType>(); 209 if (!func.empty()) 210 for (auto e : llvm::enumerate(toTy.getInputs())) { 211 unsigned i = e.index(); 212 auto &block = func.front(); 213 block.insertArgument(i, e.value(), func.getLoc()); 214 block.getArgument(i + 1).replaceAllUsesWith( 215 block.getArgument(i)); 216 block.eraseArgument(i + 1); 217 } 218 func.setType(toTy); 219 rewriter.finalizeRootUpdate(func); 220 } 221 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { 222 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk 223 // as required. 224 mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy(); 225 rewriter.setInsertionPoint(embox); 226 if (embox.getHost()) { 227 // Create the thunk. 228 auto module = embox->getParentOfType<mlir::ModuleOp>(); 229 fir::KindMapping kindMap = getKindMapping(module); 230 FirOpBuilder builder(rewriter, kindMap); 231 auto loc = embox.getLoc(); 232 mlir::Type i8Ty = builder.getI8Type(); 233 mlir::Type i8Ptr = builder.getRefType(i8Ty); 234 mlir::Type buffTy = SequenceType::get({32}, i8Ty); 235 auto buffer = builder.create<AllocaOp>(loc, buffTy); 236 mlir::Value closure = 237 builder.createConvert(loc, i8Ptr, embox.getHost()); 238 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); 239 mlir::Value func = 240 builder.createConvert(loc, i8Ptr, embox.getFunc()); 241 builder.create<fir::CallOp>( 242 loc, factory::getLlvmInitTrampoline(builder), 243 llvm::ArrayRef<mlir::Value>{tramp, func, closure}); 244 auto adjustCall = builder.create<fir::CallOp>( 245 loc, factory::getLlvmAdjustTrampoline(builder), 246 llvm::ArrayRef<mlir::Value>{tramp}); 247 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 248 adjustCall.getResult(0)); 249 } else { 250 // Just forward the function as a pointer. 251 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 252 embox.getFunc()); 253 } 254 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { 255 auto ty = mem.getType(); 256 if (typeConverter.needsConversion(ty)) { 257 rewriter.setInsertionPoint(mem); 258 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 259 bool isPinned = mem.getPinned(); 260 llvm::StringRef uniqName = 261 mem.getUniqName().value_or(llvm::StringRef()); 262 llvm::StringRef bindcName = 263 mem.getBindcName().value_or(llvm::StringRef()); 264 rewriter.replaceOpWithNewOp<AllocaOp>( 265 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), 266 mem.getShape()); 267 } 268 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { 269 auto ty = mem.getType(); 270 if (typeConverter.needsConversion(ty)) { 271 rewriter.setInsertionPoint(mem); 272 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 273 llvm::StringRef uniqName = 274 mem.getUniqName().value_or(llvm::StringRef()); 275 llvm::StringRef bindcName = 276 mem.getBindcName().value_or(llvm::StringRef()); 277 rewriter.replaceOpWithNewOp<AllocMemOp>( 278 mem, toTy, uniqName, bindcName, mem.getTypeparams(), 279 mem.getShape()); 280 } 281 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { 282 auto ty = coor.getType(); 283 mlir::Type baseTy = coor.getBaseType(); 284 if (typeConverter.needsConversion(ty) || 285 typeConverter.needsConversion(baseTy)) { 286 rewriter.setInsertionPoint(coor); 287 auto toTy = typeConverter.convertType(ty); 288 auto toBaseTy = typeConverter.convertType(baseTy); 289 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), 290 coor.getCoor(), toBaseTy); 291 } 292 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { 293 auto ty = index.getType(); 294 mlir::Type onTy = index.getOnType(); 295 if (typeConverter.needsConversion(ty) || 296 typeConverter.needsConversion(onTy)) { 297 rewriter.setInsertionPoint(index); 298 auto toTy = typeConverter.convertType(ty); 299 auto toOnTy = typeConverter.convertType(onTy); 300 rewriter.replaceOpWithNewOp<FieldIndexOp>( 301 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 302 } 303 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { 304 auto ty = index.getType(); 305 mlir::Type onTy = index.getOnType(); 306 if (typeConverter.needsConversion(ty) || 307 typeConverter.needsConversion(onTy)) { 308 rewriter.setInsertionPoint(index); 309 auto toTy = typeConverter.convertType(ty); 310 auto toOnTy = typeConverter.convertType(onTy); 311 rewriter.replaceOpWithNewOp<LenParamIndexOp>( 312 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 313 } 314 } else if (op->getDialect() == firDialect) { 315 rewriter.startRootUpdate(op); 316 for (auto i : llvm::enumerate(op->getResultTypes())) 317 if (typeConverter.needsConversion(i.value())) { 318 auto toTy = typeConverter.convertType(i.value()); 319 op->getResult(i.index()).setType(toTy); 320 } 321 rewriter.finalizeRootUpdate(op); 322 } 323 }); 324 } 325 // TODO: any alternative implementation. Note: currently, the default code 326 // gen will not be able to handle boxproc and will give an error. 327 } 328 329 private: 330 BoxedProcedureOptions options; 331 }; 332 } // namespace 333 334 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { 335 return std::make_unique<BoxedProcedurePass>(); 336 } 337 338 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { 339 return std::make_unique<BoxedProcedurePass>(useThunks); 340 } 341