1//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9 10#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD 11#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD 12 13include "mlir/Pass/PassBase.td" 14 15//===----------------------------------------------------------------------===// 16// ShardingPropagation 17//===----------------------------------------------------------------------===// 18 19def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> { 20 let summary = "sharding propagation"; 21 let description = [{ 22 Propagates sharding information throughout the graph. After this pass, each 23 of the operations' operands and results is annotated with a `mesh.shard` 24 operation, and the operations themselves are added with sharding option 25 attributes. 26 }]; 27 let dependentDialects = [ 28 "mesh::MeshDialect" 29 ]; 30} 31 32def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> { 33 let summary = "Partition a function into SPMD form."; 34 let description = [{ 35 This pass fits in right after a pass that annotates the function with 36 shardings like the `ShardingPropagation` pass. 37 It operates on a fully annotated IR. 38 39 A fully annotated IR required that all ranked tensor operands, results and 40 block arguments are annotated with the `mesh.shard` operation. 41 42 All direct descendant operations in the function must implement the 43 `ShardingInterface` interface or all their ranked tensor operands and 44 results must have full replication sharding. 45 46 The input IR must have sharding annotations such that each operation 47 that implements `ShardingInterface` can handle during spmdization with 48 its `spmdize` method. 49 This can be achieved with the `ShardingPropagation` pass. 50 51 If the function has multiple terminating blocks, 52 it is the responsibility of the the one who annotates the function with 53 shardings to make sure that all returns would be consisted that is, 54 have the same sharding. 55 56 Example: 57 ```mlir 58 mesh.mesh @mesh_1d(shape = 2) 59 60 func.func @f( 61 %arg0: tensor<2xi8> 62 ) -> tensor<2xi8> { 63 %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8> 64 %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> 65 %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> 66 %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8> 67 %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> 68 return %4 : tensor<2xi8> 69 } 70 ``` 71 Spmdizing the above would result in 72 * Performing the element-wise `abs` operation on each device. 73 * Resharding to full replication with an all-gather. 74 75 ```mlir 76 mesh.mesh @mesh_1d(shape = 2) 77 78 func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> { 79 %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8> 80 %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> 81 return %1 : tensor<2xi8> 82 } 83 ``` 84 }]; 85 let dependentDialects = [ 86 "mesh::MeshDialect" 87 ]; 88} 89 90#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD 91