xref: /llvm-project/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td (revision 72e8b9aeaa3f584f223bc59924812df69a09a48b)
1//===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definition of the ElementsAttr interface.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
14#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
15
16include "mlir/IR/OpBase.td"
17
18//===----------------------------------------------------------------------===//
19// TypedAttrInterface
20//===----------------------------------------------------------------------===//
21
22def TypedAttrInterface : AttrInterface<"TypedAttr"> {
23  let cppNamespace = "::mlir";
24
25  let description = [{
26    This interface is used for attributes that have a type. The type of an
27    attribute is understood to represent the type of the data contained in the
28    attribute and is often used as the type of a value with this data.
29  }];
30
31  let methods = [InterfaceMethod<
32    "Get the attribute's type",
33    "::mlir::Type", "getType"
34  >];
35}
36
37//===----------------------------------------------------------------------===//
38// BlobAttrInterface
39//===----------------------------------------------------------------------===//
40
41def BlobAttrInterface : AttrInterface<"BlobAttr"> {
42  let cppNamespace = "::mlir";
43  let description = [{
44    This interface allows an attribute to expose a blob of data without more
45    information. The data must be stored so that it can be accessed as a
46    contiguous ArrayRef.
47  }];
48
49  let methods = [InterfaceMethod<
50    "Get the attribute's data",
51    "::llvm::ArrayRef<char>", "getData"
52  >];
53}
54
55//===----------------------------------------------------------------------===//
56// ElementsAttrInterface
57//===----------------------------------------------------------------------===//
58
59def ElementsAttrInterface : AttrInterface<"ElementsAttr", [TypedAttrInterface]> {
60  let cppNamespace = "::mlir";
61  let description = [{
62    This interface is used for attributes that contain the constant elements of
63    a tensor or vector type. It allows for opaquely interacting with the
64    elements of the underlying attribute, and most importantly allows for
65    accessing the element values (including iteration) in any of the C++ data
66    types supported by the underlying attribute.
67
68    An attribute implementing this interface can expose the supported data types
69    in two steps:
70
71    * Define the set of iterable C++ data types:
72
73    An attribute may define the set of iterable types by providing a definition
74    of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`.
75
76    -  `ContiguousIterableTypesT` should contain types which can be iterated
77       contiguously. A contiguous range is an array-like range, such as
78       ArrayRef, where all of the elements are layed out sequentially in memory.
79
80    -  `NonContiguousIterableTypesT` should contain types which can not be
81       iterated contiguously. A non-contiguous range implies no contiguity,
82       whose elements may even be materialized when indexing, such as the case
83       for a mapped_range.
84
85    As an example, consider an attribute that only contains i64 elements, with
86    the elements being stored within an ArrayRef. This attribute could
87    potentially define the iterable types as so:
88
89    ```c++
90    using ContiguousIterableTypesT = std::tuple<uint64_t>;
91    using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
92    ```
93
94    * Provide a `FailureOr<iterator> try_value_begin_impl(OverloadToken<T>) const`
95      overload for each iterable type
96
97    These overloads should return an iterator to the start of the range for the
98    respective iterable type or fail if the type cannot be iterated. Consider
99    the example i64 elements attribute described in the previous section. This
100    attribute may define the value_begin_impl overloads like so:
101
102    ```c++
103    /// Provide begin iterators for the various iterable types.
104    /// * uint64_t
105    FailureOr<const uint64_t *>
106    value_begin_impl(OverloadToken<uint64_t>) const {
107      return getElements().begin();
108    }
109    /// * APInt
110    auto value_begin_impl(OverloadToken<llvm::APInt>) const {
111      auto it = llvm::map_range(getElements(), [=](uint64_t value) {
112        return llvm::APInt(/*numBits=*/64, value);
113      }).begin();
114      return FailureOr<decltype(it)>(std::move(it));
115    }
116    /// * Attribute
117    auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
118      mlir::Type elementType = getShapedType().getElementType();
119      auto it = llvm::map_range(getElements(), [=](uint64_t value) {
120        return mlir::IntegerAttr::get(elementType,
121                                      llvm::APInt(/*numBits=*/64, value));
122      }).begin();
123      return FailureOr<decltype(it)>(std::move(it));
124    }
125    ```
126
127    After the above, ElementsAttr will now be able to iterate over elements
128    using each of the registered iterable data types:
129
130    ```c++
131    ElementsAttr attr = myI64ElementsAttr;
132
133    // We can access value ranges for the data types via `getValues<T>`.
134    for (uint64_t value : attr.getValues<uint64_t>())
135      ...;
136    for (llvm::APInt value : attr.getValues<llvm::APInt>())
137      ...;
138    for (mlir::IntegerAttr value : attr.getValues<mlir::IntegerAttr>())
139      ...;
140
141    // We can also access the value iterators directly.
142    auto it = attr.value_begin<uint64_t>(), e = attr.value_end<uint64_t>();
143    for (; it != e; ++it) {
144      uint64_t value = *it;
145      ...
146    }
147    ```
148
149    ElementsAttr also supports failable access to iterators and ranges. This
150    allows for safely checking if the attribute supports the data type, and can
151    also allow for code to have fast paths for native data types.
152
153    ```c++
154    // Using `tryGetValues<T>`, we can also safely handle when the attribute
155    // doesn't support the data type.
156    if (auto range = attr.tryGetValues<uint64_t>()) {
157      for (uint64_t value : *range)
158        ...;
159      return;
160    }
161
162    // We can also access the begin iterator safely, by using `try_value_begin`.
163    if (auto safeIt = attr.try_value_begin<uint64_t>()) {
164      auto it = *safeIt, e = attr.value_end<uint64_t>();
165      for (; it != e; ++it) {
166        uint64_t value = *it;
167        ...
168      }
169      return;
170    }
171    ```
172  }];
173  let methods = [
174    InterfaceMethod<[{
175      This method returns an opaque range indexer for the given elementID, which
176      corresponds to a desired C++ element data type. Returns the indexer if the
177      attribute supports the given data type, failure otherwise.
178    }],
179    "::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl",
180    (ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{
181      auto result = getValueImpl(
182        (typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID,
183        /*isContiguous=*/std::true_type());
184      if (succeeded(result))
185        return std::move(result);
186
187      return getValueImpl(
188        (typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr,
189        elementID, /*isContiguous=*/std::false_type());
190    }]>,
191    InterfaceMethod<[{
192      Returns true if the attribute elements correspond to a splat, i.e. that
193      all elements of the attribute are the same value.
194    }], "bool", "isSplat", (ins), [{}], /*defaultImplementation=*/[{
195        // By default, only check for a single element splat.
196        return $_attr.getNumElements() == 1;
197    }]>,
198    InterfaceMethod<[{
199      Returns the shaped type of the elements attribute.
200    }], "::mlir::ShapedType", "getShapedType", (ins), [{}], /*defaultImplementation=*/[{
201        return $_attr.getType();
202    }]>
203  ];
204
205  string ElementsAttrInterfaceAccessors = [{
206    /// Return the number of elements held by this attribute.
207    int64_t size() const { return getNumElements(); }
208
209    /// Return if the attribute holds no elements.
210    bool empty() const { return size() == 0; }
211  }];
212
213  let extraTraitClassDeclaration = [{
214    // By default, no types are iterable.
215    using ContiguousIterableTypesT = std::tuple<>;
216    using NonContiguousIterableTypesT = std::tuple<>;
217
218    //===------------------------------------------------------------------===//
219    // Accessors
220    //===------------------------------------------------------------------===//
221
222    /// Return the element type of this ElementsAttr.
223    Type getElementType() const {
224      return ::mlir::ElementsAttr::getElementType($_attr);
225    }
226
227    /// Returns the number of elements held by this attribute.
228    int64_t getNumElements() const {
229      return ::mlir::ElementsAttr::getNumElements($_attr);
230    }
231
232    /// Return if the given 'index' refers to a valid element in this attribute.
233    bool isValidIndex(ArrayRef<uint64_t> index) const {
234      return ::mlir::ElementsAttr::isValidIndex($_attr, index);
235    }
236
237  protected:
238    /// Returns the 1-dimensional flattened row-major index from the given
239    /// multi-dimensional index.
240    uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
241      return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index);
242    }
243
244    //===------------------------------------------------------------------===//
245    // Value Iteration Internals
246    //===------------------------------------------------------------------===//
247  protected:
248    /// This class is used to allow specifying function overloads for different
249    /// types, without actually taking the types as parameters. This avoids the
250    /// need to build complicated SFINAE to select specific overloads.
251    template <typename T>
252    struct OverloadToken {};
253
254  private:
255    /// This function unpacks the types within a given tuple and then forwards
256    /// on to the unwrapped variant.
257    template <typename... Ts, typename IsContiguousT>
258    auto getValueImpl(std::tuple<Ts...> *, ::mlir::TypeID elementID,
259                      IsContiguousT isContiguous) const {
260      return getValueImpl<Ts...>(elementID, isContiguous);
261    }
262    /// Check to see if the given `elementID` matches the current type `T`. If
263    /// it does, build a value result using the current type. If it doesn't,
264    /// keep looking for the desired type.
265    template <typename T, typename... Ts, typename IsContiguousT>
266    auto getValueImpl(::mlir::TypeID elementID,
267                      IsContiguousT isContiguous) const {
268      if (::mlir::TypeID::get<T>() == elementID)
269        return buildValueResult<T>(isContiguous);
270      return getValueImpl<Ts...>(elementID, isContiguous);
271    }
272    /// Bottom out case for no matching type.
273    template <typename IsContiguousT>
274    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>
275    getValueImpl(::mlir::TypeID, IsContiguousT) const {
276      return failure();
277    }
278
279    /// Build an indexer for the given type `T`, which is represented via a
280    /// contiguous range.
281    template <typename T>
282    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
283        /*isContiguous*/std::true_type) const {
284      if ($_attr.empty()) {
285        return ::mlir::detail::ElementsAttrIndexer::contiguous<T>(
286          /*isSplat=*/false, nullptr);
287      }
288
289      auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
290      if (::mlir::failed(valueIt))
291        return ::mlir::failure();
292      return ::mlir::detail::ElementsAttrIndexer::contiguous(
293        $_attr.isSplat(), &**valueIt);
294    }
295    /// Build an indexer for the given type `T`, which is represented via a
296    /// non-contiguous range.
297    template <typename T>
298    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
299        /*isContiguous*/std::false_type) const {
300      auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
301      if (::mlir::failed(valueIt))
302        return ::mlir::failure();
303      return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
304        $_attr.isSplat(), *valueIt);
305    }
306
307  public:
308    //===------------------------------------------------------------------===//
309    // Value Iteration
310    //===------------------------------------------------------------------===//
311
312    /// The iterator for the given element type T.
313    template <typename T, typename AttrT = ConcreteAttr>
314    using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
315    /// The iterator range over the given element T.
316    template <typename T, typename AttrT = ConcreteAttr>
317    using iterator_range =
318        decltype(std::declval<AttrT>().template getValues<T>());
319
320    /// Return an iterator to the first element of this attribute as a value of
321    /// type `T`.
322    template <typename T>
323    auto value_begin() const {
324      return *$_attr.try_value_begin_impl(OverloadToken<T>());
325    }
326
327    /// Return the elements of this attribute as a value of type 'T'.
328    template <typename T>
329    auto getValues() const {
330      auto beginIt = $_attr.template value_begin<T>();
331      return detail::ElementsAttrRange<decltype(beginIt)>(
332        $_attr.getType(), beginIt, std::next(beginIt, size()));
333    }
334  }] # ElementsAttrInterfaceAccessors;
335
336  let extraClassDeclaration = [{
337    template <typename T>
338    using iterator = detail::ElementsAttrIterator<T>;
339    template <typename T>
340    using iterator_range = detail::ElementsAttrRange<iterator<T>>;
341
342    //===------------------------------------------------------------------===//
343    // Accessors
344    //===------------------------------------------------------------------===//
345
346    /// Return the element type of this ElementsAttr.
347    Type getElementType() const { return getElementType(*this); }
348    static Type getElementType(ElementsAttr elementsAttr);
349
350    /// Return if the given 'index' refers to a valid element in this attribute.
351    bool isValidIndex(ArrayRef<uint64_t> index) const {
352      return isValidIndex(*this, index);
353    }
354    static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
355    static bool isValidIndex(ElementsAttr elementsAttr,
356                             ArrayRef<uint64_t> index);
357
358    /// Return the 1 dimensional flattened row-major index from the given
359    /// multi-dimensional index.
360    uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
361      return getFlattenedIndex(*this, index);
362    }
363    static uint64_t getFlattenedIndex(Type type,
364                                      ArrayRef<uint64_t> index);
365    static uint64_t getFlattenedIndex(ElementsAttr elementsAttr,
366                                      ArrayRef<uint64_t> index) {
367      return getFlattenedIndex(elementsAttr.getShapedType(), index);
368    }
369
370    /// Returns the number of elements held by this attribute.
371    int64_t getNumElements() const { return getNumElements(*this); }
372    static int64_t getNumElements(ElementsAttr elementsAttr);
373
374    //===------------------------------------------------------------------===//
375    // Value Iteration
376    //===------------------------------------------------------------------===//
377
378    template <typename T>
379    using DerivedAttrValueCheckT =
380        std::enable_if_t<std::is_base_of<Attribute, T>::value &&
381                         !std::is_same<Attribute, T>::value>;
382    template <typename T, typename ResultT>
383    using DefaultValueCheckT =
384        std::enable_if_t<std::is_same<Attribute, T>::value ||
385                             !std::is_base_of<Attribute, T>::value,
386                         ResultT>;
387
388    /// Return the splat value for this attribute. This asserts that the
389    /// attribute corresponds to a splat.
390    template <typename T>
391    T getSplatValue() const {
392      assert(isSplat() && "expected splat attribute");
393      return *value_begin<T>();
394    }
395
396    /// Return the elements of this attribute as a value of type 'T'.
397    template <typename T>
398    DefaultValueCheckT<T, iterator_range<T>> getValues() const {
399      return {getShapedType(), value_begin<T>(), value_end<T>()};
400    }
401    template <typename T>
402    DefaultValueCheckT<T, iterator<T>> value_begin() const;
403    template <typename T>
404    DefaultValueCheckT<T, iterator<T>> value_end() const {
405      return iterator<T>({}, size());
406    }
407
408    /// Return the held element values a range of T, where T is a derived
409    /// attribute type.
410    template <typename T>
411    using DerivedAttrValueIterator =
412      llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
413    template <typename T>
414    using DerivedAttrValueIteratorRange =
415      detail::ElementsAttrRange<DerivedAttrValueIterator<T>>;
416    template <typename T, typename = DerivedAttrValueCheckT<T>>
417    DerivedAttrValueIteratorRange<T> getValues() const {
418      auto castFn = [](Attribute attr) { return ::llvm::cast<T>(attr); };
419      return {getShapedType(), llvm::map_range(getValues<Attribute>(),
420              static_cast<T (*)(Attribute)>(castFn))};
421    }
422    template <typename T, typename = DerivedAttrValueCheckT<T>>
423    DerivedAttrValueIterator<T> value_begin() const {
424      return getValues<T>().begin();
425    }
426    template <typename T, typename = DerivedAttrValueCheckT<T>>
427    DerivedAttrValueIterator<T> value_end() const {
428      return {value_end<Attribute>(), nullptr};
429    }
430
431    //===------------------------------------------------------------------===//
432    // Failable Value Iteration
433
434    /// If this attribute supports iterating over element values of type `T`,
435    /// return the iterable range. Otherwise, return std::nullopt.
436    template <typename T>
437    DefaultValueCheckT<T, std::optional<iterator_range<T>>> tryGetValues() const {
438      if (std::optional<iterator<T>> beginIt = try_value_begin<T>())
439        return iterator_range<T>(getShapedType(), *beginIt, value_end<T>());
440      return std::nullopt;
441    }
442    template <typename T>
443    DefaultValueCheckT<T, std::optional<iterator<T>>> try_value_begin() const;
444
445    /// If this attribute supports iterating over element values of type `T`,
446    /// return the iterable range. Otherwise, return std::nullopt.
447    template <typename T, typename = DerivedAttrValueCheckT<T>>
448    std::optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
449      auto values = tryGetValues<Attribute>();
450      if (!values)
451        return std::nullopt;
452
453      auto castFn = [](Attribute attr) { return ::llvm::cast<T>(attr); };
454      return DerivedAttrValueIteratorRange<T>(
455        getShapedType(),
456        llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
457      );
458    }
459    template <typename T, typename = DerivedAttrValueCheckT<T>>
460    std::optional<DerivedAttrValueIterator<T>> try_value_begin() const {
461      if (auto values = tryGetValues<T>())
462        return values->begin();
463      return std::nullopt;
464    }
465  }] # ElementsAttrInterfaceAccessors;
466}
467
468//===----------------------------------------------------------------------===//
469// MemRefLayoutAttrInterface
470//===----------------------------------------------------------------------===//
471
472def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
473  let cppNamespace = "::mlir";
474
475  let description = [{
476    This interface is used for attributes that can represent the MemRef type's
477    layout semantics, such as dimension order in the memory, strides and offsets.
478    Such a layout attribute should be representable as a
479    [semi-affine map](Affine.md/#semi-affine-maps).
480
481    Note: the MemRef type's layout is assumed to represent simple strided buffer
482    layout. For more complicated case, like sparse storage buffers,
483    it is preferable to use separate type with more specic layout, rather then
484    introducing extra complexity to the builin MemRef type.
485  }];
486
487  let methods = [
488    InterfaceMethod<
489      "Get the MemRef layout as an AffineMap, the method must not return NULL",
490      "::mlir::AffineMap", "getAffineMap", (ins)
491    >,
492
493    InterfaceMethod<
494      "Return true if this attribute represents the identity layout",
495      "bool", "isIdentity", (ins),
496      [{}],
497      [{
498        return $_attr.getAffineMap().isIdentity();
499      }]
500    >,
501
502    InterfaceMethod<
503      "Check if the current layout is applicable to the provided shape",
504      "::llvm::LogicalResult", "verifyLayout",
505      (ins "::llvm::ArrayRef<int64_t>":$shape,
506           "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError),
507      [{}],
508      [{
509        return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
510                                                       shape, emitError);
511      }]
512    >
513  ];
514}
515
516#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
517