xref: /llvm-project/mlir/docs/Dialects/Mesh.md (revision 31fc0a12e1552e6bcea63ae740f284eaf74f4c17)
1# 'mesh' Dialect
2
3The `mesh` dialect contains a set of attributes, operations and interfaces that
4are useful for representing sharding and communication on a device mesh
5cluster.
6
7[TOC]
8
9## Collective Communication Operations
10There are a number of operations in the Mesh dialect to facilitate
11communication between devices in a mesh.
12It is assumed that the user is familiar with collective operations.
13[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
14explanation.
15The main addition is that the collectives in this dialect have mesh
16semantics.
17
18### Device groups
19The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
20axes that partition the devices into disjoint groups.
21The collective operation is performed between devices in the same group.
22Devices that have the same coordinates outside of axes `mesh_axes` are in the
23same group.
24A group is described by its multi-index along the axes outside of `mesh_axes`.
25For example if we have a device mesh of size `2x3x4x5` and the partition mesh
26axes list is `[0, 1]` then devices are partitioned into the groups
27`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
28The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
29Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
30Device (1, 0, 2, 4) will be in another group.
31Some collective operations like all-to-all and all-gather care about the
32order of devices.
33The order of device in a device group is induced by the order of axes in
34`mesh_axes`.
35The axes are ordered from outer to inner.
36If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
37both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
38
39### In-group Device
40Some operations like `broadcast`, `scatter` and `send` specify devices in each
41device-group.
42These devices are represented with their multi-index over the mesh axes that
43are not constant within a device group.
44These are the axes specified by `mesh_axes` attribute.
45
46For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
47an in-group device with `(i, j)`. Then for each group with index `g` on the
48second axis, the in-group device would be `(i, g, j)`.
49### Purity
50Collectives that involve the whole device group to perform a single operation
51are pure. The exceptions are `send` and `recv`.
52
53There is an assumption that the execution is SPMD.
54Not only that each process runs the same program, but that at the point of
55execution of a collective operation, all processes are in a coherent state.
56All compiler transformations must be consistent.
57Collective operations in the IR that may correspond to the same runtime
58collective operation must be transformed in a consistent manner.
59For example if a collective operation is optimized out, than it must also
60not appear in any path of execution on any process.
61
62Having the operations as `Pure` implies that if an interpreter is to execute
63the IR containing the `mesh` collectives, all processes would execute the same
64line when they reach a pure collective operation.
65This requirement stems from the need to be compatible with general optimization
66passes like dead code and common sub-expression elimination.
67
68## Operations
69
70[include "Dialects/MeshOps.md"]
71
72## Attributes
73
74[include "Dialects/MeshAttrs.md"]
75