1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/Passes.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT 25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-abstract-result-opt" 30 31 namespace fir { 32 namespace { 33 34 static mlir::Type getResultArgumentType(mlir::Type resultType, 35 bool shouldBoxResult) { 36 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 37 .Case<fir::SequenceType, fir::RecordType>( 38 [&](mlir::Type type) -> mlir::Type { 39 if (shouldBoxResult) 40 return fir::BoxType::get(type); 41 return fir::ReferenceType::get(type); 42 }) 43 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 44 return fir::ReferenceType::get(type); 45 }) 46 .Default([](mlir::Type) -> mlir::Type { 47 llvm_unreachable("bad abstract result type"); 48 }); 49 } 50 51 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 52 bool shouldBoxResult) { 53 auto resultType = funcTy.getResult(0); 54 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 55 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 56 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 57 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 58 /*resultTypes=*/{}); 59 } 60 61 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 62 /// Follow the ABI for interoperability with C. 63 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 64 auto resultType = funcTy.getResult(0); 65 assert(fir::isa_builtin_cptr_type(resultType)); 66 llvm::SmallVector<mlir::Type> outputTypes; 67 auto recTy = resultType.dyn_cast<fir::RecordType>(); 68 outputTypes.emplace_back(recTy.getTypeList()[0].second); 69 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 70 outputTypes); 71 } 72 73 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 74 return resultType.isa<fir::SequenceType, fir::RecordType>() && 75 shouldBoxResult; 76 } 77 78 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 79 public: 80 using OpRewritePattern::OpRewritePattern; 81 CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 82 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 83 mlir::LogicalResult 84 matchAndRewrite(fir::CallOp callOp, 85 mlir::PatternRewriter &rewriter) const override { 86 auto loc = callOp.getLoc(); 87 auto result = callOp->getResult(0); 88 if (!result.hasOneUse()) { 89 mlir::emitError(loc, 90 "calls with abstract result must have exactly one user"); 91 return mlir::failure(); 92 } 93 auto saveResult = 94 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 95 if (!saveResult) { 96 mlir::emitError( 97 loc, "calls with abstract result must be used in fir.save_result"); 98 return mlir::failure(); 99 } 100 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 101 auto buffer = saveResult.getMemref(); 102 mlir::Value arg = buffer; 103 if (mustEmboxResult(result.getType(), shouldBoxResult)) 104 arg = rewriter.create<fir::EmboxOp>( 105 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 106 saveResult.getTypeparams()); 107 108 llvm::SmallVector<mlir::Type> newResultTypes; 109 // TODO: This should be generalized for derived types, and it is 110 // architecture and OS dependent. 111 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 112 fir::CallOp newCallOp; 113 if (isResultBuiltinCPtr) { 114 auto recTy = result.getType().dyn_cast<fir::RecordType>(); 115 newResultTypes.emplace_back(recTy.getTypeList()[0].second); 116 } 117 if (callOp.getCallee()) { 118 llvm::SmallVector<mlir::Value> newOperands; 119 if (!isResultBuiltinCPtr) 120 newOperands.emplace_back(arg); 121 newOperands.append(callOp.getOperands().begin(), 122 callOp.getOperands().end()); 123 newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), 124 newResultTypes, newOperands); 125 } else { 126 // Indirect calls. 127 llvm::SmallVector<mlir::Type> newInputTypes; 128 if (!isResultBuiltinCPtr) 129 newInputTypes.emplace_back(argType); 130 for (auto operand : callOp.getOperands().drop_front()) 131 newInputTypes.push_back(operand.getType()); 132 auto newFuncTy = mlir::FunctionType::get(callOp.getContext(), 133 newInputTypes, newResultTypes); 134 135 llvm::SmallVector<mlir::Value> newOperands; 136 newOperands.push_back(rewriter.create<fir::ConvertOp>( 137 loc, newFuncTy, callOp.getOperand(0))); 138 if (!isResultBuiltinCPtr) 139 newOperands.push_back(arg); 140 newOperands.append(callOp.getOperands().begin() + 1, 141 callOp.getOperands().end()); 142 newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 143 newResultTypes, newOperands); 144 } 145 if (isResultBuiltinCPtr) { 146 mlir::Value save = saveResult.getMemref(); 147 auto module = callOp->getParentOfType<mlir::ModuleOp>(); 148 fir::KindMapping kindMap = fir::getKindMapping(module); 149 FirOpBuilder builder(rewriter, kindMap); 150 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 151 builder, loc, save, result.getType()); 152 rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr); 153 } 154 callOp->dropAllReferences(); 155 rewriter.eraseOp(callOp); 156 return mlir::success(); 157 } 158 159 private: 160 bool shouldBoxResult; 161 }; 162 163 class SaveResultOpConversion 164 : public mlir::OpRewritePattern<fir::SaveResultOp> { 165 public: 166 using OpRewritePattern::OpRewritePattern; 167 SaveResultOpConversion(mlir::MLIRContext *context) 168 : OpRewritePattern(context) {} 169 mlir::LogicalResult 170 matchAndRewrite(fir::SaveResultOp op, 171 mlir::PatternRewriter &rewriter) const override { 172 rewriter.eraseOp(op); 173 return mlir::success(); 174 } 175 }; 176 177 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 178 public: 179 using OpRewritePattern::OpRewritePattern; 180 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 181 : OpRewritePattern(context), newArg{newArg} {} 182 mlir::LogicalResult 183 matchAndRewrite(mlir::func::ReturnOp ret, 184 mlir::PatternRewriter &rewriter) const override { 185 auto loc = ret.getLoc(); 186 rewriter.setInsertionPoint(ret); 187 auto returnedValue = ret.getOperand(0); 188 bool replacedStorage = false; 189 if (auto *op = returnedValue.getDefiningOp()) 190 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 191 auto resultStorage = load.getMemref(); 192 // TODO: This should be generalized for derived types, and it is 193 // architecture and OS dependent. 194 if (fir::isa_builtin_cptr_type(returnedValue.getType())) { 195 rewriter.eraseOp(load); 196 auto module = ret->getParentOfType<mlir::ModuleOp>(); 197 fir::KindMapping kindMap = fir::getKindMapping(module); 198 FirOpBuilder builder(rewriter, kindMap); 199 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( 200 builder, loc, resultStorage, returnedValue.getType()); 201 mlir::Value retValue = rewriter.create<fir::LoadOp>( 202 loc, fir::unwrapRefType(retAddr.getType()), retAddr); 203 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 204 ret, mlir::ValueRange{retValue}); 205 return mlir::success(); 206 } 207 load.getMemref().replaceAllUsesWith(newArg); 208 replacedStorage = true; 209 if (auto *alloc = resultStorage.getDefiningOp()) 210 if (alloc->use_empty()) 211 rewriter.eraseOp(alloc); 212 } 213 // The result storage may have been optimized out by a memory to 214 // register pass, this is possible for fir.box results, or fir.record 215 // with no length parameters. Simply store the result in the result storage. 216 // at the return point. 217 if (!replacedStorage) 218 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); 219 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 220 return mlir::success(); 221 } 222 223 private: 224 mlir::Value newArg; 225 }; 226 227 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 228 public: 229 using OpRewritePattern::OpRewritePattern; 230 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 231 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 232 mlir::LogicalResult 233 matchAndRewrite(fir::AddrOfOp addrOf, 234 mlir::PatternRewriter &rewriter) const override { 235 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 236 mlir::FunctionType newFuncTy; 237 // TODO: This should be generalized for derived types, and it is 238 // architecture and OS dependent. 239 if (oldFuncTy.getNumResults() != 0 && 240 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 241 newFuncTy = getCPtrFunctionType(oldFuncTy); 242 else 243 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 244 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 245 addrOf.getSymbol()); 246 // Rather than converting all op a function pointer might transit through 247 // (e.g calls, stores, loads, converts...), cast new type to the abstract 248 // type. A conversion will be added when calling indirect calls of abstract 249 // types. 250 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 251 return mlir::success(); 252 } 253 254 private: 255 bool shouldBoxResult; 256 }; 257 258 /// @brief Base CRTP class for AbstractResult pass family. 259 /// Contains common logic for abstract result conversion in a reusable fashion. 260 /// @tparam Pass target class that implements operation-specific logic. 261 /// @tparam PassBase base class template for the pass generated by TableGen. 262 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, 263 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. 264 /// This function should implement operation-specific functionality. 265 template <typename Pass, template <typename> class PassBase> 266 class AbstractResultOptTemplate : public PassBase<Pass> { 267 public: 268 void runOnOperation() override { 269 auto *context = &this->getContext(); 270 auto op = this->getOperation(); 271 272 mlir::RewritePatternSet patterns(context); 273 mlir::ConversionTarget target = *context; 274 const bool shouldBoxResult = this->passResultAsBox.getValue(); 275 276 auto &self = static_cast<Pass &>(*this); 277 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); 278 279 // Convert the calls and, if needed, the ReturnOp in the function body. 280 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 281 mlir::func::FuncDialect>(); 282 target.addIllegalOp<fir::SaveResultOp>(); 283 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 284 return !hasAbstractResult(call.getFunctionType()); 285 }); 286 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 287 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 288 return !hasAbstractResult(funTy); 289 return true; 290 }); 291 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 292 if (dispatch->getNumResults() != 1) 293 return true; 294 auto resultType = dispatch->getResult(0).getType(); 295 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 296 TODO(dispatch.getLoc(), "dispatchOp with abstract results"); 297 return false; 298 } 299 return true; 300 }); 301 302 patterns.insert<CallOpConversion>(context, shouldBoxResult); 303 patterns.insert<SaveResultOpConversion>(context); 304 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 305 if (mlir::failed( 306 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 307 mlir::emitError(op.getLoc(), "error in converting abstract results\n"); 308 this->signalPassFailure(); 309 } 310 } 311 }; 312 313 class AbstractResultOnFuncOpt 314 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt, 315 fir::impl::AbstractResultOnFuncOptBase> { 316 public: 317 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 318 mlir::RewritePatternSet &patterns, 319 mlir::ConversionTarget &target) { 320 auto loc = func.getLoc(); 321 auto *context = &getContext(); 322 // Convert function type itself if it has an abstract result. 323 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 324 if (hasAbstractResult(funcTy)) { 325 // TODO: This should be generalized for derived types, and it is 326 // architecture and OS dependent. 327 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 328 func.setType(getCPtrFunctionType(funcTy)); 329 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 330 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 331 [](mlir::func::ReturnOp ret) { 332 mlir::Type retTy = ret.getOperand(0).getType(); 333 return !fir::isa_builtin_cptr_type(retTy); 334 }); 335 return; 336 } 337 if (!func.empty()) { 338 // Insert new argument. 339 mlir::OpBuilder rewriter(context); 340 auto resultType = funcTy.getResult(0); 341 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 342 func.insertArgument(0u, argTy, {}, loc); 343 func.eraseResult(0u); 344 mlir::Value newArg = func.getArgument(0u); 345 if (mustEmboxResult(resultType, shouldBoxResult)) { 346 auto bufferType = fir::ReferenceType::get(resultType); 347 rewriter.setInsertionPointToStart(&func.front()); 348 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 349 } 350 patterns.insert<ReturnOpConversion>(context, newArg); 351 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 352 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 353 assert(func.getFunctionType() == 354 getNewFunctionType(funcTy, shouldBoxResult)); 355 } else { 356 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 357 } 358 } 359 } 360 }; 361 362 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 363 return mlir::TypeSwitch<mlir::Type, bool>(type) 364 .Case([](fir::BoxProcType boxProc) { 365 return fir::hasAbstractResult( 366 boxProc.getEleTy().cast<mlir::FunctionType>()); 367 }) 368 .Case([](fir::PointerType pointer) { 369 return fir::hasAbstractResult( 370 pointer.getEleTy().cast<mlir::FunctionType>()); 371 }) 372 .Default([](auto &&) { return false; }); 373 } 374 375 class AbstractResultOnGlobalOpt 376 : public AbstractResultOptTemplate< 377 AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> { 378 public: 379 void runOnSpecificOperation(fir::GlobalOp global, bool, 380 mlir::RewritePatternSet &, 381 mlir::ConversionTarget &) { 382 if (containsFunctionTypeWithAbstractResult(global.getType())) { 383 TODO(global->getLoc(), "support for procedure pointers"); 384 } 385 } 386 }; 387 } // end anonymous namespace 388 } // namespace fir 389 390 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() { 391 return std::make_unique<AbstractResultOnFuncOpt>(); 392 } 393 394 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() { 395 return std::make_unique<AbstractResultOnGlobalOpt>(); 396 } 397