xref: /llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h (revision 36d936a2d057ddbd7822614edf01e39a0c21d654)
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
28 class ShapedTypeComponents;
29 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
30 
31 /// Reify the shape of the result of an operation (typically in terms of the
32 /// shape of its operands).
33 LogicalResult
34 reifyResultShapes(OpBuilder &b, Operation *op,
35                   ReifiedRankedShapedTypeDims &reifiedReturnShapes);
36 
37 /// Adaptor class to abstract the differences between whether value is from
38 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
39 class ShapeAdaptor {
40 public:
41   ShapeAdaptor(Type t) {
42     if (auto st = dyn_cast<ShapedType>(t))
43       val = st;
44   }
45   ShapeAdaptor(Attribute t) {
46     if (auto da = dyn_cast<DenseIntElementsAttr>(t))
47       val = da;
48   }
49   ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
50   ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
51 
52   /// Returns whether the shape has a rank.
53   bool hasRank() const;
54 
55   /// Returns the element type.
56   Type getElementType() const;
57 
58   /// Populates the dimensions from shape referenced.
59   /// Requires: shape is ranked.
60   void getDims(SmallVectorImpl<int64_t> &res) const;
61 
62   /// Populates the dimensions of the ShapeTypeComponents.
63   /// Requires: shape is ranked.
64   void getDims(ShapedTypeComponents &res) const;
65 
66   /// Returns the size of the index'th dimension.
67   /// Requires: shape is ranked.
68   int64_t getDimSize(int index) const;
69 
70   /// Returns whether the index'th dimension is dynamic.
71   /// Requires: shape is ranked.
72   bool isDynamicDim(int index) const {
73     return ShapedType::isDynamic(getDimSize(index));
74   }
75 
76   /// Returns whether the shape is fully static.
77   bool hasStaticShape() const;
78 
79   /// Returns the rank of the shape.
80   /// Requires: shape is ranked.
81   int64_t getRank() const;
82 
83   /// Returns the number of elements in the shape.
84   /// Requires: hasStaticShape
85   int64_t getNumElements() const;
86 
87   /// Returns whether valid (non-null) shape.
88   explicit operator bool() const { return !val.isNull(); }
89 
90   /// Dumps textual repesentation to stderr.
91   void dump() const;
92 
93 private:
94   // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
95   // casted), or DenseIntElementsAttribute (stored as Atrtribute).
96   PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
97 };
98 
99 /// ShapedTypeComponents that represents the components of a ShapedType.
100 /// The components consist of
101 ///  - A ranked or unranked shape with the dimension specification match those
102 ///    of ShapeType's getShape() (e.g., dynamic dimension represented using
103 ///    ShapedType::kDynamic)
104 ///  - A element type, may be unset (nullptr)
105 ///  - A attribute, may be unset (nullptr)
106 /// Used by ShapedType type inferences.
107 class ShapedTypeComponents {
108   /// Internal storage type for shape.
109   using ShapeStorageT = SmallVector<int64_t, 3>;
110 
111 public:
112   /// Default construction is an unranked shape.
113   ShapedTypeComponents() : elementType(nullptr), attr(nullptr) {};
114   ShapedTypeComponents(Type elementType)
115       : elementType(elementType), attr(nullptr), ranked(false) {}
116   ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
117     ranked = shapedType.hasRank();
118     elementType = shapedType.getElementType();
119     if (ranked)
120       dims = llvm::to_vector<4>(shapedType.getShape());
121   }
122   ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
123     ranked = adaptor.hasRank();
124     elementType = adaptor.getElementType();
125     if (ranked)
126       adaptor.getDims(*this);
127   }
128   template <typename Arg, typename = std::enable_if_t<
129                               std::is_constructible<ShapeStorageT, Arg>::value>>
130   ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
131                        Attribute attr = nullptr)
132       : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
133         ranked(true) {}
134   ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
135                        Attribute attr = nullptr)
136       : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
137         ranked(true) {}
138 
139   /// Return the dimensions of the shape.
140   /// Requires: shape is ranked.
141   ArrayRef<int64_t> getDims() const {
142     assert(ranked && "requires ranked shape");
143     return dims;
144   }
145 
146   /// Return whether the shape has a rank.
147   bool hasRank() const { return ranked; };
148 
149   /// Return the element type component.
150   Type getElementType() const { return elementType; };
151 
152   /// Return the raw attribute component.
153   Attribute getAttribute() const { return attr; };
154 
155 private:
156   friend class ShapeAdaptor;
157 
158   ShapeStorageT dims;
159   Type elementType;
160   Attribute attr;
161   bool ranked{false};
162 };
163 
164 /// Range of values and shapes (corresponding effectively to Shapes dialect's
165 /// ValueShape type concept).
166 // Currently this exposes the Value (of operands) and Type of the Value. This is
167 // not ideal as then one can accidentally reference an out of date shape. This
168 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
169 // allow returning anything other than Value.
170 class ValueShapeRange : public ValueRange::RangeBaseT {
171 public:
172   using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>;
173 
174   ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
175                   ValueShapeMapFn valueToShape = nullptr)
176       : RangeBaseT(values), operandShape(operandShape),
177         valueToShape(valueToShape) {}
178   ValueShapeRange(const std::initializer_list<Value> &values)
179       : ValueShapeRange(ValueRange(values)) {}
180 
181   ValueShapeRange(const ValueShapeRange &) = default;
182 
183   /// Sets the Value to ShapeAdaptor mapping function and returns this.
184   ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) {
185     valueToShape = fn;
186     return *this;
187   }
188 
189   ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) {
190     operandShape = fn;
191     return *this;
192   }
193 
194   /// Returns the set Value to ShapeAdaptor mapping function.
195   ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
196   ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
197 
198   // Accessors.
199 
200   /// Returns the types of the values within this range.
201   /// Note: This returns only the types of Values in the ValueRange and not a
202   /// more refined type.
203   using type_iterator = ValueTypeIterator<iterator>;
204   using type_range = ValueTypeRange<ValueRange>;
205   type_range getTypes() const { return {begin(), end()}; }
206   auto getType() const { return getTypes(); }
207 
208   /// Returns the Values in the ValueRange.
209   /// To query the most up to date shape of a Value, query the shape
210   /// using getShape below rather than using the type of the Value.
211   ValueRange getValues() const { return ValueRange(begin(), end()); };
212 
213   /// Returns an argument as shape. If the argument is not constant or not a
214   /// shape, then the function returns a nullptr.
215   /// This will first query the valueToShape mapping (if set), before querying
216   /// the ValueRange.
217   ShapeAdaptor getValueAsShape(int index);
218 
219   /// Returns the shape of index'th operand.
220   // TODO: Update so that operator[] references these instead to avoid
221   // accidentally refering to less refined shape.
222   ShapeAdaptor getShape(int index) const;
223 
224   /// Returns the shape of the given Value.
225   ShapeAdaptor getShape(Value val) const;
226 
227 private:
228   // Mapping from Value to ShapedTypeComponents corresponding to shape of type
229   // of Value.
230   ValueShapeMapFn operandShape;
231 
232   // Mapping from Value to ShapedTypeComponents corresponding to constant Value
233   // if interpreted as shape.
234   ValueShapeMapFn valueToShape;
235 };
236 
237 namespace detail {
238 // Helper function to infer return tensor returns types given element and
239 // shape inference function.
240 LogicalResult
241 inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
242                        SmallVectorImpl<Type> &inferredReturnTypes);
243 
244 /// Verifies that the inferred result types match the actual result types for
245 /// the op. Precondition: op implements InferTypeOpInterface.
246 LogicalResult verifyInferredResultTypes(Operation *op);
247 
248 /// Report a fatal error indicating that the result types could not be
249 /// inferred.
250 void reportFatalInferReturnTypesError(OperationState &state);
251 } // namespace detail
252 
253 namespace OpTrait {
254 template <typename ConcreteType>
255 class InferTensorType;
256 } // namespace OpTrait
257 } // namespace mlir
258 
259 /// Include the generated interface declarations.
260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
261 
262 namespace mlir {
263 namespace OpTrait {
264 
265 template <typename ConcreteType>
266 class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
267 };
268 
269 template <typename ConcreteType>
270 class InferShapedTypeOpAdaptor
271     : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
272 
273 /// Tensor type inference trait that constructs a tensor from the inferred
274 /// shape and elemental types.
275 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
276 ///   Less strict is possible (e.g., implements inferReturnTypeComponents and
277 ///   these always populates all element types and shapes or fails, but this
278 ///   trait is currently only used where the interfaces are, so keep it
279 ///   restricted for now).
280 template <typename ConcreteType>
281 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
282 
283 } // namespace OpTrait
284 } // namespace mlir
285 
286 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
287