xref: /llvm-project/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp (revision 8955e285e1ac48bfcd9e030a055e66aec37785cc)
1 //===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestFormatUtils.h"
10 #include "mlir/IR/Builders.h"
11 
12 using namespace mlir;
13 using namespace test;
14 
15 //===----------------------------------------------------------------------===//
16 // CustomDirectiveOperands
17 //===----------------------------------------------------------------------===//
18 
19 ParseResult test::parseCustomDirectiveOperands(
20     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
21     std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
22     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
23   if (parser.parseOperand(operand))
24     return failure();
25   if (succeeded(parser.parseOptionalComma())) {
26     optOperand.emplace();
27     if (parser.parseOperand(*optOperand))
28       return failure();
29   }
30   if (parser.parseArrow() || parser.parseLParen() ||
31       parser.parseOperandList(varOperands) || parser.parseRParen())
32     return failure();
33   return success();
34 }
35 
36 void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
37                                         Value operand, Value optOperand,
38                                         OperandRange varOperands) {
39   printer << operand;
40   if (optOperand)
41     printer << ", " << optOperand;
42   printer << " -> (" << varOperands << ")";
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // CustomDirectiveResults
47 //===----------------------------------------------------------------------===//
48 
49 ParseResult
50 test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
51                                   Type &optOperandType,
52                                   SmallVectorImpl<Type> &varOperandTypes) {
53   if (parser.parseColon())
54     return failure();
55 
56   if (parser.parseType(operandType))
57     return failure();
58   if (succeeded(parser.parseOptionalComma()))
59     if (parser.parseType(optOperandType))
60       return failure();
61   if (parser.parseArrow() || parser.parseLParen() ||
62       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
63     return failure();
64   return success();
65 }
66 
67 void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
68                                        Type operandType, Type optOperandType,
69                                        TypeRange varOperandTypes) {
70   printer << " : " << operandType;
71   if (optOperandType)
72     printer << ", " << optOperandType;
73   printer << " -> (" << varOperandTypes << ")";
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // CustomDirectiveWithTypeRefs
78 //===----------------------------------------------------------------------===//
79 
80 ParseResult test::parseCustomDirectiveWithTypeRefs(
81     OpAsmParser &parser, Type operandType, Type optOperandType,
82     const SmallVectorImpl<Type> &varOperandTypes) {
83   if (parser.parseKeyword("type_refs_capture"))
84     return failure();
85 
86   Type operandType2, optOperandType2;
87   SmallVector<Type, 1> varOperandTypes2;
88   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
89                                   varOperandTypes2))
90     return failure();
91 
92   if (operandType != operandType2 || optOperandType != optOperandType2 ||
93       varOperandTypes != varOperandTypes2)
94     return failure();
95 
96   return success();
97 }
98 
99 void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
100                                             Operation *op, Type operandType,
101                                             Type optOperandType,
102                                             TypeRange varOperandTypes) {
103   printer << " type_refs_capture ";
104   printCustomDirectiveResults(printer, op, operandType, optOperandType,
105                               varOperandTypes);
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // CustomDirectiveOperandsAndTypes
110 //===----------------------------------------------------------------------===//
111 
112 ParseResult test::parseCustomDirectiveOperandsAndTypes(
113     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
114     std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
115     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
116     Type &operandType, Type &optOperandType,
117     SmallVectorImpl<Type> &varOperandTypes) {
118   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
119       parseCustomDirectiveResults(parser, operandType, optOperandType,
120                                   varOperandTypes))
121     return failure();
122   return success();
123 }
124 
125 void test::printCustomDirectiveOperandsAndTypes(
126     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
127     OperandRange varOperands, Type operandType, Type optOperandType,
128     TypeRange varOperandTypes) {
129   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
130   printCustomDirectiveResults(printer, op, operandType, optOperandType,
131                               varOperandTypes);
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // CustomDirectiveRegions
136 //===----------------------------------------------------------------------===//
137 
138 ParseResult test::parseCustomDirectiveRegions(
139     OpAsmParser &parser, Region &region,
140     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
141   if (parser.parseRegion(region))
142     return failure();
143   if (failed(parser.parseOptionalComma()))
144     return success();
145   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
146   if (parser.parseRegion(*varRegion))
147     return failure();
148   varRegions.emplace_back(std::move(varRegion));
149   return success();
150 }
151 
152 void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
153                                        Region &region,
154                                        MutableArrayRef<Region> varRegions) {
155   printer.printRegion(region);
156   if (!varRegions.empty()) {
157     printer << ", ";
158     for (Region &region : varRegions)
159       printer.printRegion(region);
160   }
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // CustomDirectiveSuccessors
165 //===----------------------------------------------------------------------===//
166 
167 ParseResult
168 test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
169                                      SmallVectorImpl<Block *> &varSuccessors) {
170   if (parser.parseSuccessor(successor))
171     return failure();
172   if (failed(parser.parseOptionalComma()))
173     return success();
174   Block *varSuccessor;
175   if (parser.parseSuccessor(varSuccessor))
176     return failure();
177   varSuccessors.append(2, varSuccessor);
178   return success();
179 }
180 
181 void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
182                                           Block *successor,
183                                           SuccessorRange varSuccessors) {
184   printer << successor;
185   if (!varSuccessors.empty())
186     printer << ", " << varSuccessors.front();
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // CustomDirectiveAttributes
191 //===----------------------------------------------------------------------===//
192 
193 ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
194                                                  IntegerAttr &attr,
195                                                  IntegerAttr &optAttr) {
196   if (parser.parseAttribute(attr))
197     return failure();
198   if (succeeded(parser.parseOptionalComma())) {
199     if (parser.parseAttribute(optAttr))
200       return failure();
201   }
202   return success();
203 }
204 
205 void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
206                                           Attribute attribute,
207                                           Attribute optAttribute) {
208   printer << attribute;
209   if (optAttribute)
210     printer << ", " << optAttribute;
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // CustomDirectiveAttrDict
215 //===----------------------------------------------------------------------===//
216 
217 ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
218                                                NamedAttrList &attrs) {
219   return parser.parseOptionalAttrDict(attrs);
220 }
221 
222 void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
223                                         DictionaryAttr attrs) {
224   printer.printOptionalAttrDict(attrs.getValue());
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // CustomDirectiveOptionalOperandRef
229 //===----------------------------------------------------------------------===//
230 
231 ParseResult test::parseCustomDirectiveOptionalOperandRef(
232     OpAsmParser &parser,
233     std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
234   int64_t operandCount = 0;
235   if (parser.parseInteger(operandCount))
236     return failure();
237   bool expectedOptionalOperand = operandCount == 0;
238   return success(expectedOptionalOperand != !!optOperand);
239 }
240 
241 void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
242                                                   Operation *op,
243                                                   Value optOperand) {
244   printer << (optOperand ? "1" : "0");
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // CustomDirectiveOptionalOperand
249 //===----------------------------------------------------------------------===//
250 
251 ParseResult test::parseCustomOptionalOperand(
252     OpAsmParser &parser,
253     std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
254   if (succeeded(parser.parseOptionalLParen())) {
255     optOperand.emplace();
256     if (parser.parseOperand(*optOperand) || parser.parseRParen())
257       return failure();
258   }
259   return success();
260 }
261 
262 void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
263                                       Value optOperand) {
264   if (optOperand)
265     printer << "(" << optOperand << ") ";
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // CustomDirectiveSwitchCases
270 //===----------------------------------------------------------------------===//
271 
272 ParseResult
273 test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
274                        SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
275   SmallVector<int64_t> caseValues;
276   while (succeeded(p.parseOptionalKeyword("case"))) {
277     int64_t value;
278     Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
279     if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
280       return failure();
281     caseValues.push_back(value);
282   }
283   cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
284   return success();
285 }
286 
287 void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
288                             DenseI64ArrayAttr cases, RegionRange caseRegions) {
289   for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
290     p.printNewline();
291     p << "case " << value << ' ';
292     p.printRegion(*region, /*printEntryBlockArgs=*/false);
293   }
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // CustomUsingPropertyInCustom
298 //===----------------------------------------------------------------------===//
299 
300 bool test::parseUsingPropertyInCustom(OpAsmParser &parser,
301                                       SmallVector<int64_t> &value) {
302   auto elemParser = [&]() {
303     int64_t v = 0;
304     if (failed(parser.parseInteger(v)))
305       return failure();
306     value.push_back(v);
307     return success();
308   };
309   return failed(parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
310                                                elemParser));
311 }
312 
313 void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
314                                       ArrayRef<int64_t> value) {
315   printer << '[' << value << ']';
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CustomDirectiveIntProperty
320 //===----------------------------------------------------------------------===//
321 
322 bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
323   return failed(parser.parseInteger(value));
324 }
325 
326 void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
327                             int64_t value) {
328   printer << value;
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // CustomDirectiveSumProperty
333 //===----------------------------------------------------------------------===//
334 
335 bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
336                             int64_t first) {
337   int64_t sum;
338   auto loc = parser.getCurrentLocation();
339   if (parser.parseInteger(second) || parser.parseEqual() ||
340       parser.parseInteger(sum))
341     return true;
342   if (sum != second + first) {
343     parser.emitError(loc, "Expected sum to equal first + second");
344     return true;
345   }
346   return false;
347 }
348 
349 void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
350                             int64_t second, int64_t first) {
351   printer << second << " = " << (second + first);
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // CustomDirectiveOptionalCustomParser
356 //===----------------------------------------------------------------------===//
357 
358 OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
359                                                     IntegerAttr &result) {
360   if (succeeded(p.parseOptionalKeyword("foo")))
361     return p.parseAttribute(result);
362   return {};
363 }
364 
365 void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
366                                      IntegerAttr result) {
367   p << "foo ";
368   p.printAttribute(result);
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // CustomDirectiveAttrElideType
373 //===----------------------------------------------------------------------===//
374 
375 ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
376                                      Attribute &attr) {
377   return parser.parseAttribute(attr, type.getValue());
378 }
379 
380 void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
381                               Attribute attr) {
382   printer.printAttributeWithoutType(attr);
383 }
384