1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Transforms/Passes.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/Transforms/Passes.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 #define DEBUG_TYPE "flang-abstract-result-opt" 23 24 namespace fir { 25 namespace { 26 27 struct AbstractResultOptions { 28 // Always pass result as a fir.box argument. 29 bool boxResult = false; 30 // New function block argument for the result if the current FuncOp had 31 // an abstract result. 32 mlir::Value newArg; 33 }; 34 35 static mlir::Type getResultArgumentType(mlir::Type resultType, 36 const AbstractResultOptions &options) { 37 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 38 .Case<fir::SequenceType, fir::RecordType>( 39 [&](mlir::Type type) -> mlir::Type { 40 if (options.boxResult) 41 return fir::BoxType::get(type); 42 return fir::ReferenceType::get(type); 43 }) 44 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 45 return fir::ReferenceType::get(type); 46 }) 47 .Default([](mlir::Type) -> mlir::Type { 48 llvm_unreachable("bad abstract result type"); 49 }); 50 } 51 52 static mlir::FunctionType 53 getNewFunctionType(mlir::FunctionType funcTy, 54 const AbstractResultOptions &options) { 55 auto resultType = funcTy.getResult(0); 56 auto argTy = getResultArgumentType(resultType, options); 57 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 58 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 59 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 60 /*resultTypes=*/{}); 61 } 62 63 static bool mustEmboxResult(mlir::Type resultType, 64 const AbstractResultOptions &options) { 65 return resultType.isa<fir::SequenceType, fir::RecordType>() && 66 options.boxResult; 67 } 68 69 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 70 public: 71 using OpRewritePattern::OpRewritePattern; 72 CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) 73 : OpRewritePattern(context), options{opt} {} 74 mlir::LogicalResult 75 matchAndRewrite(fir::CallOp callOp, 76 mlir::PatternRewriter &rewriter) const override { 77 auto loc = callOp.getLoc(); 78 auto result = callOp->getResult(0); 79 if (!result.hasOneUse()) { 80 mlir::emitError(loc, 81 "calls with abstract result must have exactly one user"); 82 return mlir::failure(); 83 } 84 auto saveResult = 85 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 86 if (!saveResult) { 87 mlir::emitError( 88 loc, "calls with abstract result must be used in fir.save_result"); 89 return mlir::failure(); 90 } 91 auto argType = getResultArgumentType(result.getType(), options); 92 auto buffer = saveResult.getMemref(); 93 mlir::Value arg = buffer; 94 if (mustEmboxResult(result.getType(), options)) 95 arg = rewriter.create<fir::EmboxOp>( 96 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 97 saveResult.getTypeparams()); 98 99 llvm::SmallVector<mlir::Type> newResultTypes; 100 if (callOp.getCallee()) { 101 llvm::SmallVector<mlir::Value> newOperands = {arg}; 102 newOperands.append(callOp.getOperands().begin(), 103 callOp.getOperands().end()); 104 rewriter.create<fir::CallOp>(loc, callOp.getCallee().value(), 105 newResultTypes, newOperands); 106 } else { 107 // Indirect calls. 108 llvm::SmallVector<mlir::Type> newInputTypes = {argType}; 109 for (auto operand : callOp.getOperands().drop_front()) 110 newInputTypes.push_back(operand.getType()); 111 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, 112 newResultTypes); 113 114 llvm::SmallVector<mlir::Value> newOperands; 115 newOperands.push_back( 116 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0))); 117 newOperands.push_back(arg); 118 newOperands.append(callOp.getOperands().begin() + 1, 119 callOp.getOperands().end()); 120 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes, 121 newOperands); 122 } 123 callOp->dropAllReferences(); 124 rewriter.eraseOp(callOp); 125 return mlir::success(); 126 } 127 128 private: 129 const AbstractResultOptions &options; 130 }; 131 132 class SaveResultOpConversion 133 : public mlir::OpRewritePattern<fir::SaveResultOp> { 134 public: 135 using OpRewritePattern::OpRewritePattern; 136 SaveResultOpConversion(mlir::MLIRContext *context) 137 : OpRewritePattern(context) {} 138 mlir::LogicalResult 139 matchAndRewrite(fir::SaveResultOp op, 140 mlir::PatternRewriter &rewriter) const override { 141 rewriter.eraseOp(op); 142 return mlir::success(); 143 } 144 }; 145 146 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 147 public: 148 using OpRewritePattern::OpRewritePattern; 149 ReturnOpConversion(mlir::MLIRContext *context, 150 const AbstractResultOptions &opt) 151 : OpRewritePattern(context), options{opt} {} 152 mlir::LogicalResult 153 matchAndRewrite(mlir::func::ReturnOp ret, 154 mlir::PatternRewriter &rewriter) const override { 155 rewriter.setInsertionPoint(ret); 156 auto returnedValue = ret.getOperand(0); 157 bool replacedStorage = false; 158 if (auto *op = returnedValue.getDefiningOp()) 159 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 160 auto resultStorage = load.getMemref(); 161 load.getMemref().replaceAllUsesWith(options.newArg); 162 replacedStorage = true; 163 if (auto *alloc = resultStorage.getDefiningOp()) 164 if (alloc->use_empty()) 165 rewriter.eraseOp(alloc); 166 } 167 // The result storage may have been optimized out by a memory to 168 // register pass, this is possible for fir.box results, or fir.record 169 // with no length parameters. Simply store the result in the result storage. 170 // at the return point. 171 if (!replacedStorage) 172 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, 173 options.newArg); 174 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 175 return mlir::success(); 176 } 177 178 private: 179 const AbstractResultOptions &options; 180 }; 181 182 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 183 public: 184 using OpRewritePattern::OpRewritePattern; 185 AddrOfOpConversion(mlir::MLIRContext *context, 186 const AbstractResultOptions &opt) 187 : OpRewritePattern(context), options{opt} {} 188 mlir::LogicalResult 189 matchAndRewrite(fir::AddrOfOp addrOf, 190 mlir::PatternRewriter &rewriter) const override { 191 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 192 auto newFuncTy = getNewFunctionType(oldFuncTy, options); 193 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 194 addrOf.getSymbol()); 195 // Rather than converting all op a function pointer might transit through 196 // (e.g calls, stores, loads, converts...), cast new type to the abstract 197 // type. A conversion will be added when calling indirect calls of abstract 198 // types. 199 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 200 return mlir::success(); 201 } 202 203 private: 204 const AbstractResultOptions &options; 205 }; 206 207 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> { 208 public: 209 void runOnOperation() override { 210 auto *context = &getContext(); 211 auto func = getOperation(); 212 auto loc = func.getLoc(); 213 mlir::RewritePatternSet patterns(context); 214 mlir::ConversionTarget target = *context; 215 AbstractResultOptions options{passResultAsBox.getValue(), 216 /*newArg=*/{}}; 217 218 // Convert function type itself if it has an abstract result 219 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 220 if (hasAbstractResult(funcTy)) { 221 func.setType(getNewFunctionType(funcTy, options)); 222 unsigned zero = 0; 223 if (!func.empty()) { 224 // Insert new argument 225 mlir::OpBuilder rewriter(context); 226 auto resultType = funcTy.getResult(0); 227 auto argTy = getResultArgumentType(resultType, options); 228 options.newArg = func.front().insertArgument(zero, argTy, loc); 229 if (mustEmboxResult(resultType, options)) { 230 auto bufferType = fir::ReferenceType::get(resultType); 231 rewriter.setInsertionPointToStart(&func.front()); 232 options.newArg = 233 rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg); 234 } 235 patterns.insert<ReturnOpConversion>(context, options); 236 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 237 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 238 } 239 } 240 241 if (func.empty()) 242 return; 243 244 // Convert the calls and, if needed, the ReturnOp in the function body. 245 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect, 246 mlir::func::FuncDialect>(); 247 target.addIllegalOp<fir::SaveResultOp>(); 248 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 249 return !hasAbstractResult(call.getFunctionType()); 250 }); 251 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 252 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 253 return !hasAbstractResult(funTy); 254 return true; 255 }); 256 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 257 if (dispatch->getNumResults() != 1) 258 return true; 259 auto resultType = dispatch->getResult(0).getType(); 260 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 261 TODO(dispatch.getLoc(), "dispatchOp with abstract results"); 262 return false; 263 } 264 return true; 265 }); 266 267 patterns.insert<CallOpConversion>(context, options); 268 patterns.insert<SaveResultOpConversion>(context); 269 patterns.insert<AddrOfOpConversion>(context, options); 270 if (mlir::failed( 271 mlir::applyPartialConversion(func, target, std::move(patterns)))) { 272 mlir::emitError(func.getLoc(), "error in converting abstract results\n"); 273 signalPassFailure(); 274 } 275 } 276 }; 277 } // end anonymous namespace 278 } // namespace fir 279 280 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() { 281 return std::make_unique<AbstractResultOpt>(); 282 } 283