xref: /llvm-project/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
199ef9eebSMatthias Springer //===- VectorUtils.cpp - MLIR Utilities for VectorOps   ------------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements utility methods for working with the Vector dialect.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer 
1399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1499ef9eebSMatthias Springer 
1599ef9eebSMatthias Springer #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
217a69a9d7SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2399ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
2499ef9eebSMatthias Springer #include "mlir/IR/IntegerSet.h"
2599ef9eebSMatthias Springer #include "mlir/IR/Operation.h"
269b5a3d14SMatthias Springer #include "mlir/IR/TypeUtilities.h"
2799ef9eebSMatthias Springer #include "mlir/Support/LLVM.h"
2899ef9eebSMatthias Springer 
2999ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
3099ef9eebSMatthias Springer #include "llvm/ADT/SetVector.h"
3199ef9eebSMatthias Springer 
3230d4f6afSLubomir Litchev #define DEBUG_TYPE "vector-utils"
3330d4f6afSLubomir Litchev 
3430d4f6afSLubomir Litchev #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
3530d4f6afSLubomir Litchev #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
3630d4f6afSLubomir Litchev 
3799ef9eebSMatthias Springer using namespace mlir;
3899ef9eebSMatthias Springer 
3999ef9eebSMatthias Springer /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
4099ef9eebSMatthias Springer /// the type of `source`.
4199ef9eebSMatthias Springer Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
4299ef9eebSMatthias Springer                                       int64_t dim) {
435550c821STres Popp   if (isa<UnrankedMemRefType, MemRefType>(source.getType()))
4499ef9eebSMatthias Springer     return b.createOrFold<memref::DimOp>(loc, source, dim);
455550c821STres Popp   if (isa<UnrankedTensorType, RankedTensorType>(source.getType()))
4699ef9eebSMatthias Springer     return b.createOrFold<tensor::DimOp>(loc, source, dim);
4799ef9eebSMatthias Springer   llvm_unreachable("Expected MemRefType or TensorType");
4899ef9eebSMatthias Springer }
4999ef9eebSMatthias Springer 
5025cc5a71SHanhan Wang /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
5125cc5a71SHanhan Wang /// should be transposed with each other within the context of their 2D
5225cc5a71SHanhan Wang /// transposition slice.
5325cc5a71SHanhan Wang ///
5425cc5a71SHanhan Wang /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
5525cc5a71SHanhan Wang ///   Return true: dim0 and dim1 are transposed within the context of their 2D
5625cc5a71SHanhan Wang ///   transposition slice ([1, 0]).
5725cc5a71SHanhan Wang ///
5825cc5a71SHanhan Wang /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
5925cc5a71SHanhan Wang ///   Return true: dim0 and dim1 are transposed within the context of their 2D
6025cc5a71SHanhan Wang ///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
6125cc5a71SHanhan Wang ///   transposed within the full context of the transposition.
6225cc5a71SHanhan Wang ///
6325cc5a71SHanhan Wang /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
6425cc5a71SHanhan Wang ///   Return false: dim0 and dim1 are *not* transposed within the context of
6525cc5a71SHanhan Wang ///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
6625cc5a71SHanhan Wang ///   and dim1 (1) are transposed within the full context of the of the
6725cc5a71SHanhan Wang ///   transposition.
6825cc5a71SHanhan Wang static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
6925cc5a71SHanhan Wang                                        ArrayRef<int64_t> transp) {
7025cc5a71SHanhan Wang   // Perform a linear scan along the dimensions of the transposed pattern. If
7125cc5a71SHanhan Wang   // dim0 is found first, dim0 and dim1 are not transposed within the context of
7225cc5a71SHanhan Wang   // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
7325cc5a71SHanhan Wang   for (int64_t permDim : transp) {
7425cc5a71SHanhan Wang     if (permDim == dim0)
7525cc5a71SHanhan Wang       return false;
7625cc5a71SHanhan Wang     if (permDim == dim1)
7725cc5a71SHanhan Wang       return true;
7825cc5a71SHanhan Wang   }
7925cc5a71SHanhan Wang 
8025cc5a71SHanhan Wang   llvm_unreachable("Ill-formed transpose pattern");
8125cc5a71SHanhan Wang }
8225cc5a71SHanhan Wang 
8325cc5a71SHanhan Wang FailureOr<std::pair<int, int>>
8425cc5a71SHanhan Wang mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
8525cc5a71SHanhan Wang   VectorType srcType = op.getSourceVectorType();
8625cc5a71SHanhan Wang   SmallVector<int64_t> srcGtOneDims;
8725cc5a71SHanhan Wang   for (auto [index, size] : llvm::enumerate(srcType.getShape()))
8825cc5a71SHanhan Wang     if (size > 1)
8925cc5a71SHanhan Wang       srcGtOneDims.push_back(index);
9025cc5a71SHanhan Wang 
9125cc5a71SHanhan Wang   if (srcGtOneDims.size() != 2)
9225cc5a71SHanhan Wang     return failure();
9325cc5a71SHanhan Wang 
9425cc5a71SHanhan Wang   // Check whether the two source vector dimensions that are greater than one
9525cc5a71SHanhan Wang   // must be transposed with each other so that we can apply one of the 2-D
9625cc5a71SHanhan Wang   // transpose pattens. Otherwise, these patterns are not applicable.
9732c3decbSMatthias Springer   if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
9832c3decbSMatthias Springer                                   op.getPermutation()))
9925cc5a71SHanhan Wang     return failure();
10025cc5a71SHanhan Wang 
10125cc5a71SHanhan Wang   return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
10225cc5a71SHanhan Wang }
10325cc5a71SHanhan Wang 
10499ef9eebSMatthias Springer /// Constructs a permutation map from memref indices to vector dimension.
10599ef9eebSMatthias Springer ///
10699ef9eebSMatthias Springer /// The implementation uses the knowledge of the mapping of enclosing loop to
10799ef9eebSMatthias Springer /// vector dimension. `enclosingLoopToVectorDim` carries this information as a
10899ef9eebSMatthias Springer /// map with:
10999ef9eebSMatthias Springer ///   - keys representing "vectorized enclosing loops";
11099ef9eebSMatthias Springer ///   - values representing the corresponding vector dimension.
11199ef9eebSMatthias Springer /// The algorithm traverses "vectorized enclosing loops" and extracts the
11299ef9eebSMatthias Springer /// at-most-one MemRef index that is invariant along said loop. This index is
11399ef9eebSMatthias Springer /// guaranteed to be at most one by construction: otherwise the MemRef is not
11499ef9eebSMatthias Springer /// vectorizable.
11599ef9eebSMatthias Springer /// If this invariant index is found, it is added to the permutation_map at the
11699ef9eebSMatthias Springer /// proper vector dimension.
11799ef9eebSMatthias Springer /// If no index is found to be invariant, 0 is added to the permutation_map and
11899ef9eebSMatthias Springer /// corresponds to a vector broadcast along that dimension.
11999ef9eebSMatthias Springer ///
12099ef9eebSMatthias Springer /// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
12199ef9eebSMatthias Springer /// signalling that no permutation map can be constructed given
12299ef9eebSMatthias Springer /// `enclosingLoopToVectorDim`.
12399ef9eebSMatthias Springer ///
12499ef9eebSMatthias Springer /// Examples can be found in the documentation of `makePermutationMap`, in the
12599ef9eebSMatthias Springer /// header file.
12699ef9eebSMatthias Springer static AffineMap makePermutationMap(
12799ef9eebSMatthias Springer     ArrayRef<Value> indices,
12899ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
12999ef9eebSMatthias Springer   if (enclosingLoopToVectorDim.empty())
13099ef9eebSMatthias Springer     return AffineMap();
13199ef9eebSMatthias Springer   MLIRContext *context =
13299ef9eebSMatthias Springer       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
1337a69a9d7SNicolas Vasilache   SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
13499ef9eebSMatthias Springer                                getAffineConstantExpr(0, context));
13599ef9eebSMatthias Springer 
13699ef9eebSMatthias Springer   for (auto kvp : enclosingLoopToVectorDim) {
13799ef9eebSMatthias Springer     assert(kvp.second < perm.size());
1384c48f016SMatthias Springer     auto invariants = affine::getInvariantAccesses(
1394c48f016SMatthias Springer         cast<affine::AffineForOp>(kvp.first).getInductionVar(), indices);
14099ef9eebSMatthias Springer     unsigned numIndices = indices.size();
14199ef9eebSMatthias Springer     unsigned countInvariantIndices = 0;
14299ef9eebSMatthias Springer     for (unsigned dim = 0; dim < numIndices; ++dim) {
14399ef9eebSMatthias Springer       if (!invariants.count(indices[dim])) {
14499ef9eebSMatthias Springer         assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
14599ef9eebSMatthias Springer                "permutationMap already has an entry along dim");
14699ef9eebSMatthias Springer         perm[kvp.second] = getAffineDimExpr(dim, context);
14799ef9eebSMatthias Springer       } else {
14899ef9eebSMatthias Springer         ++countInvariantIndices;
14999ef9eebSMatthias Springer       }
15099ef9eebSMatthias Springer     }
15199ef9eebSMatthias Springer     assert((countInvariantIndices == numIndices ||
15299ef9eebSMatthias Springer             countInvariantIndices == numIndices - 1) &&
15399ef9eebSMatthias Springer            "Vectorization prerequisite violated: at most 1 index may be "
15499ef9eebSMatthias Springer            "invariant wrt a vectorized loop");
15522a4b336SFangrui Song     (void)countInvariantIndices;
15699ef9eebSMatthias Springer   }
15799ef9eebSMatthias Springer   return AffineMap::get(indices.size(), 0, perm, context);
15899ef9eebSMatthias Springer }
15999ef9eebSMatthias Springer 
16099ef9eebSMatthias Springer /// Implementation detail that walks up the parents and records the ones with
16199ef9eebSMatthias Springer /// the specified type.
16299ef9eebSMatthias Springer /// TODO: could also be implemented as a collect parents followed by a
16399ef9eebSMatthias Springer /// filter and made available outside this file.
16499ef9eebSMatthias Springer template <typename T>
16599ef9eebSMatthias Springer static SetVector<Operation *> getParentsOfType(Block *block) {
16699ef9eebSMatthias Springer   SetVector<Operation *> res;
16799ef9eebSMatthias Springer   auto *current = block->getParentOp();
16899ef9eebSMatthias Springer   while (current) {
1690a0aff2dSMikhail Goncharov     if ([[maybe_unused]] auto typedParent = dyn_cast<T>(current)) {
17099ef9eebSMatthias Springer       assert(res.count(current) == 0 && "Already inserted");
17199ef9eebSMatthias Springer       res.insert(current);
17299ef9eebSMatthias Springer     }
17399ef9eebSMatthias Springer     current = current->getParentOp();
17499ef9eebSMatthias Springer   }
17599ef9eebSMatthias Springer   return res;
17699ef9eebSMatthias Springer }
17799ef9eebSMatthias Springer 
17899ef9eebSMatthias Springer /// Returns the enclosing AffineForOp, from closest to farthest.
17999ef9eebSMatthias Springer static SetVector<Operation *> getEnclosingforOps(Block *block) {
1804c48f016SMatthias Springer   return getParentsOfType<affine::AffineForOp>(block);
18199ef9eebSMatthias Springer }
18299ef9eebSMatthias Springer 
18399ef9eebSMatthias Springer AffineMap mlir::makePermutationMap(
18499ef9eebSMatthias Springer     Block *insertPoint, ArrayRef<Value> indices,
18599ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
18699ef9eebSMatthias Springer   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
18799ef9eebSMatthias Springer   auto enclosingLoops = getEnclosingforOps(insertPoint);
18899ef9eebSMatthias Springer   for (auto *forInst : enclosingLoops) {
18999ef9eebSMatthias Springer     auto it = loopToVectorDim.find(forInst);
19099ef9eebSMatthias Springer     if (it != loopToVectorDim.end()) {
19199ef9eebSMatthias Springer       enclosingLoopToVectorDim.insert(*it);
19299ef9eebSMatthias Springer     }
19399ef9eebSMatthias Springer   }
19499ef9eebSMatthias Springer   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
19599ef9eebSMatthias Springer }
19699ef9eebSMatthias Springer 
19799ef9eebSMatthias Springer AffineMap mlir::makePermutationMap(
19899ef9eebSMatthias Springer     Operation *op, ArrayRef<Value> indices,
19999ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
20099ef9eebSMatthias Springer   return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
20199ef9eebSMatthias Springer }
20299ef9eebSMatthias Springer 
20399ef9eebSMatthias Springer bool matcher::operatesOnSuperVectorsOf(Operation &op,
20499ef9eebSMatthias Springer                                        VectorType subVectorType) {
20599ef9eebSMatthias Springer   // First, extract the vector type and distinguish between:
20699ef9eebSMatthias Springer   //   a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
20799ef9eebSMatthias Springer   //      vector.transfer_write); and
20899ef9eebSMatthias Springer   //   b. ops that *may* lower a super-vector (all other ops).
20999ef9eebSMatthias Springer   // The ops that *may* lower a super-vector only do so if the super-vector to
21099ef9eebSMatthias Springer   // sub-vector ratio exists. The ops that *must* lower a super-vector are
21199ef9eebSMatthias Springer   // explicitly checked for this property.
21299ef9eebSMatthias Springer   /// TODO: there should be a single function for all ops to do this so we
21399ef9eebSMatthias Springer   /// do not have to special case. Maybe a trait, or just a method, unclear atm.
21499ef9eebSMatthias Springer   bool mustDivide = false;
21599ef9eebSMatthias Springer   (void)mustDivide;
21699ef9eebSMatthias Springer   VectorType superVectorType;
21799ef9eebSMatthias Springer   if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
21899ef9eebSMatthias Springer     superVectorType = transfer.getVectorType();
21999ef9eebSMatthias Springer     mustDivide = true;
22099ef9eebSMatthias Springer   } else if (op.getNumResults() == 0) {
22123aa5a74SRiver Riddle     if (!isa<func::ReturnOp>(op)) {
22299ef9eebSMatthias Springer       op.emitError("NYI: assuming only return operations can have 0 "
22399ef9eebSMatthias Springer                    " results at this point");
22499ef9eebSMatthias Springer     }
22599ef9eebSMatthias Springer     return false;
22699ef9eebSMatthias Springer   } else if (op.getNumResults() == 1) {
2275550c821STres Popp     if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) {
22899ef9eebSMatthias Springer       superVectorType = v;
22999ef9eebSMatthias Springer     } else {
23099ef9eebSMatthias Springer       // Not a vector type.
23199ef9eebSMatthias Springer       return false;
23299ef9eebSMatthias Springer     }
23399ef9eebSMatthias Springer   } else {
23499ef9eebSMatthias Springer     // Not a vector.transfer and has more than 1 result, fail hard for now to
23599ef9eebSMatthias Springer     // wake us up when something changes.
23699ef9eebSMatthias Springer     op.emitError("NYI: operation has more than 1 result");
23799ef9eebSMatthias Springer     return false;
23899ef9eebSMatthias Springer   }
23999ef9eebSMatthias Springer 
24099ef9eebSMatthias Springer   // Get the ratio.
2417a69a9d7SNicolas Vasilache   auto ratio =
2427a69a9d7SNicolas Vasilache       computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
24399ef9eebSMatthias Springer 
24499ef9eebSMatthias Springer   // Sanity check.
2455413bf1bSKazu Hirata   assert((ratio || !mustDivide) &&
24699ef9eebSMatthias Springer          "vector.transfer operation in which super-vector size is not an"
24799ef9eebSMatthias Springer          " integer multiple of sub-vector size");
24899ef9eebSMatthias Springer 
24999ef9eebSMatthias Springer   // This catches cases that are not strictly necessary to have multiplicity but
25099ef9eebSMatthias Springer   // still aren't divisible by the sub-vector shape.
25199ef9eebSMatthias Springer   // This could be useful information if we wanted to reshape at the level of
25299ef9eebSMatthias Springer   // the vector type (but we would have to look at the compute and distinguish
25399ef9eebSMatthias Springer   // between parallel, reduction and possibly other cases.
254064a08cdSKazu Hirata   return ratio.has_value();
25599ef9eebSMatthias Springer }
2568171eac2SAndrzej Warzyński 
2578171eac2SAndrzej Warzyński bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
2588171eac2SAndrzej Warzyński   if (vectorType.isScalable())
2598171eac2SAndrzej Warzyński     return false;
2608171eac2SAndrzej Warzyński 
2618171eac2SAndrzej Warzyński   ArrayRef<int64_t> vectorShape = vectorType.getShape();
2628171eac2SAndrzej Warzyński   auto vecRank = vectorType.getRank();
2638171eac2SAndrzej Warzyński 
264*6aaa8f25SMatthias Springer   if (!memrefType.areTrailingDimsContiguous(vecRank))
2659478bf0cSAndrzej Warzyński     return false;
2669478bf0cSAndrzej Warzyński 
2678171eac2SAndrzej Warzyński   // Extract the trailing dims and strides of the input memref
2688171eac2SAndrzej Warzyński   auto memrefShape = memrefType.getShape().take_back(vecRank);
2698171eac2SAndrzej Warzyński 
2709478bf0cSAndrzej Warzyński   // Compare the dims of `vectorType` against `memrefType` (in reverse).
2718171eac2SAndrzej Warzyński   // In the most basic case, all dims will match.
2728171eac2SAndrzej Warzyński   auto firstNonMatchingDim =
2738171eac2SAndrzej Warzyński       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
2748171eac2SAndrzej Warzyński                     memrefShape.rbegin(), memrefShape.rend());
2758171eac2SAndrzej Warzyński   if (firstNonMatchingDim.first == vectorShape.rend())
2768171eac2SAndrzej Warzyński     return true;
2778171eac2SAndrzej Warzyński 
2788171eac2SAndrzej Warzyński   // One non-matching dim is still fine, however the remaining leading dims of
2798171eac2SAndrzej Warzyński   // `vectorType` need to be 1.
2808171eac2SAndrzej Warzyński   SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
2818171eac2SAndrzej Warzyński                                    vectorShape.rend());
2828171eac2SAndrzej Warzyński 
2838171eac2SAndrzej Warzyński   return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
2848171eac2SAndrzej Warzyński }
285a1a68603SBenjamin Maxwell 
286a1a68603SBenjamin Maxwell std::optional<StaticTileOffsetRange>
287a1a68603SBenjamin Maxwell vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
288a1a68603SBenjamin Maxwell   if (vType.getRank() <= targetRank)
289a1a68603SBenjamin Maxwell     return {};
290a1a68603SBenjamin Maxwell   // Attempt to unroll until targetRank or the first scalable dimension (which
291a1a68603SBenjamin Maxwell   // cannot be unrolled).
292a1a68603SBenjamin Maxwell   auto shapeToUnroll = vType.getShape().drop_back(targetRank);
293a1a68603SBenjamin Maxwell   auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
294a1a68603SBenjamin Maxwell   auto it =
295a1a68603SBenjamin Maxwell       std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
296a1a68603SBenjamin Maxwell   auto firstScalableDim = it - scalableDimsToUnroll.begin();
297a1a68603SBenjamin Maxwell   if (firstScalableDim == 0)
298a1a68603SBenjamin Maxwell     return {};
299a1a68603SBenjamin Maxwell   // All scalable dimensions should be removed now.
300a1a68603SBenjamin Maxwell   scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
301a1a68603SBenjamin Maxwell   assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
302a1a68603SBenjamin Maxwell          "unexpected leading scalable dimension");
303a1a68603SBenjamin Maxwell   // Create an unroll iterator for leading dimensions.
304a1a68603SBenjamin Maxwell   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
305a1a68603SBenjamin Maxwell   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
306a1a68603SBenjamin Maxwell }
307c56bd7abSAndrzej Warzyński 
308c56bd7abSAndrzej Warzyński SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
309c56bd7abSAndrzej Warzyński                                                     Operation *xfer,
310c56bd7abSAndrzej Warzyński                                                     RewriterBase &rewriter) {
311c56bd7abSAndrzej Warzyński   auto loc = xfer->getLoc();
312c56bd7abSAndrzej Warzyński 
313c56bd7abSAndrzej Warzyński   Value base = TypeSwitch<Operation *, Value>(xfer)
314c56bd7abSAndrzej Warzyński                    .Case<vector::TransferReadOp>(
315c56bd7abSAndrzej Warzyński                        [&](auto readOp) { return readOp.getSource(); })
316c56bd7abSAndrzej Warzyński                    .Case<vector::TransferWriteOp>(
317c56bd7abSAndrzej Warzyński                        [&](auto writeOp) { return writeOp.getOperand(1); });
318c56bd7abSAndrzej Warzyński 
319c56bd7abSAndrzej Warzyński   SmallVector<OpFoldResult> mixedSourceDims =
320c56bd7abSAndrzej Warzyński       hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
321c56bd7abSAndrzej Warzyński                          : memref::getMixedSizes(rewriter, loc, base);
322c56bd7abSAndrzej Warzyński   return mixedSourceDims;
323c56bd7abSAndrzej Warzyński }
324d3aa92edSAndrzej Warzyński 
325d3aa92edSAndrzej Warzyński bool vector::isLinearizableVector(VectorType type) {
326fe07d9aaSAndrzej Warzyński   return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
327d3aa92edSAndrzej Warzyński }
32830d4f6afSLubomir Litchev 
32930d4f6afSLubomir Litchev Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
33030d4f6afSLubomir Litchev                                      Value source, ArrayRef<int64_t> readShape,
33130d4f6afSLubomir Litchev                                      Value padValue,
33230d4f6afSLubomir Litchev                                      bool useInBoundsInsteadOfMasking) {
33330d4f6afSLubomir Litchev   assert(llvm::none_of(readShape,
33430d4f6afSLubomir Litchev                        [](int64_t s) { return s == ShapedType::kDynamic; }) &&
33530d4f6afSLubomir Litchev          "expected static shape");
33630d4f6afSLubomir Litchev   auto sourceShapedType = cast<ShapedType>(source.getType());
33730d4f6afSLubomir Litchev   auto sourceShape = sourceShapedType.getShape();
33830d4f6afSLubomir Litchev   assert(sourceShape.size() == readShape.size() && "expected same ranks.");
33930d4f6afSLubomir Litchev   auto maskType = VectorType::get(readShape, builder.getI1Type());
34030d4f6afSLubomir Litchev   auto vectorType = VectorType::get(readShape, padValue.getType());
34130d4f6afSLubomir Litchev   assert(padValue.getType() == sourceShapedType.getElementType() &&
34230d4f6afSLubomir Litchev          "expected same pad element type to match source element type");
34330d4f6afSLubomir Litchev   int64_t readRank = readShape.size();
34430d4f6afSLubomir Litchev   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
34530d4f6afSLubomir Litchev   SmallVector<bool> inBoundsVal(readRank, true);
3468feedd5eSPrashant Kumar   if (useInBoundsInsteadOfMasking) {
34730d4f6afSLubomir Litchev     // Update the inBounds attribute.
34830d4f6afSLubomir Litchev     for (unsigned i = 0; i < readRank; i++)
34930d4f6afSLubomir Litchev       inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
35030d4f6afSLubomir Litchev                        !ShapedType::isDynamic(sourceShape[i]);
35130d4f6afSLubomir Litchev   }
35230d4f6afSLubomir Litchev   auto transferReadOp = builder.create<vector::TransferReadOp>(
35330d4f6afSLubomir Litchev       loc,
35430d4f6afSLubomir Litchev       /*vectorType=*/vectorType,
35530d4f6afSLubomir Litchev       /*source=*/source,
35630d4f6afSLubomir Litchev       /*indices=*/SmallVector<Value>(readRank, zero),
35730d4f6afSLubomir Litchev       /*padding=*/padValue,
35830d4f6afSLubomir Litchev       /*inBounds=*/inBoundsVal);
35930d4f6afSLubomir Litchev 
3608feedd5eSPrashant Kumar   if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking)
36130d4f6afSLubomir Litchev     return transferReadOp;
36230d4f6afSLubomir Litchev   SmallVector<OpFoldResult> mixedSourceDims =
36330d4f6afSLubomir Litchev       tensor::getMixedSizes(builder, loc, source);
36430d4f6afSLubomir Litchev   Value mask =
36530d4f6afSLubomir Litchev       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
36630d4f6afSLubomir Litchev   return mlir::vector::maskOperation(builder, transferReadOp, mask)
36730d4f6afSLubomir Litchev       ->getResult(0);
36830d4f6afSLubomir Litchev }
36930d4f6afSLubomir Litchev 
37030d4f6afSLubomir Litchev LogicalResult
37130d4f6afSLubomir Litchev vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
37230d4f6afSLubomir Litchev                                  ArrayRef<int64_t> inputVectorSizes) {
37330d4f6afSLubomir Litchev   LDBG("Iteration space static sizes:");
37430d4f6afSLubomir Litchev   LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
37530d4f6afSLubomir Litchev   LLVM_DEBUG(llvm::dbgs() << "\n");
37630d4f6afSLubomir Litchev 
37730d4f6afSLubomir Litchev   if (inputVectorSizes.size() != shape.size()) {
37830d4f6afSLubomir Litchev     LDBG("Input vector sizes don't match the number of loops");
37930d4f6afSLubomir Litchev     return failure();
38030d4f6afSLubomir Litchev   }
38130d4f6afSLubomir Litchev   if (ShapedType::isDynamicShape(inputVectorSizes)) {
38230d4f6afSLubomir Litchev     LDBG("Input vector sizes can't have dynamic dimensions");
38330d4f6afSLubomir Litchev     return failure();
38430d4f6afSLubomir Litchev   }
38530d4f6afSLubomir Litchev   if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
38630d4f6afSLubomir Litchev                     [](std::tuple<int64_t, int64_t> sizePair) {
38730d4f6afSLubomir Litchev                       int64_t staticSize = std::get<0>(sizePair);
38830d4f6afSLubomir Litchev                       int64_t inputSize = std::get<1>(sizePair);
38930d4f6afSLubomir Litchev                       return ShapedType::isDynamic(staticSize) ||
39030d4f6afSLubomir Litchev                              staticSize <= inputSize;
39130d4f6afSLubomir Litchev                     })) {
39230d4f6afSLubomir Litchev     LDBG("Input vector sizes must be greater than or equal to iteration space "
39330d4f6afSLubomir Litchev          "static sizes");
39430d4f6afSLubomir Litchev     return failure();
39530d4f6afSLubomir Litchev   }
39630d4f6afSLubomir Litchev   return success();
39730d4f6afSLubomir Litchev }
398