1//===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file contains the definition of the ElementsAttr interface. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ 14#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ 15 16include "mlir/IR/OpBase.td" 17 18//===----------------------------------------------------------------------===// 19// TypedAttrInterface 20//===----------------------------------------------------------------------===// 21 22def TypedAttrInterface : AttrInterface<"TypedAttr"> { 23 let cppNamespace = "::mlir"; 24 25 let description = [{ 26 This interface is used for attributes that have a type. The type of an 27 attribute is understood to represent the type of the data contained in the 28 attribute and is often used as the type of a value with this data. 29 }]; 30 31 let methods = [InterfaceMethod< 32 "Get the attribute's type", 33 "::mlir::Type", "getType" 34 >]; 35} 36 37//===----------------------------------------------------------------------===// 38// BlobAttrInterface 39//===----------------------------------------------------------------------===// 40 41def BlobAttrInterface : AttrInterface<"BlobAttr"> { 42 let cppNamespace = "::mlir"; 43 let description = [{ 44 This interface allows an attribute to expose a blob of data without more 45 information. The data must be stored so that it can be accessed as a 46 contiguous ArrayRef. 47 }]; 48 49 let methods = [InterfaceMethod< 50 "Get the attribute's data", 51 "::llvm::ArrayRef<char>", "getData" 52 >]; 53} 54 55//===----------------------------------------------------------------------===// 56// ElementsAttrInterface 57//===----------------------------------------------------------------------===// 58 59def ElementsAttrInterface : AttrInterface<"ElementsAttr", [TypedAttrInterface]> { 60 let cppNamespace = "::mlir"; 61 let description = [{ 62 This interface is used for attributes that contain the constant elements of 63 a tensor or vector type. It allows for opaquely interacting with the 64 elements of the underlying attribute, and most importantly allows for 65 accessing the element values (including iteration) in any of the C++ data 66 types supported by the underlying attribute. 67 68 An attribute implementing this interface can expose the supported data types 69 in two steps: 70 71 * Define the set of iterable C++ data types: 72 73 An attribute may define the set of iterable types by providing a definition 74 of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`. 75 76 - `ContiguousIterableTypesT` should contain types which can be iterated 77 contiguously. A contiguous range is an array-like range, such as 78 ArrayRef, where all of the elements are layed out sequentially in memory. 79 80 - `NonContiguousIterableTypesT` should contain types which can not be 81 iterated contiguously. A non-contiguous range implies no contiguity, 82 whose elements may even be materialized when indexing, such as the case 83 for a mapped_range. 84 85 As an example, consider an attribute that only contains i64 elements, with 86 the elements being stored within an ArrayRef. This attribute could 87 potentially define the iterable types as so: 88 89 ```c++ 90 using ContiguousIterableTypesT = std::tuple<uint64_t>; 91 using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>; 92 ``` 93 94 * Provide a `FailureOr<iterator> try_value_begin_impl(OverloadToken<T>) const` 95 overload for each iterable type 96 97 These overloads should return an iterator to the start of the range for the 98 respective iterable type or fail if the type cannot be iterated. Consider 99 the example i64 elements attribute described in the previous section. This 100 attribute may define the value_begin_impl overloads like so: 101 102 ```c++ 103 /// Provide begin iterators for the various iterable types. 104 /// * uint64_t 105 FailureOr<const uint64_t *> 106 value_begin_impl(OverloadToken<uint64_t>) const { 107 return getElements().begin(); 108 } 109 /// * APInt 110 auto value_begin_impl(OverloadToken<llvm::APInt>) const { 111 auto it = llvm::map_range(getElements(), [=](uint64_t value) { 112 return llvm::APInt(/*numBits=*/64, value); 113 }).begin(); 114 return FailureOr<decltype(it)>(std::move(it)); 115 } 116 /// * Attribute 117 auto value_begin_impl(OverloadToken<mlir::Attribute>) const { 118 mlir::Type elementType = getShapedType().getElementType(); 119 auto it = llvm::map_range(getElements(), [=](uint64_t value) { 120 return mlir::IntegerAttr::get(elementType, 121 llvm::APInt(/*numBits=*/64, value)); 122 }).begin(); 123 return FailureOr<decltype(it)>(std::move(it)); 124 } 125 ``` 126 127 After the above, ElementsAttr will now be able to iterate over elements 128 using each of the registered iterable data types: 129 130 ```c++ 131 ElementsAttr attr = myI64ElementsAttr; 132 133 // We can access value ranges for the data types via `getValues<T>`. 134 for (uint64_t value : attr.getValues<uint64_t>()) 135 ...; 136 for (llvm::APInt value : attr.getValues<llvm::APInt>()) 137 ...; 138 for (mlir::IntegerAttr value : attr.getValues<mlir::IntegerAttr>()) 139 ...; 140 141 // We can also access the value iterators directly. 142 auto it = attr.value_begin<uint64_t>(), e = attr.value_end<uint64_t>(); 143 for (; it != e; ++it) { 144 uint64_t value = *it; 145 ... 146 } 147 ``` 148 149 ElementsAttr also supports failable access to iterators and ranges. This 150 allows for safely checking if the attribute supports the data type, and can 151 also allow for code to have fast paths for native data types. 152 153 ```c++ 154 // Using `tryGetValues<T>`, we can also safely handle when the attribute 155 // doesn't support the data type. 156 if (auto range = attr.tryGetValues<uint64_t>()) { 157 for (uint64_t value : *range) 158 ...; 159 return; 160 } 161 162 // We can also access the begin iterator safely, by using `try_value_begin`. 163 if (auto safeIt = attr.try_value_begin<uint64_t>()) { 164 auto it = *safeIt, e = attr.value_end<uint64_t>(); 165 for (; it != e; ++it) { 166 uint64_t value = *it; 167 ... 168 } 169 return; 170 } 171 ``` 172 }]; 173 let methods = [ 174 InterfaceMethod<[{ 175 This method returns an opaque range indexer for the given elementID, which 176 corresponds to a desired C++ element data type. Returns the indexer if the 177 attribute supports the given data type, failure otherwise. 178 }], 179 "::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl", 180 (ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{ 181 auto result = getValueImpl( 182 (typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID, 183 /*isContiguous=*/std::true_type()); 184 if (succeeded(result)) 185 return std::move(result); 186 187 return getValueImpl( 188 (typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr, 189 elementID, /*isContiguous=*/std::false_type()); 190 }]>, 191 InterfaceMethod<[{ 192 Returns true if the attribute elements correspond to a splat, i.e. that 193 all elements of the attribute are the same value. 194 }], "bool", "isSplat", (ins), [{}], /*defaultImplementation=*/[{ 195 // By default, only check for a single element splat. 196 return $_attr.getNumElements() == 1; 197 }]>, 198 InterfaceMethod<[{ 199 Returns the shaped type of the elements attribute. 200 }], "::mlir::ShapedType", "getShapedType", (ins), [{}], /*defaultImplementation=*/[{ 201 return $_attr.getType(); 202 }]> 203 ]; 204 205 string ElementsAttrInterfaceAccessors = [{ 206 /// Return the number of elements held by this attribute. 207 int64_t size() const { return getNumElements(); } 208 209 /// Return if the attribute holds no elements. 210 bool empty() const { return size() == 0; } 211 }]; 212 213 let extraTraitClassDeclaration = [{ 214 // By default, no types are iterable. 215 using ContiguousIterableTypesT = std::tuple<>; 216 using NonContiguousIterableTypesT = std::tuple<>; 217 218 //===------------------------------------------------------------------===// 219 // Accessors 220 //===------------------------------------------------------------------===// 221 222 /// Return the element type of this ElementsAttr. 223 Type getElementType() const { 224 return ::mlir::ElementsAttr::getElementType($_attr); 225 } 226 227 /// Returns the number of elements held by this attribute. 228 int64_t getNumElements() const { 229 return ::mlir::ElementsAttr::getNumElements($_attr); 230 } 231 232 /// Return if the given 'index' refers to a valid element in this attribute. 233 bool isValidIndex(ArrayRef<uint64_t> index) const { 234 return ::mlir::ElementsAttr::isValidIndex($_attr, index); 235 } 236 237 protected: 238 /// Returns the 1-dimensional flattened row-major index from the given 239 /// multi-dimensional index. 240 uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const { 241 return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index); 242 } 243 244 //===------------------------------------------------------------------===// 245 // Value Iteration Internals 246 //===------------------------------------------------------------------===// 247 protected: 248 /// This class is used to allow specifying function overloads for different 249 /// types, without actually taking the types as parameters. This avoids the 250 /// need to build complicated SFINAE to select specific overloads. 251 template <typename T> 252 struct OverloadToken {}; 253 254 private: 255 /// This function unpacks the types within a given tuple and then forwards 256 /// on to the unwrapped variant. 257 template <typename... Ts, typename IsContiguousT> 258 auto getValueImpl(std::tuple<Ts...> *, ::mlir::TypeID elementID, 259 IsContiguousT isContiguous) const { 260 return getValueImpl<Ts...>(elementID, isContiguous); 261 } 262 /// Check to see if the given `elementID` matches the current type `T`. If 263 /// it does, build a value result using the current type. If it doesn't, 264 /// keep looking for the desired type. 265 template <typename T, typename... Ts, typename IsContiguousT> 266 auto getValueImpl(::mlir::TypeID elementID, 267 IsContiguousT isContiguous) const { 268 if (::mlir::TypeID::get<T>() == elementID) 269 return buildValueResult<T>(isContiguous); 270 return getValueImpl<Ts...>(elementID, isContiguous); 271 } 272 /// Bottom out case for no matching type. 273 template <typename IsContiguousT> 274 ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> 275 getValueImpl(::mlir::TypeID, IsContiguousT) const { 276 return failure(); 277 } 278 279 /// Build an indexer for the given type `T`, which is represented via a 280 /// contiguous range. 281 template <typename T> 282 ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( 283 /*isContiguous*/std::true_type) const { 284 if ($_attr.empty()) { 285 return ::mlir::detail::ElementsAttrIndexer::contiguous<T>( 286 /*isSplat=*/false, nullptr); 287 } 288 289 auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>()); 290 if (::mlir::failed(valueIt)) 291 return ::mlir::failure(); 292 return ::mlir::detail::ElementsAttrIndexer::contiguous( 293 $_attr.isSplat(), &**valueIt); 294 } 295 /// Build an indexer for the given type `T`, which is represented via a 296 /// non-contiguous range. 297 template <typename T> 298 ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( 299 /*isContiguous*/std::false_type) const { 300 auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>()); 301 if (::mlir::failed(valueIt)) 302 return ::mlir::failure(); 303 return ::mlir::detail::ElementsAttrIndexer::nonContiguous( 304 $_attr.isSplat(), *valueIt); 305 } 306 307 public: 308 //===------------------------------------------------------------------===// 309 // Value Iteration 310 //===------------------------------------------------------------------===// 311 312 /// The iterator for the given element type T. 313 template <typename T, typename AttrT = ConcreteAttr> 314 using iterator = decltype(std::declval<AttrT>().template value_begin<T>()); 315 /// The iterator range over the given element T. 316 template <typename T, typename AttrT = ConcreteAttr> 317 using iterator_range = 318 decltype(std::declval<AttrT>().template getValues<T>()); 319 320 /// Return an iterator to the first element of this attribute as a value of 321 /// type `T`. 322 template <typename T> 323 auto value_begin() const { 324 return *$_attr.try_value_begin_impl(OverloadToken<T>()); 325 } 326 327 /// Return the elements of this attribute as a value of type 'T'. 328 template <typename T> 329 auto getValues() const { 330 auto beginIt = $_attr.template value_begin<T>(); 331 return detail::ElementsAttrRange<decltype(beginIt)>( 332 $_attr.getType(), beginIt, std::next(beginIt, size())); 333 } 334 }] # ElementsAttrInterfaceAccessors; 335 336 let extraClassDeclaration = [{ 337 template <typename T> 338 using iterator = detail::ElementsAttrIterator<T>; 339 template <typename T> 340 using iterator_range = detail::ElementsAttrRange<iterator<T>>; 341 342 //===------------------------------------------------------------------===// 343 // Accessors 344 //===------------------------------------------------------------------===// 345 346 /// Return the element type of this ElementsAttr. 347 Type getElementType() const { return getElementType(*this); } 348 static Type getElementType(ElementsAttr elementsAttr); 349 350 /// Return if the given 'index' refers to a valid element in this attribute. 351 bool isValidIndex(ArrayRef<uint64_t> index) const { 352 return isValidIndex(*this, index); 353 } 354 static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index); 355 static bool isValidIndex(ElementsAttr elementsAttr, 356 ArrayRef<uint64_t> index); 357 358 /// Return the 1 dimensional flattened row-major index from the given 359 /// multi-dimensional index. 360 uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const { 361 return getFlattenedIndex(*this, index); 362 } 363 static uint64_t getFlattenedIndex(Type type, 364 ArrayRef<uint64_t> index); 365 static uint64_t getFlattenedIndex(ElementsAttr elementsAttr, 366 ArrayRef<uint64_t> index) { 367 return getFlattenedIndex(elementsAttr.getShapedType(), index); 368 } 369 370 /// Returns the number of elements held by this attribute. 371 int64_t getNumElements() const { return getNumElements(*this); } 372 static int64_t getNumElements(ElementsAttr elementsAttr); 373 374 //===------------------------------------------------------------------===// 375 // Value Iteration 376 //===------------------------------------------------------------------===// 377 378 template <typename T> 379 using DerivedAttrValueCheckT = 380 std::enable_if_t<std::is_base_of<Attribute, T>::value && 381 !std::is_same<Attribute, T>::value>; 382 template <typename T, typename ResultT> 383 using DefaultValueCheckT = 384 std::enable_if_t<std::is_same<Attribute, T>::value || 385 !std::is_base_of<Attribute, T>::value, 386 ResultT>; 387 388 /// Return the splat value for this attribute. This asserts that the 389 /// attribute corresponds to a splat. 390 template <typename T> 391 T getSplatValue() const { 392 assert(isSplat() && "expected splat attribute"); 393 return *value_begin<T>(); 394 } 395 396 /// Return the elements of this attribute as a value of type 'T'. 397 template <typename T> 398 DefaultValueCheckT<T, iterator_range<T>> getValues() const { 399 return {getShapedType(), value_begin<T>(), value_end<T>()}; 400 } 401 template <typename T> 402 DefaultValueCheckT<T, iterator<T>> value_begin() const; 403 template <typename T> 404 DefaultValueCheckT<T, iterator<T>> value_end() const { 405 return iterator<T>({}, size()); 406 } 407 408 /// Return the held element values a range of T, where T is a derived 409 /// attribute type. 410 template <typename T> 411 using DerivedAttrValueIterator = 412 llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>; 413 template <typename T> 414 using DerivedAttrValueIteratorRange = 415 detail::ElementsAttrRange<DerivedAttrValueIterator<T>>; 416 template <typename T, typename = DerivedAttrValueCheckT<T>> 417 DerivedAttrValueIteratorRange<T> getValues() const { 418 auto castFn = [](Attribute attr) { return ::llvm::cast<T>(attr); }; 419 return {getShapedType(), llvm::map_range(getValues<Attribute>(), 420 static_cast<T (*)(Attribute)>(castFn))}; 421 } 422 template <typename T, typename = DerivedAttrValueCheckT<T>> 423 DerivedAttrValueIterator<T> value_begin() const { 424 return getValues<T>().begin(); 425 } 426 template <typename T, typename = DerivedAttrValueCheckT<T>> 427 DerivedAttrValueIterator<T> value_end() const { 428 return {value_end<Attribute>(), nullptr}; 429 } 430 431 //===------------------------------------------------------------------===// 432 // Failable Value Iteration 433 434 /// If this attribute supports iterating over element values of type `T`, 435 /// return the iterable range. Otherwise, return std::nullopt. 436 template <typename T> 437 DefaultValueCheckT<T, std::optional<iterator_range<T>>> tryGetValues() const { 438 if (std::optional<iterator<T>> beginIt = try_value_begin<T>()) 439 return iterator_range<T>(getShapedType(), *beginIt, value_end<T>()); 440 return std::nullopt; 441 } 442 template <typename T> 443 DefaultValueCheckT<T, std::optional<iterator<T>>> try_value_begin() const; 444 445 /// If this attribute supports iterating over element values of type `T`, 446 /// return the iterable range. Otherwise, return std::nullopt. 447 template <typename T, typename = DerivedAttrValueCheckT<T>> 448 std::optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const { 449 auto values = tryGetValues<Attribute>(); 450 if (!values) 451 return std::nullopt; 452 453 auto castFn = [](Attribute attr) { return ::llvm::cast<T>(attr); }; 454 return DerivedAttrValueIteratorRange<T>( 455 getShapedType(), 456 llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn)) 457 ); 458 } 459 template <typename T, typename = DerivedAttrValueCheckT<T>> 460 std::optional<DerivedAttrValueIterator<T>> try_value_begin() const { 461 if (auto values = tryGetValues<T>()) 462 return values->begin(); 463 return std::nullopt; 464 } 465 }] # ElementsAttrInterfaceAccessors; 466} 467 468//===----------------------------------------------------------------------===// 469// MemRefLayoutAttrInterface 470//===----------------------------------------------------------------------===// 471 472def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> { 473 let cppNamespace = "::mlir"; 474 475 let description = [{ 476 This interface is used for attributes that can represent the MemRef type's 477 layout semantics, such as dimension order in the memory, strides and offsets. 478 Such a layout attribute should be representable as a 479 [semi-affine map](Affine.md/#semi-affine-maps). 480 481 Note: the MemRef type's layout is assumed to represent simple strided buffer 482 layout. For more complicated case, like sparse storage buffers, 483 it is preferable to use separate type with more specic layout, rather then 484 introducing extra complexity to the builin MemRef type. 485 }]; 486 487 let methods = [ 488 InterfaceMethod< 489 "Get the MemRef layout as an AffineMap, the method must not return NULL", 490 "::mlir::AffineMap", "getAffineMap", (ins) 491 >, 492 493 InterfaceMethod< 494 "Return true if this attribute represents the identity layout", 495 "bool", "isIdentity", (ins), 496 [{}], 497 [{ 498 return $_attr.getAffineMap().isIdentity(); 499 }] 500 >, 501 502 InterfaceMethod< 503 "Check if the current layout is applicable to the provided shape", 504 "::llvm::LogicalResult", "verifyLayout", 505 (ins "::llvm::ArrayRef<int64_t>":$shape, 506 "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError), 507 [{}], 508 [{ 509 return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(), 510 shape, emitError); 511 }] 512 > 513 ]; 514} 515 516#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ 517