xref: /llvm-project/flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h (revision c870632ef6162fbdccaad8cd09420728220ad344)
1 //===- RuntimeCallTestBase.cpp -- Base for runtime call generation tests --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
10 #define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
11 
12 #include "gtest/gtest.h"
13 #include "flang/Optimizer/Builder/FIRBuilder.h"
14 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
15 #include "flang/Optimizer/Support/InitFIR.h"
16 
17 struct RuntimeCallTest : public testing::Test {
18 public:
19   void SetUp() override {
20     fir::support::loadDialects(context);
21 
22     mlir::OpBuilder builder(&context);
23     auto loc = builder.getUnknownLoc();
24 
25     // Set up a Module with a dummy function operation inside.
26     // Set the insertion point in the function entry block.
27     moduleOp = builder.create<mlir::ModuleOp>(loc);
28     builder.setInsertionPointToStart(moduleOp->getBody());
29     mlir::func::FuncOp func =
30         builder.create<mlir::func::FuncOp>(loc, "runtime_unit_tests_func",
31             builder.getFunctionType(std::nullopt, std::nullopt));
32     auto *entryBlock = func.addEntryBlock();
33     builder.setInsertionPointToStart(entryBlock);
34 
35     kindMap = std::make_unique<fir::KindMapping>(&context);
36     firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap);
37 
38     i1Ty = firBuilder->getI1Type();
39     i8Ty = firBuilder->getI8Type();
40     i16Ty = firBuilder->getIntegerType(16);
41     i32Ty = firBuilder->getI32Type();
42     i64Ty = firBuilder->getI64Type();
43     i128Ty = firBuilder->getIntegerType(128);
44 
45     f32Ty = firBuilder->getF32Type();
46     f64Ty = firBuilder->getF64Type();
47     f80Ty = firBuilder->getF80Type();
48     f128Ty = firBuilder->getF128Type();
49 
50     c4Ty = mlir::ComplexType::get(f32Ty);
51     c8Ty = mlir::ComplexType::get(f64Ty);
52     c10Ty = mlir::ComplexType::get(f80Ty);
53     c16Ty = mlir::ComplexType::get(f128Ty);
54 
55     seqTy10 = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
56     boxTy = fir::BoxType::get(mlir::NoneType::get(firBuilder->getContext()));
57 
58     char1Ty = fir::CharacterType::getSingleton(builder.getContext(), 1);
59     char2Ty = fir::CharacterType::getSingleton(builder.getContext(), 2);
60     char4Ty = fir::CharacterType::getSingleton(builder.getContext(), 4);
61 
62     logical1Ty = fir::LogicalType::get(builder.getContext(), 1);
63     logical2Ty = fir::LogicalType::get(builder.getContext(), 2);
64     logical4Ty = fir::LogicalType::get(builder.getContext(), 4);
65     logical8Ty = fir::LogicalType::get(builder.getContext(), 8);
66   }
67 
68   mlir::MLIRContext context;
69   mlir::OwningOpRef<mlir::ModuleOp> moduleOp;
70   std::unique_ptr<fir::KindMapping> kindMap;
71   std::unique_ptr<fir::FirOpBuilder> firBuilder;
72 
73   // Commonly used types
74   mlir::Type i1Ty;
75   mlir::Type i8Ty;
76   mlir::Type i16Ty;
77   mlir::Type i32Ty;
78   mlir::Type i64Ty;
79   mlir::Type i128Ty;
80   mlir::Type f32Ty;
81   mlir::Type f64Ty;
82   mlir::Type f80Ty;
83   mlir::Type f128Ty;
84   mlir::Type c4Ty;
85   mlir::Type c8Ty;
86   mlir::Type c10Ty;
87   mlir::Type c16Ty;
88   mlir::Type seqTy10;
89   mlir::Type boxTy;
90   mlir::Type char1Ty;
91   mlir::Type char2Ty;
92   mlir::Type char4Ty;
93   mlir::Type logical1Ty;
94   mlir::Type logical2Ty;
95   mlir::Type logical4Ty;
96   mlir::Type logical8Ty;
97 };
98 
99 /// Check that the \p op is a `fir::CallOp` operation and its name matches
100 /// \p fctName and the number of arguments is equal to \p nbArgs.
101 /// Most runtime calls have two additional location arguments added. These are
102 /// added in this check when \p addLocArgs is true.
103 static inline void checkCallOp(mlir::Operation *op, llvm::StringRef fctName,
104     unsigned nbArgs, bool addLocArgs = true) {
105   EXPECT_TRUE(mlir::isa<fir::CallOp>(*op));
106   auto callOp = mlir::dyn_cast<fir::CallOp>(*op);
107   EXPECT_TRUE(callOp.getCallee().has_value());
108   mlir::SymbolRefAttr callee = *callOp.getCallee();
109   EXPECT_EQ(fctName, callee.getRootReference().getValue());
110   // sourceFile and sourceLine are added arguments.
111   if (addLocArgs)
112     nbArgs += 2;
113   EXPECT_EQ(nbArgs, callOp.getArgs().size());
114 }
115 
116 /// Check the call operation from the \p result value. In some cases the
117 /// value is directly used in the call and sometimes there is an indirection
118 /// through a `fir.convert` operation. Once the `fir.call` operation is
119 /// retrieved the check is made by `checkCallOp`.
120 ///
121 /// Directly used in `fir.call`.
122 /// ```
123 /// %result = arith.constant 1 : i32
124 /// %0 = fir.call @foo(%result) : (i32) -> i1
125 /// ```
126 ///
127 /// Value used in `fir.call` through `fir.convert` indirection.
128 /// ```
129 /// %result = arith.constant 1 : i32
130 /// %arg = fir.convert %result : (i32) -> i16
131 /// %0 = fir.call @foo(%arg) : (i16) -> i1
132 /// ```
133 static inline void checkCallOpFromResultBox(mlir::Value result,
134     llvm::StringRef fctName, unsigned nbArgs, bool addLocArgs = true) {
135   EXPECT_TRUE(result.hasOneUse());
136   const auto &u = result.user_begin();
137   if (mlir::isa<fir::CallOp>(*u))
138     return checkCallOp(*u, fctName, nbArgs, addLocArgs);
139   auto convOp = mlir::dyn_cast<fir::ConvertOp>(*u);
140   EXPECT_NE(nullptr, convOp);
141   checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs, addLocArgs);
142 }
143 
144 /// Check the operations in \p block for a `fir::CallOp` operation where the
145 /// function being called shares its function name with \p fctName and the
146 /// number of arguments is equal to \p nbArgs. Note that this check only cares
147 /// if the operation exists, and not the order in when the operation is called.
148 /// This results in exiting the test as soon as the first correct instance of
149 /// `fir::CallOp` is found).
150 static inline void checkBlockForCallOp(
151     mlir::Block *block, llvm::StringRef fctName, unsigned nbArgs) {
152   assert(block && "mlir::Block given is a nullptr");
153   for (auto &op : block->getOperations()) {
154     if (auto callOp = mlir::dyn_cast<fir::CallOp>(op)) {
155       if (fctName == callOp.getCallee()->getRootReference().getValue()) {
156         EXPECT_EQ(nbArgs, callOp.getArgs().size());
157         return;
158       }
159     }
160   }
161   FAIL() << "No calls to " << fctName << " were found!";
162 }
163 
164 #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
165