1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/IRDL/IR/IRDL.h" 10 #include "mlir/Dialect/IRDL/IRDLSymbols.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinAttributes.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/ExtensibleDialect.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/SetOperations.h" 22 #include "llvm/ADT/SmallString.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/IR/Metadata.h" 26 #include "llvm/Support/Casting.h" 27 28 using namespace mlir; 29 using namespace mlir::irdl; 30 31 //===----------------------------------------------------------------------===// 32 // IRDL dialect. 33 //===----------------------------------------------------------------------===// 34 35 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc" 36 37 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc" 38 39 void IRDLDialect::initialize() { 40 addOperations< 41 #define GET_OP_LIST 42 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" 43 >(); 44 addTypes< 45 #define GET_TYPEDEF_LIST 46 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" 47 >(); 48 addAttributes< 49 #define GET_ATTRDEF_LIST 50 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" 51 >(); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // Parsing/Printing/Verifying 56 //===----------------------------------------------------------------------===// 57 58 /// Parse a region, and add a single block if the region is empty. 59 /// If no region is parsed, create a new region with a single empty block. 60 static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion) { 61 auto regionParseRes = p.parseOptionalRegion(region); 62 if (regionParseRes.has_value() && failed(regionParseRes.value())) 63 return failure(); 64 65 // If the region is empty, add a single empty block. 66 if (region.empty()) 67 region.push_back(new Block()); 68 69 return success(); 70 } 71 72 static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, 73 Region ®ion) { 74 if (!region.getBlocks().front().empty()) 75 p.printRegion(region); 76 } 77 78 LogicalResult DialectOp::verify() { 79 if (!Dialect::isValidNamespace(getName())) 80 return emitOpError("invalid dialect name"); 81 return success(); 82 } 83 84 LogicalResult OperationOp::verifyRegions() { 85 // Stores pairs of value kinds and the list of names of values of this kind in 86 // the operation. 87 SmallVector<std::tuple<StringRef, llvm::SmallDenseSet<StringRef>>> valueNames; 88 89 auto insertNames = [&](StringRef kind, ArrayAttr names) { 90 llvm::SmallDenseSet<StringRef> nameSet; 91 nameSet.reserve(names.size()); 92 for (auto name : names) 93 nameSet.insert(llvm::cast<StringAttr>(name).getValue()); 94 valueNames.emplace_back(kind, std::move(nameSet)); 95 }; 96 97 for (Operation &op : getBody().getOps()) { 98 TypeSwitch<Operation *>(&op) 99 .Case<OperandsOp>( 100 [&](OperandsOp op) { insertNames("operands", op.getNames()); }) 101 .Case<ResultsOp>( 102 [&](ResultsOp op) { insertNames("results", op.getNames()); }) 103 .Case<RegionsOp>( 104 [&](RegionsOp op) { insertNames("regions", op.getNames()); }); 105 } 106 107 // Verify that no two operand, result or region share the same name. 108 // The absence of duplicates within each value kind is checked by the 109 // associated operation's verifier. 110 for (size_t i : llvm::seq(valueNames.size())) { 111 for (size_t j : llvm::seq(i + 1, valueNames.size())) { 112 auto [lhs, lhsSet] = valueNames[i]; 113 auto &[rhs, rhsSet] = valueNames[j]; 114 llvm::set_intersect(lhsSet, rhsSet); 115 if (!lhsSet.empty()) 116 return emitOpError("contains a value named '") 117 << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs; 118 } 119 } 120 121 return success(); 122 } 123 124 static LogicalResult verifyNames(Operation *op, StringRef kindName, 125 ArrayAttr names, size_t numOperands) { 126 if (numOperands != names.size()) 127 return op->emitOpError() 128 << "the number of " << kindName 129 << "s and their names must be " 130 "the same, but got " 131 << numOperands << " and " << names.size() << " respectively"; 132 133 DenseMap<StringRef, size_t> nameMap; 134 for (auto [i, name] : llvm::enumerate(names)) { 135 StringRef nameRef = llvm::cast<StringAttr>(name).getValue(); 136 if (nameRef.empty()) 137 return op->emitOpError() 138 << "name of " << kindName << " #" << i << " is empty"; 139 if (!llvm::isAlpha(nameRef[0]) && nameRef[0] != '_') 140 return op->emitOpError() 141 << "name of " << kindName << " #" << i 142 << " must start with either a letter or an underscore"; 143 if (llvm::any_of(nameRef, 144 [](char c) { return !llvm::isAlnum(c) && c != '_'; })) 145 return op->emitOpError() 146 << "name of " << kindName << " #" << i 147 << " must contain only letters, digits and underscores"; 148 if (nameMap.contains(nameRef)) 149 return op->emitOpError() << "name of " << kindName << " #" << i 150 << " is a duplicate of the name of " << kindName 151 << " #" << nameMap[nameRef]; 152 nameMap.insert({nameRef, i}); 153 } 154 155 return success(); 156 } 157 158 LogicalResult ParametersOp::verify() { 159 return verifyNames(*this, "parameter", getNames(), getNumOperands()); 160 } 161 162 template <typename ValueListOp> 163 static LogicalResult verifyOperandsResultsCommon(ValueListOp op, 164 StringRef kindName) { 165 size_t numVariadicities = op.getVariadicity().size(); 166 size_t numOperands = op.getNumOperands(); 167 168 if (numOperands != numVariadicities) 169 return op.emitOpError() 170 << "the number of " << kindName 171 << "s and their variadicities must be " 172 "the same, but got " 173 << numOperands << " and " << numVariadicities << " respectively"; 174 175 return verifyNames(op, kindName, op.getNames(), numOperands); 176 } 177 178 LogicalResult OperandsOp::verify() { 179 return verifyOperandsResultsCommon(*this, "operand"); 180 } 181 182 LogicalResult ResultsOp::verify() { 183 return verifyOperandsResultsCommon(*this, "result"); 184 } 185 186 LogicalResult AttributesOp::verify() { 187 size_t namesSize = getAttributeValueNames().size(); 188 size_t valuesSize = getAttributeValues().size(); 189 190 if (namesSize != valuesSize) 191 return emitOpError() 192 << "the number of attribute names and their constraints must be " 193 "the same but got " 194 << namesSize << " and " << valuesSize << " respectively"; 195 196 return success(); 197 } 198 199 LogicalResult BaseOp::verify() { 200 std::optional<StringRef> baseName = getBaseName(); 201 std::optional<SymbolRefAttr> baseRef = getBaseRef(); 202 if (baseName.has_value() == baseRef.has_value()) 203 return emitOpError() << "the base type or attribute should be specified by " 204 "either a name or a reference"; 205 206 if (baseName && 207 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#'))) 208 return emitOpError() << "the base type or attribute name should start with " 209 "'!' or '#'"; 210 211 return success(); 212 } 213 214 /// Finds whether the provided symbol is an IRDL type or attribute definition. 215 /// The source operation must be within a DialectOp. 216 static LogicalResult 217 checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, 218 Operation *source, SymbolRefAttr symbol) { 219 Operation *targetOp = 220 irdl::lookupSymbolNearDialect(symbolTable, source, symbol); 221 222 if (!targetOp) 223 return source->emitOpError() << "symbol '" << symbol << "' not found"; 224 225 if (!isa<TypeOp, AttributeOp>(targetOp)) 226 return source->emitOpError() << "symbol '" << symbol 227 << "' does not refer to a type or attribute " 228 "definition (refers to '" 229 << targetOp->getName() << "')"; 230 231 return success(); 232 } 233 234 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 235 std::optional<SymbolRefAttr> baseRef = getBaseRef(); 236 if (!baseRef) 237 return success(); 238 239 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef); 240 } 241 242 LogicalResult 243 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 244 std::optional<SymbolRefAttr> baseRef = getBaseType(); 245 if (!baseRef) 246 return success(); 247 248 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef); 249 } 250 251 /// Parse a value with its variadicity first. By default, the variadicity is 252 /// single. 253 /// 254 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value 255 static ParseResult 256 parseValueWithVariadicity(OpAsmParser &p, 257 OpAsmParser::UnresolvedOperand &operand, 258 VariadicityAttr &variadicityAttr) { 259 MLIRContext *ctx = p.getBuilder().getContext(); 260 261 // Parse the variadicity, if present 262 if (p.parseOptionalKeyword("single").succeeded()) { 263 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); 264 } else if (p.parseOptionalKeyword("optional").succeeded()) { 265 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional); 266 } else if (p.parseOptionalKeyword("variadic").succeeded()) { 267 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic); 268 } else { 269 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); 270 } 271 272 // Parse the value 273 if (p.parseOperand(operand)) 274 return failure(); 275 return success(); 276 } 277 278 static ParseResult parseNamedValueListImpl( 279 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 280 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) { 281 Builder &builder = p.getBuilder(); 282 MLIRContext *ctx = builder.getContext(); 283 SmallVector<Attribute> valueNames; 284 SmallVector<VariadicityAttr> variadicities; 285 286 // Parse a single value with its variadicity 287 auto parseOne = [&] { 288 StringRef name; 289 OpAsmParser::UnresolvedOperand operand; 290 VariadicityAttr variadicity; 291 if (p.parseKeyword(&name) || p.parseColon()) 292 return failure(); 293 294 if (variadicityAttr) { 295 if (parseValueWithVariadicity(p, operand, variadicity)) 296 return failure(); 297 variadicities.push_back(variadicity); 298 } else { 299 if (p.parseOperand(operand)) 300 return failure(); 301 } 302 303 valueNames.push_back(StringAttr::get(ctx, name)); 304 operands.push_back(operand); 305 return success(); 306 }; 307 308 if (p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOne)) 309 return failure(); 310 valueNamesAttr = ArrayAttr::get(ctx, valueNames); 311 if (variadicityAttr) 312 *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities); 313 return success(); 314 } 315 316 /// Parse a list of named values. 317 /// 318 /// values ::= 319 /// `(` (named-value (`,` named-value)*)? `)` 320 /// named-value := bare-id `:` ssa-value 321 static ParseResult 322 parseNamedValueList(OpAsmParser &p, 323 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 324 ArrayAttr &valueNamesAttr) { 325 return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr); 326 } 327 328 /// Parse a list of named values with their variadicities first. By default, the 329 /// variadicity is single. 330 /// 331 /// values-with-variadicity ::= 332 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` 333 /// value-with-variadicity 334 /// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value 335 static ParseResult parseNamedValueListWithVariadicity( 336 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 337 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) { 338 return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr); 339 } 340 341 static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, 342 OperandRange operands, 343 ArrayAttr valueNamesAttr, 344 VariadicityArrayAttr variadicityAttr) { 345 p << "("; 346 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) { 347 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": "; 348 if (variadicityAttr) { 349 Variadicity variadicity = variadicityAttr[i].getValue(); 350 if (variadicity != Variadicity::single) { 351 p << stringifyVariadicity(variadicity) << " "; 352 } 353 } 354 p << operands[i]; 355 }); 356 p << ")"; 357 } 358 359 /// Print a list of named values. 360 /// 361 /// values ::= 362 /// `(` (named-value (`,` named-value)*)? `)` 363 /// named-value := bare-id `:` ssa-value 364 static void printNamedValueList(OpAsmPrinter &p, Operation *op, 365 OperandRange operands, 366 ArrayAttr valueNamesAttr) { 367 printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr); 368 } 369 370 /// Print a list of named values with their variadicities first. By default, the 371 /// variadicity is single. 372 /// 373 /// values-with-variadicity ::= 374 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` 375 /// value-with-variadicity ::= 376 /// bare-id `:` ("single" | "optional" | "variadic")? ssa-value 377 static void printNamedValueListWithVariadicity( 378 OpAsmPrinter &p, Operation *op, OperandRange operands, 379 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) { 380 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr); 381 } 382 383 static ParseResult 384 parseAttributesOp(OpAsmParser &p, 385 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 386 ArrayAttr &attrNamesAttr) { 387 Builder &builder = p.getBuilder(); 388 SmallVector<Attribute> attrNames; 389 if (succeeded(p.parseOptionalLBrace())) { 390 auto parseOperands = [&]() { 391 if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() || 392 p.parseOperand(attrOperands.emplace_back())) 393 return failure(); 394 return success(); 395 }; 396 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace()) 397 return failure(); 398 } 399 attrNamesAttr = builder.getArrayAttr(attrNames); 400 return success(); 401 } 402 403 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, 404 OperandRange attrArgs, ArrayAttr attrNames) { 405 if (attrNames.empty()) 406 return; 407 p << "{"; 408 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 409 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 410 p << '}'; 411 } 412 413 LogicalResult RegionOp::verify() { 414 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr()) 415 if (int64_t number = numberOfBlocks.getInt(); number <= 0) { 416 return emitOpError("the number of blocks is expected to be >= 1 but got ") 417 << number; 418 } 419 return success(); 420 } 421 422 LogicalResult RegionsOp::verify() { 423 return verifyNames(*this, "region", getNames(), getNumOperands()); 424 } 425 426 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" 427 428 #define GET_TYPEDEF_CLASSES 429 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" 430 431 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc" 432 433 #define GET_ATTRDEF_CLASSES 434 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" 435 436 #define GET_OP_CLASSES 437 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" 438