1//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the operation definition file for linear algebra operations. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef LINALG_OPS 14#define LINALG_OPS 15 16include "mlir/Dialect/Linalg/IR/LinalgBase.td" 17include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" 18include "mlir/Interfaces/ControlFlowInterfaces.td" 19include "mlir/Interfaces/DestinationStyleOpInterface.td" 20include "mlir/Interfaces/InferTypeOpInterface.td" 21include "mlir/Interfaces/LoopLikeInterface.td" 22include "mlir/Interfaces/SideEffectInterfaces.td" 23include "mlir/Interfaces/TilingInterface.td" 24include "mlir/Interfaces/ViewLikeInterface.td" 25 26// Base class for Linalg dialect ops that do not correspond to library calls. 27class Linalg_Op<string mnemonic, list<Trait> traits = []> : 28 Op<Linalg_Dialect, mnemonic, traits>; 29 30def Linalg_YieldOp : Linalg_Op<"yield", [Pure, ReturnLike, Terminator]>, 31 Arguments<(ins Variadic<AnyType>:$values)> { 32 let summary = "Linalg yield operation"; 33 let description = [{ 34 `linalg.yield` is a special terminator operation for blocks inside regions 35 in `linalg` generic ops. It returns values to the immediately enclosing 36 `linalg` generic op. 37 38 Example: 39 40 ```mlir 41 linalg.yield %f0, %f1 : f32, f32 42 ``` 43 }]; 44 let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; 45 let hasCustomAssemblyFormat = 1; 46 let hasVerifier = 1; 47} 48 49def Linalg_IndexOp : Linalg_Op<"index", [Pure]>, 50 Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>, 51 Results<(outs Index:$result)> { 52 let summary = "linalg index operation"; 53 let description = [{ 54 The `linalg.index` operation returns the iteration index of the immediately 55 enclosing linalg structured operation for the iteration dimension `dim`. The 56 `dim` attribute specifies the position of the accessed dimension in the 57 indexing map domain. 58 59 Example: 60 61 ```mlir 62 #map = affine_map<(i, j) -> (i, j)> 63 linalg.generic {indexing_maps = [#map, #map], 64 iterator_types = ["parallel", "parallel"]} 65 outs(%I, %J : memref<?x?xindex>, memref<?x?xindex>) { 66 ^bb0(%arg0 : index, %arg1 : index): 67 // Access the outer iteration dimension i 68 %i = linalg.index 0 : index 69 // Access the inner iteration dimension j 70 %j = linalg.index 1 : index 71 linalg.yield %i, %j : index, index 72 } 73 ``` 74 75 This may lower to IR resembling: 76 77 ```mlir 78 %0 = dim %I, %c0 : memref<?x?xindex> 79 %1 = dim %I, %c1 : memref<?x?xindex> 80 scf.for %i = %c0 to %0 step %c1 { 81 scf.for %j = %c0 to %1 step %c1 { 82 store %i, %I[%i, %j] : memref<?x?xindex> 83 store %j, %J[%i, %j] : memref<?x?xindex> 84 } 85 } 86 ``` 87 }]; 88 89 let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; 90 let hasVerifier = 1; 91} 92 93def Linalg_SoftmaxOp : Linalg_Op<"softmax", 94 [DestinationStyleOpInterface, 95 PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, 96 DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, 97 DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>, 98 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, 99 DeclareOpInterfaceMethods<TilingInterface, 100 ["getIterationDomain", 101 "getLoopIteratorTypes", 102 "getResultTilePosition", 103 "getTiledImplementation"]>]> { 104 let summary = "Softmax operator"; 105 let description = [{ 106 linalg.softmax computes a numerically stable version of softmax. 107 108 For a given input tensor and a specified dimension `d`, compute: 109 1. the max `m` along that dimension `d` 110 2. f(x) = exp(x - m) 111 3. sum f(x) along dimension d to get l(x). 112 4. compute the final result f(x) / l(x). 113 114 This is an aggregate linalg operation that further reduces to a small DAG of 115 structured operations. 116 117 Warning: Regarding the tiling capabilities, the implementation doesn't 118 check that the provided dimensions make sense. This is the responsability 119 of the transformation calling the tiling to ensure that the provided 120 sizes for each dimension make sense with respect to the semantic of 121 softmax. 122 }]; 123 124 let arguments = (ins AnyShaped:$input, 125 AnyShaped:$output, 126 I64Attr:$dimension 127 ); 128 129 let results = (outs Variadic<AnyRankedTensor>:$result); 130 let hasFolder = 1; 131 let assemblyFormat = [{ 132 attr-dict 133 `dimension` `(` $dimension `)` 134 `ins` `(` $input `:` type($input) `)` 135 `outs` `(` $output `:` type($output) `)` 136 (`->` type($result)^)? 137 }]; 138 139 let extraClassDeclaration = [{ 140 ShapedType getInputOperandType() { 141 return cast<ShapedType>(getInput().getType()); 142 } 143 ShapedType getOutputOperandType() { 144 return cast<ShapedType>(getOutput().getType()); 145 } 146 int64_t getInputOperandRank() { 147 return getInputOperandType().getRank(); 148 } 149 int64_t getOutputOperandRank() { 150 return getOutputOperandType().getRank(); 151 } 152 MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } 153 }]; 154 let hasVerifier = 1; 155} 156 157def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform", 158 [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface, 159 DeclareOpInterfaceMethods<TilingInterface, 160 ["getIterationDomain", 161 "getLoopIteratorTypes", 162 "getResultTilePosition", 163 "getTiledImplementation"]>]> { 164 let summary = "Winograd filter transform operator"; 165 let description = [{ 166 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched 167 matrix multiply. Before the matrix multiply, it will convert filter and 168 input into a format suitable for batched matrix multiply. After the matrix 169 multiply, it will convert output to the final result tensor. 170 171 The algorithm F(m x m, r x r) is 172 173 Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A 174 175 The size of output Y is m x m. The size of filter g is r x r. The size of 176 input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are 177 transformation matrices. 178 179 This operator is defined to represent the high level concept of filter 180 transformation (G x g x G^T) in the Winograd Conv2D algorithm. 181 }]; 182 183 let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter, 184 TensorRankOf<[AnyType], [4]>:$output, 185 I64Attr:$m, 186 I64Attr:$r 187 ); 188 189 let results = (outs TensorRankOf<[AnyType], [4]>:$result); 190 let assemblyFormat = [{ 191 attr-dict 192 `m` `(` $m `)` 193 `r` `(` $r `)` 194 `ins` `(` $filter `:` type($filter) `)` 195 `outs` `(` $output `:` type($output) `)` 196 `->` type($result) 197 }]; 198 let extraClassDeclaration = [{ 199 ShapedType getFilterOperandType() { 200 return cast<ShapedType>(getFilter().getType()); 201 } 202 ShapedType getOutputOperandType() { 203 return cast<ShapedType>(getOutput().getType()); 204 } 205 int64_t getFilterOperandRank() { 206 return getFilterOperandType().getRank(); 207 } 208 int64_t getOutputOperandRank() { 209 return getOutputOperandType().getRank(); 210 } 211 int64_t getFilterFDim() { 212 return 0; 213 } 214 int64_t getFilterHDim() { 215 return 1; 216 } 217 int64_t getFilterWDim() { 218 return 2; 219 } 220 int64_t getFilterCDim() { 221 return 3; 222 } 223 MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } 224 }]; 225 let hasVerifier = 1; 226} 227 228def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform", 229 [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface, 230 DeclareOpInterfaceMethods<TilingInterface, 231 ["getIterationDomain", 232 "getLoopIteratorTypes", 233 "getResultTilePosition", 234 "getTiledImplementation"]>]> { 235 let summary = "Winograd input transform operator"; 236 let description = [{ 237 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched 238 matrix multiply. Before the matrix multiply, it will convert filter and 239 input into a format suitable for batched matrix multiply. After the matrix 240 multiply, it will convert output to the final result tensor. 241 242 The algorithm F(m x m, r x r) is 243 244 Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A 245 246 The size of output Y is m x m. The size of filter g is r x r. The size of 247 input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are 248 transformation matrices. 249 250 This operator is defined to represent the high level concept of input 251 transformation (B^T x d x B) in the Winograd Conv2D algorithm. 252 }]; 253 254 let arguments = (ins TensorRankOf<[AnyType], [4]>:$input, 255 TensorRankOf<[AnyType], [6]>:$output, 256 I64Attr:$m, 257 I64Attr:$r 258 ); 259 260 let results = (outs TensorRankOf<[AnyType], [6]>:$result); 261 let assemblyFormat = [{ 262 attr-dict 263 `m` `(` $m `)` 264 `r` `(` $r `)` 265 `ins` `(` $input `:` type($input) `)` 266 `outs` `(` $output `:` type($output) `)` 267 `->` type($result) 268 }]; 269 let extraClassDeclaration = [{ 270 ShapedType getInputOperandType() { 271 return cast<ShapedType>(getInput().getType()); 272 } 273 ShapedType getOutputOperandType() { 274 return cast<ShapedType>(getOutput().getType()); 275 } 276 int64_t getInputOperandRank() { 277 return getInputOperandType().getRank(); 278 } 279 int64_t getOutputOperandRank() { 280 return getOutputOperandType().getRank(); 281 } 282 int64_t getInputNDim() { 283 return 0; 284 } 285 int64_t getInputHDim() { 286 return 1; 287 } 288 int64_t getInputWDim() { 289 return 2; 290 } 291 int64_t getInputCDim() { 292 return 3; 293 } 294 int64_t getOutputAlphaHDim() { 295 return 0; 296 } 297 int64_t getOutputAlphaWDim() { 298 return 1; 299 } 300 int64_t getOutputTileHDim() { 301 return 2; 302 } 303 int64_t getOutputTileWDim() { 304 return 3; 305 } 306 int64_t getOutputNDim() { 307 return 4; 308 } 309 int64_t getOutputCDim() { 310 return 5; 311 } 312 MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } 313 }]; 314 let hasVerifier = 1; 315} 316 317def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform", 318 [AllElementTypesMatch<["value", "output"]>, DestinationStyleOpInterface, 319 DeclareOpInterfaceMethods<TilingInterface, 320 ["getIterationDomain", 321 "getLoopIteratorTypes", 322 "getResultTilePosition", 323 "getTiledImplementation"]>]> { 324 let summary = "Winograd output transform operator"; 325 let description = [{ 326 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched 327 matrix multiply. Before the matrix multiply, it will convert filter and 328 input into a format suitable for batched matrix multiply. After the matrix 329 multiply, it will convert output to the final result tensor. 330 331 The algorithm F(m x m, r x r) is 332 333 Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A 334 335 The size of output Y is m x m. The size of filter g is r x r. The size of 336 input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are 337 transformation matrices. 338 339 This operator is defined to represent the high level concept of output 340 transformation (A^T x y x A) in the Winograd Conv2D algorithm. 341 }]; 342 343 let arguments = (ins TensorRankOf<[AnyType], [6]>:$value, 344 TensorRankOf<[AnyType], [4]>:$output, 345 I64Attr:$m, 346 I64Attr:$r 347 ); 348 349 let results = (outs TensorRankOf<[AnyType], [4]>:$result); 350 let assemblyFormat = [{ 351 attr-dict 352 `m` `(` $m `)` 353 `r` `(` $r `)` 354 `ins` `(` $value `:` type($value) `)` 355 `outs` `(` $output `:` type($output) `)` 356 `->` type($result) 357 }]; 358 let extraClassDeclaration = [{ 359 ShapedType getValueOperandType() { 360 return cast<ShapedType>(getValue().getType()); 361 } 362 ShapedType getOutputOperandType() { 363 return cast<ShapedType>(getOutput().getType()); 364 } 365 int64_t getValueOperandRank() { 366 return getValueOperandType().getRank(); 367 } 368 int64_t getOutputOperandRank() { 369 return getOutputOperandType().getRank(); 370 } 371 int64_t getValueAlphaHDim() { 372 return 0; 373 } 374 int64_t getValueAlphaWDim() { 375 return 1; 376 } 377 int64_t getValueTileHDim() { 378 return 2; 379 } 380 int64_t getValueTileWDim() { 381 return 3; 382 } 383 int64_t getValueNDim() { 384 return 4; 385 } 386 int64_t getValueFDim() { 387 return 5; 388 } 389 int64_t getOutputNDim() { 390 return 0; 391 } 392 int64_t getOutputHDim() { 393 return 1; 394 } 395 int64_t getOutputWDim() { 396 return 2; 397 } 398 int64_t getOutputFDim() { 399 return 3; 400 } 401 MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } 402 }]; 403 let hasVerifier = 1; 404} 405 406#endif // LINALG_OPS 407