xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (revision abfac563f5b5a123e4bf773c3a09777e6fc4f50c)
1//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9
10#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
11#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
12
13include "mlir/Pass/PassBase.td"
14
15//===----------------------------------------------------------------------===//
16// ShardingPropagation
17//===----------------------------------------------------------------------===//
18
19def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
20  let summary = "sharding propagation";
21  let description = [{
22    Propagates sharding information throughout the graph. After this pass, each
23    of the operations' operands and results is annotated with a `mesh.shard`
24    operation, and the operations themselves are added with sharding option
25    attributes.
26  }];
27  let dependentDialects = [
28    "mesh::MeshDialect"
29  ];
30}
31
32def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
33  let summary = "Partition a function into SPMD form.";
34  let description = [{
35    This pass fits in right after a pass that annotates the function with
36    shardings like the `ShardingPropagation` pass.
37    It operates on a fully annotated IR.
38
39    A fully annotated IR required that all ranked tensor operands, results and
40    block arguments are annotated with the `mesh.shard` operation.
41
42    All direct descendant operations in the function must implement the
43    `ShardingInterface` interface or all their ranked tensor operands and
44    results must have full replication sharding.
45
46    The input IR must have sharding annotations such that each operation
47    that implements `ShardingInterface` can handle during spmdization with
48    its `spmdize` method.
49    This can be achieved with the `ShardingPropagation` pass.
50
51    If the function has multiple terminating blocks,
52    it is the responsibility of the the one who annotates the function with
53    shardings to make sure that all returns would be consisted that is,
54    have the same sharding.
55
56    Example:
57    ```mlir
58    mesh.mesh @mesh_1d(shape = 2)
59
60    func.func @f(
61      %arg0: tensor<2xi8>
62    ) -> tensor<2xi8> {
63      %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
64      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
65      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
66      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
67      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
68      return %4 : tensor<2xi8>
69    }
70    ```
71    Spmdizing the above would result in
72    * Performing the element-wise `abs` operation on each device.
73    * Resharding to full replication with an all-gather.
74
75    ```mlir
76    mesh.mesh @mesh_1d(shape = 2)
77
78    func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
79      %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
80      %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
81      return %1 : tensor<2xi8>
82    }
83    ```
84  }];
85  let dependentDialects = [
86    "mesh::MeshDialect"
87  ];
88}
89
90#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
91