1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Transforms/Passes.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/Transforms/Passes.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 #define DEBUG_TYPE "flang-abstract-result-opt" 23 24 namespace fir { 25 namespace { 26 27 static mlir::Type getResultArgumentType(mlir::Type resultType, 28 bool shouldBoxResult) { 29 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 30 .Case<fir::SequenceType, fir::RecordType>( 31 [&](mlir::Type type) -> mlir::Type { 32 if (shouldBoxResult) 33 return fir::BoxType::get(type); 34 return fir::ReferenceType::get(type); 35 }) 36 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 37 return fir::ReferenceType::get(type); 38 }) 39 .Default([](mlir::Type) -> mlir::Type { 40 llvm_unreachable("bad abstract result type"); 41 }); 42 } 43 44 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 45 bool shouldBoxResult) { 46 auto resultType = funcTy.getResult(0); 47 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 48 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 49 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 50 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 51 /*resultTypes=*/{}); 52 } 53 54 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 55 return resultType.isa<fir::SequenceType, fir::RecordType>() && 56 shouldBoxResult; 57 } 58 59 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 60 public: 61 using OpRewritePattern::OpRewritePattern; 62 CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 63 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 64 mlir::LogicalResult 65 matchAndRewrite(fir::CallOp callOp, 66 mlir::PatternRewriter &rewriter) const override { 67 auto loc = callOp.getLoc(); 68 auto result = callOp->getResult(0); 69 if (!result.hasOneUse()) { 70 mlir::emitError(loc, 71 "calls with abstract result must have exactly one user"); 72 return mlir::failure(); 73 } 74 auto saveResult = 75 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 76 if (!saveResult) { 77 mlir::emitError( 78 loc, "calls with abstract result must be used in fir.save_result"); 79 return mlir::failure(); 80 } 81 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 82 auto buffer = saveResult.getMemref(); 83 mlir::Value arg = buffer; 84 if (mustEmboxResult(result.getType(), shouldBoxResult)) 85 arg = rewriter.create<fir::EmboxOp>( 86 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 87 saveResult.getTypeparams()); 88 89 llvm::SmallVector<mlir::Type> newResultTypes; 90 if (callOp.getCallee()) { 91 llvm::SmallVector<mlir::Value> newOperands = {arg}; 92 newOperands.append(callOp.getOperands().begin(), 93 callOp.getOperands().end()); 94 rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes, 95 newOperands); 96 } else { 97 // Indirect calls. 98 llvm::SmallVector<mlir::Type> newInputTypes = {argType}; 99 for (auto operand : callOp.getOperands().drop_front()) 100 newInputTypes.push_back(operand.getType()); 101 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, 102 newResultTypes); 103 104 llvm::SmallVector<mlir::Value> newOperands; 105 newOperands.push_back( 106 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0))); 107 newOperands.push_back(arg); 108 newOperands.append(callOp.getOperands().begin() + 1, 109 callOp.getOperands().end()); 110 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes, 111 newOperands); 112 } 113 callOp->dropAllReferences(); 114 rewriter.eraseOp(callOp); 115 return mlir::success(); 116 } 117 118 private: 119 bool shouldBoxResult; 120 }; 121 122 class SaveResultOpConversion 123 : public mlir::OpRewritePattern<fir::SaveResultOp> { 124 public: 125 using OpRewritePattern::OpRewritePattern; 126 SaveResultOpConversion(mlir::MLIRContext *context) 127 : OpRewritePattern(context) {} 128 mlir::LogicalResult 129 matchAndRewrite(fir::SaveResultOp op, 130 mlir::PatternRewriter &rewriter) const override { 131 rewriter.eraseOp(op); 132 return mlir::success(); 133 } 134 }; 135 136 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 137 public: 138 using OpRewritePattern::OpRewritePattern; 139 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 140 : OpRewritePattern(context), newArg{newArg} {} 141 mlir::LogicalResult 142 matchAndRewrite(mlir::func::ReturnOp ret, 143 mlir::PatternRewriter &rewriter) const override { 144 rewriter.setInsertionPoint(ret); 145 auto returnedValue = ret.getOperand(0); 146 bool replacedStorage = false; 147 if (auto *op = returnedValue.getDefiningOp()) 148 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 149 auto resultStorage = load.getMemref(); 150 load.getMemref().replaceAllUsesWith(newArg); 151 replacedStorage = true; 152 if (auto *alloc = resultStorage.getDefiningOp()) 153 if (alloc->use_empty()) 154 rewriter.eraseOp(alloc); 155 } 156 // The result storage may have been optimized out by a memory to 157 // register pass, this is possible for fir.box results, or fir.record 158 // with no length parameters. Simply store the result in the result storage. 159 // at the return point. 160 if (!replacedStorage) 161 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg); 162 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 163 return mlir::success(); 164 } 165 166 private: 167 mlir::Value newArg; 168 }; 169 170 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 171 public: 172 using OpRewritePattern::OpRewritePattern; 173 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 174 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 175 mlir::LogicalResult 176 matchAndRewrite(fir::AddrOfOp addrOf, 177 mlir::PatternRewriter &rewriter) const override { 178 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 179 auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 180 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 181 addrOf.getSymbol()); 182 // Rather than converting all op a function pointer might transit through 183 // (e.g calls, stores, loads, converts...), cast new type to the abstract 184 // type. A conversion will be added when calling indirect calls of abstract 185 // types. 186 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 187 return mlir::success(); 188 } 189 190 private: 191 bool shouldBoxResult; 192 }; 193 194 /// @brief Base CRTP class for AbstractResult pass family. 195 /// Contains common logic for abstract result conversion in a reusable fashion. 196 /// @tparam Pass target class that implements operation-specific logic. 197 /// @tparam PassBase base class template for the pass generated by TableGen. 198 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, 199 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. 200 /// This function should implement operation-specific functionality. 201 template <typename Pass, template <typename> class PassBase> 202 class AbstractResultOptTemplate : public PassBase<Pass> { 203 public: 204 void runOnOperation() override { 205 auto *context = &this->getContext(); 206 auto op = this->getOperation(); 207 208 mlir::RewritePatternSet patterns(context); 209 mlir::ConversionTarget target = *context; 210 const bool shouldBoxResult = this->passResultAsBox.getValue(); 211 212 auto &self = static_cast<Pass &>(*this); 213 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); 214 215 // Convert the calls and, if needed, the ReturnOp in the function body. 216 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect, 217 mlir::func::FuncDialect>(); 218 target.addIllegalOp<fir::SaveResultOp>(); 219 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 220 return !hasAbstractResult(call.getFunctionType()); 221 }); 222 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 223 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 224 return !hasAbstractResult(funTy); 225 return true; 226 }); 227 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 228 if (dispatch->getNumResults() != 1) 229 return true; 230 auto resultType = dispatch->getResult(0).getType(); 231 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 232 TODO(dispatch.getLoc(), "dispatchOp with abstract results"); 233 return false; 234 } 235 return true; 236 }); 237 238 patterns.insert<CallOpConversion>(context, shouldBoxResult); 239 patterns.insert<SaveResultOpConversion>(context); 240 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 241 if (mlir::failed( 242 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 243 mlir::emitError(op.getLoc(), "error in converting abstract results\n"); 244 this->signalPassFailure(); 245 } 246 } 247 }; 248 249 class AbstractResultOnFuncOpt 250 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt, 251 fir::AbstractResultOnFuncOptBase> { 252 public: 253 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 254 mlir::RewritePatternSet &patterns, 255 mlir::ConversionTarget &target) { 256 auto loc = func.getLoc(); 257 auto *context = &getContext(); 258 // Convert function type itself if it has an abstract result. 259 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 260 if (hasAbstractResult(funcTy)) { 261 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 262 if (!func.empty()) { 263 // Insert new argument. 264 mlir::OpBuilder rewriter(context); 265 auto resultType = funcTy.getResult(0); 266 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 267 mlir::Value newArg = func.front().insertArgument(0u, argTy, loc); 268 if (mustEmboxResult(resultType, shouldBoxResult)) { 269 auto bufferType = fir::ReferenceType::get(resultType); 270 rewriter.setInsertionPointToStart(&func.front()); 271 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 272 } 273 patterns.insert<ReturnOpConversion>(context, newArg); 274 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 275 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 276 } 277 } 278 } 279 }; 280 281 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 282 return mlir::TypeSwitch<mlir::Type, bool>(type) 283 .Case([](fir::BoxProcType boxProc) { 284 return fir::hasAbstractResult( 285 boxProc.getEleTy().cast<mlir::FunctionType>()); 286 }) 287 .Case([](fir::PointerType pointer) { 288 return fir::hasAbstractResult( 289 pointer.getEleTy().cast<mlir::FunctionType>()); 290 }) 291 .Default([](auto &&) { return false; }); 292 } 293 294 class AbstractResultOnGlobalOpt 295 : public AbstractResultOptTemplate<AbstractResultOnGlobalOpt, 296 fir::AbstractResultOnGlobalOptBase> { 297 public: 298 void runOnSpecificOperation(fir::GlobalOp global, bool, 299 mlir::RewritePatternSet &, 300 mlir::ConversionTarget &) { 301 if (containsFunctionTypeWithAbstractResult(global.getType())) { 302 TODO(global->getLoc(), "support for procedure pointers"); 303 } 304 } 305 }; 306 } // end anonymous namespace 307 } // namespace fir 308 309 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() { 310 return std::make_unique<AbstractResultOnFuncOpt>(); 311 } 312 313 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() { 314 return std::make_unique<AbstractResultOnGlobalOpt>(); 315 } 316