xref: /llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (revision 06514c550105b3111c23751421265c318bd69ac6)
1//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the definition file for the structured interface sfor Linalg ops.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LINALG_IR_LINALGINTERFACES
14#define LINALG_IR_LINALGINTERFACES
15
16include "mlir/Interfaces/DestinationStyleOpInterface.td"
17include "mlir/IR/OpBase.td"
18
19// The 'LinalgContractionOpInterface' provides access to the
20// 'ContractionOpInterface'.
21def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
22  let description = [{
23   A Linalg contraction is defined in general terms:
24     1. Has 2 input and 1 output shapes.
25     2. Has at least one reduction dimension.
26     3. Has only projected permutation indexing maps.
27     4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
28     (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
29     operations that may change the type (e.g. for mixed-precision).
30   As a consequence, when vectorization of such an op occurs, the only special
31   behavior is that the (unique) MulOpType is vectorized into a
32   `vector.contract`. All other ops are handled in a generic fashion.
33   In the future, we may wish to allow more input arguments and elementwise and
34   constant operations that do not involve the reduction dimension(s).
35  }];
36  let cppNamespace = "::mlir::linalg";
37  let verify = [{ return detail::verifyContractionInterface($_op); }];
38  let verifyWithRegions = 1;
39  let methods = [
40    InterfaceMethod<
41    /*desc=*/"Returns the left-hand side operand.",
42    /*retTy=*/"Value",
43    /*methodName=*/"lhs",
44    /*args=*/(ins),
45    /*methodBody=*/[{
46      return $_op.getOperation()->getOperand(0);
47    }]>,
48    InterfaceMethod<
49    /*desc=*/"Returns the right-hand side operand.",
50    /*retTy=*/"Value",
51    /*methodName=*/"rhs",
52    /*args=*/(ins),
53    /*methodBody=*/[{
54      return $_op.getOperation()->getOperand(1);
55    }]>,
56    InterfaceMethod<
57    /*desc=*/[{
58      Returns whether the given op has indexing maps that correspond to a
59      row-major matmul operation.
60    }],
61    /*retTy=*/"bool",
62    /*methodName=*/"isRowMajorMatmul",
63    /*args=*/(ins),
64    /*methodBody=*/[{
65        return mlir::isRowMajorMatmul($_op.getIndexingMaps());
66    }]>,
67    InterfaceMethod<
68    /*desc=*/[{
69      Returns whether the given op has indexing maps that correspond to a
70      column-major matmul operation.
71    }],
72    /*retTy=*/"bool",
73    /*methodName=*/"isColumnMajorMatmul",
74    /*args=*/(ins),
75    /*methodBody=*/[{
76        return mlir::isColumnMajorMatmul($_op.getIndexingMaps());
77    }]>,
78    InterfaceMethod<
79    /*desc=*/[{
80      Returns whether the given op has indexing maps that correspond to a
81      row-major batch matmul operation.
82    }],
83    /*retTy=*/"bool",
84    /*methodName=*/"isRowMajorBatchMatmul",
85    /*args=*/(ins),
86    /*methodBody=*/[{
87        return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps());
88    }]>,
89    InterfaceMethod<
90    /*desc=*/[{
91      Returns whether the given op has indexing maps that correspond to a
92      vector-matrix multiplication.
93    }],
94    /*retTy=*/"bool",
95    /*methodName=*/"isVecmat",
96    /*args=*/(ins),
97    /*methodBody=*/[{
98        return mlir::isVecmat($_op.getIndexingMaps());
99    }]>,
100    InterfaceMethod<
101    /*desc=*/[{
102      Returns whether the given op has indexing maps that correspond to a
103      batched vector-matrix multiplication.
104    }],
105    /*retTy=*/"bool",
106    /*methodName=*/"isBatchVecmat",
107    /*args=*/(ins),
108    /*methodBody=*/[{
109        return mlir::isBatchVecmat($_op.getIndexingMaps());
110    }]>,
111    InterfaceMethod<
112    /*desc=*/[{
113      Returns whether the given op has indexing maps that correspond to a
114      matrix-vector multiplication.
115    }],
116    /*retTy=*/"bool",
117    /*methodName=*/"isMatvec",
118    /*args=*/(ins),
119    /*methodBody=*/[{
120        return mlir::isMatvec($_op.getIndexingMaps());
121    }]>,
122    InterfaceMethod<
123    /*desc=*/[{
124      Returns whether the given op has indexing maps that correspond to a
125      batched matrix-vector multiplication.
126    }],
127    /*retTy=*/"bool",
128    /*methodName=*/"isBatchMatvec",
129    /*args=*/(ins),
130    /*methodBody=*/[{
131        return mlir::isBatchMatvec($_op.getIndexingMaps());
132    }]>,
133  ];
134}
135
136def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
137  let description = [{
138    A convolution is defined in general terms:
139    1. Has an `image` and a `filter` operand.
140    2. Has one `output` operand.
141    3. The indexing maps of the input have expressions that satisfy
142    ```
143       AffineExpr ::== AffineDimExpr | ConvolvedExpr
144       ConvolvedExpr ::== MulExpr (`+` MulExpr)+
145       MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
146    ```
147    4. The filter and the output have projected permutation maps.
148    5. Each of the loops can be qualified as one of,
149       - Loop over batch dimension,
150       - Loop over output image dimensions,
151       - Loop over output channel dimensions,
152       - Loop over convolved filter dimensions,
153       - Loop over input channel dimension.
154  }];
155  let cppNamespace = "::mlir::linalg";
156  let verify = [{ return detail::verifyConvolutionInterface($_op); }];
157  let methods = [
158    InterfaceMethod<
159      /*desc=*/"Return the image operand.",
160      /*retTy=*/"Value",
161      /*methodName=*/"image",
162      /*args=*/(ins),
163      /*methodBody=*/"",
164      /*defaultImplementation=*/[{
165        return $_op.getOperation()->getOperand(0);
166      }]
167    >,
168    InterfaceMethod<
169      /*desc=*/"Return the filter operand.",
170      /*retTy=*/"Value",
171      /*methodName=*/"filter",
172      /*args=*/(ins),
173      /*methodBody=*/"",
174      /*defaultImplementation=*/[{
175        return $_op.getOperation()->getOperand(1);
176      }]
177    >,
178  ];
179}
180
181def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
182  let description = [{
183    A fill operation is defined in general terms:
184    1. Has a scalar `value` operand.
185    2. Has one `output` operand.
186  }];
187  let cppNamespace = "::mlir::linalg";
188  let verify = [{ return detail::verifyFillInterface($_op); }];
189  let methods = [
190    InterfaceMethod<
191      /*desc=*/"Return the fill value.",
192      /*retTy=*/"Value",
193      /*methodName=*/"value",
194      /*args=*/(ins),
195      /*methodBody=*/"",
196      /*defaultImplementation=*/[{
197        return $_op.getOperation()->getOperand(0);
198      }]
199    >,
200    InterfaceMethod<
201      /*desc=*/"Return the output operand.",
202      /*retTy=*/"Value",
203      /*methodName=*/"output",
204      /*args=*/(ins),
205      /*methodBody=*/"",
206      /*defaultImplementation=*/[{
207        return $_op.getOperation()->getOperand(1);
208      }]
209    >,
210    InterfaceMethod<
211      /*desc=*/"Return the result.",
212      /*retTy=*/"Value",
213      /*methodName=*/"result",
214      /*args=*/(ins),
215      /*methodBody=*/"",
216      /*defaultImplementation=*/[{
217        if ($_op.getOperation()->getResults().empty())
218          return nullptr;
219        return $_op.getOperation()->getResults().front();
220      }]
221    >,
222  ];
223}
224
225// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
226def LinalgStructuredInterface
227    : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
228  let cppNamespace = "::mlir::linalg";
229  let methods = [
230    //===------------------------------------------------------------------===//
231    // Loop types handling.
232    //===------------------------------------------------------------------===//
233    InterfaceMethod<
234      /*desc=*/[{
235        Return the number of parallel loops.
236      }],
237      /*retTy=*/"unsigned",
238      /*methodName=*/"getNumParallelLoops",
239      /*args=*/(ins),
240      /*methodBody=*/"",
241      /*defaultImplementation=*/[{
242        return llvm::count($_op.getIteratorTypesArray(),
243                           utils::IteratorType::parallel);
244      }]
245    >,
246    InterfaceMethod<
247      /*desc=*/[{
248        Return true if all loops are parallel.
249      }],
250      /*retTy=*/"bool",
251      /*methodName=*/"isAllParallelLoops",
252      /*args=*/(ins),
253      /*methodBody=*/"",
254      /*defaultImplementation=*/[{
255        return getNumParallelLoops() ==  getNumLoops();
256      }]
257    >,
258    InterfaceMethod<
259      /*desc=*/[{
260        Return the dims that are parallel loops.
261      }],
262      /*retTy=*/"void",
263      /*methodName=*/"getParallelDims",
264      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
265      /*methodBody=*/"",
266      /*defaultImplementation=*/[{
267        return findPositionsOfType($_op.getIteratorTypesArray(),
268                                   utils::IteratorType::parallel, res);
269      }]
270    >,
271    InterfaceMethod<
272      /*desc=*/[{
273        Return the number of reduction loops.
274      }],
275      /*retTy=*/"unsigned",
276      /*methodName=*/"getNumReductionLoops",
277      /*args=*/(ins),
278      /*methodBody=*/"",
279      /*defaultImplementation=*/[{
280        return llvm::count($_op.getIteratorTypesArray(),
281                           utils::IteratorType::reduction);
282      }]
283    >,
284    InterfaceMethod<
285      /*desc=*/[{
286        Return the dims that are reduction loops.
287      }],
288      /*retTy=*/"void",
289      /*methodName=*/"getReductionDims",
290      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
291      /*methodBody=*/"",
292      /*defaultImplementation=*/[{
293        return findPositionsOfType($_op.getIteratorTypesArray(),
294                                   utils::IteratorType::reduction, res);
295      }]
296    >,
297    InterfaceMethod<
298      /*desc=*/[{
299        Return the total number of loops within the current operation.
300      }],
301      /*retTy=*/"unsigned",
302      /*methodName=*/"getNumLoops",
303      /*args=*/(ins),
304      /*methodBody=*/"",
305      /*defaultImplementation=*/[{
306        return $_op.getIteratorTypesArray().size();
307      }]
308    >,
309    InterfaceMethod<
310      /*desc=*/[{
311        Returns true if the current operation has only one loop and it's a
312        reduction loop.
313      }],
314      /*retTy=*/"bool",
315      /*methodName=*/"hasSingleReductionLoop",
316      /*args=*/(ins),
317      /*methodBody=*/"",
318      /*defaultImplementation=*/[{
319        auto iters = $_op.getIteratorTypesArray();
320        return iters.size() == 1 &&
321               llvm::count(iters, utils::IteratorType::reduction) == 1;
322      }]>,
323    //===------------------------------------------------------------------===//
324    // Input and Init arguments handling.
325    //===------------------------------------------------------------------===//
326    InterfaceMethod<
327      /*desc=*/[{
328        Return true if the payload uses the value loaded from `opOperand`. This
329        is useful to avoid loading from "write-only" memory that may be
330        uninitialized, as well as properly cloning "read-write" operands.
331      }],
332      /*retTy=*/"bool",
333      /*methodName=*/"payloadUsesValueFromOperand",
334      /*args=*/(ins "OpOperand *":$opOperand),
335      /*methodBody=*/"",
336      /*defaultImplementation=*/[{
337        unsigned bbArgNumber = opOperand->getOperandNumber();
338        // Init tensors have uses.
339        return !getBlock()->getArgument(bbArgNumber).use_empty();
340      }]
341    >,
342    InterfaceMethod<
343      /*desc=*/[{
344        Returns true only if linalgOp takes one input and produces one result.
345      }],
346      /*retTy=*/"bool",
347      /*methodName=*/"isSingleInputOutput",
348      /*args=*/(ins),
349      /*methodBody=*/"",
350      /*defaultImplementation=*/[{
351        return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1;
352      }]
353    >,
354    InterfaceMethod<
355      /*desc=*/[{
356        Return true if `opOperand` is an init tensor. This is true when it is
357        an output tensor operand whose value is used in the payload region.
358      }],
359      /*retTy=*/"bool",
360      /*methodName=*/"isInitTensor",
361      /*args=*/(ins "OpOperand *":$opOperand),
362      /*methodBody=*/"",
363      /*defaultImplementation=*/[{
364        if (!$_op.isDpsInit(opOperand))
365          return false;
366        return payloadUsesValueFromOperand(opOperand);
367      }]
368    >,
369    InterfaceMethod<
370      /*desc=*/[{
371        Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
372      }],
373      /*retTy=*/"int64_t",
374      /*methodName=*/"getRank",
375      /*args=*/(ins "OpOperand*":$opOperand),
376      /*methodBody=*/"",
377      /*defaultImplementation=*/[{
378        assert(opOperand->getOwner() == this->getOperation());
379        Type t = opOperand->get().getType();
380        // A VectorType is an elemental type, do not consider its rank for the operand.
381        if (isa<VectorType>(t))
382          return 0;
383        // Tensor and Memref container types have a rank.
384        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
385          // Failsafe.
386          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
387                 "expected a ranked tensor or memref in LinalgInterface::getRank");
388          return shapedType.getRank();
389        }
390        return 0;
391      }]
392    >,
393    InterfaceMethod<
394      /*desc=*/[{
395        Return the input block arguments of the region.
396      }],
397      /*retTy=*/"Block::BlockArgListType",
398      /*methodName=*/"getRegionInputArgs",
399      /*args=*/(ins),
400      /*methodBody=*/"",
401      /*defaultImplementation=*/[{
402        return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
403      }]
404    >,
405    InterfaceMethod<
406      /*desc=*/[{
407        Return the output block arguments of the region.
408      }],
409      /*retTy=*/"Block::BlockArgListType",
410      /*methodName=*/"getRegionOutputArgs",
411      /*args=*/(ins),
412      /*methodBody=*/"",
413      /*defaultImplementation=*/[{
414        return getBlock()->getArguments().take_back($_op.getNumDpsInits());
415      }]
416    >,
417    InterfaceMethod<
418      /*desc=*/[{
419        Return the `opOperand` shape or an empty vector for scalars or vectors
420        not wrapped within a tensor or a memref.
421      }],
422      /*retTy=*/"ArrayRef<int64_t>",
423      /*methodName=*/"getShape",
424      /*args=*/(ins "OpOperand*":$opOperand),
425      /*methodBody=*/"",
426      /*defaultImplementation=*/[{
427        assert(opOperand->getOwner() == this->getOperation());
428        Type t = opOperand->get().getType();
429        // A VectorType is an elemental type, do not consider its rank for the operand.
430        if (isa<VectorType>(t))
431          return {};
432        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
433          // Failsafe.
434          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
435                 "expected a ranked tensor or memref in LinalgInterface::getRank");
436          return shapedType.getShape();
437        }
438        return {};
439      }]
440    >,
441    InterfaceMethod<
442      /*desc=*/[{
443        Return the block argument for an `opOperand`.
444      }],
445      /*retTy=*/"BlockArgument",
446      /*methodName=*/"getMatchingBlockArgument",
447      /*args=*/(ins "OpOperand *":$opOperand),
448      /*methodBody=*/"",
449      /*defaultImplementation=*/[{
450        assert(opOperand->getOwner() == this->getOperation());
451        return getBlock()->getArgument(opOperand->getOperandNumber());
452      }]
453    >,
454    InterfaceMethod<
455      /*desc=*/[{
456        Return the operand for a `blockArgument`.
457      }],
458      /*retTy=*/"OpOperand *",
459      /*methodName=*/"getMatchingOpOperand",
460      /*args=*/(ins "BlockArgument":$blockArgument),
461      /*methodBody=*/"",
462      /*defaultImplementation=*/[{
463        assert(blockArgument.getOwner() == getBlock());
464        return &this->getOperation()->getOpOperand(
465            blockArgument.getArgNumber());
466      }]
467    >,
468    InterfaceMethod<
469      /*desc=*/[{
470        Return the input or output indexing map for `opOperand`.
471      }],
472      /*retTy=*/"AffineMap",
473      /*methodName=*/"getMatchingIndexingMap",
474      /*args=*/(ins "OpOperand*":$opOperand),
475      /*methodBody=*/"",
476      /*defaultImplementation=*/[{
477        assert(opOperand->getOwner() == this->getOperation());
478        auto indexingMaps =
479          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
480        return *(indexingMaps.begin() + opOperand->getOperandNumber());
481      }]
482    >,
483    InterfaceMethod<
484      /*desc=*/[{
485        Return the indexing map for a `result`.
486      }],
487      /*retTy=*/"AffineMap",
488      /*methodName=*/"getIndexingMapMatchingResult",
489      /*args=*/(ins "OpResult":$result),
490      /*methodBody=*/"",
491      /*defaultImplementation=*/[{
492        assert(result.getOwner() == this->getOperation());
493        auto indexingMaps =
494          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
495        return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
496                 result.getResultNumber());
497      }]
498    >,
499    InterfaceMethod<
500      /*desc=*/[{
501        Return the value yielded by the region corresponding to an output
502        `opOperand`.
503      }],
504      /*retTy=*/"OpOperand *",
505      /*methodName=*/"getMatchingYieldValue",
506      /*args=*/(ins "OpOperand*":$opOperand),
507      /*methodBody=*/"",
508      /*defaultImplementation=*/[{
509        assert(opOperand->getOwner() == this->getOperation());
510        int64_t resultIndex =
511            opOperand->getOperandNumber() - $_op.getNumDpsInputs();
512        assert(resultIndex >= 0 &&
513               resultIndex < this->getOperation()->getNumResults());
514        Operation *yieldOp = getBlock()->getTerminator();
515        return &yieldOp->getOpOperand(resultIndex);
516      }]
517    >,
518    //===------------------------------------------------------------------===//
519    // Other interface methods.
520    //===------------------------------------------------------------------===//
521    InterfaceMethod<
522      /*desc=*/[{
523        Return the single block constituting the body of the operation by
524        calling the getBody method on the concrete operation.
525      }],
526      /*retTy=*/"Block*",
527      /*methodName=*/"getBlock",
528      /*args=*/(ins),
529      /*methodBody=*/"",
530      /*defaultImplementation=*/[{
531        // Assume the concrete operation implements the
532        // SingleBlockImplicitTerminator trait.
533        return $_op.getBody();
534      }]
535    >,
536    InterfaceMethod<
537      /*desc=*/[{
538        Return iterator types in the current operation.
539
540        Default implementation assumes that the operation has an attribute
541        `iterator_types`, but it's not always the case. Sometimes iterator types
542        can be infered from other parameters and in such cases default
543        getIteratorTypesArray should be overriden.
544      }],
545      /*retTy=*/"SmallVector<utils::IteratorType>",
546      /*methodName=*/"getIteratorTypesArray",
547      /*args=*/(ins),
548      /*methodBody=*/"",
549      /*defaultImplementation=*/[{
550        auto range = $_op.getIteratorTypes()
551                         .template getAsValueRange<IteratorTypeAttr,
552                                                   utils::IteratorType>();
553        return {range.begin(), range.end()};
554      }]
555    >,
556    InterfaceMethod<
557      /*desc=*/[{
558        Return true if the indexing map is depending on the current op instance.
559        This means that the indexing map is dynamically synthesized by using the
560        op instance's concrete attributes, instead of being static for all
561        instances of the same op kind.
562      }],
563      /*retTy=*/"bool",
564      /*methodName=*/"hasDynamicIndexingMaps",
565      /*args=*/(ins),
566      /*methodBody=*/"",
567      /*defaultImplementation=*/[{ return false; }]
568    >,
569    InterfaceMethod<
570      /*desc=*/[{
571        Verify all attributes used by indexing maps are valid.
572      }],
573      /*retTy=*/"LogicalResult",
574      /*methodName=*/"verifyIndexingMapRequiredAttributes",
575      /*args=*/(ins),
576      /*methodBody=*/"",
577      /*defaultImplementation=*/[{ return success(); }]
578    >,
579    InterfaceMethod<
580      /*desc=*/[{
581        Return the indexing maps attribute within the current operation.
582      }],
583      /*retTy=*/"ArrayAttr",
584      /*methodName=*/"getIndexingMaps"
585    >,
586    InterfaceMethod<
587      /*desc=*/[{
588        Return the indexing maps within the current operation.
589      }],
590      /*retTy=*/"SmallVector<AffineMap>",
591      /*methodName=*/"getIndexingMapsArray",
592      /*args=*/(ins),
593      /*methodBody=*/"",
594      /*defaultImplementation=*/[{
595        auto range = $_op.getIndexingMaps()
596          .template getAsValueRange<AffineMapAttr>();
597        return {range.begin(), range.end()};
598      }]
599    >,
600    InterfaceMethod<
601      /*desc=*/[{
602        Return true if any of the operands has a dynamic shape.
603      }],
604      /*retTy=*/"bool",
605      /*methodName=*/"hasDynamicShape",
606      /*args=*/(ins),
607      /*methodBody=*/"",
608      /*defaultImplementation=*/[{
609        return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
610      }]
611    >,
612    InterfaceMethod<
613      /*desc=*/[{
614        Return the name registered for this op when lowering to an external
615        library call.
616      }],
617      /*retTy=*/"std::string",
618      /*methodName=*/"getLibraryCallName",
619      /*args=*/(ins),
620      /*methodBody=*/"",
621      /*defaultImplementation=*/[{
622        return $_op.getLibraryCallName();
623      }]
624    >,
625    InterfaceMethod<
626      /*desc=*/[{
627         Return whether the op accesses the iteration indices.
628      }],
629      /*retTy=*/"bool",
630      /*methodName=*/"hasIndexSemantics",
631      /*args=*/(ins),
632      /*methodBody=*/"",
633      /*defaultImplementation=*/""
634    >,
635    InterfaceMethod<
636      /*desc=*/[{
637        Return op operands that have a corresponding argument in the basic block.
638        By default, the block should have an argument for each operand, but there
639        are expection. For example, in `map` output operand isn't used in
640        the block.
641      }],
642      /*retTy=*/"::llvm::SmallVector<OpOperand *>",
643      /*methodName=*/"getOpOperandsMatchingBBargs",
644      /*args=*/(ins),
645      /*methodBody=*/"",
646      /*defaultImplementation=*/[{
647        ::llvm::SmallVector<OpOperand *> result;
648        result.reserve($_op->getNumOperands());
649        llvm::transform(
650          this->getOperation()->getOpOperands(),
651          std::back_inserter(result),
652          [](OpOperand &opOperand) { return &opOperand; });
653        return result;
654      }]
655    >,
656    InterfaceMethod<
657      /*desc=*/[{
658        Given a dimension of the iteration space of a Linalg operation, finds an
659        operand in the operation that is defined on such dimension. Returns
660        whether such operand was found or not. If found, also returns the
661        operand value and the dimension position within the operand.
662      }],
663      /*retTy=*/"LogicalResult",
664      /*methodName=*/"mapIterationSpaceDimToOperandDim",
665      /*args=*/(ins "unsigned":$dimPos,
666                    "::mlir::Value &":$operand,
667                    "unsigned &":$operandDimPos),
668      /*methodBody=*/"",
669      /*defaultImplementation=*/[{
670        // Retrieve the operand and its dimension position from the first
671        // operand with a permutation map that is defined on such dimension.
672        for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
673          if (idxMap.isProjectedPermutation()) {
674            if (auto mayOperandDim = idxMap.getResultPosition(
675                getAffineDimExpr(dimPos, idxMap.getContext()))) {
676              operand = $_op->getOperand(i);
677              operandDimPos = *mayOperandDim;
678              return success();
679            }
680          }
681        }
682
683        return failure();
684      }]
685    >,
686    InterfaceMethod<
687      /*desc=*/[{
688        Given a dimension of the iteration space of a Linalg operation, finds
689        all the operands in the operation that are defined on such dimension.
690        Returns all the operand values found and their dimension positions in
691        `operandDimPairs`.
692      }],
693      /*retTy=*/"void",
694      /*methodName=*/"mapIterationSpaceDimToAllOperandDims",
695      /*args=*/(ins "unsigned":$dimPos,
696                    "mlir::SmallVectorImpl<std::pair<Value, unsigned>>&":$operandDimPairs),
697      /*methodBody=*/"",
698      /*defaultImplementation=*/[{
699        for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
700          if (idxMap.isProjectedPermutation()) {
701            if (auto mayOperandDim = idxMap.getResultPosition(
702                getAffineDimExpr(dimPos, idxMap.getContext()))) {
703              operandDimPairs.push_back({$_op->getOperand(i), *mayOperandDim});
704            }
705          }
706        }
707
708        return;
709      }]
710    >,
711    InterfaceMethod<
712      /*desc=*/[{
713        Return true if the user has supplied an explicit indexing maps for this op.
714      }],
715      /*retTy=*/"bool",
716      /*methodName=*/"hasUserDefinedMaps",
717      /*args=*/(ins),
718      /*methodBody=*/"",
719      /*defaultImplementation=*/[{ return false; }]
720    >,
721    //===------------------------------------------------------------------===//
722    // Linalg generalization hooks.
723    //===------------------------------------------------------------------===//
724    InterfaceMethod<
725      /*desc=*/[{
726        Hook to provide a custom AffineMap used to compute all the operand
727        subshapes given loop bounds. This is used to answer the question: "given
728        an iteration space over the codomain, what are the subshapes of the
729        operands involved in the computation".
730        The default behavior is to just concatenate all the indexing maps.
731        A custom AffineMap allows providing a map that can be used to
732        compute subshapes even in cases where the concatenation of indexing maps
733        (i.e. the data traversal order) is not a simple permutation of the loop
734        traversal order. It is then possible to define ops with skewed data
735        traversal order for which we can still easily compute hyperrectangular
736        loop bounds and subviews.
737      }],
738      /*retTy=*/"AffineMap",
739      /*methodName=*/"getLoopsToShapesMap",
740      /*args=*/(ins),
741      /*methodBody=*/"",
742      /*defaultImplementation=*/[{
743        auto maps =  $_op.getIndexingMapsArray();
744        return concatAffineMaps(maps, $_op.getContext());
745      }]
746    >,
747    InterfaceMethod<
748      /*desc=*/[{
749        Hook to provide a custom AffineMap used to construct the
750        hyperrectangular loop iteration space given all the operand subshapes.
751        This is used to answer the question:
752        "Given a list of operand ranges, what is the subportion of the iteration
753        space involved in the computation".
754        This is the inverse problem of `getLoopsToShapesMap`.
755        Return the empty AffineMap when such an AffineMap cannot be constructed.
756        The default behavior is based on a very simple inference procedure that
757        only works with permutation affine maps.
758        A more advanced Tensor-Comprehension like inference is possible but has
759        proven to be ambiguous in unfavorable case.
760        A safer and more robust alternative is to allow each op to define
761        its own AffineMap.
762      }],
763      /*retTy=*/"AffineMap",
764      /*methodName=*/"getShapesToLoopsMap",
765      /*args=*/(ins),
766      /*methodBody=*/"",
767      /*defaultImplementation=*/[{
768        return inversePermutation(getLoopsToShapesMap());
769      }]
770    >,
771    InterfaceMethod<
772      /*desc=*/[{
773        Checks if the given operands can be dropped, and the remaining
774        operands can still compute the bounds of the op.
775      }],
776      /*retTy=*/"bool",
777      /*methodName=*/"canOpOperandsBeDropped",
778      /*args=*/(ins "ArrayRef<OpOperand *>":$droppedOperands),
779      /*methodBody=*/"",
780      /*defaultImplementation=*/[{
781        return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
782      }]
783    >,
784    InterfaceMethod<
785      /*desc=*/[{
786        Like `getShape`, but only returns statically-known information, without
787        generating any new IR. For each shape dimension, returns >=0 if that
788        dimension is statically known, or ShapedType::kDynamic otherwise.
789      }],
790      /*retTy=*/"SmallVector<int64_t>",
791      /*methodName=*/"getStaticShape",
792      /*args=*/(ins),
793      /*methodBody=*/"",
794      /*defaultImplementation=*/[{
795        SmallVector<int64_t> res;
796        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
797          llvm::append_range(res, getShape(&opOperand));
798        return res;
799      }]
800    >,
801    InterfaceMethod<
802      /*desc=*/[{
803        Returns the statically-known loop ranges. Composes
804        `getShapesToLoopsMap()` with the result of `getStaticShape`.
805        Returns ShapedType::kDynamic for non-statically-known loop ranges.
806        This is expected to be called by a valid Linalg op
807      }],
808      /*retTy=*/"SmallVector<int64_t, 4>",
809      /*methodName=*/"getStaticLoopRanges",
810      /*args=*/(ins),
811      /*methodBody=*/"",
812      /*defaultImplementation=*/[{
813        SmallVector<int64_t> viewSizes = getStaticShape();
814        AffineMap invertedMap = getShapesToLoopsMap();
815        assert(invertedMap && "expected a valid Linalg op to call the method");
816        return invertedMap.compose(viewSizes);
817      }]
818    >,
819    //===------------------------------------------------------------------===//
820    // Other static interface methods.
821    //===------------------------------------------------------------------===//
822    StaticInterfaceMethod<
823      /*desc=*/[{
824        Returns the region builder for constructing the body for linalg.generic.
825        Returns a null function if this named op does not define a region
826        builder.
827      }],
828      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
829      /*methodName=*/"getRegionBuilder",
830      (ins),
831      [{ return ConcreteOp::getRegionBuilder(); }]
832    >,
833    InterfaceMethod<
834      /*desc=*/[{
835        Return true if all the indexing maps are projected permutations.
836        Otherwise return false.
837      }],
838      /*retTy=*/"bool",
839      /*methodName=*/"hasOnlyProjectedPermutations",
840      (ins),
841      [{
842        return llvm::all_of($_op.getIndexingMapsArray(),
843                            [](AffineMap map) { return map.isProjectedPermutation(); });
844      }]
845    >
846  ];
847
848  let extraClassDeclaration = [{
849    /// Return the flat list of all operand dimension sizes in the order they
850    /// appear in the operands.
851    SmallVector<OpFoldResult> createFlatListOfOperandDims(OpBuilder &, Location);
852
853    /// Return the flat list of all operands' static dimension sizes in the
854    /// order they appear in the operands. All operand dimension sizes have to
855    /// be statically known.
856    SmallVector<int64_t, 4> createFlatListOfOperandStaticDims();
857
858    /// Create the loop ranges to materialize the computation over the current
859    /// operands. This is done by applying `getShapesToLoopsMap` to
860    /// `createFlatListOfOperandDims`.
861    SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
862
863    /// Compute the static loop sizes necessary to vectorize the computation.
864    /// This is done by applying `getShapesToLoopsMap` to
865    /// `createFlatListOfOperandStaticDims`.
866    SmallVector<int64_t, 4> computeStaticLoopSizes();
867
868    /// Returns the value that expresses the shape of the output in terms of
869    /// shape of the input operands where possible
870    LogicalResult reifyResultShapes(OpBuilder &b,
871        ReifiedRankedShapedTypeDims &reifiedReturnShapes);
872
873    /// Return the index in the indexingMaps vector that corresponds to this `opOperand`
874    int64_t getIndexingMapIndex(OpOperand *opOperand);
875  }];
876
877  let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
878  let verifyWithRegions = 1;
879}
880
881def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
882  let description = [{
883    Interface for decomposing aggregated operations into a sequence of simpler
884    ops.
885  }];
886  let cppNamespace = "::mlir::linalg";
887  let methods = [
888      InterfaceMethod<
889        /*desc=*/[{
890          Method to decompose the operation into simpler operations.
891
892          On success, this method returns one `Value` per result in the
893          original operation.
894          The order of the returned values must match the order of the
895          original values.
896          In other words, the returned vector can be used directly with
897          `RewriterBase::replaceOp(this, returnedValues)`.
898        }],
899        /*retType=*/"FailureOr<SmallVector<Value>>",
900        /*methodName=*/"decomposeOperation",
901        /*args=*/(ins
902            "OpBuilder &":$b),
903        /*methodBody=*/"",
904        /*defaultImplementation=*/[{
905          return {};
906        }]
907      >
908  ];
909}
910
911#endif // LINALG_IR_LINALGINTERFACES
912