xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md (revision 1a8fb887197caf709710bedf88ce95ffb0605c56)
1# Resharding Spmdization Examples
2
3Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh.
4
5unsharded `2x3` tensor
6```
711 12 13
821 22 23
9```
10
11sharded on a `2x3` mesh
12
13sharding = `[[0, 1]]`
14
15mesh contents:
16
17```
18mesh axis 1
19----------->
20+----+----+----+ mesh axis 0 |
21| 11 | 12 | 13 |             |
22+----+----+----+             |
23| 21 | 22 | 23 |             |
24+----+----+----+             ↓
25```
26
27Transform into
28sharding = `[[1, 0]]`
29```
30mesh axis 1
31----------->
32+----+----+----+ mesh axis 0 |
33| 11 | 13 | 22 |             |
34+----+----+----+             |
35| 12 | 21 | 23 |             |
36+----+----+----+             ↓
37```
38Algorithm:
39Swap contents on devices that have the same linear index in the 2 shardings.
40
41--------------------------------------------------------------
42
43Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh.
44
45unsharded `2x3` tensor
46```
4711 12 13
4821 22 23
49```
50
51sharded on a `2x3` mesh
52
53sharding = `[[0, 1]]`
54
55mesh contents:
56```
57mesh axis 1
58----------->
59+----+----+----+ mesh axis 0 |
60| 11 | 12 | 13 |             |
61+----+----+----+             |
62| 21 | 22 | 23 |             |
63+----+----+----+             ↓
64```
65
66Transform into
67sharding = `[[1]]`
68```
69mesh axis 1
70----------->
71+----+----+----+ mesh axis 0 |
72| 11 | 12 | 13 |             |
73| 21 | 22 | 23 |             |
74+----+----+----+             |
75| 11 | 12 | 13 |             |
76| 21 | 22 | 23 |             |
77+----+----+----+             ↓
78```
79Algorithm:
80All-gather along mesh axis 0.
81
82--------------------------------------------------------------
83
84Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh.
85
86unsharded `4x6` tensor
87```
8811 12 13 14 15 16
8921 22 23 24 25 26
90```
91
92sharded on a `2x3` mesh
93
94sharding = `[[], [0, 1]]`
95
96mesh contents:
97```
98mesh axis 1
99----------->
100+----+----+----+ mesh axis 0 |
101| 11 | 12 | 13 |             |
102| 21 | 22 | 23 |             |
103+----+----+----+             |
104| 14 | 15 | 16 |             |
105| 24 | 25 | 26 |             |
106+----+----+----+             ↓
107```
108Transform into
109sharding = `[[], [0]]`
110```
111mesh axis 1
112----------->
113+----------+----------+ mesh axis 0 |
114| 11 12 13 | 11 12 13 |             |
115| 21 22 23 | 21 22 23 |             |
116+----------+----------+             |
117| 14 15 16 | 14 15 16 |             |
118| 24 25 26 | 24 25 26 |             |
119+----------+----------+             ↓
120```
121Algorithm:
122All-gather along mesh axis 1.
123
124--------------------------------------------------------------
125
126Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh.
127
128unsharded `4x8` tensor
129```
13011 12 13 14 15 16 17 18
13121 22 23 24 25 26 27 28
13231 32 33 34 35 36 37 38
13341 42 43 44 45 46 47 48
134```
135sharded on a `2x2x2` mesh
136
137sharding = `[[0], [1, 2]]`
138
139mesh contents:
140```
141mesh axis 2
142----------->
143+-------+-------+ mesh axis 1 | mesh axis 0 |
144| 11 12 | 13 14 |             |             |
145| 21 22 | 23 24 |             |             |
146+-------+-------+             |             |
147| 15 16 | 17 18 |             |             |
148| 25 26 | 27 28 |             |             |
149+-------+-------+             ↓             |
150+-------+-------+                           |
151| 31 32 | 33 34 |                           |
152| 41 42 | 43 44 |                           |
153+-------+-------+                           |
154| 35 36 | 37 38 |                           |
155| 45 46 | 47 48 |                           |
156+-------+-------+                           ↓
157```
158Transform into
159sharding = `[[0], [2]]`
160```
161mesh axis 2
162----------->
163+-------------+-------------+ mesh axis 1 | mesh axis 0 |
164| 11 12 13 14 | 15 16 17 18 |             |             |
165| 21 22 23 24 | 25 26 27 28 |             |             |
166+-------------+-------------+             |             |
167| 11 12 13 14 | 15 16 17 18 |             |             |
168| 21 22 23 24 | 25 26 27 28 |             |             |
169+-------------+-------------+             ↓             |
170+-------------+-------------+                           |
171| 31 32 33 34 | 35 36 37 38 |                           |
172| 41 42 43 44 | 45 46 47 48 |                           |
173+-------------+-------------+                           |
174| 31 32 33 34 | 35 36 37 38 |                           |
175| 41 42 43 44 | 45 46 47 48 |                           |
176+-------------+-------------+                           ↓
177```
178Algorithm:
179
180Can't be done with just an all-gather along mesh axis 1.
181Can be handled by multiple resharding transformations
182`[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]`
183
184--------------------------------------------------------------
185
186Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
187
188unsharded `6x6` tensor
189```
19011 12 13 14 15 16
19121 22 23 24 25 26
19231 32 33 34 35 36
19341 42 43 44 45 46
19451 52 53 54 55 56
19561 62 63 64 65 66
196```
197sharded on a `2x3` mesh
198
199sharding = `[[0], [1]]`
200```
201mesh axis 1
202----------->
203+-------+-------+-------+ mesh axis 0 |
204| 11 12 | 13 14 | 15 16 |             |
205| 21 22 | 23 24 | 25 26 |             |
206| 31 32 | 33 34 | 35 36 |             |
207+-------+-------+-------+             |
208| 41 42 | 43 44 | 45 46 |             |
209| 51 52 | 53 54 | 55 56 |             |
210| 61 62 | 63 64 | 65 66 |             |
211+-------+-------+-------+             ↓
212```
213transform to
214sharding = `[[1], [0]]`
215```
216mesh axis 1
217----------->
218+----------+----------+----------+ mesh axis 0 |
219| 11 12 13 | 31 32 33 | 51 52 53 |             |
220| 21 22 23 | 41 42 43 | 61 62 63 |             |
221+----------+----------+----------+             |
222| 14 15 16 | 34 35 36 | 54 55 56 |             |
223| 24 25 26 | 44 45 46 | 64 65 66 |             |
224+----------+----------+----------+             ↓
225
226mesh axis 0
227----------->
228+----------+----------+ mesh axis 1 |
229| 11 12 13 | 14 15 16 |             |
230| 21 22 23 | 24 25 26 |             |
231+----------+----------+             |
232| 31 32 33 | 34 35 36 |             |
233| 41 42 43 | 44 45 46 |             |
234+----------+----------+             |
235| 51 52 53 | 54 55 56 |             |
236| 61 62 63 | 64 65 66 |             |
237+----------+----------+             ↓
238```
239Algorithm: TODO
240
241--------------------------------------------------------------
242
243Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh.
244
245unsharded 6x6 tensor
246```
24711 12 13 14 15 16
24821 22 23 24 25 26
24931 32 33 34 35 36
25041 42 43 44 45 46
25151 52 53 54 55 56
25261 62 63 64 65 66
253```
254shard on `2x6` mesh
255
256sharding = `[[0], [1]]`
257```
258mesh axis 1
259----------->
260+----+----+----+----+----+----+ mesh axis 0 |
261| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
262| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
263| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
264+----+----+----+----+----+----+             |
265| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
266| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
267| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
268+----+----+----+----+----+----+             ↓
269```
270transform to
271sharding = `[[1], [0]]`
272```
273mesh axis 0
274----------->
275+----------+----------+ mesh axis 1 |
276| 11 12 13 | 14 15 16 |             |
277+----------+----------+             |
278| 21 22 23 | 24 25 26 |             |
279+----------+----------+             |
280| 31 32 33 | 34 35 36 |             |
281+==========+==========+             |
282| 41 42 43 | 44 45 46 |             |
283+----------+----------+             |
284| 51 52 53 | 54 55 56 |             |
285+----------+----------+             |
286| 61 62 63 | 64 65 66 |             |
287+----------+----------+             ↓
288```
289Algorithm: TODO
290
291--------------------------------------------------------------
292
293Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh.
294
295`M x N` mesh.
296`K x L` tensor `t`.
297`d(m, n)` the tensor on device `(m, n)`.
298
299sharding = `[[0], [1]]`
300Tensor shard s on each device has size `(K ceildiv M, L ceildiv N)`.
301```
302d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
303```
304substitute
305```
306i <- m * (K ceildiv M) + k
307j <- n * (L ceildiv N) + l
308```
309```
310m -> i floordiv (K ceildiv M)
311n -> j floordiv (L ceildiv N)
312k -> i % (K ceildiv M)
313l -> j % (L ceildiv N)
314```
315For the inverse map we get
316```
317t[i, j] -> d(
318    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
319)[
320    i % (K ceildiv M), j % (L ceildiv N)
321]
322```
323Check:
324```
325i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
326t[13, 17] = d(
327    13 floordiv (16 ceildiv 3),
328    17 floordiv (23 ceilvid 4)
329)[
330    13 % (16 ceildiv 3),
331    17 % (23 ceilvid 4)
332]
333= d(
334    13 floordiv 6,
335    17 floordiv 6
336)[
337    13 % 6,
338    17 % 6
339]
340= d(2, 2)[1, 5]
341= t[
342    2 * (16 ceildiv 3) + 1,
343    2 * (23 ceildiv 4) + 5
344]
345= t[
346    2 * 6 + 1,
347    2 * 6 + 5
348]
349= t[13, 17]
350```
351
352sharding = `[[1], [0]]`
353Tensor shard s on each device has size `(K ceildiv N, L ceildiv M)`.
354```
355d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
356```
357substitute
358```
359i <- n * (K ceildiv N) + k
360j <- m * (L ceildiv M) + l
361```
362```
363m -> j floordiv (L ceildiv M)
364n -> i floordiv (K ceildiv N)
365k -> i % (K ceildiv N)
366l -> j % (L ceildiv M)
367```
368For the inverse map we get
369```
370t[i, j] -> d(
371    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
372)[
373    i % (K ceildiv N), j % (L ceildiv M)
374]
375```
376Check:
377```
378i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
379t[9, 19] = d(
380    19 floordiv (14 ceildiv 5),
381    9 floordiv (27 ceildiv 2)
382)[
383    9 % (27 ceildiv 2),
384    19 % (14 ceildiv 5)
385]
386= d(
387    19 floordiv 3,
388    9 floordiv 14
389)[
390    9 % 14
391    19 % 3
392]
393= d(6, 0)[9, 1]
394= t[
395    0 * (27 ceildiv 2) + 9,
396    6 * (14 ceildiv 5) + 1
397]
398= t[
399    0 * 14 + 9,
400    6 * 3 + 1
401]
402= t[9, 19]
403```
404sharding = `[[0], [1]]`
405```
406d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
407t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
408```
409sharding = `[[1], [0]]`
410```
411d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
412t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
413```
414sharding `[[0], [1]] -> [[1], [0]]`
415`d1(m, n)` the tensor on device `(m, n)` for sharding sharding `[[0], [1]]`.
416`d2(m, n)` the tensor on device `(m, n)` for sharding sharding `[[1], [0]]`.
417```
418d1(m, n)[k, l] ->
419t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
420d2(
421    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
422    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
423)[
424    (n * (K ceildiv N) + k) % (K ceildiv N),
425    (m * (L ceildiv M) + l) % (L ceildiv M)
426]
427= d2(p, q)[u, v]
428```
429We want to copy the the data between devices in slices/tiles.
430What are the source/target tile coordinates?
431For a fixed `(m, n, p, q)` what is the range of `(k, l, u, v)`?
432TODO
433
434--------------------------------------------------------------
435
436Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
437
438Device placement on a `2x3` mesh
439```
44011 12 13  <- devices
44121 22 23
442```
443sharding `[[0], [1]]`
444```
445tensor axis 1
446----------->
447+----+----+----+ tensor axis 0 |
448| 11 | 12 | 13 |               |
449+----+----+----+               |
450| 21 | 22 | 23 |               |
451+----+----+----+               ↓
452```
453transform to
454sharding `[[1], [0]]`
455```
456tensor axis 1
457----------->
458+----+----+ tensor axis 0 |
459| 11 | 21 |               |
460+----+----+               |
461| 12 | 22 |               |
462+----+----+               |
463| 13 | 23 |               |
464+----+----+               ↓
465```
466```
467+-----------------+--------+--------+-----------------+
468|                 |                 |                 |
469+                 +                 +                 +
470|       11        |        12       |        13       |
471+                 +                 +                 +
472|                 |                 |                 |
473+-----------------+--------+--------+-----------------+
474|                 |                 |                 |
475+                 +                 +                 +
476|       21        |        22       |        23       |
477+                 +                 +                 +
478|                 |                 |                 |
479+-----------------+--------+--------+-----------------+
480
481+-----------------+--------+--------+-----------------+
482|                          |                          |
483+          11              +               21         +
484|                          |                          |
485+-----------------+--------+--------+-----------------+
486|                          |                          |
487+          12              +               22         +
488|                          |                          |
489+-----------------+--------+--------+-----------------+
490|                          |                          |
491+          13              +               23         +
492|                          |                          |
493+-----------------+--------+--------+-----------------+
494
495+-----------------+--------+--------+-----------------+
496|                 |        |        |                 |
497+     11  11      + 12  11 + 12  21 +     13  21      +
498|                 |        |        |                 |
499+-----------------+--------+--------+-----------------+
500|     11  12      | 12  12 | 12  22 |     13  22      |
501+-----------------+--------+--------+-----------------+
502|     21  12      | 22  12 | 22  22 |     23  22      |
503+-----------------+--------+--------+-----------------+
504|                 |        |        |                 |
505+     21  13      + 22  13 + 22  23 +     23  23      +
506|                 |        |        |                 |
507+-----------------+--------+--------+-----------------+
508```
509If `S` and `T` are the source and target shard sizes along some tensor axis.
510Then we have a period of `(S*T)/gcd(S, T)`. Then the cut pattern repeats.
511TODO
512
513--------------------------------------------------------------
514
515Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh.
516
517unsharded `6x6` tensor
518```
51911 12 13 14 15 16
52021 22 23 24 25 26
52131 32 33 34 35 36
52241 42 43 44 45 46
52351 52 53 54 55 56
52461 62 63 64 65 66
525```
526sharded on a `3` mesh
527
528sharding = `[[0], []]`
529```
530+-------------------+ mesh axis 0 |
531| 11 12 13 14 15 16 |             |
532| 21 22 23 24 25 26 |             |
533+-------------------+             |
534| 31 32 33 34 35 36 |             |
535| 41 42 43 44 45 46 |             |
536+-------------------+             |
537| 51 52 53 54 55 56 |             |
538| 61 62 63 64 65 66 |             |
539+-------------------+             ↓
540```
541transform to
542sharding = `[[], [0]]`
543```
544mesh axis 0
545----------->
546+-------+-------+-------+
547| 11 12 | 13 14 | 15 16 |
548| 21 22 | 23 24 | 25 26 |
549| 31 32 | 33 34 | 35 36 |
550| 41 42 | 43 44 | 45 46 |
551| 51 52 | 53 54 | 55 56 |
552| 61 62 | 63 64 | 65 66 |
553+-------+-------+-------+
554```
555Algorithm:
556```mlir
557%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
558```
559--------------------------------------------------------------
560
561Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh.
562
563unsharded `4x4` tensor
564```
56511 12 13 14
56621 22 23 24
56731 32 33 34
56841 42 43 44
569```
570sharded on a `2x2x2` mesh
571
572sharding = `[[0], [1, 2]]`
573```
574mesh axis 2
575----------->
576+----+----+ mesh axis 1 | mesh axis 0 |
577| 11 | 12 |             |             |
578| 21 | 22 |             |             |
579+----+----+             |             |
580| 13 | 14 |             |             |
581| 23 | 24 |             |             |
582+----+----+             ↓             |
583+----+----+                           |
584| 31 | 32 |                           |
585| 41 | 42 |                           |
586+----+----+                           |
587| 33 | 34 |                           |
588| 43 | 44 |                           |
589+----+----+                           ↓
590```
591transform to
592sharding = `[[0, 1], [2]]`
593```
594mesh axis 2
595----------->
596+-------+-------+ mesh axis 1 | mesh axis 0 |
597| 11 12 | 13 41 |             |             |
598+-------+-------+             |             |
599| 21 22 | 23 24 |             |             |
600+-------+-------+             ↓             |
601+-------+-------+                           |
602| 31 32 | 33 34 |                           |
603+-------+-------+                           |
604| 41 42 | 43 44 |                           |
605+-------+-------+                           ↓
606```
607Algorithm:
608```mlir
609%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
610```
611is not enough.
612
613Can be decomposed into
614```
615[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
616```
617
618## Decomposition into basis of reshardings
619
620We can decompose each resharding into a sequence of basis reshardings.
621It is not communication efficient in terms of minimizing the data communicated
622between devices.
623An efficient approach would be more complicated to implement.
624Each device has to receive at most as much data as the size of its target
625sharding tensor.
626
627--------------------------------------------------------------
628
629Basis:
630
631*   From replicate to split.
632    ```
633    [[]] -> [[1]]
634    ```
635    Extract slices without communication.
636
637* From split to replicate.
638    ```
639    [[0]] -> [[]]
640    [[0, 1]] -> [[1]]
641    ```
642    All-gather along mesh axis 0.
643
644*   Swap mesh axes order when assigned to the same tensor axis.
645    ```
646    [[0, 1]] -> [[1, 0]]
647    ```
648    Swap contents on devices with the same linear index.
649
650*   Move mesh axis to different tensor dimension.
651    ```
652    [[0], []] -> [[], [0]]
653    ```
654    All-to-all.
655
656--------------------------------------------------------------
657
658Example decomposition of
659```
660[[0], [1]] -> [[1], [0]]
661```
662into
663```
664[[0], [1]] -> all-gather along mesh axis 1    ->
665[[0], []]  -> all-to-all along mesh axis 0    ->
666[[], [0]]  -> extract slice along mesh axis 1 ->
667[[1], [0]]
668```
669
670--------------------------------------------------------------
671
672Example decomposition of
673```
674[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
675```
676into
677```
678[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
679[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
680[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
681[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
682[[0], [1, 2], []]
683```
684