xref: /llvm-project/mlir/include/mlir/IR/BuiltinTypes.h (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11 
12 #include "mlir/IR/BuiltinAttributeInterfaces.h"
13 #include "mlir/IR/BuiltinTypeInterfaces.h"
14 #include "mlir/Support/ADTExtras.h"
15 
16 namespace llvm {
17 class BitVector;
18 struct fltSemantics;
19 } // namespace llvm
20 
21 //===----------------------------------------------------------------------===//
22 // Tablegen Interface Declarations
23 //===----------------------------------------------------------------------===//
24 
25 namespace mlir {
26 class AffineExpr;
27 class AffineMap;
28 class IndexType;
29 class IntegerType;
30 class MemRefType;
31 class RankedTensorType;
32 class StringAttr;
33 class TypeRange;
34 
35 namespace detail {
36 struct FunctionTypeStorage;
37 struct IntegerTypeStorage;
38 struct TupleTypeStorage;
39 } // namespace detail
40 
41 /// Type trait indicating that the type has value semantics.
42 template <typename ConcreteType>
43 class ValueSemantics
44     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
45 
46 //===----------------------------------------------------------------------===//
47 // TensorType
48 //===----------------------------------------------------------------------===//
49 
50 /// Tensor types represent multi-dimensional arrays, and have two variants:
51 /// RankedTensorType and UnrankedTensorType.
52 /// Note: This class attaches the ShapedType trait to act as a mixin to
53 ///       provide many useful utility functions. This inheritance has no effect
54 ///       on derived tensor types.
55 class TensorType : public Type, public ShapedType::Trait<TensorType> {
56 public:
57   using Type::Type;
58 
59   /// Returns the element type of this tensor type.
60   Type getElementType() const;
61 
62   /// Returns if this type is ranked, i.e. it has a known number of dimensions.
63   bool hasRank() const;
64 
65   /// Returns the shape of this tensor type.
66   ArrayRef<int64_t> getShape() const;
67 
68   /// Clone this type with the given shape and element type. If the
69   /// provided shape is `std::nullopt`, the current shape of the type is used.
70   TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
71                        Type elementType) const;
72 
73   // Make sure that base class overloads are visible.
74   using ShapedType::Trait<TensorType>::clone;
75 
76   /// Return a clone of this type with the given new shape and element type.
77   /// The returned type is ranked, even if this type is unranked.
78   RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
79 
80   /// Return a clone of this type with the given new shape. The returned type
81   /// is ranked, even if this type is unranked.
82   RankedTensorType clone(ArrayRef<int64_t> shape) const;
83 
84   /// Return true if the specified element type is ok in a tensor.
85   static bool isValidElementType(Type type);
86 
87   /// Methods for support type inquiry through isa, cast, and dyn_cast.
88   static bool classof(Type type);
89 
90   /// Allow implicit conversion to ShapedType.
91   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
92 };
93 
94 //===----------------------------------------------------------------------===//
95 // BaseMemRefType
96 //===----------------------------------------------------------------------===//
97 
98 /// This class provides a shared interface for ranked and unranked memref types.
99 /// Note: This class attaches the ShapedType trait to act as a mixin to
100 ///       provide many useful utility functions. This inheritance has no effect
101 ///       on derived memref types.
102 class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
103 public:
104   using Type::Type;
105 
106   /// Returns the element type of this memref type.
107   Type getElementType() const;
108 
109   /// Returns if this type is ranked, i.e. it has a known number of dimensions.
110   bool hasRank() const;
111 
112   /// Returns the shape of this memref type.
113   ArrayRef<int64_t> getShape() const;
114 
115   /// Clone this type with the given shape and element type. If the
116   /// provided shape is `std::nullopt`, the current shape of the type is used.
117   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
118                            Type elementType) const;
119 
120   // Make sure that base class overloads are visible.
121   using ShapedType::Trait<BaseMemRefType>::clone;
122 
123   /// Return a clone of this type with the given new shape and element type.
124   /// The returned type is ranked, even if this type is unranked.
125   MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
126 
127   /// Return a clone of this type with the given new shape. The returned type
128   /// is ranked, even if this type is unranked.
129   MemRefType clone(ArrayRef<int64_t> shape) const;
130 
131   /// Return true if the specified element type is ok in a memref.
132   static bool isValidElementType(Type type);
133 
134   /// Methods for support type inquiry through isa, cast, and dyn_cast.
135   static bool classof(Type type);
136 
137   /// Returns the memory space in which data referred to by this memref resides.
138   Attribute getMemorySpace() const;
139 
140   /// [deprecated] Returns the memory space in old raw integer representation.
141   /// New `Attribute getMemorySpace()` method should be used instead.
142   unsigned getMemorySpaceAsInt() const;
143 
144   /// Allow implicit conversion to ShapedType.
145   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
146 };
147 
148 } // namespace mlir
149 
150 //===----------------------------------------------------------------------===//
151 // Tablegen Type Declarations
152 //===----------------------------------------------------------------------===//
153 
154 #define GET_TYPEDEF_CLASSES
155 #include "mlir/IR/BuiltinTypes.h.inc"
156 
157 namespace mlir {
158 #include "mlir/IR/BuiltinTypeConstraints.h.inc"
159 
160 //===----------------------------------------------------------------------===//
161 // MemRefType
162 //===----------------------------------------------------------------------===//
163 
164 /// This is a builder type that keeps local references to arguments. Arguments
165 /// that are passed into the builder must outlive the builder.
166 class MemRefType::Builder {
167 public:
168   // Build from another MemRefType.
169   explicit Builder(MemRefType other)
170       : shape(other.getShape()), elementType(other.getElementType()),
171         layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
172 
173   // Build from scratch.
174   Builder(ArrayRef<int64_t> shape, Type elementType)
175       : shape(shape), elementType(elementType) {}
176 
177   Builder &setShape(ArrayRef<int64_t> newShape) {
178     shape = newShape;
179     return *this;
180   }
181 
182   Builder &setElementType(Type newElementType) {
183     elementType = newElementType;
184     return *this;
185   }
186 
187   Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
188     layout = newLayout;
189     return *this;
190   }
191 
192   Builder &setMemorySpace(Attribute newMemorySpace) {
193     memorySpace = newMemorySpace;
194     return *this;
195   }
196 
197   operator MemRefType() {
198     return MemRefType::get(shape, elementType, layout, memorySpace);
199   }
200 
201 private:
202   ArrayRef<int64_t> shape;
203   Type elementType;
204   MemRefLayoutAttrInterface layout;
205   Attribute memorySpace;
206 };
207 
208 //===----------------------------------------------------------------------===//
209 // RankedTensorType
210 //===----------------------------------------------------------------------===//
211 
212 /// This is a builder type that keeps local references to arguments. Arguments
213 /// that are passed into the builder must outlive the builder.
214 class RankedTensorType::Builder {
215 public:
216   /// Build from another RankedTensorType.
217   explicit Builder(RankedTensorType other)
218       : shape(other.getShape()), elementType(other.getElementType()),
219         encoding(other.getEncoding()) {}
220 
221   /// Build from scratch.
222   Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
223       : shape(shape), elementType(elementType), encoding(encoding) {}
224 
225   Builder &setShape(ArrayRef<int64_t> newShape) {
226     shape = newShape;
227     return *this;
228   }
229 
230   Builder &setElementType(Type newElementType) {
231     elementType = newElementType;
232     return *this;
233   }
234 
235   Builder &setEncoding(Attribute newEncoding) {
236     encoding = newEncoding;
237     return *this;
238   }
239 
240   /// Erase a dim from shape @pos.
241   Builder &dropDim(unsigned pos) {
242     assert(pos < shape.size() && "overflow");
243     shape.erase(pos);
244     return *this;
245   }
246 
247   /// Insert a val into shape @pos.
248   Builder &insertDim(int64_t val, unsigned pos) {
249     assert(pos <= shape.size() && "overflow");
250     shape.insert(pos, val);
251     return *this;
252   }
253 
254   operator RankedTensorType() {
255     return RankedTensorType::get(shape, elementType, encoding);
256   }
257 
258 private:
259   CopyOnWriteArrayRef<int64_t> shape;
260   Type elementType;
261   Attribute encoding;
262 };
263 
264 //===----------------------------------------------------------------------===//
265 // VectorType
266 //===----------------------------------------------------------------------===//
267 
268 /// This is a builder type that keeps local references to arguments. Arguments
269 /// that are passed into the builder must outlive the builder.
270 class VectorType::Builder {
271 public:
272   /// Build from another VectorType.
273   explicit Builder(VectorType other)
274       : elementType(other.getElementType()), shape(other.getShape()),
275         scalableDims(other.getScalableDims()) {}
276 
277   /// Build from scratch.
278   Builder(ArrayRef<int64_t> shape, Type elementType,
279           ArrayRef<bool> scalableDims = {})
280       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
281 
282   Builder &setShape(ArrayRef<int64_t> newShape,
283                     ArrayRef<bool> newIsScalableDim = {}) {
284     shape = newShape;
285     scalableDims = newIsScalableDim;
286     return *this;
287   }
288 
289   Builder &setElementType(Type newElementType) {
290     elementType = newElementType;
291     return *this;
292   }
293 
294   /// Erase a dim from shape @pos.
295   Builder &dropDim(unsigned pos) {
296     assert(pos < shape.size() && "overflow");
297     shape.erase(pos);
298     if (!scalableDims.empty())
299       scalableDims.erase(pos);
300     return *this;
301   }
302 
303   /// Set a dim in shape @pos to val.
304   Builder &setDim(unsigned pos, int64_t val) {
305     assert(pos < shape.size() && "overflow");
306     shape.set(pos, val);
307     return *this;
308   }
309 
310   operator VectorType() {
311     return VectorType::get(shape, elementType, scalableDims);
312   }
313 
314 private:
315   Type elementType;
316   CopyOnWriteArrayRef<int64_t> shape;
317   CopyOnWriteArrayRef<bool> scalableDims;
318 };
319 
320 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
321 /// `originalShape` with some `1` entries erased, return the set of indices
322 /// that specifies which of the entries of `originalShape` are dropped to obtain
323 /// `reducedShape`. The returned mask can be applied as a projection to
324 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
325 /// which dimensions must be kept when e.g. compute MemRef strides under
326 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
327 /// obtained by dropping only `1` entries in `originalShape`.
328 /// If `matchDynamic` is true, then dynamic dims in `originalShape` and
329 /// `reducedShape` will be considered matching with non-dynamic dims, unless
330 /// the non-dynamic dim is from `originalShape` and equal to 1. For example,
331 /// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
332 /// match with the corresponding dynamic dims.
333 std::optional<llvm::SmallDenseSet<unsigned>>
334 computeRankReductionMask(ArrayRef<int64_t> originalShape,
335                          ArrayRef<int64_t> reducedShape,
336                          bool matchDynamic = false);
337 
338 /// Enum that captures information related to verifier error conditions on
339 /// slice insert/extract type of ops.
340 enum class SliceVerificationResult {
341   Success,
342   RankTooLarge,
343   SizeMismatch,
344   ElemTypeMismatch,
345   // Error codes to ops with a memory space and a layout annotation.
346   MemSpaceMismatch,
347   LayoutMismatch
348 };
349 
350 /// Check if `originalType` can be rank reduced to `candidateReducedType` type
351 /// by dropping some dimensions with static size `1`.
352 /// Return `SliceVerificationResult::Success` on success or an appropriate error
353 /// code.
354 SliceVerificationResult isRankReducedType(ShapedType originalType,
355                                           ShapedType candidateReducedType);
356 
357 //===----------------------------------------------------------------------===//
358 // Convenience wrappers for VectorType
359 //
360 // These are provided to allow idiomatic code like:
361 //  * isa<vector::ScalableVectorType>(type)
362 //===----------------------------------------------------------------------===//
363 /// A vector type containing at least one scalable dimension.
364 class ScalableVectorType : public VectorType {
365 public:
366   using VectorType::VectorType;
367 
368   static bool classof(Type type) {
369     auto vecTy = llvm::dyn_cast<VectorType>(type);
370     if (!vecTy)
371       return false;
372     return vecTy.isScalable();
373   }
374 };
375 
376 /// A vector type with no scalable dimensions.
377 class FixedVectorType : public VectorType {
378 public:
379   using VectorType::VectorType;
380 
381   static bool classof(Type type) {
382     auto vecTy = llvm::dyn_cast<VectorType>(type);
383     if (!vecTy)
384       return false;
385     return !vecTy.isScalable();
386   }
387 };
388 
389 //===----------------------------------------------------------------------===//
390 // Deferred Method Definitions
391 //===----------------------------------------------------------------------===//
392 
393 inline bool BaseMemRefType::classof(Type type) {
394   return llvm::isa<MemRefType, UnrankedMemRefType>(type);
395 }
396 
397 inline bool BaseMemRefType::isValidElementType(Type type) {
398   return type.isIntOrIndexOrFloat() ||
399          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
400              type) ||
401          llvm::isa<MemRefElementTypeInterface>(type);
402 }
403 
404 inline bool TensorType::classof(Type type) {
405   return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // Type Utilities
410 //===----------------------------------------------------------------------===//
411 
412 /// Given MemRef `sizes` that are either static or dynamic, returns the
413 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
414 /// once a dynamic dimension is encountered, all canonical strides become
415 /// dynamic and need to be encoded with a different symbol.
416 /// For canonical strides expressions, the offset is always 0 and the fastest
417 /// varying stride is always `1`.
418 ///
419 /// Examples:
420 ///   - memref<3x4x5xf32> has canonical stride expression
421 ///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
422 ///   - memref<3x?x5xf32> has canonical stride expression
423 ///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
424 ///   - memref<3x4x?xf32> has canonical stride expression
425 ///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
426 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
427                                           ArrayRef<AffineExpr> exprs,
428                                           MLIRContext *context);
429 
430 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
431 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
432 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
433                                           MLIRContext *context);
434 } // namespace mlir
435 
436 #endif // MLIR_IR_BUILTINTYPES_H
437