1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/Passes.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 #define DEBUG_TYPE "flang-abstract-result-opt" 22 23 namespace fir { 24 namespace { 25 26 struct AbstractResultOptions { 27 // Always pass result as a fir.box argument. 28 bool boxResult = false; 29 // New function block argument for the result if the current FuncOp had 30 // an abstract result. 31 mlir::Value newArg; 32 }; 33 34 static bool mustConvertCallOrFunc(mlir::FunctionType type) { 35 if (type.getNumResults() == 0) 36 return false; 37 auto resultType = type.getResult(0); 38 return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>(); 39 } 40 41 static mlir::Type getResultArgumentType(mlir::Type resultType, 42 const AbstractResultOptions &options) { 43 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 44 .Case<fir::SequenceType, fir::RecordType>( 45 [&](mlir::Type type) -> mlir::Type { 46 if (options.boxResult) 47 return fir::BoxType::get(type); 48 return fir::ReferenceType::get(type); 49 }) 50 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 51 return fir::ReferenceType::get(type); 52 }) 53 .Default([](mlir::Type) -> mlir::Type { 54 llvm_unreachable("bad abstract result type"); 55 }); 56 } 57 58 static mlir::FunctionType 59 getNewFunctionType(mlir::FunctionType funcTy, 60 const AbstractResultOptions &options) { 61 auto resultType = funcTy.getResult(0); 62 auto argTy = getResultArgumentType(resultType, options); 63 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 64 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 65 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 66 /*resultTypes=*/{}); 67 } 68 69 static bool mustEmboxResult(mlir::Type resultType, 70 const AbstractResultOptions &options) { 71 return resultType.isa<fir::SequenceType, fir::RecordType>() && 72 options.boxResult; 73 } 74 75 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 76 public: 77 using OpRewritePattern::OpRewritePattern; 78 CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) 79 : OpRewritePattern(context), options{opt} {} 80 mlir::LogicalResult 81 matchAndRewrite(fir::CallOp callOp, 82 mlir::PatternRewriter &rewriter) const override { 83 auto loc = callOp.getLoc(); 84 auto result = callOp->getResult(0); 85 if (!result.hasOneUse()) { 86 mlir::emitError(loc, 87 "calls with abstract result must have exactly one user"); 88 return mlir::failure(); 89 } 90 auto saveResult = 91 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 92 if (!saveResult) { 93 mlir::emitError( 94 loc, "calls with abstract result must be used in fir.save_result"); 95 return mlir::failure(); 96 } 97 auto argType = getResultArgumentType(result.getType(), options); 98 auto buffer = saveResult.getMemref(); 99 mlir::Value arg = buffer; 100 if (mustEmboxResult(result.getType(), options)) 101 arg = rewriter.create<fir::EmboxOp>( 102 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 103 saveResult.getTypeparams()); 104 105 llvm::SmallVector<mlir::Type> newResultTypes; 106 if (callOp.getCallee()) { 107 llvm::SmallVector<mlir::Value> newOperands = {arg}; 108 newOperands.append(callOp.getOperands().begin(), 109 callOp.getOperands().end()); 110 rewriter.create<fir::CallOp>(loc, callOp.getCallee().getValue(), 111 newResultTypes, newOperands); 112 } else { 113 // Indirect calls. 114 llvm::SmallVector<mlir::Type> newInputTypes = {argType}; 115 for (auto operand : callOp.getOperands().drop_front()) 116 newInputTypes.push_back(operand.getType()); 117 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, 118 newResultTypes); 119 120 llvm::SmallVector<mlir::Value> newOperands; 121 newOperands.push_back( 122 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0))); 123 newOperands.push_back(arg); 124 newOperands.append(callOp.getOperands().begin() + 1, 125 callOp.getOperands().end()); 126 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes, 127 newOperands); 128 } 129 callOp->dropAllReferences(); 130 rewriter.eraseOp(callOp); 131 return mlir::success(); 132 } 133 134 private: 135 const AbstractResultOptions &options; 136 }; 137 138 class SaveResultOpConversion 139 : public mlir::OpRewritePattern<fir::SaveResultOp> { 140 public: 141 using OpRewritePattern::OpRewritePattern; 142 SaveResultOpConversion(mlir::MLIRContext *context) 143 : OpRewritePattern(context) {} 144 mlir::LogicalResult 145 matchAndRewrite(fir::SaveResultOp op, 146 mlir::PatternRewriter &rewriter) const override { 147 rewriter.eraseOp(op); 148 return mlir::success(); 149 } 150 }; 151 152 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::ReturnOp> { 153 public: 154 using OpRewritePattern::OpRewritePattern; 155 ReturnOpConversion(mlir::MLIRContext *context, 156 const AbstractResultOptions &opt) 157 : OpRewritePattern(context), options{opt} {} 158 mlir::LogicalResult 159 matchAndRewrite(mlir::ReturnOp ret, 160 mlir::PatternRewriter &rewriter) const override { 161 rewriter.setInsertionPoint(ret); 162 auto returnedValue = ret.getOperand(0); 163 bool replacedStorage = false; 164 if (auto *op = returnedValue.getDefiningOp()) 165 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 166 auto resultStorage = load.getMemref(); 167 load.getMemref().replaceAllUsesWith(options.newArg); 168 replacedStorage = true; 169 if (auto *alloc = resultStorage.getDefiningOp()) 170 if (alloc->use_empty()) 171 rewriter.eraseOp(alloc); 172 } 173 // The result storage may have been optimized out by a memory to 174 // register pass, this is possible for fir.box results, or fir.record 175 // with no length parameters. Simply store the result in the result storage. 176 // at the return point. 177 if (!replacedStorage) 178 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, 179 options.newArg); 180 rewriter.replaceOpWithNewOp<mlir::ReturnOp>(ret); 181 return mlir::success(); 182 } 183 184 private: 185 const AbstractResultOptions &options; 186 }; 187 188 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 189 public: 190 using OpRewritePattern::OpRewritePattern; 191 AddrOfOpConversion(mlir::MLIRContext *context, 192 const AbstractResultOptions &opt) 193 : OpRewritePattern(context), options{opt} {} 194 mlir::LogicalResult 195 matchAndRewrite(fir::AddrOfOp addrOf, 196 mlir::PatternRewriter &rewriter) const override { 197 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 198 auto newFuncTy = getNewFunctionType(oldFuncTy, options); 199 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 200 addrOf.getSymbol()); 201 // Rather than converting all op a function pointer might transit through 202 // (e.g calls, stores, loads, converts...), cast new type to the abstract 203 // type. A conversion will be added when calling indirect calls of abstract 204 // types. 205 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 206 return mlir::success(); 207 } 208 209 private: 210 const AbstractResultOptions &options; 211 }; 212 213 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> { 214 public: 215 void runOnOperation() override { 216 auto *context = &getContext(); 217 auto func = getOperation(); 218 auto loc = func.getLoc(); 219 mlir::RewritePatternSet patterns(context); 220 mlir::ConversionTarget target = *context; 221 AbstractResultOptions options{passResultAsBox.getValue(), 222 /*newArg=*/{}}; 223 224 // Convert function type itself if it has an abstract result 225 auto funcTy = func.getType().cast<mlir::FunctionType>(); 226 if (mustConvertCallOrFunc(funcTy)) { 227 func.setType(getNewFunctionType(funcTy, options)); 228 unsigned zero = 0; 229 if (!func.empty()) { 230 // Insert new argument 231 mlir::OpBuilder rewriter(context); 232 auto resultType = funcTy.getResult(0); 233 auto argTy = getResultArgumentType(resultType, options); 234 options.newArg = func.front().insertArgument(zero, argTy, loc); 235 if (mustEmboxResult(resultType, options)) { 236 auto bufferType = fir::ReferenceType::get(resultType); 237 rewriter.setInsertionPointToStart(&func.front()); 238 options.newArg = 239 rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg); 240 } 241 patterns.insert<ReturnOpConversion>(context, options); 242 target.addDynamicallyLegalOp<mlir::ReturnOp>( 243 [](mlir::ReturnOp ret) { return ret.operands().empty(); }); 244 } 245 } 246 247 if (func.empty()) 248 return; 249 250 // Convert the calls and, if needed, the ReturnOp in the function body. 251 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect, 252 mlir::StandardOpsDialect>(); 253 target.addIllegalOp<fir::SaveResultOp>(); 254 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 255 return !mustConvertCallOrFunc(call.getFunctionType()); 256 }); 257 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 258 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 259 return !mustConvertCallOrFunc(funTy); 260 return true; 261 }); 262 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 263 if (dispatch->getNumResults() != 1) 264 return true; 265 auto resultType = dispatch->getResult(0).getType(); 266 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 267 mlir::emitError(dispatch.getLoc(), 268 "TODO: dispatchOp with abstract results"); 269 return false; 270 } 271 return true; 272 }); 273 274 patterns.insert<CallOpConversion>(context, options); 275 patterns.insert<SaveResultOpConversion>(context); 276 patterns.insert<AddrOfOpConversion>(context, options); 277 if (mlir::failed( 278 mlir::applyPartialConversion(func, target, std::move(patterns)))) { 279 mlir::emitError(func.getLoc(), "error in converting abstract results\n"); 280 signalPassFailure(); 281 } 282 } 283 }; 284 } // end anonymous namespace 285 } // namespace fir 286 287 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() { 288 return std::make_unique<AbstractResultOpt>(); 289 } 290