xref: /llvm-project/mlir/test/lib/Dialect/Test/TestAttributes.cpp (revision 690dc4eff19c85d0afaa9e189cf7e40fe3d1ff76)
1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "TestTypes.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/Types.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace mlir;
32 using namespace test;
33 
34 //===----------------------------------------------------------------------===//
35 // CompoundAAttr
36 //===----------------------------------------------------------------------===//
37 
38 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
39   int widthOfSomething;
40   Type oneType;
41   SmallVector<int, 4> arrayOfInts;
42   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
43       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
44       parser.parseLSquare())
45     return Attribute();
46 
47   int intVal;
48   while (!*parser.parseOptionalInteger(intVal)) {
49     arrayOfInts.push_back(intVal);
50     if (parser.parseOptionalComma())
51       break;
52   }
53 
54   if (parser.parseRSquare() || parser.parseGreater())
55     return Attribute();
56   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
57 }
58 
59 void CompoundAAttr::print(AsmPrinter &printer) const {
60   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
61   llvm::interleaveComma(getArrayOfInts(), printer);
62   printer << "]>";
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // CompoundAAttr
67 //===----------------------------------------------------------------------===//
68 
69 Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
70   if (parser.parseLess()){
71     return Attribute();
72   }
73   SmallVector<int64_t> shape;
74   if (parser.parseOptionalGreater()) {
75     auto parseDecimal = [&]() {
76       shape.emplace_back();
77       auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
78       if (!parseResult.has_value() || failed(*parseResult)) {
79         parser.emitError(parser.getCurrentLocation()) << "expected an integer";
80         return failure();
81       }
82       return success();
83     };
84     if (failed(parseDecimal())) {
85       return Attribute();
86     }
87     while (failed(parser.parseOptionalGreater())) {
88       if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
89         return Attribute();
90       }
91     }
92   }
93   return get(parser.getContext(), shape);
94 }
95 
96 void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
97   printer << "<";
98   llvm::interleave(getShape(), printer, "x");
99   printer << ">";
100 }
101 
102 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
103   SmallVector<uint64_t> elements;
104   if (parser.parseLess() || parser.parseLSquare())
105     return Attribute();
106   uint64_t intVal;
107   while (succeeded(*parser.parseOptionalInteger(intVal))) {
108     elements.push_back(intVal);
109     if (parser.parseOptionalComma())
110       break;
111   }
112 
113   if (parser.parseRSquare() || parser.parseGreater())
114     return Attribute();
115   return parser.getChecked<TestI64ElementsAttr>(
116       parser.getContext(), llvm::cast<ShapedType>(type), elements);
117 }
118 
119 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
120   printer << "<[";
121   llvm::interleaveComma(getElements(), printer);
122   printer << "]>";
123 }
124 
125 LogicalResult
126 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
127                             ShapedType type, ArrayRef<uint64_t> elements) {
128   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
129     return emitError()
130            << "number of elements does not match the provided shape type, got: "
131            << elements.size() << ", but expected: " << type.getNumElements();
132   }
133   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
134     return emitError() << "expected single rank 64-bit shape type, but got: "
135                        << type;
136   return success();
137 }
138 
139 LogicalResult TestAttrWithFormatAttr::verify(
140     function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
141     IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six,
142     ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
143   if (four.size() != static_cast<unsigned>(one))
144     return emitError() << "expected 'one' to equal 'four.size()'";
145   return success();
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Utility Functions for Generated Attributes
150 //===----------------------------------------------------------------------===//
151 
152 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
153   SmallVector<int> ints;
154   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
155         ints.push_back(0);
156         return parser.parseInteger(ints.back());
157       }) ||
158       parser.parseRSquare())
159     return failure();
160   return ints;
161 }
162 
163 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
164   printer << '[';
165   llvm::interleaveComma(ints, printer);
166   printer << ']';
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // TestSubElementsAccessAttr
171 //===----------------------------------------------------------------------===//
172 
173 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
174                                            ::mlir::Type type) {
175   Attribute first, second, third;
176   if (parser.parseLess() || parser.parseAttribute(first) ||
177       parser.parseComma() || parser.parseAttribute(second) ||
178       parser.parseComma() || parser.parseAttribute(third) ||
179       parser.parseGreater()) {
180     return {};
181   }
182   return get(parser.getContext(), first, second, third);
183 }
184 
185 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
186   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
187           << ">";
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // TestExtern1DI64ElementsAttr
192 //===----------------------------------------------------------------------===//
193 
194 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
195   if (auto *blob = getHandle().getBlob())
196     return blob->getDataAs<uint64_t>();
197   return std::nullopt;
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // TestCustomAnchorAttr
202 //===----------------------------------------------------------------------===//
203 
204 static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) {
205   bool b;
206   if (p.parseInteger(b))
207     return failure();
208   result = b;
209   return success();
210 }
211 
212 static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
213   p << (*result ? "true" : "false");
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // CopyCountAttr Implementation
218 //===----------------------------------------------------------------------===//
219 
220 CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
221   CopyCount::counter++;
222 }
223 
224 CopyCount &CopyCount::operator=(const CopyCount &rhs) {
225   CopyCount::counter++;
226   value = rhs.value;
227   return *this;
228 }
229 
230 int CopyCount::counter;
231 
232 static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
233   return lhs.value == rhs.value;
234 }
235 
236 llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
237                                     const test::CopyCount &value) {
238   return os << value.value;
239 }
240 
241 template <>
242 struct mlir::FieldParser<test::CopyCount> {
243   static FailureOr<test::CopyCount> parse(AsmParser &parser) {
244     std::string value;
245     if (parser.parseKeyword(value))
246       return failure();
247     return test::CopyCount(value);
248   }
249 };
250 namespace test {
251 llvm::hash_code hash_value(const test::CopyCount &copyCount) {
252   return llvm::hash_value(copyCount.value);
253 }
254 } // namespace test
255 
256 //===----------------------------------------------------------------------===//
257 // TestConditionalAliasAttr
258 //===----------------------------------------------------------------------===//
259 
260 /// Attempt to parse the conditionally-aliased string attribute as a keyword or
261 /// string, else try to parse an alias.
262 static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) {
263   std::string str;
264   if (succeeded(p.parseOptionalKeywordOrString(&str))) {
265     value = StringAttr::get(p.getContext(), str);
266     return success();
267   }
268   return p.parseAttribute(value);
269 }
270 
271 /// Print the string attribute as an alias if it has one, otherwise print it as
272 /// a keyword if possible.
273 static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
274   if (succeeded(p.printAlias(value)))
275     return;
276   p.printKeywordOrString(value);
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Custom Float Attribute
281 //===----------------------------------------------------------------------===//
282 
283 static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
284                                  APFloat value) {
285   p << typeStrAttr << " : " << value;
286 }
287 
288 static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
289                                         FailureOr<APFloat> &value) {
290 
291   std::string str;
292   if (p.parseString(&str))
293     return failure();
294 
295   typeStrAttr = StringAttr::get(p.getContext(), str);
296 
297   if (p.parseColon())
298     return failure();
299 
300   const llvm::fltSemantics *semantics;
301   if (str == "float")
302     semantics = &llvm::APFloat::IEEEsingle();
303   else if (str == "double")
304     semantics = &llvm::APFloat::IEEEdouble();
305   else if (str == "fp80")
306     semantics = &llvm::APFloat::x87DoubleExtended();
307   else
308     return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
309                                                "'float', 'double' or 'fp80'");
310 
311   APFloat parsedValue(0.0);
312   if (p.parseFloat(*semantics, parsedValue))
313     return failure();
314 
315   value.emplace(parsedValue);
316   return success();
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // Tablegen Generated Definitions
321 //===----------------------------------------------------------------------===//
322 
323 #include "TestAttrInterfaces.cpp.inc"
324 #include "TestOpEnums.cpp.inc"
325 #define GET_ATTRDEF_CLASSES
326 #include "TestAttrDefs.cpp.inc"
327 
328 //===----------------------------------------------------------------------===//
329 // Dynamic Attributes
330 //===----------------------------------------------------------------------===//
331 
332 /// Define a singleton dynamic attribute.
333 static std::unique_ptr<DynamicAttrDefinition>
334 getDynamicSingletonAttr(TestDialect *testDialect) {
335   return DynamicAttrDefinition::get(
336       "dynamic_singleton", testDialect,
337       [](function_ref<InFlightDiagnostic()> emitError,
338          ArrayRef<Attribute> args) {
339         if (!args.empty()) {
340           emitError() << "expected 0 attribute arguments, but had "
341                       << args.size();
342           return failure();
343         }
344         return success();
345       });
346 }
347 
348 /// Define a dynamic attribute representing a pair or attributes.
349 static std::unique_ptr<DynamicAttrDefinition>
350 getDynamicPairAttr(TestDialect *testDialect) {
351   return DynamicAttrDefinition::get(
352       "dynamic_pair", testDialect,
353       [](function_ref<InFlightDiagnostic()> emitError,
354          ArrayRef<Attribute> args) {
355         if (args.size() != 2) {
356           emitError() << "expected 2 attribute arguments, but had "
357                       << args.size();
358           return failure();
359         }
360         return success();
361       });
362 }
363 
364 static std::unique_ptr<DynamicAttrDefinition>
365 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
366   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
367                      ArrayRef<Attribute> args) {
368     if (args.size() != 2) {
369       emitError() << "expected 2 attribute arguments, but had " << args.size();
370       return failure();
371     }
372     return success();
373   };
374 
375   auto parser = [](AsmParser &parser,
376                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
377     Attribute leftAttr, rightAttr;
378     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
379         parser.parseColon() || parser.parseAttribute(rightAttr) ||
380         parser.parseGreater())
381       return failure();
382     parsedParams.push_back(leftAttr);
383     parsedParams.push_back(rightAttr);
384     return success();
385   };
386 
387   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
388     printer << "<" << params[0] << ":" << params[1] << ">";
389   };
390 
391   return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
392                                     testDialect, std::move(verifier),
393                                     std::move(parser), std::move(printer));
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // TestDialect
398 //===----------------------------------------------------------------------===//
399 
400 void TestDialect::registerAttributes() {
401   addAttributes<
402 #define GET_ATTRDEF_LIST
403 #include "TestAttrDefs.cpp.inc"
404       >();
405   registerDynamicAttr(getDynamicSingletonAttr(this));
406   registerDynamicAttr(getDynamicPairAttr(this));
407   registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
408 }
409