1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/Todo.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/Passes.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 namespace fir { 22 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT 23 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT 24 #include "flang/Optimizer/Transforms/Passes.h.inc" 25 } // namespace fir 26 27 #define DEBUG_TYPE "flang-abstract-result-opt" 28 29 namespace fir { 30 namespace { 31 32 static mlir::Type getResultArgumentType(mlir::Type resultType, 33 bool shouldBoxResult) { 34 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 35 .Case<fir::SequenceType, fir::RecordType>( 36 [&](mlir::Type type) -> mlir::Type { 37 if (shouldBoxResult) 38 return fir::BoxType::get(type); 39 return fir::ReferenceType::get(type); 40 }) 41 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 42 return fir::ReferenceType::get(type); 43 }) 44 .Default([](mlir::Type) -> mlir::Type { 45 llvm_unreachable("bad abstract result type"); 46 }); 47 } 48 49 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 50 bool shouldBoxResult) { 51 auto resultType = funcTy.getResult(0); 52 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 53 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 54 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 55 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 56 /*resultTypes=*/{}); 57 } 58 59 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 60 return resultType.isa<fir::SequenceType, fir::RecordType>() && 61 shouldBoxResult; 62 } 63 64 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 65 public: 66 using OpRewritePattern::OpRewritePattern; 67 CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 68 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 69 mlir::LogicalResult 70 matchAndRewrite(fir::CallOp callOp, 71 mlir::PatternRewriter &rewriter) const override { 72 auto loc = callOp.getLoc(); 73 auto result = callOp->getResult(0); 74 if (!result.hasOneUse()) { 75 mlir::emitError(loc, 76 "calls with abstract result must have exactly one user"); 77 return mlir::failure(); 78 } 79 auto saveResult = 80 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 81 if (!saveResult) { 82 mlir::emitError( 83 loc, "calls with abstract result must be used in fir.save_result"); 84 return mlir::failure(); 85 } 86 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 87 auto buffer = saveResult.getMemref(); 88 mlir::Value arg = buffer; 89 if (mustEmboxResult(result.getType(), shouldBoxResult)) 90 arg = rewriter.create<fir::EmboxOp>( 91 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 92 saveResult.getTypeparams()); 93 94 llvm::SmallVector<mlir::Type> newResultTypes; 95 if (callOp.getCallee()) { 96 llvm::SmallVector<mlir::Value> newOperands = {arg}; 97 newOperands.append(callOp.getOperands().begin(), 98 callOp.getOperands().end()); 99 rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes, 100 newOperands); 101 } else { 102 // Indirect calls. 103 llvm::SmallVector<mlir::Type> newInputTypes = {argType}; 104 for (auto operand : callOp.getOperands().drop_front()) 105 newInputTypes.push_back(operand.getType()); 106 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, 107 newResultTypes); 108 109 llvm::SmallVector<mlir::Value> newOperands; 110 newOperands.push_back( 111 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0))); 112 newOperands.push_back(arg); 113 newOperands.append(callOp.getOperands().begin() + 1, 114 callOp.getOperands().end()); 115 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes, 116 newOperands); 117 } 118 callOp->dropAllReferences(); 119 rewriter.eraseOp(callOp); 120 return mlir::success(); 121 } 122 123 private: 124 bool shouldBoxResult; 125 }; 126 127 class SaveResultOpConversion 128 : public mlir::OpRewritePattern<fir::SaveResultOp> { 129 public: 130 using OpRewritePattern::OpRewritePattern; 131 SaveResultOpConversion(mlir::MLIRContext *context) 132 : OpRewritePattern(context) {} 133 mlir::LogicalResult 134 matchAndRewrite(fir::SaveResultOp op, 135 mlir::PatternRewriter &rewriter) const override { 136 rewriter.eraseOp(op); 137 return mlir::success(); 138 } 139 }; 140 141 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 142 public: 143 using OpRewritePattern::OpRewritePattern; 144 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 145 : OpRewritePattern(context), newArg{newArg} {} 146 mlir::LogicalResult 147 matchAndRewrite(mlir::func::ReturnOp ret, 148 mlir::PatternRewriter &rewriter) const override { 149 rewriter.setInsertionPoint(ret); 150 auto returnedValue = ret.getOperand(0); 151 bool replacedStorage = false; 152 if (auto *op = returnedValue.getDefiningOp()) 153 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 154 auto resultStorage = load.getMemref(); 155 load.getMemref().replaceAllUsesWith(newArg); 156 replacedStorage = true; 157 if (auto *alloc = resultStorage.getDefiningOp()) 158 if (alloc->use_empty()) 159 rewriter.eraseOp(alloc); 160 } 161 // The result storage may have been optimized out by a memory to 162 // register pass, this is possible for fir.box results, or fir.record 163 // with no length parameters. Simply store the result in the result storage. 164 // at the return point. 165 if (!replacedStorage) 166 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg); 167 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 168 return mlir::success(); 169 } 170 171 private: 172 mlir::Value newArg; 173 }; 174 175 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 176 public: 177 using OpRewritePattern::OpRewritePattern; 178 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 179 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 180 mlir::LogicalResult 181 matchAndRewrite(fir::AddrOfOp addrOf, 182 mlir::PatternRewriter &rewriter) const override { 183 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 184 auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 185 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 186 addrOf.getSymbol()); 187 // Rather than converting all op a function pointer might transit through 188 // (e.g calls, stores, loads, converts...), cast new type to the abstract 189 // type. A conversion will be added when calling indirect calls of abstract 190 // types. 191 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 192 return mlir::success(); 193 } 194 195 private: 196 bool shouldBoxResult; 197 }; 198 199 /// @brief Base CRTP class for AbstractResult pass family. 200 /// Contains common logic for abstract result conversion in a reusable fashion. 201 /// @tparam Pass target class that implements operation-specific logic. 202 /// @tparam PassBase base class template for the pass generated by TableGen. 203 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, 204 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. 205 /// This function should implement operation-specific functionality. 206 template <typename Pass, template <typename> class PassBase> 207 class AbstractResultOptTemplate : public PassBase<Pass> { 208 public: 209 void runOnOperation() override { 210 auto *context = &this->getContext(); 211 auto op = this->getOperation(); 212 213 mlir::RewritePatternSet patterns(context); 214 mlir::ConversionTarget target = *context; 215 const bool shouldBoxResult = this->passResultAsBox.getValue(); 216 217 auto &self = static_cast<Pass &>(*this); 218 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); 219 220 // Convert the calls and, if needed, the ReturnOp in the function body. 221 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 222 mlir::func::FuncDialect>(); 223 target.addIllegalOp<fir::SaveResultOp>(); 224 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 225 return !hasAbstractResult(call.getFunctionType()); 226 }); 227 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 228 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 229 return !hasAbstractResult(funTy); 230 return true; 231 }); 232 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 233 if (dispatch->getNumResults() != 1) 234 return true; 235 auto resultType = dispatch->getResult(0).getType(); 236 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 237 TODO(dispatch.getLoc(), "dispatchOp with abstract results"); 238 return false; 239 } 240 return true; 241 }); 242 243 patterns.insert<CallOpConversion>(context, shouldBoxResult); 244 patterns.insert<SaveResultOpConversion>(context); 245 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 246 if (mlir::failed( 247 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 248 mlir::emitError(op.getLoc(), "error in converting abstract results\n"); 249 this->signalPassFailure(); 250 } 251 } 252 }; 253 254 class AbstractResultOnFuncOpt 255 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt, 256 fir::impl::AbstractResultOnFuncOptBase> { 257 public: 258 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 259 mlir::RewritePatternSet &patterns, 260 mlir::ConversionTarget &target) { 261 auto loc = func.getLoc(); 262 auto *context = &getContext(); 263 // Convert function type itself if it has an abstract result. 264 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 265 if (hasAbstractResult(funcTy)) { 266 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 267 if (!func.empty()) { 268 // Insert new argument. 269 mlir::OpBuilder rewriter(context); 270 auto resultType = funcTy.getResult(0); 271 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 272 mlir::Value newArg = func.front().insertArgument(0u, argTy, loc); 273 if (mustEmboxResult(resultType, shouldBoxResult)) { 274 auto bufferType = fir::ReferenceType::get(resultType); 275 rewriter.setInsertionPointToStart(&func.front()); 276 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 277 } 278 patterns.insert<ReturnOpConversion>(context, newArg); 279 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 280 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 281 } 282 } 283 } 284 }; 285 286 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 287 return mlir::TypeSwitch<mlir::Type, bool>(type) 288 .Case([](fir::BoxProcType boxProc) { 289 return fir::hasAbstractResult( 290 boxProc.getEleTy().cast<mlir::FunctionType>()); 291 }) 292 .Case([](fir::PointerType pointer) { 293 return fir::hasAbstractResult( 294 pointer.getEleTy().cast<mlir::FunctionType>()); 295 }) 296 .Default([](auto &&) { return false; }); 297 } 298 299 class AbstractResultOnGlobalOpt 300 : public AbstractResultOptTemplate< 301 AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> { 302 public: 303 void runOnSpecificOperation(fir::GlobalOp global, bool, 304 mlir::RewritePatternSet &, 305 mlir::ConversionTarget &) { 306 if (containsFunctionTypeWithAbstractResult(global.getType())) { 307 TODO(global->getLoc(), "support for procedure pointers"); 308 } 309 } 310 }; 311 } // end anonymous namespace 312 } // namespace fir 313 314 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() { 315 return std::make_unique<AbstractResultOnFuncOpt>(); 316 } 317 318 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() { 319 return std::make_unique<AbstractResultOnGlobalOpt>(); 320 } 321