1//===- LinalgStructuredOps.td - Linalg dialect library ops -*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the operation definition file for structured operations on buffers 10// that correspond to underlying library calls (e.g. BLAS). 11// 12//===----------------------------------------------------------------------===// 13 14#ifndef LINALG_STRUCTURED_OPS 15#define LINALG_STRUCTURED_OPS 16 17include "mlir/Dialect/Linalg/IR/LinalgBase.td" 18include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" 19include "mlir/Interfaces/DestinationStyleOpInterface.td" 20include "mlir/Interfaces/InferTypeOpInterface.td" 21include "mlir/Interfaces/SideEffectInterfaces.td" 22include "mlir/IR/OpAsmInterface.td" 23 24// Base Tablegen class for Linalg ops. 25// Linalg ops that correspond to library calls operate on ShapedType as their 26// first operands. These may be optionally followed by non-view operands 27// depending on the specific Linalg op. 28class LinalgStructuredBase_Op<string mnemonic, list<Trait> props> 29 : Op<Linalg_Dialect, mnemonic, !listconcat([ 30 SingleBlockImplicitTerminator<"YieldOp">, 31 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, 32 DeclareOpInterfaceMethods<ConditionallySpeculatable>, 33 RecursiveMemoryEffects, 34 DestinationStyleOpInterface, 35 LinalgStructuredInterface, 36 ReifyRankedShapedTypeOpInterface], props)> { 37 code structuredOpsBaseDecls = [{ 38 // Return whether the op accesses the iteration indices. 39 bool hasIndexSemantics() { 40 return !this->getBody()->getOps<IndexOp>().empty(); 41 } 42 43 LogicalResult reifyResultShapes(OpBuilder &b, 44 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 45 return llvm::cast<LinalgOp>(getOperation()).reifyResultShapes(b, 46 reifiedReturnShapes); 47 } 48 }]; 49} 50 51//===----------------------------------------------------------------------===// 52// Generic Linalg ops. 53//===----------------------------------------------------------------------===// 54 55def GenericOp : LinalgStructuredBase_Op<"generic", [ 56 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>, 57 AttrSizedOperandSegments]> { 58 let description = [{ 59 Generic Linalg op form where the key properties of the computation are 60 specified as attributes. In pretty form, a `linalg.generic` op is written 61 as: 62 63 ```mlir 64 linalg.generic #trait_attribute 65 ins(%A, %B : memref<?x?xf32, stride_specification>, 66 memref<?x?xf32, stride_specification>) 67 outs(%C : memref<?x?xf32, stride_specification>) 68 attrs = {other-optional-attributes} 69 {region} 70 ``` 71 72 Where #trait_attributes is an alias of a dictionary attribute containing: 73 - doc [optional]: a documentation string 74 - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input 75 and output view. Such AffineMapAttr specifies the mapping between the 76 loops and the indexing within each view. 77 - library_call [optional]: a StringAttr containing the name of an 78 external library function that the linalg.generic operation maps to. 79 The external library is assumed to be dynamically linked and no strong 80 compile-time guarantees are provided. In the absence of such a library 81 call, linalg.generic will always lower to loops. 82 - iterator_types: an ArrayAttr specifying the type of the enclosing loops. 83 Each element of the list represents and iterator of one of the following 84 types: 85 parallel, reduction, window 86 87 Example: 88 Defining a #matmul_trait attribute in MLIR can be done as follows: 89 ```mlir 90 #matmul_accesses = [ 91 (m, n, k) -> (m, k), 92 (m, n, k) -> (k, n), 93 (m, n, k) -> (m, n) 94 ] 95 #matmul_trait = { 96 doc = "C(m, n) += A(m, k) * B(k, n)", 97 indexing_maps = #matmul_accesses, 98 library_call = "linalg_matmul", 99 iterator_types = ["parallel", "parallel", "reduction"] 100 } 101 ``` 102 103 And can be reused in multiple places as: 104 ```mlir 105 linalg.generic #matmul_trait 106 ins(%A, %B : memref<?x?xf32, stride_specification>, 107 memref<?x?xf32, stride_specification>) 108 outs(%C : memref<?x?xf32, stride_specification>) 109 {other-optional-attributes} { 110 ^bb0(%a: f32, %b: f32, %c: f32) : 111 %d = arith.mulf %a, %b: f32 112 %e = arith.addf %c, %d: f32 113 linalg.yield %e : f32 114 } 115 ``` 116 117 This may lower to either: 118 ```mlir 119 call @linalg_matmul(%A, %B, %C) : 120 (memref<?x?xf32, stride_specification>, 121 memref<?x?xf32, stride_specification>, 122 memref<?x?xf32, stride_specification>) 123 -> () 124 ``` 125 126 or IR resembling: 127 ```mlir 128 scf.for %m = %c0 to %M step %c1 { 129 scf.for %n = %c0 to %N step %c1 { 130 scf.for %k = %c0 to %K step %c1 { 131 %a = load %A[%m, %k] : memref<?x?xf32, stride_specification> 132 %b = load %B[%k, %n] : memref<?x?xf32, stride_specification> 133 %c = load %C[%m, %n] : memref<?x?xf32, stride_specification> 134 %d = arith.mulf %a, %b: f32 135 %e = arith.addf %c, %d: f32 136 store %e, %C[%m, %n] : memref<?x?x?xf32, stride_specification> 137 } 138 } 139 } 140 ``` 141 142 To allow progressive lowering from the value world (a.k.a tensor values) to 143 the buffer world (a.k.a memref values), a `linalg.generic` op allows mixing 144 tensors and buffers operands and tensor results. 145 146 ```mlir 147 %C = linalg.generic #trait_attribute 148 ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>) 149 outs(%C : tensor<?x?xf32>) 150 {other-optional-attributes} 151 {region} 152 -> (tensor<?x?xf32>) 153 ``` 154 }]; 155 156 let arguments = (ins Variadic<AnyType>:$inputs, 157 Variadic<AnyShaped>:$outputs, 158 AffineMapArrayAttr:$indexing_maps, 159 IteratorTypeArrayAttr:$iterator_types, 160 OptionalAttr<StrAttr>:$doc, 161 OptionalAttr<StrAttr>:$library_call); 162 let results = (outs Variadic<AnyRankedTensor>:$result_tensors); 163 let regions = (region AnyRegion:$region); 164 165 let builders = [ 166 OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 167 "ValueRange":$outputs, "ArrayAttr":$indexingMaps, 168 "ArrayAttr":$iteratorTypes, "StringAttr":$doc, 169 "StringAttr":$libraryCall, 170 "function_ref<void(OpBuilder &, Location, ValueRange)>", 171 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, 172 OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 173 "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps, 174 "ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc, 175 "StringRef":$libraryCall, 176 CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">, 177 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, 178 OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, 179 "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes, 180 "StringRef":$doc, "StringRef":$libraryCall, 181 CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">, 182 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, 183 OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 184 "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps, 185 "ArrayRef<utils::IteratorType>":$iteratorTypes, 186 CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">, 187 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, 188 OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, 189 "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes, 190 CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">, 191 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)> 192 ]; 193 194 let extraClassDeclaration = structuredOpsBaseDecls # [{ 195 SmallVector<StringRef, 8> linalgTraitAttrNames() { 196 return SmallVector<StringRef, 8>{ 197 getDocAttrName(), 198 getIndexingMapsAttrName(), getLibraryCallAttrName(), 199 getIteratorTypesAttrName(), 200 }; 201 } 202 std::string getLibraryCallName() { 203 return getLibraryCall() ? 204 getLibraryCall()->str() : "op_has_no_registered_library_name"; 205 } 206 207 static std::function<void(ImplicitLocOpBuilder &, 208 Block &, ArrayRef<NamedAttribute>)> 209 getRegionBuilder() { 210 return nullptr; 211 } 212 213 MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } 214 215 // Return true only if GenericOp has a single input and single 216 // output, and the body is a single yieldOp that yields the input. 217 // This check is useful when trying to determine if the op is 218 // essentially a transpose, broadcast, copy or something like that. 219 bool isSingleYieldOp() { 220 if (!isSingleInputOutput()) 221 return false; 222 Block *body = getBody(); 223 if (body->getOperations().size() != 1) 224 return false; 225 226 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); 227 if (!yieldOp || yieldOp.getNumOperands() != 1 || 228 yieldOp->getOperand(0) != body->getArgument(0)) 229 return false; 230 return true; 231 } 232 }]; 233 234 let hasCanonicalizer = 1; 235 let hasCustomAssemblyFormat = 1; 236 let hasFolder = 1; 237 let hasVerifier = 1; 238} 239 240 241//===----------------------------------------------------------------------===// 242// Map op. 243//===----------------------------------------------------------------------===// 244 245def TensorOrMemref : 246 AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; 247 248def MapOp : LinalgStructuredBase_Op<"map", [ 249 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, 250 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>, 251 SingleBlockImplicitTerminator<"YieldOp">]> { 252 let summary = "Elementwise operations"; 253 let description = [{ 254 Models elementwise operations on tensors in terms of arithmetic operations 255 on the corresponding elements. 256 257 Example: 258 ``` 259 %add = linalg.map 260 ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) 261 outs(%init: tensor<64xf32>) 262 (%lhs_elem: f32, %rhs_elem: f32) { 263 %0 = arith.addf %lhs_elem, %rhs_elem: f32 264 linalg.yield %0: f32 265 } 266 ``` 267 268 Shortened print form is available. Applies to simple maps with one 269 non-yield operation inside the body. 270 271 The example above will be printed as: 272 ``` 273 %add = linalg.map { arith.addf } 274 ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) 275 outs(%init: tensor<64xf32>) 276 ``` 277 }]; 278 279 let arguments = (ins 280 // Input args 281 Variadic<TensorOrMemref>:$inputs, 282 283 // Output arg 284 TensorOrMemref:$init 285 ); 286 let results = (outs Variadic<AnyTensor>:$result); 287 let regions = (region SizedRegion<1>:$mapper); 288 289 let builders = [ 290 OpBuilder<(ins "ValueRange":$inputs, "Value":$init, 291 "function_ref<void(OpBuilder &, Location, ValueRange)>", 292 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)> 293 ]; 294 295 let extraClassDeclaration = structuredOpsBaseDecls # [{ 296 // Implement functions necessary for LinalgStructuredInterface. 297 SmallVector<utils::IteratorType> getIteratorTypesArray(); 298 ArrayAttr getIndexingMaps(); 299 std::string getLibraryCallName() { 300 return "op_has_no_registered_library_name"; 301 } 302 303 // Implement functions necessary for DestinationStyleOpInterface. 304 MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } 305 306 SmallVector<OpOperand *> getOpOperandsMatchingBBargs() { 307 return getDpsInputOperands(); 308 } 309 310 bool payloadUsesValueFromOperand(OpOperand * opOperand) { 311 if (isDpsInit(opOperand)) return false; 312 return !getMatchingBlockArgument(opOperand).use_empty(); 313 } 314 315 static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, 316 mlir::ArrayRef<mlir::NamedAttribute>)> 317 getRegionBuilder() { 318 return nullptr; 319 } 320 }]; 321 322 let hasCustomAssemblyFormat = 1; 323 let hasVerifier = 1; 324} 325 326 327//===----------------------------------------------------------------------===// 328// Reduce op. 329//===----------------------------------------------------------------------===// 330 331def ReduceOp : LinalgStructuredBase_Op<"reduce", [ 332 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, 333 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>, 334 SameVariadicOperandSize, 335 SingleBlockImplicitTerminator<"YieldOp">]> { 336 let summary = "Reduce operator"; 337 let description = [{ 338 Executes `combiner` on the `dimensions` of `inputs` and returns the 339 reduced result. The `dimensions` attribute needs to list the reduction 340 dimensions in increasing order. 341 342 Example: 343 ``` 344 %reduce = linalg.reduce 345 ins(%input:tensor<16x32x64xf32>) 346 outs(%init:tensor<16x64xf32>) 347 dimensions = [1] 348 (%in: f32, %out: f32) { 349 %0 = arith.addf %out, %in: f32 350 linalg.yield %0: f32 351 } 352 ``` 353 354 Shortened print form is available. Applies to simple (not variadic) reduces 355 with one non-yield operation inside the body. Applies only if the operation 356 takes `%out` as the first argument. 357 358 The example above will be printed as: 359 ``` 360 %reduce = linalg.reduce { arith.addf } 361 ins(%input:tensor<16x32x64xf32>) 362 outs(%init:tensor<16x64xf32>) 363 dimensions = [1] 364 ``` 365 }]; 366 367 let arguments = (ins 368 // Input arg 369 Variadic<TensorOrMemref>:$inputs, 370 // Output arg 371 Variadic<TensorOrMemref>:$inits, 372 373 ConfinedAttr<DenseI64ArrayAttr, 374 [DenseArrayStrictlySorted<DenseI64ArrayAttr>]>:$dimensions 375 ); 376 let results = (outs Variadic<AnyTensor>); 377 let regions = (region SizedRegion<1>:$combiner); 378 379 let builders = [ 380 OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits, 381 "ArrayRef<int64_t>":$dimensions, 382 "function_ref<void(OpBuilder &, Location, ValueRange)>", 383 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)> 384 ]; 385 386 let extraClassDeclaration = structuredOpsBaseDecls # [{ 387 // Declare functions necessary for LinalgStructuredInterface. 388 SmallVector<utils::IteratorType> getIteratorTypesArray(); 389 ArrayAttr getIndexingMaps(); 390 std::string getLibraryCallName() { 391 return "op_has_no_registered_library_name"; 392 } 393 394 // Implement functions necessary for DestinationStyleOpInterface. 395 static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, 396 mlir::ArrayRef<mlir::NamedAttribute>)> 397 getRegionBuilder() { 398 return nullptr; 399 } 400 MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } 401 }]; 402 403 let hasCustomAssemblyFormat = 1; 404 let hasVerifier = 1; 405} 406 407 408//===----------------------------------------------------------------------===// 409// Transpose op. 410//===----------------------------------------------------------------------===// 411 412def TransposeOp : LinalgStructuredBase_Op<"transpose", [ 413 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, 414 SingleBlockImplicitTerminator<"YieldOp">]> { 415 let summary = "Transpose operator"; 416 let description = [{ 417 Permutes the dimensions of `input` according to the given `permutation`. 418 `dim(result, i) = dim(input, permutation[i])` 419 420 This op actually moves data, unlike `memref.transpose` which is a metadata 421 operation only that produces a transposed "view". 422 423 Example: 424 ``` 425 %transpose = linalg.transpose 426 ins(%input:tensor<16x64xf32>) 427 outs(%init:tensor<64x16xf32>) 428 permutation = [1, 0] 429 ``` 430 }]; 431 432 let arguments = (ins 433 // Input arg 434 TensorOrMemref:$input, 435 // Output arg 436 TensorOrMemref:$init, 437 438 DenseI64ArrayAttr:$permutation 439 ); 440 let results = (outs Variadic<AnyTensor>:$result); 441 let regions = (region SizedRegion<1>:$region); 442 443 let skipDefaultBuilders = 1; 444 let builders = [ 445 OpBuilder<(ins "Value":$input, "Value":$init, 446 "DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>", 447 "{}">:$attributes)>, 448 OpBuilder<(ins "Value":$input, "Value":$init, 449 "ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>", 450 "{}">:$attributes)>, 451 ]; 452 453 let extraClassDeclaration = structuredOpsBaseDecls # [{ 454 // Declare functions necessary for LinalgStructuredInterface. 455 SmallVector<utils::IteratorType> getIteratorTypesArray(); 456 ArrayAttr getIndexingMaps(); 457 std::string getLibraryCallName() { 458 return "op_has_no_registered_library_name"; 459 } 460 461 // Implement functions necessary for DestinationStyleOpInterface. 462 MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } 463 464 static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, 465 mlir::ArrayRef<mlir::NamedAttribute>) { 466 OpBuilder::InsertionGuard guard(b); 467 b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0)); 468 } 469 470 static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, 471 mlir::ArrayRef<mlir::NamedAttribute>)> 472 getRegionBuilder() { 473 return regionBuilder; 474 } 475 }]; 476 477 let hasFolder = 1; 478 let hasCanonicalizer = 1; 479 let hasCustomAssemblyFormat = 1; 480 let hasVerifier = 1; 481} 482 483 484//===----------------------------------------------------------------------===// 485// Broadcast op. 486//===----------------------------------------------------------------------===// 487 488def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ 489 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, 490 SingleBlockImplicitTerminator<"YieldOp">]> { 491 let summary = "Static broadcast operator"; 492 let description = [{ 493 Broadcast the input into the given shape by adding `dimensions`. 494 495 Example: 496 ``` 497 %bcast = linalg.broadcast 498 ins(%input:tensor<16xf32>) 499 outs(%init:tensor<16x64xf32>) 500 dimensions = [1] 501 ``` 502 }]; 503 504 let arguments = (ins 505 // Input arg 506 TensorOrMemref:$input, 507 // Output arg 508 TensorOrMemref:$init, 509 510 DenseI64ArrayAttr:$dimensions 511 ); 512 let results = (outs Variadic<AnyTensor>:$result); 513 let regions = (region SizedRegion<1>:$region); 514 515 let skipDefaultBuilders = 1; 516 let builders = [ 517 OpBuilder<(ins "Value":$input, "Value":$init, 518 "DenseI64ArrayAttr":$dimensions, CArg<"ArrayRef<NamedAttribute>", 519 "{}">:$attributes)>, 520 OpBuilder<(ins "Value":$input, "Value":$init, 521 "ArrayRef<int64_t>":$dimensions, CArg<"ArrayRef<NamedAttribute>", 522 "{}">:$attributes)>, 523 ]; 524 525 let extraClassDeclaration = structuredOpsBaseDecls # [{ 526 // Declare functions necessary for LinalgStructuredInterface. 527 SmallVector<utils::IteratorType> getIteratorTypesArray(); 528 ArrayAttr getIndexingMaps(); 529 std::string getLibraryCallName() { 530 return "op_has_no_registered_library_name"; 531 } 532 533 // Implement functions necessary for DestinationStyleOpInterface. 534 MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } 535 536 static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, 537 mlir::ArrayRef<mlir::NamedAttribute>) { 538 OpBuilder::InsertionGuard guard(b); 539 b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0)); 540 } 541 542 static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, 543 mlir::ArrayRef<mlir::NamedAttribute>)> 544 getRegionBuilder() { 545 return regionBuilder; 546 } 547 }]; 548 549 let hasCustomAssemblyFormat = 1; 550 let hasVerifier = 1; 551 let hasCanonicalizer = 1; 552} 553 554//===----------------------------------------------------------------------===// 555// Op definition for MatmulOp 556//===----------------------------------------------------------------------===// 557 558def MatmulOp : LinalgStructuredBase_Op<"matmul", [ 559 AttrSizedOperandSegments, 560 LinalgContractionOpInterface]> { 561 562 let summary = [{ 563 Performs a matrix multiplication of two 2D inputs without broadcast or transpose. 564 }]; 565 let description = [{ 566 Numeric casting is performed on the operands to the inner multiply, 567 promoting them to the same data type as the accumulator/output. 568 569 Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 570 'indexing_maps' as shown below.This is a list attribute, so the list must include all 571 the maps if specified. 572 573 Example Transpose: 574 ``` 575 linalg.matmul indexing_maps = [ 576 affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose 577 affine_map<(d0, d1, d2) -> (d2, d1)>, 578 affine_map<(d0, d1, d2) -> (d0, d1)> 579 ] 580 ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) 581 outs(%arg2: memref<3x7xf32>) 582 ``` 583 584 Example Broadcast: 585 ``` 586 linalg.matmul indexing_maps = [ 587 affine_map<(d0, d1, d2) -> (d2)>, // broadcast 588 affine_map<(d0, d1, d2) -> (d2, d1)>, 589 affine_map<(d0, d1, d2) -> (d0, d1)> 590 ] 591 ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) 592 outs(%arg2: memref<3x7xf32>) 593 ``` 594 595 Example Broadcast and transpose: 596 ``` 597 linalg.matmul indexing_maps = [ 598 affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose 599 affine_map<(d0, d1, d2) -> (d2)>, // broadcast 600 affine_map<(d0, d1, d2) -> (d0, d1)> 601 ] 602 ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) 603 ``` 604 }]; 605 606 let arguments = (ins 607 Variadic<AnyType>:$inputs, 608 Variadic<AnyShaped>:$outputs, 609 DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps, 610 DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast 611 ); 612 let results = (outs Variadic<AnyRankedTensor>:$result_tensors); 613 let regions = (region AnyRegion:$region); 614 615 let skipDefaultBuilders = 1; 616 let builders = [ 617 OpBuilder< 618 (ins "ValueRange":$inputs, "ValueRange":$outputs, 619 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), 620 [{ 621 buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, 622 attributes, MatmulOp::getRegionBuilder(), 623 MatmulOp::getDefaultIndexingMaps($_builder.getContext())); 624 }]>, 625 OpBuilder< 626 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 627 "ValueRange":$outputs, 628 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), 629 [{ 630 buildMatmulOp($_builder, $_state, resultTensorTypes, 631 inputs, outputs, attributes, MatmulOp::getRegionBuilder(), 632 MatmulOp::getDefaultIndexingMaps($_builder.getContext())); 633 }]>, 634 OpBuilder< 635 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 636 "ValueRange":$outputs, 637 "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), 638 [{ 639 $_state.addAttribute("cast", cast); 640 buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs, 641 attributes, MatmulOp::getRegionBuilder(), 642 MatmulOp::getDefaultIndexingMaps($_builder.getContext())); 643 }]> 644 645 ]; 646 let hasCustomAssemblyFormat = 1; 647 let hasFolder = 1; 648 let hasVerifier = 1; 649 650 let extraClassDeclaration = structuredOpsBaseDecls # [{ 651 SmallVector<utils::IteratorType> getIteratorTypesArray(); 652 653 /// Implements the block region builder. 654 static void regionBuilder(ImplicitLocOpBuilder &b, 655 Block &block, ArrayRef<NamedAttribute> attrs); 656 657 /// Returns a list of AffineMap with the typical matmul indexing charactristic. 658 static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context); 659 660 /// Returns true if the given broadcast map \p bcastMap is valid for this op. 661 bool isValidLhsRhsBroadcastMap(AffineMap bcastMap); 662 663 static std::function<void(ImplicitLocOpBuilder &, 664 Block &, ArrayRef<NamedAttribute>)> 665 getRegionBuilder() { 666 return regionBuilder; 667 } 668 669 ::mlir::MutableOperandRange getDpsInitsMutable() { 670 return getOutputsMutable(); 671 } 672 673 // Generic methods. 674 static unsigned getNumRegionArgs(); 675 std::string getLibraryCallName(); 676 bool hasDynamicIndexingMaps(); 677 /// Check if the op has broadcast and/or transpose semantic. Returns true if the 678 /// user defined indexing maps are not equal to default map. 679 bool hasUserDefinedMaps(); 680 }]; 681} 682 683//===----------------------------------------------------------------------===// 684// Contract op. 685//===----------------------------------------------------------------------===// 686 687def ContractOp : LinalgStructuredBase_Op<"contract", [ 688 AttrSizedOperandSegments, 689 LinalgContractionOpInterface]> { 690 let summary = [{ 691 Perform a contraction on two inputs, accumulating into the third. 692 }]; 693 let description = [{ 694 The semantics of contracting inputs `A` and `B` on top of `C` to produce 695 output `D` is given by 696 697 `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]` 698 699 where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension 700 identifiers - meant to range over valid indices - corresponding to the 701 results of the mandatory (projected permutation) `indexing_maps` for `A`, 702 `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the 703 dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of 704 dim identifiers). 705 706 The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the 707 domain of each of the `affine_map`s. Like for einsums, the iteration type of 708 each dim is inferred and is either: 709 710 - reduction: the dim is used to index into `A` and `B` but not `C`. Per the 711 above semantics, these dims will be contracted, i.e. reduced over. 712 713 - parallel: the dim is used to index into `C` and at least one of `A` and 714 `B`, and - deriving from matmul terminology - is either an "M-like" dim 715 (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a 716 "batch"-dim (if used to index into `A`, `B`, and `C`). 717 718 For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`, 719 `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`, 720 `n` and `b` have parallel iteration-type) and gets represented as: 721 722 ``` 723 %D = linalg.contract 724 indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, 725 affine_map<(batch, m, n, k) -> (batch, k, n)>, 726 affine_map<(batch, m, n, k) -> (batch, m, n)>] 727 ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>) 728 outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 729 ``` 730 731 Note that by permuting dims in the `affine_map`s' results, accesses to 732 to the inputs and output can be arbitrarily transposed. Similarly, arbitrary 733 broadcasts can be achieved through leaving out dims on either input operand. 734 For example, the following is a variant of batch-matmul with a transposition 735 applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim: 736 737 ``` 738 linalg.contract 739 indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, 740 affine_map<(batch, m, n, k) -> (k, n)>, 741 affine_map<(batch, m, n, k) -> (batch, m, n)>] 742 ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>) 743 outs(%C: memref<?x?x?xf32>) 744 ``` 745 746 Numeric casting is performed on the operands to the inner multiplication, 747 promoting/truncating them to the same data type as the accumulator/output. 748 749 TODO: Allow control over the combining/accumulating op and possibly the 750 multiplication op. 751 }]; 752 753 let arguments = (ins 754 Variadic<AnyType>:$inputs, 755 Variadic<AnyShaped>:$outputs, 756 AffineMapArrayAttr:$indexing_maps 757 ); 758 let results = (outs Variadic<AnyShaped>:$result_tensors); 759 // NB: The only reason this op has a region - and it get populated at op build 760 // time - is that currently the LinalgOp interface exposes methods that 761 // assume a relevant region is available to be queried at any time. 762 let regions = (region SizedRegion<1>:$combiner); 763 764 let skipDefaultBuilders = 1; 765 let builders = [ 766 OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 767 "ValueRange":$outputs, "ArrayAttr":$indexingMaps, 768 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), 769 [{ 770 $_state.addAttribute("indexing_maps", indexingMaps); 771 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, 772 outputs, attributes, regionBuilder); 773 }]>, 774 OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, 775 "ArrayAttr":$indexingMaps, 776 CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), 777 [{ 778 $_state.addAttribute("indexing_maps", indexingMaps); 779 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, 780 attributes, regionBuilder); 781 }]> 782 ]; 783 let hasCustomAssemblyFormat = 1; 784 let hasFolder = 1; 785 let hasVerifier = 1; 786 787 let extraClassDeclaration = structuredOpsBaseDecls # [{ 788 // Declare/implement functions necessary for LinalgStructuredInterface. 789 790 /// Infer iterator types for each dim in the domain of IndexingMaps. 791 SmallVector<utils::IteratorType> getIteratorTypesArray(); 792 793 /// IndexingMaps always depends on attr associated to current Op instance. 794 bool hasDynamicIndexingMaps() { return true; }; 795 bool hasUserDefinedMaps() { return true; }; 796 797 static unsigned getNumRegionArgs(); 798 799 static void regionBuilder(ImplicitLocOpBuilder &b, 800 Block &block, ArrayRef<NamedAttribute> attrs); 801 802 static std::function<void(ImplicitLocOpBuilder &, 803 Block &, ArrayRef<NamedAttribute>)> 804 getRegionBuilder() { 805 return regionBuilder; 806 } 807 808 std::string getLibraryCallName() { 809 return "op_has_no_registered_library_name"; 810 } 811 812 // Implement function necessary for DestinationStyleOpInterface. 813 ::mlir::MutableOperandRange getDpsInitsMutable() { 814 return getOutputsMutable(); 815 } 816 }]; 817} 818 819//===----------------------------------------------------------------------===// 820// Named Linalg ops, implemented as a declarative configurations of generic ops. 821//===----------------------------------------------------------------------===// 822 823include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td" 824 825#endif // LINALG_STRUCTURED_OPS 826