1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/ScopeExit.h" 13 #include "llvm/ADT/SetVector.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::LLVM; 18 19 //===----------------------------------------------------------------------===// 20 // Printing. 21 //===----------------------------------------------------------------------===// 22 23 /// If the given type is compatible with the LLVM dialect, prints it using 24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise 25 /// prints it as usual. 26 static void dispatchPrint(AsmPrinter &printer, Type type) { 27 if (isCompatibleType(type) && 28 !llvm::isa<IntegerType, FloatType, VectorType>(type)) 29 return mlir::LLVM::detail::printType(type, printer); 30 printer.printType(type); 31 } 32 33 /// Returns the keyword to use for the given type. 34 static StringRef getTypeKeyword(Type type) { 35 return TypeSwitch<Type, StringRef>(type) 36 .Case<LLVMVoidType>([&](Type) { return "void"; }) 37 .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; }) 38 .Case<LLVMTokenType>([&](Type) { return "token"; }) 39 .Case<LLVMLabelType>([&](Type) { return "label"; }) 40 .Case<LLVMMetadataType>([&](Type) { return "metadata"; }) 41 .Case<LLVMFunctionType>([&](Type) { return "func"; }) 42 .Case<LLVMPointerType>([&](Type) { return "ptr"; }) 43 .Case<LLVMFixedVectorType, LLVMScalableVectorType>( 44 [&](Type) { return "vec"; }) 45 .Case<LLVMArrayType>([&](Type) { return "array"; }) 46 .Case<LLVMStructType>([&](Type) { return "struct"; }) 47 .Case<LLVMTargetExtType>([&](Type) { return "target"; }) 48 .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; }) 49 .Default([](Type) -> StringRef { 50 llvm_unreachable("unexpected 'llvm' type kind"); 51 }); 52 } 53 54 /// Prints a structure type. Keeps track of known struct names to handle self- 55 /// or mutually-referring structs without falling into infinite recursion. 56 void LLVMStructType::print(AsmPrinter &printer) const { 57 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint; 58 59 printer << "<"; 60 if (isIdentified()) { 61 cyclicPrint = printer.tryStartCyclicPrint(*this); 62 63 printer << '"' << getName() << '"'; 64 // If we are printing a reference to one of the enclosing structs, just 65 // print the name and stop to avoid infinitely long output. 66 if (failed(cyclicPrint)) { 67 printer << '>'; 68 return; 69 } 70 printer << ", "; 71 } 72 73 if (isIdentified() && isOpaque()) { 74 printer << "opaque>"; 75 return; 76 } 77 78 if (isPacked()) 79 printer << "packed "; 80 81 // Put the current type on stack to avoid infinite recursion. 82 printer << '('; 83 llvm::interleaveComma(getBody(), printer.getStream(), 84 [&](Type subtype) { dispatchPrint(printer, subtype); }); 85 printer << ')'; 86 printer << '>'; 87 } 88 89 /// Prints the given LLVM dialect type recursively. This leverages closedness of 90 /// the LLVM dialect type system to avoid printing the dialect prefix 91 /// repeatedly. For recursive structures, only prints the name of the structure 92 /// when printing a self-reference. Note that this does not apply to sibling 93 /// references. For example, 94 /// struct<"a", (ptr<struct<"a">>)> 95 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>, 96 /// ptr<struct<"b", (ptr<struct<"c">>)>>)> 97 /// note that "b" is printed twice. 98 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { 99 if (!type) { 100 printer << "<<NULL-TYPE>>"; 101 return; 102 } 103 104 printer << getTypeKeyword(type); 105 106 llvm::TypeSwitch<Type>(type) 107 .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType, 108 LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType, 109 LLVMStructType>([&](auto type) { type.print(printer); }); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Parsing. 114 //===----------------------------------------------------------------------===// 115 116 static ParseResult dispatchParse(AsmParser &parser, Type &type); 117 118 /// Parses an LLVM dialect vector type. 119 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` 120 /// Supports both fixed and scalable vectors. 121 static Type parseVectorType(AsmParser &parser) { 122 SmallVector<int64_t, 2> dims; 123 SMLoc dimPos, typePos; 124 Type elementType; 125 SMLoc loc = parser.getCurrentLocation(); 126 if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || 127 parser.parseDimensionList(dims, /*allowDynamic=*/true) || 128 parser.getCurrentLocation(&typePos) || 129 dispatchParse(parser, elementType) || parser.parseGreater()) 130 return Type(); 131 132 // We parsed a generic dimension list, but vectors only support two forms: 133 // - single non-dynamic entry in the list (fixed vector); 134 // - two elements, the first dynamic (indicated by ShapedType::kDynamic) 135 // and the second 136 // non-dynamic (scalable vector). 137 if (dims.empty() || dims.size() > 2 || 138 ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) || 139 (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) { 140 parser.emitError(dimPos) 141 << "expected '? x <integer> x <type>' or '<integer> x <type>'"; 142 return Type(); 143 } 144 145 bool isScalable = dims.size() == 2; 146 if (isScalable) 147 return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); 148 if (elementType.isSignlessIntOrFloat()) { 149 parser.emitError(typePos) 150 << "cannot use !llvm.vec for built-in primitives, use 'vector' instead"; 151 return Type(); 152 } 153 return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]); 154 } 155 156 /// Attempts to set the body of an identified structure type. Reports a parsing 157 /// error at `subtypesLoc` in case of failure. 158 static LLVMStructType trySetStructBody(LLVMStructType type, 159 ArrayRef<Type> subtypes, bool isPacked, 160 AsmParser &parser, SMLoc subtypesLoc) { 161 for (Type t : subtypes) { 162 if (!LLVMStructType::isValidElementType(t)) { 163 parser.emitError(subtypesLoc) 164 << "invalid LLVM structure element type: " << t; 165 return LLVMStructType(); 166 } 167 } 168 169 if (succeeded(type.setBody(subtypes, isPacked))) 170 return type; 171 172 parser.emitError(subtypesLoc) 173 << "identified type already used with a different body"; 174 return LLVMStructType(); 175 } 176 177 /// Parses an LLVM dialect structure type. 178 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`? 179 /// `(` llvm-type-list `)` `>` 180 /// | `struct<` string-literal `>` 181 /// | `struct<` string-literal `, opaque>` 182 Type LLVMStructType::parse(AsmParser &parser) { 183 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 184 185 if (failed(parser.parseLess())) 186 return LLVMStructType(); 187 188 // If we are parsing a self-reference to a recursive struct, i.e. the parsing 189 // stack already contains a struct with the same identifier, bail out after 190 // the name. 191 std::string name; 192 bool isIdentified = succeeded(parser.parseOptionalString(&name)); 193 if (isIdentified) { 194 SMLoc greaterLoc = parser.getCurrentLocation(); 195 if (succeeded(parser.parseOptionalGreater())) { 196 auto type = LLVMStructType::getIdentifiedChecked( 197 [loc] { return emitError(loc); }, loc.getContext(), name); 198 if (succeeded(parser.tryStartCyclicParse(type))) { 199 parser.emitError( 200 greaterLoc, 201 "struct without a body only allowed in a recursive struct"); 202 return nullptr; 203 } 204 205 return type; 206 } 207 if (failed(parser.parseComma())) 208 return LLVMStructType(); 209 } 210 211 // Handle intentionally opaque structs. 212 SMLoc kwLoc = parser.getCurrentLocation(); 213 if (succeeded(parser.parseOptionalKeyword("opaque"))) { 214 if (!isIdentified) 215 return parser.emitError(kwLoc, "only identified structs can be opaque"), 216 LLVMStructType(); 217 if (failed(parser.parseGreater())) 218 return LLVMStructType(); 219 auto type = LLVMStructType::getOpaqueChecked( 220 [loc] { return emitError(loc); }, loc.getContext(), name); 221 if (!type.isOpaque()) { 222 parser.emitError(kwLoc, "redeclaring defined struct as opaque"); 223 return LLVMStructType(); 224 } 225 return type; 226 } 227 228 FailureOr<AsmParser::CyclicParseReset> cyclicParse; 229 if (isIdentified) { 230 cyclicParse = 231 parser.tryStartCyclicParse(LLVMStructType::getIdentifiedChecked( 232 [loc] { return emitError(loc); }, loc.getContext(), name)); 233 if (failed(cyclicParse)) { 234 parser.emitError(kwLoc, 235 "identifier already used for an enclosing struct"); 236 return nullptr; 237 } 238 } 239 240 // Check for packedness. 241 bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); 242 if (failed(parser.parseLParen())) 243 return LLVMStructType(); 244 245 // Fast pass for structs with zero subtypes. 246 if (succeeded(parser.parseOptionalRParen())) { 247 if (failed(parser.parseGreater())) 248 return LLVMStructType(); 249 if (!isIdentified) 250 return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); }, 251 loc.getContext(), {}, isPacked); 252 auto type = LLVMStructType::getIdentifiedChecked( 253 [loc] { return emitError(loc); }, loc.getContext(), name); 254 return trySetStructBody(type, {}, isPacked, parser, kwLoc); 255 } 256 257 // Parse subtypes. For identified structs, put the identifier of the struct on 258 // the stack to support self-references in the recursive calls. 259 SmallVector<Type, 4> subtypes; 260 SMLoc subtypesLoc = parser.getCurrentLocation(); 261 do { 262 Type type; 263 if (dispatchParse(parser, type)) 264 return LLVMStructType(); 265 subtypes.push_back(type); 266 } while (succeeded(parser.parseOptionalComma())); 267 268 if (parser.parseRParen() || parser.parseGreater()) 269 return LLVMStructType(); 270 271 // Construct the struct with body. 272 if (!isIdentified) 273 return LLVMStructType::getLiteralChecked( 274 [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked); 275 auto type = LLVMStructType::getIdentifiedChecked( 276 [loc] { return emitError(loc); }, loc.getContext(), name); 277 return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); 278 } 279 280 /// Parses a type appearing inside another LLVM dialect-compatible type. This 281 /// will try to parse any type in full form (including types with the `!llvm` 282 /// prefix), and on failure fall back to parsing the short-hand version of the 283 /// LLVM dialect types without the `!llvm` prefix. 284 static Type dispatchParse(AsmParser &parser, bool allowAny = true) { 285 SMLoc keyLoc = parser.getCurrentLocation(); 286 287 // Try parsing any MLIR type. 288 Type type; 289 OptionalParseResult result = parser.parseOptionalType(type); 290 if (result.has_value()) { 291 if (failed(result.value())) 292 return nullptr; 293 if (!allowAny) { 294 parser.emitError(keyLoc) << "unexpected type, expected keyword"; 295 return nullptr; 296 } 297 return type; 298 } 299 300 // If no type found, fallback to the shorthand form. 301 StringRef key; 302 if (failed(parser.parseKeyword(&key))) 303 return Type(); 304 305 MLIRContext *ctx = parser.getContext(); 306 return StringSwitch<function_ref<Type()>>(key) 307 .Case("void", [&] { return LLVMVoidType::get(ctx); }) 308 .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) 309 .Case("token", [&] { return LLVMTokenType::get(ctx); }) 310 .Case("label", [&] { return LLVMLabelType::get(ctx); }) 311 .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) 312 .Case("func", [&] { return LLVMFunctionType::parse(parser); }) 313 .Case("ptr", [&] { return LLVMPointerType::parse(parser); }) 314 .Case("vec", [&] { return parseVectorType(parser); }) 315 .Case("array", [&] { return LLVMArrayType::parse(parser); }) 316 .Case("struct", [&] { return LLVMStructType::parse(parser); }) 317 .Case("target", [&] { return LLVMTargetExtType::parse(parser); }) 318 .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); }) 319 .Default([&] { 320 parser.emitError(keyLoc) << "unknown LLVM type: " << key; 321 return Type(); 322 })(); 323 } 324 325 /// Helper to use in parse lists. 326 static ParseResult dispatchParse(AsmParser &parser, Type &type) { 327 type = dispatchParse(parser); 328 return success(type != nullptr); 329 } 330 331 /// Parses one of the LLVM dialect types. 332 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) { 333 SMLoc loc = parser.getCurrentLocation(); 334 Type type = dispatchParse(parser, /*allowAny=*/false); 335 if (!type) 336 return type; 337 if (!isCompatibleOuterType(type)) { 338 parser.emitError(loc) << "unexpected type, expected keyword"; 339 return nullptr; 340 } 341 return type; 342 } 343 344 ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, Type &type) { 345 return dispatchParse(p, type); 346 } 347 348 void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) { 349 return dispatchPrint(p, type); 350 } 351