1 //===- Syntax.cpp - Custom syntax for Linalg transform ops ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" 10 #include "mlir/IR/OpImplementation.h" 11 12 using namespace mlir; 13 14 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType, 15 Type &resultType, bool resultOptional) { 16 argumentType = resultType = nullptr; 17 18 bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded() 19 : parser.parseLParen().succeeded(); 20 if (!resultOptional && !hasLParen) 21 return failure(); 22 if (parser.parseType(argumentType).failed()) 23 return failure(); 24 if (!hasLParen) 25 return success(); 26 27 return failure(parser.parseRParen().failed() || 28 parser.parseArrow().failed() || 29 parser.parseType(resultType).failed()); 30 } 31 32 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType, 33 SmallVectorImpl<Type> &resultTypes) { 34 argumentType = nullptr; 35 bool hasLParen = parser.parseOptionalLParen().succeeded(); 36 if (parser.parseType(argumentType).failed()) 37 return failure(); 38 if (!hasLParen) 39 return success(); 40 41 if (parser.parseRParen().failed() || parser.parseArrow().failed()) 42 return failure(); 43 44 if (parser.parseOptionalLParen().failed()) { 45 Type type; 46 if (parser.parseType(type).failed()) 47 return failure(); 48 resultTypes.push_back(type); 49 return success(); 50 } 51 if (parser.parseTypeList(resultTypes).failed() || 52 parser.parseRParen().failed()) { 53 resultTypes.clear(); 54 return failure(); 55 } 56 return success(); 57 } 58 59 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op, 60 Type argumentType, TypeRange resultType) { 61 if (!resultType.empty()) 62 printer << "("; 63 printer << argumentType; 64 if (resultType.empty()) 65 return; 66 printer << ") -> "; 67 68 if (resultType.size() > 1) 69 printer << "("; 70 llvm::interleaveComma(resultType, printer.getStream()); 71 if (resultType.size() > 1) 72 printer << ")"; 73 } 74 75 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op, 76 Type argumentType, Type resultType, 77 bool resultOptional) { 78 assert(resultOptional || resultType != nullptr); 79 return printSemiFunctionType(printer, op, argumentType, 80 resultType ? TypeRange(resultType) 81 : TypeRange()); 82 } 83