xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (revision 79eb406a67fe08458548289da72cda18248a9313)
1//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
10#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
11
12include "mlir/Dialect/Mesh/IR/MeshBase.td"
13include "mlir/Dialect/Shape/IR/ShapeBase.td"
14include "mlir/Interfaces/DestinationStyleOpInterface.td"
15include "mlir/Interfaces/InferTypeOpInterface.td"
16include "mlir/Interfaces/SideEffectInterfaces.td"
17include "mlir/IR/BuiltinTypes.td"
18include "mlir/IR/CommonAttrConstraints.td"
19include "mlir/IR/CommonTypeConstraints.td"
20include "mlir/IR/OpAsmInterface.td"
21include "mlir/IR/SymbolInterfaces.td"
22
23//===----------------------------------------------------------------------===//
24// Mesh operations.
25//===----------------------------------------------------------------------===//
26
27class Mesh_Op<string mnemonic, list<Trait> traits = []> :
28    Op<Mesh_Dialect, mnemonic, traits> {
29}
30
31def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
32  let summary = "Description of a device/process mesh.";
33  let description = [{
34    The mesh.mesh operation is a symbol operation that identifies a specific
35    mesh. The operation has three attributes:
36
37    1. `sym_name`: This attribute uniquely identifies the name of the mesh.
38    This name serves as a symbolic reference to the mesh throughout
39    the MLIR module, allowing for consistent referencing and easier debugging.
40
41    2. `shape`: This attribute represents the shape of the device mesh.
42    It uses the same notation as a tensor shape. Also allowing for dynamic
43    dimensions.
44    This flexibility allows for dynamic device assignment or configurations
45    where the exact number of devices might not be determined during compile
46    time.
47    For example `2x?x4`.
48
49    Example:
50    ```
51    // A device mesh with 3 axes, the total device number is 4 * 8 * 12
52    // The dimension sizes are 4, 8, 12
53    mesh.mesh @mesh0(shape = 4x8x12)
54
55    // A device mesh with 2 axes, the total device number is unknown
56    // The first dimension size is 4 and the second is unknown
57    mesh.mesh @mesh1(shape = 4x?)
58
59    // A device mesh with 2 axes, the total device number is unknown
60    // The first dimension size is unknown and the second is 4
61    mesh.mesh @mesh2(shape = ?x4)
62
63    // A device mesh with 2 axes, the number of devices along both axes
64    // is unknown
65    mesh.mesh @mesh3(shape = ?x?)
66    ```
67  }];
68  let arguments = (ins
69    SymbolNameAttr:$sym_name,
70    DenseI64ArrayAttr:$shape
71  );
72  let assemblyFormat = [{
73    $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
74      attr-dict
75  }];
76  let extraClassDeclaration = [{
77    int64_t getRank() { return getShape().size(); }
78  }];
79  let hasVerifier = 1;
80}
81
82def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
83    Pure,
84    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
85    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
86  ]> {
87  let summary = "Get the shape of the mesh.";
88  let arguments = (ins
89    FlatSymbolRefAttr:$mesh,
90    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
91  );
92
93  let results = (outs
94    Variadic<Index>:$result
95  );
96
97  let assemblyFormat = [{
98    $mesh (`axes` `=` $axes^)?
99    attr-dict `:` type($result)
100  }];
101
102  let builders = [
103    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
104    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
105    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
106  ];
107}
108
109def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
110  Pure,
111  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
112  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
113]> {
114  let summary = "Get the multi index of current device along specified mesh axes.";
115  let description = [{
116    It is used in the SPMD format of IR.
117    The `axes` mush be non-negative and less than the total number of mesh axes.
118    If the axes are empty then get the index along all axes.
119  }];
120  let arguments = (ins
121    FlatSymbolRefAttr:$mesh,
122    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
123  );
124  let results = (outs
125    Variadic<Index>:$result
126  );
127  let assemblyFormat = [{
128    `on` $mesh (`axes` `=` $axes^)?
129    attr-dict `:` type($result)
130  }];
131  let builders = [
132    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
133    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
134  ];
135}
136
137def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
138  Pure,
139  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
140  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
141]> {
142  let summary = "Get the linear index of the current device.";
143  let description = [{
144    Example:
145    ```
146    %idx = mesh.process_linear_index on @mesh : index
147    ```
148    if `@mesh` has shape `(10, 20, 30)`, a device with multi
149    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
150  }];
151  let arguments = (ins FlatSymbolRefAttr:$mesh);
152  let results = (outs Index:$result);
153  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
154  let builders = [
155    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
156  ];
157}
158
159def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
160  Pure,
161  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
162  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
163]> {
164  let summary =
165      "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
166  let description = [{
167    Example:
168    ```
169    mesh.mesh @mesh0(shape = 10x20x30)
170    %c1 = arith.constant 1 : index
171    %c2 = arith.constant 2 : index
172    %c3 = arith.constant 3 : index
173    %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
174    ```
175    The above returns two indices, `633` and `693`, which correspond to the
176    index of the previous process `(1, 1, 3)`, and the next process
177    `(1, 3, 3) along the split axis `1`.
178
179    A negative value is returned if there is no neighbor in the respective
180    direction along the given `split_axes`.
181  }];
182  let arguments = (ins FlatSymbolRefAttr:$mesh,
183                       Variadic<Index>:$device,
184                       Mesh_MeshAxesAttr:$split_axes);
185  let results = (outs Index:$neighbor_down, Index:$neighbor_up);
186  let assemblyFormat =  [{
187      `on` $mesh `[` $device `]`
188      `split_axes` `=` $split_axes
189      attr-dict `:` type(results)
190  }];
191}
192
193//===----------------------------------------------------------------------===//
194// Sharding operations.
195//===----------------------------------------------------------------------===//
196
197def Mesh_ShardingOp : Mesh_Op<"sharding", [
198    Pure,
199    AttrSizedOperandSegments,
200    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
201    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
202  ]> {
203  let summary = "Define a sharding of a tensor.";
204  let description = [{
205    The MeshSharding specifies how a tensor is sharded and distributed across the
206    process mesh. It is typically used in a `mesh.shard` operation.
207    The operation has the follwing attributes and operands:
208
209    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
210    mesh where the distributed tensor is placed. The symbol must resolve to a
211    `mesh.mesh` operation.
212
213    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
214    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
215    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
216    along the x and y axes of the device mesh.
217
218    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
219    one along the specified mesh axes. An all-reduce should be applied to obtain
220    the complete tensor, with reduction type being specified by `partial_type`.
221
222    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
223    op. It has 4 possible values:
224    `generic`: is not an allowed value inside a shard attribute.
225
226    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
227    `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
228    sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
229    gets an additional halo of size 1 at the start of the first dimension and a halo
230    size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
231    sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
232    seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
233
234    6. [Optional] Offsets for each shard and sharded tensor dimension.
235    `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
236    sharded tensor dimension the offsets (starting index) of all shards in that
237    dimension and an additional value for the end of the last shard are provided.
238    For a 1d sharding this means that position `i` has the exclusive prefix sum for
239    shard `i`, and since only contiguous sharding is supported, its inclusive prefix
240    sum is at position 'i+1'.
241
242    Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
243    `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
244    the device-mesh will get a shard of shape 24x20x32 and the second device will get
245    a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
246
247    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
248
249    Examples:
250
251    ```
252    mesh.mesh @mesh0(shape = 2x2x4)
253    mesh.mesh @mesh1d_4(shape = 4)
254
255    // The tensor is fully replicated on @mesh0.
256    // Currently, there must be at least one sub-array present in axes, even
257    // if it's empty. Otherwise, a parsing error will occur.
258    %sharding0 = mesh.sharding @mesh0 split_axes = [[]]
259
260    // The tensor is sharded on the first dimension along axis 0 of @mesh0
261    %sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
262
263    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
264    // it is also a partial_sum along mesh axis 1.
265    %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
266
267    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
268    // it is also a partial_max along mesh axis 1.
269    %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
270
271    // Could be used for a mesh.shard op
272    %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
273
274    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
275    // and it has halo-sizes of 1 and 2 on the sharded dim.
276    %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
277    %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
278
279    // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
280    // and it has pre-defined shard sizes. The shards of the devices will have
281    // the following shapes: [4x2, 4x3, 4x4, 4x5]
282    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
283    %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
284    ```
285  }];
286
287  let arguments = (ins
288    FlatSymbolRefAttr:$mesh,
289    Mesh_MeshAxesArrayAttr:$split_axes,
290    OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
291    OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
292    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
293    Variadic<I64>:$dynamic_sharded_dims_offsets,
294    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
295    Variadic<I64>:$dynamic_halo_sizes
296  );
297  let results = (outs
298    Mesh_Sharding:$result
299  );
300  let assemblyFormat = [{
301    $mesh
302    `split_axes` `=` $split_axes
303    (`partial` `=` $partial_type $partial_axes^)?
304    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
305    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
306    attr-dict `:` type($result)
307  }];
308  let builders = [
309    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
310                   "ArrayRef<MeshAxesAttr>":$split_axes,
311                   "ArrayRef<MeshAxis>":$partial_axes,
312                   "mesh::ReductionKind":$partial_type,
313                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
314                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
315    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
316                   "ArrayRef<MeshAxesAttr>":$split_axes)>,
317    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
318                   "ArrayRef<MeshAxesAttr>":$split_axes,
319                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
320                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
321    OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
322  ];
323  let hasVerifier = 1;
324  let hasCanonicalizer = 1;
325}
326
327def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
328  let summary = "Get the shard shape of a given process/device.";
329  let description = [{
330    The device/process id is a linearized id of the device/process in the mesh.
331    This operation might be used during spmdization when the shard shape depends
332    on (non-constant) values used in `mesh.sharding`.
333  }];
334  let arguments = (ins
335    DenseI64ArrayAttr:$shape,
336    Mesh_Sharding:$sharding,
337    Index:$device
338  );
339  let results = (outs Variadic<Index>:$result);
340  let assemblyFormat = [{
341      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
342  }];
343  let builders = [
344    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
345  ];
346}
347
348def Mesh_ShardOp : Mesh_Op<"shard", [
349    Pure,
350    AllTypesMatch<["result", "src"]>,
351    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
352  ]> {
353  let summary = "Annotate on how a tensor is sharded across a mesh.";
354  let description = [{
355    The mesh.shard operation is designed to specify and guide the sharding
356    behavior of a tensor value across a mesh topology. This operation has two
357    operands and two optional attributes:
358
359    1. `input`: This operand represents the tensor value that needs to be
360    annotated for sharding.
361
362    2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
363    structure to represent distribution of a tensor on a mesh. it is typically defiend
364    by an `mesh.sharding` operation.
365
366    3. `annotate_for_users`: A unit attribute addressing the scenario when a
367    tensor's sharding annotation differs based on its context of use (either as
368    a result or an operand). If specified, the sharding pertains to specific
369    users of the tensor value, indicating how it should be considered when used
370    as an operand in subsequent operations. If not, the sharding applies to the
371    operation that defines the tensor value.
372
373    Example:
374    ```
375    func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
376      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
377      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
378      ...
379    }
380
381    func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
382      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
383      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
384      ...
385    }
386
387    func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
388      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
389      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
390      %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
391      ...
392    }
393
394    // The first mesh.shard op applies to %arg0, the second mesh.shard op
395    // applies for the operand of op0, the third mesh.shard op applies for the
396    // operand of op2
397    func.func @both_result_and_multi_operands_annotated(
398        %arg0 : tensor<4x8xf32>) -> () {
399      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
400      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
401      %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
402      %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
403      %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
404      %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
405      "op0"(%1) : ...
406      "op1"(%2) : ...
407      ...
408    }
409    ```
410
411    The following usages are undefined:
412    ```
413    func.func @annotate_on_same_result_with_different_sharding(
414        %arg0 : tensor<4x8xf32>) -> () {
415      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
416      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
417      %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
418      %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
419      ...
420    }
421
422    func.func @annotate_on_same_result_same_value_with_different_sharding(
423        %arg0 : tensor<4x8xf32>) -> () {
424      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
425      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
426      %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
427      %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
428      ...
429    }
430
431    func.func @annotate_on_same_operand_with_different_sharding(
432        %arg0 : tensor<4x8xf32>) -> () {
433      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
434      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
435      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
436      %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
437      ...
438    }
439
440    func.func @result_annotated_after_operand(
441        %arg0 : tensor<4x8xf32>) -> () {
442      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
443      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
444      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
445      %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
446      ...
447    }
448    ```
449  }];
450  let arguments = (ins
451    AnyRankedTensor:$src,
452    Mesh_Sharding:$sharding,
453    UnitAttr:$annotate_for_users
454  );
455  let results = (outs
456    AnyRankedTensor:$result
457  );
458  let assemblyFormat = [{
459    $src `to` $sharding
460      (`annotate_for_users` $annotate_for_users^)?
461      attr-dict `:` type($result)
462  }];
463}
464
465//===----------------------------------------------------------------------===//
466// collective communication ops
467//===----------------------------------------------------------------------===//
468
469class Mesh_CollectiveCommunicationOpBase<
470    string mnemonic, list<Trait> traits = []> :
471    Mesh_Op<mnemonic,
472      !listconcat(traits,
473        [
474          DeclareOpInterfaceMethods<SymbolUserOpInterface>,
475          DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
476        ])> {
477  dag commonArgs = (ins
478    FlatSymbolRefAttr:$mesh,
479    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
480  );
481}
482
483def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
484    Pure,
485    SameOperandsAndResultElementType,
486    SameOperandsAndResultRank,
487  ]> {
488  let summary = "All-gather over a device mesh.";
489  let description = [{
490    Gathers along the `gather_axis` tensor axis.
491
492    Example:
493    ```mlir
494    mesh.mesh @mesh0(shape = 2x2)
495    ...
496    %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
497      : tensor<2x2xi8> -> tensor<2x4xi8>
498    ```
499    Input:
500    ```
501                     +-------+-------+
502    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
503                     |  3  4 |  7  8 |
504                     +-------+-------+
505    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
506                     | 11 12 | 15 16 |
507                     +-------+-------+
508    ```
509    Result:
510    ```
511    gather tensor
512    axis 1
513    ------------>
514    +-------------+
515    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
516    |  3  4  7  8 |
517    +-------------+
518    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
519    | 11 12 15 16 |
520    +-------------+
521    ```
522  }];
523  let arguments = !con(commonArgs, (ins
524    AnyNon0RankedTensor:$input,
525    IndexAttr:$gather_axis
526  ));
527  let results = (outs
528    AnyNon0RankedTensor:$result
529  );
530  let assemblyFormat = [{
531    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
532    attr-dict `:` type($input) `->` type($result)
533  }];
534  let hasCanonicalizer = 1;
535}
536
537def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
538    Pure,
539    SameOperandsAndResultShape]> {
540  let summary = "All-reduce over a device mesh.";
541  let description = [{
542    The accumulation element type is specified by the result type and
543    it does not need to match the input element type.
544    The input element is converted to the result element type before
545    performing the reduction.
546
547    Attributes:
548    `reduction`: Indicates the reduction method.
549
550    Example:
551    ```
552    %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
553      : tensor<3x4xf32> -> tensor<3x4xf64>
554    ```
555  }];
556  let arguments = !con(commonArgs, (ins
557    AnyRankedTensor:$input,
558    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
559  ));
560  let results = (outs
561    AnyRankedTensor:$result
562  );
563  let assemblyFormat = [{
564    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
565    attr-dict `:` type($input) `->` type($result)
566  }];
567  let hasCanonicalizer = 1;
568    let builders = [
569    OpBuilder<(ins "Value":$input, "StringRef":$mesh,
570      "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
571  ];
572}
573
574def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
575    Pure,
576    SameOperandsAndResultElementType,
577    SameOperandsAndResultRank
578  ]> {
579  let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
580  let description = [{
581    Slice along the `slice_axis` tensor axis.
582    This operation can be thought of as the inverse of all-gather.
583    Technically, it is not required that all processes have the same input tensor.
584    Each process will slice a piece of its local tensor based on its in-group device index.
585    The operation does not communicate data between devices.
586
587    Example:
588    ```mlir
589    mesh.mesh @mesh0(shape = 2x2)
590    ...
591    %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
592      : tensor<2x4xi8> -> tensor<2x2xi8>
593    ```
594    Input:
595    ```
596    +-------------+
597    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
598    |  3  4  7  8 |
599    +-------------+
600    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
601    | 11 12 15 16 |
602    +-------------+
603    ```
604    Result:
605    ```
606    gather tensor
607    axis 1
608    ------------>
609                     +-------+-------+
610    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
611                     |  3  4 |  7  8 |
612                     +-------+-------+
613    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
614                     | 11 12 | 15 16 |
615                     +-------+-------+
616    ```
617  }];
618  let arguments = !con(commonArgs, (ins
619    AnyNon0RankedTensor:$input,
620    IndexAttr:$slice_axis
621  ));
622  let results = (outs
623    AnyNon0RankedTensor:$result
624  );
625  let assemblyFormat = [{
626    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
627    attr-dict `:` type($input) `->` type($result)
628  }];
629  let hasCanonicalizer = 1;
630  let builders = [
631    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
632    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
633  ];
634}
635
636def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
637    Pure,
638    SameOperandsAndResultElementType,
639    SameOperandsAndResultRank]> {
640  let summary = "All-to-all over a device mesh.";
641  let description = [{
642    Performs an all-to-all on tensor pieces split along `split_axis`.
643    The resulting pieces are concatenated along `concat_axis` on ech device.
644
645    Example:
646    ```
647    mesh.mesh @mesh0(shape = 3)
648    ...
649    %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
650      split_axis = 0 concat_axis = 0
651      : tensor<3x2xi8> -> tensor<3x2xi8>
652    ```
653    Input:
654    ```
655     device  device  device
656     (0)     (1)     (2)
657    +-------+-------+-------+  | split and concat along
658    | 11 12 | 21 22 | 31 32 |  | tensor axis 0
659    | 13 14 | 23 24 | 33 34 |  ↓
660    | 15 16 | 25 26 | 35 36 |
661    +-------+-------+-------+
662    ```
663    Result:
664    ```
665     device  device  device
666     (0)     (1)     (2)
667    +-------+-------+-------+
668    | 11 12 | 13 14 | 15 16 |
669    | 21 22 | 23 24 | 25 26 |
670    | 31 32 | 33 34 | 35 36 |
671    +-------+-------+-------+
672    ```
673  }];
674  let arguments = !con(commonArgs, (ins
675    AnyNon0RankedTensor:$input,
676    IndexAttr:$split_axis,
677    IndexAttr:$concat_axis
678  ));
679  let results = (outs
680    AnyNon0RankedTensor:$result
681  );
682  let assemblyFormat = [{
683    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
684    `split_axis` `=` $split_axis
685    `concat_axis` `=` $concat_axis
686    attr-dict `:` type($input) `->` type($result)
687  }];
688  let hasCanonicalizer = 1;
689}
690
691def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
692    Pure,
693    AllShapesMatch<["input", "result"]>,
694    AllElementTypesMatch<["input", "result"]>
695  ]> {
696  let summary = "Broadcast over a device mesh.";
697  let description = [{
698    Broadcast the tensor on `root` to all devices in each respective group.
699    The operation broadcasts along mesh axes `mesh_axes`.
700    The `root` device specifies the in-group multi-index that is broadcast to
701    all other devices in the group.
702
703    Example:
704    ```
705    mesh.mesh @mesh0(shape = 2x2)
706
707    %1 = mesh.broadcast %0 on @mesh0
708      mesh_axes = [0]
709      root = [0]
710      : (tensor<2xi8>) -> tensor<2xi8>
711    ```
712
713    Input:
714    ```
715                     +-------+-------+                   | broadcast
716    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
717                     +-------+-------+                   ↓
718    device (1, 0) -> |       |       | <- device (1, 1)
719                     +-------+-------+
720    ```
721
722    Output:
723    ```
724                     +-------+-------+
725    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
726                     +-------+-------+
727    device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
728                     +-------+-------+
729    ```
730  }];
731  let arguments = !con(commonArgs, (ins
732    AnyRankedTensor:$input,
733    DenseI64ArrayAttr:$root,
734    Variadic<Index>:$root_dynamic
735  ));
736  let results = (outs
737    AnyRankedTensor:$result
738  );
739  let assemblyFormat = [{
740    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
741    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
742    attr-dict `:` functional-type(operands, results)
743  }];
744  let hasCanonicalizer = 1;
745}
746
747def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
748    Pure,
749    AllRanksMatch<["input", "result"]>,
750    AllElementTypesMatch<["input", "result"]>
751  ]> {
752  let summary = "Gather over a device mesh.";
753  let description = [{
754    Gathers on device `root` along the `gather_axis` tensor axis.
755    `root` specifies the coordinates of a device along `mesh_axes`.
756    It uniquely identifies the root device for each device group.
757    The result tensor on non-root devices is undefined.
758    Using it will result in undefined behavior.
759
760    Example:
761    ```mlir
762    mesh.mesh @mesh0(shape = 2x2)
763    ...
764    %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
765      gather_axis = 1 root = [1]
766      : (tensor<2x2xi8>) -> tensor<2x4xi8>
767    ```
768    Input:
769    ```
770                      gather tensor
771                      axis 1
772                      ------------>
773                     +-------+-------+
774    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
775                     |  3  4 |  7  8 |
776                     +-------+-------+
777    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
778                     | 11 12 | 15 16 |
779                     +-------+-------+
780    ```
781    Result:
782    ```
783    +-------------+
784    |  1  2  5  6 | <- devices (0, 1)
785    |  3  4  7  8 |
786    +-------------+
787    |  9 10 13 14 | <- devices (1, 1)
788    | 11 12 15 16 |
789    +-------------+
790    ```
791    Devices `(0, 0)` and `(1, 0)` have undefined result.
792  }];
793  let arguments = !con(commonArgs, (ins
794    AnyNon0RankedTensor:$input,
795    IndexAttr:$gather_axis,
796    DenseI64ArrayAttr:$root,
797    Variadic<Index>:$root_dynamic
798  ));
799  let results = (outs
800    AnyNon0RankedTensor:$result
801  );
802  let assemblyFormat = [{
803    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
804    `gather_axis` `=` $gather_axis
805    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
806    attr-dict `:` functional-type(operands, results)
807  }];
808  let hasCanonicalizer = 1;
809}
810
811def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
812    AllShapesMatch<["input", "result"]>,
813    AllElementTypesMatch<["input", "result"]>
814  ]> {
815  let summary = "Send over a device mesh.";
816  let description = [{
817    Receive from a device within a device group.
818  }];
819  let arguments = !con(commonArgs, (ins
820    AnyNon0RankedTensor:$input,
821    OptionalAttr<DenseI64ArrayAttr>:$source,
822    Variadic<Index>:$source_dynamic
823  ));
824  let results = (outs
825    AnyRankedTensor:$result
826  );
827  let assemblyFormat = [{
828    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
829    (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
830    attr-dict `:` functional-type(operands, results)
831  }];
832  let hasCanonicalizer = 1;
833}
834
835def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
836    Pure,
837    AllShapesMatch<["input", "result"]>
838  ]> {
839  let summary = "Reduce over a device mesh.";
840  let description = [{
841    Reduces on device `root` within each device group.
842    `root` specifies the coordinates of a device along `mesh_axes`.
843    It uniquely identifies the root device within its device group.
844    The accumulation element type is specified by the result type and
845    it does not need to match the input element type.
846    The input element is converted to the result element type before
847    performing the reduction.
848
849    Attributes:
850    `reduction`: Indicates the reduction method.
851
852    Example:
853    ```
854    %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
855      reduction = <max> root = [2, 3]
856      : (tensor<3x4xf32>) -> tensor<3x4xf64>
857    ```
858  }];
859  let arguments = !con(commonArgs, (ins
860    AnyRankedTensor:$input,
861    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
862    DenseI64ArrayAttr:$root,
863    Variadic<Index>:$root_dynamic
864  ));
865  let results = (outs
866    AnyRankedTensor:$result
867  );
868  let assemblyFormat = [{
869    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
870    (`reduction` `=` $reduction^)?
871    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
872    attr-dict `:` functional-type(operands, results)
873  }];
874  let hasCanonicalizer = 1;
875}
876
877def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
878    Pure,
879    SameOperandsAndResultRank]> {
880  let summary = "Reduce-scatter over a device mesh.";
881  let description = [{
882    After the reduction, the result is scattered within each device group.
883    The tensor is split along `scatter_axis` and the pieces distributed
884    across the device group.
885    Example:
886    ```
887    mesh.mesh @mesh0(shape = 2x2)
888    ...
889    %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
890      reduction = <max> scatter_axis = 0
891      : tensor<3x4xf32> -> tensor<1x4xf64>
892    ```
893    Input:
894    ```
895                              device
896                              (0, 1)
897898                     +-------+-------+  | scatter tensor
899    device (0, 0) -> |  1  2 |  5  6 |  | axis 0
900                     |  3  4 |  7  8 |  ↓
901                     +-------+-------+
902    device (1, 0) -> |  9 10 | 13 14 |
903                     | 11 12 | 15 16 |
904                     +-------+-------+
905906                              device
907                              (1, 1)
908    ```
909    Result:
910    ```
911    +-------+
912    |  6  8 | <- devices (0, 0)
913    +-------+
914    | 10 12 | <- devices (0, 1)
915    +-------+
916    | 22 24 | <- devices (1, 0)
917    +-------+
918    | 26 28 | <- devices (1, 1)
919    +-------+
920    ```
921  }];
922  let arguments = !con(commonArgs, (ins
923    AnyNon0RankedTensor:$input,
924    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
925    IndexAttr:$scatter_axis
926  ));
927  let results = (outs
928    AnyRankedTensor:$result
929  );
930  let assemblyFormat = [{
931    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
932    (`reduction` `=` $reduction^)?
933    `scatter_axis` `=` $scatter_axis
934    attr-dict `:` type($input) `->` type($result)
935  }];
936  let hasCanonicalizer = 1;
937}
938
939def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
940    Pure,
941    AllRanksMatch<["input", "result"]>,
942    AllElementTypesMatch<["input", "result"]>
943  ]> {
944  let summary = "Scatter over a device mesh.";
945  let description = [{
946    For each device group split the input tensor on the `root` device along
947    axis `scatter_axis` and scatter the parts across the group devices.
948
949    Example:
950    ```
951    mesh.mesh @mesh0(shape = 2x2)
952    %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
953      scatter_axis = 0
954      root = [1]
955      : (tensor<2x2xi8>) -> tensor<1x2xi8>
956    ```
957
958    Input:
959    ```
960                              device
961                              (0, 1)
962963                     +-------+-------+  | scatter tensor
964    device (0, 0) -> |       |       |  | axis 0
965                     |       |       |  ↓
966                     +-------+-------+
967    device (1, 0) -> |  1  2 |  5  6 |
968                     |  3  4 |  7  8 |
969                     +-------+-------+
970971                              device
972                              (1, 1)
973    ```
974
975    Result:
976    ```
977                              device
978                              (0, 1)
979980                     +-------+-------+
981    device (0, 0) -> |  1  2 |  5  6 |
982                     +-------+-------+
983    device (1, 0) -> |  3  4 |  7  8 |
984                     +-------+-------+
985986                              device
987                              (1, 1)
988    ```
989  }];
990  let arguments = !con(commonArgs, (ins
991    AnyNon0RankedTensor:$input,
992    IndexAttr:$scatter_axis,
993    DenseI64ArrayAttr:$root,
994    Variadic<Index>:$root_dynamic
995  ));
996  let results = (outs
997    AnyRankedTensor:$result
998  );
999  let assemblyFormat = [{
1000    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
1001    `scatter_axis` `=` $scatter_axis
1002    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
1003    attr-dict `:` functional-type(operands, results)
1004  }];
1005  let hasCanonicalizer = 1;
1006}
1007
1008def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
1009    AllShapesMatch<["input", "result"]>,
1010    AllElementTypesMatch<["input", "result"]>
1011  ]> {
1012  let summary = "Send over a device mesh.";
1013  let description = [{
1014    Send from one device to another within a device group.
1015  }];
1016  let arguments = !con(commonArgs, (ins
1017    AnyNon0RankedTensor:$input,
1018    DenseI64ArrayAttr:$destination,
1019    Variadic<Index>:$destination_dynamic
1020  ));
1021  let results = (outs
1022    AnyRankedTensor:$result
1023  );
1024  let assemblyFormat = [{
1025    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
1026    `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
1027    attr-dict `:` functional-type(operands, results)
1028  }];
1029  let hasCanonicalizer = 1;
1030}
1031
1032def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
1033    Pure,
1034    SameOperandsAndResultElementType,
1035    SameOperandsAndResultShape
1036  ]> {
1037  let summary = "Shift over a device mesh.";
1038  let description = [{
1039    Within each device group shift along mesh axis `shift_axis` by an offset
1040    `offset`.
1041    The result on devices that do not have a corresponding source is undefined.
1042    `shift_axis` must be one of `mesh_axes`.
1043    If the `rotate` attribute is present,
1044    instead of a shift a rotation is done.
1045
1046    Example:
1047    ```
1048    mesh.mesh @mesh0(shape = 2x4)
1049    %1 = mesh.shift on @mesh0 mesh_axes = [1]
1050      shift_axis = 1 offset = 2 rotate
1051      : tensor<2xi8> -> tensor<2xi8>
1052    ```
1053
1054    Input:
1055    ```
1056    mesh axis 1
1057    ----------->
1058
1059    +----+----+----+----+
1060    |  1 |  2 |  3 |  4 |
1061    +----+----+----+----+
1062    |  5 |  6 |  7 |  8 |
1063    +----+----+----+----+
1064    ```
1065
1066    Result:
1067    ```
1068    +----+----+----+----+
1069    |  3 |  4 |  1 |  2 |
1070    +----+----+----+----+
1071    |  7 |  8 |  5 |  6 |
1072    +----+----+----+----+
1073    ```
1074  }];
1075  let arguments = !con(commonArgs, (ins
1076    AnyNon0RankedTensor:$input,
1077    IndexAttr:$shift_axis,
1078    I64Attr:$offset,
1079    UnitAttr:$rotate
1080  ));
1081  let results = (outs
1082    AnyRankedTensor:$result
1083  );
1084  let assemblyFormat = [{
1085    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
1086    `shift_axis` `=` $shift_axis
1087    `offset` `=` $offset
1088    (`rotate` $rotate^)?
1089    attr-dict `:` type($input) `->` type($result)
1090  }];
1091  let hasCanonicalizer = 1;
1092}
1093
1094def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1095  Pure,
1096  DestinationStyleOpInterface,
1097  TypesMatchWith<
1098    "result has same type as destination",
1099    "result", "destination", "$_self">,
1100  DeclareOpInterfaceMethods<SymbolUserOpInterface>
1101]> {
1102  let summary = "Update halo data.";
1103  let description = [{
1104    This operation updates halo regions of shards, e.g. if their sharding
1105    specified halos and the actual tensor/memref data might have changed
1106    on the remote devices. Changes might be caused by mutating operations
1107    and/or if the new halo regions are larger than the existing ones.
1108
1109    Destination is supposed to be initialized with the local data (not halos).
1110
1111    Assumes all devices hold tensors with same-sized halo data as specified
1112    by `source_halo_sizes/static_source_halo_sizes` and
1113    `destination_halo_sizes/static_destination_halo_sizes` in source shard
1114    and destination/result shard.
1115
1116    `split_axes` specifies for each tensor axis along which mesh axes its halo
1117    data is updated.
1118
1119  }];
1120  let arguments = (ins
1121    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
1122    FlatSymbolRefAttr:$mesh,
1123    Mesh_MeshAxesArrayAttr:$split_axes,
1124    Variadic<I64>:$halo_sizes,
1125    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
1126  );
1127  let results = (outs
1128    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
1129  );
1130  let assemblyFormat = [{
1131    $destination
1132    `on` $mesh
1133    `split_axes` `=` $split_axes
1134    (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
1135    attr-dict `:` type($result)
1136  }];
1137  let extraClassDeclaration = [{
1138    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
1139  }];
1140}
1141#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
1142