1# 'shape' Dialect 2 3Description of operations & types within the Shape dialect as well as their 4[usage](#different-stages-of-lowering-shape-dialect). 5 6[include "Dialects/ShapeDialectOps.md"] 7 8## Different stages of lowering Shape dialect 9 10In this section we shall give a brief overview of the different uses of the 11shape dialect and the lowering between these uses. Currently we have 3 worlds / 12stages of lowering of shape functions: 13 141. _Error monadic/error carrying/user specification_: 15 This "input" form carries both the shape and whether in error state as 16 value. Hence at this level all operations are pure operations producing and 17 consuming values where the values could represent an error. 18 192. _Constrained_: 20 This form uses a variant of explicit evidence passing to allow leveraging 21 existing compiler infrastructure to preserve safety information during 22 optimization. 23 243. _Side-effecting/asserting_: 25 This final lowered form is imperative form with side-effecting ops (e.g., 26 assert) for final codegen. 27 28We are going to do a quick step through of the lowering using the example of 29a matmul. 30 31Starting from the shape function of matmul in the error monadic form 32below[^wip_form1]: 33 34```mlir 35shape.function_library @shplib { 36 37func.func @matmul(%lhs: !shape.value_shape, %rhs: !shape.value_shape) -> !shape.shape { 38 %c1 = shape.const_size 1 39 %c2 = shape.const_size 2 40 // We could also allow rank etc operations directly on value_shape too, that 41 // would make it nicer as "input" language, but keeping it explicit inside the 42 // IR instead and then we could have helper methods in front-end language. 43 %lhs_shape = shape.shape_of %lhs : !shape.value_shape -> !shape.shape 44 %rhs_shape = shape.shape_of %rhs : !shape.value_shape -> !shape.shape 45 %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size 46 %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size 47 // This is not minimal as one could ensure the ranks are the same below, also a 48 // variadic meet would make it more concise too. 49 %r = "shape.meet"(%lhs_rank, %rhs_rank) : (!shape.size, !shape.size) -> !shape.size 50 %rank = shape.meet %c2, %r, error="requires rank 2 operands" : 51 !shape.size, !shape.size -> !shape.size 52 %l0, %l1 = "shape.split_at"(%lhs_shape, %c1) : 53 (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape) 54 %r0, %r1 = "shape.split_at"(%rhs_shape, %c1) : 55 (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape) 56 %c = shape.meet %l1, %r0, error="inner dimensions required to match" : 57 !shape.shape, !shape.shape -> !shape.shape 58 %res = shape.concat %l0, %r1 59 // Should have `shape.return %res requires %c, %rank` to enable 60 return %res : !shape.shape 61} 62 63} mapping { 64 foo.matmul = @matmul 65} 66``` 67 68* We are using the default builtin func and return here. Preferably we'd use 69 ‘shape\_func’ as a special function op that allows passing multiple results 70 back that affect correct execution (e.g., serves as an error join) 71 * This would also means one can't reify it inside a regular function 72 without handling the shape.return - that is a feature here as these are 73 more of a template. 74 * Currently we also have not marked `meet` as having no side-effects to 75 avoid DCE until we have `shape.return`, at which point computing the 76 meet could be treated as purely computational returning error. 77* Meet represents a constraint that should hold, so should not be used to see 78 *if* something is equal. E.g., this means `meet` can't be used to represent 79 80 ``` 81 either(meet(x, y), meet(y,z)) 82 ``` 83 84* This could have been written more concisely as something like 85 86 ``` 87 concat(lhs[0], rhs[1]) if rank(lhs) == 2 && 88 rank(rhs) == 2 && lhs[1] == rhs[0] 89 ``` 90 91 but not focusing on front-end proper here. 92 93We are going to lower to "most" nested form directly (see 94[test](https://github.com/tensorflow/tensorflow/blob/64062b5c51e04e370df26551d247496787d3f5c2/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L3088) 95for an example reification along with legalization). In the above this was in a 96separate shape function library, while here we would normally reify it as part 97of lowering, but for simplicity will show as a standalone shape function. 98 99```mlir 100func.func @matmul_shape1(%lhs: tensor<*xf32>, %rhs: tensor<*xindex>) -> tensor<?xindex> { 101 %c1 = shape.const_size 1 102 %c2 = shape.const_size 2 103 // We allow `shape.shape_of` to return either a `!shape.shape` or 104 // `tensor<?xindex>` type, in the case where the input is a tensor the most 105 // refined type is a tensor of `index` but not required. 106 %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> !shape.shape 107 %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> !shape.shape 108 %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size 109 %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size 110 %w1 = shape.cstr_eq %lhs_rank, %rhs_rank : !shape.witness 111 %res = shape.assuming %w1 -> tensor<?xindex> { 112 %r1 = shape.any %lhs_rank, %rhs_rank : (!shape.size, !shape.size) -> !shape.size 113 // Error message needs an addition, currently only on cstr_require. 114 %w2 = shape.cstr_eq %c2, %r1, error="requires rank 2 operands" 115 %res_1 = shape.assuming %w2 -> tensor<?xindex> { 116 // Here the lowered 117 // %rank = shape.any %c2, %r1 (!shape.size, !shape.size) -> !shape.size 118 // is dead and so elided further. But if `%rank` was actually consumed, 119 // then it could have been folded in `shape.any`. 120 %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) : 121 (!shape.shape, !shape.size) -> !shape.shape 122 %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) : 123 (!shape.shape, !shape.size) -> !shape.shape 124 %c = shape.meet %l1, %r0, error="inner dimensions required to match" : 125 !shape.size, !shape.size -> !shape.size 126 %res = concat(%l0, %r1) 127 shape.assuming_yield %res 128 } 129 shape.assuming_yield %res_1 130 } 131 return %res : tensor<?xindex> 132} 133``` 134 135We can now hoist computations of constraint were possible (which in the case 136below is not too many as we need to verify the rank before we can split) 137 138```mlir 139func.func @matmul_shape2(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> { 140 %c1 = shape.const_size 1 141 %c2 = shape.const_size 2 142 %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex> 143 %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex> 144 %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index> 145 %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index> 146 %w1 = shape.cstr_eq %c2, %lhs_rank, error="requires rank 2 operands" 147 %w2 = shape.cstr_eq %c2, %rhs_rank, error="requires rank 2 operands" 148 %w = shape.assuming_all %w1, %w2 149 %res = shape.assuming %w -> tensor<?xindex> { 150 %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) : 151 (tensor<?xindex>, !shape.size) -> tensor<?xindex> 152 %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) : 153 (tensor<?xindex>, !shape.size) -> tensor<?xindex> 154 %w3 = shape.cstr_eq %l1, %r0, error="inner dimensions required to match" 155 %res_2 = shape.assuming %w3 { 156 %res = concat(%l0, %r1) 157 shape.assuming_yield %res 158 } 159 shape.assuming_yield %res_1 160 } 161 return %res 162} 163``` 164 165The above form can now be lowered to the fully imperative form (see 166[test](https://github.com/tensorflow/mlir-hlo/blob/af14e1ded33c3164d4418c5d234b5b346b6d017c/tests/rank-specialization.mlir#L22) 167for example). 168 169```mlir 170func.func @matmul_shape3(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> { 171 %c1 = arith.constant 1 : index 172 %c2 = arith.constant 2 : index 173 %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex> 174 %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex> 175 %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index> 176 %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index> 177 %w1 = shape.shape_eq %lhs_rank, %rhs_rank 178 %w2 = shape.shape_eq %c2, %lhs_rank 179 %w3 = and %w1, %w2 180 assert %w3, "requires rank 2 operands" 181 %l0, %l1 = shape.split_at(%lhs_shape, %c1) : tensor<?xindex> 182 %r0, %r1 = shape.split_at(%rhs_shape, %c1) : tensor<?xindex> 183 %w4 = shape.eq %l1, %r0 184 assert %w4, "inner dimensions required to match" 185 %res = concat(%l0, %r1) 186 return %res 187} 188``` 189 190* In this case form 3 is as easy and closer to form 1 (but only as no 191 reordering was required). So it is a good question if the frontend authoring 192 language could be more similar to the imperative form (under discussion). 193* The above form presented here is an intermittent form during a lowering 194 pass. If used as input we would need to restrict the optimizations on it as 195 the `shape` dialect operations are no longer connected by producer-consumer 196 to enforce guard checking. 197 198The above could be further lowered by using `tensor.dim`, `tensor.from_elements` 199etc (or one could even lower these by way of, say, MHLO or TOSA dialect). 200 201[^wip_form1]: This form is least use inside the current workflows and needs more work. In particular in the example we use `shape_func` where in the code we instead use standard func as first form 1 isn't used explicitly. 202