xref: /llvm-project/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- VectorUtils.cpp - MLIR Utilities for VectorOps   ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utility methods for working with the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14 
15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Support/LLVM.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/SetVector.h"
31 
32 #define DEBUG_TYPE "vector-utils"
33 
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36 
37 using namespace mlir;
38 
39 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
40 /// the type of `source`.
41 Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
42                                       int64_t dim) {
43   if (isa<UnrankedMemRefType, MemRefType>(source.getType()))
44     return b.createOrFold<memref::DimOp>(loc, source, dim);
45   if (isa<UnrankedTensorType, RankedTensorType>(source.getType()))
46     return b.createOrFold<tensor::DimOp>(loc, source, dim);
47   llvm_unreachable("Expected MemRefType or TensorType");
48 }
49 
50 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
51 /// should be transposed with each other within the context of their 2D
52 /// transposition slice.
53 ///
54 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
55 ///   Return true: dim0 and dim1 are transposed within the context of their 2D
56 ///   transposition slice ([1, 0]).
57 ///
58 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
59 ///   Return true: dim0 and dim1 are transposed within the context of their 2D
60 ///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
61 ///   transposed within the full context of the transposition.
62 ///
63 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
64 ///   Return false: dim0 and dim1 are *not* transposed within the context of
65 ///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
66 ///   and dim1 (1) are transposed within the full context of the of the
67 ///   transposition.
68 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
69                                        ArrayRef<int64_t> transp) {
70   // Perform a linear scan along the dimensions of the transposed pattern. If
71   // dim0 is found first, dim0 and dim1 are not transposed within the context of
72   // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
73   for (int64_t permDim : transp) {
74     if (permDim == dim0)
75       return false;
76     if (permDim == dim1)
77       return true;
78   }
79 
80   llvm_unreachable("Ill-formed transpose pattern");
81 }
82 
83 FailureOr<std::pair<int, int>>
84 mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
85   VectorType srcType = op.getSourceVectorType();
86   SmallVector<int64_t> srcGtOneDims;
87   for (auto [index, size] : llvm::enumerate(srcType.getShape()))
88     if (size > 1)
89       srcGtOneDims.push_back(index);
90 
91   if (srcGtOneDims.size() != 2)
92     return failure();
93 
94   // Check whether the two source vector dimensions that are greater than one
95   // must be transposed with each other so that we can apply one of the 2-D
96   // transpose pattens. Otherwise, these patterns are not applicable.
97   if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
98                                   op.getPermutation()))
99     return failure();
100 
101   return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
102 }
103 
104 /// Constructs a permutation map from memref indices to vector dimension.
105 ///
106 /// The implementation uses the knowledge of the mapping of enclosing loop to
107 /// vector dimension. `enclosingLoopToVectorDim` carries this information as a
108 /// map with:
109 ///   - keys representing "vectorized enclosing loops";
110 ///   - values representing the corresponding vector dimension.
111 /// The algorithm traverses "vectorized enclosing loops" and extracts the
112 /// at-most-one MemRef index that is invariant along said loop. This index is
113 /// guaranteed to be at most one by construction: otherwise the MemRef is not
114 /// vectorizable.
115 /// If this invariant index is found, it is added to the permutation_map at the
116 /// proper vector dimension.
117 /// If no index is found to be invariant, 0 is added to the permutation_map and
118 /// corresponds to a vector broadcast along that dimension.
119 ///
120 /// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
121 /// signalling that no permutation map can be constructed given
122 /// `enclosingLoopToVectorDim`.
123 ///
124 /// Examples can be found in the documentation of `makePermutationMap`, in the
125 /// header file.
126 static AffineMap makePermutationMap(
127     ArrayRef<Value> indices,
128     const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
129   if (enclosingLoopToVectorDim.empty())
130     return AffineMap();
131   MLIRContext *context =
132       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
133   SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
134                                getAffineConstantExpr(0, context));
135 
136   for (auto kvp : enclosingLoopToVectorDim) {
137     assert(kvp.second < perm.size());
138     auto invariants = affine::getInvariantAccesses(
139         cast<affine::AffineForOp>(kvp.first).getInductionVar(), indices);
140     unsigned numIndices = indices.size();
141     unsigned countInvariantIndices = 0;
142     for (unsigned dim = 0; dim < numIndices; ++dim) {
143       if (!invariants.count(indices[dim])) {
144         assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
145                "permutationMap already has an entry along dim");
146         perm[kvp.second] = getAffineDimExpr(dim, context);
147       } else {
148         ++countInvariantIndices;
149       }
150     }
151     assert((countInvariantIndices == numIndices ||
152             countInvariantIndices == numIndices - 1) &&
153            "Vectorization prerequisite violated: at most 1 index may be "
154            "invariant wrt a vectorized loop");
155     (void)countInvariantIndices;
156   }
157   return AffineMap::get(indices.size(), 0, perm, context);
158 }
159 
160 /// Implementation detail that walks up the parents and records the ones with
161 /// the specified type.
162 /// TODO: could also be implemented as a collect parents followed by a
163 /// filter and made available outside this file.
164 template <typename T>
165 static SetVector<Operation *> getParentsOfType(Block *block) {
166   SetVector<Operation *> res;
167   auto *current = block->getParentOp();
168   while (current) {
169     if ([[maybe_unused]] auto typedParent = dyn_cast<T>(current)) {
170       assert(res.count(current) == 0 && "Already inserted");
171       res.insert(current);
172     }
173     current = current->getParentOp();
174   }
175   return res;
176 }
177 
178 /// Returns the enclosing AffineForOp, from closest to farthest.
179 static SetVector<Operation *> getEnclosingforOps(Block *block) {
180   return getParentsOfType<affine::AffineForOp>(block);
181 }
182 
183 AffineMap mlir::makePermutationMap(
184     Block *insertPoint, ArrayRef<Value> indices,
185     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
186   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
187   auto enclosingLoops = getEnclosingforOps(insertPoint);
188   for (auto *forInst : enclosingLoops) {
189     auto it = loopToVectorDim.find(forInst);
190     if (it != loopToVectorDim.end()) {
191       enclosingLoopToVectorDim.insert(*it);
192     }
193   }
194   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
195 }
196 
197 AffineMap mlir::makePermutationMap(
198     Operation *op, ArrayRef<Value> indices,
199     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
200   return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
201 }
202 
203 bool matcher::operatesOnSuperVectorsOf(Operation &op,
204                                        VectorType subVectorType) {
205   // First, extract the vector type and distinguish between:
206   //   a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
207   //      vector.transfer_write); and
208   //   b. ops that *may* lower a super-vector (all other ops).
209   // The ops that *may* lower a super-vector only do so if the super-vector to
210   // sub-vector ratio exists. The ops that *must* lower a super-vector are
211   // explicitly checked for this property.
212   /// TODO: there should be a single function for all ops to do this so we
213   /// do not have to special case. Maybe a trait, or just a method, unclear atm.
214   bool mustDivide = false;
215   (void)mustDivide;
216   VectorType superVectorType;
217   if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
218     superVectorType = transfer.getVectorType();
219     mustDivide = true;
220   } else if (op.getNumResults() == 0) {
221     if (!isa<func::ReturnOp>(op)) {
222       op.emitError("NYI: assuming only return operations can have 0 "
223                    " results at this point");
224     }
225     return false;
226   } else if (op.getNumResults() == 1) {
227     if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) {
228       superVectorType = v;
229     } else {
230       // Not a vector type.
231       return false;
232     }
233   } else {
234     // Not a vector.transfer and has more than 1 result, fail hard for now to
235     // wake us up when something changes.
236     op.emitError("NYI: operation has more than 1 result");
237     return false;
238   }
239 
240   // Get the ratio.
241   auto ratio =
242       computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
243 
244   // Sanity check.
245   assert((ratio || !mustDivide) &&
246          "vector.transfer operation in which super-vector size is not an"
247          " integer multiple of sub-vector size");
248 
249   // This catches cases that are not strictly necessary to have multiplicity but
250   // still aren't divisible by the sub-vector shape.
251   // This could be useful information if we wanted to reshape at the level of
252   // the vector type (but we would have to look at the compute and distinguish
253   // between parallel, reduction and possibly other cases.
254   return ratio.has_value();
255 }
256 
257 bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
258   if (vectorType.isScalable())
259     return false;
260 
261   ArrayRef<int64_t> vectorShape = vectorType.getShape();
262   auto vecRank = vectorType.getRank();
263 
264   if (!memrefType.areTrailingDimsContiguous(vecRank))
265     return false;
266 
267   // Extract the trailing dims and strides of the input memref
268   auto memrefShape = memrefType.getShape().take_back(vecRank);
269 
270   // Compare the dims of `vectorType` against `memrefType` (in reverse).
271   // In the most basic case, all dims will match.
272   auto firstNonMatchingDim =
273       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
274                     memrefShape.rbegin(), memrefShape.rend());
275   if (firstNonMatchingDim.first == vectorShape.rend())
276     return true;
277 
278   // One non-matching dim is still fine, however the remaining leading dims of
279   // `vectorType` need to be 1.
280   SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
281                                    vectorShape.rend());
282 
283   return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
284 }
285 
286 std::optional<StaticTileOffsetRange>
287 vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
288   if (vType.getRank() <= targetRank)
289     return {};
290   // Attempt to unroll until targetRank or the first scalable dimension (which
291   // cannot be unrolled).
292   auto shapeToUnroll = vType.getShape().drop_back(targetRank);
293   auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
294   auto it =
295       std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
296   auto firstScalableDim = it - scalableDimsToUnroll.begin();
297   if (firstScalableDim == 0)
298     return {};
299   // All scalable dimensions should be removed now.
300   scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
301   assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
302          "unexpected leading scalable dimension");
303   // Create an unroll iterator for leading dimensions.
304   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
305   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
306 }
307 
308 SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
309                                                     Operation *xfer,
310                                                     RewriterBase &rewriter) {
311   auto loc = xfer->getLoc();
312 
313   Value base = TypeSwitch<Operation *, Value>(xfer)
314                    .Case<vector::TransferReadOp>(
315                        [&](auto readOp) { return readOp.getSource(); })
316                    .Case<vector::TransferWriteOp>(
317                        [&](auto writeOp) { return writeOp.getOperand(1); });
318 
319   SmallVector<OpFoldResult> mixedSourceDims =
320       hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
321                          : memref::getMixedSizes(rewriter, loc, base);
322   return mixedSourceDims;
323 }
324 
325 bool vector::isLinearizableVector(VectorType type) {
326   return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
327 }
328 
329 Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
330                                      Value source, ArrayRef<int64_t> readShape,
331                                      Value padValue,
332                                      bool useInBoundsInsteadOfMasking) {
333   assert(llvm::none_of(readShape,
334                        [](int64_t s) { return s == ShapedType::kDynamic; }) &&
335          "expected static shape");
336   auto sourceShapedType = cast<ShapedType>(source.getType());
337   auto sourceShape = sourceShapedType.getShape();
338   assert(sourceShape.size() == readShape.size() && "expected same ranks.");
339   auto maskType = VectorType::get(readShape, builder.getI1Type());
340   auto vectorType = VectorType::get(readShape, padValue.getType());
341   assert(padValue.getType() == sourceShapedType.getElementType() &&
342          "expected same pad element type to match source element type");
343   int64_t readRank = readShape.size();
344   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
345   SmallVector<bool> inBoundsVal(readRank, true);
346   if (useInBoundsInsteadOfMasking) {
347     // Update the inBounds attribute.
348     for (unsigned i = 0; i < readRank; i++)
349       inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
350                        !ShapedType::isDynamic(sourceShape[i]);
351   }
352   auto transferReadOp = builder.create<vector::TransferReadOp>(
353       loc,
354       /*vectorType=*/vectorType,
355       /*source=*/source,
356       /*indices=*/SmallVector<Value>(readRank, zero),
357       /*padding=*/padValue,
358       /*inBounds=*/inBoundsVal);
359 
360   if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking)
361     return transferReadOp;
362   SmallVector<OpFoldResult> mixedSourceDims =
363       tensor::getMixedSizes(builder, loc, source);
364   Value mask =
365       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
366   return mlir::vector::maskOperation(builder, transferReadOp, mask)
367       ->getResult(0);
368 }
369 
370 LogicalResult
371 vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
372                                  ArrayRef<int64_t> inputVectorSizes) {
373   LDBG("Iteration space static sizes:");
374   LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
375   LLVM_DEBUG(llvm::dbgs() << "\n");
376 
377   if (inputVectorSizes.size() != shape.size()) {
378     LDBG("Input vector sizes don't match the number of loops");
379     return failure();
380   }
381   if (ShapedType::isDynamicShape(inputVectorSizes)) {
382     LDBG("Input vector sizes can't have dynamic dimensions");
383     return failure();
384   }
385   if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
386                     [](std::tuple<int64_t, int64_t> sizePair) {
387                       int64_t staticSize = std::get<0>(sizePair);
388                       int64_t inputSize = std::get<1>(sizePair);
389                       return ShapedType::isDynamic(staticSize) ||
390                              staticSize <= inputSize;
391                     })) {
392     LDBG("Input vector sizes must be greater than or equal to iteration space "
393          "static sizes");
394     return failure();
395   }
396   return success();
397 }
398