xref: /llvm-project/flang/include/flang/Optimizer/Support/Utils.h (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- Optimizer/Support/Utils.h -------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
14 #define FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
15 
16 #include "flang/Common/default-kinds.h"
17 #include "flang/Optimizer/Builder/FIRBuilder.h"
18 #include "flang/Optimizer/Builder/Todo.h"
19 #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Support/FatalError.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/StringRef.h"
29 
30 namespace fir {
31 /// Return the integer value of a arith::ConstantOp.
32 inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
33   return mlir::cast<mlir::IntegerAttr>(cop.getValue())
34       .getValue()
35       .getSExtValue();
36 }
37 
38 // Reconstruct binding tables for dynamic dispatch.
39 using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
40 using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;
41 
42 inline void buildBindingTables(BindingTables &bindingTables,
43                                mlir::ModuleOp mod) {
44 
45   // The binding tables are defined in FIR after lowering inside fir.type_info
46   // operations. Go through each binding tables and store the procedure name and
47   // binding index for later use by the fir.dispatch conversion pattern.
48   for (auto typeInfo : mod.getOps<fir::TypeInfoOp>()) {
49     unsigned bindingIdx = 0;
50     BindingTable bindings;
51     if (typeInfo.getDispatchTable().empty()) {
52       bindingTables[typeInfo.getSymName()] = bindings;
53       continue;
54     }
55     for (auto dtEntry :
56          typeInfo.getDispatchTable().front().getOps<fir::DTEntryOp>()) {
57       bindings[dtEntry.getMethod()] = bindingIdx;
58       ++bindingIdx;
59     }
60     bindingTables[typeInfo.getSymName()] = bindings;
61   }
62 }
63 
64 // Translate front-end KINDs for use in the IR and code gen.
65 inline std::vector<fir::KindTy>
66 fromDefaultKinds(const Fortran::common::IntrinsicTypeDefaultKinds &defKinds) {
67   return {static_cast<fir::KindTy>(defKinds.GetDefaultKind(
68               Fortran::common::TypeCategory::Character)),
69           static_cast<fir::KindTy>(
70               defKinds.GetDefaultKind(Fortran::common::TypeCategory::Complex)),
71           static_cast<fir::KindTy>(defKinds.doublePrecisionKind()),
72           static_cast<fir::KindTy>(
73               defKinds.GetDefaultKind(Fortran::common::TypeCategory::Integer)),
74           static_cast<fir::KindTy>(
75               defKinds.GetDefaultKind(Fortran::common::TypeCategory::Logical)),
76           static_cast<fir::KindTy>(
77               defKinds.GetDefaultKind(Fortran::common::TypeCategory::Real))};
78 }
79 
80 inline std::string mlirTypeToString(mlir::Type type) {
81   std::string result{};
82   llvm::raw_string_ostream sstream(result);
83   sstream << type;
84   return result;
85 }
86 
87 inline std::optional<int> mlirFloatTypeToKind(mlir::Type type) {
88   if (type.isF16())
89     return 2;
90   else if (type.isBF16())
91     return 3;
92   else if (type.isF32())
93     return 4;
94   else if (type.isF64())
95     return 8;
96   else if (type.isF80())
97     return 10;
98   else if (type.isF128())
99     return 16;
100   return std::nullopt;
101 }
102 
103 inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
104                                               mlir::Type type,
105                                               mlir::Location loc,
106                                               const llvm::Twine &name) {
107   if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(type)) {
108     if (std::optional<int> kind = mlirFloatTypeToKind(type))
109       return "REAL(KIND="s + std::to_string(*kind) + ")";
110   } else if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
111     if (std::optional<int> kind = mlirFloatTypeToKind(cplxTy.getElementType()))
112       return "COMPLEX(KIND+"s + std::to_string(*kind) + ")";
113   } else if (type.isUnsignedInteger()) {
114     if (type.isInteger(8))
115       return "UNSIGNED(KIND=1)";
116     else if (type.isInteger(16))
117       return "UNSIGNED(KIND=2)";
118     else if (type.isInteger(32))
119       return "UNSIGNED(KIND=4)";
120     else if (type.isInteger(64))
121       return "UNSIGNED(KIND=8)";
122     else if (type.isInteger(128))
123       return "UNSIGNED(KIND=16)";
124   } else if (type.isInteger(8))
125     return "INTEGER(KIND=1)";
126   else if (type.isInteger(16))
127     return "INTEGER(KIND=2)";
128   else if (type.isInteger(32))
129     return "INTEGER(KIND=4)";
130   else if (type.isInteger(64))
131     return "INTEGER(KIND=8)";
132   else if (type.isInteger(128))
133     return "INTEGER(KIND=16)";
134   else if (type == fir::LogicalType::get(builder.getContext(), 1))
135     return "LOGICAL(KIND=1)";
136   else if (type == fir::LogicalType::get(builder.getContext(), 2))
137     return "LOGICAL(KIND=2)";
138   else if (type == fir::LogicalType::get(builder.getContext(), 4))
139     return "LOGICAL(KIND=4)";
140   else if (type == fir::LogicalType::get(builder.getContext(), 8))
141     return "LOGICAL(KIND=8)";
142 
143   fir::emitFatalError(loc, "unsupported type in " + name + ": " +
144                                fir::mlirTypeToString(type));
145 }
146 
147 inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
148                               mlir::Location loc,
149                               const llvm::Twine &intrinsicName) {
150   TODO(loc,
151        "intrinsic: " +
152            fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
153            " in " + intrinsicName);
154 }
155 
156 inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
157                                mlir::Type type2, mlir::Location loc,
158                                const llvm::Twine &intrinsicName) {
159   TODO(loc,
160        "intrinsic: {" +
161            fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
162            ", " +
163            fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
164            "} in " + intrinsicName);
165 }
166 
167 inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
168 mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
169   if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(type)) {
170     if (std::optional<int> kind = mlirFloatTypeToKind(type))
171       return {Fortran::common::TypeCategory::Real, *kind};
172   } else if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
173     if (std::optional<int> kind = mlirFloatTypeToKind(cplxTy.getElementType()))
174       return {Fortran::common::TypeCategory::Complex, *kind};
175   } else if (type.isInteger(8))
176     return {type.isUnsignedInteger() ? Fortran::common::TypeCategory::Unsigned
177                                      : Fortran::common::TypeCategory::Integer,
178             1};
179   else if (type.isInteger(16))
180     return {type.isUnsignedInteger() ? Fortran::common::TypeCategory::Unsigned
181                                      : Fortran::common::TypeCategory::Integer,
182             2};
183   else if (type.isInteger(32))
184     return {type.isUnsignedInteger() ? Fortran::common::TypeCategory::Unsigned
185                                      : Fortran::common::TypeCategory::Integer,
186             4};
187   else if (type.isInteger(64))
188     return {type.isUnsignedInteger() ? Fortran::common::TypeCategory::Unsigned
189                                      : Fortran::common::TypeCategory::Integer,
190             8};
191   else if (type.isInteger(128))
192     return {type.isUnsignedInteger() ? Fortran::common::TypeCategory::Unsigned
193                                      : Fortran::common::TypeCategory::Integer,
194             16};
195   else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
196     return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
197   else if (auto charType = mlir::dyn_cast<fir::CharacterType>(type))
198     return {Fortran::common::TypeCategory::Character, charType.getFKind()};
199   else if (mlir::isa<fir::RecordType>(type))
200     return {Fortran::common::TypeCategory::Derived, 0};
201   fir::emitFatalError(loc, "unsupported type: " + fir::mlirTypeToString(type));
202 }
203 
204 /// Find the fir.type_info that was created for this \p recordType in \p module,
205 /// if any. \p  symbolTable can be provided to speed-up the lookup. This tool
206 /// will match record type even if they have been "altered" in type conversion
207 /// passes.
208 fir::TypeInfoOp
209 lookupTypeInfoOp(fir::RecordType recordType, mlir::ModuleOp module,
210                  const mlir::SymbolTable *symbolTable = nullptr);
211 
212 /// Find the fir.type_info named \p name in \p module, if any. \p  symbolTable
213 /// can be provided to speed-up the lookup. Prefer using the equivalent with a
214 /// RecordType argument  unless it is certain \p name has not been altered by a
215 /// pass rewriting fir.type (see NameUniquer::dropTypeConversionMarkers).
216 fir::TypeInfoOp
217 lookupTypeInfoOp(llvm::StringRef name, mlir::ModuleOp module,
218                  const mlir::SymbolTable *symbolTable = nullptr);
219 
220 /// Returns all lower bounds of \p component if it is an array component of \p
221 /// recordType with non default lower bounds. Returns nullopt if this is not an
222 /// array componnet of \p recordType or if its lower bounds are all ones.
223 std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
224     fir::RecordType recordType, llvm::StringRef component,
225     mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
226 
227 } // namespace fir
228 
229 #endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
230