xref: /llvm-project/mlir/lib/Dialect/Traits.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
159001277SLei Zhang //===- Traits.cpp - Common op traits shared by dialects -------------------===//
259001277SLei Zhang //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
659001277SLei Zhang //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
859001277SLei Zhang 
959001277SLei Zhang #include "mlir/Dialect/Traits.h"
1009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
11b70e4efbSJacques Pienaar #include "mlir/IR/TypeUtilities.h"
1278972572SJacques Pienaar #include "llvm/Support/FormatVariadic.h"
13a1fe1f5fSKazu Hirata #include <optional>
1459001277SLei Zhang 
1559001277SLei Zhang using namespace mlir;
1659001277SLei Zhang 
172ef71cb7STres Popp bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
182ef71cb7STres Popp                                                  ArrayRef<int64_t> shape2) {
1963a35f35SBenjamin Kramer   SmallVector<SmallVector<int64_t, 6>, 2> extents;
2063a35f35SBenjamin Kramer   extents.emplace_back(shape1.begin(), shape1.end());
2163a35f35SBenjamin Kramer   extents.emplace_back(shape2.begin(), shape2.end());
2263a35f35SBenjamin Kramer   return staticallyKnownBroadcastable(extents);
2363a35f35SBenjamin Kramer }
2463a35f35SBenjamin Kramer 
2563a35f35SBenjamin Kramer bool OpTrait::util::staticallyKnownBroadcastable(
2663a35f35SBenjamin Kramer     ArrayRef<SmallVector<int64_t, 6>> shapes) {
2763a35f35SBenjamin Kramer   assert(!shapes.empty() && "Expected at least one shape");
2863a35f35SBenjamin Kramer   size_t maxRank = shapes[0].size();
2963a35f35SBenjamin Kramer   for (size_t i = 1; i != shapes.size(); ++i)
3063a35f35SBenjamin Kramer     maxRank = std::max(maxRank, shapes[i].size());
3163a35f35SBenjamin Kramer 
3263a35f35SBenjamin Kramer   // We look backwards through every column of `shapes`.
3363a35f35SBenjamin Kramer   for (size_t i = 0; i != maxRank; ++i) {
3463a35f35SBenjamin Kramer     bool seenDynamic = false;
350a81ace0SKazu Hirata     std::optional<int64_t> nonOneDim;
3663a35f35SBenjamin Kramer     for (ArrayRef<int64_t> extent : shapes) {
3763a35f35SBenjamin Kramer       int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
3863a35f35SBenjamin Kramer 
3963a35f35SBenjamin Kramer       if (dim == 1)
4063a35f35SBenjamin Kramer         continue;
4163a35f35SBenjamin Kramer 
4263a35f35SBenjamin Kramer       // Dimensions are compatible when
4363a35f35SBenjamin Kramer       //.  1. One is dynamic, the rest are 1
4463a35f35SBenjamin Kramer       if (ShapedType::isDynamic(dim)) {
4563a35f35SBenjamin Kramer         if (seenDynamic || nonOneDim)
462ef71cb7STres Popp           return false;
4763a35f35SBenjamin Kramer         seenDynamic = true;
4863a35f35SBenjamin Kramer       }
4963a35f35SBenjamin Kramer 
5063a35f35SBenjamin Kramer       //   2. All are 1 or a specific constant.
5163a35f35SBenjamin Kramer       if (nonOneDim && dim != *nonOneDim)
5263a35f35SBenjamin Kramer         return false;
5363a35f35SBenjamin Kramer 
5463a35f35SBenjamin Kramer       nonOneDim = dim;
5563a35f35SBenjamin Kramer     }
5663a35f35SBenjamin Kramer   }
5763a35f35SBenjamin Kramer   return true;
582ef71cb7STres Popp }
592ef71cb7STres Popp 
60eeadfbc1SLei Zhang bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
61eeadfbc1SLei Zhang                                         ArrayRef<int64_t> shape2,
62eeadfbc1SLei Zhang                                         SmallVectorImpl<int64_t> &resultShape) {
637972dcefSLei Zhang   // To compute the result broadcasted shape, we compare operand shapes
647972dcefSLei Zhang   // element-wise: starting with the trailing dimensions, and working the
657972dcefSLei Zhang   // way backward. Two dimensions are compatible when
667972dcefSLei Zhang   //   1. they are equal, or
677972dcefSLei Zhang   //   2. one of them is 1
687972dcefSLei Zhang   // The result shape has the maximum among the two inputs at every
697972dcefSLei Zhang   // dimension index.
707972dcefSLei Zhang 
71eeadfbc1SLei Zhang   resultShape.clear();
727972dcefSLei Zhang   if (shape1.size() > shape2.size()) {
737972dcefSLei Zhang     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
747972dcefSLei Zhang   } else {
757972dcefSLei Zhang     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
767972dcefSLei Zhang   }
777972dcefSLei Zhang 
787972dcefSLei Zhang   auto i1 = shape1.rbegin(), e1 = shape1.rend();
797972dcefSLei Zhang   auto i2 = shape2.rbegin(), e2 = shape2.rend();
807972dcefSLei Zhang   auto iR = resultShape.rbegin();
817972dcefSLei Zhang 
827972dcefSLei Zhang   // Check each dimension is consistent.
837972dcefSLei Zhang   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84fb4cedccSAliia Khasanova     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
857972dcefSLei Zhang       // One or both dimensions is unknown. Follow TensorFlow behavior:
867972dcefSLei Zhang       // - If either dimension is greater than 1, we assume that the program is
877972dcefSLei Zhang       //   correct, and the other dimension will be broadcast to match it.
887972dcefSLei Zhang       // - If either dimension is 1, the other dimension is the output.
897972dcefSLei Zhang       if (*i1 > 1) {
907972dcefSLei Zhang         *iR = *i1;
917972dcefSLei Zhang       } else if (*i2 > 1) {
927972dcefSLei Zhang         *iR = *i2;
937972dcefSLei Zhang       } else if (*i1 == 1) {
947972dcefSLei Zhang         *iR = *i2;
957972dcefSLei Zhang       } else if (*i2 == 1) {
967972dcefSLei Zhang         *iR = *i1;
977972dcefSLei Zhang       } else {
98399638f9SAliia Khasanova         *iR = ShapedType::kDynamic;
997972dcefSLei Zhang       }
1007972dcefSLei Zhang     } else {
1017972dcefSLei Zhang       if (*i1 == *i2 || *i2 == 1) {
1027972dcefSLei Zhang         *iR = *i1;
1037972dcefSLei Zhang       } else if (*i1 == 1) {
1047972dcefSLei Zhang         *iR = *i2;
1057972dcefSLei Zhang       } else {
1067972dcefSLei Zhang         // This dimension of the two operand types is incompatible.
107eeadfbc1SLei Zhang         resultShape.clear();
108eeadfbc1SLei Zhang         return false;
1097972dcefSLei Zhang       }
1107972dcefSLei Zhang     }
1117972dcefSLei Zhang   }
1127972dcefSLei Zhang 
113eeadfbc1SLei Zhang   return true;
1147972dcefSLei Zhang }
1157972dcefSLei Zhang 
116b0be00c7SLei Zhang /// Returns the shape of the given type. Scalars will be considered as having a
117b0be00c7SLei Zhang /// shape with zero dimensions.
118b0be00c7SLei Zhang static ArrayRef<int64_t> getShape(Type type) {
1195550c821STres Popp   if (auto sType = dyn_cast<ShapedType>(type))
120090662c5SGeoffrey Martin-Noble     return sType.getShape();
121b0be00c7SLei Zhang   return {};
122b0be00c7SLei Zhang }
123b0be00c7SLei Zhang 
12459001277SLei Zhang /// Returns the result broadcast composition type from the two given types by
125c201e6efSSmit Hinsu /// following NumPy broadcast semantics. Returned type may have dynamic shape if
126c201e6efSSmit Hinsu /// either of the input types has dynamic shape. Returns null type if the two
127c201e6efSSmit Hinsu /// given types are not broadcast-compatible.
128b70e4efbSJacques Pienaar ///
129b70e4efbSJacques Pienaar /// elementType, if specified, will be used as the element type of the
130b70e4efbSJacques Pienaar /// broadcasted result type. Otherwise it is required that the element type of
131b70e4efbSJacques Pienaar /// type1 and type2 is the same and this element type will be used as the
132b70e4efbSJacques Pienaar /// resultant element type.
133b70e4efbSJacques Pienaar Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
134b70e4efbSJacques Pienaar                                        Type elementType) {
135b70e4efbSJacques Pienaar   // If the elementType is not specified, then the use the common element type
136b70e4efbSJacques Pienaar   // of the inputs or fail if there is no common element type.
137b70e4efbSJacques Pienaar   if (!elementType) {
138b70e4efbSJacques Pienaar     elementType = getElementTypeOrSelf(type1);
139b70e4efbSJacques Pienaar     if (elementType != getElementTypeOrSelf(type2))
14059001277SLei Zhang       return {};
141b70e4efbSJacques Pienaar   }
14259001277SLei Zhang 
143c201e6efSSmit Hinsu   // If one of the types is unranked tensor, then the other type shouldn't be
144c201e6efSSmit Hinsu   // vector and the result should have unranked tensor type.
1455550c821STres Popp   if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
1465550c821STres Popp     if (isa<VectorType>(type1) || isa<VectorType>(type2))
147c201e6efSSmit Hinsu       return {};
148b70e4efbSJacques Pienaar     return UnrankedTensorType::get(elementType);
149c201e6efSSmit Hinsu   }
150c201e6efSSmit Hinsu 
15159001277SLei Zhang   // Returns the type kind if the given type is a vector or ranked tensor type.
15270c73d1bSKazu Hirata   // Returns std::nullopt otherwise.
1530a81ace0SKazu Hirata   auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
1545550c821STres Popp     if (isa<VectorType, RankedTensorType>(type))
155c8c45985SRiver Riddle       return type.getTypeID();
1561a36588eSKazu Hirata     return std::nullopt;
15759001277SLei Zhang   };
15859001277SLei Zhang 
15959001277SLei Zhang   // Make sure the composite type, if has, is consistent.
1600a81ace0SKazu Hirata   std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
1610a81ace0SKazu Hirata   std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
1620a81ace0SKazu Hirata   std::optional<TypeID> resultCompositeKind;
16359001277SLei Zhang 
16459001277SLei Zhang   if (compositeKind1 && compositeKind2) {
16559001277SLei Zhang     // Disallow mixing vector and tensor.
16659001277SLei Zhang     if (compositeKind1 != compositeKind2)
16759001277SLei Zhang       return {};
16859001277SLei Zhang     resultCompositeKind = compositeKind1;
16959001277SLei Zhang   } else if (compositeKind1) {
17059001277SLei Zhang     resultCompositeKind = compositeKind1;
17159001277SLei Zhang   } else if (compositeKind2) {
17259001277SLei Zhang     resultCompositeKind = compositeKind2;
17359001277SLei Zhang   }
17459001277SLei Zhang 
17559001277SLei Zhang   // Get the shape of each type.
176eeadfbc1SLei Zhang   SmallVector<int64_t, 4> resultShape;
177eeadfbc1SLei Zhang   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
17859001277SLei Zhang     return {};
17959001277SLei Zhang 
18059001277SLei Zhang   // Compose the final broadcasted type
181c8c45985SRiver Riddle   if (resultCompositeKind == VectorType::getTypeID())
182b70e4efbSJacques Pienaar     return VectorType::get(resultShape, elementType);
183c8c45985SRiver Riddle   if (resultCompositeKind == RankedTensorType::getTypeID())
184b70e4efbSJacques Pienaar     return RankedTensorType::get(resultShape, elementType);
185b70e4efbSJacques Pienaar   return elementType;
18659001277SLei Zhang }
18759001277SLei Zhang 
188b70e4efbSJacques Pienaar /// Returns a tuple corresponding to whether range has tensor or vector type.
189b70e4efbSJacques Pienaar template <typename iterator_range>
190b70e4efbSJacques Pienaar static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
191971b8525SJakub Kuderski   return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192971b8525SJakub Kuderski           llvm::any_of(types, llvm::IsaPred<VectorType>)};
193e1595df1SLei Zhang }
194e1595df1SLei Zhang 
19551cbe4e5SJacques Pienaar static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
19651cbe4e5SJacques Pienaar                                             ArrayRef<int64_t> existing) {
197bd077e98SRafael Ubal Tena   // If both interred and existing dimensions are static, they must be equal.
198b2d76a06SRafael Ubal Tena   auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199bd077e98SRafael Ubal Tena     return ShapedType::isDynamic(existingDim) ||
200bd077e98SRafael Ubal Tena            ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201cce2f4c4SSmit Hinsu   };
20251cbe4e5SJacques Pienaar   if (inferred.size() != existing.size())
203c253c6ebSJacques Pienaar     return false;
204971b8525SJakub Kuderski   for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
205b2d76a06SRafael Ubal Tena     if (!isCompatible(inferredDim, existingDim))
206c253c6ebSJacques Pienaar       return false;
207c253c6ebSJacques Pienaar   return true;
208cce2f4c4SSmit Hinsu }
209cce2f4c4SSmit Hinsu 
210b70e4efbSJacques Pienaar static std::string getShapeString(ArrayRef<int64_t> shape) {
211b70e4efbSJacques Pienaar   // TODO: should replace with printing shape more uniformly across here and
212b70e4efbSJacques Pienaar   // when in type.
21351cbe4e5SJacques Pienaar   std::string ret;
21451cbe4e5SJacques Pienaar   llvm::raw_string_ostream ss(ret);
21551cbe4e5SJacques Pienaar   ss << '\'';
21651cbe4e5SJacques Pienaar   llvm::interleave(
21751cbe4e5SJacques Pienaar       shape, ss,
21851cbe4e5SJacques Pienaar       [&](int64_t dim) {
21951cbe4e5SJacques Pienaar         if (ShapedType::isDynamic(dim))
22051cbe4e5SJacques Pienaar           ss << '?';
22151cbe4e5SJacques Pienaar         else
22251cbe4e5SJacques Pienaar           ss << dim;
22351cbe4e5SJacques Pienaar       },
22451cbe4e5SJacques Pienaar       "x");
22551cbe4e5SJacques Pienaar   ss << '\'';
226*884221edSJOE1994   return ret;
227b0be00c7SLei Zhang }
228e1595df1SLei Zhang 
229b70e4efbSJacques Pienaar LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
230b70e4efbSJacques Pienaar   // Ensure broadcasting only tensor or only vector types.
231b70e4efbSJacques Pienaar   auto operandsHasTensorVectorType =
232b70e4efbSJacques Pienaar       hasTensorOrVectorType(op->getOperandTypes());
233b70e4efbSJacques Pienaar   auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
234b70e4efbSJacques Pienaar   if ((std::get<0>(operandsHasTensorVectorType) ||
235b70e4efbSJacques Pienaar        std::get<0>(resultsHasTensorVectorType)) &&
236b70e4efbSJacques Pienaar       (std::get<1>(operandsHasTensorVectorType) ||
237b70e4efbSJacques Pienaar        std::get<1>(resultsHasTensorVectorType)))
238b70e4efbSJacques Pienaar     return op->emitError("cannot broadcast vector with tensor");
239b70e4efbSJacques Pienaar 
240971b8525SJakub Kuderski   auto rankedOperands =
241971b8525SJakub Kuderski       make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
242b70e4efbSJacques Pienaar 
243b70e4efbSJacques Pienaar   // If all operands are unranked, then all result shapes are possible.
244b70e4efbSJacques Pienaar   if (rankedOperands.empty())
245b70e4efbSJacques Pienaar     return success();
246b70e4efbSJacques Pienaar 
247b70e4efbSJacques Pienaar   // Compute broadcasted shape of operands (which requires that operands are
248b70e4efbSJacques Pienaar   // broadcast compatible). The results need to be broadcast compatible with
249b70e4efbSJacques Pienaar   // this result shape.
250b0be00c7SLei Zhang   SmallVector<int64_t, 4> resultShape;
251b70e4efbSJacques Pienaar   (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
252b70e4efbSJacques Pienaar                                   resultShape);
253b70e4efbSJacques Pienaar   for (auto other : make_early_inc_range(rankedOperands)) {
254b70e4efbSJacques Pienaar     SmallVector<int64_t, 4> temp = resultShape;
255b70e4efbSJacques Pienaar     if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
256b0be00c7SLei Zhang       return op->emitOpError("operands don't have broadcast-compatible shapes");
257b70e4efbSJacques Pienaar   }
258b0be00c7SLei Zhang 
259971b8525SJakub Kuderski   auto rankedResults =
260971b8525SJakub Kuderski       make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
26159001277SLei Zhang 
262e5a85126SKazuaki Ishizaki   // If all of the results are unranked then no further verification.
263b70e4efbSJacques Pienaar   if (rankedResults.empty())
264b70e4efbSJacques Pienaar     return success();
265b70e4efbSJacques Pienaar 
266b70e4efbSJacques Pienaar   for (auto type : rankedResults) {
267b70e4efbSJacques Pienaar     ArrayRef<int64_t> actualSuffix =
268b70e4efbSJacques Pienaar         getShape(type).take_back(resultShape.size());
26951cbe4e5SJacques Pienaar     if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
270b70e4efbSJacques Pienaar       return op->emitOpError()
271b70e4efbSJacques Pienaar              << "result type " << getShapeString(getShape(type))
272b70e4efbSJacques Pienaar              << " not broadcast compatible with broadcasted operands's shapes "
273b70e4efbSJacques Pienaar              << getShapeString(resultShape);
274b70e4efbSJacques Pienaar   }
27567a52c44SRiver Riddle   return success();
27659001277SLei Zhang }
277