xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp (revision a07b422e90174430213201d0b4b307f5ed089d3f)
1 //===- Syntax.cpp - Custom syntax for Linalg transform ops ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
10 #include "mlir/IR/OpImplementation.h"
11 
12 using namespace mlir;
13 
14 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
15                                         Type &resultType, bool resultOptional) {
16   argumentType = resultType = nullptr;
17 
18   bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
19                                   : parser.parseLParen().succeeded();
20   if (!resultOptional && !hasLParen)
21     return failure();
22   if (parser.parseType(argumentType).failed())
23     return failure();
24   if (!hasLParen)
25     return success();
26 
27   return failure(parser.parseRParen().failed() ||
28                  parser.parseArrow().failed() ||
29                  parser.parseType(resultType).failed());
30 }
31 
32 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
33                                         SmallVectorImpl<Type> &resultTypes) {
34   argumentType = nullptr;
35   bool hasLParen = parser.parseOptionalLParen().succeeded();
36   if (parser.parseType(argumentType).failed())
37     return failure();
38   if (!hasLParen)
39     return success();
40 
41   if (parser.parseRParen().failed() || parser.parseArrow().failed())
42     return failure();
43 
44   if (parser.parseOptionalLParen().failed()) {
45     Type type;
46     if (parser.parseType(type).failed())
47       return failure();
48     resultTypes.push_back(type);
49     return success();
50   }
51   if (parser.parseTypeList(resultTypes).failed() ||
52       parser.parseRParen().failed()) {
53     resultTypes.clear();
54     return failure();
55   }
56   return success();
57 }
58 
59 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
60                                  Type argumentType, TypeRange resultType) {
61   if (!resultType.empty())
62     printer << "(";
63   printer << argumentType;
64   if (resultType.empty())
65     return;
66   printer << ") -> ";
67 
68   if (resultType.size() > 1)
69     printer << "(";
70   llvm::interleaveComma(resultType, printer.getStream());
71   if (resultType.size() > 1)
72     printer << ")";
73 }
74 
75 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
76                                  Type argumentType, Type resultType,
77                                  bool resultOptional) {
78   assert(resultOptional || resultType != nullptr);
79   return printSemiFunctionType(printer, op, argumentType,
80                                resultType ? TypeRange(resultType)
81                                           : TypeRange());
82 }
83