xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp (revision 332719561000dcac94384234ace1fa959362ad8e)
1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
26 static void dispatchPrint(AsmPrinter &printer, Type type) {
27   if (isCompatibleType(type) &&
28       !llvm::isa<IntegerType, FloatType, VectorType>(type))
29     return mlir::LLVM::detail::printType(type, printer);
30   printer.printType(type);
31 }
32 
33 /// Returns the keyword to use for the given type.
34 static StringRef getTypeKeyword(Type type) {
35   return TypeSwitch<Type, StringRef>(type)
36       .Case<LLVMVoidType>([&](Type) { return "void"; })
37       .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
38       .Case<LLVMTokenType>([&](Type) { return "token"; })
39       .Case<LLVMLabelType>([&](Type) { return "label"; })
40       .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41       .Case<LLVMFunctionType>([&](Type) { return "func"; })
42       .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43       .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44           [&](Type) { return "vec"; })
45       .Case<LLVMArrayType>([&](Type) { return "array"; })
46       .Case<LLVMStructType>([&](Type) { return "struct"; })
47       .Case<LLVMTargetExtType>([&](Type) { return "target"; })
48       .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
49       .Default([](Type) -> StringRef {
50         llvm_unreachable("unexpected 'llvm' type kind");
51       });
52 }
53 
54 /// Prints a structure type. Keeps track of known struct names to handle self-
55 /// or mutually-referring structs without falling into infinite recursion.
56 void LLVMStructType::print(AsmPrinter &printer) const {
57   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
58 
59   printer << "<";
60   if (isIdentified()) {
61     cyclicPrint = printer.tryStartCyclicPrint(*this);
62 
63     printer << '"' << getName() << '"';
64     // If we are printing a reference to one of the enclosing structs, just
65     // print the name and stop to avoid infinitely long output.
66     if (failed(cyclicPrint)) {
67       printer << '>';
68       return;
69     }
70     printer << ", ";
71   }
72 
73   if (isIdentified() && isOpaque()) {
74     printer << "opaque>";
75     return;
76   }
77 
78   if (isPacked())
79     printer << "packed ";
80 
81   // Put the current type on stack to avoid infinite recursion.
82   printer << '(';
83   llvm::interleaveComma(getBody(), printer.getStream(),
84                         [&](Type subtype) { dispatchPrint(printer, subtype); });
85   printer << ')';
86   printer << '>';
87 }
88 
89 /// Prints the given LLVM dialect type recursively. This leverages closedness of
90 /// the LLVM dialect type system to avoid printing the dialect prefix
91 /// repeatedly. For recursive structures, only prints the name of the structure
92 /// when printing a self-reference. Note that this does not apply to sibling
93 /// references. For example,
94 ///   struct<"a", (ptr<struct<"a">>)>
95 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
96 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
97 /// note that "b" is printed twice.
98 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
99   if (!type) {
100     printer << "<<NULL-TYPE>>";
101     return;
102   }
103 
104   printer << getTypeKeyword(type);
105 
106   llvm::TypeSwitch<Type>(type)
107       .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
108             LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType,
109             LLVMStructType>([&](auto type) { type.print(printer); });
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parsing.
114 //===----------------------------------------------------------------------===//
115 
116 static ParseResult dispatchParse(AsmParser &parser, Type &type);
117 
118 /// Parses an LLVM dialect vector type.
119 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
120 /// Supports both fixed and scalable vectors.
121 static Type parseVectorType(AsmParser &parser) {
122   SmallVector<int64_t, 2> dims;
123   SMLoc dimPos, typePos;
124   Type elementType;
125   SMLoc loc = parser.getCurrentLocation();
126   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
127       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
128       parser.getCurrentLocation(&typePos) ||
129       dispatchParse(parser, elementType) || parser.parseGreater())
130     return Type();
131 
132   // We parsed a generic dimension list, but vectors only support two forms:
133   //  - single non-dynamic entry in the list (fixed vector);
134   //  - two elements, the first dynamic (indicated by ShapedType::kDynamic)
135   //  and the second
136   //    non-dynamic (scalable vector).
137   if (dims.empty() || dims.size() > 2 ||
138       ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) ||
139       (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) {
140     parser.emitError(dimPos)
141         << "expected '? x <integer> x <type>' or '<integer> x <type>'";
142     return Type();
143   }
144 
145   bool isScalable = dims.size() == 2;
146   if (isScalable)
147     return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
148   if (elementType.isSignlessIntOrFloat()) {
149     parser.emitError(typePos)
150         << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
151     return Type();
152   }
153   return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
154 }
155 
156 /// Attempts to set the body of an identified structure type. Reports a parsing
157 /// error at `subtypesLoc` in case of failure.
158 static LLVMStructType trySetStructBody(LLVMStructType type,
159                                        ArrayRef<Type> subtypes, bool isPacked,
160                                        AsmParser &parser, SMLoc subtypesLoc) {
161   for (Type t : subtypes) {
162     if (!LLVMStructType::isValidElementType(t)) {
163       parser.emitError(subtypesLoc)
164           << "invalid LLVM structure element type: " << t;
165       return LLVMStructType();
166     }
167   }
168 
169   if (succeeded(type.setBody(subtypes, isPacked)))
170     return type;
171 
172   parser.emitError(subtypesLoc)
173       << "identified type already used with a different body";
174   return LLVMStructType();
175 }
176 
177 /// Parses an LLVM dialect structure type.
178 ///   llvm-type ::= `struct<` (string-literal `,`)? `packed`?
179 ///                 `(` llvm-type-list `)` `>`
180 ///               | `struct<` string-literal `>`
181 ///               | `struct<` string-literal `, opaque>`
182 Type LLVMStructType::parse(AsmParser &parser) {
183   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
184 
185   if (failed(parser.parseLess()))
186     return LLVMStructType();
187 
188   // If we are parsing a self-reference to a recursive struct, i.e. the parsing
189   // stack already contains a struct with the same identifier, bail out after
190   // the name.
191   std::string name;
192   bool isIdentified = succeeded(parser.parseOptionalString(&name));
193   if (isIdentified) {
194     SMLoc greaterLoc = parser.getCurrentLocation();
195     if (succeeded(parser.parseOptionalGreater())) {
196       auto type = LLVMStructType::getIdentifiedChecked(
197           [loc] { return emitError(loc); }, loc.getContext(), name);
198       if (succeeded(parser.tryStartCyclicParse(type))) {
199         parser.emitError(
200             greaterLoc,
201             "struct without a body only allowed in a recursive struct");
202         return nullptr;
203       }
204 
205       return type;
206     }
207     if (failed(parser.parseComma()))
208       return LLVMStructType();
209   }
210 
211   // Handle intentionally opaque structs.
212   SMLoc kwLoc = parser.getCurrentLocation();
213   if (succeeded(parser.parseOptionalKeyword("opaque"))) {
214     if (!isIdentified)
215       return parser.emitError(kwLoc, "only identified structs can be opaque"),
216              LLVMStructType();
217     if (failed(parser.parseGreater()))
218       return LLVMStructType();
219     auto type = LLVMStructType::getOpaqueChecked(
220         [loc] { return emitError(loc); }, loc.getContext(), name);
221     if (!type.isOpaque()) {
222       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
223       return LLVMStructType();
224     }
225     return type;
226   }
227 
228   FailureOr<AsmParser::CyclicParseReset> cyclicParse;
229   if (isIdentified) {
230     cyclicParse =
231         parser.tryStartCyclicParse(LLVMStructType::getIdentifiedChecked(
232             [loc] { return emitError(loc); }, loc.getContext(), name));
233     if (failed(cyclicParse)) {
234       parser.emitError(kwLoc,
235                        "identifier already used for an enclosing struct");
236       return nullptr;
237     }
238   }
239 
240   // Check for packedness.
241   bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
242   if (failed(parser.parseLParen()))
243     return LLVMStructType();
244 
245   // Fast pass for structs with zero subtypes.
246   if (succeeded(parser.parseOptionalRParen())) {
247     if (failed(parser.parseGreater()))
248       return LLVMStructType();
249     if (!isIdentified)
250       return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
251                                                loc.getContext(), {}, isPacked);
252     auto type = LLVMStructType::getIdentifiedChecked(
253         [loc] { return emitError(loc); }, loc.getContext(), name);
254     return trySetStructBody(type, {}, isPacked, parser, kwLoc);
255   }
256 
257   // Parse subtypes. For identified structs, put the identifier of the struct on
258   // the stack to support self-references in the recursive calls.
259   SmallVector<Type, 4> subtypes;
260   SMLoc subtypesLoc = parser.getCurrentLocation();
261   do {
262     Type type;
263     if (dispatchParse(parser, type))
264       return LLVMStructType();
265     subtypes.push_back(type);
266   } while (succeeded(parser.parseOptionalComma()));
267 
268   if (parser.parseRParen() || parser.parseGreater())
269     return LLVMStructType();
270 
271   // Construct the struct with body.
272   if (!isIdentified)
273     return LLVMStructType::getLiteralChecked(
274         [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
275   auto type = LLVMStructType::getIdentifiedChecked(
276       [loc] { return emitError(loc); }, loc.getContext(), name);
277   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
278 }
279 
280 /// Parses a type appearing inside another LLVM dialect-compatible type. This
281 /// will try to parse any type in full form (including types with the `!llvm`
282 /// prefix), and on failure fall back to parsing the short-hand version of the
283 /// LLVM dialect types without the `!llvm` prefix.
284 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
285   SMLoc keyLoc = parser.getCurrentLocation();
286 
287   // Try parsing any MLIR type.
288   Type type;
289   OptionalParseResult result = parser.parseOptionalType(type);
290   if (result.has_value()) {
291     if (failed(result.value()))
292       return nullptr;
293     if (!allowAny) {
294       parser.emitError(keyLoc) << "unexpected type, expected keyword";
295       return nullptr;
296     }
297     return type;
298   }
299 
300   // If no type found, fallback to the shorthand form.
301   StringRef key;
302   if (failed(parser.parseKeyword(&key)))
303     return Type();
304 
305   MLIRContext *ctx = parser.getContext();
306   return StringSwitch<function_ref<Type()>>(key)
307       .Case("void", [&] { return LLVMVoidType::get(ctx); })
308       .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
309       .Case("token", [&] { return LLVMTokenType::get(ctx); })
310       .Case("label", [&] { return LLVMLabelType::get(ctx); })
311       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
312       .Case("func", [&] { return LLVMFunctionType::parse(parser); })
313       .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
314       .Case("vec", [&] { return parseVectorType(parser); })
315       .Case("array", [&] { return LLVMArrayType::parse(parser); })
316       .Case("struct", [&] { return LLVMStructType::parse(parser); })
317       .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
318       .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
319       .Default([&] {
320         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
321         return Type();
322       })();
323 }
324 
325 /// Helper to use in parse lists.
326 static ParseResult dispatchParse(AsmParser &parser, Type &type) {
327   type = dispatchParse(parser);
328   return success(type != nullptr);
329 }
330 
331 /// Parses one of the LLVM dialect types.
332 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
333   SMLoc loc = parser.getCurrentLocation();
334   Type type = dispatchParse(parser, /*allowAny=*/false);
335   if (!type)
336     return type;
337   if (!isCompatibleOuterType(type)) {
338     parser.emitError(loc) << "unexpected type, expected keyword";
339     return nullptr;
340   }
341   return type;
342 }
343 
344 ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, Type &type) {
345   return dispatchParse(p, type);
346 }
347 
348 void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) {
349   return dispatchPrint(p, type);
350 }
351