1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 16 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/Location.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/Support/LLVM.h" 23 #include "llvm/ADT/PointerUnion.h" 24 #include "llvm/ADT/SmallVector.h" 25 26 namespace mlir { 27 28 class ShapedTypeComponents; 29 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>; 30 31 /// Reify the shape of the result of an operation (typically in terms of the 32 /// shape of its operands). 33 LogicalResult 34 reifyResultShapes(OpBuilder &b, Operation *op, 35 ReifiedRankedShapedTypeDims &reifiedReturnShapes); 36 37 /// Adaptor class to abstract the differences between whether value is from 38 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. 39 class ShapeAdaptor { 40 public: 41 ShapeAdaptor(Type t) { 42 if (auto st = dyn_cast<ShapedType>(t)) 43 val = st; 44 } 45 ShapeAdaptor(Attribute t) { 46 if (auto da = dyn_cast<DenseIntElementsAttr>(t)) 47 val = da; 48 } 49 ShapeAdaptor(ShapedTypeComponents *components) : val(components) {} 50 ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {} 51 52 /// Returns whether the shape has a rank. 53 bool hasRank() const; 54 55 /// Returns the element type. 56 Type getElementType() const; 57 58 /// Populates the dimensions from shape referenced. 59 /// Requires: shape is ranked. 60 void getDims(SmallVectorImpl<int64_t> &res) const; 61 62 /// Populates the dimensions of the ShapeTypeComponents. 63 /// Requires: shape is ranked. 64 void getDims(ShapedTypeComponents &res) const; 65 66 /// Returns the size of the index'th dimension. 67 /// Requires: shape is ranked. 68 int64_t getDimSize(int index) const; 69 70 /// Returns whether the index'th dimension is dynamic. 71 /// Requires: shape is ranked. 72 bool isDynamicDim(int index) const { 73 return ShapedType::isDynamic(getDimSize(index)); 74 } 75 76 /// Returns whether the shape is fully static. 77 bool hasStaticShape() const; 78 79 /// Returns the rank of the shape. 80 /// Requires: shape is ranked. 81 int64_t getRank() const; 82 83 /// Returns the number of elements in the shape. 84 /// Requires: hasStaticShape 85 int64_t getNumElements() const; 86 87 /// Returns whether valid (non-null) shape. 88 explicit operator bool() const { return !val.isNull(); } 89 90 /// Dumps textual repesentation to stderr. 91 void dump() const; 92 93 private: 94 // Union storing either ShapedTypeComponents, ShapedType (stored as Type and 95 // casted), or DenseIntElementsAttribute (stored as Atrtribute). 96 PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr; 97 }; 98 99 /// ShapedTypeComponents that represents the components of a ShapedType. 100 /// The components consist of 101 /// - A ranked or unranked shape with the dimension specification match those 102 /// of ShapeType's getShape() (e.g., dynamic dimension represented using 103 /// ShapedType::kDynamic) 104 /// - A element type, may be unset (nullptr) 105 /// - A attribute, may be unset (nullptr) 106 /// Used by ShapedType type inferences. 107 class ShapedTypeComponents { 108 /// Internal storage type for shape. 109 using ShapeStorageT = SmallVector<int64_t, 3>; 110 111 public: 112 /// Default construction is an unranked shape. 113 ShapedTypeComponents() : elementType(nullptr), attr(nullptr) {}; 114 ShapedTypeComponents(Type elementType) 115 : elementType(elementType), attr(nullptr), ranked(false) {} 116 ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { 117 ranked = shapedType.hasRank(); 118 elementType = shapedType.getElementType(); 119 if (ranked) 120 dims = llvm::to_vector<4>(shapedType.getShape()); 121 } 122 ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { 123 ranked = adaptor.hasRank(); 124 elementType = adaptor.getElementType(); 125 if (ranked) 126 adaptor.getDims(*this); 127 } 128 template <typename Arg, typename = std::enable_if_t< 129 std::is_constructible<ShapeStorageT, Arg>::value>> 130 ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, 131 Attribute attr = nullptr) 132 : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr), 133 ranked(true) {} 134 ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr, 135 Attribute attr = nullptr) 136 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), 137 ranked(true) {} 138 139 /// Return the dimensions of the shape. 140 /// Requires: shape is ranked. 141 ArrayRef<int64_t> getDims() const { 142 assert(ranked && "requires ranked shape"); 143 return dims; 144 } 145 146 /// Return whether the shape has a rank. 147 bool hasRank() const { return ranked; }; 148 149 /// Return the element type component. 150 Type getElementType() const { return elementType; }; 151 152 /// Return the raw attribute component. 153 Attribute getAttribute() const { return attr; }; 154 155 private: 156 friend class ShapeAdaptor; 157 158 ShapeStorageT dims; 159 Type elementType; 160 Attribute attr; 161 bool ranked{false}; 162 }; 163 164 /// Range of values and shapes (corresponding effectively to Shapes dialect's 165 /// ValueShape type concept). 166 // Currently this exposes the Value (of operands) and Type of the Value. This is 167 // not ideal as then one can accidentally reference an out of date shape. This 168 // is done to both enable gradual switch and also as OpAdaptor doesn't currently 169 // allow returning anything other than Value. 170 class ValueShapeRange : public ValueRange::RangeBaseT { 171 public: 172 using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>; 173 174 ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr, 175 ValueShapeMapFn valueToShape = nullptr) 176 : RangeBaseT(values), operandShape(operandShape), 177 valueToShape(valueToShape) {} 178 ValueShapeRange(const std::initializer_list<Value> &values) 179 : ValueShapeRange(ValueRange(values)) {} 180 181 ValueShapeRange(const ValueShapeRange &) = default; 182 183 /// Sets the Value to ShapeAdaptor mapping function and returns this. 184 ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) { 185 valueToShape = fn; 186 return *this; 187 } 188 189 ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) { 190 operandShape = fn; 191 return *this; 192 } 193 194 /// Returns the set Value to ShapeAdaptor mapping function. 195 ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; } 196 ValueShapeMapFn getOperandShapeMapping() const { return operandShape; } 197 198 // Accessors. 199 200 /// Returns the types of the values within this range. 201 /// Note: This returns only the types of Values in the ValueRange and not a 202 /// more refined type. 203 using type_iterator = ValueTypeIterator<iterator>; 204 using type_range = ValueTypeRange<ValueRange>; 205 type_range getTypes() const { return {begin(), end()}; } 206 auto getType() const { return getTypes(); } 207 208 /// Returns the Values in the ValueRange. 209 /// To query the most up to date shape of a Value, query the shape 210 /// using getShape below rather than using the type of the Value. 211 ValueRange getValues() const { return ValueRange(begin(), end()); }; 212 213 /// Returns an argument as shape. If the argument is not constant or not a 214 /// shape, then the function returns a nullptr. 215 /// This will first query the valueToShape mapping (if set), before querying 216 /// the ValueRange. 217 ShapeAdaptor getValueAsShape(int index); 218 219 /// Returns the shape of index'th operand. 220 // TODO: Update so that operator[] references these instead to avoid 221 // accidentally refering to less refined shape. 222 ShapeAdaptor getShape(int index) const; 223 224 /// Returns the shape of the given Value. 225 ShapeAdaptor getShape(Value val) const; 226 227 private: 228 // Mapping from Value to ShapedTypeComponents corresponding to shape of type 229 // of Value. 230 ValueShapeMapFn operandShape; 231 232 // Mapping from Value to ShapedTypeComponents corresponding to constant Value 233 // if interpreted as shape. 234 ValueShapeMapFn valueToShape; 235 }; 236 237 namespace detail { 238 // Helper function to infer return tensor returns types given element and 239 // shape inference function. 240 LogicalResult 241 inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents, 242 SmallVectorImpl<Type> &inferredReturnTypes); 243 244 /// Verifies that the inferred result types match the actual result types for 245 /// the op. Precondition: op implements InferTypeOpInterface. 246 LogicalResult verifyInferredResultTypes(Operation *op); 247 248 /// Report a fatal error indicating that the result types could not be 249 /// inferred. 250 void reportFatalInferReturnTypesError(OperationState &state); 251 } // namespace detail 252 253 namespace OpTrait { 254 template <typename ConcreteType> 255 class InferTensorType; 256 } // namespace OpTrait 257 } // namespace mlir 258 259 /// Include the generated interface declarations. 260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc" 261 262 namespace mlir { 263 namespace OpTrait { 264 265 template <typename ConcreteType> 266 class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> { 267 }; 268 269 template <typename ConcreteType> 270 class InferShapedTypeOpAdaptor 271 : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {}; 272 273 /// Tensor type inference trait that constructs a tensor from the inferred 274 /// shape and elemental types. 275 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. 276 /// Less strict is possible (e.g., implements inferReturnTypeComponents and 277 /// these always populates all element types and shapes or fails, but this 278 /// trait is currently only used where the interfaces are, so keep it 279 /// restricted for now). 280 template <typename ConcreteType> 281 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {}; 282 283 } // namespace OpTrait 284 } // namespace mlir 285 286 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 287