xref: /llvm-project/flang/lib/Optimizer/Dialect/FIRAttr.cpp (revision 2051a7bcd3f375c063f803df3cfde9e6e6d724ad)
1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRAttr.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
16 #include "mlir/IR/AttributeSupport.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc"
25 #define GET_ATTRDEF_CLASSES
26 #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc"
27 
28 using namespace fir;
29 
30 namespace fir::detail {
31 
32 struct RealAttributeStorage : public mlir::AttributeStorage {
33   using KeyTy = std::pair<int, llvm::APFloat>;
34 
35   RealAttributeStorage(int kind, const llvm::APFloat &value)
36       : kind(kind), value(value) {}
37   RealAttributeStorage(const KeyTy &key)
38       : RealAttributeStorage(key.first, key.second) {}
39 
40   static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
41 
42   bool operator==(const KeyTy &key) const {
43     return key.first == kind &&
44            key.second.compare(value) == llvm::APFloatBase::cmpEqual;
45   }
46 
47   static RealAttributeStorage *
48   construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
49     return new (allocator.allocate<RealAttributeStorage>())
50         RealAttributeStorage(key);
51   }
52 
53   KindTy getFKind() const { return kind; }
54   llvm::APFloat getValue() const { return value; }
55 
56 private:
57   int kind;
58   llvm::APFloat value;
59 };
60 
61 /// An attribute representing a reference to a type.
62 struct TypeAttributeStorage : public mlir::AttributeStorage {
63   using KeyTy = mlir::Type;
64 
65   TypeAttributeStorage(mlir::Type value) : value(value) {
66     assert(value && "must not be of Type null");
67   }
68 
69   /// Key equality function.
70   bool operator==(const KeyTy &key) const { return key == value; }
71 
72   /// Construct a new storage instance.
73   static TypeAttributeStorage *
74   construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
75     return new (allocator.allocate<TypeAttributeStorage>())
76         TypeAttributeStorage(key);
77   }
78 
79   mlir::Type getType() const { return value; }
80 
81 private:
82   mlir::Type value;
83 };
84 } // namespace fir::detail
85 
86 //===----------------------------------------------------------------------===//
87 // Attributes for SELECT TYPE
88 //===----------------------------------------------------------------------===//
89 
90 ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
91   return Base::get(value.getContext(), value);
92 }
93 
94 mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
95 
96 SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
97   return Base::get(value.getContext(), value);
98 }
99 
100 mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
101 
102 //===----------------------------------------------------------------------===//
103 // Attributes for SELECT CASE
104 //===----------------------------------------------------------------------===//
105 
106 using AttributeUniquer = mlir::detail::AttributeUniquer;
107 
108 ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
109   return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
110 }
111 
112 UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
113   return AttributeUniquer::get<UpperBoundAttr>(ctxt);
114 }
115 
116 LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
117   return AttributeUniquer::get<LowerBoundAttr>(ctxt);
118 }
119 
120 PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
121   return AttributeUniquer::get<PointIntervalAttr>(ctxt);
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // RealAttr
126 //===----------------------------------------------------------------------===//
127 
128 RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
129                             const RealAttr::ValueType &key) {
130   return Base::get(ctxt, key);
131 }
132 
133 KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
134 
135 llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
136 
137 //===----------------------------------------------------------------------===//
138 // FIR attribute parsing
139 //===----------------------------------------------------------------------===//
140 
141 static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
142                                         mlir::DialectAsmParser &parser,
143                                         mlir::Type type) {
144   int kind = 0;
145   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
146     parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
147     return {};
148   }
149   KindMapping kindMap(dialect->getContext());
150   llvm::APFloat value(0.);
151   if (parser.parseOptionalKeyword("i")) {
152     // `i` not present, so literal float must be present
153     double dontCare;
154     if (parser.parseFloat(dontCare) || parser.parseGreater()) {
155       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
156       return {};
157     }
158     auto fltStr = parser.getFullSymbolSpec()
159                       .drop_until([](char c) { return c == ','; })
160                       .drop_front()
161                       .drop_while([](char c) { return c == ' ' || c == '\t'; })
162                       .take_until([](char c) {
163                         return c == '>' || c == ' ' || c == '\t';
164                       });
165     value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
166   } else {
167     // `i` is present, so literal bitstring (hex) must be present
168     llvm::StringRef hex;
169     if (parser.parseKeyword(&hex) || parser.parseGreater()) {
170       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
171       return {};
172     }
173     const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
174     unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
175     auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
176     value = llvm::APFloat(sem, bits);
177   }
178   return RealAttr::get(dialect->getContext(), {kind, value});
179 }
180 
181 mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser,
182                                                      mlir::Type type) {
183   if (mlir::failed(parser.parseLess()))
184     return {};
185 
186   fir::FortranVariableFlagsEnum flags = {};
187   if (mlir::failed(parser.parseOptionalGreater())) {
188     auto parseFlags = [&]() -> mlir::ParseResult {
189       llvm::StringRef elemName;
190       if (mlir::failed(parser.parseKeyword(&elemName)))
191         return mlir::failure();
192 
193       auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName);
194       if (!elem)
195         return parser.emitError(parser.getNameLoc(),
196                                 "Unknown fortran variable attribute: ")
197                << elemName;
198 
199       flags = flags | *elem;
200       return mlir::success();
201     };
202     if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) ||
203         parser.parseGreater())
204       return {};
205   }
206 
207   return FortranVariableFlagsAttr::get(parser.getContext(), flags);
208 }
209 
210 mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
211                                        mlir::DialectAsmParser &parser,
212                                        mlir::Type type) {
213   auto loc = parser.getNameLoc();
214   llvm::StringRef attrName;
215   mlir::Attribute attr;
216   mlir::OptionalParseResult result =
217       generatedAttributeParser(parser, &attrName, type, attr);
218   if (result.has_value())
219     return attr;
220   if (attrName.empty())
221     return {}; // error reported by generatedAttributeParser
222 
223   if (attrName == ExactTypeAttr::getAttrName()) {
224     mlir::Type type;
225     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
226       parser.emitError(loc, "expected a type");
227       return {};
228     }
229     return ExactTypeAttr::get(type);
230   }
231   if (attrName == SubclassAttr::getAttrName()) {
232     mlir::Type type;
233     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
234       parser.emitError(loc, "expected a subtype");
235       return {};
236     }
237     return SubclassAttr::get(type);
238   }
239   if (attrName == PointIntervalAttr::getAttrName())
240     return PointIntervalAttr::get(dialect->getContext());
241   if (attrName == LowerBoundAttr::getAttrName())
242     return LowerBoundAttr::get(dialect->getContext());
243   if (attrName == UpperBoundAttr::getAttrName())
244     return UpperBoundAttr::get(dialect->getContext());
245   if (attrName == ClosedIntervalAttr::getAttrName())
246     return ClosedIntervalAttr::get(dialect->getContext());
247   if (attrName == RealAttr::getAttrName())
248     return parseFirRealAttr(dialect, parser, type);
249 
250   parser.emitError(loc, "unknown FIR attribute: ") << attrName;
251   return {};
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // FIR attribute pretty printer
256 //===----------------------------------------------------------------------===//
257 
258 void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const {
259   printer << "<";
260   printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags());
261   printer << ">";
262 }
263 
264 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
265                             mlir::DialectAsmPrinter &p) {
266   auto &os = p.getStream();
267   if (auto exact = mlir::dyn_cast<fir::ExactTypeAttr>(attr)) {
268     os << fir::ExactTypeAttr::getAttrName() << '<';
269     p.printType(exact.getType());
270     os << '>';
271   } else if (auto sub = mlir::dyn_cast<fir::SubclassAttr>(attr)) {
272     os << fir::SubclassAttr::getAttrName() << '<';
273     p.printType(sub.getType());
274     os << '>';
275   } else if (mlir::dyn_cast_or_null<fir::PointIntervalAttr>(attr)) {
276     os << fir::PointIntervalAttr::getAttrName();
277   } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) {
278     os << fir::ClosedIntervalAttr::getAttrName();
279   } else if (mlir::dyn_cast_or_null<fir::LowerBoundAttr>(attr)) {
280     os << fir::LowerBoundAttr::getAttrName();
281   } else if (mlir::dyn_cast_or_null<fir::UpperBoundAttr>(attr)) {
282     os << fir::UpperBoundAttr::getAttrName();
283   } else if (auto a = mlir::dyn_cast_or_null<fir::RealAttr>(attr)) {
284     os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
285     llvm::SmallString<40> ss;
286     a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
287     os << ss << '>';
288   } else if (mlir::failed(generatedAttributePrinter(attr, p))) {
289     // don't know how to print the attribute, so use a default
290     os << "<(unknown attribute)>";
291   }
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // FIROpsDialect
296 //===----------------------------------------------------------------------===//
297 
298 void FIROpsDialect::registerAttributes() {
299   addAttributes<ClosedIntervalAttr, ExactTypeAttr,
300                 FortranProcedureFlagsEnumAttr, FortranVariableFlagsAttr,
301                 LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
302                 SubclassAttr, UpperBoundAttr, LocationKindAttr,
303                 LocationKindArrayAttr>();
304 }
305