xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp (revision a07b422e90174430213201d0b4b307f5ed089d3f)
12fe4d90cSAlex Zinenko //===- Syntax.cpp - Custom syntax for Linalg transform ops ----------------===//
22fe4d90cSAlex Zinenko //
32fe4d90cSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42fe4d90cSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
52fe4d90cSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62fe4d90cSAlex Zinenko //
72fe4d90cSAlex Zinenko //===----------------------------------------------------------------------===//
82fe4d90cSAlex Zinenko 
92fe4d90cSAlex Zinenko #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
102fe4d90cSAlex Zinenko #include "mlir/IR/OpImplementation.h"
112fe4d90cSAlex Zinenko 
122fe4d90cSAlex Zinenko using namespace mlir;
132fe4d90cSAlex Zinenko 
142fe4d90cSAlex Zinenko ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
15*a07b422eSFelix Schneider                                         Type &resultType, bool resultOptional) {
162fe4d90cSAlex Zinenko   argumentType = resultType = nullptr;
17*a07b422eSFelix Schneider 
18*a07b422eSFelix Schneider   bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
19*a07b422eSFelix Schneider                                   : parser.parseLParen().succeeded();
20*a07b422eSFelix Schneider   if (!resultOptional && !hasLParen)
21*a07b422eSFelix Schneider     return failure();
222fe4d90cSAlex Zinenko   if (parser.parseType(argumentType).failed())
232fe4d90cSAlex Zinenko     return failure();
242fe4d90cSAlex Zinenko   if (!hasLParen)
252fe4d90cSAlex Zinenko     return success();
262fe4d90cSAlex Zinenko 
272fe4d90cSAlex Zinenko   return failure(parser.parseRParen().failed() ||
282fe4d90cSAlex Zinenko                  parser.parseArrow().failed() ||
292fe4d90cSAlex Zinenko                  parser.parseType(resultType).failed());
302fe4d90cSAlex Zinenko }
312fe4d90cSAlex Zinenko 
322fe4d90cSAlex Zinenko ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
332fe4d90cSAlex Zinenko                                         SmallVectorImpl<Type> &resultTypes) {
342fe4d90cSAlex Zinenko   argumentType = nullptr;
352fe4d90cSAlex Zinenko   bool hasLParen = parser.parseOptionalLParen().succeeded();
362fe4d90cSAlex Zinenko   if (parser.parseType(argumentType).failed())
372fe4d90cSAlex Zinenko     return failure();
382fe4d90cSAlex Zinenko   if (!hasLParen)
392fe4d90cSAlex Zinenko     return success();
402fe4d90cSAlex Zinenko 
412fe4d90cSAlex Zinenko   if (parser.parseRParen().failed() || parser.parseArrow().failed())
422fe4d90cSAlex Zinenko     return failure();
432fe4d90cSAlex Zinenko 
442fe4d90cSAlex Zinenko   if (parser.parseOptionalLParen().failed()) {
452fe4d90cSAlex Zinenko     Type type;
462fe4d90cSAlex Zinenko     if (parser.parseType(type).failed())
472fe4d90cSAlex Zinenko       return failure();
482fe4d90cSAlex Zinenko     resultTypes.push_back(type);
492fe4d90cSAlex Zinenko     return success();
502fe4d90cSAlex Zinenko   }
512fe4d90cSAlex Zinenko   if (parser.parseTypeList(resultTypes).failed() ||
522fe4d90cSAlex Zinenko       parser.parseRParen().failed()) {
532fe4d90cSAlex Zinenko     resultTypes.clear();
542fe4d90cSAlex Zinenko     return failure();
552fe4d90cSAlex Zinenko   }
562fe4d90cSAlex Zinenko   return success();
572fe4d90cSAlex Zinenko }
582fe4d90cSAlex Zinenko 
592fe4d90cSAlex Zinenko void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
602fe4d90cSAlex Zinenko                                  Type argumentType, TypeRange resultType) {
612fe4d90cSAlex Zinenko   if (!resultType.empty())
622fe4d90cSAlex Zinenko     printer << "(";
632fe4d90cSAlex Zinenko   printer << argumentType;
642fe4d90cSAlex Zinenko   if (resultType.empty())
652fe4d90cSAlex Zinenko     return;
662fe4d90cSAlex Zinenko   printer << ") -> ";
672fe4d90cSAlex Zinenko 
682fe4d90cSAlex Zinenko   if (resultType.size() > 1)
692fe4d90cSAlex Zinenko     printer << "(";
702fe4d90cSAlex Zinenko   llvm::interleaveComma(resultType, printer.getStream());
712fe4d90cSAlex Zinenko   if (resultType.size() > 1)
722fe4d90cSAlex Zinenko     printer << ")";
732fe4d90cSAlex Zinenko }
742fe4d90cSAlex Zinenko 
752fe4d90cSAlex Zinenko void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
76*a07b422eSFelix Schneider                                  Type argumentType, Type resultType,
77*a07b422eSFelix Schneider                                  bool resultOptional) {
78*a07b422eSFelix Schneider   assert(resultOptional || resultType != nullptr);
792fe4d90cSAlex Zinenko   return printSemiFunctionType(printer, op, argumentType,
802fe4d90cSAlex Zinenko                                resultType ? TypeRange(resultType)
812fe4d90cSAlex Zinenko                                           : TypeRange());
822fe4d90cSAlex Zinenko }
83