xref: /llvm-project/mlir/docs/Dialects/ShapeDialect.md (revision 2310ced8745b28a79a4ff7f08d461605e52d153d)
1# 'shape' Dialect
2
3Description of operations & types within the Shape dialect as well as their
4[usage](#different-stages-of-lowering-shape-dialect).
5
6[include "Dialects/ShapeDialectOps.md"]
7
8## Different stages of lowering Shape dialect
9
10In this section we shall give a brief overview of the different uses of the
11shape dialect and the lowering between these uses. Currently we have 3 worlds /
12stages of lowering of shape functions:
13
141.  _Error monadic/error carrying/user specification_:
15    This "input" form carries both the shape and whether in error state as
16    value. Hence at this level all operations are pure operations producing and
17    consuming values where the values could represent an error.
18
192.  _Constrained_:
20    This form uses a variant of explicit evidence passing to allow leveraging
21    existing compiler infrastructure to preserve safety information during
22    optimization.
23
243.  _Side-effecting/asserting_:
25    This final lowered form is imperative form with side-effecting ops (e.g.,
26    assert) for final codegen.
27
28We are going to do a quick step through of the lowering using the example of
29a matmul.
30
31Starting from the shape function of matmul in the error monadic form
32below[^wip_form1]:
33
34```mlir
35shape.function_library @shplib {
36
37func.func @matmul(%lhs: !shape.value_shape, %rhs: !shape.value_shape) -> !shape.shape {
38  %c1 = shape.const_size 1
39  %c2 = shape.const_size 2
40  // We could also allow rank etc operations directly on value_shape too, that
41  // would make it nicer as "input" language, but keeping it explicit inside the
42  // IR instead and then we could have helper methods in front-end language.
43  %lhs_shape = shape.shape_of %lhs : !shape.value_shape -> !shape.shape
44  %rhs_shape = shape.shape_of %rhs : !shape.value_shape -> !shape.shape
45  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
46  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
47  // This is not minimal as one could ensure the ranks are the same below, also a
48  // variadic meet would make it more concise too.
49  %r = "shape.meet"(%lhs_rank, %rhs_rank) : (!shape.size, !shape.size) -> !shape.size
50  %rank = shape.meet %c2, %r, error="requires rank 2 operands" :
51    !shape.size, !shape.size -> !shape.size
52  %l0, %l1 = "shape.split_at"(%lhs_shape, %c1) :
53    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
54  %r0, %r1 = "shape.split_at"(%rhs_shape, %c1) :
55    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
56  %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
57    !shape.shape, !shape.shape -> !shape.shape
58  %res = shape.concat %l0, %r1
59  // Should have `shape.return %res requires %c, %rank` to enable
60  return %res : !shape.shape
61}
62
63} mapping {
64  foo.matmul = @matmul
65}
66```
67
68*   We are using the default builtin func and return here. Preferably we'd use
69    ‘shape\_func’ as a special function op that allows passing multiple results
70    back that affect correct execution (e.g., serves as an error join)
71    *   This would also means one can't reify it inside a regular function
72        without handling the shape.return - that is a feature here as these are
73        more of a template.
74    *   Currently we also have not marked `meet` as having no side-effects to
75        avoid DCE until we have `shape.return`, at which point computing the
76        meet could be treated as purely computational returning error.
77*   Meet represents a constraint that should hold, so should not be used to see
78    *if* something is equal. E.g., this means `meet` can't be used to represent
79
80    ```
81       either(meet(x, y), meet(y,z))
82    ```
83
84*   This could have been written more concisely as something like
85
86    ```
87      concat(lhs[0], rhs[1]) if rank(lhs) == 2 &&
88        rank(rhs) == 2 && lhs[1] == rhs[0]
89    ```
90
91    but not focusing on front-end proper here.
92
93We are going to lower to "most" nested form directly (see
94[test](https://github.com/tensorflow/tensorflow/blob/64062b5c51e04e370df26551d247496787d3f5c2/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L3088)
95for an example reification along with legalization). In the above this was in a
96separate shape function library, while here we would normally reify it as part
97of lowering, but for simplicity will show as a standalone shape function.
98
99```mlir
100func.func @matmul_shape1(%lhs: tensor<*xf32>, %rhs: tensor<*xindex>) -> tensor<?xindex> {
101  %c1 = shape.const_size 1
102  %c2 = shape.const_size 2
103  // We allow `shape.shape_of` to return either a `!shape.shape` or
104  // `tensor<?xindex>` type, in the case where the input is a tensor the most
105  // refined type is a tensor of `index` but not required.
106  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> !shape.shape
107  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> !shape.shape
108  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
109  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
110  %w1 = shape.cstr_eq %lhs_rank, %rhs_rank : !shape.witness
111  %res = shape.assuming %w1 -> tensor<?xindex> {
112    %r1 = shape.any %lhs_rank, %rhs_rank : (!shape.size, !shape.size) -> !shape.size
113    // Error message needs an addition, currently only on cstr_require.
114    %w2 = shape.cstr_eq %c2, %r1, error="requires rank 2 operands"
115    %res_1 = shape.assuming %w2 -> tensor<?xindex> {
116      // Here the lowered
117      //   %rank = shape.any %c2, %r1 (!shape.size, !shape.size) -> !shape.size
118      // is dead and so elided further. But if `%rank` was actually consumed,
119      // then it could have been folded in `shape.any`.
120      %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
121        (!shape.shape, !shape.size) -> !shape.shape
122      %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
123        (!shape.shape, !shape.size) -> !shape.shape
124      %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
125        !shape.size, !shape.size -> !shape.size
126      %res = concat(%l0, %r1)
127      shape.assuming_yield %res
128    }
129    shape.assuming_yield %res_1
130  }
131  return %res : tensor<?xindex>
132}
133```
134
135We can now hoist computations of constraint were possible (which in the case
136below is not too many as we need to verify the rank before we can split)
137
138```mlir
139func.func @matmul_shape2(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
140  %c1 = shape.const_size 1
141  %c2 = shape.const_size 2
142  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
143  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
144  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
145  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
146  %w1 = shape.cstr_eq %c2, %lhs_rank, error="requires rank 2 operands"
147  %w2 = shape.cstr_eq %c2, %rhs_rank, error="requires rank 2 operands"
148  %w = shape.assuming_all %w1, %w2
149  %res = shape.assuming %w -> tensor<?xindex> {
150    %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
151      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
152    %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
153      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
154    %w3 = shape.cstr_eq %l1, %r0, error="inner dimensions required to match"
155    %res_2 = shape.assuming %w3 {
156      %res = concat(%l0, %r1)
157      shape.assuming_yield %res
158    }
159    shape.assuming_yield %res_1
160  }
161  return %res
162}
163```
164
165The above form can now be lowered to the fully imperative form (see
166[test](https://github.com/tensorflow/mlir-hlo/blob/af14e1ded33c3164d4418c5d234b5b346b6d017c/tests/rank-specialization.mlir#L22)
167for example).
168
169```mlir
170func.func @matmul_shape3(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
171  %c1 = arith.constant 1 : index
172  %c2 = arith.constant 2 : index
173  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
174  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
175  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
176  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
177  %w1 = shape.shape_eq %lhs_rank, %rhs_rank
178  %w2 = shape.shape_eq %c2, %lhs_rank
179  %w3 = and %w1, %w2
180  assert %w3, "requires rank 2 operands"
181  %l0, %l1 = shape.split_at(%lhs_shape, %c1) : tensor<?xindex>
182  %r0, %r1 = shape.split_at(%rhs_shape, %c1) : tensor<?xindex>
183  %w4 = shape.eq %l1, %r0
184  assert %w4, "inner dimensions required to match"
185  %res = concat(%l0, %r1)
186  return %res
187}
188```
189
190*   In this case form 3 is as easy and closer to form 1 (but only as no
191    reordering was required). So it is a good question if the frontend authoring
192    language could be more similar to the imperative form (under discussion).
193*   The above form presented here is an intermittent form during a lowering
194    pass. If used as input we would need to restrict the optimizations on it as
195    the `shape` dialect operations are no longer connected by producer-consumer
196    to enforce guard checking.
197
198The above could be further lowered by using `tensor.dim`, `tensor.from_elements`
199etc (or one could even lower these by way of, say, MHLO or TOSA dialect).
200
201[^wip_form1]: This form is least use inside the current workflows and needs more work. In particular in the example we use `shape_func` where in the code we instead use standard func as first form 1 isn't used explicitly.
202