1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ 10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ 11 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Utils/IndexingUtils.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/Support/LLVM.h" 19 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 25 // Forward declarations. 26 class AffineMap; 27 class Block; 28 class Location; 29 class OpBuilder; 30 class Operation; 31 class ShapedType; 32 class Value; 33 class VectorType; 34 class VectorTransferOpInterface; 35 36 namespace affine { 37 class AffineApplyOp; 38 class AffineForOp; 39 } // namespace affine 40 41 namespace vector { 42 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 43 /// the type of `source`. 44 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); 45 46 /// Returns two dims that are greater than one if the transposition is applied 47 /// on a 2D slice. Otherwise, returns a failure. 48 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op); 49 50 /// Return true if `vectorType` is a contiguous slice of `memrefType`. 51 /// 52 /// Only the N = vectorType.getRank() trailing dims of `memrefType` are 53 /// checked (the other dims are not relevant). Note that for `vectorType` to be 54 /// a contiguous slice of `memrefType`, the trailing dims of the latter have 55 /// to be contiguous - this is checked by looking at the corresponding strides. 56 /// 57 /// There might be some restriction on the leading dim of `VectorType`: 58 /// 59 /// Case 1. If all the trailing dims of `vectorType` match the trailing dims 60 /// of `memrefType` then the leading dim of `vectorType` can be 61 /// arbitrary. 62 /// 63 /// Ex. 1.1 contiguous slice, perfect match 64 /// vector<4x3x2xi32> from memref<5x4x3x2xi32> 65 /// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) 66 /// vector<2x3x2xi32> from memref<5x4x3x2xi32> 67 /// 68 /// Case 2. If an "internal" dim of `vectorType` does not match the 69 /// corresponding trailing dim in `memrefType` then the remaining 70 /// leading dims of `vectorType` have to be 1 (the first non-matching 71 /// dim can be arbitrary). 72 /// 73 /// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> 74 /// vector<2x2x2xi32> from memref<5x4x3x2xi32> 75 /// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> 76 /// vector<1x2x2xi32> from memref<5x4x3x2xi32> 77 /// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> 78 /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> 79 /// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> 80 /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) 81 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType); 82 83 /// Returns an iterator for all positions in the leading dimensions of `vType` 84 /// up to the `targetRank`. If any leading dimension before the `targetRank` is 85 /// scalable (so cannot be unrolled), it will return an iterator for positions 86 /// up to the first scalable dimension. 87 /// 88 /// If no leading dimensions can be unrolled an empty optional will be returned. 89 /// 90 /// Examples: 91 /// 92 /// For vType = vector<2x3x4> and targetRank = 1 93 /// 94 /// The resulting iterator will yield: 95 /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] 96 /// 97 /// For vType = vector<3x[4]x5> and targetRank = 0 98 /// 99 /// The scalable dimension blocks unrolling so the iterator yields only: 100 /// [0], [1], [2] 101 /// 102 std::optional<StaticTileOffsetRange> 103 createUnrollIterator(VectorType vType, int64_t targetRank = 1); 104 105 /// Returns a functor (int64_t -> Value) which returns a constant vscale 106 /// multiple. 107 /// 108 /// Example: 109 /// ```c++ 110 /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc); 111 /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale 112 /// ``` 113 inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) { 114 Value vscale = nullptr; 115 return [loc, vscale, &rewriter](int64_t multiplier) mutable { 116 if (!vscale) 117 vscale = rewriter.create<vector::VectorScaleOp>(loc); 118 return rewriter.create<arith::MulIOp>( 119 loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier)); 120 }; 121 } 122 123 /// Returns a range over the dims (size and scalability) of a VectorType. 124 inline auto getDims(VectorType vType) { 125 return llvm::zip_equal(vType.getShape(), vType.getScalableDims()); 126 } 127 128 /// A wrapper for getMixedSizes for vector.transfer_read and 129 /// vector.transfer_write Ops (for source and destination, respectively). 130 /// 131 /// Tensor and MemRef types implement their own, very similar version of 132 /// getMixedSizes. This method will call the appropriate version (depending on 133 /// `hasTensorSemantics`). It will also automatically extract the operand for 134 /// which to call it on (source for "read" and destination for "write" ops). 135 SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics, 136 Operation *xfer, 137 RewriterBase &rewriter); 138 139 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be 140 /// masked (i.e. inside `vector.mask` Op region). In particular: 141 /// 1. Matches `SourceOp` operation, Op. 142 /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the 143 /// insertion point to avoid inserting new ops into the `vector.mask` Op 144 /// region (which only allows one Op). 145 /// 2.2 If Op is not masked, this step is skipped. 146 /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if 147 /// found in step 2.1. 148 /// 149 /// This wrapper frees patterns from re-implementing the logic to update the 150 /// insertion point when a maskable Op is masked. Such patterns are still 151 /// responsible for providing an updated ("rewritten") version of: 152 /// a. the source Op when mask _is not_ present, 153 /// b. the source Op and the masking Op when mask _is_ present. 154 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that 155 /// the return value will depend on the case above. 156 template <class SourceOp> 157 struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> { 158 using OpRewritePattern<SourceOp>::OpRewritePattern; 159 160 private: 161 LogicalResult matchAndRewrite(SourceOp sourceOp, 162 PatternRewriter &rewriter) const final { 163 auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation()); 164 if (!maskableOp) 165 return failure(); 166 167 Operation *rootOp = sourceOp; 168 169 // If this Op is masked, update the insertion point to avoid inserting into 170 // the vector.mask Op region. 171 OpBuilder::InsertionGuard guard(rewriter); 172 MaskingOpInterface maskOp; 173 if (maskableOp.isMasked()) { 174 maskOp = maskableOp.getMaskingOp(); 175 rewriter.setInsertionPoint(maskOp); 176 rootOp = maskOp; 177 } 178 179 FailureOr<Value> newOp = 180 matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); 181 if (failed(newOp)) 182 return failure(); 183 184 // Rewriting succeeded but there are no values to replace. 185 if (rootOp->getNumResults() == 0) { 186 rewriter.eraseOp(rootOp); 187 } else { 188 assert(*newOp != Value() && 189 "Cannot replace an op's use with an empty value."); 190 rewriter.replaceOp(rootOp, *newOp); 191 } 192 return success(); 193 } 194 195 public: 196 // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the 197 // latter is present, returns a replacement for `maskingOp`. Otherwise, 198 // returns a replacement for `sourceOp`. 199 virtual FailureOr<Value> 200 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, 201 PatternRewriter &rewriter) const = 0; 202 }; 203 204 /// Returns true if the input Vector type can be linearized. 205 /// 206 /// Linearization is meant in the sense of flattening vectors, e.g.: 207 /// * vector<NxMxKxi32> -> vector<N*M*Kxi32> 208 /// In this sense, Vectors that are either: 209 /// * already linearized, or 210 /// * contain more than 1 scalable dimensions, 211 /// are not linearizable. 212 bool isLinearizableVector(VectorType type); 213 214 /// Create a TransferReadOp from `source` with static shape `readShape`. If the 215 /// vector type for the read is not the same as the type of `source`, then a 216 /// mask is created on the read, if use of mask is specified or the bounds on a 217 /// dimension are different. 218 /// 219 /// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set 220 /// properly, based on 221 /// the rank dimensions of the source and destination tensors. And that is 222 /// what determines if masking is done. 223 /// 224 /// Note that the internal `vector::TransferReadOp` always read at indices zero 225 /// for each dimension of the passed in tensor. 226 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, 227 ArrayRef<int64_t> readShape, Value padValue, 228 bool useInBoundsInsteadOfMasking); 229 230 /// Returns success if `inputVectorSizes` is a valid masking configuraion for 231 /// given `shape`, i.e., it meets: 232 /// 1. The numbers of elements in both array are equal. 233 /// 2. `inputVectorSizes` does not have dynamic dimensions. 234 /// 3. All the values in `inputVectorSizes` are greater than or equal to 235 /// static sizes in `shape`. 236 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape, 237 ArrayRef<int64_t> inputVectorSizes); 238 } // namespace vector 239 240 /// Constructs a permutation map of invariant memref indices to vector 241 /// dimension. 242 /// 243 /// If no index is found to be invariant, 0 is added to the permutation_map and 244 /// corresponds to a vector broadcast along that dimension. 245 /// 246 /// The implementation uses the knowledge of the mapping of loops to 247 /// vector dimension. `loopToVectorDim` carries this information as a map with: 248 /// - keys representing "vectorized enclosing loops"; 249 /// - values representing the corresponding vector dimension. 250 /// Note that loopToVectorDim is a whole function map from which only enclosing 251 /// loop information is extracted. 252 /// 253 /// Prerequisites: `indices` belong to a vectorizable load or store operation 254 /// (i.e. at most one invariant index along each AffineForOp of 255 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized 256 /// load or store operation. 257 /// 258 /// Example 1: 259 /// The following MLIR snippet: 260 /// 261 /// ```mlir 262 /// affine.for %i3 = 0 to %0 { 263 /// affine.for %i4 = 0 to %1 { 264 /// affine.for %i5 = 0 to %2 { 265 /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32> 266 /// }}} 267 /// ``` 268 /// 269 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: 270 /// 271 /// ```mlir 272 /// affine.for %i3 = 0 to %0 step 32 { 273 /// affine.for %i4 = 0 to %1 { 274 /// affine.for %i5 = 0 to %2 step 256 { 275 /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3 276 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : 277 /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32> 278 /// }}} 279 /// ``` 280 /// 281 /// Meaning that vector.transfer_read will be responsible for reading the slice: 282 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>. 283 /// 284 /// Example 2: 285 /// The following MLIR snippet: 286 /// 287 /// ```mlir 288 /// %cst0 = arith.constant 0 : index 289 /// affine.for %i0 = 0 to %0 { 290 /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32> 291 /// } 292 /// ``` 293 /// 294 /// may vectorize with {permutation_map: (d0) -> (0)} into: 295 /// 296 /// ```mlir 297 /// affine.for %i0 = 0 to %0 step 128 { 298 /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0 299 /// {permutation_map: (d0, d1) -> (0)} : 300 /// (memref<?x?xf32>, index, index) -> vector<128xf32> 301 /// } 302 /// ```` 303 /// 304 /// Meaning that vector.transfer_read will be responsible of reading the slice 305 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. 306 /// 307 AffineMap 308 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices, 309 const DenseMap<Operation *, unsigned> &loopToVectorDim); 310 AffineMap 311 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices, 312 const DenseMap<Operation *, unsigned> &loopToVectorDim); 313 314 namespace matcher { 315 316 /// Matches vector.transfer_read, vector.transfer_write and ops that return a 317 /// vector type that is a multiple of the sub-vector type. This allows passing 318 /// over other smaller vector types in the function and avoids interfering with 319 /// operations on those. 320 /// This is a first approximation, it can easily be extended in the future. 321 /// TODO: this could all be much simpler if we added a bit that a vector type to 322 /// mark that a vector is a strict super-vector but it still does not warrant 323 /// adding even 1 extra bit in the IR for now. 324 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType); 325 326 } // namespace matcher 327 } // namespace mlir 328 329 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ 330