xref: /llvm-project/mlir/docs/ShapeInference.md (revision 64bb0ae75f44ee6a09e749164edfac247a3e1a10)
155de49acSRiver Riddle# Shape Inference
2fa26a37dSJacques Pienaar
3fa26a37dSJacques PienaarShape inference as discussed here is considered a specific instance of type
4fa26a37dSJacques Pienaarinference for [ShapedType][ShapedType]. Type constraints are along (at least)
5fa26a37dSJacques Pienaarthree axis: 1) elemental type, 2) rank (including static or dynamic), 3)
6fa26a37dSJacques Pienaardimensions. While some operations have no compile time fixed shape (e.g., output
7fa26a37dSJacques Pienaarshape is dictated by data) we could still have some knowledge of
8fa26a37dSJacques Pienaarconstraints/bounds in the system for that operation (e.g., the output of a
9fa26a37dSJacques Pienaar`tf.where` is at most the size of the input data). That is, there are additional
10fa26a37dSJacques Pienaarvaluable constraints that could be captured even without full knowledge of the
11fa26a37dSJacques Pienaarshape.
12fa26a37dSJacques Pienaar
13453cd2dbSBaden HughesType inference is currently modelled executionally for operation creation using the
14fa26a37dSJacques Pienaar[`InferTypeOpInterface`][InferTypeOpInterface], while
15fa26a37dSJacques Pienaar`InferShapedTypeOpInterface` is used to implement the shape and element type
16fa26a37dSJacques Pienaarinference. The return type can often be deduced from the deduced return shape
17fa26a37dSJacques Pienaarand elemental type (queryable from `InferShapedTypeOpInterface`) and so type
18fa26a37dSJacques Pienaarinference for tensor types can be implemented with `InferShapedTypeOpInterface`.
19fa26a37dSJacques Pienaar
20*64bb0ae7SAlex Zinenko[TOC]
21*64bb0ae7SAlex Zinenko
22fa26a37dSJacques Pienaar## Shape functions
23fa26a37dSJacques Pienaar
24fa26a37dSJacques PienaarThe C++ interfaces are the base mechanism whereby shape inference is queried and
25fa26a37dSJacques Pienaarexecuted, but not the intended way to specify shape constraints in general.
26fa26a37dSJacques Pienaar
27fa26a37dSJacques PienaarInitially the shape inference will be declaratively specified using:
28fa26a37dSJacques Pienaar
29fa26a37dSJacques Pienaar*   Constraints on the operands of an operation directly. For example
30fa26a37dSJacques Pienaar    constraining the input type to be tensor/vector elements or that the
31fa26a37dSJacques Pienaar    elemental type be of a specific type (e.g., output of computing the size
32453cd2dbSBaden Hughes    of a value is of elemental type `i1`) or class (e.g., float-like).
33fa26a37dSJacques Pienaar*   Constraints across operands and results of an operation.
34fa26a37dSJacques Pienaar
35fa26a37dSJacques Pienaar    - For example, specifying equality constraints on type/constituents of a
36fa26a37dSJacques Pienaar      type (shape and elemental type) between operands and results (e.g., the
37fa26a37dSJacques Pienaar      output type of an add is the same as those of the input operands).
38fa26a37dSJacques Pienaar
39fa26a37dSJacques PienaarNOTE: The C++ shape functions are an intermediate step until the shape dialect
40fa26a37dSJacques Pienaaris more full-fledged, at which point the C++ functions should become the
41fa26a37dSJacques Pienaarexceptional case.
42fa26a37dSJacques Pienaar
43fa26a37dSJacques Pienaar## Testing
44fa26a37dSJacques Pienaar
45fa26a37dSJacques PienaarShape inference is currently tested alongside type inference by
46453cd2dbSBaden Hughes`TestReturnTypeDriver` in the test dialect. This driver performs two checks:
47fa26a37dSJacques Pienaar
484dc39ae7SJacques Pienaar1.  Verification that the return types specified matches the inferred types. This
49fc817b09SKazuaki Ishizaki    explicit check will be removed and made part of Op verification instead.
50fa26a37dSJacques Pienaar2.  Test the creation of Ops without specifying the return type explicitly in
51fa26a37dSJacques Pienaar    function `testCreateFunctions` by creating new binary Ops (Op classes
52fa26a37dSJacques Pienaar    specified in `TestReturnTypeDriver`) using 1) all operands to
53fa26a37dSJacques Pienaar    `testCreateFunctions` as both operands, and 2) using combinations of input
54fa26a37dSJacques Pienaar    operands of the function.
55fa26a37dSJacques Pienaar
56c4b4c0c4SJacques Pienaar## Shape dialect
57c4b4c0c4SJacques Pienaar
58c4b4c0c4SJacques PienaarThis section details the shape type inference dialect (`shape`). The initial
59c4b4c0c4SJacques Pienaarfocus will be on shape functions that describe shape functions could be used in
60c4b4c0c4SJacques Pienaarruntime and compiler (for constructions of ops/refinement of shapes, reification
61c4b4c0c4SJacques Pienaarof dynamic allocations for dialect including TF, TFLite, XLA & tensor compute
62c4b4c0c4SJacques Pienaardialect under discussion).
63c4b4c0c4SJacques Pienaar
64c4b4c0c4SJacques PienaarThis will focus on the shape functions (e.g., determine the rank and dimensions
65c4b4c0c4SJacques Pienaarof the output shape). As shown in the shaped container type, shape will be one
66c4b4c0c4SJacques Pienaarof 3 components, the others being elemental type and attribute (which is
67c4b4c0c4SJacques Pienaarcurrently left open with the intention of supporting extensions such as layouts
68453cd2dbSBaden Hughesor bounded shapes at a later point). This allows for decoupling of these:
69c4b4c0c4SJacques Pienaar
70c4b4c0c4SJacques Pienaar*   Not all the information is needed for all analysis;
71c4b4c0c4SJacques Pienaar*   Not all shape functions need to provide all the information (e.g., one could
72c4b4c0c4SJacques Pienaar    define a base class function that only populates element type but composes
73c4b4c0c4SJacques Pienaar    with the others);
74c4b4c0c4SJacques Pienaar*   It allows reusing the constraints between, say, Tensor and Memref
75c4b4c0c4SJacques Pienaar    representation of an operation;
76c4b4c0c4SJacques Pienaar
77c4b4c0c4SJacques PienaarAn argument could be made that these are metadata function instead of shape
78453cd2dbSBaden Hughesfunctions, with some considering shape and elemental types different and some considering them both as
79c4b4c0c4SJacques Pienaarpart of shape. But `shape function` is IMHO descriptive and metadata can span
80c4b4c0c4SJacques Pienaartoo large a range of potential uses/values.
81c4b4c0c4SJacques Pienaar
82c4b4c0c4SJacques Pienaar### Requirements
83c4b4c0c4SJacques Pienaar
84453cd2dbSBaden HughesThe requirements for the shape inference functions are determined by the
85c4b4c0c4SJacques Pienaarrequirements of shape inference, but we believe the requirements below still
86453cd2dbSBaden Hughesallow freedom to consider different shape inference approaches and so we do not
87453cd2dbSBaden Hughesimpose a particular shape inference approach here.
88c4b4c0c4SJacques Pienaar
89c4b4c0c4SJacques Pienaar#### Shape inference functions
90c4b4c0c4SJacques Pienaar
91c4b4c0c4SJacques Pienaar*   **Expressiveness** shape functions need to support programs where tensors
92c4b4c0c4SJacques Pienaar    have shapes that are not known statically (for example, `tensor<16x?xf32>`
93c4b4c0c4SJacques Pienaar    or `tensor<*xf32>*`);
94c4b4c0c4SJacques Pienaar*   **Shape error detection** Many operations will have constraints on their
95c4b4c0c4SJacques Pienaar    operands. If the constraints are not satisfied or cannot be determined if
96c4b4c0c4SJacques Pienaar    satisfied statically, then a runtime check/assertion could be generated.
97c4b4c0c4SJacques Pienaar
98c4b4c0c4SJacques Pienaar    *   This also aligns with the requirement that the shape function description
99c4b4c0c4SJacques Pienaar        should be usable by both the compiler and runtime.
100c4b4c0c4SJacques Pienaar    *   Shape error functions should be easy to understand, at least what
101c4b4c0c4SJacques Pienaar        constraint of the operation is violated. This also requires that shape
102c4b4c0c4SJacques Pienaar        function error messages should be configurable by the author of the
103c4b4c0c4SJacques Pienaar        shape function (e.g., the author would be able to give the semantic
104c4b4c0c4SJacques Pienaar        constraint invalidated rather the low-level check that failed).
105c4b4c0c4SJacques Pienaar    *   The static analysis may be used to eliminate run-time checks that are
106c4b4c0c4SJacques Pienaar        guaranteed to pass.
107c4b4c0c4SJacques Pienaar        *   Ideally all would eventually (see section
108c4b4c0c4SJacques Pienaar            [Inlining shape checking](#inline)) be elided.
109453cd2dbSBaden Hughes    *   Only reporting errors which are guaranteed to occur at runtime. If an error is only
110453cd2dbSBaden Hughes        possible (rather than guaranteed) then we use a runtime assertion to fail and produce an error
111c4b4c0c4SJacques Pienaar        message with the invariant violated.
112c4b4c0c4SJacques Pienaar
113c4b4c0c4SJacques Pienaar*   Shape functions usable by compiler and runtime.
114c4b4c0c4SJacques Pienaar
115c4b4c0c4SJacques Pienaar    *   This does not mean the exact same C++ function, but rather the
116c4b4c0c4SJacques Pienaar        description should be consumable by either.
117c4b4c0c4SJacques Pienaar    *   Shape function description should not be constrained by either runtime
118c4b4c0c4SJacques Pienaar        or compiler's type system to handle types only used for analysis. That
119c4b4c0c4SJacques Pienaar        is, these two type systems differ and both should be supported, but the
120c4b4c0c4SJacques Pienaar        intersection of the two should not be required. As a particular example,
121c4b4c0c4SJacques Pienaar        if a compiler only wants to differentiate exact shapes vs dynamic
122b2f5fd84SKazuaki Ishizaki        shapes, then it need not consider a more generic shape lattice even
123c4b4c0c4SJacques Pienaar        though the shape description supports it.
124c4b4c0c4SJacques Pienaar
125c4b4c0c4SJacques Pienaar*   Declarative (e.g., analyzable at compile time, possible to generate
126c4b4c0c4SJacques Pienaar    different versions for different use cases)
127c4b4c0c4SJacques Pienaar
128c4b4c0c4SJacques Pienaar    *   This may not strictly be a requirement, but a way to handle the former:
129c4b4c0c4SJacques Pienaar        a declarative specification could be reused by both while avoiding a
130c4b4c0c4SJacques Pienaar        need to map to or from a 3rd representation given these two systems
131c4b4c0c4SJacques Pienaar        have/and will have different types.
132c4b4c0c4SJacques Pienaar
133c4b4c0c4SJacques Pienaar*   Shape inference functions are expressible at runtime
134c4b4c0c4SJacques Pienaar
135453cd2dbSBaden Hughes    *   User can define a shape function for a new operation dynamically at runtime,
136c4b4c0c4SJacques Pienaar        this allows for vendors to describe an operation and shape function
137c4b4c0c4SJacques Pienaar        dynamically.
138c4b4c0c4SJacques Pienaar
139c4b4c0c4SJacques Pienaar        This requirement is on the wishlist.
140c4b4c0c4SJacques Pienaar
141c4b4c0c4SJacques Pienaar*   Doesn't require graph-wide shape information (e.g., only require local
142c4b4c0c4SJacques Pienaar    information)
143c4b4c0c4SJacques Pienaar
144c4b4c0c4SJacques Pienaar    *   Shape functions should be cheap to invoke on each kernel launch.
145453cd2dbSBaden Hughes    *   Shape function can be dictated by arguments (operands, attributes and regions)
146c4b4c0c4SJacques Pienaar        only (e.g., same operands as the corresponding operation could be
147c4b4c0c4SJacques Pienaar        constructed & invoked with).
148453cd2dbSBaden Hughes    *   Shape information that needs higher-level/graph information should use
149c4b4c0c4SJacques Pienaar        richer types (e.g., `TensorList<F32>`);
150c4b4c0c4SJacques Pienaar    *   The function should be invocable before/while constructing an op (e.g.,
151c4b4c0c4SJacques Pienaar        can't rely on the op being constructed).
152c4b4c0c4SJacques Pienaar
153c4b4c0c4SJacques Pienaar*   Shape functions should be pure functions.
154c4b4c0c4SJacques Pienaar
155c4b4c0c4SJacques Pienaar*   Should support functions whose type is only known dynamically (e.g.,
156c4b4c0c4SJacques Pienaar    `read_from_file` op)
157c4b4c0c4SJacques Pienaar
158c4b4c0c4SJacques Pienaar    *   Without needing to invoke the op (e.g., reading a file once for
159c4b4c0c4SJacques Pienaar        determining the shape & then post to be able to actually consume the
160c4b4c0c4SJacques Pienaar        output of the file).
161c4b4c0c4SJacques Pienaar
162453cd2dbSBaden Hughes*   The shape function operation dialect should be interoperable with non-shape function dialect operations.
163c4b4c0c4SJacques Pienaar
164453cd2dbSBaden Hughes    *   There may be a common set of operations that satisfy most uses (e.g., merge,
165c4b4c0c4SJacques Pienaar        equal_type, arithmetic expressions, slice, concat, pattern matching on
166c4b4c0c4SJacques Pienaar        attributes such as padding etc.) that will be discovered and could cover
167453cd2dbSBaden Hughes        a large percentage of the use cases. Among these there will be some
168c4b4c0c4SJacques Pienaar        which carry extra semantic info that could be used for symbolic
169c4b4c0c4SJacques Pienaar        constraints (e.g., checking equality of two dimensions resulting in
170c4b4c0c4SJacques Pienaar        setting an equality constraint) and higher-order interpretation for
171c4b4c0c4SJacques Pienaar        constraint solving.
172c4b4c0c4SJacques Pienaar
173453cd2dbSBaden Hughes        It is therefore beneficial (but not required) to reuse operations,
174453cd2dbSBaden Hughes        especially as for statically known shapes, arbitrary arithmetic
175c4b4c0c4SJacques Pienaar        computations could still be performed. This means that the computations
176c4b4c0c4SJacques Pienaar        performed statically may or may not be supported by an arbitrary solver,
177c4b4c0c4SJacques Pienaar        but would still be allowed.
178c4b4c0c4SJacques Pienaar
179c4b4c0c4SJacques Pienaar*   The shape function should be expandable such that symbolic equality and
180c4b4c0c4SJacques Pienaar    upper bound constraints (say) could be represented and may be propagated by
181c4b4c0c4SJacques Pienaar    shape inference.
182c4b4c0c4SJacques Pienaar
183c4b4c0c4SJacques Pienaar    *   E.g., the shape functions may contain more information that is only
184c4b4c0c4SJacques Pienaar        useful when used from shape inference;
185c4b4c0c4SJacques Pienaar
186c4b4c0c4SJacques Pienaar*   Shape functions are allowed to fail and report an error. The error reporting
187c4b4c0c4SJacques Pienaar    should report the location of the operation that failed with, where
188c4b4c0c4SJacques Pienaar    possible, a user actionable error message.
189c4b4c0c4SJacques Pienaar
190c4b4c0c4SJacques Pienaar    *   These failures could become inlined and become runtime failures with
191c4b4c0c4SJacques Pienaar        runtime values and error messages.
192c4b4c0c4SJacques Pienaar    *   Reporting errors should be optional. E.g., The same function
193c4b4c0c4SJacques Pienaar        may be used as to query validity without reporting an error.
194c4b4c0c4SJacques Pienaar
195c4b4c0c4SJacques Pienaar#### Non-goals
196c4b4c0c4SJacques Pienaar
197c4b4c0c4SJacques Pienaar1.  The shape dialect is an IR representations and not a programming language;
198c4b4c0c4SJacques Pienaar    *   While the functions should be readable, it doesn't carry the
199c4b4c0c4SJacques Pienaar        conveniences of a programming language. Deciding how people write these
200c4b4c0c4SJacques Pienaar        things, e.g. a mini dsl, a C++ API that generates them, extracting them
201c4b4c0c4SJacques Pienaar        programmatically from `SetShapeFn` calls, etc., is still TBD.
202c4b4c0c4SJacques Pienaar1.  Describe the shape inference approach that will use the shape functions;
203c4b4c0c4SJacques Pienaar    *   The goal is that the shape functions and the constraints one could
204c4b4c0c4SJacques Pienaar        obtain from them are general enough that they would be useful for
205c4b4c0c4SJacques Pienaar        various analysis. But whether we follow very simple (e.g., only fully
206c4b4c0c4SJacques Pienaar        static information is used for shape output, unranked for everything
207c4b4c0c4SJacques Pienaar        else) to very advance (e.g., expression trees of symbolic constants) can
208c4b4c0c4SJacques Pienaar        be evaluated independently of this proposal and with concrete benefit
209c4b4c0c4SJacques Pienaar        analysis.
210c4b4c0c4SJacques Pienaar1.  Describe the approach whereby error messages will be generated;
211c4b4c0c4SJacques Pienaar    *   While the shape functions will be able to emit errors optionally, it
212c4b4c0c4SJacques Pienaar        will be possible to dictate when they emit an error. This enables
213c4b4c0c4SJacques Pienaar        deciding whether or which error to emit: there have been proposals in
214c4b4c0c4SJacques Pienaar        the literature that the iteration order for shape inference affect the
215c4b4c0c4SJacques Pienaar        quality of the error message produced, and the shape functions do not
216c4b4c0c4SJacques Pienaar        mandate that.
217c4b4c0c4SJacques Pienaar1.  Flow sensitive shape functions;
218c4b4c0c4SJacques Pienaar    *   To enable scalable/cheap shape inference, the shape functions do not
219c4b4c0c4SJacques Pienaar        intend to provide flow sensitive information. This facility could
220286a7a40SMarkus Böck        potentially be built as part of some higher order analysis that reuse
221c4b4c0c4SJacques Pienaar        the shape functions/constraints due to the shape functions.
222c4b4c0c4SJacques Pienaar1.  All static functions are usable for dynamic/unknown shapes;
223c4b4c0c4SJacques Pienaar    *   More involved computations can be performed with statically known shapes
224c4b4c0c4SJacques Pienaar        than what can be sensibly analyzed with unknown/symbolic variables.
225c4b4c0c4SJacques Pienaar
226c4b4c0c4SJacques Pienaar### Discussion
227c4b4c0c4SJacques Pienaar
228c4b4c0c4SJacques Pienaar#### Inline shape inference checks {#inline}
229c4b4c0c4SJacques Pienaar
230c4b4c0c4SJacques PienaarShape functions should be lowerable to runtime checks for validity. E.g. verify
231c4b4c0c4SJacques Pienaaras much as possible statically, but enable generating instructions to compute the
232c4b4c0c4SJacques Pienaarshape dynamically and or falling back to runtime checks for attributes not
233c4b4c0c4SJacques Pienaarverifiable at compile time. These checks inserted should ideally only check that
234c4b4c0c4SJacques Pienaarwhich could not have been verified statically.
235c4b4c0c4SJacques Pienaar
236c4b4c0c4SJacques PienaarThese inlined calls could interfere with optimization patterns/passes (e.g.,
237c4b4c0c4SJacques Pienaarshape inference should not insert constructs that interfere with optimization
238c4b4c0c4SJacques Pienaarpatterns) and so could be delayed until later (with another round of
239c4b4c0c4SJacques Pienaaroptimizations, constant folding, CSE, etc., that should remove redundant runtime
240c4b4c0c4SJacques Pienaaroperations).
241c4b4c0c4SJacques Pienaar
242c4b4c0c4SJacques Pienaar### Possibly Asked Questions
243c4b4c0c4SJacques Pienaar
244453cd2dbSBaden Hughes#### What about ODS specifications of operations?
245c4b4c0c4SJacques Pienaar
246c4b4c0c4SJacques PienaarIn ODS we have been recording the constraints for the operands & attributes of
247c4b4c0c4SJacques Pienaaran operation. Where these are sufficient to constrain the output shape (e.g.,
248c4b4c0c4SJacques Pienaar`SameOperandAndResultType` or broadcastable) we should generate the shape
249c4b4c0c4SJacques Pienaarfunction from those. Where not, an explicit shape function should be specified
250c4b4c0c4SJacques Pienaar(spelling TBD but currently considering using the MLIR textual form as
251c4b4c0c4SJacques Pienaarserialization approach).
252c4b4c0c4SJacques Pienaar
253c4b4c0c4SJacques Pienaar#### Why not extract the shape function from reference implementation?
254c4b4c0c4SJacques Pienaar
255c4b4c0c4SJacques PienaarThis could be done in future! The extracted shape function would use the shape
256453cd2dbSBaden Hughesinference dialect, so we are starting there. Especially for operations described in a
257c4b4c0c4SJacques Pienaarstructured way, one could autogenerate the shape function.
258c4b4c0c4SJacques Pienaar
259c4b4c0c4SJacques Pienaar#### How/in what language will the shape functions be authored?
260c4b4c0c4SJacques Pienaar
261c4b4c0c4SJacques PienaarTBD. open to many approaches and suggestions, starting on the IR produced by
262c4b4c0c4SJacques Pienaarwhatever language is the priority of this proposal.
263c4b4c0c4SJacques Pienaar
264c4b4c0c4SJacques Pienaar#### What shape inference approach is being suggested here?
265c4b4c0c4SJacques Pienaar
266c4b4c0c4SJacques PienaarNone. There are multiple different shape inference approaches that we could
267c4b4c0c4SJacques Pienaarlayer on top of these. From the most basic (always return unranked), to more
268c4b4c0c4SJacques Pienaaruseful (return fixed shape for constant inputs/arguments) to the more advanced
269286a7a40SMarkus Böck(create logical conjunctions of algebraic statements between symbolic named
270c4b4c0c4SJacques Pienaarvalues).
271c4b4c0c4SJacques Pienaar
272c4b4c0c4SJacques Pienaar### Open points
273c4b4c0c4SJacques Pienaar
274c4b4c0c4SJacques Pienaar1.  Should shape functions that produce dynamic outputs given all statically
275c4b4c0c4SJacques Pienaar    shaped inputs be marked specially? E.g., read from file.
276c4b4c0c4SJacques Pienaar
277c4b4c0c4SJacques PienaarTODO: Add examples here.
278c4b4c0c4SJacques Pienaar
279fa26a37dSJacques Pienaar## WIP/Future considerations
280fa26a37dSJacques Pienaar
281fa26a37dSJacques PienaarShape functions are determined by attributes and could be arbitrarily
282fa26a37dSJacques Pienaarcomplicated with a wide-range of specification possibilities. Equality
283fa26a37dSJacques Pienaarrelationships are common (e.g., the elemental type of the output matches the
284fa26a37dSJacques Pienaarprimitive type of the inputs, both inputs have exactly the same type [primitive
285fa26a37dSJacques Pienaartype and shape]) and so these should be easy to specify. Algebraic relationships
286fa26a37dSJacques Pienaarwould also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
287fa26a37dSJacques Pienaaris `[n+n, m]` matrix), while some ops only have defined shapes under certain
288fa26a37dSJacques Pienaarcases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b
289fa26a37dSJacques Pienaar== c`).
290fa26a37dSJacques Pienaar
291fa26a37dSJacques PienaarInstead of specifying an additional mechanism to specify a shape transfer
292fa26a37dSJacques Pienaarfunction, the reference implementation of the operation will be used to derive
293fa26a37dSJacques Pienaarthe shape function. The reference implementation is general and can support the
294fa26a37dSJacques Pienaararbitrary computations needed to specify output shapes.
295fa26a37dSJacques Pienaar
29694fac81fSxgupta[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
29794fac81fSxgupta[ShapedType]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/IR/BuiltinTypes.h
298