1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines generic type utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/IR/TypeUtilities.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/ADT/SmallVectorExtras.h" 19 #include <numeric> 20 21 using namespace mlir; 22 23 Type mlir::getElementTypeOrSelf(Type type) { 24 if (auto st = llvm::dyn_cast<ShapedType>(type)) 25 return st.getElementType(); 26 return type; 27 } 28 29 Type mlir::getElementTypeOrSelf(Value val) { 30 return getElementTypeOrSelf(val.getType()); 31 } 32 33 Type mlir::getElementTypeOrSelf(Attribute attr) { 34 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) 35 return getElementTypeOrSelf(typedAttr.getType()); 36 return {}; 37 } 38 39 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { 40 SmallVector<Type, 10> fTypes; 41 t.getFlattenedTypes(fTypes); 42 return fTypes; 43 } 44 45 /// Return true if the specified type is an opaque type with the specified 46 /// dialect and typeData. 47 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, 48 StringRef typeData) { 49 if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type)) 50 return opaque.getDialectNamespace() == dialect && 51 opaque.getTypeData() == typeData; 52 return false; 53 } 54 55 /// Returns success if the given two shapes are compatible. That is, they have 56 /// the same size and each pair of the elements are equal or one of them is 57 /// dynamic. 58 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, 59 ArrayRef<int64_t> shape2) { 60 if (shape1.size() != shape2.size()) 61 return failure(); 62 for (auto dims : llvm::zip(shape1, shape2)) { 63 int64_t dim1 = std::get<0>(dims); 64 int64_t dim2 = std::get<1>(dims); 65 if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && 66 dim1 != dim2) 67 return failure(); 68 } 69 return success(); 70 } 71 72 /// Returns success if the given two types have compatible shape. That is, 73 /// they are both scalars (not shaped), or they are both shaped types and at 74 /// least one is unranked or they have compatible dimensions. Dimensions are 75 /// compatible if at least one is dynamic or both are equal. The element type 76 /// does not matter. 77 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { 78 auto sType1 = llvm::dyn_cast<ShapedType>(type1); 79 auto sType2 = llvm::dyn_cast<ShapedType>(type2); 80 81 // Either both or neither type should be shaped. 82 if (!sType1) 83 return success(!sType2); 84 if (!sType2) 85 return failure(); 86 87 if (!sType1.hasRank() || !sType2.hasRank()) 88 return success(); 89 90 return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); 91 } 92 93 /// Returns success if the given two arrays have the same number of elements and 94 /// each pair wise entries have compatible shape. 95 LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { 96 if (types1.size() != types2.size()) 97 return failure(); 98 for (auto it : llvm::zip_first(types1, types2)) 99 if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it)))) 100 return failure(); 101 return success(); 102 } 103 104 LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { 105 if (dims.empty()) 106 return success(); 107 auto staticDim = std::accumulate( 108 dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { 109 return ShapedType::isDynamic(dim) ? fold : dim; 110 }); 111 return success(llvm::all_of(dims, [&](auto dim) { 112 return ShapedType::isDynamic(dim) || dim == staticDim; 113 })); 114 } 115 116 /// Returns success if all given types have compatible shapes. That is, they are 117 /// all scalars (not shaped), or they are all shaped types and any ranked shapes 118 /// have compatible dimensions. Dimensions are compatible if all non-dynamic 119 /// dims are equal. The element type does not matter. 120 LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { 121 auto shapedTypes = llvm::map_to_vector<8>( 122 types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }); 123 // Return failure if some, but not all are not shaped. Return early if none 124 // are shaped also. 125 if (llvm::none_of(shapedTypes, [](auto t) { return t; })) 126 return success(); 127 if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) 128 return failure(); 129 130 // Return failure if some, but not all, are scalable vectors. 131 bool hasScalableVecTypes = false; 132 bool hasNonScalableVecTypes = false; 133 for (Type t : types) { 134 auto vType = llvm::dyn_cast<VectorType>(t); 135 if (vType && vType.isScalable()) 136 hasScalableVecTypes = true; 137 else 138 hasNonScalableVecTypes = true; 139 if (hasScalableVecTypes && hasNonScalableVecTypes) 140 return failure(); 141 } 142 143 // Remove all unranked shapes 144 auto shapes = llvm::filter_to_vector<8>( 145 shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }); 146 if (shapes.empty()) 147 return success(); 148 149 // All ranks should be equal 150 auto firstRank = shapes.front().getRank(); 151 if (llvm::any_of(shapes, 152 [&](auto shape) { return firstRank != shape.getRank(); })) 153 return failure(); 154 155 for (unsigned i = 0; i < firstRank; ++i) { 156 // Retrieve all ranked dimensions 157 auto dims = llvm::map_to_vector<8>( 158 llvm::make_filter_range( 159 shapes, [&](auto shape) { return shape.getRank() >= i; }), 160 [&](auto shape) { return shape.getDimSize(i); }); 161 if (verifyCompatibleDims(dims).failed()) 162 return failure(); 163 } 164 165 return success(); 166 } 167 168 Type OperandElementTypeIterator::mapElement(Value value) const { 169 return llvm::cast<ShapedType>(value.getType()).getElementType(); 170 } 171 172 Type ResultElementTypeIterator::mapElement(Value value) const { 173 return llvm::cast<ShapedType>(value.getType()).getElementType(); 174 } 175 176 TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices, 177 TypeRange newTypes, 178 SmallVectorImpl<Type> &storage) { 179 assert(indices.size() == newTypes.size() && 180 "mismatch between indice and type count"); 181 if (indices.empty()) 182 return oldTypes; 183 184 auto fromIt = oldTypes.begin(); 185 for (auto it : llvm::zip(indices, newTypes)) { 186 const auto toIt = oldTypes.begin() + std::get<0>(it); 187 storage.append(fromIt, toIt); 188 storage.push_back(std::get<1>(it)); 189 fromIt = toIt; 190 } 191 storage.append(fromIt, oldTypes.end()); 192 return storage; 193 } 194 195 TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices, 196 SmallVectorImpl<Type> &storage) { 197 if (indices.none()) 198 return types; 199 200 for (unsigned i = 0, e = types.size(); i < e; ++i) 201 if (!indices[i]) 202 storage.emplace_back(types[i]); 203 return storage; 204 } 205