1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_BUILTINTYPES_H 10 #define MLIR_IR_BUILTINTYPES_H 11 12 #include "mlir/IR/BuiltinAttributeInterfaces.h" 13 #include "mlir/IR/BuiltinTypeInterfaces.h" 14 #include "mlir/Support/ADTExtras.h" 15 16 namespace llvm { 17 class BitVector; 18 struct fltSemantics; 19 } // namespace llvm 20 21 //===----------------------------------------------------------------------===// 22 // Tablegen Interface Declarations 23 //===----------------------------------------------------------------------===// 24 25 namespace mlir { 26 class AffineExpr; 27 class AffineMap; 28 class IndexType; 29 class IntegerType; 30 class MemRefType; 31 class RankedTensorType; 32 class StringAttr; 33 class TypeRange; 34 35 namespace detail { 36 struct FunctionTypeStorage; 37 struct IntegerTypeStorage; 38 struct TupleTypeStorage; 39 } // namespace detail 40 41 /// Type trait indicating that the type has value semantics. 42 template <typename ConcreteType> 43 class ValueSemantics 44 : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {}; 45 46 //===----------------------------------------------------------------------===// 47 // TensorType 48 //===----------------------------------------------------------------------===// 49 50 /// Tensor types represent multi-dimensional arrays, and have two variants: 51 /// RankedTensorType and UnrankedTensorType. 52 /// Note: This class attaches the ShapedType trait to act as a mixin to 53 /// provide many useful utility functions. This inheritance has no effect 54 /// on derived tensor types. 55 class TensorType : public Type, public ShapedType::Trait<TensorType> { 56 public: 57 using Type::Type; 58 59 /// Returns the element type of this tensor type. 60 Type getElementType() const; 61 62 /// Returns if this type is ranked, i.e. it has a known number of dimensions. 63 bool hasRank() const; 64 65 /// Returns the shape of this tensor type. 66 ArrayRef<int64_t> getShape() const; 67 68 /// Clone this type with the given shape and element type. If the 69 /// provided shape is `std::nullopt`, the current shape of the type is used. 70 TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, 71 Type elementType) const; 72 73 // Make sure that base class overloads are visible. 74 using ShapedType::Trait<TensorType>::clone; 75 76 /// Return a clone of this type with the given new shape and element type. 77 /// The returned type is ranked, even if this type is unranked. 78 RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const; 79 80 /// Return a clone of this type with the given new shape. The returned type 81 /// is ranked, even if this type is unranked. 82 RankedTensorType clone(ArrayRef<int64_t> shape) const; 83 84 /// Return true if the specified element type is ok in a tensor. 85 static bool isValidElementType(Type type); 86 87 /// Methods for support type inquiry through isa, cast, and dyn_cast. 88 static bool classof(Type type); 89 90 /// Allow implicit conversion to ShapedType. 91 operator ShapedType() const { return llvm::cast<ShapedType>(*this); } 92 }; 93 94 //===----------------------------------------------------------------------===// 95 // BaseMemRefType 96 //===----------------------------------------------------------------------===// 97 98 /// This class provides a shared interface for ranked and unranked memref types. 99 /// Note: This class attaches the ShapedType trait to act as a mixin to 100 /// provide many useful utility functions. This inheritance has no effect 101 /// on derived memref types. 102 class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> { 103 public: 104 using Type::Type; 105 106 /// Returns the element type of this memref type. 107 Type getElementType() const; 108 109 /// Returns if this type is ranked, i.e. it has a known number of dimensions. 110 bool hasRank() const; 111 112 /// Returns the shape of this memref type. 113 ArrayRef<int64_t> getShape() const; 114 115 /// Clone this type with the given shape and element type. If the 116 /// provided shape is `std::nullopt`, the current shape of the type is used. 117 BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, 118 Type elementType) const; 119 120 // Make sure that base class overloads are visible. 121 using ShapedType::Trait<BaseMemRefType>::clone; 122 123 /// Return a clone of this type with the given new shape and element type. 124 /// The returned type is ranked, even if this type is unranked. 125 MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const; 126 127 /// Return a clone of this type with the given new shape. The returned type 128 /// is ranked, even if this type is unranked. 129 MemRefType clone(ArrayRef<int64_t> shape) const; 130 131 /// Return true if the specified element type is ok in a memref. 132 static bool isValidElementType(Type type); 133 134 /// Methods for support type inquiry through isa, cast, and dyn_cast. 135 static bool classof(Type type); 136 137 /// Returns the memory space in which data referred to by this memref resides. 138 Attribute getMemorySpace() const; 139 140 /// [deprecated] Returns the memory space in old raw integer representation. 141 /// New `Attribute getMemorySpace()` method should be used instead. 142 unsigned getMemorySpaceAsInt() const; 143 144 /// Allow implicit conversion to ShapedType. 145 operator ShapedType() const { return llvm::cast<ShapedType>(*this); } 146 }; 147 148 } // namespace mlir 149 150 //===----------------------------------------------------------------------===// 151 // Tablegen Type Declarations 152 //===----------------------------------------------------------------------===// 153 154 #define GET_TYPEDEF_CLASSES 155 #include "mlir/IR/BuiltinTypes.h.inc" 156 157 namespace mlir { 158 #include "mlir/IR/BuiltinTypeConstraints.h.inc" 159 160 //===----------------------------------------------------------------------===// 161 // MemRefType 162 //===----------------------------------------------------------------------===// 163 164 /// This is a builder type that keeps local references to arguments. Arguments 165 /// that are passed into the builder must outlive the builder. 166 class MemRefType::Builder { 167 public: 168 // Build from another MemRefType. 169 explicit Builder(MemRefType other) 170 : shape(other.getShape()), elementType(other.getElementType()), 171 layout(other.getLayout()), memorySpace(other.getMemorySpace()) {} 172 173 // Build from scratch. 174 Builder(ArrayRef<int64_t> shape, Type elementType) 175 : shape(shape), elementType(elementType) {} 176 177 Builder &setShape(ArrayRef<int64_t> newShape) { 178 shape = newShape; 179 return *this; 180 } 181 182 Builder &setElementType(Type newElementType) { 183 elementType = newElementType; 184 return *this; 185 } 186 187 Builder &setLayout(MemRefLayoutAttrInterface newLayout) { 188 layout = newLayout; 189 return *this; 190 } 191 192 Builder &setMemorySpace(Attribute newMemorySpace) { 193 memorySpace = newMemorySpace; 194 return *this; 195 } 196 197 operator MemRefType() { 198 return MemRefType::get(shape, elementType, layout, memorySpace); 199 } 200 201 private: 202 ArrayRef<int64_t> shape; 203 Type elementType; 204 MemRefLayoutAttrInterface layout; 205 Attribute memorySpace; 206 }; 207 208 //===----------------------------------------------------------------------===// 209 // RankedTensorType 210 //===----------------------------------------------------------------------===// 211 212 /// This is a builder type that keeps local references to arguments. Arguments 213 /// that are passed into the builder must outlive the builder. 214 class RankedTensorType::Builder { 215 public: 216 /// Build from another RankedTensorType. 217 explicit Builder(RankedTensorType other) 218 : shape(other.getShape()), elementType(other.getElementType()), 219 encoding(other.getEncoding()) {} 220 221 /// Build from scratch. 222 Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding) 223 : shape(shape), elementType(elementType), encoding(encoding) {} 224 225 Builder &setShape(ArrayRef<int64_t> newShape) { 226 shape = newShape; 227 return *this; 228 } 229 230 Builder &setElementType(Type newElementType) { 231 elementType = newElementType; 232 return *this; 233 } 234 235 Builder &setEncoding(Attribute newEncoding) { 236 encoding = newEncoding; 237 return *this; 238 } 239 240 /// Erase a dim from shape @pos. 241 Builder &dropDim(unsigned pos) { 242 assert(pos < shape.size() && "overflow"); 243 shape.erase(pos); 244 return *this; 245 } 246 247 /// Insert a val into shape @pos. 248 Builder &insertDim(int64_t val, unsigned pos) { 249 assert(pos <= shape.size() && "overflow"); 250 shape.insert(pos, val); 251 return *this; 252 } 253 254 operator RankedTensorType() { 255 return RankedTensorType::get(shape, elementType, encoding); 256 } 257 258 private: 259 CopyOnWriteArrayRef<int64_t> shape; 260 Type elementType; 261 Attribute encoding; 262 }; 263 264 //===----------------------------------------------------------------------===// 265 // VectorType 266 //===----------------------------------------------------------------------===// 267 268 /// This is a builder type that keeps local references to arguments. Arguments 269 /// that are passed into the builder must outlive the builder. 270 class VectorType::Builder { 271 public: 272 /// Build from another VectorType. 273 explicit Builder(VectorType other) 274 : elementType(other.getElementType()), shape(other.getShape()), 275 scalableDims(other.getScalableDims()) {} 276 277 /// Build from scratch. 278 Builder(ArrayRef<int64_t> shape, Type elementType, 279 ArrayRef<bool> scalableDims = {}) 280 : elementType(elementType), shape(shape), scalableDims(scalableDims) {} 281 282 Builder &setShape(ArrayRef<int64_t> newShape, 283 ArrayRef<bool> newIsScalableDim = {}) { 284 shape = newShape; 285 scalableDims = newIsScalableDim; 286 return *this; 287 } 288 289 Builder &setElementType(Type newElementType) { 290 elementType = newElementType; 291 return *this; 292 } 293 294 /// Erase a dim from shape @pos. 295 Builder &dropDim(unsigned pos) { 296 assert(pos < shape.size() && "overflow"); 297 shape.erase(pos); 298 if (!scalableDims.empty()) 299 scalableDims.erase(pos); 300 return *this; 301 } 302 303 /// Set a dim in shape @pos to val. 304 Builder &setDim(unsigned pos, int64_t val) { 305 assert(pos < shape.size() && "overflow"); 306 shape.set(pos, val); 307 return *this; 308 } 309 310 operator VectorType() { 311 return VectorType::get(shape, elementType, scalableDims); 312 } 313 314 private: 315 Type elementType; 316 CopyOnWriteArrayRef<int64_t> shape; 317 CopyOnWriteArrayRef<bool> scalableDims; 318 }; 319 320 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 321 /// `originalShape` with some `1` entries erased, return the set of indices 322 /// that specifies which of the entries of `originalShape` are dropped to obtain 323 /// `reducedShape`. The returned mask can be applied as a projection to 324 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 325 /// which dimensions must be kept when e.g. compute MemRef strides under 326 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be 327 /// obtained by dropping only `1` entries in `originalShape`. 328 /// If `matchDynamic` is true, then dynamic dims in `originalShape` and 329 /// `reducedShape` will be considered matching with non-dynamic dims, unless 330 /// the non-dynamic dim is from `originalShape` and equal to 1. For example, 331 /// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will 332 /// match with the corresponding dynamic dims. 333 std::optional<llvm::SmallDenseSet<unsigned>> 334 computeRankReductionMask(ArrayRef<int64_t> originalShape, 335 ArrayRef<int64_t> reducedShape, 336 bool matchDynamic = false); 337 338 /// Enum that captures information related to verifier error conditions on 339 /// slice insert/extract type of ops. 340 enum class SliceVerificationResult { 341 Success, 342 RankTooLarge, 343 SizeMismatch, 344 ElemTypeMismatch, 345 // Error codes to ops with a memory space and a layout annotation. 346 MemSpaceMismatch, 347 LayoutMismatch 348 }; 349 350 /// Check if `originalType` can be rank reduced to `candidateReducedType` type 351 /// by dropping some dimensions with static size `1`. 352 /// Return `SliceVerificationResult::Success` on success or an appropriate error 353 /// code. 354 SliceVerificationResult isRankReducedType(ShapedType originalType, 355 ShapedType candidateReducedType); 356 357 //===----------------------------------------------------------------------===// 358 // Convenience wrappers for VectorType 359 // 360 // These are provided to allow idiomatic code like: 361 // * isa<vector::ScalableVectorType>(type) 362 //===----------------------------------------------------------------------===// 363 /// A vector type containing at least one scalable dimension. 364 class ScalableVectorType : public VectorType { 365 public: 366 using VectorType::VectorType; 367 368 static bool classof(Type type) { 369 auto vecTy = llvm::dyn_cast<VectorType>(type); 370 if (!vecTy) 371 return false; 372 return vecTy.isScalable(); 373 } 374 }; 375 376 /// A vector type with no scalable dimensions. 377 class FixedVectorType : public VectorType { 378 public: 379 using VectorType::VectorType; 380 381 static bool classof(Type type) { 382 auto vecTy = llvm::dyn_cast<VectorType>(type); 383 if (!vecTy) 384 return false; 385 return !vecTy.isScalable(); 386 } 387 }; 388 389 //===----------------------------------------------------------------------===// 390 // Deferred Method Definitions 391 //===----------------------------------------------------------------------===// 392 393 inline bool BaseMemRefType::classof(Type type) { 394 return llvm::isa<MemRefType, UnrankedMemRefType>(type); 395 } 396 397 inline bool BaseMemRefType::isValidElementType(Type type) { 398 return type.isIntOrIndexOrFloat() || 399 llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>( 400 type) || 401 llvm::isa<MemRefElementTypeInterface>(type); 402 } 403 404 inline bool TensorType::classof(Type type) { 405 return llvm::isa<RankedTensorType, UnrankedTensorType>(type); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // Type Utilities 410 //===----------------------------------------------------------------------===// 411 412 /// Given MemRef `sizes` that are either static or dynamic, returns the 413 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and 414 /// once a dynamic dimension is encountered, all canonical strides become 415 /// dynamic and need to be encoded with a different symbol. 416 /// For canonical strides expressions, the offset is always 0 and the fastest 417 /// varying stride is always `1`. 418 /// 419 /// Examples: 420 /// - memref<3x4x5xf32> has canonical stride expression 421 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`. 422 /// - memref<3x?x5xf32> has canonical stride expression 423 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`. 424 /// - memref<3x4x?xf32> has canonical stride expression 425 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`. 426 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 427 ArrayRef<AffineExpr> exprs, 428 MLIRContext *context); 429 430 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case 431 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} 432 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 433 MLIRContext *context); 434 } // namespace mlir 435 436 #endif // MLIR_IR_BUILTINTYPES_H 437