xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
10#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
11
12include "mlir/IR/OpBase.td"
13
14def ShardingInterface : OpInterface<"ShardingInterface"> {
15    let description = [{
16        Interface for allowing operations to expose information needed to
17        shard them.
18    }];
19    let cppNamespace = "::mlir::mesh";
20
21    let methods = [
22      InterfaceMethod<
23        /*desc=*/[{
24          Returns a list of iterator types that describe the number of loops.
25          The iterator types determine how the operation traverses its input and
26          output tensors.
27
28          Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
29          types are parallel, parallel, reduction. This indicates that M and
30          N are traversed in parallel, while the K dimension is used for
31          reduction.
32        }],
33        /*retType=*/"SmallVector<mlir::utils::IteratorType>",
34        /*methodName=*/"getLoopIteratorTypes",
35        /*args=*/(ins),
36        /*methodBody=*/"",
37        /*defaultImplementation=*/"return {};"
38      >,
39      InterfaceMethod<
40        /*desc=*/[{
41          Return the kind of all reduction loop iterators.
42          The order is the same as the same as the result from
43          `getLoopIteratorTypes`.
44
45          Example 1:
46          iterator types =  (parallel, reduction, parallel, reduction)
47                                             ||                   ||
48          reduction kinds = (                sum,                 max)
49
50          Example 2:
51          A softmax op's loop iterator types are parallel and
52          reduction.
53          The reduction iterator will be of kind `generic`, since it is non of
54          the available presets.
55        }],
56        /*retType=*/"SmallVector<ReductionKind>",
57        /*methodName=*/"getReductionLoopIteratorKinds",
58        /*args=*/(ins),
59        /*methodBody=*/"",
60        /*defaultImplementation=*/"return {};"
61      >,
62      InterfaceMethod<
63        /*desc=*/[{
64          Return the indexing maps attribute within the current operation.
65          Indexing maps determine how indices in the iteration space map to
66          tensor indices. They are specified using `affine_map` in MLIR, which
67          provides an affine transformation of indices.
68        }],
69        /*retTy=*/"SmallVector<AffineMap>",
70        /*methodName=*/"getIndexingMaps",
71        /*args=*/(ins),
72        /*methodBody=*/"",
73        /*defaultImplementation=*/"return {};"
74      >,
75      InterfaceMethod<
76        /*desc=*/[{
77          Given that certain operands or results of the operation may have
78          sharding annotations, this method leverages this information to
79          deduce how the operation should be sharded.
80          The passed sharding may be incomplete, this gives freedom for the
81          op to select the most appropriate shardings for all the operands
82          and results and the op itself.
83        }],
84        /*retTy=*/"FailureOr<ShardingOption>",
85        /*methodName=*/"getShardingOption",
86        /*args=*/(ins
87          "ArrayRef<MeshSharding>": $operandShardings,
88          "ArrayRef<MeshSharding>": $resultShardings
89        ),
90        /*methodBody=*/"",
91        /*defaultImplementation=*/[{
92          return detail::defaultGetShardingOption(
93            $_op.getOperation(), operandShardings, resultShardings);
94        }]
95      >,
96      InterfaceMethod<
97        /*desc=*/[{
98          Based on a given ShardingOption, get the operand and result
99          operations for the operands and results sharding annotations.
100          This is what shardings the operands and results need to have in order
101          to shard the op according to shardingOption.
102        }],
103        /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
104        /*methodName=*/"getShardingAnnotations",
105        /*args=*/(ins
106          "const ShardingOption &":$shardingOption
107        ),
108        /*methodBody=*/"",
109        /*defaultImplementation=*/[{
110          return detail::defaultGetShardingAnnotations(
111            $_op.getOperation(), shardingOption);
112        }]
113      >,
114      InterfaceMethod<
115        /*desc=*/[{
116          Based on a given ShardingOption, this method adds `mesh.shard`
117          operations for the operands and results that previously lacked
118          sharding annotations.
119        }],
120        /*retTy=*/"LogicalResult",
121        /*methodName=*/"addShardingAnnotations",
122        /*args=*/(ins
123          "OpBuilder &":$b,
124          "const ShardingOption &":$shardingOption
125        ),
126        /*methodBody=*/"",
127        /*defaultImplementation=*/[{
128          return detail::defaultAddShardingAnnotations(
129            $_op.getOperation(), b, shardingOption);
130        }]
131      >,
132      InterfaceMethod<
133        /*desc=*/[{
134          Convert self to SPMD form.
135          This method is used during the spmdization pass of a program fully
136          annotated with shardings.
137
138          The spmdization algorithm would read the surrounding sharding
139          annotations from the IR for each argument/result and prepare
140          `operandShardings` and `resultShardings`.
141          Values that are not ranked tensors do not have sharding annotations.
142          In this case their corresponding MeshSharding is null.
143
144          For convenience it will also prepare `spmdizedOperands`, although
145          they can be retrieved from the `spmdizationMap`.
146
147          The `spmdizationMap` contains a mapping from unsharded to
148          sharded/spmdized values that are constructed during the spmdization
149          pass. The interface implementation must populate `spmdizationMap`
150          with the mapping for this op's results.
151
152          `builder` is set to insert new operations in the appropriate point.
153          The implementation should not return the builder to the original
154          insertion point.
155          It should leave it as is after all insertions are done.
156
157          The default implementation does full replication.
158          This assumes that all sharding annotations are for full replication.
159        }],
160        /*retTy=*/"LogicalResult",
161        /*methodName=*/"spmdize",
162        /*args=*/(ins
163          "ArrayRef<Value>": $spmdizedOperands,
164          "ArrayRef<MeshSharding>": $operandShardings,
165          "ArrayRef<MeshSharding>": $resultShardings,
166          "IRMapping&": $spmdizationMap,
167          "SymbolTableCollection &": $symbolTableCollection,
168          "OpBuilder &":$builder
169        ),
170        /*methodBody=*/"",
171        /*defaultImplementation=*/[{
172          spmdizeFullyReplicatedOperation(
173            *$_op.getOperation(), spmdizedOperands, operandShardings,
174              resultShardings, spmdizationMap, symbolTableCollection, builder);
175          return success();
176        }]>
177    ];
178
179    let extraClassDeclaration = [{
180      LogicalResult verifyShardingInterfaceImpl();
181
182      void printLoopTypesAndIndexingMaps(raw_ostream &os);
183    }];
184}
185
186
187#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
188