1# Shape Inference 2 3Shape inference as discussed here is considered a specific instance of type 4inference for [ShapedType][ShapedType]. Type constraints are along (at least) 5three axis: 1) elemental type, 2) rank (including static or dynamic), 3) 6dimensions. While some operations have no compile time fixed shape (e.g., output 7shape is dictated by data) we could still have some knowledge of 8constraints/bounds in the system for that operation (e.g., the output of a 9`tf.where` is at most the size of the input data). That is, there are additional 10valuable constraints that could be captured even without full knowledge of the 11shape. 12 13Type inference is currently modelled executionally for operation creation using the 14[`InferTypeOpInterface`][InferTypeOpInterface], while 15`InferShapedTypeOpInterface` is used to implement the shape and element type 16inference. The return type can often be deduced from the deduced return shape 17and elemental type (queryable from `InferShapedTypeOpInterface`) and so type 18inference for tensor types can be implemented with `InferShapedTypeOpInterface`. 19 20[TOC] 21 22## Shape functions 23 24The C++ interfaces are the base mechanism whereby shape inference is queried and 25executed, but not the intended way to specify shape constraints in general. 26 27Initially the shape inference will be declaratively specified using: 28 29* Constraints on the operands of an operation directly. For example 30 constraining the input type to be tensor/vector elements or that the 31 elemental type be of a specific type (e.g., output of computing the size 32 of a value is of elemental type `i1`) or class (e.g., float-like). 33* Constraints across operands and results of an operation. 34 35 - For example, specifying equality constraints on type/constituents of a 36 type (shape and elemental type) between operands and results (e.g., the 37 output type of an add is the same as those of the input operands). 38 39NOTE: The C++ shape functions are an intermediate step until the shape dialect 40is more full-fledged, at which point the C++ functions should become the 41exceptional case. 42 43## Testing 44 45Shape inference is currently tested alongside type inference by 46`TestReturnTypeDriver` in the test dialect. This driver performs two checks: 47 481. Verification that the return types specified matches the inferred types. This 49 explicit check will be removed and made part of Op verification instead. 502. Test the creation of Ops without specifying the return type explicitly in 51 function `testCreateFunctions` by creating new binary Ops (Op classes 52 specified in `TestReturnTypeDriver`) using 1) all operands to 53 `testCreateFunctions` as both operands, and 2) using combinations of input 54 operands of the function. 55 56## Shape dialect 57 58This section details the shape type inference dialect (`shape`). The initial 59focus will be on shape functions that describe shape functions could be used in 60runtime and compiler (for constructions of ops/refinement of shapes, reification 61of dynamic allocations for dialect including TF, TFLite, XLA & tensor compute 62dialect under discussion). 63 64This will focus on the shape functions (e.g., determine the rank and dimensions 65of the output shape). As shown in the shaped container type, shape will be one 66of 3 components, the others being elemental type and attribute (which is 67currently left open with the intention of supporting extensions such as layouts 68or bounded shapes at a later point). This allows for decoupling of these: 69 70* Not all the information is needed for all analysis; 71* Not all shape functions need to provide all the information (e.g., one could 72 define a base class function that only populates element type but composes 73 with the others); 74* It allows reusing the constraints between, say, Tensor and Memref 75 representation of an operation; 76 77An argument could be made that these are metadata function instead of shape 78functions, with some considering shape and elemental types different and some considering them both as 79part of shape. But `shape function` is IMHO descriptive and metadata can span 80too large a range of potential uses/values. 81 82### Requirements 83 84The requirements for the shape inference functions are determined by the 85requirements of shape inference, but we believe the requirements below still 86allow freedom to consider different shape inference approaches and so we do not 87impose a particular shape inference approach here. 88 89#### Shape inference functions 90 91* **Expressiveness** shape functions need to support programs where tensors 92 have shapes that are not known statically (for example, `tensor<16x?xf32>` 93 or `tensor<*xf32>*`); 94* **Shape error detection** Many operations will have constraints on their 95 operands. If the constraints are not satisfied or cannot be determined if 96 satisfied statically, then a runtime check/assertion could be generated. 97 98 * This also aligns with the requirement that the shape function description 99 should be usable by both the compiler and runtime. 100 * Shape error functions should be easy to understand, at least what 101 constraint of the operation is violated. This also requires that shape 102 function error messages should be configurable by the author of the 103 shape function (e.g., the author would be able to give the semantic 104 constraint invalidated rather the low-level check that failed). 105 * The static analysis may be used to eliminate run-time checks that are 106 guaranteed to pass. 107 * Ideally all would eventually (see section 108 [Inlining shape checking](#inline)) be elided. 109 * Only reporting errors which are guaranteed to occur at runtime. If an error is only 110 possible (rather than guaranteed) then we use a runtime assertion to fail and produce an error 111 message with the invariant violated. 112 113* Shape functions usable by compiler and runtime. 114 115 * This does not mean the exact same C++ function, but rather the 116 description should be consumable by either. 117 * Shape function description should not be constrained by either runtime 118 or compiler's type system to handle types only used for analysis. That 119 is, these two type systems differ and both should be supported, but the 120 intersection of the two should not be required. As a particular example, 121 if a compiler only wants to differentiate exact shapes vs dynamic 122 shapes, then it need not consider a more generic shape lattice even 123 though the shape description supports it. 124 125* Declarative (e.g., analyzable at compile time, possible to generate 126 different versions for different use cases) 127 128 * This may not strictly be a requirement, but a way to handle the former: 129 a declarative specification could be reused by both while avoiding a 130 need to map to or from a 3rd representation given these two systems 131 have/and will have different types. 132 133* Shape inference functions are expressible at runtime 134 135 * User can define a shape function for a new operation dynamically at runtime, 136 this allows for vendors to describe an operation and shape function 137 dynamically. 138 139 This requirement is on the wishlist. 140 141* Doesn't require graph-wide shape information (e.g., only require local 142 information) 143 144 * Shape functions should be cheap to invoke on each kernel launch. 145 * Shape function can be dictated by arguments (operands, attributes and regions) 146 only (e.g., same operands as the corresponding operation could be 147 constructed & invoked with). 148 * Shape information that needs higher-level/graph information should use 149 richer types (e.g., `TensorList<F32>`); 150 * The function should be invocable before/while constructing an op (e.g., 151 can't rely on the op being constructed). 152 153* Shape functions should be pure functions. 154 155* Should support functions whose type is only known dynamically (e.g., 156 `read_from_file` op) 157 158 * Without needing to invoke the op (e.g., reading a file once for 159 determining the shape & then post to be able to actually consume the 160 output of the file). 161 162* The shape function operation dialect should be interoperable with non-shape function dialect operations. 163 164 * There may be a common set of operations that satisfy most uses (e.g., merge, 165 equal_type, arithmetic expressions, slice, concat, pattern matching on 166 attributes such as padding etc.) that will be discovered and could cover 167 a large percentage of the use cases. Among these there will be some 168 which carry extra semantic info that could be used for symbolic 169 constraints (e.g., checking equality of two dimensions resulting in 170 setting an equality constraint) and higher-order interpretation for 171 constraint solving. 172 173 It is therefore beneficial (but not required) to reuse operations, 174 especially as for statically known shapes, arbitrary arithmetic 175 computations could still be performed. This means that the computations 176 performed statically may or may not be supported by an arbitrary solver, 177 but would still be allowed. 178 179* The shape function should be expandable such that symbolic equality and 180 upper bound constraints (say) could be represented and may be propagated by 181 shape inference. 182 183 * E.g., the shape functions may contain more information that is only 184 useful when used from shape inference; 185 186* Shape functions are allowed to fail and report an error. The error reporting 187 should report the location of the operation that failed with, where 188 possible, a user actionable error message. 189 190 * These failures could become inlined and become runtime failures with 191 runtime values and error messages. 192 * Reporting errors should be optional. E.g., The same function 193 may be used as to query validity without reporting an error. 194 195#### Non-goals 196 1971. The shape dialect is an IR representations and not a programming language; 198 * While the functions should be readable, it doesn't carry the 199 conveniences of a programming language. Deciding how people write these 200 things, e.g. a mini dsl, a C++ API that generates them, extracting them 201 programmatically from `SetShapeFn` calls, etc., is still TBD. 2021. Describe the shape inference approach that will use the shape functions; 203 * The goal is that the shape functions and the constraints one could 204 obtain from them are general enough that they would be useful for 205 various analysis. But whether we follow very simple (e.g., only fully 206 static information is used for shape output, unranked for everything 207 else) to very advance (e.g., expression trees of symbolic constants) can 208 be evaluated independently of this proposal and with concrete benefit 209 analysis. 2101. Describe the approach whereby error messages will be generated; 211 * While the shape functions will be able to emit errors optionally, it 212 will be possible to dictate when they emit an error. This enables 213 deciding whether or which error to emit: there have been proposals in 214 the literature that the iteration order for shape inference affect the 215 quality of the error message produced, and the shape functions do not 216 mandate that. 2171. Flow sensitive shape functions; 218 * To enable scalable/cheap shape inference, the shape functions do not 219 intend to provide flow sensitive information. This facility could 220 potentially be built as part of some higher order analysis that reuse 221 the shape functions/constraints due to the shape functions. 2221. All static functions are usable for dynamic/unknown shapes; 223 * More involved computations can be performed with statically known shapes 224 than what can be sensibly analyzed with unknown/symbolic variables. 225 226### Discussion 227 228#### Inline shape inference checks {#inline} 229 230Shape functions should be lowerable to runtime checks for validity. E.g. verify 231as much as possible statically, but enable generating instructions to compute the 232shape dynamically and or falling back to runtime checks for attributes not 233verifiable at compile time. These checks inserted should ideally only check that 234which could not have been verified statically. 235 236These inlined calls could interfere with optimization patterns/passes (e.g., 237shape inference should not insert constructs that interfere with optimization 238patterns) and so could be delayed until later (with another round of 239optimizations, constant folding, CSE, etc., that should remove redundant runtime 240operations). 241 242### Possibly Asked Questions 243 244#### What about ODS specifications of operations? 245 246In ODS we have been recording the constraints for the operands & attributes of 247an operation. Where these are sufficient to constrain the output shape (e.g., 248`SameOperandAndResultType` or broadcastable) we should generate the shape 249function from those. Where not, an explicit shape function should be specified 250(spelling TBD but currently considering using the MLIR textual form as 251serialization approach). 252 253#### Why not extract the shape function from reference implementation? 254 255This could be done in future! The extracted shape function would use the shape 256inference dialect, so we are starting there. Especially for operations described in a 257structured way, one could autogenerate the shape function. 258 259#### How/in what language will the shape functions be authored? 260 261TBD. open to many approaches and suggestions, starting on the IR produced by 262whatever language is the priority of this proposal. 263 264#### What shape inference approach is being suggested here? 265 266None. There are multiple different shape inference approaches that we could 267layer on top of these. From the most basic (always return unranked), to more 268useful (return fixed shape for constant inputs/arguments) to the more advanced 269(create logical conjunctions of algebraic statements between symbolic named 270values). 271 272### Open points 273 2741. Should shape functions that produce dynamic outputs given all statically 275 shaped inputs be marked specially? E.g., read from file. 276 277TODO: Add examples here. 278 279## WIP/Future considerations 280 281Shape functions are determined by attributes and could be arbitrarily 282complicated with a wide-range of specification possibilities. Equality 283relationships are common (e.g., the elemental type of the output matches the 284primitive type of the inputs, both inputs have exactly the same type [primitive 285type and shape]) and so these should be easy to specify. Algebraic relationships 286would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 287is `[n+n, m]` matrix), while some ops only have defined shapes under certain 288cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b 289== c`). 290 291Instead of specifying an additional mechanism to specify a shape transfer 292function, the reference implementation of the operation will be used to derive 293the shape function. The reference implementation is general and can support the 294arbitrary computations needed to specify output shapes. 295 296[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Interfaces/InferTypeOpInterface.td 297[ShapedType]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/IR/BuiltinTypes.h 298