1//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD 10#define MLIR_DIALECT_MESH_IR_MESHOPS_TD 11 12include "mlir/Dialect/Mesh/IR/MeshBase.td" 13include "mlir/Dialect/Shape/IR/ShapeBase.td" 14include "mlir/Interfaces/DestinationStyleOpInterface.td" 15include "mlir/Interfaces/InferTypeOpInterface.td" 16include "mlir/Interfaces/SideEffectInterfaces.td" 17include "mlir/IR/BuiltinTypes.td" 18include "mlir/IR/CommonAttrConstraints.td" 19include "mlir/IR/CommonTypeConstraints.td" 20include "mlir/IR/OpAsmInterface.td" 21include "mlir/IR/SymbolInterfaces.td" 22 23//===----------------------------------------------------------------------===// 24// Mesh operations. 25//===----------------------------------------------------------------------===// 26 27class Mesh_Op<string mnemonic, list<Trait> traits = []> : 28 Op<Mesh_Dialect, mnemonic, traits> { 29} 30 31def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> { 32 let summary = "Description of a device/process mesh."; 33 let description = [{ 34 The mesh.mesh operation is a symbol operation that identifies a specific 35 mesh. The operation has three attributes: 36 37 1. `sym_name`: This attribute uniquely identifies the name of the mesh. 38 This name serves as a symbolic reference to the mesh throughout 39 the MLIR module, allowing for consistent referencing and easier debugging. 40 41 2. `shape`: This attribute represents the shape of the device mesh. 42 It uses the same notation as a tensor shape. Also allowing for dynamic 43 dimensions. 44 This flexibility allows for dynamic device assignment or configurations 45 where the exact number of devices might not be determined during compile 46 time. 47 For example `2x?x4`. 48 49 Example: 50 ``` 51 // A device mesh with 3 axes, the total device number is 4 * 8 * 12 52 // The dimension sizes are 4, 8, 12 53 mesh.mesh @mesh0(shape = 4x8x12) 54 55 // A device mesh with 2 axes, the total device number is unknown 56 // The first dimension size is 4 and the second is unknown 57 mesh.mesh @mesh1(shape = 4x?) 58 59 // A device mesh with 2 axes, the total device number is unknown 60 // The first dimension size is unknown and the second is 4 61 mesh.mesh @mesh2(shape = ?x4) 62 63 // A device mesh with 2 axes, the number of devices along both axes 64 // is unknown 65 mesh.mesh @mesh3(shape = ?x?) 66 ``` 67 }]; 68 let arguments = (ins 69 SymbolNameAttr:$sym_name, 70 DenseI64ArrayAttr:$shape 71 ); 72 let assemblyFormat = [{ 73 $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)` 74 attr-dict 75 }]; 76 let extraClassDeclaration = [{ 77 int64_t getRank() { return getShape().size(); } 78 }]; 79 let hasVerifier = 1; 80} 81 82def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [ 83 Pure, 84 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 85 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 86 ]> { 87 let summary = "Get the shape of the mesh."; 88 let arguments = (ins 89 FlatSymbolRefAttr:$mesh, 90 DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes 91 ); 92 93 let results = (outs 94 Variadic<Index>:$result 95 ); 96 97 let assemblyFormat = [{ 98 $mesh (`axes` `=` $axes^)? 99 attr-dict `:` type($result) 100 }]; 101 102 let builders = [ 103 OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, 104 OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>, 105 OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)> 106 ]; 107} 108 109def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [ 110 Pure, 111 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 112 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 113]> { 114 let summary = "Get the multi index of current device along specified mesh axes."; 115 let description = [{ 116 It is used in the SPMD format of IR. 117 The `axes` mush be non-negative and less than the total number of mesh axes. 118 If the axes are empty then get the index along all axes. 119 }]; 120 let arguments = (ins 121 FlatSymbolRefAttr:$mesh, 122 DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes 123 ); 124 let results = (outs 125 Variadic<Index>:$result 126 ); 127 let assemblyFormat = [{ 128 `on` $mesh (`axes` `=` $axes^)? 129 attr-dict `:` type($result) 130 }]; 131 let builders = [ 132 OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, 133 OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)> 134 ]; 135} 136 137def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ 138 Pure, 139 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 140 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 141]> { 142 let summary = "Get the linear index of the current device."; 143 let description = [{ 144 Example: 145 ``` 146 %idx = mesh.process_linear_index on @mesh : index 147 ``` 148 if `@mesh` has shape `(10, 20, 30)`, a device with multi 149 index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`. 150 }]; 151 let arguments = (ins FlatSymbolRefAttr:$mesh); 152 let results = (outs Index:$result); 153 let assemblyFormat = "`on` $mesh attr-dict `:` type($result)"; 154 let builders = [ 155 OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)> 156 ]; 157} 158 159def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ 160 Pure, 161 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 162 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 163]> { 164 let summary = 165 "For given mesh index get the linear indices of the direct neighbor processes along the given split."; 166 let description = [{ 167 Example: 168 ``` 169 mesh.mesh @mesh0(shape = 10x20x30) 170 %c1 = arith.constant 1 : index 171 %c2 = arith.constant 2 : index 172 %c3 = arith.constant 3 : index 173 %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index 174 ``` 175 The above returns two indices, `633` and `693`, which correspond to the 176 index of the previous process `(1, 1, 3)`, and the next process 177 `(1, 3, 3) along the split axis `1`. 178 179 A negative value is returned if there is no neighbor in the respective 180 direction along the given `split_axes`. 181 }]; 182 let arguments = (ins FlatSymbolRefAttr:$mesh, 183 Variadic<Index>:$device, 184 Mesh_MeshAxesAttr:$split_axes); 185 let results = (outs Index:$neighbor_down, Index:$neighbor_up); 186 let assemblyFormat = [{ 187 `on` $mesh `[` $device `]` 188 `split_axes` `=` $split_axes 189 attr-dict `:` type(results) 190 }]; 191} 192 193//===----------------------------------------------------------------------===// 194// Sharding operations. 195//===----------------------------------------------------------------------===// 196 197def Mesh_ShardingOp : Mesh_Op<"sharding", [ 198 Pure, 199 AttrSizedOperandSegments, 200 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 201 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 202 ]> { 203 let summary = "Define a sharding of a tensor."; 204 let description = [{ 205 The MeshSharding specifies how a tensor is sharded and distributed across the 206 process mesh. It is typically used in a `mesh.shard` operation. 207 The operation has the follwing attributes and operands: 208 209 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device 210 mesh where the distributed tensor is placed. The symbol must resolve to a 211 `mesh.mesh` operation. 212 213 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's 214 maximum size is the `rank` of the related tensor. For the i-th sub-array, if 215 its value is [x, y], it indicates that the tensor's i-th dimension is splitted 216 along the x and y axes of the device mesh. 217 218 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial 219 one along the specified mesh axes. An all-reduce should be applied to obtain 220 the complete tensor, with reduction type being specified by `partial_type`. 221 222 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce 223 op. It has 4 possible values: 224 `generic`: is not an allowed value inside a shard attribute. 225 226 5. [Optional] Sizes of halos to be added for each sharded tensor dimension. 227 `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each 228 sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension 229 gets an additional halo of size 1 at the start of the first dimension and a halo 230 size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2 231 sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the 232 seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes. 233 234 6. [Optional] Offsets for each shard and sharded tensor dimension. 235 `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each 236 sharded tensor dimension the offsets (starting index) of all shards in that 237 dimension and an additional value for the end of the last shard are provided. 238 For a 1d sharding this means that position `i` has the exclusive prefix sum for 239 shard `i`, and since only contiguous sharding is supported, its inclusive prefix 240 sum is at position 'i+1'. 241 242 Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, 243 `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of 244 the device-mesh will get a shard of shape 24x20x32 and the second device will get 245 a shard of shape 8x12x32. `?` indicates dynamic shard dimensions. 246 247 `halo_sizes` and `sharded_dims_offsets` are mutually exclusive. 248 249 Examples: 250 251 ``` 252 mesh.mesh @mesh0(shape = 2x2x4) 253 mesh.mesh @mesh1d_4(shape = 4) 254 255 // The tensor is fully replicated on @mesh0. 256 // Currently, there must be at least one sub-array present in axes, even 257 // if it's empty. Otherwise, a parsing error will occur. 258 %sharding0 = mesh.sharding @mesh0 split_axes = [[]] 259 260 // The tensor is sharded on the first dimension along axis 0 of @mesh0 261 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] 262 263 // The tensor is sharded on its first dimension along axis 0 of @mesh0 and 264 // it is also a partial_sum along mesh axis 1. 265 %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1] 266 267 // The tensor is sharded on its first dimension along axis 0 of @mesh0 and 268 // it is also a partial_max along mesh axis 1. 269 %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1] 270 271 // Could be used for a mesh.shard op 272 %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32> 273 274 // The tensor is sharded on its first dimension along axis 0 of @mesh0 and 275 // and it has halo-sizes of 1 and 2 on the sharded dim. 276 %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] 277 %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32> 278 279 // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4 280 // and it has pre-defined shard sizes. The shards of the devices will have 281 // the following shapes: [4x2, 4x3, 4x4, 4x5] 282 %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] 283 %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32> 284 ``` 285 }]; 286 287 let arguments = (ins 288 FlatSymbolRefAttr:$mesh, 289 Mesh_MeshAxesArrayAttr:$split_axes, 290 OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes, 291 OptionalAttr<Mesh_ReductionKindAttr>:$partial_type, 292 DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets, 293 Variadic<I64>:$dynamic_sharded_dims_offsets, 294 DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes, 295 Variadic<I64>:$dynamic_halo_sizes 296 ); 297 let results = (outs 298 Mesh_Sharding:$result 299 ); 300 let assemblyFormat = [{ 301 $mesh 302 `split_axes` `=` $split_axes 303 (`partial` `=` $partial_type $partial_axes^)? 304 (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)? 305 (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)? 306 attr-dict `:` type($result) 307 }]; 308 let builders = [ 309 OpBuilder<(ins "FlatSymbolRefAttr":$mesh, 310 "ArrayRef<MeshAxesAttr>":$split_axes, 311 "ArrayRef<MeshAxis>":$partial_axes, 312 "mesh::ReductionKind":$partial_type, 313 CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes, 314 CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>, 315 OpBuilder<(ins "FlatSymbolRefAttr":$mesh, 316 "ArrayRef<MeshAxesAttr>":$split_axes)>, 317 OpBuilder<(ins "FlatSymbolRefAttr":$mesh, 318 "ArrayRef<MeshAxesAttr>":$split_axes, 319 "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes, 320 "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>, 321 OpBuilder<(ins "mlir::mesh::MeshSharding":$from)> 322 ]; 323 let hasVerifier = 1; 324 let hasCanonicalizer = 1; 325} 326 327def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> { 328 let summary = "Get the shard shape of a given process/device."; 329 let description = [{ 330 The device/process id is a linearized id of the device/process in the mesh. 331 This operation might be used during spmdization when the shard shape depends 332 on (non-constant) values used in `mesh.sharding`. 333 }]; 334 let arguments = (ins 335 DenseI64ArrayAttr:$shape, 336 Mesh_Sharding:$sharding, 337 Index:$device 338 ); 339 let results = (outs Variadic<Index>:$result); 340 let assemblyFormat = [{ 341 custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result) 342 }]; 343 let builders = [ 344 OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)> 345 ]; 346} 347 348def Mesh_ShardOp : Mesh_Op<"shard", [ 349 Pure, 350 AllTypesMatch<["result", "src"]>, 351 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 352 ]> { 353 let summary = "Annotate on how a tensor is sharded across a mesh."; 354 let description = [{ 355 The mesh.shard operation is designed to specify and guide the sharding 356 behavior of a tensor value across a mesh topology. This operation has two 357 operands and two optional attributes: 358 359 1. `input`: This operand represents the tensor value that needs to be 360 annotated for sharding. 361 362 2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data 363 structure to represent distribution of a tensor on a mesh. it is typically defiend 364 by an `mesh.sharding` operation. 365 366 3. `annotate_for_users`: A unit attribute addressing the scenario when a 367 tensor's sharding annotation differs based on its context of use (either as 368 a result or an operand). If specified, the sharding pertains to specific 369 users of the tensor value, indicating how it should be considered when used 370 as an operand in subsequent operations. If not, the sharding applies to the 371 operation that defines the tensor value. 372 373 Example: 374 ``` 375 func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { 376 %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 377 %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> 378 ... 379 } 380 381 func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { 382 %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 383 %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> 384 ... 385 } 386 387 func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () { 388 %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 389 %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> 390 %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> 391 ... 392 } 393 394 // The first mesh.shard op applies to %arg0, the second mesh.shard op 395 // applies for the operand of op0, the third mesh.shard op applies for the 396 // operand of op2 397 func.func @both_result_and_multi_operands_annotated( 398 %arg0 : tensor<4x8xf32>) -> () { 399 %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 400 %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> 401 %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 402 %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> 403 %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding 404 %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> 405 "op0"(%1) : ... 406 "op1"(%2) : ... 407 ... 408 } 409 ``` 410 411 The following usages are undefined: 412 ``` 413 func.func @annotate_on_same_result_with_different_sharding( 414 %arg0 : tensor<4x8xf32>) -> () { 415 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 416 %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 417 %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32> 418 %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32> 419 ... 420 } 421 422 func.func @annotate_on_same_result_same_value_with_different_sharding( 423 %arg0 : tensor<4x8xf32>) -> () { 424 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 425 %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 426 %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32> 427 %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32> 428 ... 429 } 430 431 func.func @annotate_on_same_operand_with_different_sharding( 432 %arg0 : tensor<4x8xf32>) -> () { 433 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 434 %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 435 %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> 436 %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> 437 ... 438 } 439 440 func.func @result_annotated_after_operand( 441 %arg0 : tensor<4x8xf32>) -> () { 442 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 443 %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 444 %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> 445 %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32> 446 ... 447 } 448 ``` 449 }]; 450 let arguments = (ins 451 AnyRankedTensor:$src, 452 Mesh_Sharding:$sharding, 453 UnitAttr:$annotate_for_users 454 ); 455 let results = (outs 456 AnyRankedTensor:$result 457 ); 458 let assemblyFormat = [{ 459 $src `to` $sharding 460 (`annotate_for_users` $annotate_for_users^)? 461 attr-dict `:` type($result) 462 }]; 463} 464 465//===----------------------------------------------------------------------===// 466// collective communication ops 467//===----------------------------------------------------------------------===// 468 469class Mesh_CollectiveCommunicationOpBase< 470 string mnemonic, list<Trait> traits = []> : 471 Mesh_Op<mnemonic, 472 !listconcat(traits, 473 [ 474 DeclareOpInterfaceMethods<SymbolUserOpInterface>, 475 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 476 ])> { 477 dag commonArgs = (ins 478 FlatSymbolRefAttr:$mesh, 479 DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes 480 ); 481} 482 483def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [ 484 Pure, 485 SameOperandsAndResultElementType, 486 SameOperandsAndResultRank, 487 ]> { 488 let summary = "All-gather over a device mesh."; 489 let description = [{ 490 Gathers along the `gather_axis` tensor axis. 491 492 Example: 493 ```mlir 494 mesh.mesh @mesh0(shape = 2x2) 495 ... 496 %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1 497 : tensor<2x2xi8> -> tensor<2x4xi8> 498 ``` 499 Input: 500 ``` 501 +-------+-------+ 502 device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) 503 | 3 4 | 7 8 | 504 +-------+-------+ 505 device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) 506 | 11 12 | 15 16 | 507 +-------+-------+ 508 ``` 509 Result: 510 ``` 511 gather tensor 512 axis 1 513 ------------> 514 +-------------+ 515 | 1 2 5 6 | <- devices (0, 0) and (0, 1) 516 | 3 4 7 8 | 517 +-------------+ 518 | 9 10 13 14 | <- devices (1, 0) and (1, 1) 519 | 11 12 15 16 | 520 +-------------+ 521 ``` 522 }]; 523 let arguments = !con(commonArgs, (ins 524 AnyNon0RankedTensor:$input, 525 IndexAttr:$gather_axis 526 )); 527 let results = (outs 528 AnyNon0RankedTensor:$result 529 ); 530 let assemblyFormat = [{ 531 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis 532 attr-dict `:` type($input) `->` type($result) 533 }]; 534 let hasCanonicalizer = 1; 535} 536 537def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ 538 Pure, 539 SameOperandsAndResultShape]> { 540 let summary = "All-reduce over a device mesh."; 541 let description = [{ 542 The accumulation element type is specified by the result type and 543 it does not need to match the input element type. 544 The input element is converted to the result element type before 545 performing the reduction. 546 547 Attributes: 548 `reduction`: Indicates the reduction method. 549 550 Example: 551 ``` 552 %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max> 553 : tensor<3x4xf32> -> tensor<3x4xf64> 554 ``` 555 }]; 556 let arguments = !con(commonArgs, (ins 557 AnyRankedTensor:$input, 558 DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction 559 )); 560 let results = (outs 561 AnyRankedTensor:$result 562 ); 563 let assemblyFormat = [{ 564 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? 565 attr-dict `:` type($input) `->` type($result) 566 }]; 567 let hasCanonicalizer = 1; 568 let builders = [ 569 OpBuilder<(ins "Value":$input, "StringRef":$mesh, 570 "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)> 571 ]; 572} 573 574def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ 575 Pure, 576 SameOperandsAndResultElementType, 577 SameOperandsAndResultRank 578 ]> { 579 let summary = "All-slice over a device mesh. This is the inverse of all-gather."; 580 let description = [{ 581 Slice along the `slice_axis` tensor axis. 582 This operation can be thought of as the inverse of all-gather. 583 Technically, it is not required that all processes have the same input tensor. 584 Each process will slice a piece of its local tensor based on its in-group device index. 585 The operation does not communicate data between devices. 586 587 Example: 588 ```mlir 589 mesh.mesh @mesh0(shape = 2x2) 590 ... 591 %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1 592 : tensor<2x4xi8> -> tensor<2x2xi8> 593 ``` 594 Input: 595 ``` 596 +-------------+ 597 | 1 2 5 6 | <- devices (0, 0) and (0, 1) 598 | 3 4 7 8 | 599 +-------------+ 600 | 9 10 13 14 | <- devices (1, 0) and (1, 1) 601 | 11 12 15 16 | 602 +-------------+ 603 ``` 604 Result: 605 ``` 606 gather tensor 607 axis 1 608 ------------> 609 +-------+-------+ 610 device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) 611 | 3 4 | 7 8 | 612 +-------+-------+ 613 device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) 614 | 11 12 | 15 16 | 615 +-------+-------+ 616 ``` 617 }]; 618 let arguments = !con(commonArgs, (ins 619 AnyNon0RankedTensor:$input, 620 IndexAttr:$slice_axis 621 )); 622 let results = (outs 623 AnyNon0RankedTensor:$result 624 ); 625 let assemblyFormat = [{ 626 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis 627 attr-dict `:` type($input) `->` type($result) 628 }]; 629 let hasCanonicalizer = 1; 630 let builders = [ 631 OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>, 632 OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)> 633 ]; 634} 635 636def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ 637 Pure, 638 SameOperandsAndResultElementType, 639 SameOperandsAndResultRank]> { 640 let summary = "All-to-all over a device mesh."; 641 let description = [{ 642 Performs an all-to-all on tensor pieces split along `split_axis`. 643 The resulting pieces are concatenated along `concat_axis` on ech device. 644 645 Example: 646 ``` 647 mesh.mesh @mesh0(shape = 3) 648 ... 649 %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0] 650 split_axis = 0 concat_axis = 0 651 : tensor<3x2xi8> -> tensor<3x2xi8> 652 ``` 653 Input: 654 ``` 655 device device device 656 (0) (1) (2) 657 +-------+-------+-------+ | split and concat along 658 | 11 12 | 21 22 | 31 32 | | tensor axis 0 659 | 13 14 | 23 24 | 33 34 | ↓ 660 | 15 16 | 25 26 | 35 36 | 661 +-------+-------+-------+ 662 ``` 663 Result: 664 ``` 665 device device device 666 (0) (1) (2) 667 +-------+-------+-------+ 668 | 11 12 | 13 14 | 15 16 | 669 | 21 22 | 23 24 | 25 26 | 670 | 31 32 | 33 34 | 35 36 | 671 +-------+-------+-------+ 672 ``` 673 }]; 674 let arguments = !con(commonArgs, (ins 675 AnyNon0RankedTensor:$input, 676 IndexAttr:$split_axis, 677 IndexAttr:$concat_axis 678 )); 679 let results = (outs 680 AnyNon0RankedTensor:$result 681 ); 682 let assemblyFormat = [{ 683 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 684 `split_axis` `=` $split_axis 685 `concat_axis` `=` $concat_axis 686 attr-dict `:` type($input) `->` type($result) 687 }]; 688 let hasCanonicalizer = 1; 689} 690 691def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ 692 Pure, 693 AllShapesMatch<["input", "result"]>, 694 AllElementTypesMatch<["input", "result"]> 695 ]> { 696 let summary = "Broadcast over a device mesh."; 697 let description = [{ 698 Broadcast the tensor on `root` to all devices in each respective group. 699 The operation broadcasts along mesh axes `mesh_axes`. 700 The `root` device specifies the in-group multi-index that is broadcast to 701 all other devices in the group. 702 703 Example: 704 ``` 705 mesh.mesh @mesh0(shape = 2x2) 706 707 %1 = mesh.broadcast %0 on @mesh0 708 mesh_axes = [0] 709 root = [0] 710 : (tensor<2xi8>) -> tensor<2xi8> 711 ``` 712 713 Input: 714 ``` 715 +-------+-------+ | broadcast 716 device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 717 +-------+-------+ ↓ 718 device (1, 0) -> | | | <- device (1, 1) 719 +-------+-------+ 720 ``` 721 722 Output: 723 ``` 724 +-------+-------+ 725 device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) 726 +-------+-------+ 727 device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) 728 +-------+-------+ 729 ``` 730 }]; 731 let arguments = !con(commonArgs, (ins 732 AnyRankedTensor:$input, 733 DenseI64ArrayAttr:$root, 734 Variadic<Index>:$root_dynamic 735 )); 736 let results = (outs 737 AnyRankedTensor:$result 738 ); 739 let assemblyFormat = [{ 740 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 741 `root` `=` custom<DynamicIndexList>($root_dynamic, $root) 742 attr-dict `:` functional-type(operands, results) 743 }]; 744 let hasCanonicalizer = 1; 745} 746 747def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ 748 Pure, 749 AllRanksMatch<["input", "result"]>, 750 AllElementTypesMatch<["input", "result"]> 751 ]> { 752 let summary = "Gather over a device mesh."; 753 let description = [{ 754 Gathers on device `root` along the `gather_axis` tensor axis. 755 `root` specifies the coordinates of a device along `mesh_axes`. 756 It uniquely identifies the root device for each device group. 757 The result tensor on non-root devices is undefined. 758 Using it will result in undefined behavior. 759 760 Example: 761 ```mlir 762 mesh.mesh @mesh0(shape = 2x2) 763 ... 764 %1 = mesh.gather %0 on @mesh0 mesh_axes = [1] 765 gather_axis = 1 root = [1] 766 : (tensor<2x2xi8>) -> tensor<2x4xi8> 767 ``` 768 Input: 769 ``` 770 gather tensor 771 axis 1 772 ------------> 773 +-------+-------+ 774 device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) 775 | 3 4 | 7 8 | 776 +-------+-------+ 777 device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) 778 | 11 12 | 15 16 | 779 +-------+-------+ 780 ``` 781 Result: 782 ``` 783 +-------------+ 784 | 1 2 5 6 | <- devices (0, 1) 785 | 3 4 7 8 | 786 +-------------+ 787 | 9 10 13 14 | <- devices (1, 1) 788 | 11 12 15 16 | 789 +-------------+ 790 ``` 791 Devices `(0, 0)` and `(1, 0)` have undefined result. 792 }]; 793 let arguments = !con(commonArgs, (ins 794 AnyNon0RankedTensor:$input, 795 IndexAttr:$gather_axis, 796 DenseI64ArrayAttr:$root, 797 Variadic<Index>:$root_dynamic 798 )); 799 let results = (outs 800 AnyNon0RankedTensor:$result 801 ); 802 let assemblyFormat = [{ 803 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 804 `gather_axis` `=` $gather_axis 805 `root` `=` custom<DynamicIndexList>($root_dynamic, $root) 806 attr-dict `:` functional-type(operands, results) 807 }]; 808 let hasCanonicalizer = 1; 809} 810 811def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [ 812 AllShapesMatch<["input", "result"]>, 813 AllElementTypesMatch<["input", "result"]> 814 ]> { 815 let summary = "Send over a device mesh."; 816 let description = [{ 817 Receive from a device within a device group. 818 }]; 819 let arguments = !con(commonArgs, (ins 820 AnyNon0RankedTensor:$input, 821 OptionalAttr<DenseI64ArrayAttr>:$source, 822 Variadic<Index>:$source_dynamic 823 )); 824 let results = (outs 825 AnyRankedTensor:$result 826 ); 827 let assemblyFormat = [{ 828 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 829 (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)? 830 attr-dict `:` functional-type(operands, results) 831 }]; 832 let hasCanonicalizer = 1; 833} 834 835def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ 836 Pure, 837 AllShapesMatch<["input", "result"]> 838 ]> { 839 let summary = "Reduce over a device mesh."; 840 let description = [{ 841 Reduces on device `root` within each device group. 842 `root` specifies the coordinates of a device along `mesh_axes`. 843 It uniquely identifies the root device within its device group. 844 The accumulation element type is specified by the result type and 845 it does not need to match the input element type. 846 The input element is converted to the result element type before 847 performing the reduction. 848 849 Attributes: 850 `reduction`: Indicates the reduction method. 851 852 Example: 853 ``` 854 %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0] 855 reduction = <max> root = [2, 3] 856 : (tensor<3x4xf32>) -> tensor<3x4xf64> 857 ``` 858 }]; 859 let arguments = !con(commonArgs, (ins 860 AnyRankedTensor:$input, 861 DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction, 862 DenseI64ArrayAttr:$root, 863 Variadic<Index>:$root_dynamic 864 )); 865 let results = (outs 866 AnyRankedTensor:$result 867 ); 868 let assemblyFormat = [{ 869 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 870 (`reduction` `=` $reduction^)? 871 `root` `=` custom<DynamicIndexList>($root_dynamic, $root) 872 attr-dict `:` functional-type(operands, results) 873 }]; 874 let hasCanonicalizer = 1; 875} 876 877def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [ 878 Pure, 879 SameOperandsAndResultRank]> { 880 let summary = "Reduce-scatter over a device mesh."; 881 let description = [{ 882 After the reduction, the result is scattered within each device group. 883 The tensor is split along `scatter_axis` and the pieces distributed 884 across the device group. 885 Example: 886 ``` 887 mesh.mesh @mesh0(shape = 2x2) 888 ... 889 %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1] 890 reduction = <max> scatter_axis = 0 891 : tensor<3x4xf32> -> tensor<1x4xf64> 892 ``` 893 Input: 894 ``` 895 device 896 (0, 1) 897 ↓ 898 +-------+-------+ | scatter tensor 899 device (0, 0) -> | 1 2 | 5 6 | | axis 0 900 | 3 4 | 7 8 | ↓ 901 +-------+-------+ 902 device (1, 0) -> | 9 10 | 13 14 | 903 | 11 12 | 15 16 | 904 +-------+-------+ 905 ↑ 906 device 907 (1, 1) 908 ``` 909 Result: 910 ``` 911 +-------+ 912 | 6 8 | <- devices (0, 0) 913 +-------+ 914 | 10 12 | <- devices (0, 1) 915 +-------+ 916 | 22 24 | <- devices (1, 0) 917 +-------+ 918 | 26 28 | <- devices (1, 1) 919 +-------+ 920 ``` 921 }]; 922 let arguments = !con(commonArgs, (ins 923 AnyNon0RankedTensor:$input, 924 DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction, 925 IndexAttr:$scatter_axis 926 )); 927 let results = (outs 928 AnyRankedTensor:$result 929 ); 930 let assemblyFormat = [{ 931 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 932 (`reduction` `=` $reduction^)? 933 `scatter_axis` `=` $scatter_axis 934 attr-dict `:` type($input) `->` type($result) 935 }]; 936 let hasCanonicalizer = 1; 937} 938 939def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ 940 Pure, 941 AllRanksMatch<["input", "result"]>, 942 AllElementTypesMatch<["input", "result"]> 943 ]> { 944 let summary = "Scatter over a device mesh."; 945 let description = [{ 946 For each device group split the input tensor on the `root` device along 947 axis `scatter_axis` and scatter the parts across the group devices. 948 949 Example: 950 ``` 951 mesh.mesh @mesh0(shape = 2x2) 952 %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] 953 scatter_axis = 0 954 root = [1] 955 : (tensor<2x2xi8>) -> tensor<1x2xi8> 956 ``` 957 958 Input: 959 ``` 960 device 961 (0, 1) 962 ↓ 963 +-------+-------+ | scatter tensor 964 device (0, 0) -> | | | | axis 0 965 | | | ↓ 966 +-------+-------+ 967 device (1, 0) -> | 1 2 | 5 6 | 968 | 3 4 | 7 8 | 969 +-------+-------+ 970 ↑ 971 device 972 (1, 1) 973 ``` 974 975 Result: 976 ``` 977 device 978 (0, 1) 979 ↓ 980 +-------+-------+ 981 device (0, 0) -> | 1 2 | 5 6 | 982 +-------+-------+ 983 device (1, 0) -> | 3 4 | 7 8 | 984 +-------+-------+ 985 ↑ 986 device 987 (1, 1) 988 ``` 989 }]; 990 let arguments = !con(commonArgs, (ins 991 AnyNon0RankedTensor:$input, 992 IndexAttr:$scatter_axis, 993 DenseI64ArrayAttr:$root, 994 Variadic<Index>:$root_dynamic 995 )); 996 let results = (outs 997 AnyRankedTensor:$result 998 ); 999 let assemblyFormat = [{ 1000 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 1001 `scatter_axis` `=` $scatter_axis 1002 `root` `=` custom<DynamicIndexList>($root_dynamic, $root) 1003 attr-dict `:` functional-type(operands, results) 1004 }]; 1005 let hasCanonicalizer = 1; 1006} 1007 1008def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [ 1009 AllShapesMatch<["input", "result"]>, 1010 AllElementTypesMatch<["input", "result"]> 1011 ]> { 1012 let summary = "Send over a device mesh."; 1013 let description = [{ 1014 Send from one device to another within a device group. 1015 }]; 1016 let arguments = !con(commonArgs, (ins 1017 AnyNon0RankedTensor:$input, 1018 DenseI64ArrayAttr:$destination, 1019 Variadic<Index>:$destination_dynamic 1020 )); 1021 let results = (outs 1022 AnyRankedTensor:$result 1023 ); 1024 let assemblyFormat = [{ 1025 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 1026 `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination) 1027 attr-dict `:` functional-type(operands, results) 1028 }]; 1029 let hasCanonicalizer = 1; 1030} 1031 1032def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ 1033 Pure, 1034 SameOperandsAndResultElementType, 1035 SameOperandsAndResultShape 1036 ]> { 1037 let summary = "Shift over a device mesh."; 1038 let description = [{ 1039 Within each device group shift along mesh axis `shift_axis` by an offset 1040 `offset`. 1041 The result on devices that do not have a corresponding source is undefined. 1042 `shift_axis` must be one of `mesh_axes`. 1043 If the `rotate` attribute is present, 1044 instead of a shift a rotation is done. 1045 1046 Example: 1047 ``` 1048 mesh.mesh @mesh0(shape = 2x4) 1049 %1 = mesh.shift on @mesh0 mesh_axes = [1] 1050 shift_axis = 1 offset = 2 rotate 1051 : tensor<2xi8> -> tensor<2xi8> 1052 ``` 1053 1054 Input: 1055 ``` 1056 mesh axis 1 1057 -----------> 1058 1059 +----+----+----+----+ 1060 | 1 | 2 | 3 | 4 | 1061 +----+----+----+----+ 1062 | 5 | 6 | 7 | 8 | 1063 +----+----+----+----+ 1064 ``` 1065 1066 Result: 1067 ``` 1068 +----+----+----+----+ 1069 | 3 | 4 | 1 | 2 | 1070 +----+----+----+----+ 1071 | 7 | 8 | 5 | 6 | 1072 +----+----+----+----+ 1073 ``` 1074 }]; 1075 let arguments = !con(commonArgs, (ins 1076 AnyNon0RankedTensor:$input, 1077 IndexAttr:$shift_axis, 1078 I64Attr:$offset, 1079 UnitAttr:$rotate 1080 )); 1081 let results = (outs 1082 AnyRankedTensor:$result 1083 ); 1084 let assemblyFormat = [{ 1085 $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? 1086 `shift_axis` `=` $shift_axis 1087 `offset` `=` $offset 1088 (`rotate` $rotate^)? 1089 attr-dict `:` type($input) `->` type($result) 1090 }]; 1091 let hasCanonicalizer = 1; 1092} 1093 1094def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ 1095 Pure, 1096 DestinationStyleOpInterface, 1097 TypesMatchWith< 1098 "result has same type as destination", 1099 "result", "destination", "$_self">, 1100 DeclareOpInterfaceMethods<SymbolUserOpInterface> 1101]> { 1102 let summary = "Update halo data."; 1103 let description = [{ 1104 This operation updates halo regions of shards, e.g. if their sharding 1105 specified halos and the actual tensor/memref data might have changed 1106 on the remote devices. Changes might be caused by mutating operations 1107 and/or if the new halo regions are larger than the existing ones. 1108 1109 Destination is supposed to be initialized with the local data (not halos). 1110 1111 Assumes all devices hold tensors with same-sized halo data as specified 1112 by `source_halo_sizes/static_source_halo_sizes` and 1113 `destination_halo_sizes/static_destination_halo_sizes` in source shard 1114 and destination/result shard. 1115 1116 `split_axes` specifies for each tensor axis along which mesh axes its halo 1117 data is updated. 1118 1119 }]; 1120 let arguments = (ins 1121 AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination, 1122 FlatSymbolRefAttr:$mesh, 1123 Mesh_MeshAxesArrayAttr:$split_axes, 1124 Variadic<I64>:$halo_sizes, 1125 DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes 1126 ); 1127 let results = (outs 1128 AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result 1129 ); 1130 let assemblyFormat = [{ 1131 $destination 1132 `on` $mesh 1133 `split_axes` `=` $split_axes 1134 (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)? 1135 attr-dict `:` type($result) 1136 }]; 1137 let extraClassDeclaration = [{ 1138 MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); } 1139 }]; 1140} 1141#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD 1142