1 //===-- BoxedProcedure.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/FIRBuilder.h" 11 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 12 #include "flang/Optimizer/CodeGen/CodeGen.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Support/FIRContext.h" 17 #include "flang/Optimizer/Support/FatalError.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 #define DEBUG_TYPE "flang-procedure-pointer" 23 24 using namespace fir; 25 26 namespace { 27 /// Options to the procedure pointer pass. 28 struct BoxedProcedureOptions { 29 // Lower the boxproc abstraction to function pointers and thunks where 30 // required. 31 bool useThunks = true; 32 }; 33 34 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. 35 class BoxprocTypeRewriter : public mlir::TypeConverter { 36 public: 37 using mlir::TypeConverter::convertType; 38 39 /// Does the type \p ty need to be converted? 40 /// Any type that is a `!fir.boxproc` in whole or in part will need to be 41 /// converted to a function type to lower the IR to function pointer form in 42 /// the default implementation performed in this pass. Other implementations 43 /// are possible, so those may convert `!fir.boxproc` to some other type or 44 /// not at all depending on the implementation target's characteristics and 45 /// preference. 46 bool needsConversion(mlir::Type ty) { 47 if (ty.isa<BoxProcType>()) 48 return true; 49 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) { 50 for (auto t : funcTy.getInputs()) 51 if (needsConversion(t)) 52 return true; 53 for (auto t : funcTy.getResults()) 54 if (needsConversion(t)) 55 return true; 56 return false; 57 } 58 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) { 59 for (auto t : tupleTy.getTypes()) 60 if (needsConversion(t)) 61 return true; 62 return false; 63 } 64 if (auto recTy = ty.dyn_cast<RecordType>()) { 65 if (llvm::any_of(visitedTypes, 66 [&](mlir::Type rt) { return rt == recTy; })) 67 return false; 68 bool result = false; 69 visitedTypes.push_back(recTy); 70 for (auto t : recTy.getTypeList()) { 71 if (needsConversion(t.second)) { 72 result = true; 73 break; 74 } 75 } 76 visitedTypes.pop_back(); 77 return result; 78 } 79 if (auto boxTy = ty.dyn_cast<BoxType>()) 80 return needsConversion(boxTy.getEleTy()); 81 if (isa_ref_type(ty)) 82 return needsConversion(unwrapRefType(ty)); 83 if (auto t = ty.dyn_cast<SequenceType>()) 84 return needsConversion(unwrapSequenceType(ty)); 85 return false; 86 } 87 88 BoxprocTypeRewriter(mlir::Location location) : loc{location} { 89 addConversion([](mlir::Type ty) { return ty; }); 90 addConversion( 91 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); 92 addConversion([&](mlir::TupleType tupTy) { 93 llvm::SmallVector<mlir::Type> memTys; 94 for (auto ty : tupTy.getTypes()) 95 memTys.push_back(convertType(ty)); 96 return mlir::TupleType::get(tupTy.getContext(), memTys); 97 }); 98 addConversion([&](mlir::FunctionType funcTy) { 99 llvm::SmallVector<mlir::Type> inTys; 100 llvm::SmallVector<mlir::Type> resTys; 101 for (auto ty : funcTy.getInputs()) 102 inTys.push_back(convertType(ty)); 103 for (auto ty : funcTy.getResults()) 104 resTys.push_back(convertType(ty)); 105 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); 106 }); 107 addConversion([&](ReferenceType ty) { 108 return ReferenceType::get(convertType(ty.getEleTy())); 109 }); 110 addConversion([&](PointerType ty) { 111 return PointerType::get(convertType(ty.getEleTy())); 112 }); 113 addConversion( 114 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); 115 addConversion( 116 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); 117 addConversion([&](SequenceType ty) { 118 // TODO: add ty.getLayoutMap() as needed. 119 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); 120 }); 121 addConversion([&](RecordType ty) -> mlir::Type { 122 if (!needsConversion(ty)) 123 return ty; 124 // FIR record types can have recursive references, so conversion is a bit 125 // more complex than the other types. This conversion is not needed 126 // presently, so just emit a TODO message. Need to consider the uniqued 127 // name of the record, etc. Also, fir::RecordType::get returns the 128 // existing type being translated. So finalize() will not change it, and 129 // the translation would not do anything. So the type needs to be mutated, 130 // and this might require special care to comply with MLIR infrastructure. 131 132 // TODO: this will be needed to support derived type containing procedure 133 // pointer components. 134 fir::emitFatalError( 135 loc, "not yet implemented: record type with a boxproc type"); 136 return RecordType::get(ty.getContext(), "*fixme*"); 137 }); 138 addArgumentMaterialization(materializeProcedure); 139 addSourceMaterialization(materializeProcedure); 140 addTargetMaterialization(materializeProcedure); 141 } 142 143 static mlir::Value materializeProcedure(mlir::OpBuilder &builder, 144 BoxProcType type, 145 mlir::ValueRange inputs, 146 mlir::Location loc) { 147 assert(inputs.size() == 1); 148 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), 149 inputs[0]); 150 } 151 152 void setLocation(mlir::Location location) { loc = location; } 153 154 private: 155 llvm::SmallVector<mlir::Type> visitedTypes; 156 mlir::Location loc; 157 }; 158 159 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, 160 /// Fortran procedures can be referenced directly through a function pointer. 161 /// However, Fortran has one-level dynamic scoping between a host procedure and 162 /// its internal procedures. This allows internal procedures to directly access 163 /// and modify the state of the host procedure's variables. 164 /// 165 /// There are any number of possible implementations possible. 166 /// 167 /// The implementation used here is to convert `boxproc` values to function 168 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the 169 /// host procedure's data, then a thunk will be created at runtime to capture 170 /// the frame pointer during execution. In LLVM IR, the frame pointer is 171 /// designated with the `nest` attribute. The thunk's address will then be used 172 /// as the call target instead of the original function's address directly. 173 class BoxedProcedurePass : public BoxedProcedurePassBase<BoxedProcedurePass> { 174 public: 175 BoxedProcedurePass() { options = {true}; } 176 BoxedProcedurePass(bool useThunks) { options = {useThunks}; } 177 178 inline mlir::ModuleOp getModule() { return getOperation(); } 179 180 void runOnOperation() override final { 181 if (options.useThunks) { 182 auto *context = &getContext(); 183 mlir::IRRewriter rewriter(context); 184 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); 185 mlir::Dialect *firDialect = context->getLoadedDialect("fir"); 186 getModule().walk([&](mlir::Operation *op) { 187 typeConverter.setLocation(op->getLoc()); 188 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { 189 auto ty = addr.getVal().getType(); 190 if (typeConverter.needsConversion(ty) || 191 ty.isa<mlir::FunctionType>()) { 192 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` 193 // or function type to be `fir.convert` ops. 194 rewriter.setInsertionPoint(addr); 195 rewriter.replaceOpWithNewOp<ConvertOp>( 196 addr, typeConverter.convertType(addr.getType()), addr.getVal()); 197 } 198 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { 199 mlir::FunctionType ty = func.getFunctionType(); 200 if (typeConverter.needsConversion(ty)) { 201 rewriter.startRootUpdate(func); 202 auto toTy = 203 typeConverter.convertType(ty).cast<mlir::FunctionType>(); 204 if (!func.empty()) 205 for (auto e : llvm::enumerate(toTy.getInputs())) { 206 unsigned i = e.index(); 207 auto &block = func.front(); 208 block.insertArgument(i, e.value(), func.getLoc()); 209 block.getArgument(i + 1).replaceAllUsesWith( 210 block.getArgument(i)); 211 block.eraseArgument(i + 1); 212 } 213 func.setType(toTy); 214 rewriter.finalizeRootUpdate(func); 215 } 216 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { 217 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk 218 // as required. 219 mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy(); 220 rewriter.setInsertionPoint(embox); 221 if (embox.getHost()) { 222 // Create the thunk. 223 auto module = embox->getParentOfType<mlir::ModuleOp>(); 224 fir::KindMapping kindMap = getKindMapping(module); 225 FirOpBuilder builder(rewriter, kindMap); 226 auto loc = embox.getLoc(); 227 mlir::Type i8Ty = builder.getI8Type(); 228 mlir::Type i8Ptr = builder.getRefType(i8Ty); 229 mlir::Type buffTy = SequenceType::get({32}, i8Ty); 230 auto buffer = builder.create<AllocaOp>(loc, buffTy); 231 mlir::Value closure = 232 builder.createConvert(loc, i8Ptr, embox.getHost()); 233 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); 234 mlir::Value func = 235 builder.createConvert(loc, i8Ptr, embox.getFunc()); 236 builder.create<fir::CallOp>( 237 loc, factory::getLlvmInitTrampoline(builder), 238 llvm::ArrayRef<mlir::Value>{tramp, func, closure}); 239 auto adjustCall = builder.create<fir::CallOp>( 240 loc, factory::getLlvmAdjustTrampoline(builder), 241 llvm::ArrayRef<mlir::Value>{tramp}); 242 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 243 adjustCall.getResult(0)); 244 } else { 245 // Just forward the function as a pointer. 246 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, 247 embox.getFunc()); 248 } 249 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { 250 auto ty = mem.getType(); 251 if (typeConverter.needsConversion(ty)) { 252 rewriter.setInsertionPoint(mem); 253 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 254 bool isPinned = mem.getPinned(); 255 llvm::StringRef uniqName = 256 mem.getUniqName().value_or(llvm::StringRef()); 257 llvm::StringRef bindcName = 258 mem.getBindcName().value_or(llvm::StringRef()); 259 rewriter.replaceOpWithNewOp<AllocaOp>( 260 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), 261 mem.getShape()); 262 } 263 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { 264 auto ty = mem.getType(); 265 if (typeConverter.needsConversion(ty)) { 266 rewriter.setInsertionPoint(mem); 267 auto toTy = typeConverter.convertType(unwrapRefType(ty)); 268 llvm::StringRef uniqName = 269 mem.getUniqName().value_or(llvm::StringRef()); 270 llvm::StringRef bindcName = 271 mem.getBindcName().value_or(llvm::StringRef()); 272 rewriter.replaceOpWithNewOp<AllocMemOp>( 273 mem, toTy, uniqName, bindcName, mem.getTypeparams(), 274 mem.getShape()); 275 } 276 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { 277 auto ty = coor.getType(); 278 mlir::Type baseTy = coor.getBaseType(); 279 if (typeConverter.needsConversion(ty) || 280 typeConverter.needsConversion(baseTy)) { 281 rewriter.setInsertionPoint(coor); 282 auto toTy = typeConverter.convertType(ty); 283 auto toBaseTy = typeConverter.convertType(baseTy); 284 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), 285 coor.getCoor(), toBaseTy); 286 } 287 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { 288 auto ty = index.getType(); 289 mlir::Type onTy = index.getOnType(); 290 if (typeConverter.needsConversion(ty) || 291 typeConverter.needsConversion(onTy)) { 292 rewriter.setInsertionPoint(index); 293 auto toTy = typeConverter.convertType(ty); 294 auto toOnTy = typeConverter.convertType(onTy); 295 rewriter.replaceOpWithNewOp<FieldIndexOp>( 296 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 297 } 298 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { 299 auto ty = index.getType(); 300 mlir::Type onTy = index.getOnType(); 301 if (typeConverter.needsConversion(ty) || 302 typeConverter.needsConversion(onTy)) { 303 rewriter.setInsertionPoint(index); 304 auto toTy = typeConverter.convertType(ty); 305 auto toOnTy = typeConverter.convertType(onTy); 306 rewriter.replaceOpWithNewOp<LenParamIndexOp>( 307 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); 308 } 309 } else if (op->getDialect() == firDialect) { 310 rewriter.startRootUpdate(op); 311 for (auto i : llvm::enumerate(op->getResultTypes())) 312 if (typeConverter.needsConversion(i.value())) { 313 auto toTy = typeConverter.convertType(i.value()); 314 op->getResult(i.index()).setType(toTy); 315 } 316 rewriter.finalizeRootUpdate(op); 317 } 318 }); 319 } 320 // TODO: any alternative implementation. Note: currently, the default code 321 // gen will not be able to handle boxproc and will give an error. 322 } 323 324 private: 325 BoxedProcedureOptions options; 326 }; 327 } // namespace 328 329 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { 330 return std::make_unique<BoxedProcedurePass>(); 331 } 332 333 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { 334 return std::make_unique<BoxedProcedurePass>(useThunks); 335 } 336