xref: /llvm-project/mlir/lib/IR/TypeUtilities.cpp (revision f4d758634305304c0deb49a4ed3f99180a2488ea)
1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/SmallVectorExtras.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 Type mlir::getElementTypeOrSelf(Type type) {
24   if (auto st = llvm::dyn_cast<ShapedType>(type))
25     return st.getElementType();
26   return type;
27 }
28 
29 Type mlir::getElementTypeOrSelf(Value val) {
30   return getElementTypeOrSelf(val.getType());
31 }
32 
33 Type mlir::getElementTypeOrSelf(Attribute attr) {
34   if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
35     return getElementTypeOrSelf(typedAttr.getType());
36   return {};
37 }
38 
39 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
40   SmallVector<Type, 10> fTypes;
41   t.getFlattenedTypes(fTypes);
42   return fTypes;
43 }
44 
45 /// Return true if the specified type is an opaque type with the specified
46 /// dialect and typeData.
47 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
48                                 StringRef typeData) {
49   if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
50     return opaque.getDialectNamespace() == dialect &&
51            opaque.getTypeData() == typeData;
52   return false;
53 }
54 
55 /// Returns success if the given two shapes are compatible. That is, they have
56 /// the same size and each pair of the elements are equal or one of them is
57 /// dynamic.
58 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
59                                           ArrayRef<int64_t> shape2) {
60   if (shape1.size() != shape2.size())
61     return failure();
62   for (auto dims : llvm::zip(shape1, shape2)) {
63     int64_t dim1 = std::get<0>(dims);
64     int64_t dim2 = std::get<1>(dims);
65     if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
66         dim1 != dim2)
67       return failure();
68   }
69   return success();
70 }
71 
72 /// Returns success if the given two types have compatible shape. That is,
73 /// they are both scalars (not shaped), or they are both shaped types and at
74 /// least one is unranked or they have compatible dimensions. Dimensions are
75 /// compatible if at least one is dynamic or both are equal. The element type
76 /// does not matter.
77 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
78   auto sType1 = llvm::dyn_cast<ShapedType>(type1);
79   auto sType2 = llvm::dyn_cast<ShapedType>(type2);
80 
81   // Either both or neither type should be shaped.
82   if (!sType1)
83     return success(!sType2);
84   if (!sType2)
85     return failure();
86 
87   if (!sType1.hasRank() || !sType2.hasRank())
88     return success();
89 
90   return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
91 }
92 
93 /// Returns success if the given two arrays have the same number of elements and
94 /// each pair wise entries have compatible shape.
95 LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
96   if (types1.size() != types2.size())
97     return failure();
98   for (auto it : llvm::zip_first(types1, types2))
99     if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
100       return failure();
101   return success();
102 }
103 
104 LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
105   if (dims.empty())
106     return success();
107   auto staticDim = std::accumulate(
108       dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
109         return ShapedType::isDynamic(dim) ? fold : dim;
110       });
111   return success(llvm::all_of(dims, [&](auto dim) {
112     return ShapedType::isDynamic(dim) || dim == staticDim;
113   }));
114 }
115 
116 /// Returns success if all given types have compatible shapes. That is, they are
117 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
118 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
119 /// dims are equal. The element type does not matter.
120 LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
121   auto shapedTypes = llvm::map_to_vector<8>(
122       types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
123   // Return failure if some, but not all are not shaped. Return early if none
124   // are shaped also.
125   if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
126     return success();
127   if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
128     return failure();
129 
130   // Return failure if some, but not all, are scalable vectors.
131   bool hasScalableVecTypes = false;
132   bool hasNonScalableVecTypes = false;
133   for (Type t : types) {
134     auto vType = llvm::dyn_cast<VectorType>(t);
135     if (vType && vType.isScalable())
136       hasScalableVecTypes = true;
137     else
138       hasNonScalableVecTypes = true;
139     if (hasScalableVecTypes && hasNonScalableVecTypes)
140       return failure();
141   }
142 
143   // Remove all unranked shapes
144   auto shapes = llvm::filter_to_vector<8>(
145       shapedTypes, [](auto shapedType) { return shapedType.hasRank(); });
146   if (shapes.empty())
147     return success();
148 
149   // All ranks should be equal
150   auto firstRank = shapes.front().getRank();
151   if (llvm::any_of(shapes,
152                    [&](auto shape) { return firstRank != shape.getRank(); }))
153     return failure();
154 
155   for (unsigned i = 0; i < firstRank; ++i) {
156     // Retrieve all ranked dimensions
157     auto dims = llvm::map_to_vector<8>(
158         llvm::make_filter_range(
159             shapes, [&](auto shape) { return shape.getRank() >= i; }),
160         [&](auto shape) { return shape.getDimSize(i); });
161     if (verifyCompatibleDims(dims).failed())
162       return failure();
163   }
164 
165   return success();
166 }
167 
168 Type OperandElementTypeIterator::mapElement(Value value) const {
169   return llvm::cast<ShapedType>(value.getType()).getElementType();
170 }
171 
172 Type ResultElementTypeIterator::mapElement(Value value) const {
173   return llvm::cast<ShapedType>(value.getType()).getElementType();
174 }
175 
176 TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
177                                 TypeRange newTypes,
178                                 SmallVectorImpl<Type> &storage) {
179   assert(indices.size() == newTypes.size() &&
180          "mismatch between indice and type count");
181   if (indices.empty())
182     return oldTypes;
183 
184   auto fromIt = oldTypes.begin();
185   for (auto it : llvm::zip(indices, newTypes)) {
186     const auto toIt = oldTypes.begin() + std::get<0>(it);
187     storage.append(fromIt, toIt);
188     storage.push_back(std::get<1>(it));
189     fromIt = toIt;
190   }
191   storage.append(fromIt, oldTypes.end());
192   return storage;
193 }
194 
195 TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices,
196                                SmallVectorImpl<Type> &storage) {
197   if (indices.none())
198     return types;
199 
200   for (unsigned i = 0, e = types.size(); i < e; ++i)
201     if (!indices[i])
202       storage.emplace_back(types[i]);
203   return storage;
204 }
205