1//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD 10#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD 11 12include "mlir/IR/OpBase.td" 13 14def ShardingInterface : OpInterface<"ShardingInterface"> { 15 let description = [{ 16 Interface for allowing operations to expose information needed to 17 shard them. 18 }]; 19 let cppNamespace = "::mlir::mesh"; 20 21 let methods = [ 22 InterfaceMethod< 23 /*desc=*/[{ 24 Returns a list of iterator types that describe the number of loops. 25 The iterator types determine how the operation traverses its input and 26 output tensors. 27 28 Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator 29 types are parallel, parallel, reduction. This indicates that M and 30 N are traversed in parallel, while the K dimension is used for 31 reduction. 32 }], 33 /*retType=*/"SmallVector<mlir::utils::IteratorType>", 34 /*methodName=*/"getLoopIteratorTypes", 35 /*args=*/(ins), 36 /*methodBody=*/"", 37 /*defaultImplementation=*/"return {};" 38 >, 39 InterfaceMethod< 40 /*desc=*/[{ 41 Return the kind of all reduction loop iterators. 42 The order is the same as the same as the result from 43 `getLoopIteratorTypes`. 44 45 Example 1: 46 iterator types = (parallel, reduction, parallel, reduction) 47 || || 48 reduction kinds = ( sum, max) 49 50 Example 2: 51 A softmax op's loop iterator types are parallel and 52 reduction. 53 The reduction iterator will be of kind `generic`, since it is non of 54 the available presets. 55 }], 56 /*retType=*/"SmallVector<ReductionKind>", 57 /*methodName=*/"getReductionLoopIteratorKinds", 58 /*args=*/(ins), 59 /*methodBody=*/"", 60 /*defaultImplementation=*/"return {};" 61 >, 62 InterfaceMethod< 63 /*desc=*/[{ 64 Return the indexing maps attribute within the current operation. 65 Indexing maps determine how indices in the iteration space map to 66 tensor indices. They are specified using `affine_map` in MLIR, which 67 provides an affine transformation of indices. 68 }], 69 /*retTy=*/"SmallVector<AffineMap>", 70 /*methodName=*/"getIndexingMaps", 71 /*args=*/(ins), 72 /*methodBody=*/"", 73 /*defaultImplementation=*/"return {};" 74 >, 75 InterfaceMethod< 76 /*desc=*/[{ 77 Given that certain operands or results of the operation may have 78 sharding annotations, this method leverages this information to 79 deduce how the operation should be sharded. 80 The passed sharding may be incomplete, this gives freedom for the 81 op to select the most appropriate shardings for all the operands 82 and results and the op itself. 83 }], 84 /*retTy=*/"FailureOr<ShardingOption>", 85 /*methodName=*/"getShardingOption", 86 /*args=*/(ins 87 "ArrayRef<MeshSharding>": $operandShardings, 88 "ArrayRef<MeshSharding>": $resultShardings 89 ), 90 /*methodBody=*/"", 91 /*defaultImplementation=*/[{ 92 return detail::defaultGetShardingOption( 93 $_op.getOperation(), operandShardings, resultShardings); 94 }] 95 >, 96 InterfaceMethod< 97 /*desc=*/[{ 98 Based on a given ShardingOption, get the operand and result 99 operations for the operands and results sharding annotations. 100 This is what shardings the operands and results need to have in order 101 to shard the op according to shardingOption. 102 }], 103 /*retTy=*/"FailureOr<std::vector<MeshSharding>>", 104 /*methodName=*/"getShardingAnnotations", 105 /*args=*/(ins 106 "const ShardingOption &":$shardingOption 107 ), 108 /*methodBody=*/"", 109 /*defaultImplementation=*/[{ 110 return detail::defaultGetShardingAnnotations( 111 $_op.getOperation(), shardingOption); 112 }] 113 >, 114 InterfaceMethod< 115 /*desc=*/[{ 116 Based on a given ShardingOption, this method adds `mesh.shard` 117 operations for the operands and results that previously lacked 118 sharding annotations. 119 }], 120 /*retTy=*/"LogicalResult", 121 /*methodName=*/"addShardingAnnotations", 122 /*args=*/(ins 123 "OpBuilder &":$b, 124 "const ShardingOption &":$shardingOption 125 ), 126 /*methodBody=*/"", 127 /*defaultImplementation=*/[{ 128 return detail::defaultAddShardingAnnotations( 129 $_op.getOperation(), b, shardingOption); 130 }] 131 >, 132 InterfaceMethod< 133 /*desc=*/[{ 134 Convert self to SPMD form. 135 This method is used during the spmdization pass of a program fully 136 annotated with shardings. 137 138 The spmdization algorithm would read the surrounding sharding 139 annotations from the IR for each argument/result and prepare 140 `operandShardings` and `resultShardings`. 141 Values that are not ranked tensors do not have sharding annotations. 142 In this case their corresponding MeshSharding is null. 143 144 For convenience it will also prepare `spmdizedOperands`, although 145 they can be retrieved from the `spmdizationMap`. 146 147 The `spmdizationMap` contains a mapping from unsharded to 148 sharded/spmdized values that are constructed during the spmdization 149 pass. The interface implementation must populate `spmdizationMap` 150 with the mapping for this op's results. 151 152 `builder` is set to insert new operations in the appropriate point. 153 The implementation should not return the builder to the original 154 insertion point. 155 It should leave it as is after all insertions are done. 156 157 The default implementation does full replication. 158 This assumes that all sharding annotations are for full replication. 159 }], 160 /*retTy=*/"LogicalResult", 161 /*methodName=*/"spmdize", 162 /*args=*/(ins 163 "ArrayRef<Value>": $spmdizedOperands, 164 "ArrayRef<MeshSharding>": $operandShardings, 165 "ArrayRef<MeshSharding>": $resultShardings, 166 "IRMapping&": $spmdizationMap, 167 "SymbolTableCollection &": $symbolTableCollection, 168 "OpBuilder &":$builder 169 ), 170 /*methodBody=*/"", 171 /*defaultImplementation=*/[{ 172 spmdizeFullyReplicatedOperation( 173 *$_op.getOperation(), spmdizedOperands, operandShardings, 174 resultShardings, spmdizationMap, symbolTableCollection, builder); 175 return success(); 176 }]> 177 ]; 178 179 let extraClassDeclaration = [{ 180 LogicalResult verifyShardingInterfaceImpl(); 181 182 void printLoopTypesAndIndexingMaps(raw_ostream &os); 183 }]; 184} 185 186 187#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD 188