//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestFormatUtils.h" #include "mlir/IR/Builders.h" using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// // CustomDirectiveOperands //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, std::optional &optOperand, SmallVectorImpl &varOperands) { if (parser.parseOperand(operand)) return failure(); if (succeeded(parser.parseOptionalComma())) { optOperand.emplace(); if (parser.parseOperand(*optOperand)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseOperandList(varOperands) || parser.parseRParen()) return failure(); return success(); } void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { printer << operand; if (optOperand) printer << ", " << optOperand; printer << " -> (" << varOperands << ")"; } //===----------------------------------------------------------------------===// // CustomDirectiveResults //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parser.parseColon()) return failure(); if (parser.parseType(operandType)) return failure(); if (succeeded(parser.parseOptionalComma())) if (parser.parseType(optOperandType)) return failure(); if (parser.parseArrow() || parser.parseLParen() || parser.parseTypeList(varOperandTypes) || parser.parseRParen()) return failure(); return success(); } void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " : " << operandType; if (optOperandType) printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } //===----------------------------------------------------------------------===// // CustomDirectiveWithTypeRefs //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveWithTypeRefs( OpAsmParser &parser, Type operandType, Type optOperandType, const SmallVectorImpl &varOperandTypes) { if (parser.parseKeyword("type_refs_capture")) return failure(); Type operandType2, optOperandType2; SmallVector varOperandTypes2; if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, varOperandTypes2)) return failure(); if (operandType != operandType2 || optOperandType != optOperandType2 || varOperandTypes != varOperandTypes2) return failure(); return success(); } void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, Operation *op, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " type_refs_capture "; printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } //===----------------------------------------------------------------------===// // CustomDirectiveOperandsAndTypes //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, std::optional &optOperand, SmallVectorImpl &varOperands, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || parseCustomDirectiveResults(parser, operandType, optOperandType, varOperandTypes)) return failure(); return success(); } void test::printCustomDirectiveOperandsAndTypes( OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, OperandRange varOperands, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } //===----------------------------------------------------------------------===// // CustomDirectiveRegions //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveRegions( OpAsmParser &parser, Region ®ion, SmallVectorImpl> &varRegions) { if (parser.parseRegion(region)) return failure(); if (failed(parser.parseOptionalComma())) return success(); std::unique_ptr varRegion = std::make_unique(); if (parser.parseRegion(*varRegion)) return failure(); varRegions.emplace_back(std::move(varRegion)); return success(); } void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, Region ®ion, MutableArrayRef varRegions) { printer.printRegion(region); if (!varRegions.empty()) { printer << ", "; for (Region ®ion : varRegions) printer.printRegion(region); } } //===----------------------------------------------------------------------===// // CustomDirectiveSuccessors //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, SmallVectorImpl &varSuccessors) { if (parser.parseSuccessor(successor)) return failure(); if (failed(parser.parseOptionalComma())) return success(); Block *varSuccessor; if (parser.parseSuccessor(varSuccessor)) return failure(); varSuccessors.append(2, varSuccessor); return success(); } void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, Block *successor, SuccessorRange varSuccessors) { printer << successor; if (!varSuccessors.empty()) printer << ", " << varSuccessors.front(); } //===----------------------------------------------------------------------===// // CustomDirectiveAttributes //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser, IntegerAttr &attr, IntegerAttr &optAttr) { if (parser.parseAttribute(attr)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseAttribute(optAttr)) return failure(); } return success(); } void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, Attribute attribute, Attribute optAttribute) { printer << attribute; if (optAttribute) printer << ", " << optAttribute; } //===----------------------------------------------------------------------===// // CustomDirectiveAttrDict //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser, NamedAttrList &attrs) { return parser.parseOptionalAttrDict(attrs); } void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { printer.printOptionalAttrDict(attrs.getValue()); } //===----------------------------------------------------------------------===// // CustomDirectiveOptionalOperandRef //===----------------------------------------------------------------------===// ParseResult test::parseCustomDirectiveOptionalOperandRef( OpAsmParser &parser, std::optional &optOperand) { int64_t operandCount = 0; if (parser.parseInteger(operandCount)) return failure(); bool expectedOptionalOperand = operandCount == 0; return success(expectedOptionalOperand != !!optOperand); } void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, Operation *op, Value optOperand) { printer << (optOperand ? "1" : "0"); } //===----------------------------------------------------------------------===// // CustomDirectiveOptionalOperand //===----------------------------------------------------------------------===// ParseResult test::parseCustomOptionalOperand( OpAsmParser &parser, std::optional &optOperand) { if (succeeded(parser.parseOptionalLParen())) { optOperand.emplace(); if (parser.parseOperand(*optOperand) || parser.parseRParen()) return failure(); } return success(); } void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, Value optOperand) { if (optOperand) printer << "(" << optOperand << ") "; } //===----------------------------------------------------------------------===// // CustomDirectiveSwitchCases //===----------------------------------------------------------------------===// ParseResult test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl> &caseRegions) { SmallVector caseValues; while (succeeded(p.parseOptionalKeyword("case"))) { int64_t value; Region ®ion = *caseRegions.emplace_back(std::make_unique()); if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) return failure(); caseValues.push_back(value); } cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); return success(); } void test::printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions) { for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { p.printNewline(); p << "case " << value << ' '; p.printRegion(*region, /*printEntryBlockArgs=*/false); } } //===----------------------------------------------------------------------===// // CustomUsingPropertyInCustom //===----------------------------------------------------------------------===// bool test::parseUsingPropertyInCustom(OpAsmParser &parser, SmallVector &value) { auto elemParser = [&]() { int64_t v = 0; if (failed(parser.parseInteger(v))) return failure(); value.push_back(v); return success(); }; return failed(parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, elemParser)); } void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, ArrayRef value) { printer << '[' << value << ']'; } //===----------------------------------------------------------------------===// // CustomDirectiveIntProperty //===----------------------------------------------------------------------===// bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) { return failed(parser.parseInteger(value)); } void test::printIntProperty(OpAsmPrinter &printer, Operation *op, int64_t value) { printer << value; } //===----------------------------------------------------------------------===// // CustomDirectiveSumProperty //===----------------------------------------------------------------------===// bool test::parseSumProperty(OpAsmParser &parser, int64_t &second, int64_t first) { int64_t sum; auto loc = parser.getCurrentLocation(); if (parser.parseInteger(second) || parser.parseEqual() || parser.parseInteger(sum)) return true; if (sum != second + first) { parser.emitError(loc, "Expected sum to equal first + second"); return true; } return false; } void test::printSumProperty(OpAsmPrinter &printer, Operation *op, int64_t second, int64_t first) { printer << second << " = " << (second + first); } //===----------------------------------------------------------------------===// // CustomDirectiveOptionalCustomParser //===----------------------------------------------------------------------===// OptionalParseResult test::parseOptionalCustomParser(AsmParser &p, IntegerAttr &result) { if (succeeded(p.parseOptionalKeyword("foo"))) return p.parseAttribute(result); return {}; } void test::printOptionalCustomParser(AsmPrinter &p, Operation *, IntegerAttr result) { p << "foo "; p.printAttribute(result); } //===----------------------------------------------------------------------===// // CustomDirectiveAttrElideType //===----------------------------------------------------------------------===// ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type, Attribute &attr) { return parser.parseAttribute(attr, type.getValue()); } void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, Attribute attr) { printer.printAttributeWithoutType(attr); }