xref: /llvm-project/mlir/test/lib/Dialect/Test/TestTypes.cpp (revision 3c64f86314fbf9a3cd578419f16e621a4de57eaa)
1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTypes.h"
15 #include "TestDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/ExtensibleDialect.h"
20 #include "mlir/IR/Types.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Custom parser for SignednessSemantics.
31 static ParseResult
32 parseSignedness(AsmParser &parser,
33                 TestIntegerType::SignednessSemantics &result) {
34   StringRef signStr;
35   auto loc = parser.getCurrentLocation();
36   if (parser.parseKeyword(&signStr))
37     return failure();
38   if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
39     result = TestIntegerType::SignednessSemantics::Unsigned;
40   else if (signStr.equals_insensitive("s") ||
41            signStr.equals_insensitive("signed"))
42     result = TestIntegerType::SignednessSemantics::Signed;
43   else if (signStr.equals_insensitive("n") ||
44            signStr.equals_insensitive("none"))
45     result = TestIntegerType::SignednessSemantics::Signless;
46   else
47     return parser.emitError(loc, "expected signed, unsigned, or none");
48   return success();
49 }
50 
51 // Custom printer for SignednessSemantics.
52 static void printSignedness(AsmPrinter &printer,
53                             const TestIntegerType::SignednessSemantics &ss) {
54   switch (ss) {
55   case TestIntegerType::SignednessSemantics::Unsigned:
56     printer << "unsigned";
57     break;
58   case TestIntegerType::SignednessSemantics::Signed:
59     printer << "signed";
60     break;
61   case TestIntegerType::SignednessSemantics::Signless:
62     printer << "none";
63     break;
64   }
65 }
66 
67 // The functions don't need to be in the header file, but need to be in the mlir
68 // namespace. Declare them here, then define them immediately below. Separating
69 // the declaration and definition adheres to the LLVM coding standards.
70 namespace test {
71 // FieldInfo is used as part of a parameter, so equality comparison is
72 // compulsory.
73 static bool operator==(const FieldInfo &a, const FieldInfo &b);
74 // FieldInfo is used as part of a parameter, so a hash will be computed.
75 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
76 } // namespace test
77 
78 // FieldInfo is used as part of a parameter, so equality comparison is
79 // compulsory.
80 static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
81   return a.name == b.name && a.type == b.type;
82 }
83 
84 // FieldInfo is used as part of a parameter, so a hash will be computed.
85 static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
86   return llvm::hash_combine(fi.name, fi.type);
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // TestCustomType
91 //===----------------------------------------------------------------------===//
92 
93 static ParseResult parseCustomTypeA(AsmParser &parser, int &aResult) {
94   return parser.parseInteger(aResult);
95 }
96 
97 static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98 
99 static ParseResult parseCustomTypeB(AsmParser &parser, int a,
100                                     std::optional<int> &bResult) {
101   if (a < 0)
102     return success();
103   for (int i : llvm::seq(0, a))
104     if (failed(parser.parseInteger(i)))
105       return failure();
106   bResult.emplace(0);
107   return parser.parseInteger(*bResult);
108 }
109 
110 static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
111   if (a < 0)
112     return;
113   printer << ' ';
114   for (int i : llvm::seq(0, a))
115     printer << i << ' ';
116   printer << *b;
117 }
118 
119 static ParseResult parseFooString(AsmParser &parser, std::string &foo) {
120   std::string result;
121   if (parser.parseString(&result))
122     return failure();
123   foo = std::move(result);
124   return success();
125 }
126 
127 static void printFooString(AsmPrinter &printer, StringRef foo) {
128   printer << '"' << foo << '"';
129 }
130 
131 static ParseResult parseBarString(AsmParser &parser, StringRef foo) {
132   return parser.parseKeyword(foo);
133 }
134 
135 static void printBarString(AsmPrinter &printer, StringRef foo) {
136   printer << foo;
137 }
138 //===----------------------------------------------------------------------===//
139 // Tablegen Generated Definitions
140 //===----------------------------------------------------------------------===//
141 
142 #include "TestTypeInterfaces.cpp.inc"
143 #define GET_TYPEDEF_CLASSES
144 #include "TestTypeDefs.cpp.inc"
145 
146 //===----------------------------------------------------------------------===//
147 // CompoundAType
148 //===----------------------------------------------------------------------===//
149 
150 Type CompoundAType::parse(AsmParser &parser) {
151   int widthOfSomething;
152   Type oneType;
153   SmallVector<int, 4> arrayOfInts;
154   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
155       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
156       parser.parseLSquare())
157     return Type();
158 
159   int i;
160   while (!*parser.parseOptionalInteger(i)) {
161     arrayOfInts.push_back(i);
162     if (parser.parseOptionalComma())
163       break;
164   }
165 
166   if (parser.parseRSquare() || parser.parseGreater())
167     return Type();
168 
169   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170 }
171 void CompoundAType::print(AsmPrinter &printer) const {
172   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173   auto intArray = getArrayOfInts();
174   llvm::interleaveComma(intArray, printer);
175   printer << "]>";
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // TestIntegerType
180 //===----------------------------------------------------------------------===//
181 
182 // Example type validity checker.
183 LogicalResult
184 TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185                         unsigned width,
186                         TestIntegerType::SignednessSemantics ss) {
187   if (width > 8)
188     return failure();
189   return success();
190 }
191 
192 Type TestIntegerType::parse(AsmParser &parser) {
193   SignednessSemantics signedness;
194   int width;
195   if (parser.parseLess() || parseSignedness(parser, signedness) ||
196       parser.parseComma() || parser.parseInteger(width) ||
197       parser.parseGreater())
198     return Type();
199   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
200   return getChecked(loc, loc.getContext(), width, signedness);
201 }
202 
203 void TestIntegerType::print(AsmPrinter &p) const {
204   p << "<";
205   printSignedness(p, getSignedness());
206   p << ", " << getWidth() << ">";
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // TestStructType
211 //===----------------------------------------------------------------------===//
212 
213 Type StructType::parse(AsmParser &p) {
214   SmallVector<FieldInfo, 4> parameters;
215   if (p.parseLess())
216     return Type();
217   while (succeeded(p.parseOptionalLBrace())) {
218     Type type;
219     StringRef name;
220     if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
221         p.parseRBrace())
222       return Type();
223     parameters.push_back(FieldInfo{name, type});
224     if (p.parseOptionalComma())
225       break;
226   }
227   if (p.parseGreater())
228     return Type();
229   return get(p.getContext(), parameters);
230 }
231 
232 void StructType::print(AsmPrinter &p) const {
233   p << "<";
234   llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
235     p << "{" << field.name << "," << field.type << "}";
236   });
237   p << ">";
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // TestType
242 //===----------------------------------------------------------------------===//
243 
244 void TestType::printTypeC(Location loc) const {
245   emitRemark(loc) << *this << " - TestC";
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // TestTypeWithLayout
250 //===----------------------------------------------------------------------===//
251 
252 Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253   unsigned val;
254   if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
255     return Type();
256   return TestTypeWithLayoutType::get(parser.getContext(), val);
257 }
258 
259 void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260   printer << "<" << getKey() << ">";
261 }
262 
263 llvm::TypeSize
264 TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265                                           DataLayoutEntryListRef params) const {
266   return llvm::TypeSize::getFixed(extractKind(params, "size"));
267 }
268 
269 uint64_t
270 TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271                                         DataLayoutEntryListRef params) const {
272   return extractKind(params, "alignment");
273 }
274 
275 uint64_t TestTypeWithLayoutType::getPreferredAlignment(
276     const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277   return extractKind(params, "preferred");
278 }
279 
280 std::optional<uint64_t>
281 TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
282                                          DataLayoutEntryListRef params) const {
283   return extractKind(params, "index");
284 }
285 
286 bool TestTypeWithLayoutType::areCompatible(
287     DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
288   unsigned old = extractKind(oldLayout, "alignment");
289   return old == 1 || extractKind(newLayout, "alignment") <= old;
290 }
291 
292 LogicalResult
293 TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
294                                       Location loc) const {
295   for (DataLayoutEntryInterface entry : params) {
296     // This is for testing purposes only, so assert well-formedness.
297     assert(entry.isTypeEntry() && "unexpected identifier entry");
298     assert(
299         llvm::isa<TestTypeWithLayoutType>(llvm::cast<Type>(entry.getKey())) &&
300         "wrong type passed in");
301     auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue());
302     assert(array && array.getValue().size() == 2 &&
303            "expected array of two elements");
304     auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front());
305     (void)kind;
306     assert(kind &&
307            (kind.getValue() == "size" || kind.getValue() == "alignment" ||
308             kind.getValue() == "preferred" || kind.getValue() == "index") &&
309            "unexpected kind");
310     assert(llvm::isa<IntegerAttr>(array.getValue().back()));
311   }
312   return success();
313 }
314 
315 uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
316                                              StringRef expectedKind) const {
317   for (DataLayoutEntryInterface entry : params) {
318     ArrayRef<Attribute> pair =
319         llvm::cast<ArrayAttr>(entry.getValue()).getValue();
320     StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue();
321     if (kind == expectedKind)
322       return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue();
323   }
324   return 1;
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // Dynamic Types
329 //===----------------------------------------------------------------------===//
330 
331 /// Define a singleton dynamic type.
332 static std::unique_ptr<DynamicTypeDefinition>
333 getSingletonDynamicType(TestDialect *testDialect) {
334   return DynamicTypeDefinition::get(
335       "dynamic_singleton", testDialect,
336       [](function_ref<InFlightDiagnostic()> emitError,
337          ArrayRef<Attribute> args) {
338         if (!args.empty()) {
339           emitError() << "expected 0 type arguments, but had " << args.size();
340           return failure();
341         }
342         return success();
343       });
344 }
345 
346 /// Define a dynamic type representing a pair.
347 static std::unique_ptr<DynamicTypeDefinition>
348 getPairDynamicType(TestDialect *testDialect) {
349   return DynamicTypeDefinition::get(
350       "dynamic_pair", testDialect,
351       [](function_ref<InFlightDiagnostic()> emitError,
352          ArrayRef<Attribute> args) {
353         if (args.size() != 2) {
354           emitError() << "expected 2 type arguments, but had " << args.size();
355           return failure();
356         }
357         return success();
358       });
359 }
360 
361 static std::unique_ptr<DynamicTypeDefinition>
362 getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
363   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
364                      ArrayRef<Attribute> args) {
365     if (args.size() != 2) {
366       emitError() << "expected 2 type arguments, but had " << args.size();
367       return failure();
368     }
369     return success();
370   };
371 
372   auto parser = [](AsmParser &parser,
373                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
374     Attribute leftAttr, rightAttr;
375     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
376         parser.parseColon() || parser.parseAttribute(rightAttr) ||
377         parser.parseGreater())
378       return failure();
379     parsedParams.push_back(leftAttr);
380     parsedParams.push_back(rightAttr);
381     return success();
382   };
383 
384   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
385     printer << "<" << params[0] << ":" << params[1] << ">";
386   };
387 
388   return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
389                                     testDialect, std::move(verifier),
390                                     std::move(parser), std::move(printer));
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // TestDialect
395 //===----------------------------------------------------------------------===//
396 
397 namespace {
398 
399 struct PtrElementModel
400     : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
401                                                               SimpleAType> {};
402 } // namespace
403 
404 void TestDialect::registerTypes() {
405   addTypes<TestRecursiveType,
406 #define GET_TYPEDEF_LIST
407 #include "TestTypeDefs.cpp.inc"
408            >();
409   SimpleAType::attachInterface<PtrElementModel>(*getContext());
410 
411   registerDynamicType(getSingletonDynamicType(this));
412   registerDynamicType(getPairDynamicType(this));
413   registerDynamicType(getCustomAssemblyFormatDynamicType(this));
414 }
415 
416 Type TestDialect::parseType(DialectAsmParser &parser) const {
417   StringRef typeTag;
418   {
419     Type genType;
420     auto parseResult = generatedTypeParser(parser, &typeTag, genType);
421     if (parseResult.has_value())
422       return genType;
423   }
424 
425   {
426     Type dynType;
427     auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
428     if (parseResult.has_value()) {
429       if (succeeded(parseResult.value()))
430         return dynType;
431       return Type();
432     }
433   }
434 
435   if (typeTag != "test_rec") {
436     parser.emitError(parser.getNameLoc()) << "unknown type!";
437     return Type();
438   }
439 
440   StringRef name;
441   if (parser.parseLess() || parser.parseKeyword(&name))
442     return Type();
443   auto rec = TestRecursiveType::get(parser.getContext(), name);
444 
445   FailureOr<AsmParser::CyclicParseReset> cyclicParse =
446       parser.tryStartCyclicParse(rec);
447 
448   // If this type already has been parsed above in the stack, expect just the
449   // name.
450   if (failed(cyclicParse)) {
451     if (failed(parser.parseGreater()))
452       return Type();
453     return rec;
454   }
455 
456   // Otherwise, parse the body and update the type.
457   if (failed(parser.parseComma()))
458     return Type();
459   Type subtype = parseType(parser);
460   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
461     return Type();
462 
463   return rec;
464 }
465 
466 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
467   if (succeeded(generatedTypePrinter(type, printer)))
468     return;
469 
470   if (succeeded(printIfDynamicType(type, printer)))
471     return;
472 
473   auto rec = llvm::cast<TestRecursiveType>(type);
474 
475   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
476       printer.tryStartCyclicPrint(rec);
477 
478   printer << "test_rec<" << rec.getName();
479   if (succeeded(cyclicPrint)) {
480     printer << ", ";
481     printType(rec.getBody(), printer);
482   }
483   printer << ">";
484 }
485 
486 Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
487 
488 void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
489 
490 StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
491 
492 Type TestRecursiveAliasType::parse(AsmParser &parser) {
493   StringRef name;
494   if (parser.parseLess() || parser.parseKeyword(&name))
495     return Type();
496   auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
497 
498   FailureOr<AsmParser::CyclicParseReset> cyclicParse =
499       parser.tryStartCyclicParse(rec);
500 
501   // If this type already has been parsed above in the stack, expect just the
502   // name.
503   if (failed(cyclicParse)) {
504     if (failed(parser.parseGreater()))
505       return Type();
506     return rec;
507   }
508 
509   // Otherwise, parse the body and update the type.
510   if (failed(parser.parseComma()))
511     return Type();
512   Type subtype;
513   if (parser.parseType(subtype))
514     return nullptr;
515   if (!subtype || failed(parser.parseGreater()))
516     return Type();
517 
518   rec.setBody(subtype);
519 
520   return rec;
521 }
522 
523 void TestRecursiveAliasType::print(AsmPrinter &printer) const {
524 
525   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
526       printer.tryStartCyclicPrint(*this);
527 
528   printer << "<" << getName();
529   if (succeeded(cyclicPrint)) {
530     printer << ", ";
531     printer << getBody();
532   }
533   printer << ">";
534 }
535 
536 void TestTypeOpAsmTypeInterfaceType::getAsmName(
537     OpAsmSetNameFn setNameFn) const {
538   setNameFn("op_asm_type_interface");
539 }
540