xref: /llvm-project/mlir/docs/Traits/Broadcastable.md (revision bd077e98e463933e72bbd7bd03c6432d529e710c)
1# The `Broadcastable` Trait
2
3[TOC]
4
5## Description
6
7The `Broadcastable` trait enforces the following properties on an operation:
8
9- The operation has at least one input operand.
10
11- The operation has exactly one result.
12
13- All input operands and result are of type `tensor` or `vector`.
14
15- A shape inference mechanism is able to compute the result shape solely based on input operand shapes.
16
17- Input operands have broadcast-compatible shapes, according to the verification rules presented below.
18
19- The operation's result shape is compatible with —though not necessarily identical to— the shape inferred from its input operands, according to the verification rules presented below.
20
21
22## Dimension inference
23
24Given an operation with two input operands, the size of dimension `i` of its result can be inferred from dimension `i` of the operands according to the table below. Here, `dim0` and `dim1` represent dimension `i` of the input operands in an interchangeable order, while `inferredDim` represents the inferred size for dimension `i` of the operation result. Dimensions are classified in three categories: dynamic ("?"), static equal to 1 ("1"), and static greater than 1 (">1").
25
26
27| `dim0` | `dim1` | `inferredDim` | Notes |
28| -------- | -------- | ------------- | ----- |
29| ? | ? | ? | If `RuntimeSize(dim0)` is 1, dimension `dim0` is broadcast to `RuntimeSize(dim1)`. If `RuntimeSize(dim1)` is 1, dimension `dim1` is broadcast to `RuntimeSize(dim0)`. The operation produces undefined behavior if both runtime sizes are greater than 1 and not equal. |
30| ? | 1 | ? | Dimension `dim1` is broadcast to `RuntimeSize(dim0)`. |
31| ? | >1 | `dim1` | If `RuntimeSize(dim0)` is 1, `dim0` is broadcast to `dim1`. The operation produces undefined behavior if `RuntimeSize(dim0)` is greater than 1 and not equal to `dim1`. |
32| 1 | 1 | 1 | |
33| 1 | >1 | `dim1` | Dimension `dim0` is broadcast to `dim1`. |
34| >1 | >1 | `dim0` | The operation verifier produces a compile-time error if `dim0` != `dim1`. |
35
36
37The following pseudo-function is a formal representation of the dimension inference process:
38
39```python
40InferDim(dim0, dim1):
41	switch (dim0, dim1):
42		case (?, ?):
43		case (?, 1):
44		case (1, 1):
45		case (>1, ?):
46		case (>1, 1):
47			return dim0
48		case (?, >1):
49		case (1, ?):
50		case (1, >1):
51			return dim1
52		case (>1, >1):
53			ERROR_IF(dim0 != dim1)
54			return dim0
55```
56
57## Shape inference
58
59The shape inference process begins by correcting rank differences in input operands. A shape is expanded by adding additional dimensions of size 1 on its left until the desired rank is reached, as shown here:
60
61```python
62ExpandRank(shape, rank):
63	while len(shape) < rank:
64		shape.prepend(1)
65```
66
67Given the shapes of two ranked input operands, the result's shape is inferred by equalizing input ranks and inferring individual dimensions, as shown here:
68
69```python
70InferShape(shape0, shape1):
71
72  # Equalize ranks
73  rank = max(GetRank(shape0), GetRank(shape1))
74  ExpandRank(shape0, rank)
75  ExpandRank(shape1, rank)
76
77  # Infer shape
78  inferredShape = []
79  for (dim0, dim1) in zip(shape0, shape1):
80    inferredDim = InferDim(dim0, dim1)
81    inferredShape.append(inferredDim)
82  return inferredShape
83```
84
85The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:
86
87```python
88InferResultShape(op):
89
90	# Filter ranked operands
91	rankedOperands = filter(op.operands, IsRanked)
92	if len(rankedOperands) == 0:
93		return None
94
95	# Infer result shape
96	inferredShape = GetShape(rankedOperands[0])
97	for operand in rankedOperands[1:]:
98		inferredShape = InferShape(inferredShape, GetShape(operand))
99	return inferredShape
100```
101
102## Verification
103
104The legality of an operation with the `Broadcastable` trait is verified by first running the shape inference process. If a failure occurs during shape inference, it is concluded that input operands are not broadcast-compatible, and verification fails. If shape inference succeeds, verification continues.
105
106If either the result is unranked or all input operands are unranked, no further verification steps are needed, and the process ends here successfully. If, on the contrary, both the result and at least one input operand are ranked, verification continues by checking for a matching rank between the previously inferred shape and the result.
107
108Once a rank match is guaranteed, each dimension of the inferred shape is compared with the corresponding dimension of the actual result shape according to the following table table:
109
110
111| `inferredDim` | `actualDim` | Verification outcome |
112| ------------- | ----------- | -------------------- |
113| ? | ? | **OK** |
114| ? | static | **OK** <br> A failure to guarantee that the runtime dimension size of the result is equal to `actualDim` causes undefined behavior. While unusual, this implicit dynamic-to-static cast is convenient in certain scenarios, such as an intermediate state of a shape inference pass. Ultimately, a static dimension in the result implies that all input dimension sizes are also known at compile time and may therefore become static as well, preferably. |
115| static | ? | **OK** <br> The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
116| static | static | **OK if equal** <br> When both the inferred and actual dimensions are static, they must be set to the same size. |
117
118
119The full verification process can be formally specified as follows:
120
121```python
122Verify(op):
123
124	# Run shape inference
125	inferredShape = InferResultShape(op.operands)
126
127	# Done if result is unranked or all operands are unranked
128	if not IsRanked(op.result) or inferredShape is None:
129		return
130
131	# Rank must match
132	actualShape = GetShape(op.result):
133	ERROR_IF(len(inferredShape) != len(actualShape))
134
135	# Verify
136	for (inferredDim, actualDim) in zip(inferredShape, actualShape):
137		ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
138```
139
140## Examples
141
142The following are correct uses of broadcastable ops:
143
144```mlir
145// Exact match of static sizes.
146%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xi32) -> tensor<1x2xi32>
147
148// Dynamic sizes match. The programmer must guarantee that the runtime sizes of
149// %arg0 and %arg1 are equal at runtime.
150%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<?xi32>
151
152// The shape of %arg0 is broadcast from tensor<1xi32> to tensor<4xi32>.
153%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<4xi32) -> tensor<4xi32>
154
155// The shape of %result is inferred as tensor<4xi32>, while the actual result
156// type is tensor<?xi32>. The inferred shape is compatible with the actual shape.
157%result = "test.broadcastable"(%arg0) : (tensor<4xi32) -> tensor<?xi32>
158
159// The shape of %arg0 is first expanded to tensor<1x1x4xi32> and then broadcast
160// to tensor<2x3x4xi32>.
161%result = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x3x4xi32) -> tensor<2x3x4xi32>
162
163// Input and results tensors have different element types (i1, i32, i64). The
164// 'Broadcastable' trait has no restrictions on element types.
165%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi32) -> tensor<2xi64>
166
167// No result shape verification is needed when the result is unranked.
168%result = "test.broadcastable"(%arg0) : (tensor<2xi32>) -> tensor<*xi32>
169
170// No result shape verification needed when all inputs are unranked.
171%result = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
172```
173
174
175The following are incorrect uses of broadcastable ops:
176
177```mlir
178// Dimension 0 of input operands is static but not equal.
179%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32) -> tensor<?xi32>
180
181// The inferred result shape is tensor<3xi32>, but the actual result shape is
182// tensor<1x3xi32>. Inferred and actual shapes differ in rank.
183%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32) -> tensor<1x3xi32>
184
185// The inferred result shape is tensor<?xi32>, but the actual shape is
186// tensor<4xi32>. The inferred shape is not compatible with the actual shape.
187%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<4xi32>
188
189// The inferred result shape is tensor<2xi32>, but the actual result shape is
190// tensor<4xi32>, which is not compatible.
191%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32) -> tensor<4xi32>
192
193// The inferred result shape is tensor<1xi32>, but the actual result shape is
194// tensor<4xi32>. Broadcast semantics are not applicable for results.
195%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
196```
197
198
199