xref: /llvm-project/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestOpsSyntax.h"
10 #include "TestDialect.h"
11 #include "TestOps.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "llvm/Support/Base64.h"
14 
15 using namespace mlir;
16 using namespace test;
17 
18 //===----------------------------------------------------------------------===//
19 // Test Format* operations
20 //===----------------------------------------------------------------------===//
21 
22 //===----------------------------------------------------------------------===//
23 // Parsing
24 
parseCustomOptionalOperand(OpAsmParser & parser,std::optional<OpAsmParser::UnresolvedOperand> & optOperand)25 static ParseResult parseCustomOptionalOperand(
26     OpAsmParser &parser,
27     std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
28   if (succeeded(parser.parseOptionalLParen())) {
29     optOperand.emplace();
30     if (parser.parseOperand(*optOperand) || parser.parseRParen())
31       return failure();
32   }
33   return success();
34 }
35 
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,std::optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands)36 static ParseResult parseCustomDirectiveOperands(
37     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
38     std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
39     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
40   if (parser.parseOperand(operand))
41     return failure();
42   if (succeeded(parser.parseOptionalComma())) {
43     optOperand.emplace();
44     if (parser.parseOperand(*optOperand))
45       return failure();
46   }
47   if (parser.parseArrow() || parser.parseLParen() ||
48       parser.parseOperandList(varOperands) || parser.parseRParen())
49     return failure();
50   return success();
51 }
52 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)53 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
54                             Type &optOperandType,
55                             SmallVectorImpl<Type> &varOperandTypes) {
56   if (parser.parseColon())
57     return failure();
58 
59   if (parser.parseType(operandType))
60     return failure();
61   if (succeeded(parser.parseOptionalComma())) {
62     if (parser.parseType(optOperandType))
63       return failure();
64   }
65   if (parser.parseArrow() || parser.parseLParen() ||
66       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
67     return failure();
68   return success();
69 }
70 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)71 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
72                                  Type optOperandType,
73                                  const SmallVectorImpl<Type> &varOperandTypes) {
74   if (parser.parseKeyword("type_refs_capture"))
75     return failure();
76 
77   Type operandType2, optOperandType2;
78   SmallVector<Type, 1> varOperandTypes2;
79   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
80                                   varOperandTypes2))
81     return failure();
82 
83   if (operandType != operandType2 || optOperandType != optOperandType2 ||
84       varOperandTypes != varOperandTypes2)
85     return failure();
86 
87   return success();
88 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,std::optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)89 static ParseResult parseCustomDirectiveOperandsAndTypes(
90     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
91     std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
92     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
93     Type &operandType, Type &optOperandType,
94     SmallVectorImpl<Type> &varOperandTypes) {
95   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
96       parseCustomDirectiveResults(parser, operandType, optOperandType,
97                                   varOperandTypes))
98     return failure();
99   return success();
100 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)101 static ParseResult parseCustomDirectiveRegions(
102     OpAsmParser &parser, Region &region,
103     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
104   if (parser.parseRegion(region))
105     return failure();
106   if (failed(parser.parseOptionalComma()))
107     return success();
108   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
109   if (parser.parseRegion(*varRegion))
110     return failure();
111   varRegions.emplace_back(std::move(varRegion));
112   return success();
113 }
114 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)115 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
116                                SmallVectorImpl<Block *> &varSuccessors) {
117   if (parser.parseSuccessor(successor))
118     return failure();
119   if (failed(parser.parseOptionalComma()))
120     return success();
121   Block *varSuccessor;
122   if (parser.parseSuccessor(varSuccessor))
123     return failure();
124   varSuccessors.append(2, varSuccessor);
125   return success();
126 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)127 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
128                                                   IntegerAttr &attr,
129                                                   IntegerAttr &optAttr) {
130   if (parser.parseAttribute(attr))
131     return failure();
132   if (succeeded(parser.parseOptionalComma())) {
133     if (parser.parseAttribute(optAttr))
134       return failure();
135   }
136   return success();
137 }
parseCustomDirectiveSpacing(OpAsmParser & parser,mlir::StringAttr & attr)138 static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
139                                                mlir::StringAttr &attr) {
140   return parser.parseAttribute(attr);
141 }
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)142 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
143                                                 NamedAttrList &attrs) {
144   return parser.parseOptionalAttrDict(attrs);
145 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,std::optional<OpAsmParser::UnresolvedOperand> & optOperand)146 static ParseResult parseCustomDirectiveOptionalOperandRef(
147     OpAsmParser &parser,
148     std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
149   int64_t operandCount = 0;
150   if (parser.parseInteger(operandCount))
151     return failure();
152   bool expectedOptionalOperand = operandCount == 0;
153   return success(expectedOptionalOperand != optOperand.has_value());
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Printing
158 
printCustomOptionalOperand(OpAsmPrinter & printer,Operation *,Value optOperand)159 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
160                                        Value optOperand) {
161   if (optOperand)
162     printer << "(" << optOperand << ") ";
163 }
164 
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)165 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
166                                          Value operand, Value optOperand,
167                                          OperandRange varOperands) {
168   printer << operand;
169   if (optOperand)
170     printer << ", " << optOperand;
171   printer << " -> (" << varOperands << ")";
172 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)173 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
174                                         Type operandType, Type optOperandType,
175                                         TypeRange varOperandTypes) {
176   printer << " : " << operandType;
177   if (optOperandType)
178     printer << ", " << optOperandType;
179   printer << " -> (" << varOperandTypes << ")";
180 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)181 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
182                                              Operation *op, Type operandType,
183                                              Type optOperandType,
184                                              TypeRange varOperandTypes) {
185   printer << " type_refs_capture ";
186   printCustomDirectiveResults(printer, op, operandType, optOperandType,
187                               varOperandTypes);
188 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)189 static void printCustomDirectiveOperandsAndTypes(
190     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
191     OperandRange varOperands, Type operandType, Type optOperandType,
192     TypeRange varOperandTypes) {
193   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
194   printCustomDirectiveResults(printer, op, operandType, optOperandType,
195                               varOperandTypes);
196 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)197 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
198                                         Region &region,
199                                         MutableArrayRef<Region> varRegions) {
200   printer.printRegion(region);
201   if (!varRegions.empty()) {
202     printer << ", ";
203     for (Region &region : varRegions)
204       printer.printRegion(region);
205   }
206 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)207 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
208                                            Block *successor,
209                                            SuccessorRange varSuccessors) {
210   printer << successor;
211   if (!varSuccessors.empty())
212     printer << ", " << varSuccessors.front();
213 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)214 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
215                                            Attribute attribute,
216                                            Attribute optAttribute) {
217   printer << attribute;
218   if (optAttribute)
219     printer << ", " << optAttribute;
220 }
printCustomDirectiveSpacing(OpAsmPrinter & printer,Operation * op,Attribute attribute)221 static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
222                                         Attribute attribute) {
223   printer << attribute;
224 }
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)225 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
226                                          DictionaryAttr attrs) {
227   printer.printOptionalAttrDict(attrs.getValue());
228 }
229 
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)230 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
231                                                    Operation *op,
232                                                    Value optOperand) {
233   printer << (optOperand ? "1" : "0");
234 }
235 //===----------------------------------------------------------------------===//
236 // Test parser.
237 //===----------------------------------------------------------------------===//
238 
parse(OpAsmParser & parser,OperationState & result)239 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
240                                          OperationState &result) {
241   if (parser.parseOptionalColon())
242     return success();
243   uint64_t numResults;
244   if (parser.parseInteger(numResults))
245     return failure();
246 
247   IndexType type = parser.getBuilder().getIndexType();
248   for (unsigned i = 0; i < numResults; ++i)
249     result.addTypes(type);
250   return success();
251 }
252 
print(OpAsmPrinter & p)253 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
254   if (unsigned numResults = getNumResults())
255     p << " : " << numResults;
256 }
257 
parse(OpAsmParser & parser,OperationState & result)258 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
259                                          OperationState &result) {
260   StringRef keyword;
261   if (parser.parseKeyword(&keyword))
262     return failure();
263   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
264   return success();
265 }
266 
print(OpAsmPrinter & p)267 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
268 
parse(OpAsmParser & parser,OperationState & result)269 ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
270                                    OperationState &result) {
271   std::vector<char> bytes;
272   if (parser.parseBase64Bytes(&bytes))
273     return failure();
274   result.addAttribute("b64", parser.getBuilder().getStringAttr(
275                                  StringRef(&bytes.front(), bytes.size())));
276   return success();
277 }
278 
print(OpAsmPrinter & p)279 void ParseB64BytesOp::print(OpAsmPrinter &p) {
280   p << " \"" << llvm::encodeBase64(getB64()) << "\"";
281 }
282 
inferReturnTypes(::mlir::MLIRContext * context,::std::optional<::mlir::Location> location,::mlir::ValueRange operands,::mlir::DictionaryAttr attributes,OpaqueProperties properties,::mlir::RegionRange regions,::llvm::SmallVectorImpl<::mlir::Type> & inferredReturnTypes)283 ::llvm::LogicalResult FormatInferType2Op::inferReturnTypes(
284     ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
285     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
286     OpaqueProperties properties, ::mlir::RegionRange regions,
287     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
288   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
289   return ::mlir::success();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
294 
parse(OpAsmParser & parser,OperationState & result)295 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
296                                     OperationState &result) {
297   if (parser.parseKeyword("wraps"))
298     return failure();
299 
300   // Parse the wrapped op in a region
301   Region &body = *result.addRegion();
302   body.push_back(new Block);
303   Block &block = body.back();
304   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
305   if (!wrappedOp)
306     return failure();
307 
308   // Create a return terminator in the inner region, pass as operand to the
309   // terminator the returned values from the wrapped operation.
310   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
311   OpBuilder builder(parser.getContext());
312   builder.setInsertionPointToEnd(&block);
313   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
314 
315   // Get the results type for the wrapping op from the terminator operands.
316   Operation &returnOp = body.back().back();
317   result.types.append(returnOp.operand_type_begin(),
318                       returnOp.operand_type_end());
319 
320   // Use the location of the wrapped op for the "test.wrapping_region" op.
321   result.location = wrappedOp->getLoc();
322 
323   return success();
324 }
325 
print(OpAsmPrinter & p)326 void WrappingRegionOp::print(OpAsmPrinter &p) {
327   p << " wraps ";
328   p.printGenericOp(&getRegion().front().front());
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
333 //   parseGenericOperationAfterOpName
334 //   parseCustomOperationName
335 //===----------------------------------------------------------------------===//
336 
parse(OpAsmParser & parser,OperationState & result)337 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
338                                          OperationState &result) {
339 
340   SMLoc loc = parser.getCurrentLocation();
341   Location currLocation = parser.getEncodedSourceLoc(loc);
342 
343   // Parse the operands.
344   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
345   if (parser.parseOperandList(operands))
346     return failure();
347 
348   // Check if we are parsing the pretty-printed version
349   //  test.pretty_printed_region start <inner-op> end : <functional-type>
350   // Else fallback to parsing the "non pretty-printed" version.
351   if (!succeeded(parser.parseOptionalKeyword("start")))
352     return parser.parseGenericOperationAfterOpName(result,
353                                                    llvm::ArrayRef(operands));
354 
355   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
356   if (failed(parseOpNameInfo))
357     return failure();
358 
359   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
360 
361   FunctionType opFntype;
362   std::optional<Location> explicitLoc;
363   if (parser.parseKeyword("end") || parser.parseColon() ||
364       parser.parseType(opFntype) ||
365       parser.parseOptionalLocationSpecifier(explicitLoc))
366     return failure();
367 
368   // If location of the op is explicitly provided, then use it; Else use
369   // the parser's current location.
370   Location opLoc = explicitLoc.value_or(currLocation);
371 
372   // Derive the SSA-values for op's operands.
373   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
374                              result.operands))
375     return failure();
376 
377   // Add a region for op.
378   Region &region = *result.addRegion();
379 
380   // Create a basic-block inside op's region.
381   Block &block = region.emplaceBlock();
382 
383   // Create and insert an "inner-op" operation in the block.
384   // Just for testing purposes, we can assume that inner op is a binary op with
385   // result and operand types all same as the test-op's first operand.
386   Type innerOpType = opFntype.getInput(0);
387   Value lhs = block.addArgument(innerOpType, opLoc);
388   Value rhs = block.addArgument(innerOpType, opLoc);
389 
390   OpBuilder builder(parser.getBuilder().getContext());
391   builder.setInsertionPointToStart(&block);
392 
393   Operation *innerOp =
394       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
395 
396   // Insert a return statement in the block returning the inner-op's result.
397   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
398 
399   // Populate the op operation-state with result-type and location.
400   result.addTypes(opFntype.getResults());
401   result.location = innerOp->getLoc();
402 
403   return success();
404 }
405 
print(OpAsmPrinter & p)406 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
407   p << ' ';
408   p.printOperands(getOperands());
409 
410   Operation &innerOp = getRegion().front().front();
411   // Assuming that region has a single non-terminator inner-op, if the inner-op
412   // meets some criteria (which in this case is a simple one  based on the name
413   // of inner-op), then we can print the entire region in a succinct way.
414   // Here we assume that the prototype of "test.special.op" can be trivially
415   // derived while parsing it back.
416   if (innerOp.getName().getStringRef() == "test.special.op") {
417     p << " start test.special.op end";
418   } else {
419     p << " (";
420     p.printRegion(getRegion());
421     p << ")";
422   }
423 
424   p << " : ";
425   p.printFunctionalType(*this);
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // Test PolyForOp - parse list of region arguments.
430 //===----------------------------------------------------------------------===//
431 
parse(OpAsmParser & parser,OperationState & result)432 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
433   SmallVector<OpAsmParser::Argument, 4> ivsInfo;
434   // Parse list of region arguments without a delimiter.
435   if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
436     return failure();
437 
438   // Parse the body region.
439   Region *body = result.addRegion();
440   for (auto &iv : ivsInfo)
441     iv.type = parser.getBuilder().getIndexType();
442   return parser.parseRegion(*body, ivsInfo);
443 }
444 
print(OpAsmPrinter & p)445 void PolyForOp::print(OpAsmPrinter &p) {
446   p << " ";
447   llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
448     p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
449   });
450   p << " ";
451   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
452 }
453 
getAsmBlockArgumentNames(Region & region,OpAsmSetValueNameFn setNameFn)454 void PolyForOp::getAsmBlockArgumentNames(Region &region,
455                                          OpAsmSetValueNameFn setNameFn) {
456   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
457   if (!arrayAttr)
458     return;
459   auto args = getRegion().front().getArguments();
460   auto e = std::min(arrayAttr.size(), args.size());
461   for (unsigned i = 0; i < e; ++i) {
462     if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
463       setNameFn(args[i], strAttr.getValue());
464   }
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // TestAttrWithLoc - parse/printOptionalLocationSpecifier
469 //===----------------------------------------------------------------------===//
470 
parseOptionalLoc(OpAsmParser & p,Attribute & loc)471 static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
472   std::optional<Location> result;
473   SMLoc sourceLoc = p.getCurrentLocation();
474   if (p.parseOptionalLocationSpecifier(result))
475     return failure();
476   if (result)
477     loc = *result;
478   else
479     loc = p.getEncodedSourceLoc(sourceLoc);
480   return success();
481 }
482 
printOptionalLoc(OpAsmPrinter & p,Operation * op,Attribute loc)483 static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
484   p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
485 }
486 
487 #define GET_OP_CLASSES
488 #include "TestOpsSyntax.cpp.inc"
489 
registerOpsSyntax()490 void TestDialect::registerOpsSyntax() {
491   addOperations<
492 #define GET_OP_LIST
493 #include "TestOpsSyntax.cpp.inc"
494       >();
495 }
496