xref: /llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (revision c13f806f17ac61961015e38b69c8b39ba7d454ac)
1 //===- LinalgInterface.h - Linalg operations interfaces -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the operation interfaces for Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
14 #define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
15 
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Interfaces/ViewLikeInterface.h"
25 #include "mlir/Support/RawOstreamExtras.h"
26 
27 namespace mlir {
28 namespace linalg {
29 class IteratorTypeAttr;
30 class LinalgOp;
31 class GenericOp;
32 
33 namespace detail {
34 /// Implementation of the method that check if given operands
35 /// can be dropped, i.e. the remaining operands can compute the loop
36 /// bounds of the op.
37 bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,
38                                 ArrayRef<OpOperand *> droppedOperands);
39 } // namespace detail
40 
41 /// Positions of a Linalg op loops that correspond to different kinds of a
42 /// contraction dimension.
43 struct ContractionDimensions {
44   SmallVector<unsigned, 2> batch;
45   SmallVector<unsigned, 2> m;
46   SmallVector<unsigned, 2> n;
47   SmallVector<unsigned, 2> k;
48 };
49 
50 /// Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates
51 /// that form a matmul subcomputation within `linalgOp`.
52 /// These dimensions are such that:
53 ///   1. The m dimension is involved in an outer-product along LHS
54 ///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
55 ///   2. The n dimension is involved in an outer-product along RHS
56 ///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
57 ///   3. The k dimension appears as a permutation on LHS and RHS.
58 ///   4. m, n and k appear only once in any given indexing.
59 ///   5. Optional batch dimensions that appear in all operands are captured.
60 /// This allows e.g. detecting that some contraction is embedded within
61 /// `linalgOp` with some orthogonal heuristic.
62 /// When multiple dimension occurrences exist that match `batch`, `m`, `n`, or
63 /// `k`, indices are returned in sorted order.
64 /// Returns a failure if any of `m`, `n` or `k` is empty.
65 FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
66 FailureOr<ContractionDimensions>
67 inferContractionDims(ArrayRef<AffineMap> indexingMaps);
68 
69 /// Checks whether `linalgOp` conforms to ContractionOpInterface.
70 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
71 bool isaContractionOpInterface(LinalgOp linalgOp);
72 
73 /// Positions of a Linalg op loops that correspond to different kinds of a
74 /// convolution dimension.
75 struct ConvolutionDimensions {
76   SmallVector<unsigned, 2> batch;
77   SmallVector<unsigned, 2> outputImage;
78   SmallVector<unsigned, 2> outputChannel;
79   SmallVector<unsigned, 2> filterLoop;
80   SmallVector<unsigned, 2> inputChannel;
81   SmallVector<unsigned, 2> depth;
82   SmallVector<int64_t, 2> strides;
83   SmallVector<int64_t, 2> dilations;
84 };
85 
86 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
87 /// dimension candidates that form a convolution subcomputation within
88 /// `linalgOp`. The LHS is assumed to be the convolution input while the
89 /// RHS is assumed as the filter.
90 /// These dimensions are such that:
91 ///   1. Optional batch dimensions that appear in the input and filter.
92 ///   2. The output_image dimension is involved in a cross-correlation along LHS
93 ///      (i.e. it is a permutation on RES and LHS and has an associated
94 ///      filter_loop in RHS).
95 ///   3. Optional output_channel dimension is involved in an outer-product along
96 ///      RHS (i.e. it is a permutation on RES and RHS and does not appear in
97 ///      LHS).
98 ///   4. Optional input_channel dimension appears as a permutation on LHS and
99 ///      RHS.
100 ///   5. The filter_loop dimension appears as a permutation on the RHS and
101 ///      represents the shape of the kernel cross-correlated along a
102 ///      corresponding output_image dim.
103 ///   6. The input_channel dimension appears as a permutation on LHS and RHS.
104 ///   7. All dimensions appear only once in any given indexing map.
105 /// This allows e.g. detecting that some convolution is embedded within
106 /// `linalgOp` with some orthogonal heuristic.
107 /// When multiple dimension occurrences exist that match any classification
108 /// indices are returned in sorted order.
109 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
110 FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
111 
112 /// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
113 /// By default, we require the `linalgOp` to have non-empty convolved dims
114 /// (implicitly non-empty `output_image` and `filter_loop`).
115 /// Users can loosen the constraint by setting `allowEmptyConvolvedDims` to true
116 // TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
117 bool isaConvolutionOpInterface(LinalgOp linalgOp,
118                                bool allowEmptyConvolvedDims = false);
119 
120 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
121 bool isaCopyOpInterface(LinalgOp linalgOp);
122 
123 /// Checks whether `genericOp` is semantically equivalent to a
124 ///  `linalg.broadcast`. Returns broadcast dimensions if true.
125 std::optional<SmallVector<int64_t>>
126 isaBroadcastOpInterface(GenericOp genericOp);
127 
128 /// Checks whether `genericOp` is semantically equivalent to a
129 ///  `linalg.transpose`. Returns permuted dimensions if true.
130 std::optional<SmallVector<int64_t>>
131 isaTransposeOpInterface(GenericOp genericOp);
132 
133 /// Checks whether a given `genericOp` is semantically equivalent to a single
134 /// linalgelementwise unary op. e.g. linalg.exp.
135 /// A linalg.generic body could be a series of unary elementwise ops e.g.
136 /// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
137 /// detecting cases where body is is a single computation op.
138 bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
139 
140 /// Checks whether `genericOp` is semantically equivalent to a single linalg
141 /// elementwise binary op e.g. linalg.sub.
142 bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
143 
144 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
145 /// Returns the scalar fill value if true.
146 std::optional<Value> isaFillOpInterface(GenericOp genericOp);
147 
148 namespace detail {
149 
150 /// Returns true if the block contains a contraction of the following form:
151 ///
152 ///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
153 ///                                  cu(block-argument-1)))
154 ///   %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
155 ///   return-like cu(%1)
156 ///
157 /// where <elemwise> and <reduce> are binary operations constituting a
158 /// contraction (in the canonical case, <elemwise> is a multiplication and
159 /// <reduce> is an addition). The name and other properties of these operations
160 /// are checked by `isaPair`. All operands of all operations may be supplied
161 /// through a chain of side effect-free unary operations, such as casts, which
162 /// is denoted as `cu` above.
163 ///
164 /// When the body does not contain a contraction, a more precise description of
165 /// the failed precondition is send to the `errs` stream, if provided.
166 bool isContractionBody(Block &block,
167                        function_ref<bool(Operation *, Operation *)> isaPair,
168                        llvm::raw_ostream &errs = mlir::thread_safe_nulls());
169 
170 /// Result of matching a Linalg generic against the predicates of it being a
171 /// contraction.
172 enum class MatchContractionResult;
173 
174 /// Checks whether `op` conforms to ContractionOpInterface and populates
175 /// `dimensions` with indexes of the different kinds of dimensions when
176 /// present.
177 MatchContractionResult
178 isContractionInterfaceImpl(Operation *op,
179                            ContractionDimensions *dimensions = nullptr);
180 
181 /// Returns the error message corresponding to the contraction checking return
182 /// code.
183 StringRef getMatchContractionMessage(MatchContractionResult res);
184 
185 /// Result of matching a Linalg generic against the predicates of it being a
186 /// convolution.
187 enum class MatchConvolutionResult;
188 
189 /// Checks whether `op` conforms to ConvolutionOpInterface and populates
190 /// `dimensions` with indexes of the different kinds of dimensions when
191 /// present.
192 /// If `allowEmptyConvolvedDims` is not set, we further checks whether the `op`
193 /// contains convolved dims.
194 MatchConvolutionResult
195 isConvolutionInterfaceImpl(Operation *op,
196                            ConvolutionDimensions *dimensions = nullptr,
197                            bool allowEmptyConvolvedDims = false);
198 
199 /// Returns the error message corresponding to the convolution checking return
200 /// code.
201 StringRef getMatchConvolutionMessage(MatchConvolutionResult res);
202 
203 /// Verify that `op` conforms to ContractionOpInterface.
204 LogicalResult verifyContractionInterface(Operation *op);
205 
206 /// Verify that `op` conforms to the ConvolutionOpInterface.
207 LogicalResult verifyConvolutionInterface(Operation *op);
208 
209 /// Verify that `op` conforms to the FillOpInterface.
210 LogicalResult verifyFillInterface(Operation *op);
211 
212 /// Verify that `op` conforms to the invariants of StructuredOpInterface
213 LogicalResult verifyStructuredOpInterface(Operation *op);
214 
215 } // namespace detail
216 } // namespace linalg
217 } // namespace mlir
218 
219 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
220 
221 /// Include the generated interface declarations.
222 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc"
223 
224 #endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
225