xref: /llvm-project/mlir/include/mlir/IR/TypeUtilities.h (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- TypeUtilities.h - Helper function for type queries -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPEUTILITIES_H
14 #define MLIR_IR_TYPEUTILITIES_H
15 
16 #include "mlir/IR/Operation.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 
21 class Attribute;
22 class TupleType;
23 class Type;
24 class TypeRange;
25 class Value;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility Functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Return the element type or return the type itself.
32 Type getElementTypeOrSelf(Type type);
33 
34 /// Return the element type or return the type itself.
35 Type getElementTypeOrSelf(Attribute attr);
36 Type getElementTypeOrSelf(Value val);
37 
38 /// Get the types within a nested Tuple. A helper for the class method that
39 /// handles storage concerns, which is tricky to do in tablegen.
40 SmallVector<Type, 10> getFlattenedTypes(TupleType t);
41 
42 /// Return true if the specified type is an opaque type with the specified
43 /// dialect and typeData.
44 bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);
45 
46 /// Returns success if the given two shapes are compatible. That is, they have
47 /// the same size and each pair of the elements are equal or one of them is
48 /// dynamic.
49 LogicalResult verifyCompatibleShape(ArrayRef<int64_t> shape1,
50                                     ArrayRef<int64_t> shape2);
51 
52 /// Returns success if the given two types have compatible shape. That is,
53 /// they are both scalars (not shaped), or they are both shaped types and at
54 /// least one is unranked or they have compatible dimensions. Dimensions are
55 /// compatible if at least one is dynamic or both are equal. The element type
56 /// does not matter.
57 LogicalResult verifyCompatibleShape(Type type1, Type type2);
58 
59 /// Returns success if the given two arrays have the same number of elements and
60 /// each pair wise entries have compatible shape.
61 LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2);
62 
63 /// Returns success if all given types have compatible shapes. That is, they are
64 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
65 /// have compatible dimensions. The element type does not matter.
66 LogicalResult verifyCompatibleShapes(TypeRange types);
67 
68 /// Dimensions are compatible if all non-dynamic dims are equal.
69 LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
70 
71 /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
72 /// types are inserted, `storage` is used to hold the new type list. The new
73 /// type list is returned. `indices` must be sorted by increasing index.
74 TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
75                           TypeRange newTypes, SmallVectorImpl<Type> &storage);
76 
77 /// Filters out any elements referenced by `indices`. If any types are removed,
78 /// `storage` is used to hold the new type list. Returns the new type list.
79 TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
80                          SmallVectorImpl<Type> &storage);
81 
82 //===----------------------------------------------------------------------===//
83 // Utility Iterators
84 //===----------------------------------------------------------------------===//
85 
86 // An iterator for the element types of an op's operands of shaped types.
87 class OperandElementTypeIterator final
88     : public llvm::mapped_iterator_base<OperandElementTypeIterator,
89                                         Operation::operand_iterator, Type> {
90 public:
91   using BaseT::BaseT;
92 
93   /// Map the element to the iterator result type.
94   Type mapElement(Value value) const;
95 };
96 
97 using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
98 
99 // An iterator for the tensor element types of an op's results of shaped types.
100 class ResultElementTypeIterator final
101     : public llvm::mapped_iterator_base<ResultElementTypeIterator,
102                                         Operation::result_iterator, Type> {
103 public:
104   using BaseT::BaseT;
105 
106   /// Map the element to the iterator result type.
107   Type mapElement(Value value) const;
108 };
109 
110 using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;
111 
112 } // namespace mlir
113 
114 #endif // MLIR_IR_TYPEUTILITIES_H
115