155de49acSRiver Riddle# Shape Inference 2fa26a37dSJacques Pienaar 3fa26a37dSJacques PienaarShape inference as discussed here is considered a specific instance of type 4fa26a37dSJacques Pienaarinference for [ShapedType][ShapedType]. Type constraints are along (at least) 5fa26a37dSJacques Pienaarthree axis: 1) elemental type, 2) rank (including static or dynamic), 3) 6fa26a37dSJacques Pienaardimensions. While some operations have no compile time fixed shape (e.g., output 7fa26a37dSJacques Pienaarshape is dictated by data) we could still have some knowledge of 8fa26a37dSJacques Pienaarconstraints/bounds in the system for that operation (e.g., the output of a 9fa26a37dSJacques Pienaar`tf.where` is at most the size of the input data). That is, there are additional 10fa26a37dSJacques Pienaarvaluable constraints that could be captured even without full knowledge of the 11fa26a37dSJacques Pienaarshape. 12fa26a37dSJacques Pienaar 13453cd2dbSBaden HughesType inference is currently modelled executionally for operation creation using the 14fa26a37dSJacques Pienaar[`InferTypeOpInterface`][InferTypeOpInterface], while 15fa26a37dSJacques Pienaar`InferShapedTypeOpInterface` is used to implement the shape and element type 16fa26a37dSJacques Pienaarinference. The return type can often be deduced from the deduced return shape 17fa26a37dSJacques Pienaarand elemental type (queryable from `InferShapedTypeOpInterface`) and so type 18fa26a37dSJacques Pienaarinference for tensor types can be implemented with `InferShapedTypeOpInterface`. 19fa26a37dSJacques Pienaar 20*64bb0ae7SAlex Zinenko[TOC] 21*64bb0ae7SAlex Zinenko 22fa26a37dSJacques Pienaar## Shape functions 23fa26a37dSJacques Pienaar 24fa26a37dSJacques PienaarThe C++ interfaces are the base mechanism whereby shape inference is queried and 25fa26a37dSJacques Pienaarexecuted, but not the intended way to specify shape constraints in general. 26fa26a37dSJacques Pienaar 27fa26a37dSJacques PienaarInitially the shape inference will be declaratively specified using: 28fa26a37dSJacques Pienaar 29fa26a37dSJacques Pienaar* Constraints on the operands of an operation directly. For example 30fa26a37dSJacques Pienaar constraining the input type to be tensor/vector elements or that the 31fa26a37dSJacques Pienaar elemental type be of a specific type (e.g., output of computing the size 32453cd2dbSBaden Hughes of a value is of elemental type `i1`) or class (e.g., float-like). 33fa26a37dSJacques Pienaar* Constraints across operands and results of an operation. 34fa26a37dSJacques Pienaar 35fa26a37dSJacques Pienaar - For example, specifying equality constraints on type/constituents of a 36fa26a37dSJacques Pienaar type (shape and elemental type) between operands and results (e.g., the 37fa26a37dSJacques Pienaar output type of an add is the same as those of the input operands). 38fa26a37dSJacques Pienaar 39fa26a37dSJacques PienaarNOTE: The C++ shape functions are an intermediate step until the shape dialect 40fa26a37dSJacques Pienaaris more full-fledged, at which point the C++ functions should become the 41fa26a37dSJacques Pienaarexceptional case. 42fa26a37dSJacques Pienaar 43fa26a37dSJacques Pienaar## Testing 44fa26a37dSJacques Pienaar 45fa26a37dSJacques PienaarShape inference is currently tested alongside type inference by 46453cd2dbSBaden Hughes`TestReturnTypeDriver` in the test dialect. This driver performs two checks: 47fa26a37dSJacques Pienaar 484dc39ae7SJacques Pienaar1. Verification that the return types specified matches the inferred types. This 49fc817b09SKazuaki Ishizaki explicit check will be removed and made part of Op verification instead. 50fa26a37dSJacques Pienaar2. Test the creation of Ops without specifying the return type explicitly in 51fa26a37dSJacques Pienaar function `testCreateFunctions` by creating new binary Ops (Op classes 52fa26a37dSJacques Pienaar specified in `TestReturnTypeDriver`) using 1) all operands to 53fa26a37dSJacques Pienaar `testCreateFunctions` as both operands, and 2) using combinations of input 54fa26a37dSJacques Pienaar operands of the function. 55fa26a37dSJacques Pienaar 56c4b4c0c4SJacques Pienaar## Shape dialect 57c4b4c0c4SJacques Pienaar 58c4b4c0c4SJacques PienaarThis section details the shape type inference dialect (`shape`). The initial 59c4b4c0c4SJacques Pienaarfocus will be on shape functions that describe shape functions could be used in 60c4b4c0c4SJacques Pienaarruntime and compiler (for constructions of ops/refinement of shapes, reification 61c4b4c0c4SJacques Pienaarof dynamic allocations for dialect including TF, TFLite, XLA & tensor compute 62c4b4c0c4SJacques Pienaardialect under discussion). 63c4b4c0c4SJacques Pienaar 64c4b4c0c4SJacques PienaarThis will focus on the shape functions (e.g., determine the rank and dimensions 65c4b4c0c4SJacques Pienaarof the output shape). As shown in the shaped container type, shape will be one 66c4b4c0c4SJacques Pienaarof 3 components, the others being elemental type and attribute (which is 67c4b4c0c4SJacques Pienaarcurrently left open with the intention of supporting extensions such as layouts 68453cd2dbSBaden Hughesor bounded shapes at a later point). This allows for decoupling of these: 69c4b4c0c4SJacques Pienaar 70c4b4c0c4SJacques Pienaar* Not all the information is needed for all analysis; 71c4b4c0c4SJacques Pienaar* Not all shape functions need to provide all the information (e.g., one could 72c4b4c0c4SJacques Pienaar define a base class function that only populates element type but composes 73c4b4c0c4SJacques Pienaar with the others); 74c4b4c0c4SJacques Pienaar* It allows reusing the constraints between, say, Tensor and Memref 75c4b4c0c4SJacques Pienaar representation of an operation; 76c4b4c0c4SJacques Pienaar 77c4b4c0c4SJacques PienaarAn argument could be made that these are metadata function instead of shape 78453cd2dbSBaden Hughesfunctions, with some considering shape and elemental types different and some considering them both as 79c4b4c0c4SJacques Pienaarpart of shape. But `shape function` is IMHO descriptive and metadata can span 80c4b4c0c4SJacques Pienaartoo large a range of potential uses/values. 81c4b4c0c4SJacques Pienaar 82c4b4c0c4SJacques Pienaar### Requirements 83c4b4c0c4SJacques Pienaar 84453cd2dbSBaden HughesThe requirements for the shape inference functions are determined by the 85c4b4c0c4SJacques Pienaarrequirements of shape inference, but we believe the requirements below still 86453cd2dbSBaden Hughesallow freedom to consider different shape inference approaches and so we do not 87453cd2dbSBaden Hughesimpose a particular shape inference approach here. 88c4b4c0c4SJacques Pienaar 89c4b4c0c4SJacques Pienaar#### Shape inference functions 90c4b4c0c4SJacques Pienaar 91c4b4c0c4SJacques Pienaar* **Expressiveness** shape functions need to support programs where tensors 92c4b4c0c4SJacques Pienaar have shapes that are not known statically (for example, `tensor<16x?xf32>` 93c4b4c0c4SJacques Pienaar or `tensor<*xf32>*`); 94c4b4c0c4SJacques Pienaar* **Shape error detection** Many operations will have constraints on their 95c4b4c0c4SJacques Pienaar operands. If the constraints are not satisfied or cannot be determined if 96c4b4c0c4SJacques Pienaar satisfied statically, then a runtime check/assertion could be generated. 97c4b4c0c4SJacques Pienaar 98c4b4c0c4SJacques Pienaar * This also aligns with the requirement that the shape function description 99c4b4c0c4SJacques Pienaar should be usable by both the compiler and runtime. 100c4b4c0c4SJacques Pienaar * Shape error functions should be easy to understand, at least what 101c4b4c0c4SJacques Pienaar constraint of the operation is violated. This also requires that shape 102c4b4c0c4SJacques Pienaar function error messages should be configurable by the author of the 103c4b4c0c4SJacques Pienaar shape function (e.g., the author would be able to give the semantic 104c4b4c0c4SJacques Pienaar constraint invalidated rather the low-level check that failed). 105c4b4c0c4SJacques Pienaar * The static analysis may be used to eliminate run-time checks that are 106c4b4c0c4SJacques Pienaar guaranteed to pass. 107c4b4c0c4SJacques Pienaar * Ideally all would eventually (see section 108c4b4c0c4SJacques Pienaar [Inlining shape checking](#inline)) be elided. 109453cd2dbSBaden Hughes * Only reporting errors which are guaranteed to occur at runtime. If an error is only 110453cd2dbSBaden Hughes possible (rather than guaranteed) then we use a runtime assertion to fail and produce an error 111c4b4c0c4SJacques Pienaar message with the invariant violated. 112c4b4c0c4SJacques Pienaar 113c4b4c0c4SJacques Pienaar* Shape functions usable by compiler and runtime. 114c4b4c0c4SJacques Pienaar 115c4b4c0c4SJacques Pienaar * This does not mean the exact same C++ function, but rather the 116c4b4c0c4SJacques Pienaar description should be consumable by either. 117c4b4c0c4SJacques Pienaar * Shape function description should not be constrained by either runtime 118c4b4c0c4SJacques Pienaar or compiler's type system to handle types only used for analysis. That 119c4b4c0c4SJacques Pienaar is, these two type systems differ and both should be supported, but the 120c4b4c0c4SJacques Pienaar intersection of the two should not be required. As a particular example, 121c4b4c0c4SJacques Pienaar if a compiler only wants to differentiate exact shapes vs dynamic 122b2f5fd84SKazuaki Ishizaki shapes, then it need not consider a more generic shape lattice even 123c4b4c0c4SJacques Pienaar though the shape description supports it. 124c4b4c0c4SJacques Pienaar 125c4b4c0c4SJacques Pienaar* Declarative (e.g., analyzable at compile time, possible to generate 126c4b4c0c4SJacques Pienaar different versions for different use cases) 127c4b4c0c4SJacques Pienaar 128c4b4c0c4SJacques Pienaar * This may not strictly be a requirement, but a way to handle the former: 129c4b4c0c4SJacques Pienaar a declarative specification could be reused by both while avoiding a 130c4b4c0c4SJacques Pienaar need to map to or from a 3rd representation given these two systems 131c4b4c0c4SJacques Pienaar have/and will have different types. 132c4b4c0c4SJacques Pienaar 133c4b4c0c4SJacques Pienaar* Shape inference functions are expressible at runtime 134c4b4c0c4SJacques Pienaar 135453cd2dbSBaden Hughes * User can define a shape function for a new operation dynamically at runtime, 136c4b4c0c4SJacques Pienaar this allows for vendors to describe an operation and shape function 137c4b4c0c4SJacques Pienaar dynamically. 138c4b4c0c4SJacques Pienaar 139c4b4c0c4SJacques Pienaar This requirement is on the wishlist. 140c4b4c0c4SJacques Pienaar 141c4b4c0c4SJacques Pienaar* Doesn't require graph-wide shape information (e.g., only require local 142c4b4c0c4SJacques Pienaar information) 143c4b4c0c4SJacques Pienaar 144c4b4c0c4SJacques Pienaar * Shape functions should be cheap to invoke on each kernel launch. 145453cd2dbSBaden Hughes * Shape function can be dictated by arguments (operands, attributes and regions) 146c4b4c0c4SJacques Pienaar only (e.g., same operands as the corresponding operation could be 147c4b4c0c4SJacques Pienaar constructed & invoked with). 148453cd2dbSBaden Hughes * Shape information that needs higher-level/graph information should use 149c4b4c0c4SJacques Pienaar richer types (e.g., `TensorList<F32>`); 150c4b4c0c4SJacques Pienaar * The function should be invocable before/while constructing an op (e.g., 151c4b4c0c4SJacques Pienaar can't rely on the op being constructed). 152c4b4c0c4SJacques Pienaar 153c4b4c0c4SJacques Pienaar* Shape functions should be pure functions. 154c4b4c0c4SJacques Pienaar 155c4b4c0c4SJacques Pienaar* Should support functions whose type is only known dynamically (e.g., 156c4b4c0c4SJacques Pienaar `read_from_file` op) 157c4b4c0c4SJacques Pienaar 158c4b4c0c4SJacques Pienaar * Without needing to invoke the op (e.g., reading a file once for 159c4b4c0c4SJacques Pienaar determining the shape & then post to be able to actually consume the 160c4b4c0c4SJacques Pienaar output of the file). 161c4b4c0c4SJacques Pienaar 162453cd2dbSBaden Hughes* The shape function operation dialect should be interoperable with non-shape function dialect operations. 163c4b4c0c4SJacques Pienaar 164453cd2dbSBaden Hughes * There may be a common set of operations that satisfy most uses (e.g., merge, 165c4b4c0c4SJacques Pienaar equal_type, arithmetic expressions, slice, concat, pattern matching on 166c4b4c0c4SJacques Pienaar attributes such as padding etc.) that will be discovered and could cover 167453cd2dbSBaden Hughes a large percentage of the use cases. Among these there will be some 168c4b4c0c4SJacques Pienaar which carry extra semantic info that could be used for symbolic 169c4b4c0c4SJacques Pienaar constraints (e.g., checking equality of two dimensions resulting in 170c4b4c0c4SJacques Pienaar setting an equality constraint) and higher-order interpretation for 171c4b4c0c4SJacques Pienaar constraint solving. 172c4b4c0c4SJacques Pienaar 173453cd2dbSBaden Hughes It is therefore beneficial (but not required) to reuse operations, 174453cd2dbSBaden Hughes especially as for statically known shapes, arbitrary arithmetic 175c4b4c0c4SJacques Pienaar computations could still be performed. This means that the computations 176c4b4c0c4SJacques Pienaar performed statically may or may not be supported by an arbitrary solver, 177c4b4c0c4SJacques Pienaar but would still be allowed. 178c4b4c0c4SJacques Pienaar 179c4b4c0c4SJacques Pienaar* The shape function should be expandable such that symbolic equality and 180c4b4c0c4SJacques Pienaar upper bound constraints (say) could be represented and may be propagated by 181c4b4c0c4SJacques Pienaar shape inference. 182c4b4c0c4SJacques Pienaar 183c4b4c0c4SJacques Pienaar * E.g., the shape functions may contain more information that is only 184c4b4c0c4SJacques Pienaar useful when used from shape inference; 185c4b4c0c4SJacques Pienaar 186c4b4c0c4SJacques Pienaar* Shape functions are allowed to fail and report an error. The error reporting 187c4b4c0c4SJacques Pienaar should report the location of the operation that failed with, where 188c4b4c0c4SJacques Pienaar possible, a user actionable error message. 189c4b4c0c4SJacques Pienaar 190c4b4c0c4SJacques Pienaar * These failures could become inlined and become runtime failures with 191c4b4c0c4SJacques Pienaar runtime values and error messages. 192c4b4c0c4SJacques Pienaar * Reporting errors should be optional. E.g., The same function 193c4b4c0c4SJacques Pienaar may be used as to query validity without reporting an error. 194c4b4c0c4SJacques Pienaar 195c4b4c0c4SJacques Pienaar#### Non-goals 196c4b4c0c4SJacques Pienaar 197c4b4c0c4SJacques Pienaar1. The shape dialect is an IR representations and not a programming language; 198c4b4c0c4SJacques Pienaar * While the functions should be readable, it doesn't carry the 199c4b4c0c4SJacques Pienaar conveniences of a programming language. Deciding how people write these 200c4b4c0c4SJacques Pienaar things, e.g. a mini dsl, a C++ API that generates them, extracting them 201c4b4c0c4SJacques Pienaar programmatically from `SetShapeFn` calls, etc., is still TBD. 202c4b4c0c4SJacques Pienaar1. Describe the shape inference approach that will use the shape functions; 203c4b4c0c4SJacques Pienaar * The goal is that the shape functions and the constraints one could 204c4b4c0c4SJacques Pienaar obtain from them are general enough that they would be useful for 205c4b4c0c4SJacques Pienaar various analysis. But whether we follow very simple (e.g., only fully 206c4b4c0c4SJacques Pienaar static information is used for shape output, unranked for everything 207c4b4c0c4SJacques Pienaar else) to very advance (e.g., expression trees of symbolic constants) can 208c4b4c0c4SJacques Pienaar be evaluated independently of this proposal and with concrete benefit 209c4b4c0c4SJacques Pienaar analysis. 210c4b4c0c4SJacques Pienaar1. Describe the approach whereby error messages will be generated; 211c4b4c0c4SJacques Pienaar * While the shape functions will be able to emit errors optionally, it 212c4b4c0c4SJacques Pienaar will be possible to dictate when they emit an error. This enables 213c4b4c0c4SJacques Pienaar deciding whether or which error to emit: there have been proposals in 214c4b4c0c4SJacques Pienaar the literature that the iteration order for shape inference affect the 215c4b4c0c4SJacques Pienaar quality of the error message produced, and the shape functions do not 216c4b4c0c4SJacques Pienaar mandate that. 217c4b4c0c4SJacques Pienaar1. Flow sensitive shape functions; 218c4b4c0c4SJacques Pienaar * To enable scalable/cheap shape inference, the shape functions do not 219c4b4c0c4SJacques Pienaar intend to provide flow sensitive information. This facility could 220286a7a40SMarkus Böck potentially be built as part of some higher order analysis that reuse 221c4b4c0c4SJacques Pienaar the shape functions/constraints due to the shape functions. 222c4b4c0c4SJacques Pienaar1. All static functions are usable for dynamic/unknown shapes; 223c4b4c0c4SJacques Pienaar * More involved computations can be performed with statically known shapes 224c4b4c0c4SJacques Pienaar than what can be sensibly analyzed with unknown/symbolic variables. 225c4b4c0c4SJacques Pienaar 226c4b4c0c4SJacques Pienaar### Discussion 227c4b4c0c4SJacques Pienaar 228c4b4c0c4SJacques Pienaar#### Inline shape inference checks {#inline} 229c4b4c0c4SJacques Pienaar 230c4b4c0c4SJacques PienaarShape functions should be lowerable to runtime checks for validity. E.g. verify 231c4b4c0c4SJacques Pienaaras much as possible statically, but enable generating instructions to compute the 232c4b4c0c4SJacques Pienaarshape dynamically and or falling back to runtime checks for attributes not 233c4b4c0c4SJacques Pienaarverifiable at compile time. These checks inserted should ideally only check that 234c4b4c0c4SJacques Pienaarwhich could not have been verified statically. 235c4b4c0c4SJacques Pienaar 236c4b4c0c4SJacques PienaarThese inlined calls could interfere with optimization patterns/passes (e.g., 237c4b4c0c4SJacques Pienaarshape inference should not insert constructs that interfere with optimization 238c4b4c0c4SJacques Pienaarpatterns) and so could be delayed until later (with another round of 239c4b4c0c4SJacques Pienaaroptimizations, constant folding, CSE, etc., that should remove redundant runtime 240c4b4c0c4SJacques Pienaaroperations). 241c4b4c0c4SJacques Pienaar 242c4b4c0c4SJacques Pienaar### Possibly Asked Questions 243c4b4c0c4SJacques Pienaar 244453cd2dbSBaden Hughes#### What about ODS specifications of operations? 245c4b4c0c4SJacques Pienaar 246c4b4c0c4SJacques PienaarIn ODS we have been recording the constraints for the operands & attributes of 247c4b4c0c4SJacques Pienaaran operation. Where these are sufficient to constrain the output shape (e.g., 248c4b4c0c4SJacques Pienaar`SameOperandAndResultType` or broadcastable) we should generate the shape 249c4b4c0c4SJacques Pienaarfunction from those. Where not, an explicit shape function should be specified 250c4b4c0c4SJacques Pienaar(spelling TBD but currently considering using the MLIR textual form as 251c4b4c0c4SJacques Pienaarserialization approach). 252c4b4c0c4SJacques Pienaar 253c4b4c0c4SJacques Pienaar#### Why not extract the shape function from reference implementation? 254c4b4c0c4SJacques Pienaar 255c4b4c0c4SJacques PienaarThis could be done in future! The extracted shape function would use the shape 256453cd2dbSBaden Hughesinference dialect, so we are starting there. Especially for operations described in a 257c4b4c0c4SJacques Pienaarstructured way, one could autogenerate the shape function. 258c4b4c0c4SJacques Pienaar 259c4b4c0c4SJacques Pienaar#### How/in what language will the shape functions be authored? 260c4b4c0c4SJacques Pienaar 261c4b4c0c4SJacques PienaarTBD. open to many approaches and suggestions, starting on the IR produced by 262c4b4c0c4SJacques Pienaarwhatever language is the priority of this proposal. 263c4b4c0c4SJacques Pienaar 264c4b4c0c4SJacques Pienaar#### What shape inference approach is being suggested here? 265c4b4c0c4SJacques Pienaar 266c4b4c0c4SJacques PienaarNone. There are multiple different shape inference approaches that we could 267c4b4c0c4SJacques Pienaarlayer on top of these. From the most basic (always return unranked), to more 268c4b4c0c4SJacques Pienaaruseful (return fixed shape for constant inputs/arguments) to the more advanced 269286a7a40SMarkus Böck(create logical conjunctions of algebraic statements between symbolic named 270c4b4c0c4SJacques Pienaarvalues). 271c4b4c0c4SJacques Pienaar 272c4b4c0c4SJacques Pienaar### Open points 273c4b4c0c4SJacques Pienaar 274c4b4c0c4SJacques Pienaar1. Should shape functions that produce dynamic outputs given all statically 275c4b4c0c4SJacques Pienaar shaped inputs be marked specially? E.g., read from file. 276c4b4c0c4SJacques Pienaar 277c4b4c0c4SJacques PienaarTODO: Add examples here. 278c4b4c0c4SJacques Pienaar 279fa26a37dSJacques Pienaar## WIP/Future considerations 280fa26a37dSJacques Pienaar 281fa26a37dSJacques PienaarShape functions are determined by attributes and could be arbitrarily 282fa26a37dSJacques Pienaarcomplicated with a wide-range of specification possibilities. Equality 283fa26a37dSJacques Pienaarrelationships are common (e.g., the elemental type of the output matches the 284fa26a37dSJacques Pienaarprimitive type of the inputs, both inputs have exactly the same type [primitive 285fa26a37dSJacques Pienaartype and shape]) and so these should be easy to specify. Algebraic relationships 286fa26a37dSJacques Pienaarwould also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 287fa26a37dSJacques Pienaaris `[n+n, m]` matrix), while some ops only have defined shapes under certain 288fa26a37dSJacques Pienaarcases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b 289fa26a37dSJacques Pienaar== c`). 290fa26a37dSJacques Pienaar 291fa26a37dSJacques PienaarInstead of specifying an additional mechanism to specify a shape transfer 292fa26a37dSJacques Pienaarfunction, the reference implementation of the operation will be used to derive 293fa26a37dSJacques Pienaarthe shape function. The reference implementation is general and can support the 294fa26a37dSJacques Pienaararbitrary computations needed to specify output shapes. 295fa26a37dSJacques Pienaar 29694fac81fSxgupta[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Interfaces/InferTypeOpInterface.td 29794fac81fSxgupta[ShapedType]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/IR/BuiltinTypes.h 298