xref: /llvm-project/mlir/lib/Dialect/Traits.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include <optional>
14 
15 using namespace mlir;
16 
17 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
18                                                  ArrayRef<int64_t> shape2) {
19   SmallVector<SmallVector<int64_t, 6>, 2> extents;
20   extents.emplace_back(shape1.begin(), shape1.end());
21   extents.emplace_back(shape2.begin(), shape2.end());
22   return staticallyKnownBroadcastable(extents);
23 }
24 
25 bool OpTrait::util::staticallyKnownBroadcastable(
26     ArrayRef<SmallVector<int64_t, 6>> shapes) {
27   assert(!shapes.empty() && "Expected at least one shape");
28   size_t maxRank = shapes[0].size();
29   for (size_t i = 1; i != shapes.size(); ++i)
30     maxRank = std::max(maxRank, shapes[i].size());
31 
32   // We look backwards through every column of `shapes`.
33   for (size_t i = 0; i != maxRank; ++i) {
34     bool seenDynamic = false;
35     std::optional<int64_t> nonOneDim;
36     for (ArrayRef<int64_t> extent : shapes) {
37       int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
38 
39       if (dim == 1)
40         continue;
41 
42       // Dimensions are compatible when
43       //.  1. One is dynamic, the rest are 1
44       if (ShapedType::isDynamic(dim)) {
45         if (seenDynamic || nonOneDim)
46           return false;
47         seenDynamic = true;
48       }
49 
50       //   2. All are 1 or a specific constant.
51       if (nonOneDim && dim != *nonOneDim)
52         return false;
53 
54       nonOneDim = dim;
55     }
56   }
57   return true;
58 }
59 
60 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
61                                         ArrayRef<int64_t> shape2,
62                                         SmallVectorImpl<int64_t> &resultShape) {
63   // To compute the result broadcasted shape, we compare operand shapes
64   // element-wise: starting with the trailing dimensions, and working the
65   // way backward. Two dimensions are compatible when
66   //   1. they are equal, or
67   //   2. one of them is 1
68   // The result shape has the maximum among the two inputs at every
69   // dimension index.
70 
71   resultShape.clear();
72   if (shape1.size() > shape2.size()) {
73     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
74   } else {
75     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
76   }
77 
78   auto i1 = shape1.rbegin(), e1 = shape1.rend();
79   auto i2 = shape2.rbegin(), e2 = shape2.rend();
80   auto iR = resultShape.rbegin();
81 
82   // Check each dimension is consistent.
83   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
85       // One or both dimensions is unknown. Follow TensorFlow behavior:
86       // - If either dimension is greater than 1, we assume that the program is
87       //   correct, and the other dimension will be broadcast to match it.
88       // - If either dimension is 1, the other dimension is the output.
89       if (*i1 > 1) {
90         *iR = *i1;
91       } else if (*i2 > 1) {
92         *iR = *i2;
93       } else if (*i1 == 1) {
94         *iR = *i2;
95       } else if (*i2 == 1) {
96         *iR = *i1;
97       } else {
98         *iR = ShapedType::kDynamic;
99       }
100     } else {
101       if (*i1 == *i2 || *i2 == 1) {
102         *iR = *i1;
103       } else if (*i1 == 1) {
104         *iR = *i2;
105       } else {
106         // This dimension of the two operand types is incompatible.
107         resultShape.clear();
108         return false;
109       }
110     }
111   }
112 
113   return true;
114 }
115 
116 /// Returns the shape of the given type. Scalars will be considered as having a
117 /// shape with zero dimensions.
118 static ArrayRef<int64_t> getShape(Type type) {
119   if (auto sType = dyn_cast<ShapedType>(type))
120     return sType.getShape();
121   return {};
122 }
123 
124 /// Returns the result broadcast composition type from the two given types by
125 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
126 /// either of the input types has dynamic shape. Returns null type if the two
127 /// given types are not broadcast-compatible.
128 ///
129 /// elementType, if specified, will be used as the element type of the
130 /// broadcasted result type. Otherwise it is required that the element type of
131 /// type1 and type2 is the same and this element type will be used as the
132 /// resultant element type.
133 Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
134                                        Type elementType) {
135   // If the elementType is not specified, then the use the common element type
136   // of the inputs or fail if there is no common element type.
137   if (!elementType) {
138     elementType = getElementTypeOrSelf(type1);
139     if (elementType != getElementTypeOrSelf(type2))
140       return {};
141   }
142 
143   // If one of the types is unranked tensor, then the other type shouldn't be
144   // vector and the result should have unranked tensor type.
145   if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
146     if (isa<VectorType>(type1) || isa<VectorType>(type2))
147       return {};
148     return UnrankedTensorType::get(elementType);
149   }
150 
151   // Returns the type kind if the given type is a vector or ranked tensor type.
152   // Returns std::nullopt otherwise.
153   auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
154     if (isa<VectorType, RankedTensorType>(type))
155       return type.getTypeID();
156     return std::nullopt;
157   };
158 
159   // Make sure the composite type, if has, is consistent.
160   std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
161   std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
162   std::optional<TypeID> resultCompositeKind;
163 
164   if (compositeKind1 && compositeKind2) {
165     // Disallow mixing vector and tensor.
166     if (compositeKind1 != compositeKind2)
167       return {};
168     resultCompositeKind = compositeKind1;
169   } else if (compositeKind1) {
170     resultCompositeKind = compositeKind1;
171   } else if (compositeKind2) {
172     resultCompositeKind = compositeKind2;
173   }
174 
175   // Get the shape of each type.
176   SmallVector<int64_t, 4> resultShape;
177   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
178     return {};
179 
180   // Compose the final broadcasted type
181   if (resultCompositeKind == VectorType::getTypeID())
182     return VectorType::get(resultShape, elementType);
183   if (resultCompositeKind == RankedTensorType::getTypeID())
184     return RankedTensorType::get(resultShape, elementType);
185   return elementType;
186 }
187 
188 /// Returns a tuple corresponding to whether range has tensor or vector type.
189 template <typename iterator_range>
190 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
191   return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192           llvm::any_of(types, llvm::IsaPred<VectorType>)};
193 }
194 
195 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
196                                             ArrayRef<int64_t> existing) {
197   // If both interred and existing dimensions are static, they must be equal.
198   auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199     return ShapedType::isDynamic(existingDim) ||
200            ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201   };
202   if (inferred.size() != existing.size())
203     return false;
204   for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
205     if (!isCompatible(inferredDim, existingDim))
206       return false;
207   return true;
208 }
209 
210 static std::string getShapeString(ArrayRef<int64_t> shape) {
211   // TODO: should replace with printing shape more uniformly across here and
212   // when in type.
213   std::string ret;
214   llvm::raw_string_ostream ss(ret);
215   ss << '\'';
216   llvm::interleave(
217       shape, ss,
218       [&](int64_t dim) {
219         if (ShapedType::isDynamic(dim))
220           ss << '?';
221         else
222           ss << dim;
223       },
224       "x");
225   ss << '\'';
226   return ret;
227 }
228 
229 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
230   // Ensure broadcasting only tensor or only vector types.
231   auto operandsHasTensorVectorType =
232       hasTensorOrVectorType(op->getOperandTypes());
233   auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
234   if ((std::get<0>(operandsHasTensorVectorType) ||
235        std::get<0>(resultsHasTensorVectorType)) &&
236       (std::get<1>(operandsHasTensorVectorType) ||
237        std::get<1>(resultsHasTensorVectorType)))
238     return op->emitError("cannot broadcast vector with tensor");
239 
240   auto rankedOperands =
241       make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
242 
243   // If all operands are unranked, then all result shapes are possible.
244   if (rankedOperands.empty())
245     return success();
246 
247   // Compute broadcasted shape of operands (which requires that operands are
248   // broadcast compatible). The results need to be broadcast compatible with
249   // this result shape.
250   SmallVector<int64_t, 4> resultShape;
251   (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
252                                   resultShape);
253   for (auto other : make_early_inc_range(rankedOperands)) {
254     SmallVector<int64_t, 4> temp = resultShape;
255     if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
256       return op->emitOpError("operands don't have broadcast-compatible shapes");
257   }
258 
259   auto rankedResults =
260       make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
261 
262   // If all of the results are unranked then no further verification.
263   if (rankedResults.empty())
264     return success();
265 
266   for (auto type : rankedResults) {
267     ArrayRef<int64_t> actualSuffix =
268         getShape(type).take_back(resultShape.size());
269     if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
270       return op->emitOpError()
271              << "result type " << getShapeString(getShape(type))
272              << " not broadcast compatible with broadcasted operands's shapes "
273              << getShapeString(resultShape);
274   }
275   return success();
276 }
277