1 //===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestFormatUtils.h" 10 #include "mlir/IR/Builders.h" 11 12 using namespace mlir; 13 using namespace test; 14 15 //===----------------------------------------------------------------------===// 16 // CustomDirectiveOperands 17 //===----------------------------------------------------------------------===// 18 19 ParseResult test::parseCustomDirectiveOperands( 20 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 21 std::optional<OpAsmParser::UnresolvedOperand> &optOperand, 22 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 23 if (parser.parseOperand(operand)) 24 return failure(); 25 if (succeeded(parser.parseOptionalComma())) { 26 optOperand.emplace(); 27 if (parser.parseOperand(*optOperand)) 28 return failure(); 29 } 30 if (parser.parseArrow() || parser.parseLParen() || 31 parser.parseOperandList(varOperands) || parser.parseRParen()) 32 return failure(); 33 return success(); 34 } 35 36 void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 37 Value operand, Value optOperand, 38 OperandRange varOperands) { 39 printer << operand; 40 if (optOperand) 41 printer << ", " << optOperand; 42 printer << " -> (" << varOperands << ")"; 43 } 44 45 //===----------------------------------------------------------------------===// 46 // CustomDirectiveResults 47 //===----------------------------------------------------------------------===// 48 49 ParseResult 50 test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 51 Type &optOperandType, 52 SmallVectorImpl<Type> &varOperandTypes) { 53 if (parser.parseColon()) 54 return failure(); 55 56 if (parser.parseType(operandType)) 57 return failure(); 58 if (succeeded(parser.parseOptionalComma())) 59 if (parser.parseType(optOperandType)) 60 return failure(); 61 if (parser.parseArrow() || parser.parseLParen() || 62 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 63 return failure(); 64 return success(); 65 } 66 67 void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 68 Type operandType, Type optOperandType, 69 TypeRange varOperandTypes) { 70 printer << " : " << operandType; 71 if (optOperandType) 72 printer << ", " << optOperandType; 73 printer << " -> (" << varOperandTypes << ")"; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // CustomDirectiveWithTypeRefs 78 //===----------------------------------------------------------------------===// 79 80 ParseResult test::parseCustomDirectiveWithTypeRefs( 81 OpAsmParser &parser, Type operandType, Type optOperandType, 82 const SmallVectorImpl<Type> &varOperandTypes) { 83 if (parser.parseKeyword("type_refs_capture")) 84 return failure(); 85 86 Type operandType2, optOperandType2; 87 SmallVector<Type, 1> varOperandTypes2; 88 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 89 varOperandTypes2)) 90 return failure(); 91 92 if (operandType != operandType2 || optOperandType != optOperandType2 || 93 varOperandTypes != varOperandTypes2) 94 return failure(); 95 96 return success(); 97 } 98 99 void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 100 Operation *op, Type operandType, 101 Type optOperandType, 102 TypeRange varOperandTypes) { 103 printer << " type_refs_capture "; 104 printCustomDirectiveResults(printer, op, operandType, optOperandType, 105 varOperandTypes); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // CustomDirectiveOperandsAndTypes 110 //===----------------------------------------------------------------------===// 111 112 ParseResult test::parseCustomDirectiveOperandsAndTypes( 113 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 114 std::optional<OpAsmParser::UnresolvedOperand> &optOperand, 115 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 116 Type &operandType, Type &optOperandType, 117 SmallVectorImpl<Type> &varOperandTypes) { 118 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 119 parseCustomDirectiveResults(parser, operandType, optOperandType, 120 varOperandTypes)) 121 return failure(); 122 return success(); 123 } 124 125 void test::printCustomDirectiveOperandsAndTypes( 126 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 127 OperandRange varOperands, Type operandType, Type optOperandType, 128 TypeRange varOperandTypes) { 129 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 130 printCustomDirectiveResults(printer, op, operandType, optOperandType, 131 varOperandTypes); 132 } 133 134 //===----------------------------------------------------------------------===// 135 // CustomDirectiveRegions 136 //===----------------------------------------------------------------------===// 137 138 ParseResult test::parseCustomDirectiveRegions( 139 OpAsmParser &parser, Region ®ion, 140 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 141 if (parser.parseRegion(region)) 142 return failure(); 143 if (failed(parser.parseOptionalComma())) 144 return success(); 145 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 146 if (parser.parseRegion(*varRegion)) 147 return failure(); 148 varRegions.emplace_back(std::move(varRegion)); 149 return success(); 150 } 151 152 void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 153 Region ®ion, 154 MutableArrayRef<Region> varRegions) { 155 printer.printRegion(region); 156 if (!varRegions.empty()) { 157 printer << ", "; 158 for (Region ®ion : varRegions) 159 printer.printRegion(region); 160 } 161 } 162 163 //===----------------------------------------------------------------------===// 164 // CustomDirectiveSuccessors 165 //===----------------------------------------------------------------------===// 166 167 ParseResult 168 test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 169 SmallVectorImpl<Block *> &varSuccessors) { 170 if (parser.parseSuccessor(successor)) 171 return failure(); 172 if (failed(parser.parseOptionalComma())) 173 return success(); 174 Block *varSuccessor; 175 if (parser.parseSuccessor(varSuccessor)) 176 return failure(); 177 varSuccessors.append(2, varSuccessor); 178 return success(); 179 } 180 181 void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 182 Block *successor, 183 SuccessorRange varSuccessors) { 184 printer << successor; 185 if (!varSuccessors.empty()) 186 printer << ", " << varSuccessors.front(); 187 } 188 189 //===----------------------------------------------------------------------===// 190 // CustomDirectiveAttributes 191 //===----------------------------------------------------------------------===// 192 193 ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser, 194 IntegerAttr &attr, 195 IntegerAttr &optAttr) { 196 if (parser.parseAttribute(attr)) 197 return failure(); 198 if (succeeded(parser.parseOptionalComma())) { 199 if (parser.parseAttribute(optAttr)) 200 return failure(); 201 } 202 return success(); 203 } 204 205 void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 206 Attribute attribute, 207 Attribute optAttribute) { 208 printer << attribute; 209 if (optAttribute) 210 printer << ", " << optAttribute; 211 } 212 213 //===----------------------------------------------------------------------===// 214 // CustomDirectiveAttrDict 215 //===----------------------------------------------------------------------===// 216 217 ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser, 218 NamedAttrList &attrs) { 219 return parser.parseOptionalAttrDict(attrs); 220 } 221 222 void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 223 DictionaryAttr attrs) { 224 printer.printOptionalAttrDict(attrs.getValue()); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // CustomDirectiveOptionalOperandRef 229 //===----------------------------------------------------------------------===// 230 231 ParseResult test::parseCustomDirectiveOptionalOperandRef( 232 OpAsmParser &parser, 233 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { 234 int64_t operandCount = 0; 235 if (parser.parseInteger(operandCount)) 236 return failure(); 237 bool expectedOptionalOperand = operandCount == 0; 238 return success(expectedOptionalOperand != !!optOperand); 239 } 240 241 void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 242 Operation *op, 243 Value optOperand) { 244 printer << (optOperand ? "1" : "0"); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // CustomDirectiveOptionalOperand 249 //===----------------------------------------------------------------------===// 250 251 ParseResult test::parseCustomOptionalOperand( 252 OpAsmParser &parser, 253 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { 254 if (succeeded(parser.parseOptionalLParen())) { 255 optOperand.emplace(); 256 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 257 return failure(); 258 } 259 return success(); 260 } 261 262 void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 263 Value optOperand) { 264 if (optOperand) 265 printer << "(" << optOperand << ") "; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // CustomDirectiveSwitchCases 270 //===----------------------------------------------------------------------===// 271 272 ParseResult 273 test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, 274 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { 275 SmallVector<int64_t> caseValues; 276 while (succeeded(p.parseOptionalKeyword("case"))) { 277 int64_t value; 278 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>()); 279 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) 280 return failure(); 281 caseValues.push_back(value); 282 } 283 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); 284 return success(); 285 } 286 287 void test::printSwitchCases(OpAsmPrinter &p, Operation *op, 288 DenseI64ArrayAttr cases, RegionRange caseRegions) { 289 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { 290 p.printNewline(); 291 p << "case " << value << ' '; 292 p.printRegion(*region, /*printEntryBlockArgs=*/false); 293 } 294 } 295 296 //===----------------------------------------------------------------------===// 297 // CustomUsingPropertyInCustom 298 //===----------------------------------------------------------------------===// 299 300 bool test::parseUsingPropertyInCustom(OpAsmParser &parser, 301 SmallVector<int64_t> &value) { 302 auto elemParser = [&]() { 303 int64_t v = 0; 304 if (failed(parser.parseInteger(v))) 305 return failure(); 306 value.push_back(v); 307 return success(); 308 }; 309 return failed(parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, 310 elemParser)); 311 } 312 313 void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, 314 ArrayRef<int64_t> value) { 315 printer << '[' << value << ']'; 316 } 317 318 //===----------------------------------------------------------------------===// 319 // CustomDirectiveIntProperty 320 //===----------------------------------------------------------------------===// 321 322 bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) { 323 return failed(parser.parseInteger(value)); 324 } 325 326 void test::printIntProperty(OpAsmPrinter &printer, Operation *op, 327 int64_t value) { 328 printer << value; 329 } 330 331 //===----------------------------------------------------------------------===// 332 // CustomDirectiveSumProperty 333 //===----------------------------------------------------------------------===// 334 335 bool test::parseSumProperty(OpAsmParser &parser, int64_t &second, 336 int64_t first) { 337 int64_t sum; 338 auto loc = parser.getCurrentLocation(); 339 if (parser.parseInteger(second) || parser.parseEqual() || 340 parser.parseInteger(sum)) 341 return true; 342 if (sum != second + first) { 343 parser.emitError(loc, "Expected sum to equal first + second"); 344 return true; 345 } 346 return false; 347 } 348 349 void test::printSumProperty(OpAsmPrinter &printer, Operation *op, 350 int64_t second, int64_t first) { 351 printer << second << " = " << (second + first); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // CustomDirectiveOptionalCustomParser 356 //===----------------------------------------------------------------------===// 357 358 OptionalParseResult test::parseOptionalCustomParser(AsmParser &p, 359 IntegerAttr &result) { 360 if (succeeded(p.parseOptionalKeyword("foo"))) 361 return p.parseAttribute(result); 362 return {}; 363 } 364 365 void test::printOptionalCustomParser(AsmPrinter &p, Operation *, 366 IntegerAttr result) { 367 p << "foo "; 368 p.printAttribute(result); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // CustomDirectiveAttrElideType 373 //===----------------------------------------------------------------------===// 374 375 ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type, 376 Attribute &attr) { 377 return parser.parseAttribute(attr, type.getValue()); 378 } 379 380 void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, 381 Attribute attr) { 382 printer.printAttributeWithoutType(attr); 383 } 384