xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1 //===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11 
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
25 using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
26 using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
27 
28 struct ShardingOption {
29   // An array of int array. The sub-array at the i-th position signifies the
30   // mesh axes the i-th loop will be sharded on.
31   ShardingArray shardingArray = {};
32   FlatSymbolRefAttr mesh = nullptr;
33   // `empty` being true indicates that no sharding information can be inferred
34   // at present. Note that it is different from the case where an operation is
35   // not sharded.
36   bool empty = false;
37   ShardingOption() = default;
38   ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
39       : shardingArray(std::move(shardingArray)), mesh(mesh) {}
40   static ShardingOption makeEmpty() {
41     auto res = ShardingOption();
42     res.empty = true;
43     return res;
44   }
45 };
46 
47 // This method retrieves the 'MeshSharding' from a given operation
48 // result and includes the 'annotate_for_users' information.
49 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
50 
51 // This method retrieves the 'MeshSharding' from a given operation
52 // operand and includes the 'annotate_for_users' information.
53 FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
54 
55 namespace detail {
56 
57 FailureOr<ShardingOption>
58 defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
59                          ArrayRef<MeshSharding> resultShardings);
60 
61 FailureOr<std::vector<MeshSharding>>
62 defaultGetShardingAnnotations(Operation *op,
63                               const ShardingOption &shardingOption);
64 
65 LogicalResult
66 defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
67                               const ShardingOption &shardingOption);
68 
69 } // namespace detail
70 
71 // Assumes full replication on all ranked tensor arguments and results.
72 void spmdizeFullyReplicatedOperation(Operation &op,
73                                      ArrayRef<Value> spmdizedOperands,
74                                      ArrayRef<MeshSharding> operandShardings,
75                                      ArrayRef<MeshSharding> resultShardings,
76                                      IRMapping &spmdizationMap,
77                                      SymbolTableCollection &symbolTable,
78                                      OpBuilder &builder);
79 
80 } // namespace mesh
81 } // namespace mlir
82 
83 /// Include the ODS generated interface header files.
84 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
85 
86 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
87