1 //===- RuntimeCallTestBase.cpp -- Base for runtime call generation tests --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H 10 #define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H 11 12 #include "gtest/gtest.h" 13 #include "flang/Optimizer/Builder/FIRBuilder.h" 14 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 15 #include "flang/Optimizer/Support/InitFIR.h" 16 17 struct RuntimeCallTest : public testing::Test { 18 public: 19 void SetUp() override { 20 fir::support::loadDialects(context); 21 22 mlir::OpBuilder builder(&context); 23 auto loc = builder.getUnknownLoc(); 24 25 // Set up a Module with a dummy function operation inside. 26 // Set the insertion point in the function entry block. 27 moduleOp = builder.create<mlir::ModuleOp>(loc); 28 builder.setInsertionPointToStart(moduleOp->getBody()); 29 mlir::func::FuncOp func = 30 builder.create<mlir::func::FuncOp>(loc, "runtime_unit_tests_func", 31 builder.getFunctionType(std::nullopt, std::nullopt)); 32 auto *entryBlock = func.addEntryBlock(); 33 builder.setInsertionPointToStart(entryBlock); 34 35 kindMap = std::make_unique<fir::KindMapping>(&context); 36 firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap); 37 38 i1Ty = firBuilder->getI1Type(); 39 i8Ty = firBuilder->getI8Type(); 40 i16Ty = firBuilder->getIntegerType(16); 41 i32Ty = firBuilder->getI32Type(); 42 i64Ty = firBuilder->getI64Type(); 43 i128Ty = firBuilder->getIntegerType(128); 44 45 f32Ty = firBuilder->getF32Type(); 46 f64Ty = firBuilder->getF64Type(); 47 f80Ty = firBuilder->getF80Type(); 48 f128Ty = firBuilder->getF128Type(); 49 50 c4Ty = mlir::ComplexType::get(f32Ty); 51 c8Ty = mlir::ComplexType::get(f64Ty); 52 c10Ty = mlir::ComplexType::get(f80Ty); 53 c16Ty = mlir::ComplexType::get(f128Ty); 54 55 seqTy10 = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); 56 boxTy = fir::BoxType::get(mlir::NoneType::get(firBuilder->getContext())); 57 58 char1Ty = fir::CharacterType::getSingleton(builder.getContext(), 1); 59 char2Ty = fir::CharacterType::getSingleton(builder.getContext(), 2); 60 char4Ty = fir::CharacterType::getSingleton(builder.getContext(), 4); 61 62 logical1Ty = fir::LogicalType::get(builder.getContext(), 1); 63 logical2Ty = fir::LogicalType::get(builder.getContext(), 2); 64 logical4Ty = fir::LogicalType::get(builder.getContext(), 4); 65 logical8Ty = fir::LogicalType::get(builder.getContext(), 8); 66 } 67 68 mlir::MLIRContext context; 69 mlir::OwningOpRef<mlir::ModuleOp> moduleOp; 70 std::unique_ptr<fir::KindMapping> kindMap; 71 std::unique_ptr<fir::FirOpBuilder> firBuilder; 72 73 // Commonly used types 74 mlir::Type i1Ty; 75 mlir::Type i8Ty; 76 mlir::Type i16Ty; 77 mlir::Type i32Ty; 78 mlir::Type i64Ty; 79 mlir::Type i128Ty; 80 mlir::Type f32Ty; 81 mlir::Type f64Ty; 82 mlir::Type f80Ty; 83 mlir::Type f128Ty; 84 mlir::Type c4Ty; 85 mlir::Type c8Ty; 86 mlir::Type c10Ty; 87 mlir::Type c16Ty; 88 mlir::Type seqTy10; 89 mlir::Type boxTy; 90 mlir::Type char1Ty; 91 mlir::Type char2Ty; 92 mlir::Type char4Ty; 93 mlir::Type logical1Ty; 94 mlir::Type logical2Ty; 95 mlir::Type logical4Ty; 96 mlir::Type logical8Ty; 97 }; 98 99 /// Check that the \p op is a `fir::CallOp` operation and its name matches 100 /// \p fctName and the number of arguments is equal to \p nbArgs. 101 /// Most runtime calls have two additional location arguments added. These are 102 /// added in this check when \p addLocArgs is true. 103 static inline void checkCallOp(mlir::Operation *op, llvm::StringRef fctName, 104 unsigned nbArgs, bool addLocArgs = true) { 105 EXPECT_TRUE(mlir::isa<fir::CallOp>(*op)); 106 auto callOp = mlir::dyn_cast<fir::CallOp>(*op); 107 EXPECT_TRUE(callOp.getCallee().has_value()); 108 mlir::SymbolRefAttr callee = *callOp.getCallee(); 109 EXPECT_EQ(fctName, callee.getRootReference().getValue()); 110 // sourceFile and sourceLine are added arguments. 111 if (addLocArgs) 112 nbArgs += 2; 113 EXPECT_EQ(nbArgs, callOp.getArgs().size()); 114 } 115 116 /// Check the call operation from the \p result value. In some cases the 117 /// value is directly used in the call and sometimes there is an indirection 118 /// through a `fir.convert` operation. Once the `fir.call` operation is 119 /// retrieved the check is made by `checkCallOp`. 120 /// 121 /// Directly used in `fir.call`. 122 /// ``` 123 /// %result = arith.constant 1 : i32 124 /// %0 = fir.call @foo(%result) : (i32) -> i1 125 /// ``` 126 /// 127 /// Value used in `fir.call` through `fir.convert` indirection. 128 /// ``` 129 /// %result = arith.constant 1 : i32 130 /// %arg = fir.convert %result : (i32) -> i16 131 /// %0 = fir.call @foo(%arg) : (i16) -> i1 132 /// ``` 133 static inline void checkCallOpFromResultBox(mlir::Value result, 134 llvm::StringRef fctName, unsigned nbArgs, bool addLocArgs = true) { 135 EXPECT_TRUE(result.hasOneUse()); 136 const auto &u = result.user_begin(); 137 if (mlir::isa<fir::CallOp>(*u)) 138 return checkCallOp(*u, fctName, nbArgs, addLocArgs); 139 auto convOp = mlir::dyn_cast<fir::ConvertOp>(*u); 140 EXPECT_NE(nullptr, convOp); 141 checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs, addLocArgs); 142 } 143 144 /// Check the operations in \p block for a `fir::CallOp` operation where the 145 /// function being called shares its function name with \p fctName and the 146 /// number of arguments is equal to \p nbArgs. Note that this check only cares 147 /// if the operation exists, and not the order in when the operation is called. 148 /// This results in exiting the test as soon as the first correct instance of 149 /// `fir::CallOp` is found). 150 static inline void checkBlockForCallOp( 151 mlir::Block *block, llvm::StringRef fctName, unsigned nbArgs) { 152 assert(block && "mlir::Block given is a nullptr"); 153 for (auto &op : block->getOperations()) { 154 if (auto callOp = mlir::dyn_cast<fir::CallOp>(op)) { 155 if (fctName == callOp.getCallee()->getRootReference().getValue()) { 156 EXPECT_EQ(nbArgs, callOp.getArgs().size()); 157 return; 158 } 159 } 160 } 161 FAIL() << "No calls to " << fctName << " were found!"; 162 } 163 164 #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H 165