1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/Passes.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT 25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-abstract-result-opt" 30 31 using namespace mlir; 32 33 namespace fir { 34 namespace { 35 36 static mlir::Type getResultArgumentType(mlir::Type resultType, 37 bool shouldBoxResult) { 38 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 39 .Case<fir::SequenceType, fir::RecordType>( 40 [&](mlir::Type type) -> mlir::Type { 41 if (shouldBoxResult) 42 return fir::BoxType::get(type); 43 return fir::ReferenceType::get(type); 44 }) 45 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 46 return fir::ReferenceType::get(type); 47 }) 48 .Default([](mlir::Type) -> mlir::Type { 49 llvm_unreachable("bad abstract result type"); 50 }); 51 } 52 53 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 54 bool shouldBoxResult) { 55 auto resultType = funcTy.getResult(0); 56 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 57 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 58 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 59 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 60 /*resultTypes=*/{}); 61 } 62 63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 64 /// Follow the ABI for interoperability with C. 65 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 66 auto resultType = funcTy.getResult(0); 67 assert(fir::isa_builtin_cptr_type(resultType)); 68 llvm::SmallVector<mlir::Type> outputTypes; 69 auto recTy = resultType.dyn_cast<fir::RecordType>(); 70 outputTypes.emplace_back(recTy.getTypeList()[0].second); 71 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 72 outputTypes); 73 } 74 75 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 76 return resultType.isa<fir::SequenceType, fir::RecordType>() && 77 shouldBoxResult; 78 } 79 80 template <typename Op> 81 class CallConversion : public mlir::OpRewritePattern<Op> { 82 public: 83 using mlir::OpRewritePattern<Op>::OpRewritePattern; 84 85 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 86 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 87 88 mlir::LogicalResult 89 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 90 auto loc = op.getLoc(); 91 auto result = op->getResult(0); 92 if (!result.hasOneUse()) { 93 mlir::emitError(loc, 94 "calls with abstract result must have exactly one user"); 95 return mlir::failure(); 96 } 97 auto saveResult = 98 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 99 if (!saveResult) { 100 mlir::emitError( 101 loc, "calls with abstract result must be used in fir.save_result"); 102 return mlir::failure(); 103 } 104 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 105 auto buffer = saveResult.getMemref(); 106 mlir::Value arg = buffer; 107 if (mustEmboxResult(result.getType(), shouldBoxResult)) 108 arg = rewriter.create<fir::EmboxOp>( 109 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 110 saveResult.getTypeparams()); 111 112 llvm::SmallVector<mlir::Type> newResultTypes; 113 // TODO: This should be generalized for derived types, and it is 114 // architecture and OS dependent. 115 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 116 Op newOp; 117 if (isResultBuiltinCPtr) { 118 auto recTy = result.getType().template dyn_cast<fir::RecordType>(); 119 newResultTypes.emplace_back(recTy.getTypeList()[0].second); 120 } 121 122 // fir::CallOp specific handling. 123 if constexpr (std::is_same_v<Op, fir::CallOp>) { 124 if (op.getCallee()) { 125 llvm::SmallVector<mlir::Value> newOperands; 126 if (!isResultBuiltinCPtr) 127 newOperands.emplace_back(arg); 128 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 129 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 130 newResultTypes, newOperands); 131 } else { 132 // Indirect calls. 133 llvm::SmallVector<mlir::Type> newInputTypes; 134 if (!isResultBuiltinCPtr) 135 newInputTypes.emplace_back(argType); 136 for (auto operand : op.getOperands().drop_front()) 137 newInputTypes.push_back(operand.getType()); 138 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 139 newResultTypes); 140 141 llvm::SmallVector<mlir::Value> newOperands; 142 newOperands.push_back( 143 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 144 if (!isResultBuiltinCPtr) 145 newOperands.push_back(arg); 146 newOperands.append(op.getOperands().begin() + 1, 147 op.getOperands().end()); 148 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 149 newResultTypes, newOperands); 150 } 151 } 152 153 // fir::DispatchOp specific handling. 154 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 155 llvm::SmallVector<mlir::Value> newOperands; 156 if (!isResultBuiltinCPtr) 157 newOperands.emplace_back(arg); 158 unsigned passArgShift = newOperands.size(); 159 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 160 161 fir::DispatchOp newDispatchOp; 162 if (op.getPassArgPos()) 163 newOp = rewriter.create<fir::DispatchOp>( 164 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 165 op.getOperands()[0], newOperands, 166 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); 167 else 168 newOp = rewriter.create<fir::DispatchOp>( 169 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 170 op.getOperands()[0], newOperands, nullptr); 171 } 172 173 if (isResultBuiltinCPtr) { 174 mlir::Value save = saveResult.getMemref(); 175 auto module = op->template getParentOfType<mlir::ModuleOp>(); 176 fir::KindMapping kindMap = fir::getKindMapping(module); 177 FirOpBuilder builder(rewriter, kindMap); 178 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 179 builder, loc, save, result.getType()); 180 rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); 181 } 182 op->dropAllReferences(); 183 rewriter.eraseOp(op); 184 return mlir::success(); 185 } 186 187 private: 188 bool shouldBoxResult; 189 }; 190 191 class SaveResultOpConversion 192 : public mlir::OpRewritePattern<fir::SaveResultOp> { 193 public: 194 using OpRewritePattern::OpRewritePattern; 195 SaveResultOpConversion(mlir::MLIRContext *context) 196 : OpRewritePattern(context) {} 197 mlir::LogicalResult 198 matchAndRewrite(fir::SaveResultOp op, 199 mlir::PatternRewriter &rewriter) const override { 200 rewriter.eraseOp(op); 201 return mlir::success(); 202 } 203 }; 204 205 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 206 public: 207 using OpRewritePattern::OpRewritePattern; 208 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 209 : OpRewritePattern(context), newArg{newArg} {} 210 mlir::LogicalResult 211 matchAndRewrite(mlir::func::ReturnOp ret, 212 mlir::PatternRewriter &rewriter) const override { 213 auto loc = ret.getLoc(); 214 rewriter.setInsertionPoint(ret); 215 auto returnedValue = ret.getOperand(0); 216 bool replacedStorage = false; 217 if (auto *op = returnedValue.getDefiningOp()) 218 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 219 auto resultStorage = load.getMemref(); 220 // The result alloca may be behind a fir.declare, if any. 221 if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>( 222 resultStorage.getDefiningOp())) 223 resultStorage = declare.getMemref(); 224 // TODO: This should be generalized for derived types, and it is 225 // architecture and OS dependent. 226 if (fir::isa_builtin_cptr_type(returnedValue.getType())) { 227 rewriter.eraseOp(load); 228 auto module = ret->getParentOfType<mlir::ModuleOp>(); 229 fir::KindMapping kindMap = fir::getKindMapping(module); 230 FirOpBuilder builder(rewriter, kindMap); 231 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( 232 builder, loc, resultStorage, returnedValue.getType()); 233 mlir::Value retValue = rewriter.create<fir::LoadOp>( 234 loc, fir::unwrapRefType(retAddr.getType()), retAddr); 235 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 236 ret, mlir::ValueRange{retValue}); 237 return mlir::success(); 238 } 239 resultStorage.replaceAllUsesWith(newArg); 240 replacedStorage = true; 241 if (auto *alloc = resultStorage.getDefiningOp()) 242 if (alloc->use_empty()) 243 rewriter.eraseOp(alloc); 244 } 245 // The result storage may have been optimized out by a memory to 246 // register pass, this is possible for fir.box results, or fir.record 247 // with no length parameters. Simply store the result in the result storage. 248 // at the return point. 249 if (!replacedStorage) 250 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); 251 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 252 return mlir::success(); 253 } 254 255 private: 256 mlir::Value newArg; 257 }; 258 259 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 260 public: 261 using OpRewritePattern::OpRewritePattern; 262 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 263 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 264 mlir::LogicalResult 265 matchAndRewrite(fir::AddrOfOp addrOf, 266 mlir::PatternRewriter &rewriter) const override { 267 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 268 mlir::FunctionType newFuncTy; 269 // TODO: This should be generalized for derived types, and it is 270 // architecture and OS dependent. 271 if (oldFuncTy.getNumResults() != 0 && 272 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 273 newFuncTy = getCPtrFunctionType(oldFuncTy); 274 else 275 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 276 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 277 addrOf.getSymbol()); 278 // Rather than converting all op a function pointer might transit through 279 // (e.g calls, stores, loads, converts...), cast new type to the abstract 280 // type. A conversion will be added when calling indirect calls of abstract 281 // types. 282 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 283 return mlir::success(); 284 } 285 286 private: 287 bool shouldBoxResult; 288 }; 289 290 /// @brief Base CRTP class for AbstractResult pass family. 291 /// Contains common logic for abstract result conversion in a reusable fashion. 292 /// @tparam Pass target class that implements operation-specific logic. 293 /// @tparam PassBase base class template for the pass generated by TableGen. 294 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, 295 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. 296 /// This function should implement operation-specific functionality. 297 template <typename Pass, template <typename> class PassBase> 298 class AbstractResultOptTemplate : public PassBase<Pass> { 299 public: 300 void runOnOperation() override { 301 auto *context = &this->getContext(); 302 auto op = this->getOperation(); 303 304 mlir::RewritePatternSet patterns(context); 305 mlir::ConversionTarget target = *context; 306 const bool shouldBoxResult = this->passResultAsBox.getValue(); 307 308 auto &self = static_cast<Pass &>(*this); 309 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); 310 311 // Convert the calls and, if needed, the ReturnOp in the function body. 312 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 313 mlir::func::FuncDialect>(); 314 target.addIllegalOp<fir::SaveResultOp>(); 315 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 316 return !hasAbstractResult(call.getFunctionType()); 317 }); 318 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 319 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 320 return !hasAbstractResult(funTy); 321 return true; 322 }); 323 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 324 return !hasAbstractResult(dispatch.getFunctionType()); 325 }); 326 327 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 328 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 329 patterns.insert<SaveResultOpConversion>(context); 330 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 331 if (mlir::failed( 332 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 333 mlir::emitError(op.getLoc(), "error in converting abstract results\n"); 334 this->signalPassFailure(); 335 } 336 } 337 }; 338 339 class AbstractResultOnFuncOpt 340 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt, 341 fir::impl::AbstractResultOnFuncOptBase> { 342 public: 343 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 344 mlir::RewritePatternSet &patterns, 345 mlir::ConversionTarget &target) { 346 auto loc = func.getLoc(); 347 auto *context = &getContext(); 348 // Convert function type itself if it has an abstract result. 349 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 350 if (hasAbstractResult(funcTy)) { 351 // TODO: This should be generalized for derived types, and it is 352 // architecture and OS dependent. 353 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 354 func.setType(getCPtrFunctionType(funcTy)); 355 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 356 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 357 [](mlir::func::ReturnOp ret) { 358 mlir::Type retTy = ret.getOperand(0).getType(); 359 return !fir::isa_builtin_cptr_type(retTy); 360 }); 361 return; 362 } 363 if (!func.empty()) { 364 // Insert new argument. 365 mlir::OpBuilder rewriter(context); 366 auto resultType = funcTy.getResult(0); 367 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 368 func.insertArgument(0u, argTy, {}, loc); 369 func.eraseResult(0u); 370 mlir::Value newArg = func.getArgument(0u); 371 if (mustEmboxResult(resultType, shouldBoxResult)) { 372 auto bufferType = fir::ReferenceType::get(resultType); 373 rewriter.setInsertionPointToStart(&func.front()); 374 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 375 } 376 patterns.insert<ReturnOpConversion>(context, newArg); 377 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 378 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 379 assert(func.getFunctionType() == 380 getNewFunctionType(funcTy, shouldBoxResult)); 381 } else { 382 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 383 func.getAllArgAttrs(allArgs); 384 allArgs.insert(allArgs.begin(), 385 mlir::DictionaryAttr::get(func->getContext())); 386 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 387 func.setAllArgAttrs(allArgs); 388 } 389 } 390 } 391 }; 392 393 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 394 return mlir::TypeSwitch<mlir::Type, bool>(type) 395 .Case([](fir::BoxProcType boxProc) { 396 return fir::hasAbstractResult( 397 boxProc.getEleTy().cast<mlir::FunctionType>()); 398 }) 399 .Case([](fir::PointerType pointer) { 400 return fir::hasAbstractResult( 401 pointer.getEleTy().cast<mlir::FunctionType>()); 402 }) 403 .Default([](auto &&) { return false; }); 404 } 405 406 class AbstractResultOnGlobalOpt 407 : public AbstractResultOptTemplate< 408 AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> { 409 public: 410 void runOnSpecificOperation(fir::GlobalOp global, bool, 411 mlir::RewritePatternSet &, 412 mlir::ConversionTarget &) { 413 if (containsFunctionTypeWithAbstractResult(global.getType())) { 414 TODO(global->getLoc(), "support for procedure pointers"); 415 } 416 } 417 }; 418 } // end anonymous namespace 419 } // namespace fir 420 421 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() { 422 return std::make_unique<AbstractResultOnFuncOpt>(); 423 } 424 425 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() { 426 return std::make_unique<AbstractResultOnGlobalOpt>(); 427 } 428