1 //===-- FIRAttr.cpp -------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/FIRAttr.h" 14 #include "flang/Optimizer/Dialect/FIRDialect.h" 15 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 16 #include "mlir/IR/AttributeSupport.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc" 25 #define GET_ATTRDEF_CLASSES 26 #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc" 27 28 using namespace fir; 29 30 namespace fir::detail { 31 32 struct RealAttributeStorage : public mlir::AttributeStorage { 33 using KeyTy = std::pair<int, llvm::APFloat>; 34 35 RealAttributeStorage(int kind, const llvm::APFloat &value) 36 : kind(kind), value(value) {} 37 RealAttributeStorage(const KeyTy &key) 38 : RealAttributeStorage(key.first, key.second) {} 39 40 static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); } 41 42 bool operator==(const KeyTy &key) const { 43 return key.first == kind && 44 key.second.compare(value) == llvm::APFloatBase::cmpEqual; 45 } 46 47 static RealAttributeStorage * 48 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { 49 return new (allocator.allocate<RealAttributeStorage>()) 50 RealAttributeStorage(key); 51 } 52 53 KindTy getFKind() const { return kind; } 54 llvm::APFloat getValue() const { return value; } 55 56 private: 57 int kind; 58 llvm::APFloat value; 59 }; 60 61 /// An attribute representing a reference to a type. 62 struct TypeAttributeStorage : public mlir::AttributeStorage { 63 using KeyTy = mlir::Type; 64 65 TypeAttributeStorage(mlir::Type value) : value(value) { 66 assert(value && "must not be of Type null"); 67 } 68 69 /// Key equality function. 70 bool operator==(const KeyTy &key) const { return key == value; } 71 72 /// Construct a new storage instance. 73 static TypeAttributeStorage * 74 construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) { 75 return new (allocator.allocate<TypeAttributeStorage>()) 76 TypeAttributeStorage(key); 77 } 78 79 mlir::Type getType() const { return value; } 80 81 private: 82 mlir::Type value; 83 }; 84 } // namespace fir::detail 85 86 //===----------------------------------------------------------------------===// 87 // Attributes for SELECT TYPE 88 //===----------------------------------------------------------------------===// 89 90 ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) { 91 return Base::get(value.getContext(), value); 92 } 93 94 mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); } 95 96 SubclassAttr fir::SubclassAttr::get(mlir::Type value) { 97 return Base::get(value.getContext(), value); 98 } 99 100 mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); } 101 102 //===----------------------------------------------------------------------===// 103 // Attributes for SELECT CASE 104 //===----------------------------------------------------------------------===// 105 106 using AttributeUniquer = mlir::detail::AttributeUniquer; 107 108 ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { 109 return AttributeUniquer::get<ClosedIntervalAttr>(ctxt); 110 } 111 112 UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) { 113 return AttributeUniquer::get<UpperBoundAttr>(ctxt); 114 } 115 116 LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) { 117 return AttributeUniquer::get<LowerBoundAttr>(ctxt); 118 } 119 120 PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) { 121 return AttributeUniquer::get<PointIntervalAttr>(ctxt); 122 } 123 124 //===----------------------------------------------------------------------===// 125 // RealAttr 126 //===----------------------------------------------------------------------===// 127 128 RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt, 129 const RealAttr::ValueType &key) { 130 return Base::get(ctxt, key); 131 } 132 133 KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); } 134 135 llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); } 136 137 //===----------------------------------------------------------------------===// 138 // FIR attribute parsing 139 //===----------------------------------------------------------------------===// 140 141 static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect, 142 mlir::DialectAsmParser &parser, 143 mlir::Type type) { 144 int kind = 0; 145 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) { 146 parser.emitError(parser.getNameLoc(), "expected '<' kind ','"); 147 return {}; 148 } 149 KindMapping kindMap(dialect->getContext()); 150 llvm::APFloat value(0.); 151 if (parser.parseOptionalKeyword("i")) { 152 // `i` not present, so literal float must be present 153 double dontCare; 154 if (parser.parseFloat(dontCare) || parser.parseGreater()) { 155 parser.emitError(parser.getNameLoc(), "expected real constant '>'"); 156 return {}; 157 } 158 auto fltStr = parser.getFullSymbolSpec() 159 .drop_until([](char c) { return c == ','; }) 160 .drop_front() 161 .drop_while([](char c) { return c == ' ' || c == '\t'; }) 162 .take_until([](char c) { 163 return c == '>' || c == ' ' || c == '\t'; 164 }); 165 value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr); 166 } else { 167 // `i` is present, so literal bitstring (hex) must be present 168 llvm::StringRef hex; 169 if (parser.parseKeyword(&hex) || parser.parseGreater()) { 170 parser.emitError(parser.getNameLoc(), "expected real constant '>'"); 171 return {}; 172 } 173 const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind); 174 unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem); 175 auto bits = llvm::APInt(numBits, hex.drop_front(), 16); 176 value = llvm::APFloat(sem, bits); 177 } 178 return RealAttr::get(dialect->getContext(), {kind, value}); 179 } 180 181 mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser, 182 mlir::Type type) { 183 if (mlir::failed(parser.parseLess())) 184 return {}; 185 186 fir::FortranVariableFlagsEnum flags = {}; 187 if (mlir::failed(parser.parseOptionalGreater())) { 188 auto parseFlags = [&]() -> mlir::ParseResult { 189 llvm::StringRef elemName; 190 if (mlir::failed(parser.parseKeyword(&elemName))) 191 return mlir::failure(); 192 193 auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName); 194 if (!elem) 195 return parser.emitError(parser.getNameLoc(), 196 "Unknown fortran variable attribute: ") 197 << elemName; 198 199 flags = flags | *elem; 200 return mlir::success(); 201 }; 202 if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) || 203 parser.parseGreater()) 204 return {}; 205 } 206 207 return FortranVariableFlagsAttr::get(parser.getContext(), flags); 208 } 209 210 mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect, 211 mlir::DialectAsmParser &parser, 212 mlir::Type type) { 213 auto loc = parser.getNameLoc(); 214 llvm::StringRef attrName; 215 mlir::Attribute attr; 216 mlir::OptionalParseResult result = 217 generatedAttributeParser(parser, &attrName, type, attr); 218 if (result.has_value()) 219 return attr; 220 if (attrName.empty()) 221 return {}; // error reported by generatedAttributeParser 222 223 if (attrName == ExactTypeAttr::getAttrName()) { 224 mlir::Type type; 225 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) { 226 parser.emitError(loc, "expected a type"); 227 return {}; 228 } 229 return ExactTypeAttr::get(type); 230 } 231 if (attrName == SubclassAttr::getAttrName()) { 232 mlir::Type type; 233 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) { 234 parser.emitError(loc, "expected a subtype"); 235 return {}; 236 } 237 return SubclassAttr::get(type); 238 } 239 if (attrName == PointIntervalAttr::getAttrName()) 240 return PointIntervalAttr::get(dialect->getContext()); 241 if (attrName == LowerBoundAttr::getAttrName()) 242 return LowerBoundAttr::get(dialect->getContext()); 243 if (attrName == UpperBoundAttr::getAttrName()) 244 return UpperBoundAttr::get(dialect->getContext()); 245 if (attrName == ClosedIntervalAttr::getAttrName()) 246 return ClosedIntervalAttr::get(dialect->getContext()); 247 if (attrName == RealAttr::getAttrName()) 248 return parseFirRealAttr(dialect, parser, type); 249 250 parser.emitError(loc, "unknown FIR attribute: ") << attrName; 251 return {}; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // FIR attribute pretty printer 256 //===----------------------------------------------------------------------===// 257 258 void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const { 259 printer << "<"; 260 printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags()); 261 printer << ">"; 262 } 263 264 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, 265 mlir::DialectAsmPrinter &p) { 266 auto &os = p.getStream(); 267 if (auto exact = mlir::dyn_cast<fir::ExactTypeAttr>(attr)) { 268 os << fir::ExactTypeAttr::getAttrName() << '<'; 269 p.printType(exact.getType()); 270 os << '>'; 271 } else if (auto sub = mlir::dyn_cast<fir::SubclassAttr>(attr)) { 272 os << fir::SubclassAttr::getAttrName() << '<'; 273 p.printType(sub.getType()); 274 os << '>'; 275 } else if (mlir::dyn_cast_or_null<fir::PointIntervalAttr>(attr)) { 276 os << fir::PointIntervalAttr::getAttrName(); 277 } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) { 278 os << fir::ClosedIntervalAttr::getAttrName(); 279 } else if (mlir::dyn_cast_or_null<fir::LowerBoundAttr>(attr)) { 280 os << fir::LowerBoundAttr::getAttrName(); 281 } else if (mlir::dyn_cast_or_null<fir::UpperBoundAttr>(attr)) { 282 os << fir::UpperBoundAttr::getAttrName(); 283 } else if (auto a = mlir::dyn_cast_or_null<fir::RealAttr>(attr)) { 284 os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x"; 285 llvm::SmallString<40> ss; 286 a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16); 287 os << ss << '>'; 288 } else if (mlir::failed(generatedAttributePrinter(attr, p))) { 289 // don't know how to print the attribute, so use a default 290 os << "<(unknown attribute)>"; 291 } 292 } 293 294 //===----------------------------------------------------------------------===// 295 // FIROpsDialect 296 //===----------------------------------------------------------------------===// 297 298 void FIROpsDialect::registerAttributes() { 299 addAttributes<ClosedIntervalAttr, ExactTypeAttr, 300 FortranProcedureFlagsEnumAttr, FortranVariableFlagsAttr, 301 LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr, 302 SubclassAttr, UpperBoundAttr, LocationKindAttr, 303 LocationKindArrayAttr>(); 304 } 305