//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE "flang-abstract-result-opt" namespace fir { namespace { struct AbstractResultOptions { // Always pass result as a fir.box argument. bool boxResult = false; // New function block argument for the result if the current FuncOp had // an abstract result. mlir::Value newArg; }; static mlir::Type getResultArgumentType(mlir::Type resultType, const AbstractResultOptions &options) { return llvm::TypeSwitch(resultType) .Case( [&](mlir::Type type) -> mlir::Type { if (options.boxResult) return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) .Case([](mlir::Type type) -> mlir::Type { return fir::ReferenceType::get(type); }) .Default([](mlir::Type) -> mlir::Type { llvm_unreachable("bad abstract result type"); }); } static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, const AbstractResultOptions &options) { auto resultType = funcTy.getResult(0); auto argTy = getResultArgumentType(resultType, options); llvm::SmallVector newInputTypes = {argTy}; newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, /*resultTypes=*/{}); } static bool mustEmboxResult(mlir::Type resultType, const AbstractResultOptions &options) { return resultType.isa() && options.boxResult; } class CallOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) : OpRewritePattern(context), options{opt} {} mlir::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { auto loc = callOp.getLoc(); auto result = callOp->getResult(0); if (!result.hasOneUse()) { mlir::emitError(loc, "calls with abstract result must have exactly one user"); return mlir::failure(); } auto saveResult = mlir::dyn_cast(result.use_begin().getUser()); if (!saveResult) { mlir::emitError( loc, "calls with abstract result must be used in fir.save_result"); return mlir::failure(); } auto argType = getResultArgumentType(result.getType(), options); auto buffer = saveResult.getMemref(); mlir::Value arg = buffer; if (mustEmboxResult(result.getType(), options)) arg = rewriter.create( loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, saveResult.getTypeparams()); llvm::SmallVector newResultTypes; if (callOp.getCallee()) { llvm::SmallVector newOperands = {arg}; newOperands.append(callOp.getOperands().begin(), callOp.getOperands().end()); rewriter.create(loc, *callOp.getCallee(), newResultTypes, newOperands); } else { // Indirect calls. llvm::SmallVector newInputTypes = {argType}; for (auto operand : callOp.getOperands().drop_front()) newInputTypes.push_back(operand.getType()); auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, newResultTypes); llvm::SmallVector newOperands; newOperands.push_back( rewriter.create(loc, funTy, callOp.getOperand(0))); newOperands.push_back(arg); newOperands.append(callOp.getOperands().begin() + 1, callOp.getOperands().end()); rewriter.create(loc, mlir::SymbolRefAttr{}, newResultTypes, newOperands); } callOp->dropAllReferences(); rewriter.eraseOp(callOp); return mlir::success(); } private: const AbstractResultOptions &options; }; class SaveResultOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; SaveResultOpConversion(mlir::MLIRContext *context) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(fir::SaveResultOp op, mlir::PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return mlir::success(); } }; class ReturnOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ReturnOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) : OpRewritePattern(context), options{opt} {} mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { rewriter.setInsertionPoint(ret); auto returnedValue = ret.getOperand(0); bool replacedStorage = false; if (auto *op = returnedValue.getDefiningOp()) if (auto load = mlir::dyn_cast(op)) { auto resultStorage = load.getMemref(); load.getMemref().replaceAllUsesWith(options.newArg); replacedStorage = true; if (auto *alloc = resultStorage.getDefiningOp()) if (alloc->use_empty()) rewriter.eraseOp(alloc); } // The result storage may have been optimized out by a memory to // register pass, this is possible for fir.box results, or fir.record // with no length parameters. Simply store the result in the result storage. // at the return point. if (!replacedStorage) rewriter.create(ret.getLoc(), returnedValue, options.newArg); rewriter.replaceOpWithNewOp(ret); return mlir::success(); } private: const AbstractResultOptions &options; }; class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; AddrOfOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) : OpRewritePattern(context), options{opt} {} mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = addrOf.getType().cast(); auto newFuncTy = getNewFunctionType(oldFuncTy, options); auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, addrOf.getSymbol()); // Rather than converting all op a function pointer might transit through // (e.g calls, stores, loads, converts...), cast new type to the abstract // type. A conversion will be added when calling indirect calls of abstract // types. rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); return mlir::success(); } private: const AbstractResultOptions &options; }; class AbstractResultOpt : public fir::AbstractResultOptBase { public: void runOnOperation() override { auto *context = &getContext(); auto func = getOperation(); auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; AbstractResultOptions options{passResultAsBox.getValue(), /*newArg=*/{}}; // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { func.setType(getNewFunctionType(funcTy, options)); unsigned zero = 0; if (!func.empty()) { // Insert new argument mlir::OpBuilder rewriter(context); auto resultType = funcTy.getResult(0); auto argTy = getResultArgumentType(resultType, options); options.newArg = func.front().insertArgument(zero, argTy, loc); if (mustEmboxResult(resultType, options)) { auto bufferType = fir::ReferenceType::get(resultType); rewriter.setInsertionPointToStart(&func.front()); options.newArg = rewriter.create(loc, bufferType, options.newArg); } patterns.insert(context, options); target.addDynamicallyLegalOp( [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); } } if (func.empty()) return; // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::CallOp call) { return !hasAbstractResult(call.getFunctionType()); }); target.addDynamicallyLegalOp([](fir::AddrOfOp addrOf) { if (auto funTy = addrOf.getType().dyn_cast()) return !hasAbstractResult(funTy); return true; }); target.addDynamicallyLegalOp([](fir::DispatchOp dispatch) { if (dispatch->getNumResults() != 1) return true; auto resultType = dispatch->getResult(0).getType(); if (resultType.isa()) { TODO(dispatch.getLoc(), "dispatchOp with abstract results"); return false; } return true; }); patterns.insert(context, options); patterns.insert(context); patterns.insert(context, options); if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns)))) { mlir::emitError(func.getLoc(), "error in converting abstract results\n"); signalPassFailure(); } } }; } // end anonymous namespace } // namespace fir std::unique_ptr fir::createAbstractResultOptPass() { return std::make_unique(); }