xref: /llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (revision 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83)
1//===- LinalgStructuredOps.td - Linalg dialect library ops -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the operation definition file for structured operations on buffers
10// that correspond to underlying library calls (e.g. BLAS).
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LINALG_STRUCTURED_OPS
15#define LINALG_STRUCTURED_OPS
16
17include "mlir/Dialect/Linalg/IR/LinalgBase.td"
18include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
19include "mlir/Interfaces/DestinationStyleOpInterface.td"
20include "mlir/Interfaces/InferTypeOpInterface.td"
21include "mlir/Interfaces/SideEffectInterfaces.td"
22include "mlir/IR/OpAsmInterface.td"
23
24// Base Tablegen class for Linalg ops.
25// Linalg ops that correspond to library calls operate on ShapedType as their
26// first operands. These may be optionally followed by non-view operands
27// depending on the specific Linalg op.
28class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
29  : Op<Linalg_Dialect, mnemonic, !listconcat([
30       SingleBlockImplicitTerminator<"YieldOp">,
31       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
32       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
33       RecursiveMemoryEffects,
34       DestinationStyleOpInterface,
35       LinalgStructuredInterface,
36       ReifyRankedShapedTypeOpInterface], props)> {
37  code structuredOpsBaseDecls = [{
38    // Return whether the op accesses the iteration indices.
39    bool hasIndexSemantics() {
40      return !this->getBody()->getOps<IndexOp>().empty();
41    }
42
43    LogicalResult reifyResultShapes(OpBuilder &b,
44        ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
45      return llvm::cast<LinalgOp>(getOperation()).reifyResultShapes(b,
46          reifiedReturnShapes);
47    }
48  }];
49}
50
51//===----------------------------------------------------------------------===//
52// Generic Linalg ops.
53//===----------------------------------------------------------------------===//
54
55def GenericOp : LinalgStructuredBase_Op<"generic", [
56    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
57    AttrSizedOperandSegments]> {
58  let description = [{
59    Generic Linalg op form where the key properties of the computation are
60    specified as attributes. In pretty form, a `linalg.generic` op is written
61    as:
62
63      ```mlir
64      linalg.generic #trait_attribute
65          ins(%A, %B : memref<?x?xf32, stride_specification>,
66                       memref<?x?xf32, stride_specification>)
67          outs(%C : memref<?x?xf32, stride_specification>)
68          attrs = {other-optional-attributes}
69          {region}
70      ```
71
72    Where #trait_attributes is an alias of a dictionary attribute containing:
73      - doc [optional]: a documentation string
74      - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
75        and output view. Such AffineMapAttr specifies the mapping between the
76        loops and the indexing within each view.
77      - library_call [optional]: a StringAttr containing the name of an
78        external library function that the linalg.generic operation maps to.
79        The external library is assumed to be dynamically linked and no strong
80        compile-time guarantees are provided. In the absence of such a library
81        call, linalg.generic will always lower to loops.
82      - iterator_types: an ArrayAttr specifying the type of the enclosing loops.
83        Each element of the list represents and iterator of one of the following
84        types:
85          parallel, reduction, window
86
87    Example:
88    Defining a #matmul_trait attribute in MLIR can be done as follows:
89      ```mlir
90      #matmul_accesses = [
91        (m, n, k) -> (m, k),
92        (m, n, k) -> (k, n),
93        (m, n, k) -> (m, n)
94      ]
95      #matmul_trait = {
96        doc = "C(m, n) += A(m, k) * B(k, n)",
97        indexing_maps = #matmul_accesses,
98        library_call = "linalg_matmul",
99        iterator_types = ["parallel", "parallel", "reduction"]
100      }
101      ```
102
103    And can be reused in multiple places as:
104      ```mlir
105      linalg.generic #matmul_trait
106        ins(%A, %B : memref<?x?xf32, stride_specification>,
107                     memref<?x?xf32, stride_specification>)
108        outs(%C : memref<?x?xf32, stride_specification>)
109        {other-optional-attributes} {
110        ^bb0(%a: f32, %b: f32, %c: f32) :
111          %d = arith.mulf %a, %b: f32
112          %e = arith.addf %c, %d: f32
113          linalg.yield %e : f32
114      }
115      ```
116
117    This may lower to either:
118      ```mlir
119      call @linalg_matmul(%A, %B, %C) :
120        (memref<?x?xf32, stride_specification>,
121         memref<?x?xf32, stride_specification>,
122         memref<?x?xf32, stride_specification>)
123        -> ()
124      ```
125
126    or IR resembling:
127    ```mlir
128    scf.for %m = %c0 to %M step %c1 {
129      scf.for %n = %c0 to %N step %c1 {
130        scf.for %k = %c0 to %K step %c1 {
131          %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
132          %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
133          %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
134          %d = arith.mulf %a, %b: f32
135          %e = arith.addf %c, %d: f32
136          store %e, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
137        }
138      }
139    }
140    ```
141
142    To allow progressive lowering from the value world (a.k.a tensor values) to
143    the buffer world (a.k.a memref values), a `linalg.generic` op allows mixing
144    tensors and buffers operands and tensor results.
145
146    ```mlir
147    %C = linalg.generic #trait_attribute
148      ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
149      outs(%C : tensor<?x?xf32>)
150      {other-optional-attributes}
151      {region}
152      -> (tensor<?x?xf32>)
153    ```
154  }];
155
156  let arguments = (ins Variadic<AnyType>:$inputs,
157                       Variadic<AnyShaped>:$outputs,
158                       AffineMapArrayAttr:$indexing_maps,
159                       IteratorTypeArrayAttr:$iterator_types,
160                       OptionalAttr<StrAttr>:$doc,
161                       OptionalAttr<StrAttr>:$library_call);
162  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
163  let regions = (region AnyRegion:$region);
164
165  let builders = [
166    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
167      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
168      "ArrayAttr":$iteratorTypes, "StringAttr":$doc,
169      "StringAttr":$libraryCall,
170      "function_ref<void(OpBuilder &, Location, ValueRange)>",
171      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
172    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
173      "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
174      "ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
175      "StringRef":$libraryCall,
176      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
177      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
178    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
179      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
180      "StringRef":$doc, "StringRef":$libraryCall,
181      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
182      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
183    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
184      "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
185      "ArrayRef<utils::IteratorType>":$iteratorTypes,
186      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
187      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
188    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
189      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
190      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
191      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
192  ];
193
194  let extraClassDeclaration = structuredOpsBaseDecls # [{
195    SmallVector<StringRef, 8> linalgTraitAttrNames() {
196      return SmallVector<StringRef, 8>{
197        getDocAttrName(),
198        getIndexingMapsAttrName(), getLibraryCallAttrName(),
199        getIteratorTypesAttrName(),
200      };
201    }
202    std::string getLibraryCallName() {
203      return getLibraryCall() ?
204        getLibraryCall()->str() : "op_has_no_registered_library_name";
205    }
206
207    static std::function<void(ImplicitLocOpBuilder &,
208                              Block &, ArrayRef<NamedAttribute>)>
209    getRegionBuilder() {
210      return nullptr;
211    }
212
213    MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
214
215    // Return true only if GenericOp has a single input and single
216    // output, and the body is a single yieldOp that yields the input.
217    // This check is useful when trying to determine if the op is
218    // essentially a transpose, broadcast, copy or something like that.
219    bool isSingleYieldOp() {
220      if (!isSingleInputOutput())
221        return false;
222     Block *body = getBody();
223     if (body->getOperations().size() != 1)
224       return false;
225
226     auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
227       if (!yieldOp || yieldOp.getNumOperands() != 1 ||
228           yieldOp->getOperand(0) != body->getArgument(0))
229         return false;
230     return true;
231   }
232  }];
233
234  let hasCanonicalizer = 1;
235  let hasCustomAssemblyFormat = 1;
236  let hasFolder = 1;
237  let hasVerifier = 1;
238}
239
240
241//===----------------------------------------------------------------------===//
242// Map op.
243//===----------------------------------------------------------------------===//
244
245def TensorOrMemref :
246  AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
247
248def MapOp : LinalgStructuredBase_Op<"map", [
249    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
250    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
251    SingleBlockImplicitTerminator<"YieldOp">]> {
252  let summary = "Elementwise operations";
253  let description = [{
254    Models elementwise operations on tensors in terms of arithmetic operations
255    on the corresponding elements.
256
257    Example:
258    ```
259      %add = linalg.map
260          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
261          outs(%init: tensor<64xf32>)
262          (%lhs_elem: f32, %rhs_elem: f32) {
263            %0 = arith.addf %lhs_elem, %rhs_elem: f32
264            linalg.yield %0: f32
265          }
266    ```
267
268    Shortened print form is available. Applies to simple maps with one
269    non-yield operation inside the body.
270
271    The example above will be printed as:
272    ```
273      %add = linalg.map { arith.addf }
274          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
275          outs(%init: tensor<64xf32>)
276    ```
277  }];
278
279  let arguments = (ins
280    // Input args
281    Variadic<TensorOrMemref>:$inputs,
282
283    // Output arg
284    TensorOrMemref:$init
285  );
286  let results = (outs Variadic<AnyTensor>:$result);
287  let regions = (region SizedRegion<1>:$mapper);
288
289  let builders = [
290    OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
291      "function_ref<void(OpBuilder &, Location, ValueRange)>",
292      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
293  ];
294
295  let extraClassDeclaration = structuredOpsBaseDecls # [{
296    // Implement functions necessary for LinalgStructuredInterface.
297    SmallVector<utils::IteratorType> getIteratorTypesArray();
298    ArrayAttr getIndexingMaps();
299    std::string getLibraryCallName() {
300      return "op_has_no_registered_library_name";
301    }
302
303    // Implement functions necessary for DestinationStyleOpInterface.
304    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
305
306    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
307      return getDpsInputOperands();
308    }
309
310    bool payloadUsesValueFromOperand(OpOperand * opOperand) {
311      if (isDpsInit(opOperand)) return false;
312      return !getMatchingBlockArgument(opOperand).use_empty();
313    }
314
315    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
316                              mlir::ArrayRef<mlir::NamedAttribute>)>
317    getRegionBuilder() {
318      return nullptr;
319    }
320  }];
321
322  let hasCustomAssemblyFormat = 1;
323  let hasVerifier = 1;
324}
325
326
327//===----------------------------------------------------------------------===//
328// Reduce op.
329//===----------------------------------------------------------------------===//
330
331def ReduceOp : LinalgStructuredBase_Op<"reduce", [
332    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
333    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
334    SameVariadicOperandSize,
335    SingleBlockImplicitTerminator<"YieldOp">]> {
336  let summary = "Reduce operator";
337  let description = [{
338    Executes `combiner` on the `dimensions` of `inputs` and returns the
339    reduced result. The `dimensions` attribute needs to list the reduction
340    dimensions in increasing order.
341
342    Example:
343    ```
344      %reduce = linalg.reduce
345          ins(%input:tensor<16x32x64xf32>)
346          outs(%init:tensor<16x64xf32>)
347          dimensions = [1]
348          (%in: f32, %out: f32) {
349            %0 = arith.addf %out, %in: f32
350            linalg.yield %0: f32
351          }
352    ```
353
354    Shortened print form is available. Applies to simple (not variadic) reduces
355    with one non-yield operation inside the body. Applies only if the operation
356    takes `%out` as the first argument.
357
358    The example above will be printed as:
359    ```
360          %reduce = linalg.reduce { arith.addf }
361          ins(%input:tensor<16x32x64xf32>)
362          outs(%init:tensor<16x64xf32>)
363          dimensions = [1]
364    ```
365  }];
366
367  let arguments = (ins
368    // Input arg
369    Variadic<TensorOrMemref>:$inputs,
370    // Output arg
371    Variadic<TensorOrMemref>:$inits,
372
373    ConfinedAttr<DenseI64ArrayAttr,
374                 [DenseArrayStrictlySorted<DenseI64ArrayAttr>]>:$dimensions
375  );
376  let results = (outs Variadic<AnyTensor>);
377  let regions = (region SizedRegion<1>:$combiner);
378
379  let builders = [
380    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
381      "ArrayRef<int64_t>":$dimensions,
382      "function_ref<void(OpBuilder &, Location, ValueRange)>",
383      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
384  ];
385
386  let extraClassDeclaration = structuredOpsBaseDecls # [{
387    // Declare functions necessary for LinalgStructuredInterface.
388    SmallVector<utils::IteratorType> getIteratorTypesArray();
389    ArrayAttr getIndexingMaps();
390    std::string getLibraryCallName() {
391      return "op_has_no_registered_library_name";
392    }
393
394    // Implement functions necessary for DestinationStyleOpInterface.
395    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
396                              mlir::ArrayRef<mlir::NamedAttribute>)>
397    getRegionBuilder() {
398      return nullptr;
399    }
400    MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
401  }];
402
403  let hasCustomAssemblyFormat = 1;
404  let hasVerifier = 1;
405}
406
407
408//===----------------------------------------------------------------------===//
409// Transpose op.
410//===----------------------------------------------------------------------===//
411
412def TransposeOp : LinalgStructuredBase_Op<"transpose", [
413    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
414    SingleBlockImplicitTerminator<"YieldOp">]> {
415  let summary = "Transpose operator";
416  let description = [{
417    Permutes the dimensions of `input` according to the given `permutation`.
418      `dim(result, i) = dim(input, permutation[i])`
419
420    This op actually moves data, unlike `memref.transpose` which is a metadata
421    operation only that produces a transposed "view".
422
423    Example:
424    ```
425      %transpose = linalg.transpose
426          ins(%input:tensor<16x64xf32>)
427          outs(%init:tensor<64x16xf32>)
428          permutation = [1, 0]
429    ```
430  }];
431
432  let arguments = (ins
433    // Input arg
434    TensorOrMemref:$input,
435    // Output arg
436    TensorOrMemref:$init,
437
438    DenseI64ArrayAttr:$permutation
439  );
440  let results = (outs Variadic<AnyTensor>:$result);
441  let regions = (region SizedRegion<1>:$region);
442
443  let skipDefaultBuilders = 1;
444  let builders = [
445    OpBuilder<(ins "Value":$input, "Value":$init,
446        "DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>",
447        "{}">:$attributes)>,
448    OpBuilder<(ins "Value":$input, "Value":$init,
449        "ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>",
450        "{}">:$attributes)>,
451  ];
452
453  let extraClassDeclaration = structuredOpsBaseDecls # [{
454    // Declare functions necessary for LinalgStructuredInterface.
455    SmallVector<utils::IteratorType> getIteratorTypesArray();
456    ArrayAttr getIndexingMaps();
457    std::string getLibraryCallName() {
458      return "op_has_no_registered_library_name";
459    }
460
461    // Implement functions necessary for DestinationStyleOpInterface.
462    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
463
464    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
465        mlir::ArrayRef<mlir::NamedAttribute>) {
466      OpBuilder::InsertionGuard guard(b);
467      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
468    }
469
470    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
471        mlir::ArrayRef<mlir::NamedAttribute>)>
472      getRegionBuilder() {
473      return regionBuilder;
474    }
475  }];
476
477  let hasFolder = 1;
478  let hasCanonicalizer = 1;
479  let hasCustomAssemblyFormat = 1;
480  let hasVerifier = 1;
481}
482
483
484//===----------------------------------------------------------------------===//
485// Broadcast op.
486//===----------------------------------------------------------------------===//
487
488def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
489    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
490    SingleBlockImplicitTerminator<"YieldOp">]> {
491  let summary = "Static broadcast operator";
492  let description = [{
493    Broadcast the input into the given shape by adding `dimensions`.
494
495    Example:
496    ```
497      %bcast = linalg.broadcast
498          ins(%input:tensor<16xf32>)
499          outs(%init:tensor<16x64xf32>)
500          dimensions = [1]
501    ```
502  }];
503
504  let arguments = (ins
505    // Input arg
506    TensorOrMemref:$input,
507    // Output arg
508    TensorOrMemref:$init,
509
510    DenseI64ArrayAttr:$dimensions
511  );
512  let results = (outs Variadic<AnyTensor>:$result);
513  let regions = (region SizedRegion<1>:$region);
514
515  let skipDefaultBuilders = 1;
516  let builders = [
517    OpBuilder<(ins "Value":$input, "Value":$init,
518        "DenseI64ArrayAttr":$dimensions, CArg<"ArrayRef<NamedAttribute>",
519        "{}">:$attributes)>,
520    OpBuilder<(ins "Value":$input, "Value":$init,
521        "ArrayRef<int64_t>":$dimensions, CArg<"ArrayRef<NamedAttribute>",
522        "{}">:$attributes)>,
523  ];
524
525  let extraClassDeclaration = structuredOpsBaseDecls # [{
526    // Declare functions necessary for LinalgStructuredInterface.
527    SmallVector<utils::IteratorType> getIteratorTypesArray();
528    ArrayAttr getIndexingMaps();
529    std::string getLibraryCallName() {
530      return "op_has_no_registered_library_name";
531    }
532
533    // Implement functions necessary for DestinationStyleOpInterface.
534    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
535
536    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
537        mlir::ArrayRef<mlir::NamedAttribute>) {
538      OpBuilder::InsertionGuard guard(b);
539      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
540    }
541
542    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
543        mlir::ArrayRef<mlir::NamedAttribute>)>
544      getRegionBuilder() {
545      return regionBuilder;
546    }
547  }];
548
549  let hasCustomAssemblyFormat = 1;
550  let hasVerifier = 1;
551  let hasCanonicalizer = 1;
552}
553
554//===----------------------------------------------------------------------===//
555// Op definition for MatmulOp
556//===----------------------------------------------------------------------===//
557
558def MatmulOp : LinalgStructuredBase_Op<"matmul", [
559               AttrSizedOperandSegments,
560               LinalgContractionOpInterface]> {
561
562  let summary = [{
563    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
564    }];
565  let description = [{
566    Numeric casting is performed on the operands to the inner multiply,
567    promoting them to the same data type as the accumulator/output.
568
569    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
570    'indexing_maps' as shown below.This is a list attribute, so the list must include all
571    the maps if specified.
572
573    Example Transpose:
574    ```
575    linalg.matmul indexing_maps = [
576                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
577                   affine_map<(d0, d1, d2) -> (d2, d1)>,
578                   affine_map<(d0, d1, d2) -> (d0, d1)>
579                   ]
580                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
581                   outs(%arg2: memref<3x7xf32>)
582     ```
583
584    Example Broadcast:
585     ```
586    linalg.matmul indexing_maps = [
587                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
588                   affine_map<(d0, d1, d2) -> (d2, d1)>,
589                   affine_map<(d0, d1, d2) -> (d0, d1)>
590                  ]
591                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
592                  outs(%arg2: memref<3x7xf32>)
593    ```
594
595    Example Broadcast and transpose:
596    ```
597    linalg.matmul indexing_maps = [
598                      affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
599                      affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
600                      affine_map<(d0, d1, d2) -> (d0, d1)>
601                    ]
602                    ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
603    ```
604    }];
605
606    let arguments = (ins
607      Variadic<AnyType>:$inputs,
608      Variadic<AnyShaped>:$outputs,
609      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
610      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
611    );
612    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
613    let regions = (region AnyRegion:$region);
614
615    let skipDefaultBuilders = 1;
616    let builders = [
617      OpBuilder<
618      (ins "ValueRange":$inputs, "ValueRange":$outputs,
619            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
620      [{
621        buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
622          attributes, MatmulOp::getRegionBuilder(),
623          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
624      }]>,
625      OpBuilder<
626      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
627            "ValueRange":$outputs,
628            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
629      [{
630        buildMatmulOp($_builder, $_state, resultTensorTypes,
631          inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
632          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
633      }]>,
634      OpBuilder<
635      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
636       "ValueRange":$outputs,
637       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
638      [{
639        $_state.addAttribute("cast", cast);
640        buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
641          attributes, MatmulOp::getRegionBuilder(),
642          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
643      }]>
644
645    ];
646    let hasCustomAssemblyFormat = 1;
647    let hasFolder = 1;
648    let hasVerifier = 1;
649
650    let extraClassDeclaration = structuredOpsBaseDecls # [{
651      SmallVector<utils::IteratorType> getIteratorTypesArray();
652
653      /// Implements the block region builder.
654      static void regionBuilder(ImplicitLocOpBuilder &b,
655                                Block &block, ArrayRef<NamedAttribute> attrs);
656
657      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
658      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
659
660      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
661      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
662
663      static std::function<void(ImplicitLocOpBuilder &,
664                                Block &, ArrayRef<NamedAttribute>)>
665      getRegionBuilder() {
666        return regionBuilder;
667      }
668
669      ::mlir::MutableOperandRange getDpsInitsMutable() {
670        return getOutputsMutable();
671      }
672
673      // Generic methods.
674      static unsigned getNumRegionArgs();
675      std::string getLibraryCallName();
676      bool hasDynamicIndexingMaps();
677      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
678      /// user defined indexing maps are not equal to default map.
679      bool hasUserDefinedMaps();
680    }];
681}
682
683//===----------------------------------------------------------------------===//
684// Contract op.
685//===----------------------------------------------------------------------===//
686
687def ContractOp : LinalgStructuredBase_Op<"contract", [
688               AttrSizedOperandSegments,
689               LinalgContractionOpInterface]> {
690  let summary = [{
691    Perform a contraction on two inputs, accumulating into the third.
692  }];
693  let description = [{
694    The semantics of contracting inputs `A` and `B` on top of `C` to produce
695    output `D` is given by
696
697      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
698
699    where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
700    identifiers - meant to range over valid indices - corresponding to the
701    results of the mandatory (projected permutation) `indexing_maps` for `A`,
702    `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
703    dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
704    dim identifiers).
705
706    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
707    domain of each of the `affine_map`s. Like for einsums, the iteration type of
708    each dim is inferred and is either:
709
710    - reduction: the dim is used to index into `A` and `B` but not `C`. Per the
711      above semantics, these dims will be contracted, i.e. reduced over.
712
713    - parallel: the dim is used to index into `C` and at least one of `A` and
714      `B`, and - deriving from matmul terminology - is either an "M-like" dim
715      (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
716      "batch"-dim (if used to index into `A`, `B`, and `C`).
717
718    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
719    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
720    `n` and `b` have parallel iteration-type) and gets represented as:
721
722    ```
723    %D = linalg.contract
724        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
725                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
726                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
727        ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
728        outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
729    ```
730
731    Note that by permuting dims in the `affine_map`s' results, accesses to
732    to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
733    broadcasts can be achieved through leaving out dims on either input operand.
734    For example, the following is a variant of batch-matmul with a transposition
735    applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
736
737    ```
738    linalg.contract
739        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
740                         affine_map<(batch, m, n, k) -> (k, n)>,
741                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
742        ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
743        outs(%C: memref<?x?x?xf32>)
744    ```
745
746    Numeric casting is performed on the operands to the inner multiplication,
747    promoting/truncating them to the same data type as the accumulator/output.
748
749    TODO: Allow control over the combining/accumulating op and possibly the
750          multiplication op.
751  }];
752
753  let arguments = (ins
754    Variadic<AnyType>:$inputs,
755    Variadic<AnyShaped>:$outputs,
756    AffineMapArrayAttr:$indexing_maps
757  );
758  let results = (outs Variadic<AnyShaped>:$result_tensors);
759  // NB: The only reason this op has a region - and it get populated at op build
760  //     time - is that currently the LinalgOp interface exposes methods that
761  //     assume a relevant region is available to be queried at any time.
762  let regions = (region SizedRegion<1>:$combiner);
763
764  let skipDefaultBuilders = 1;
765  let builders = [
766    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
767      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
768      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
769      [{
770        $_state.addAttribute("indexing_maps", indexingMaps);
771        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
772                          outputs, attributes, regionBuilder);
773      }]>,
774    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
775      "ArrayAttr":$indexingMaps,
776      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
777      [{
778        $_state.addAttribute("indexing_maps", indexingMaps);
779        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
780                          attributes, regionBuilder);
781      }]>
782  ];
783  let hasCustomAssemblyFormat = 1;
784  let hasFolder = 1;
785  let hasVerifier = 1;
786
787  let extraClassDeclaration = structuredOpsBaseDecls # [{
788    // Declare/implement functions necessary for LinalgStructuredInterface.
789
790    /// Infer iterator types for each dim in the domain of IndexingMaps.
791    SmallVector<utils::IteratorType> getIteratorTypesArray();
792
793    /// IndexingMaps always depends on attr associated to current Op instance.
794    bool hasDynamicIndexingMaps() { return true; };
795    bool hasUserDefinedMaps() { return true; };
796
797    static unsigned getNumRegionArgs();
798
799    static void regionBuilder(ImplicitLocOpBuilder &b,
800                              Block &block, ArrayRef<NamedAttribute> attrs);
801
802    static std::function<void(ImplicitLocOpBuilder &,
803                              Block &, ArrayRef<NamedAttribute>)>
804    getRegionBuilder() {
805      return regionBuilder;
806    }
807
808    std::string getLibraryCallName() {
809      return "op_has_no_registered_library_name";
810    }
811
812    // Implement function necessary for DestinationStyleOpInterface.
813    ::mlir::MutableOperandRange getDpsInitsMutable() {
814      return getOutputsMutable();
815    }
816  }];
817}
818
819//===----------------------------------------------------------------------===//
820// Named Linalg ops, implemented as a declarative configurations of generic ops.
821//===----------------------------------------------------------------------===//
822
823include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td"
824
825#endif // LINALG_STRUCTURED_OPS
826