159001277SLei Zhang //===- Traits.cpp - Common op traits shared by dialects -------------------===// 259001277SLei Zhang // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 659001277SLei Zhang // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 859001277SLei Zhang 959001277SLei Zhang #include "mlir/Dialect/Traits.h" 1009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 11b70e4efbSJacques Pienaar #include "mlir/IR/TypeUtilities.h" 1278972572SJacques Pienaar #include "llvm/Support/FormatVariadic.h" 13a1fe1f5fSKazu Hirata #include <optional> 1459001277SLei Zhang 1559001277SLei Zhang using namespace mlir; 1659001277SLei Zhang 172ef71cb7STres Popp bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, 182ef71cb7STres Popp ArrayRef<int64_t> shape2) { 1963a35f35SBenjamin Kramer SmallVector<SmallVector<int64_t, 6>, 2> extents; 2063a35f35SBenjamin Kramer extents.emplace_back(shape1.begin(), shape1.end()); 2163a35f35SBenjamin Kramer extents.emplace_back(shape2.begin(), shape2.end()); 2263a35f35SBenjamin Kramer return staticallyKnownBroadcastable(extents); 2363a35f35SBenjamin Kramer } 2463a35f35SBenjamin Kramer 2563a35f35SBenjamin Kramer bool OpTrait::util::staticallyKnownBroadcastable( 2663a35f35SBenjamin Kramer ArrayRef<SmallVector<int64_t, 6>> shapes) { 2763a35f35SBenjamin Kramer assert(!shapes.empty() && "Expected at least one shape"); 2863a35f35SBenjamin Kramer size_t maxRank = shapes[0].size(); 2963a35f35SBenjamin Kramer for (size_t i = 1; i != shapes.size(); ++i) 3063a35f35SBenjamin Kramer maxRank = std::max(maxRank, shapes[i].size()); 3163a35f35SBenjamin Kramer 3263a35f35SBenjamin Kramer // We look backwards through every column of `shapes`. 3363a35f35SBenjamin Kramer for (size_t i = 0; i != maxRank; ++i) { 3463a35f35SBenjamin Kramer bool seenDynamic = false; 350a81ace0SKazu Hirata std::optional<int64_t> nonOneDim; 3663a35f35SBenjamin Kramer for (ArrayRef<int64_t> extent : shapes) { 3763a35f35SBenjamin Kramer int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; 3863a35f35SBenjamin Kramer 3963a35f35SBenjamin Kramer if (dim == 1) 4063a35f35SBenjamin Kramer continue; 4163a35f35SBenjamin Kramer 4263a35f35SBenjamin Kramer // Dimensions are compatible when 4363a35f35SBenjamin Kramer //. 1. One is dynamic, the rest are 1 4463a35f35SBenjamin Kramer if (ShapedType::isDynamic(dim)) { 4563a35f35SBenjamin Kramer if (seenDynamic || nonOneDim) 462ef71cb7STres Popp return false; 4763a35f35SBenjamin Kramer seenDynamic = true; 4863a35f35SBenjamin Kramer } 4963a35f35SBenjamin Kramer 5063a35f35SBenjamin Kramer // 2. All are 1 or a specific constant. 5163a35f35SBenjamin Kramer if (nonOneDim && dim != *nonOneDim) 5263a35f35SBenjamin Kramer return false; 5363a35f35SBenjamin Kramer 5463a35f35SBenjamin Kramer nonOneDim = dim; 5563a35f35SBenjamin Kramer } 5663a35f35SBenjamin Kramer } 5763a35f35SBenjamin Kramer return true; 582ef71cb7STres Popp } 592ef71cb7STres Popp 60eeadfbc1SLei Zhang bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 61eeadfbc1SLei Zhang ArrayRef<int64_t> shape2, 62eeadfbc1SLei Zhang SmallVectorImpl<int64_t> &resultShape) { 637972dcefSLei Zhang // To compute the result broadcasted shape, we compare operand shapes 647972dcefSLei Zhang // element-wise: starting with the trailing dimensions, and working the 657972dcefSLei Zhang // way backward. Two dimensions are compatible when 667972dcefSLei Zhang // 1. they are equal, or 677972dcefSLei Zhang // 2. one of them is 1 687972dcefSLei Zhang // The result shape has the maximum among the two inputs at every 697972dcefSLei Zhang // dimension index. 707972dcefSLei Zhang 71eeadfbc1SLei Zhang resultShape.clear(); 727972dcefSLei Zhang if (shape1.size() > shape2.size()) { 737972dcefSLei Zhang std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 747972dcefSLei Zhang } else { 757972dcefSLei Zhang std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 767972dcefSLei Zhang } 777972dcefSLei Zhang 787972dcefSLei Zhang auto i1 = shape1.rbegin(), e1 = shape1.rend(); 797972dcefSLei Zhang auto i2 = shape2.rbegin(), e2 = shape2.rend(); 807972dcefSLei Zhang auto iR = resultShape.rbegin(); 817972dcefSLei Zhang 827972dcefSLei Zhang // Check each dimension is consistent. 837972dcefSLei Zhang for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 84fb4cedccSAliia Khasanova if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { 857972dcefSLei Zhang // One or both dimensions is unknown. Follow TensorFlow behavior: 867972dcefSLei Zhang // - If either dimension is greater than 1, we assume that the program is 877972dcefSLei Zhang // correct, and the other dimension will be broadcast to match it. 887972dcefSLei Zhang // - If either dimension is 1, the other dimension is the output. 897972dcefSLei Zhang if (*i1 > 1) { 907972dcefSLei Zhang *iR = *i1; 917972dcefSLei Zhang } else if (*i2 > 1) { 927972dcefSLei Zhang *iR = *i2; 937972dcefSLei Zhang } else if (*i1 == 1) { 947972dcefSLei Zhang *iR = *i2; 957972dcefSLei Zhang } else if (*i2 == 1) { 967972dcefSLei Zhang *iR = *i1; 977972dcefSLei Zhang } else { 98399638f9SAliia Khasanova *iR = ShapedType::kDynamic; 997972dcefSLei Zhang } 1007972dcefSLei Zhang } else { 1017972dcefSLei Zhang if (*i1 == *i2 || *i2 == 1) { 1027972dcefSLei Zhang *iR = *i1; 1037972dcefSLei Zhang } else if (*i1 == 1) { 1047972dcefSLei Zhang *iR = *i2; 1057972dcefSLei Zhang } else { 1067972dcefSLei Zhang // This dimension of the two operand types is incompatible. 107eeadfbc1SLei Zhang resultShape.clear(); 108eeadfbc1SLei Zhang return false; 1097972dcefSLei Zhang } 1107972dcefSLei Zhang } 1117972dcefSLei Zhang } 1127972dcefSLei Zhang 113eeadfbc1SLei Zhang return true; 1147972dcefSLei Zhang } 1157972dcefSLei Zhang 116b0be00c7SLei Zhang /// Returns the shape of the given type. Scalars will be considered as having a 117b0be00c7SLei Zhang /// shape with zero dimensions. 118b0be00c7SLei Zhang static ArrayRef<int64_t> getShape(Type type) { 1195550c821STres Popp if (auto sType = dyn_cast<ShapedType>(type)) 120090662c5SGeoffrey Martin-Noble return sType.getShape(); 121b0be00c7SLei Zhang return {}; 122b0be00c7SLei Zhang } 123b0be00c7SLei Zhang 12459001277SLei Zhang /// Returns the result broadcast composition type from the two given types by 125c201e6efSSmit Hinsu /// following NumPy broadcast semantics. Returned type may have dynamic shape if 126c201e6efSSmit Hinsu /// either of the input types has dynamic shape. Returns null type if the two 127c201e6efSSmit Hinsu /// given types are not broadcast-compatible. 128b70e4efbSJacques Pienaar /// 129b70e4efbSJacques Pienaar /// elementType, if specified, will be used as the element type of the 130b70e4efbSJacques Pienaar /// broadcasted result type. Otherwise it is required that the element type of 131b70e4efbSJacques Pienaar /// type1 and type2 is the same and this element type will be used as the 132b70e4efbSJacques Pienaar /// resultant element type. 133b70e4efbSJacques Pienaar Type OpTrait::util::getBroadcastedType(Type type1, Type type2, 134b70e4efbSJacques Pienaar Type elementType) { 135b70e4efbSJacques Pienaar // If the elementType is not specified, then the use the common element type 136b70e4efbSJacques Pienaar // of the inputs or fail if there is no common element type. 137b70e4efbSJacques Pienaar if (!elementType) { 138b70e4efbSJacques Pienaar elementType = getElementTypeOrSelf(type1); 139b70e4efbSJacques Pienaar if (elementType != getElementTypeOrSelf(type2)) 14059001277SLei Zhang return {}; 141b70e4efbSJacques Pienaar } 14259001277SLei Zhang 143c201e6efSSmit Hinsu // If one of the types is unranked tensor, then the other type shouldn't be 144c201e6efSSmit Hinsu // vector and the result should have unranked tensor type. 1455550c821STres Popp if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) { 1465550c821STres Popp if (isa<VectorType>(type1) || isa<VectorType>(type2)) 147c201e6efSSmit Hinsu return {}; 148b70e4efbSJacques Pienaar return UnrankedTensorType::get(elementType); 149c201e6efSSmit Hinsu } 150c201e6efSSmit Hinsu 15159001277SLei Zhang // Returns the type kind if the given type is a vector or ranked tensor type. 15270c73d1bSKazu Hirata // Returns std::nullopt otherwise. 1530a81ace0SKazu Hirata auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> { 1545550c821STres Popp if (isa<VectorType, RankedTensorType>(type)) 155c8c45985SRiver Riddle return type.getTypeID(); 1561a36588eSKazu Hirata return std::nullopt; 15759001277SLei Zhang }; 15859001277SLei Zhang 15959001277SLei Zhang // Make sure the composite type, if has, is consistent. 1600a81ace0SKazu Hirata std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); 1610a81ace0SKazu Hirata std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); 1620a81ace0SKazu Hirata std::optional<TypeID> resultCompositeKind; 16359001277SLei Zhang 16459001277SLei Zhang if (compositeKind1 && compositeKind2) { 16559001277SLei Zhang // Disallow mixing vector and tensor. 16659001277SLei Zhang if (compositeKind1 != compositeKind2) 16759001277SLei Zhang return {}; 16859001277SLei Zhang resultCompositeKind = compositeKind1; 16959001277SLei Zhang } else if (compositeKind1) { 17059001277SLei Zhang resultCompositeKind = compositeKind1; 17159001277SLei Zhang } else if (compositeKind2) { 17259001277SLei Zhang resultCompositeKind = compositeKind2; 17359001277SLei Zhang } 17459001277SLei Zhang 17559001277SLei Zhang // Get the shape of each type. 176eeadfbc1SLei Zhang SmallVector<int64_t, 4> resultShape; 177eeadfbc1SLei Zhang if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 17859001277SLei Zhang return {}; 17959001277SLei Zhang 18059001277SLei Zhang // Compose the final broadcasted type 181c8c45985SRiver Riddle if (resultCompositeKind == VectorType::getTypeID()) 182b70e4efbSJacques Pienaar return VectorType::get(resultShape, elementType); 183c8c45985SRiver Riddle if (resultCompositeKind == RankedTensorType::getTypeID()) 184b70e4efbSJacques Pienaar return RankedTensorType::get(resultShape, elementType); 185b70e4efbSJacques Pienaar return elementType; 18659001277SLei Zhang } 18759001277SLei Zhang 188b70e4efbSJacques Pienaar /// Returns a tuple corresponding to whether range has tensor or vector type. 189b70e4efbSJacques Pienaar template <typename iterator_range> 190b70e4efbSJacques Pienaar static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { 191971b8525SJakub Kuderski return {llvm::any_of(types, llvm::IsaPred<TensorType>), 192971b8525SJakub Kuderski llvm::any_of(types, llvm::IsaPred<VectorType>)}; 193e1595df1SLei Zhang } 194e1595df1SLei Zhang 19551cbe4e5SJacques Pienaar static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, 19651cbe4e5SJacques Pienaar ArrayRef<int64_t> existing) { 197bd077e98SRafael Ubal Tena // If both interred and existing dimensions are static, they must be equal. 198b2d76a06SRafael Ubal Tena auto isCompatible = [](int64_t inferredDim, int64_t existingDim) { 199bd077e98SRafael Ubal Tena return ShapedType::isDynamic(existingDim) || 200bd077e98SRafael Ubal Tena ShapedType::isDynamic(inferredDim) || inferredDim == existingDim; 201cce2f4c4SSmit Hinsu }; 20251cbe4e5SJacques Pienaar if (inferred.size() != existing.size()) 203c253c6ebSJacques Pienaar return false; 204971b8525SJakub Kuderski for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing)) 205b2d76a06SRafael Ubal Tena if (!isCompatible(inferredDim, existingDim)) 206c253c6ebSJacques Pienaar return false; 207c253c6ebSJacques Pienaar return true; 208cce2f4c4SSmit Hinsu } 209cce2f4c4SSmit Hinsu 210b70e4efbSJacques Pienaar static std::string getShapeString(ArrayRef<int64_t> shape) { 211b70e4efbSJacques Pienaar // TODO: should replace with printing shape more uniformly across here and 212b70e4efbSJacques Pienaar // when in type. 21351cbe4e5SJacques Pienaar std::string ret; 21451cbe4e5SJacques Pienaar llvm::raw_string_ostream ss(ret); 21551cbe4e5SJacques Pienaar ss << '\''; 21651cbe4e5SJacques Pienaar llvm::interleave( 21751cbe4e5SJacques Pienaar shape, ss, 21851cbe4e5SJacques Pienaar [&](int64_t dim) { 21951cbe4e5SJacques Pienaar if (ShapedType::isDynamic(dim)) 22051cbe4e5SJacques Pienaar ss << '?'; 22151cbe4e5SJacques Pienaar else 22251cbe4e5SJacques Pienaar ss << dim; 22351cbe4e5SJacques Pienaar }, 22451cbe4e5SJacques Pienaar "x"); 22551cbe4e5SJacques Pienaar ss << '\''; 226*884221edSJOE1994 return ret; 227b0be00c7SLei Zhang } 228e1595df1SLei Zhang 229b70e4efbSJacques Pienaar LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 230b70e4efbSJacques Pienaar // Ensure broadcasting only tensor or only vector types. 231b70e4efbSJacques Pienaar auto operandsHasTensorVectorType = 232b70e4efbSJacques Pienaar hasTensorOrVectorType(op->getOperandTypes()); 233b70e4efbSJacques Pienaar auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); 234b70e4efbSJacques Pienaar if ((std::get<0>(operandsHasTensorVectorType) || 235b70e4efbSJacques Pienaar std::get<0>(resultsHasTensorVectorType)) && 236b70e4efbSJacques Pienaar (std::get<1>(operandsHasTensorVectorType) || 237b70e4efbSJacques Pienaar std::get<1>(resultsHasTensorVectorType))) 238b70e4efbSJacques Pienaar return op->emitError("cannot broadcast vector with tensor"); 239b70e4efbSJacques Pienaar 240971b8525SJakub Kuderski auto rankedOperands = 241971b8525SJakub Kuderski make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); 242b70e4efbSJacques Pienaar 243b70e4efbSJacques Pienaar // If all operands are unranked, then all result shapes are possible. 244b70e4efbSJacques Pienaar if (rankedOperands.empty()) 245b70e4efbSJacques Pienaar return success(); 246b70e4efbSJacques Pienaar 247b70e4efbSJacques Pienaar // Compute broadcasted shape of operands (which requires that operands are 248b70e4efbSJacques Pienaar // broadcast compatible). The results need to be broadcast compatible with 249b70e4efbSJacques Pienaar // this result shape. 250b0be00c7SLei Zhang SmallVector<int64_t, 4> resultShape; 251b70e4efbSJacques Pienaar (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, 252b70e4efbSJacques Pienaar resultShape); 253b70e4efbSJacques Pienaar for (auto other : make_early_inc_range(rankedOperands)) { 254b70e4efbSJacques Pienaar SmallVector<int64_t, 4> temp = resultShape; 255b70e4efbSJacques Pienaar if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) 256b0be00c7SLei Zhang return op->emitOpError("operands don't have broadcast-compatible shapes"); 257b70e4efbSJacques Pienaar } 258b0be00c7SLei Zhang 259971b8525SJakub Kuderski auto rankedResults = 260971b8525SJakub Kuderski make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>); 26159001277SLei Zhang 262e5a85126SKazuaki Ishizaki // If all of the results are unranked then no further verification. 263b70e4efbSJacques Pienaar if (rankedResults.empty()) 264b70e4efbSJacques Pienaar return success(); 265b70e4efbSJacques Pienaar 266b70e4efbSJacques Pienaar for (auto type : rankedResults) { 267b70e4efbSJacques Pienaar ArrayRef<int64_t> actualSuffix = 268b70e4efbSJacques Pienaar getShape(type).take_back(resultShape.size()); 26951cbe4e5SJacques Pienaar if (!isCompatibleInferredReturnShape(resultShape, actualSuffix)) 270b70e4efbSJacques Pienaar return op->emitOpError() 271b70e4efbSJacques Pienaar << "result type " << getShapeString(getShape(type)) 272b70e4efbSJacques Pienaar << " not broadcast compatible with broadcasted operands's shapes " 273b70e4efbSJacques Pienaar << getShapeString(resultShape); 274b70e4efbSJacques Pienaar } 27567a52c44SRiver Riddle return success(); 27659001277SLei Zhang } 277