xref: /llvm-project/mlir/lib/Dialect/IRDL/IR/IRDL.cpp (revision 69d3ba3db922fca8cfc47b5f115b6bea6a737aab)
1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/IRDL/IR/IRDL.h"
10 #include "mlir/Dialect/IRDL/IRDLSymbols.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinAttributes.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/ExtensibleDialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/IR/Metadata.h"
26 #include "llvm/Support/Casting.h"
27 
28 using namespace mlir;
29 using namespace mlir::irdl;
30 
31 //===----------------------------------------------------------------------===//
32 // IRDL dialect.
33 //===----------------------------------------------------------------------===//
34 
35 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
36 
37 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
38 
39 void IRDLDialect::initialize() {
40   addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
43       >();
44   addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
47       >();
48   addAttributes<
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
51       >();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // Parsing/Printing/Verifying
56 //===----------------------------------------------------------------------===//
57 
58 /// Parse a region, and add a single block if the region is empty.
59 /// If no region is parsed, create a new region with a single empty block.
60 static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
61   auto regionParseRes = p.parseOptionalRegion(region);
62   if (regionParseRes.has_value() && failed(regionParseRes.value()))
63     return failure();
64 
65   // If the region is empty, add a single empty block.
66   if (region.empty())
67     region.push_back(new Block());
68 
69   return success();
70 }
71 
72 static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
73                                    Region &region) {
74   if (!region.getBlocks().front().empty())
75     p.printRegion(region);
76 }
77 
78 LogicalResult DialectOp::verify() {
79   if (!Dialect::isValidNamespace(getName()))
80     return emitOpError("invalid dialect name");
81   return success();
82 }
83 
84 LogicalResult OperationOp::verifyRegions() {
85   // Stores pairs of value kinds and the list of names of values of this kind in
86   // the operation.
87   SmallVector<std::tuple<StringRef, llvm::SmallDenseSet<StringRef>>> valueNames;
88 
89   auto insertNames = [&](StringRef kind, ArrayAttr names) {
90     llvm::SmallDenseSet<StringRef> nameSet;
91     nameSet.reserve(names.size());
92     for (auto name : names)
93       nameSet.insert(llvm::cast<StringAttr>(name).getValue());
94     valueNames.emplace_back(kind, std::move(nameSet));
95   };
96 
97   for (Operation &op : getBody().getOps()) {
98     TypeSwitch<Operation *>(&op)
99         .Case<OperandsOp>(
100             [&](OperandsOp op) { insertNames("operands", op.getNames()); })
101         .Case<ResultsOp>(
102             [&](ResultsOp op) { insertNames("results", op.getNames()); })
103         .Case<RegionsOp>(
104             [&](RegionsOp op) { insertNames("regions", op.getNames()); });
105   }
106 
107   // Verify that no two operand, result or region share the same name.
108   // The absence of duplicates within each value kind is checked by the
109   // associated operation's verifier.
110   for (size_t i : llvm::seq(valueNames.size())) {
111     for (size_t j : llvm::seq(i + 1, valueNames.size())) {
112       auto [lhs, lhsSet] = valueNames[i];
113       auto &[rhs, rhsSet] = valueNames[j];
114       llvm::set_intersect(lhsSet, rhsSet);
115       if (!lhsSet.empty())
116         return emitOpError("contains a value named '")
117                << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
118     }
119   }
120 
121   return success();
122 }
123 
124 static LogicalResult verifyNames(Operation *op, StringRef kindName,
125                                  ArrayAttr names, size_t numOperands) {
126   if (numOperands != names.size())
127     return op->emitOpError()
128            << "the number of " << kindName
129            << "s and their names must be "
130               "the same, but got "
131            << numOperands << " and " << names.size() << " respectively";
132 
133   DenseMap<StringRef, size_t> nameMap;
134   for (auto [i, name] : llvm::enumerate(names)) {
135     StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
136     if (nameRef.empty())
137       return op->emitOpError()
138              << "name of " << kindName << " #" << i << " is empty";
139     if (!llvm::isAlpha(nameRef[0]) && nameRef[0] != '_')
140       return op->emitOpError()
141              << "name of " << kindName << " #" << i
142              << " must start with either a letter or an underscore";
143     if (llvm::any_of(nameRef,
144                      [](char c) { return !llvm::isAlnum(c) && c != '_'; }))
145       return op->emitOpError()
146              << "name of " << kindName << " #" << i
147              << " must contain only letters, digits and underscores";
148     if (nameMap.contains(nameRef))
149       return op->emitOpError() << "name of " << kindName << " #" << i
150                                << " is a duplicate of the name of " << kindName
151                                << " #" << nameMap[nameRef];
152     nameMap.insert({nameRef, i});
153   }
154 
155   return success();
156 }
157 
158 LogicalResult ParametersOp::verify() {
159   return verifyNames(*this, "parameter", getNames(), getNumOperands());
160 }
161 
162 template <typename ValueListOp>
163 static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
164                                                  StringRef kindName) {
165   size_t numVariadicities = op.getVariadicity().size();
166   size_t numOperands = op.getNumOperands();
167 
168   if (numOperands != numVariadicities)
169     return op.emitOpError()
170            << "the number of " << kindName
171            << "s and their variadicities must be "
172               "the same, but got "
173            << numOperands << " and " << numVariadicities << " respectively";
174 
175   return verifyNames(op, kindName, op.getNames(), numOperands);
176 }
177 
178 LogicalResult OperandsOp::verify() {
179   return verifyOperandsResultsCommon(*this, "operand");
180 }
181 
182 LogicalResult ResultsOp::verify() {
183   return verifyOperandsResultsCommon(*this, "result");
184 }
185 
186 LogicalResult AttributesOp::verify() {
187   size_t namesSize = getAttributeValueNames().size();
188   size_t valuesSize = getAttributeValues().size();
189 
190   if (namesSize != valuesSize)
191     return emitOpError()
192            << "the number of attribute names and their constraints must be "
193               "the same but got "
194            << namesSize << " and " << valuesSize << " respectively";
195 
196   return success();
197 }
198 
199 LogicalResult BaseOp::verify() {
200   std::optional<StringRef> baseName = getBaseName();
201   std::optional<SymbolRefAttr> baseRef = getBaseRef();
202   if (baseName.has_value() == baseRef.has_value())
203     return emitOpError() << "the base type or attribute should be specified by "
204                             "either a name or a reference";
205 
206   if (baseName &&
207       (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
208     return emitOpError() << "the base type or attribute name should start with "
209                             "'!' or '#'";
210 
211   return success();
212 }
213 
214 /// Finds whether the provided symbol is an IRDL type or attribute definition.
215 /// The source operation must be within a DialectOp.
216 static LogicalResult
217 checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
218                              Operation *source, SymbolRefAttr symbol) {
219   Operation *targetOp =
220       irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
221 
222   if (!targetOp)
223     return source->emitOpError() << "symbol '" << symbol << "' not found";
224 
225   if (!isa<TypeOp, AttributeOp>(targetOp))
226     return source->emitOpError() << "symbol '" << symbol
227                                  << "' does not refer to a type or attribute "
228                                     "definition (refers to '"
229                                  << targetOp->getName() << "')";
230 
231   return success();
232 }
233 
234 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
235   std::optional<SymbolRefAttr> baseRef = getBaseRef();
236   if (!baseRef)
237     return success();
238 
239   return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
240 }
241 
242 LogicalResult
243 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244   std::optional<SymbolRefAttr> baseRef = getBaseType();
245   if (!baseRef)
246     return success();
247 
248   return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
249 }
250 
251 /// Parse a value with its variadicity first. By default, the variadicity is
252 /// single.
253 ///
254 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
255 static ParseResult
256 parseValueWithVariadicity(OpAsmParser &p,
257                           OpAsmParser::UnresolvedOperand &operand,
258                           VariadicityAttr &variadicityAttr) {
259   MLIRContext *ctx = p.getBuilder().getContext();
260 
261   // Parse the variadicity, if present
262   if (p.parseOptionalKeyword("single").succeeded()) {
263     variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
264   } else if (p.parseOptionalKeyword("optional").succeeded()) {
265     variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
266   } else if (p.parseOptionalKeyword("variadic").succeeded()) {
267     variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
268   } else {
269     variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
270   }
271 
272   // Parse the value
273   if (p.parseOperand(operand))
274     return failure();
275   return success();
276 }
277 
278 static ParseResult parseNamedValueListImpl(
279     OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
280     ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
281   Builder &builder = p.getBuilder();
282   MLIRContext *ctx = builder.getContext();
283   SmallVector<Attribute> valueNames;
284   SmallVector<VariadicityAttr> variadicities;
285 
286   // Parse a single value with its variadicity
287   auto parseOne = [&] {
288     StringRef name;
289     OpAsmParser::UnresolvedOperand operand;
290     VariadicityAttr variadicity;
291     if (p.parseKeyword(&name) || p.parseColon())
292       return failure();
293 
294     if (variadicityAttr) {
295       if (parseValueWithVariadicity(p, operand, variadicity))
296         return failure();
297       variadicities.push_back(variadicity);
298     } else {
299       if (p.parseOperand(operand))
300         return failure();
301     }
302 
303     valueNames.push_back(StringAttr::get(ctx, name));
304     operands.push_back(operand);
305     return success();
306   };
307 
308   if (p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOne))
309     return failure();
310   valueNamesAttr = ArrayAttr::get(ctx, valueNames);
311   if (variadicityAttr)
312     *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
313   return success();
314 }
315 
316 /// Parse a list of named values.
317 ///
318 /// values ::=
319 ///   `(` (named-value (`,` named-value)*)? `)`
320 /// named-value := bare-id `:` ssa-value
321 static ParseResult
322 parseNamedValueList(OpAsmParser &p,
323                     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
324                     ArrayAttr &valueNamesAttr) {
325   return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
326 }
327 
328 /// Parse a list of named values with their variadicities first. By default, the
329 /// variadicity is single.
330 ///
331 /// values-with-variadicity ::=
332 ///   `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
333 /// value-with-variadicity
334 ///   ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
335 static ParseResult parseNamedValueListWithVariadicity(
336     OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
337     ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
338   return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
339 }
340 
341 static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op,
342                                     OperandRange operands,
343                                     ArrayAttr valueNamesAttr,
344                                     VariadicityArrayAttr variadicityAttr) {
345   p << "(";
346   interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
347     p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
348     if (variadicityAttr) {
349       Variadicity variadicity = variadicityAttr[i].getValue();
350       if (variadicity != Variadicity::single) {
351         p << stringifyVariadicity(variadicity) << " ";
352       }
353     }
354     p << operands[i];
355   });
356   p << ")";
357 }
358 
359 /// Print a list of named values.
360 ///
361 /// values ::=
362 ///   `(` (named-value (`,` named-value)*)? `)`
363 /// named-value := bare-id `:` ssa-value
364 static void printNamedValueList(OpAsmPrinter &p, Operation *op,
365                                 OperandRange operands,
366                                 ArrayAttr valueNamesAttr) {
367   printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
368 }
369 
370 /// Print a list of named values with their variadicities first. By default, the
371 /// variadicity is single.
372 ///
373 /// values-with-variadicity ::=
374 ///   `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
375 /// value-with-variadicity ::=
376 ///   bare-id `:` ("single" | "optional" | "variadic")? ssa-value
377 static void printNamedValueListWithVariadicity(
378     OpAsmPrinter &p, Operation *op, OperandRange operands,
379     ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
380   printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
381 }
382 
383 static ParseResult
384 parseAttributesOp(OpAsmParser &p,
385                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
386                   ArrayAttr &attrNamesAttr) {
387   Builder &builder = p.getBuilder();
388   SmallVector<Attribute> attrNames;
389   if (succeeded(p.parseOptionalLBrace())) {
390     auto parseOperands = [&]() {
391       if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
392           p.parseOperand(attrOperands.emplace_back()))
393         return failure();
394       return success();
395     };
396     if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
397       return failure();
398   }
399   attrNamesAttr = builder.getArrayAttr(attrNames);
400   return success();
401 }
402 
403 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
404                               OperandRange attrArgs, ArrayAttr attrNames) {
405   if (attrNames.empty())
406     return;
407   p << "{";
408   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
409                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
410   p << '}';
411 }
412 
413 LogicalResult RegionOp::verify() {
414   if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
415     if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
416       return emitOpError("the number of blocks is expected to be >= 1 but got ")
417              << number;
418     }
419   return success();
420 }
421 
422 LogicalResult RegionsOp::verify() {
423   return verifyNames(*this, "region", getNames(), getNumOperands());
424 }
425 
426 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
427 
428 #define GET_TYPEDEF_CLASSES
429 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
430 
431 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
432 
433 #define GET_ATTRDEF_CLASSES
434 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
435 
436 #define GET_OP_CLASSES
437 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
438