1 //===- LinalgInterface.h - Linalg operations interfaces -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the operation interfaces for Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ 14 #define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ 15 16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/IRMapping.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 #include "mlir/Interfaces/ViewLikeInterface.h" 25 #include "mlir/Support/RawOstreamExtras.h" 26 27 namespace mlir { 28 namespace linalg { 29 class IteratorTypeAttr; 30 class LinalgOp; 31 class GenericOp; 32 33 namespace detail { 34 /// Implementation of the method that check if given operands 35 /// can be dropped, i.e. the remaining operands can compute the loop 36 /// bounds of the op. 37 bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, 38 ArrayRef<OpOperand *> droppedOperands); 39 } // namespace detail 40 41 /// Positions of a Linalg op loops that correspond to different kinds of a 42 /// contraction dimension. 43 struct ContractionDimensions { 44 SmallVector<unsigned, 2> batch; 45 SmallVector<unsigned, 2> m; 46 SmallVector<unsigned, 2> n; 47 SmallVector<unsigned, 2> k; 48 }; 49 50 /// Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates 51 /// that form a matmul subcomputation within `linalgOp`. 52 /// These dimensions are such that: 53 /// 1. The m dimension is involved in an outer-product along LHS 54 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). 55 /// 2. The n dimension is involved in an outer-product along RHS 56 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). 57 /// 3. The k dimension appears as a permutation on LHS and RHS. 58 /// 4. m, n and k appear only once in any given indexing. 59 /// 5. Optional batch dimensions that appear in all operands are captured. 60 /// This allows e.g. detecting that some contraction is embedded within 61 /// `linalgOp` with some orthogonal heuristic. 62 /// When multiple dimension occurrences exist that match `batch`, `m`, `n`, or 63 /// `k`, indices are returned in sorted order. 64 /// Returns a failure if any of `m`, `n` or `k` is empty. 65 FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp); 66 FailureOr<ContractionDimensions> 67 inferContractionDims(ArrayRef<AffineMap> indexingMaps); 68 69 /// Checks whether `linalgOp` conforms to ContractionOpInterface. 70 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural. 71 bool isaContractionOpInterface(LinalgOp linalgOp); 72 73 /// Positions of a Linalg op loops that correspond to different kinds of a 74 /// convolution dimension. 75 struct ConvolutionDimensions { 76 SmallVector<unsigned, 2> batch; 77 SmallVector<unsigned, 2> outputImage; 78 SmallVector<unsigned, 2> outputChannel; 79 SmallVector<unsigned, 2> filterLoop; 80 SmallVector<unsigned, 2> inputChannel; 81 SmallVector<unsigned, 2> depth; 82 SmallVector<int64_t, 2> strides; 83 SmallVector<int64_t, 2> dilations; 84 }; 85 86 /// Find at least 1 parallel (output_image) and reduction (filter_loop) 87 /// dimension candidates that form a convolution subcomputation within 88 /// `linalgOp`. The LHS is assumed to be the convolution input while the 89 /// RHS is assumed as the filter. 90 /// These dimensions are such that: 91 /// 1. Optional batch dimensions that appear in the input and filter. 92 /// 2. The output_image dimension is involved in a cross-correlation along LHS 93 /// (i.e. it is a permutation on RES and LHS and has an associated 94 /// filter_loop in RHS). 95 /// 3. Optional output_channel dimension is involved in an outer-product along 96 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in 97 /// LHS). 98 /// 4. Optional input_channel dimension appears as a permutation on LHS and 99 /// RHS. 100 /// 5. The filter_loop dimension appears as a permutation on the RHS and 101 /// represents the shape of the kernel cross-correlated along a 102 /// corresponding output_image dim. 103 /// 6. The input_channel dimension appears as a permutation on LHS and RHS. 104 /// 7. All dimensions appear only once in any given indexing map. 105 /// This allows e.g. detecting that some convolution is embedded within 106 /// `linalgOp` with some orthogonal heuristic. 107 /// When multiple dimension occurrences exist that match any classification 108 /// indices are returned in sorted order. 109 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. 110 FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp); 111 112 /// Checks whether `linalgOp` conforms to ConvolutionOpInterface. 113 /// By default, we require the `linalgOp` to have non-empty convolved dims 114 /// (implicitly non-empty `output_image` and `filter_loop`). 115 /// Users can loosen the constraint by setting `allowEmptyConvolvedDims` to true 116 // TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural. 117 bool isaConvolutionOpInterface(LinalgOp linalgOp, 118 bool allowEmptyConvolvedDims = false); 119 120 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`. 121 bool isaCopyOpInterface(LinalgOp linalgOp); 122 123 /// Checks whether `genericOp` is semantically equivalent to a 124 /// `linalg.broadcast`. Returns broadcast dimensions if true. 125 std::optional<SmallVector<int64_t>> 126 isaBroadcastOpInterface(GenericOp genericOp); 127 128 /// Checks whether `genericOp` is semantically equivalent to a 129 /// `linalg.transpose`. Returns permuted dimensions if true. 130 std::optional<SmallVector<int64_t>> 131 isaTransposeOpInterface(GenericOp genericOp); 132 133 /// Checks whether a given `genericOp` is semantically equivalent to a single 134 /// linalgelementwise unary op. e.g. linalg.exp. 135 /// A linalg.generic body could be a series of unary elementwise ops e.g. 136 /// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to 137 /// detecting cases where body is is a single computation op. 138 bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp); 139 140 /// Checks whether `genericOp` is semantically equivalent to a single linalg 141 /// elementwise binary op e.g. linalg.sub. 142 bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp); 143 144 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`. 145 /// Returns the scalar fill value if true. 146 std::optional<Value> isaFillOpInterface(GenericOp genericOp); 147 148 namespace detail { 149 150 /// Returns true if the block contains a contraction of the following form: 151 /// 152 /// %0 = <elemwise>(permutation-of(cu(block-argument-0), 153 /// cu(block-argument-1))) 154 /// %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2))) 155 /// return-like cu(%1) 156 /// 157 /// where <elemwise> and <reduce> are binary operations constituting a 158 /// contraction (in the canonical case, <elemwise> is a multiplication and 159 /// <reduce> is an addition). The name and other properties of these operations 160 /// are checked by `isaPair`. All operands of all operations may be supplied 161 /// through a chain of side effect-free unary operations, such as casts, which 162 /// is denoted as `cu` above. 163 /// 164 /// When the body does not contain a contraction, a more precise description of 165 /// the failed precondition is send to the `errs` stream, if provided. 166 bool isContractionBody(Block &block, 167 function_ref<bool(Operation *, Operation *)> isaPair, 168 llvm::raw_ostream &errs = mlir::thread_safe_nulls()); 169 170 /// Result of matching a Linalg generic against the predicates of it being a 171 /// contraction. 172 enum class MatchContractionResult; 173 174 /// Checks whether `op` conforms to ContractionOpInterface and populates 175 /// `dimensions` with indexes of the different kinds of dimensions when 176 /// present. 177 MatchContractionResult 178 isContractionInterfaceImpl(Operation *op, 179 ContractionDimensions *dimensions = nullptr); 180 181 /// Returns the error message corresponding to the contraction checking return 182 /// code. 183 StringRef getMatchContractionMessage(MatchContractionResult res); 184 185 /// Result of matching a Linalg generic against the predicates of it being a 186 /// convolution. 187 enum class MatchConvolutionResult; 188 189 /// Checks whether `op` conforms to ConvolutionOpInterface and populates 190 /// `dimensions` with indexes of the different kinds of dimensions when 191 /// present. 192 /// If `allowEmptyConvolvedDims` is not set, we further checks whether the `op` 193 /// contains convolved dims. 194 MatchConvolutionResult 195 isConvolutionInterfaceImpl(Operation *op, 196 ConvolutionDimensions *dimensions = nullptr, 197 bool allowEmptyConvolvedDims = false); 198 199 /// Returns the error message corresponding to the convolution checking return 200 /// code. 201 StringRef getMatchConvolutionMessage(MatchConvolutionResult res); 202 203 /// Verify that `op` conforms to ContractionOpInterface. 204 LogicalResult verifyContractionInterface(Operation *op); 205 206 /// Verify that `op` conforms to the ConvolutionOpInterface. 207 LogicalResult verifyConvolutionInterface(Operation *op); 208 209 /// Verify that `op` conforms to the FillOpInterface. 210 LogicalResult verifyFillInterface(Operation *op); 211 212 /// Verify that `op` conforms to the invariants of StructuredOpInterface 213 LogicalResult verifyStructuredOpInterface(Operation *op); 214 215 } // namespace detail 216 } // namespace linalg 217 } // namespace mlir 218 219 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" 220 221 /// Include the generated interface declarations. 222 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc" 223 224 #endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ 225