xref: /llvm-project/mlir/include/mlir/IR/EnumAttr.td (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1//===-- EnumAttr.td - Enum attributes ----------------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef ENUMATTR_TD
10#define ENUMATTR_TD
11
12include "mlir/IR/AttrTypeBase.td"
13
14//===----------------------------------------------------------------------===//
15// Enum attribute kinds
16
17// Additional information for an enum attribute case.
18class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
19  // The C++ enumerant symbol.
20  string symbol = sym;
21
22  // The C++ enumerant value.
23  // If less than zero, there will be no explicit discriminator values assigned
24  // to enumerators in the generated enum class.
25  int value = intVal;
26
27  // The string representation of the enumerant. May be the same as symbol.
28  string str = strVal;
29}
30
31// An enum attribute case stored with IntegerAttr, which has an integer value,
32// its representation as a string and a C++ symbol name which may be different.
33class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
34    EnumAttrCaseInfo<sym, intVal, strVal>,
35    SignlessIntegerAttrBase<intType, "case " # strVal> {
36  let predicate =
37    CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>;
38}
39
40// Cases of integer enum attributes with a specific type. By default, the string
41// representation is the same as the C++ symbol name.
42class I32EnumAttrCase<string sym, int val, string str = sym>
43    : IntEnumAttrCaseBase<I32, sym, str, val>;
44class I64EnumAttrCase<string sym, int val, string str = sym>
45    : IntEnumAttrCaseBase<I64, sym, str, val>;
46
47// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal
48// number of a bit that is set. It is an integer value with bits set to match
49// the case.
50class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
51    EnumAttrCaseInfo<sym, val, str>,
52    SignlessIntegerAttrBase<intType, "case " #str>;
53
54class I8BitEnumAttrCase<string sym, int val, string str = sym>
55    : BitEnumAttrCaseBase<I8, sym, val, str>;
56class I16BitEnumAttrCase<string sym, int val, string str = sym>
57    : BitEnumAttrCaseBase<I16, sym, val, str>;
58class I32BitEnumAttrCase<string sym, int val, string str = sym>
59    : BitEnumAttrCaseBase<I32, sym, val, str>;
60class I64BitEnumAttrCase<string sym, int val, string str = sym>
61    : BitEnumAttrCaseBase<I64, sym, val, str>;
62
63// The special bit enum case with no bits set (i.e. value = 0).
64class I8BitEnumAttrCaseNone<string sym, string str = sym>
65    : I8BitEnumAttrCase<sym, 0, str>;
66class I16BitEnumAttrCaseNone<string sym, string str = sym>
67    : I16BitEnumAttrCase<sym, 0, str>;
68class I32BitEnumAttrCaseNone<string sym, string str = sym>
69    : I32BitEnumAttrCase<sym, 0, str>;
70class I64BitEnumAttrCaseNone<string sym, string str = sym>
71    : I64BitEnumAttrCase<sym, 0, str>;
72
73// A bit enum case for a single bit, specified by a bit position.
74// The pos argument refers to the index of the bit, and is limited
75// to be in the range [0, bitwidth).
76class BitEnumAttrCaseBit<I intType, string sym, int pos, string str = sym>
77    : BitEnumAttrCaseBase<intType, sym, !shl(1, pos), str> {
78  assert !and(!ge(pos, 0), !lt(pos, intType.bitwidth)),
79      "bit position larger than underlying storage";
80}
81
82class I8BitEnumAttrCaseBit<string sym, int pos, string str = sym>
83    : BitEnumAttrCaseBit<I8, sym, pos, str>;
84class I16BitEnumAttrCaseBit<string sym, int pos, string str = sym>
85    : BitEnumAttrCaseBit<I16, sym, pos, str>;
86class I32BitEnumAttrCaseBit<string sym, int pos, string str = sym>
87    : BitEnumAttrCaseBit<I32, sym, pos, str>;
88class I64BitEnumAttrCaseBit<string sym, int pos, string str = sym>
89    : BitEnumAttrCaseBit<I64, sym, pos, str>;
90
91// A bit enum case for a group/list of previously declared cases, providing
92// a convenient alias for that group.
93class BitEnumAttrCaseGroup<I intType, string sym,
94                           list<BitEnumAttrCaseBase> cases, string str = sym>
95    : BitEnumAttrCaseBase<intType, sym,
96          !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
97          str>;
98
99class I8BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
100                              string str = sym>
101    : BitEnumAttrCaseGroup<I8, sym, cases, str>;
102class I16BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
103                              string str = sym>
104    : BitEnumAttrCaseGroup<I16, sym, cases, str>;
105class I32BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
106                              string str = sym>
107    : BitEnumAttrCaseGroup<I32, sym, cases, str>;
108class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
109                              string str = sym>
110    : BitEnumAttrCaseGroup<I64, sym, cases, str>;
111
112// Additional information for an enum attribute.
113class EnumAttrInfo<
114    string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
115      Attr<baseClass.predicate, baseClass.summary> {
116
117  // Generate a description of this enums members for the MLIR docs.
118  let description =
119        "Enum cases:\n" # !interleave(
120          !foreach(case, cases,
121              "* " # case.str  # " (`" # case.symbol # "`)"), "\n");
122
123  // The C++ enum class name
124  string className = name;
125
126  // List of all accepted cases
127  list<EnumAttrCaseInfo> enumerants = cases;
128
129  // The following fields are only used by the EnumsGen backend to generate
130  // an enum class definition and conversion utility functions.
131
132  // The underlying type for the C++ enum class. An empty string mean the
133  // underlying type is not explicitly specified.
134  string underlyingType = "";
135
136  // The name of the utility function that converts a value of the underlying
137  // type to the corresponding symbol. It will have the following signature:
138  //
139  // ```c++
140  // std::optional<<qualified-enum-class-name>> <fn-name>(<underlying-type>);
141  // ```
142  string underlyingToSymbolFnName = "symbolize" # name;
143
144  // The name of the utility function that converts a string to the
145  // corresponding symbol. It will have the following signature:
146  //
147  // ```c++
148  // std::optional<<qualified-enum-class-name>> <fn-name>(llvm::StringRef);
149  // ```
150  string stringToSymbolFnName = "symbolize" # name;
151
152  // The name of the utility function that converts a symbol to the
153  // corresponding string. It will have the following signature:
154  //
155  // ```c++
156  // <return-type> <fn-name>(<qualified-enum-class-name>);
157  // ```
158  string symbolToStringFnName = "stringify" # name;
159  string symbolToStringFnRetType = "::llvm::StringRef";
160
161  // The name of the utility function that returns the max enum value used
162  // within the enum class. It will have the following signature:
163  //
164  // ```c++
165  // static constexpr unsigned <fn-name>();
166  // ```
167  string maxEnumValFnName = "getMaxEnumValFor" # name;
168
169  // Generate specialized Attribute class
170  bit genSpecializedAttr = 1;
171  // The underlying Attribute class, which holds the enum value
172  Attr baseAttrClass = baseClass;
173  // The name of specialized Enum Attribute class
174  string specializedAttrClassName = name # Attr;
175
176  // Override Attr class fields for specialized class
177  let predicate = !if(genSpecializedAttr,
178    CPred<"::llvm::isa<" # cppNamespace # "::" # specializedAttrClassName # ">($_self)">,
179    baseAttrClass.predicate);
180  let storageType = !if(genSpecializedAttr,
181    cppNamespace # "::" # specializedAttrClassName,
182    baseAttrClass.storageType);
183  let returnType = !if(genSpecializedAttr,
184    cppNamespace # "::" # className,
185    baseAttrClass.returnType);
186  let constBuilderCall = !if(genSpecializedAttr,
187    cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
188    baseAttrClass.constBuilderCall);
189  let valueType = baseAttrClass.valueType;
190
191  // C++ type wrapped by attribute
192  string cppType = cppNamespace # "::" # className;
193
194  // Parser and printer code used by the EnumParameter class, to be provided by
195  // derived classes
196  string parameterParser = ?;
197  string parameterPrinter = ?;
198}
199
200// An enum attribute backed by IntegerAttr.
201//
202// Op attributes of this kind are stored as IntegerAttr. Extra verification will
203// be generated on the integer though: only the values of the allowed cases are
204// permitted as the integer value.
205class IntEnumAttrBase<I intType, list<IntEnumAttrCaseBase> cases, string summary> :
206    SignlessIntegerAttrBase<intType, summary> {
207  let predicate = And<[
208    SignlessIntegerAttrBase<intType, summary>.predicate,
209    Or<!foreach(case, cases, case.predicate)>]>;
210}
211
212class IntEnumAttr<I intType, string name, string summary,
213                  list<IntEnumAttrCaseBase> cases> :
214  EnumAttrInfo<name, cases,
215    IntEnumAttrBase<intType, cases,
216      !if(!empty(summary), "allowed " # intType.summary # " cases: " #
217          !interleave(!foreach(case, cases, case.value), ", "),
218          summary)>> {
219  // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
220  // symbol is not valid.
221  let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
222    auto loc = $_parser.getCurrentLocation();
223    ::llvm::StringRef enumKeyword;
224    if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
225      return ::mlir::failure();
226    auto maybeEnum = }] # cppNamespace # "::" #
227                          stringToSymbolFnName # [{(enumKeyword);
228    if (maybeEnum)
229      return *maybeEnum;
230    return {(::llvm::LogicalResult)($_parser.emitError(loc) << "expected " }] #
231    [{<< "}] # cppType # [{" << " to be one of: " << }] #
232    !interleave(!foreach(enum, enumerants, "\"" # enum.str # "\""),
233                [{ << ", " << }]) # [{)};
234  }()}];
235  // Print the enum by calling `symbolToString`.
236  let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
237}
238
239class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
240    IntEnumAttr<I32, name, summary, cases> {
241  let underlyingType = "uint32_t";
242}
243class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
244    IntEnumAttr<I64, name, summary, cases> {
245  let underlyingType = "uint64_t";
246}
247
248// A bit enum stored with an IntegerAttr.
249//
250// Op attributes of this kind are stored as IntegerAttr. Extra verification will
251// be generated on the integer to make sure only allowed bits are set. Besides,
252// helper methods are generated to parse a string separated with a specified
253// delimiter to a symbol and vice versa.
254class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
255                      string summary>
256    : SignlessIntegerAttrBase<intType, summary> {
257  let predicate = And<[
258    SignlessIntegerAttrBase<intType, summary>.predicate,
259    // Make sure we don't have unknown bit set.
260    CPred<"!(::llvm::cast<::mlir::IntegerAttr>($_self).getValue().getZExtValue() & (~("
261          # !interleave(!foreach(case, cases, case.value # "u"), "|") #
262          ")))">
263  ]>;
264}
265
266class BitEnumAttr<I intType, string name, string summary,
267                  list<BitEnumAttrCaseBase> cases>
268    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>> {
269  // Determine "valid" bits from enum cases for error checking
270  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
271
272  // We need to return a string because we may concatenate symbols for multiple
273  // bits together.
274  let symbolToStringFnRetType = "std::string";
275
276  // The delimiter used to separate bit enum cases in strings. Only "|" and
277  // "," (along with optional spaces) are supported due to the use of the
278  // parseSeparatorFn in parameterParser below.
279  // Spaces in the separator string are used for printing, but will be optional
280  // for parsing.
281  string separator = "|";
282  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
283      "separator must contain '|' or ',' for parameter parsing";
284
285  // Parsing function that corresponds to the enum separator. Only
286  // "," and "|" are supported by this definition.
287  string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
288                                "parseOptionalVerticalBar",
289                                "parseOptionalComma");
290
291  // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
292  // symbol is not valid.
293  let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
294    }] # cppType # [{ flags = {};
295    auto loc = $_parser.getCurrentLocation();
296    ::llvm::StringRef enumKeyword;
297    do {
298      if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
299        return ::mlir::failure();
300      auto maybeEnum = }] # cppNamespace # "::" #
301                            stringToSymbolFnName # [{(enumKeyword);
302      if (!maybeEnum) {
303          return {(::llvm::LogicalResult)($_parser.emitError(loc) << }] #
304              [{"expected " << "}] # cppType # [{" << " to be one of: " << }] #
305              !interleave(!foreach(enum, enumerants, "\"" # enum.str # "\""),
306                          [{ << ", " << }]) # [{)};
307      }
308      flags = flags | *maybeEnum;
309    } while(::mlir::succeeded($_parser.}] # parseSeparatorFn # [{()));
310    return flags;
311  }()}];
312  // Print the enum by calling `symbolToString`.
313  let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
314
315  // Print the "primary group" only for bits that are members of case groups
316  // that have all bits present. When the value is 0, printing will display both
317  // both individual bit case names AND the names for all groups that the bit is
318  // contained in. When the value is 1, for each bit that is set AND is a member
319  // of a group with all bits set, only the "primary group" (i.e. the first
320  // group with all bits set in reverse declaration order) will be printed (for
321  // conciseness).
322  bit printBitEnumPrimaryGroups = 0;
323}
324
325class I8BitEnumAttr<string name, string summary,
326                     list<BitEnumAttrCaseBase> cases>
327    : BitEnumAttr<I8, name, summary, cases> {
328  let underlyingType = "uint8_t";
329}
330
331class I16BitEnumAttr<string name, string summary,
332                     list<BitEnumAttrCaseBase> cases>
333    : BitEnumAttr<I16, name, summary, cases> {
334  let underlyingType = "uint16_t";
335}
336
337class I32BitEnumAttr<string name, string summary,
338                     list<BitEnumAttrCaseBase> cases>
339    : BitEnumAttr<I32, name, summary, cases> {
340  let underlyingType = "uint32_t";
341}
342
343class I64BitEnumAttr<string name, string summary,
344                     list<BitEnumAttrCaseBase> cases>
345    : BitEnumAttr<I64, name, summary, cases> {
346  let underlyingType = "uint64_t";
347}
348
349// A C++ enum as an attribute parameter. The parameter implements a parser and
350// printer for the enum by dispatching calls to `stringToSymbol` and
351// `symbolToString`.
352class EnumParameter<EnumAttrInfo enumInfo>
353    : AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
354                    "an enum of type " # enumInfo.className> {
355  let parser = enumInfo.parameterParser;
356  let printer = enumInfo.parameterPrinter;
357}
358
359// An attribute backed by a C++ enum. The attribute contains a single
360// parameter `value` whose type is the C++ enum class.
361//
362// Example:
363//
364// ```
365// def MyEnum : I32EnumAttr<"MyEnum", "a simple enum", [
366//                            I32EnumAttrCase<"First", 0, "first">,
367//                            I32EnumAttrCase<"Second", 1, "second>]> {
368//   let genSpecializedAttr = 0;
369// }
370//
371// def MyEnumAttr : EnumAttr<MyDialect, MyEnum, "enum">;
372// ```
373//
374// By default, the assembly format of the attribute works best with operation
375// assembly formats. For example:
376//
377// ```
378// def MyOp : Op<MyDialect, "my_op"> {
379//   let arguments = (ins MyEnumAttr:$enum);
380//   let assemblyFormat = "$enum attr-dict";
381// }
382// ```
383//
384// The op will appear in the IR as `my_dialect.my_op first`. However, the
385// generic format of the attribute will be `#my_dialect<"enum first">`. Override
386// the attribute's assembly format as required.
387class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
388               list <Trait> traits = []>
389    : AttrDef<dialect, enumInfo.className, traits> {
390  let summary = enumInfo.summary;
391  let description = enumInfo.description;
392
393  // The backing enumeration.
394  EnumAttrInfo enum = enumInfo;
395
396  // Inherit the C++ namespace from the enum.
397  let cppNamespace = enumInfo.cppNamespace;
398
399  // Define a constant builder for the attribute to convert from C++ enums.
400  let constBuilderCall = cppNamespace # "::" # cppClassName #
401                         "::get($_builder.getContext(), $0)";
402
403  // Op attribute getters should return the underlying C++ enum type.
404  let returnType = enumInfo.cppNamespace # "::" # enumInfo.className;
405
406  // Convert from attribute to the underlying C++ type in op getters.
407  let convertFromStorage = "$_self.getValue()";
408
409  // The enum attribute has one parameter: the C++ enum value.
410  let parameters = (ins EnumParameter<enumInfo>:$value);
411
412  // If a mnemonic was provided, use it to generate a custom assembly format.
413  let mnemonic = name;
414
415  // The default assembly format for enum attributes. Selected to best work with
416  // operation assembly formats.
417  let assemblyFormat = "$value";
418}
419
420class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
421  defvar cases =
422    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
423
424  assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
425
426  // `!empty` check to not cause an error if the cases are empty.
427  // The assertion catches the issue later and emits a proper error message.
428  string value = enumAttrInfo.cppType # "::"
429    # !if(!empty(cases), "", !head(cases).symbol);
430}
431
432class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
433  defvar pos = !find(case, "|");
434
435  // Recursive instantiation looking up the symbol before the `|` in
436  // enum cases.
437  string value = !if(
438    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
439    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
440    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
441  );
442}
443
444class ConstantEnumCaseBase<Attr attribute,
445    EnumAttrInfo enumAttrInfo, string case>
446  : ConstantAttr<attribute,
447  !if(!isa<BitEnumAttr>(enumAttrInfo),
448    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
449    _symbolToValue<enumAttrInfo, case>.value
450  )
451>;
452
453/// Attribute constraint matching a constant enum case. `attribute` should be
454/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
455/// of an enum case. Multiple enum values of a bit-enum can be combined using
456/// `|` as a separator. Note that there mustn't be any whitespace around the
457/// separator.
458/// This attribute constraint is additionally buildable, making it possible to
459/// use it in result patterns.
460///
461/// Examples:
462/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
463/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
464class ConstantEnumCase<Attr attribute, string case>
465  : ConstantEnumCaseBase<attribute,
466    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
467          !cast<EnumAttr>(attribute).enum), case> {
468  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
469    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
470}
471
472#endif // ENUMATTR_TD
473