1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Traits.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "llvm/Support/FormatVariadic.h" 13 #include <optional> 14 15 using namespace mlir; 16 17 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, 18 ArrayRef<int64_t> shape2) { 19 SmallVector<SmallVector<int64_t, 6>, 2> extents; 20 extents.emplace_back(shape1.begin(), shape1.end()); 21 extents.emplace_back(shape2.begin(), shape2.end()); 22 return staticallyKnownBroadcastable(extents); 23 } 24 25 bool OpTrait::util::staticallyKnownBroadcastable( 26 ArrayRef<SmallVector<int64_t, 6>> shapes) { 27 assert(!shapes.empty() && "Expected at least one shape"); 28 size_t maxRank = shapes[0].size(); 29 for (size_t i = 1; i != shapes.size(); ++i) 30 maxRank = std::max(maxRank, shapes[i].size()); 31 32 // We look backwards through every column of `shapes`. 33 for (size_t i = 0; i != maxRank; ++i) { 34 bool seenDynamic = false; 35 std::optional<int64_t> nonOneDim; 36 for (ArrayRef<int64_t> extent : shapes) { 37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; 38 39 if (dim == 1) 40 continue; 41 42 // Dimensions are compatible when 43 //. 1. One is dynamic, the rest are 1 44 if (ShapedType::isDynamic(dim)) { 45 if (seenDynamic || nonOneDim) 46 return false; 47 seenDynamic = true; 48 } 49 50 // 2. All are 1 or a specific constant. 51 if (nonOneDim && dim != *nonOneDim) 52 return false; 53 54 nonOneDim = dim; 55 } 56 } 57 return true; 58 } 59 60 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 61 ArrayRef<int64_t> shape2, 62 SmallVectorImpl<int64_t> &resultShape) { 63 // To compute the result broadcasted shape, we compare operand shapes 64 // element-wise: starting with the trailing dimensions, and working the 65 // way backward. Two dimensions are compatible when 66 // 1. they are equal, or 67 // 2. one of them is 1 68 // The result shape has the maximum among the two inputs at every 69 // dimension index. 70 71 resultShape.clear(); 72 if (shape1.size() > shape2.size()) { 73 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 74 } else { 75 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 76 } 77 78 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 79 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 80 auto iR = resultShape.rbegin(); 81 82 // Check each dimension is consistent. 83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { 85 // One or both dimensions is unknown. Follow TensorFlow behavior: 86 // - If either dimension is greater than 1, we assume that the program is 87 // correct, and the other dimension will be broadcast to match it. 88 // - If either dimension is 1, the other dimension is the output. 89 if (*i1 > 1) { 90 *iR = *i1; 91 } else if (*i2 > 1) { 92 *iR = *i2; 93 } else if (*i1 == 1) { 94 *iR = *i2; 95 } else if (*i2 == 1) { 96 *iR = *i1; 97 } else { 98 *iR = ShapedType::kDynamic; 99 } 100 } else { 101 if (*i1 == *i2 || *i2 == 1) { 102 *iR = *i1; 103 } else if (*i1 == 1) { 104 *iR = *i2; 105 } else { 106 // This dimension of the two operand types is incompatible. 107 resultShape.clear(); 108 return false; 109 } 110 } 111 } 112 113 return true; 114 } 115 116 /// Returns the shape of the given type. Scalars will be considered as having a 117 /// shape with zero dimensions. 118 static ArrayRef<int64_t> getShape(Type type) { 119 if (auto sType = dyn_cast<ShapedType>(type)) 120 return sType.getShape(); 121 return {}; 122 } 123 124 /// Returns the result broadcast composition type from the two given types by 125 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 126 /// either of the input types has dynamic shape. Returns null type if the two 127 /// given types are not broadcast-compatible. 128 /// 129 /// elementType, if specified, will be used as the element type of the 130 /// broadcasted result type. Otherwise it is required that the element type of 131 /// type1 and type2 is the same and this element type will be used as the 132 /// resultant element type. 133 Type OpTrait::util::getBroadcastedType(Type type1, Type type2, 134 Type elementType) { 135 // If the elementType is not specified, then the use the common element type 136 // of the inputs or fail if there is no common element type. 137 if (!elementType) { 138 elementType = getElementTypeOrSelf(type1); 139 if (elementType != getElementTypeOrSelf(type2)) 140 return {}; 141 } 142 143 // If one of the types is unranked tensor, then the other type shouldn't be 144 // vector and the result should have unranked tensor type. 145 if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) { 146 if (isa<VectorType>(type1) || isa<VectorType>(type2)) 147 return {}; 148 return UnrankedTensorType::get(elementType); 149 } 150 151 // Returns the type kind if the given type is a vector or ranked tensor type. 152 // Returns std::nullopt otherwise. 153 auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> { 154 if (isa<VectorType, RankedTensorType>(type)) 155 return type.getTypeID(); 156 return std::nullopt; 157 }; 158 159 // Make sure the composite type, if has, is consistent. 160 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); 161 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); 162 std::optional<TypeID> resultCompositeKind; 163 164 if (compositeKind1 && compositeKind2) { 165 // Disallow mixing vector and tensor. 166 if (compositeKind1 != compositeKind2) 167 return {}; 168 resultCompositeKind = compositeKind1; 169 } else if (compositeKind1) { 170 resultCompositeKind = compositeKind1; 171 } else if (compositeKind2) { 172 resultCompositeKind = compositeKind2; 173 } 174 175 // Get the shape of each type. 176 SmallVector<int64_t, 4> resultShape; 177 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 178 return {}; 179 180 // Compose the final broadcasted type 181 if (resultCompositeKind == VectorType::getTypeID()) 182 return VectorType::get(resultShape, elementType); 183 if (resultCompositeKind == RankedTensorType::getTypeID()) 184 return RankedTensorType::get(resultShape, elementType); 185 return elementType; 186 } 187 188 /// Returns a tuple corresponding to whether range has tensor or vector type. 189 template <typename iterator_range> 190 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { 191 return {llvm::any_of(types, llvm::IsaPred<TensorType>), 192 llvm::any_of(types, llvm::IsaPred<VectorType>)}; 193 } 194 195 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, 196 ArrayRef<int64_t> existing) { 197 // If both interred and existing dimensions are static, they must be equal. 198 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) { 199 return ShapedType::isDynamic(existingDim) || 200 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim; 201 }; 202 if (inferred.size() != existing.size()) 203 return false; 204 for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing)) 205 if (!isCompatible(inferredDim, existingDim)) 206 return false; 207 return true; 208 } 209 210 static std::string getShapeString(ArrayRef<int64_t> shape) { 211 // TODO: should replace with printing shape more uniformly across here and 212 // when in type. 213 std::string ret; 214 llvm::raw_string_ostream ss(ret); 215 ss << '\''; 216 llvm::interleave( 217 shape, ss, 218 [&](int64_t dim) { 219 if (ShapedType::isDynamic(dim)) 220 ss << '?'; 221 else 222 ss << dim; 223 }, 224 "x"); 225 ss << '\''; 226 return ret; 227 } 228 229 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 230 // Ensure broadcasting only tensor or only vector types. 231 auto operandsHasTensorVectorType = 232 hasTensorOrVectorType(op->getOperandTypes()); 233 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); 234 if ((std::get<0>(operandsHasTensorVectorType) || 235 std::get<0>(resultsHasTensorVectorType)) && 236 (std::get<1>(operandsHasTensorVectorType) || 237 std::get<1>(resultsHasTensorVectorType))) 238 return op->emitError("cannot broadcast vector with tensor"); 239 240 auto rankedOperands = 241 make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); 242 243 // If all operands are unranked, then all result shapes are possible. 244 if (rankedOperands.empty()) 245 return success(); 246 247 // Compute broadcasted shape of operands (which requires that operands are 248 // broadcast compatible). The results need to be broadcast compatible with 249 // this result shape. 250 SmallVector<int64_t, 4> resultShape; 251 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, 252 resultShape); 253 for (auto other : make_early_inc_range(rankedOperands)) { 254 SmallVector<int64_t, 4> temp = resultShape; 255 if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) 256 return op->emitOpError("operands don't have broadcast-compatible shapes"); 257 } 258 259 auto rankedResults = 260 make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>); 261 262 // If all of the results are unranked then no further verification. 263 if (rankedResults.empty()) 264 return success(); 265 266 for (auto type : rankedResults) { 267 ArrayRef<int64_t> actualSuffix = 268 getShape(type).take_back(resultShape.size()); 269 if (!isCompatibleInferredReturnShape(resultShape, actualSuffix)) 270 return op->emitOpError() 271 << "result type " << getShapeString(getShape(type)) 272 << " not broadcast compatible with broadcasted operands's shapes " 273 << getShapeString(resultShape); 274 } 275 return success(); 276 } 277