1# Resharding Spmdization Examples 2 3Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh. 4 5unsharded `2x3` tensor 6``` 711 12 13 821 22 23 9``` 10 11sharded on a `2x3` mesh 12 13sharding = `[[0, 1]]` 14 15mesh contents: 16 17``` 18mesh axis 1 19-----------> 20+----+----+----+ mesh axis 0 | 21| 11 | 12 | 13 | | 22+----+----+----+ | 23| 21 | 22 | 23 | | 24+----+----+----+ ↓ 25``` 26 27Transform into 28sharding = `[[1, 0]]` 29``` 30mesh axis 1 31-----------> 32+----+----+----+ mesh axis 0 | 33| 11 | 13 | 22 | | 34+----+----+----+ | 35| 12 | 21 | 23 | | 36+----+----+----+ ↓ 37``` 38Algorithm: 39Swap contents on devices that have the same linear index in the 2 shardings. 40 41-------------------------------------------------------------- 42 43Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh. 44 45unsharded `2x3` tensor 46``` 4711 12 13 4821 22 23 49``` 50 51sharded on a `2x3` mesh 52 53sharding = `[[0, 1]]` 54 55mesh contents: 56``` 57mesh axis 1 58-----------> 59+----+----+----+ mesh axis 0 | 60| 11 | 12 | 13 | | 61+----+----+----+ | 62| 21 | 22 | 23 | | 63+----+----+----+ ↓ 64``` 65 66Transform into 67sharding = `[[1]]` 68``` 69mesh axis 1 70-----------> 71+----+----+----+ mesh axis 0 | 72| 11 | 12 | 13 | | 73| 21 | 22 | 23 | | 74+----+----+----+ | 75| 11 | 12 | 13 | | 76| 21 | 22 | 23 | | 77+----+----+----+ ↓ 78``` 79Algorithm: 80All-gather along mesh axis 0. 81 82-------------------------------------------------------------- 83 84Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh. 85 86unsharded `4x6` tensor 87``` 8811 12 13 14 15 16 8921 22 23 24 25 26 90``` 91 92sharded on a `2x3` mesh 93 94sharding = `[[], [0, 1]]` 95 96mesh contents: 97``` 98mesh axis 1 99-----------> 100+----+----+----+ mesh axis 0 | 101| 11 | 12 | 13 | | 102| 21 | 22 | 23 | | 103+----+----+----+ | 104| 14 | 15 | 16 | | 105| 24 | 25 | 26 | | 106+----+----+----+ ↓ 107``` 108Transform into 109sharding = `[[], [0]]` 110``` 111mesh axis 1 112-----------> 113+----------+----------+ mesh axis 0 | 114| 11 12 13 | 11 12 13 | | 115| 21 22 23 | 21 22 23 | | 116+----------+----------+ | 117| 14 15 16 | 14 15 16 | | 118| 24 25 26 | 24 25 26 | | 119+----------+----------+ ↓ 120``` 121Algorithm: 122All-gather along mesh axis 1. 123 124-------------------------------------------------------------- 125 126Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh. 127 128unsharded `4x8` tensor 129``` 13011 12 13 14 15 16 17 18 13121 22 23 24 25 26 27 28 13231 32 33 34 35 36 37 38 13341 42 43 44 45 46 47 48 134``` 135sharded on a `2x2x2` mesh 136 137sharding = `[[0], [1, 2]]` 138 139mesh contents: 140``` 141mesh axis 2 142-----------> 143+-------+-------+ mesh axis 1 | mesh axis 0 | 144| 11 12 | 13 14 | | | 145| 21 22 | 23 24 | | | 146+-------+-------+ | | 147| 15 16 | 17 18 | | | 148| 25 26 | 27 28 | | | 149+-------+-------+ ↓ | 150+-------+-------+ | 151| 31 32 | 33 34 | | 152| 41 42 | 43 44 | | 153+-------+-------+ | 154| 35 36 | 37 38 | | 155| 45 46 | 47 48 | | 156+-------+-------+ ↓ 157``` 158Transform into 159sharding = `[[0], [2]]` 160``` 161mesh axis 2 162-----------> 163+-------------+-------------+ mesh axis 1 | mesh axis 0 | 164| 11 12 13 14 | 15 16 17 18 | | | 165| 21 22 23 24 | 25 26 27 28 | | | 166+-------------+-------------+ | | 167| 11 12 13 14 | 15 16 17 18 | | | 168| 21 22 23 24 | 25 26 27 28 | | | 169+-------------+-------------+ ↓ | 170+-------------+-------------+ | 171| 31 32 33 34 | 35 36 37 38 | | 172| 41 42 43 44 | 45 46 47 48 | | 173+-------------+-------------+ | 174| 31 32 33 34 | 35 36 37 38 | | 175| 41 42 43 44 | 45 46 47 48 | | 176+-------------+-------------+ ↓ 177``` 178Algorithm: 179 180Can't be done with just an all-gather along mesh axis 1. 181Can be handled by multiple resharding transformations 182`[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]` 183 184-------------------------------------------------------------- 185 186Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh. 187 188unsharded `6x6` tensor 189``` 19011 12 13 14 15 16 19121 22 23 24 25 26 19231 32 33 34 35 36 19341 42 43 44 45 46 19451 52 53 54 55 56 19561 62 63 64 65 66 196``` 197sharded on a `2x3` mesh 198 199sharding = `[[0], [1]]` 200``` 201mesh axis 1 202-----------> 203+-------+-------+-------+ mesh axis 0 | 204| 11 12 | 13 14 | 15 16 | | 205| 21 22 | 23 24 | 25 26 | | 206| 31 32 | 33 34 | 35 36 | | 207+-------+-------+-------+ | 208| 41 42 | 43 44 | 45 46 | | 209| 51 52 | 53 54 | 55 56 | | 210| 61 62 | 63 64 | 65 66 | | 211+-------+-------+-------+ ↓ 212``` 213transform to 214sharding = `[[1], [0]]` 215``` 216mesh axis 1 217-----------> 218+----------+----------+----------+ mesh axis 0 | 219| 11 12 13 | 31 32 33 | 51 52 53 | | 220| 21 22 23 | 41 42 43 | 61 62 63 | | 221+----------+----------+----------+ | 222| 14 15 16 | 34 35 36 | 54 55 56 | | 223| 24 25 26 | 44 45 46 | 64 65 66 | | 224+----------+----------+----------+ ↓ 225 226mesh axis 0 227-----------> 228+----------+----------+ mesh axis 1 | 229| 11 12 13 | 14 15 16 | | 230| 21 22 23 | 24 25 26 | | 231+----------+----------+ | 232| 31 32 33 | 34 35 36 | | 233| 41 42 43 | 44 45 46 | | 234+----------+----------+ | 235| 51 52 53 | 54 55 56 | | 236| 61 62 63 | 64 65 66 | | 237+----------+----------+ ↓ 238``` 239Algorithm: TODO 240 241-------------------------------------------------------------- 242 243Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh. 244 245unsharded 6x6 tensor 246``` 24711 12 13 14 15 16 24821 22 23 24 25 26 24931 32 33 34 35 36 25041 42 43 44 45 46 25151 52 53 54 55 56 25261 62 63 64 65 66 253``` 254shard on `2x6` mesh 255 256sharding = `[[0], [1]]` 257``` 258mesh axis 1 259-----------> 260+----+----+----+----+----+----+ mesh axis 0 | 261| 11 | 12 | 13 ‖ 14 | 15 | 16 | | 262| 21 | 22 | 23 ‖ 24 | 23 | 26 | | 263| 31 | 32 | 33 ‖ 34 | 35 | 36 | | 264+----+----+----+----+----+----+ | 265| 41 | 42 | 43 ‖ 44 | 45 | 46 | | 266| 51 | 52 | 53 ‖ 54 | 55 | 56 | | 267| 61 | 62 | 63 ‖ 64 | 65 | 66 | | 268+----+----+----+----+----+----+ ↓ 269``` 270transform to 271sharding = `[[1], [0]]` 272``` 273mesh axis 0 274-----------> 275+----------+----------+ mesh axis 1 | 276| 11 12 13 | 14 15 16 | | 277+----------+----------+ | 278| 21 22 23 | 24 25 26 | | 279+----------+----------+ | 280| 31 32 33 | 34 35 36 | | 281+==========+==========+ | 282| 41 42 43 | 44 45 46 | | 283+----------+----------+ | 284| 51 52 53 | 54 55 56 | | 285+----------+----------+ | 286| 61 62 63 | 64 65 66 | | 287+----------+----------+ ↓ 288``` 289Algorithm: TODO 290 291-------------------------------------------------------------- 292 293Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh. 294 295`M x N` mesh. 296`K x L` tensor `t`. 297`d(m, n)` the tensor on device `(m, n)`. 298 299sharding = `[[0], [1]]` 300Tensor shard s on each device has size `(K ceildiv M, L ceildiv N)`. 301``` 302d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] 303``` 304substitute 305``` 306i <- m * (K ceildiv M) + k 307j <- n * (L ceildiv N) + l 308``` 309``` 310m -> i floordiv (K ceildiv M) 311n -> j floordiv (L ceildiv N) 312k -> i % (K ceildiv M) 313l -> j % (L ceildiv N) 314``` 315For the inverse map we get 316``` 317t[i, j] -> d( 318 i floordiv (K ceildiv M), j floordiv (L ceildiv N) 319)[ 320 i % (K ceildiv M), j % (L ceildiv N) 321] 322``` 323Check: 324``` 325i = 13, j = 17, M = 3, N = 4, K = 16, L = 23 326t[13, 17] = d( 327 13 floordiv (16 ceildiv 3), 328 17 floordiv (23 ceilvid 4) 329)[ 330 13 % (16 ceildiv 3), 331 17 % (23 ceilvid 4) 332] 333= d( 334 13 floordiv 6, 335 17 floordiv 6 336)[ 337 13 % 6, 338 17 % 6 339] 340= d(2, 2)[1, 5] 341= t[ 342 2 * (16 ceildiv 3) + 1, 343 2 * (23 ceildiv 4) + 5 344] 345= t[ 346 2 * 6 + 1, 347 2 * 6 + 5 348] 349= t[13, 17] 350``` 351 352sharding = `[[1], [0]]` 353Tensor shard s on each device has size `(K ceildiv N, L ceildiv M)`. 354``` 355d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l] 356``` 357substitute 358``` 359i <- n * (K ceildiv N) + k 360j <- m * (L ceildiv M) + l 361``` 362``` 363m -> j floordiv (L ceildiv M) 364n -> i floordiv (K ceildiv N) 365k -> i % (K ceildiv N) 366l -> j % (L ceildiv M) 367``` 368For the inverse map we get 369``` 370t[i, j] -> d( 371 j floordiv (L ceildiv M), i floordiv (K ceildiv N) 372)[ 373 i % (K ceildiv N), j % (L ceildiv M) 374] 375``` 376Check: 377``` 378i = 9, j = 19, M = 5, N = 2, K = 27, L = 14 379t[9, 19] = d( 380 19 floordiv (14 ceildiv 5), 381 9 floordiv (27 ceildiv 2) 382)[ 383 9 % (27 ceildiv 2), 384 19 % (14 ceildiv 5) 385] 386= d( 387 19 floordiv 3, 388 9 floordiv 14 389)[ 390 9 % 14 391 19 % 3 392] 393= d(6, 0)[9, 1] 394= t[ 395 0 * (27 ceildiv 2) + 9, 396 6 * (14 ceildiv 5) + 1 397] 398= t[ 399 0 * 14 + 9, 400 6 * 3 + 1 401] 402= t[9, 19] 403``` 404sharding = `[[0], [1]]` 405``` 406d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] 407t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)] 408``` 409sharding = `[[1], [0]]` 410``` 411d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l] 412t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)] 413``` 414sharding `[[0], [1]] -> [[1], [0]]` 415`d1(m, n)` the tensor on device `(m, n)` for sharding sharding `[[0], [1]]`. 416`d2(m, n)` the tensor on device `(m, n)` for sharding sharding `[[1], [0]]`. 417``` 418d1(m, n)[k, l] -> 419t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] -> 420d2( 421 (m * (L ceildiv M) + l) floordiv (L ceildiv M), 422 (n * (K ceildiv N) + k) floordiv (K ceildiv N) 423)[ 424 (n * (K ceildiv N) + k) % (K ceildiv N), 425 (m * (L ceildiv M) + l) % (L ceildiv M) 426] 427= d2(p, q)[u, v] 428``` 429We want to copy the the data between devices in slices/tiles. 430What are the source/target tile coordinates? 431For a fixed `(m, n, p, q)` what is the range of `(k, l, u, v)`? 432TODO 433 434-------------------------------------------------------------- 435 436Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh. 437 438Device placement on a `2x3` mesh 439``` 44011 12 13 <- devices 44121 22 23 442``` 443sharding `[[0], [1]]` 444``` 445tensor axis 1 446-----------> 447+----+----+----+ tensor axis 0 | 448| 11 | 12 | 13 | | 449+----+----+----+ | 450| 21 | 22 | 23 | | 451+----+----+----+ ↓ 452``` 453transform to 454sharding `[[1], [0]]` 455``` 456tensor axis 1 457-----------> 458+----+----+ tensor axis 0 | 459| 11 | 21 | | 460+----+----+ | 461| 12 | 22 | | 462+----+----+ | 463| 13 | 23 | | 464+----+----+ ↓ 465``` 466``` 467+-----------------+--------+--------+-----------------+ 468| | | | 469+ + + + 470| 11 | 12 | 13 | 471+ + + + 472| | | | 473+-----------------+--------+--------+-----------------+ 474| | | | 475+ + + + 476| 21 | 22 | 23 | 477+ + + + 478| | | | 479+-----------------+--------+--------+-----------------+ 480 481+-----------------+--------+--------+-----------------+ 482| | | 483+ 11 + 21 + 484| | | 485+-----------------+--------+--------+-----------------+ 486| | | 487+ 12 + 22 + 488| | | 489+-----------------+--------+--------+-----------------+ 490| | | 491+ 13 + 23 + 492| | | 493+-----------------+--------+--------+-----------------+ 494 495+-----------------+--------+--------+-----------------+ 496| | | | | 497+ 11 11 + 12 11 + 12 21 + 13 21 + 498| | | | | 499+-----------------+--------+--------+-----------------+ 500| 11 12 | 12 12 | 12 22 | 13 22 | 501+-----------------+--------+--------+-----------------+ 502| 21 12 | 22 12 | 22 22 | 23 22 | 503+-----------------+--------+--------+-----------------+ 504| | | | | 505+ 21 13 + 22 13 + 22 23 + 23 23 + 506| | | | | 507+-----------------+--------+--------+-----------------+ 508``` 509If `S` and `T` are the source and target shard sizes along some tensor axis. 510Then we have a period of `(S*T)/gcd(S, T)`. Then the cut pattern repeats. 511TODO 512 513-------------------------------------------------------------- 514 515Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh. 516 517unsharded `6x6` tensor 518``` 51911 12 13 14 15 16 52021 22 23 24 25 26 52131 32 33 34 35 36 52241 42 43 44 45 46 52351 52 53 54 55 56 52461 62 63 64 65 66 525``` 526sharded on a `3` mesh 527 528sharding = `[[0], []]` 529``` 530+-------------------+ mesh axis 0 | 531| 11 12 13 14 15 16 | | 532| 21 22 23 24 25 26 | | 533+-------------------+ | 534| 31 32 33 34 35 36 | | 535| 41 42 43 44 45 46 | | 536+-------------------+ | 537| 51 52 53 54 55 56 | | 538| 61 62 63 64 65 66 | | 539+-------------------+ ↓ 540``` 541transform to 542sharding = `[[], [0]]` 543``` 544mesh axis 0 545-----------> 546+-------+-------+-------+ 547| 11 12 | 13 14 | 15 16 | 548| 21 22 | 23 24 | 25 26 | 549| 31 32 | 33 34 | 35 36 | 550| 41 42 | 43 44 | 45 46 | 551| 51 52 | 53 54 | 55 56 | 552| 61 62 | 63 64 | 65 66 | 553+-------+-------+-------+ 554``` 555Algorithm: 556```mlir 557%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8> 558``` 559-------------------------------------------------------------- 560 561Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh. 562 563unsharded `4x4` tensor 564``` 56511 12 13 14 56621 22 23 24 56731 32 33 34 56841 42 43 44 569``` 570sharded on a `2x2x2` mesh 571 572sharding = `[[0], [1, 2]]` 573``` 574mesh axis 2 575-----------> 576+----+----+ mesh axis 1 | mesh axis 0 | 577| 11 | 12 | | | 578| 21 | 22 | | | 579+----+----+ | | 580| 13 | 14 | | | 581| 23 | 24 | | | 582+----+----+ ↓ | 583+----+----+ | 584| 31 | 32 | | 585| 41 | 42 | | 586+----+----+ | 587| 33 | 34 | | 588| 43 | 44 | | 589+----+----+ ↓ 590``` 591transform to 592sharding = `[[0, 1], [2]]` 593``` 594mesh axis 2 595-----------> 596+-------+-------+ mesh axis 1 | mesh axis 0 | 597| 11 12 | 13 41 | | | 598+-------+-------+ | | 599| 21 22 | 23 24 | | | 600+-------+-------+ ↓ | 601+-------+-------+ | 602| 31 32 | 33 34 | | 603+-------+-------+ | 604| 41 42 | 43 44 | | 605+-------+-------+ ↓ 606``` 607Algorithm: 608```mlir 609%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8> 610``` 611is not enough. 612 613Can be decomposed into 614``` 615[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]] 616``` 617 618## Decomposition into basis of reshardings 619 620We can decompose each resharding into a sequence of basis reshardings. 621It is not communication efficient in terms of minimizing the data communicated 622between devices. 623An efficient approach would be more complicated to implement. 624Each device has to receive at most as much data as the size of its target 625sharding tensor. 626 627-------------------------------------------------------------- 628 629Basis: 630 631* From replicate to split. 632 ``` 633 [[]] -> [[1]] 634 ``` 635 Extract slices without communication. 636 637* From split to replicate. 638 ``` 639 [[0]] -> [[]] 640 [[0, 1]] -> [[1]] 641 ``` 642 All-gather along mesh axis 0. 643 644* Swap mesh axes order when assigned to the same tensor axis. 645 ``` 646 [[0, 1]] -> [[1, 0]] 647 ``` 648 Swap contents on devices with the same linear index. 649 650* Move mesh axis to different tensor dimension. 651 ``` 652 [[0], []] -> [[], [0]] 653 ``` 654 All-to-all. 655 656-------------------------------------------------------------- 657 658Example decomposition of 659``` 660[[0], [1]] -> [[1], [0]] 661``` 662into 663``` 664[[0], [1]] -> all-gather along mesh axis 1 -> 665[[0], []] -> all-to-all along mesh axis 0 -> 666[[], [0]] -> extract slice along mesh axis 1 -> 667[[1], [0]] 668``` 669 670-------------------------------------------------------------- 671 672Example decomposition of 673``` 674[[3, 2], [], [0, 1]] -> [[0], [1, 2], []] 675``` 676into 677``` 678[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 -> 679[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 -> 680[[3], [1, 2], [0]] -> all-gather along mesh axis 3 -> 681[[], [1, 2], [0]] -> all-to-all along mesh axis 0 -> 682[[0], [1, 2], []] 683``` 684