199ef9eebSMatthias Springer //===- VectorUtils.cpp - MLIR Utilities for VectorOps ------------------===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer // 999ef9eebSMatthias Springer // This file implements utility methods for working with the Vector dialect. 1099ef9eebSMatthias Springer // 1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 1299ef9eebSMatthias Springer 1399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 1499ef9eebSMatthias Springer 1599ef9eebSMatthias Springer #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 1699ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1999ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 2099ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 217a69a9d7SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 2399ef9eebSMatthias Springer #include "mlir/IR/Builders.h" 2499ef9eebSMatthias Springer #include "mlir/IR/IntegerSet.h" 2599ef9eebSMatthias Springer #include "mlir/IR/Operation.h" 269b5a3d14SMatthias Springer #include "mlir/IR/TypeUtilities.h" 2799ef9eebSMatthias Springer #include "mlir/Support/LLVM.h" 2899ef9eebSMatthias Springer 2999ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h" 3099ef9eebSMatthias Springer #include "llvm/ADT/SetVector.h" 3199ef9eebSMatthias Springer 3230d4f6afSLubomir Litchev #define DEBUG_TYPE "vector-utils" 3330d4f6afSLubomir Litchev 3430d4f6afSLubomir Litchev #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 3530d4f6afSLubomir Litchev #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 3630d4f6afSLubomir Litchev 3799ef9eebSMatthias Springer using namespace mlir; 3899ef9eebSMatthias Springer 3999ef9eebSMatthias Springer /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 4099ef9eebSMatthias Springer /// the type of `source`. 4199ef9eebSMatthias Springer Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, 4299ef9eebSMatthias Springer int64_t dim) { 435550c821STres Popp if (isa<UnrankedMemRefType, MemRefType>(source.getType())) 4499ef9eebSMatthias Springer return b.createOrFold<memref::DimOp>(loc, source, dim); 455550c821STres Popp if (isa<UnrankedTensorType, RankedTensorType>(source.getType())) 4699ef9eebSMatthias Springer return b.createOrFold<tensor::DimOp>(loc, source, dim); 4799ef9eebSMatthias Springer llvm_unreachable("Expected MemRefType or TensorType"); 4899ef9eebSMatthias Springer } 4999ef9eebSMatthias Springer 5025cc5a71SHanhan Wang /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1' 5125cc5a71SHanhan Wang /// should be transposed with each other within the context of their 2D 5225cc5a71SHanhan Wang /// transposition slice. 5325cc5a71SHanhan Wang /// 5425cc5a71SHanhan Wang /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0] 5525cc5a71SHanhan Wang /// Return true: dim0 and dim1 are transposed within the context of their 2D 5625cc5a71SHanhan Wang /// transposition slice ([1, 0]). 5725cc5a71SHanhan Wang /// 5825cc5a71SHanhan Wang /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0] 5925cc5a71SHanhan Wang /// Return true: dim0 and dim1 are transposed within the context of their 2D 6025cc5a71SHanhan Wang /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not* 6125cc5a71SHanhan Wang /// transposed within the full context of the transposition. 6225cc5a71SHanhan Wang /// 6325cc5a71SHanhan Wang /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1] 6425cc5a71SHanhan Wang /// Return false: dim0 and dim1 are *not* transposed within the context of 6525cc5a71SHanhan Wang /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0) 6625cc5a71SHanhan Wang /// and dim1 (1) are transposed within the full context of the of the 6725cc5a71SHanhan Wang /// transposition. 6825cc5a71SHanhan Wang static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, 6925cc5a71SHanhan Wang ArrayRef<int64_t> transp) { 7025cc5a71SHanhan Wang // Perform a linear scan along the dimensions of the transposed pattern. If 7125cc5a71SHanhan Wang // dim0 is found first, dim0 and dim1 are not transposed within the context of 7225cc5a71SHanhan Wang // their 2D slice. Otherwise, 'dim1' is found first and they are transposed. 7325cc5a71SHanhan Wang for (int64_t permDim : transp) { 7425cc5a71SHanhan Wang if (permDim == dim0) 7525cc5a71SHanhan Wang return false; 7625cc5a71SHanhan Wang if (permDim == dim1) 7725cc5a71SHanhan Wang return true; 7825cc5a71SHanhan Wang } 7925cc5a71SHanhan Wang 8025cc5a71SHanhan Wang llvm_unreachable("Ill-formed transpose pattern"); 8125cc5a71SHanhan Wang } 8225cc5a71SHanhan Wang 8325cc5a71SHanhan Wang FailureOr<std::pair<int, int>> 8425cc5a71SHanhan Wang mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { 8525cc5a71SHanhan Wang VectorType srcType = op.getSourceVectorType(); 8625cc5a71SHanhan Wang SmallVector<int64_t> srcGtOneDims; 8725cc5a71SHanhan Wang for (auto [index, size] : llvm::enumerate(srcType.getShape())) 8825cc5a71SHanhan Wang if (size > 1) 8925cc5a71SHanhan Wang srcGtOneDims.push_back(index); 9025cc5a71SHanhan Wang 9125cc5a71SHanhan Wang if (srcGtOneDims.size() != 2) 9225cc5a71SHanhan Wang return failure(); 9325cc5a71SHanhan Wang 9425cc5a71SHanhan Wang // Check whether the two source vector dimensions that are greater than one 9525cc5a71SHanhan Wang // must be transposed with each other so that we can apply one of the 2-D 9625cc5a71SHanhan Wang // transpose pattens. Otherwise, these patterns are not applicable. 9732c3decbSMatthias Springer if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], 9832c3decbSMatthias Springer op.getPermutation())) 9925cc5a71SHanhan Wang return failure(); 10025cc5a71SHanhan Wang 10125cc5a71SHanhan Wang return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]); 10225cc5a71SHanhan Wang } 10325cc5a71SHanhan Wang 10499ef9eebSMatthias Springer /// Constructs a permutation map from memref indices to vector dimension. 10599ef9eebSMatthias Springer /// 10699ef9eebSMatthias Springer /// The implementation uses the knowledge of the mapping of enclosing loop to 10799ef9eebSMatthias Springer /// vector dimension. `enclosingLoopToVectorDim` carries this information as a 10899ef9eebSMatthias Springer /// map with: 10999ef9eebSMatthias Springer /// - keys representing "vectorized enclosing loops"; 11099ef9eebSMatthias Springer /// - values representing the corresponding vector dimension. 11199ef9eebSMatthias Springer /// The algorithm traverses "vectorized enclosing loops" and extracts the 11299ef9eebSMatthias Springer /// at-most-one MemRef index that is invariant along said loop. This index is 11399ef9eebSMatthias Springer /// guaranteed to be at most one by construction: otherwise the MemRef is not 11499ef9eebSMatthias Springer /// vectorizable. 11599ef9eebSMatthias Springer /// If this invariant index is found, it is added to the permutation_map at the 11699ef9eebSMatthias Springer /// proper vector dimension. 11799ef9eebSMatthias Springer /// If no index is found to be invariant, 0 is added to the permutation_map and 11899ef9eebSMatthias Springer /// corresponds to a vector broadcast along that dimension. 11999ef9eebSMatthias Springer /// 12099ef9eebSMatthias Springer /// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty, 12199ef9eebSMatthias Springer /// signalling that no permutation map can be constructed given 12299ef9eebSMatthias Springer /// `enclosingLoopToVectorDim`. 12399ef9eebSMatthias Springer /// 12499ef9eebSMatthias Springer /// Examples can be found in the documentation of `makePermutationMap`, in the 12599ef9eebSMatthias Springer /// header file. 12699ef9eebSMatthias Springer static AffineMap makePermutationMap( 12799ef9eebSMatthias Springer ArrayRef<Value> indices, 12899ef9eebSMatthias Springer const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) { 12999ef9eebSMatthias Springer if (enclosingLoopToVectorDim.empty()) 13099ef9eebSMatthias Springer return AffineMap(); 13199ef9eebSMatthias Springer MLIRContext *context = 13299ef9eebSMatthias Springer enclosingLoopToVectorDim.begin()->getFirst()->getContext(); 1337a69a9d7SNicolas Vasilache SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(), 13499ef9eebSMatthias Springer getAffineConstantExpr(0, context)); 13599ef9eebSMatthias Springer 13699ef9eebSMatthias Springer for (auto kvp : enclosingLoopToVectorDim) { 13799ef9eebSMatthias Springer assert(kvp.second < perm.size()); 1384c48f016SMatthias Springer auto invariants = affine::getInvariantAccesses( 1394c48f016SMatthias Springer cast<affine::AffineForOp>(kvp.first).getInductionVar(), indices); 14099ef9eebSMatthias Springer unsigned numIndices = indices.size(); 14199ef9eebSMatthias Springer unsigned countInvariantIndices = 0; 14299ef9eebSMatthias Springer for (unsigned dim = 0; dim < numIndices; ++dim) { 14399ef9eebSMatthias Springer if (!invariants.count(indices[dim])) { 14499ef9eebSMatthias Springer assert(perm[kvp.second] == getAffineConstantExpr(0, context) && 14599ef9eebSMatthias Springer "permutationMap already has an entry along dim"); 14699ef9eebSMatthias Springer perm[kvp.second] = getAffineDimExpr(dim, context); 14799ef9eebSMatthias Springer } else { 14899ef9eebSMatthias Springer ++countInvariantIndices; 14999ef9eebSMatthias Springer } 15099ef9eebSMatthias Springer } 15199ef9eebSMatthias Springer assert((countInvariantIndices == numIndices || 15299ef9eebSMatthias Springer countInvariantIndices == numIndices - 1) && 15399ef9eebSMatthias Springer "Vectorization prerequisite violated: at most 1 index may be " 15499ef9eebSMatthias Springer "invariant wrt a vectorized loop"); 15522a4b336SFangrui Song (void)countInvariantIndices; 15699ef9eebSMatthias Springer } 15799ef9eebSMatthias Springer return AffineMap::get(indices.size(), 0, perm, context); 15899ef9eebSMatthias Springer } 15999ef9eebSMatthias Springer 16099ef9eebSMatthias Springer /// Implementation detail that walks up the parents and records the ones with 16199ef9eebSMatthias Springer /// the specified type. 16299ef9eebSMatthias Springer /// TODO: could also be implemented as a collect parents followed by a 16399ef9eebSMatthias Springer /// filter and made available outside this file. 16499ef9eebSMatthias Springer template <typename T> 16599ef9eebSMatthias Springer static SetVector<Operation *> getParentsOfType(Block *block) { 16699ef9eebSMatthias Springer SetVector<Operation *> res; 16799ef9eebSMatthias Springer auto *current = block->getParentOp(); 16899ef9eebSMatthias Springer while (current) { 1690a0aff2dSMikhail Goncharov if ([[maybe_unused]] auto typedParent = dyn_cast<T>(current)) { 17099ef9eebSMatthias Springer assert(res.count(current) == 0 && "Already inserted"); 17199ef9eebSMatthias Springer res.insert(current); 17299ef9eebSMatthias Springer } 17399ef9eebSMatthias Springer current = current->getParentOp(); 17499ef9eebSMatthias Springer } 17599ef9eebSMatthias Springer return res; 17699ef9eebSMatthias Springer } 17799ef9eebSMatthias Springer 17899ef9eebSMatthias Springer /// Returns the enclosing AffineForOp, from closest to farthest. 17999ef9eebSMatthias Springer static SetVector<Operation *> getEnclosingforOps(Block *block) { 1804c48f016SMatthias Springer return getParentsOfType<affine::AffineForOp>(block); 18199ef9eebSMatthias Springer } 18299ef9eebSMatthias Springer 18399ef9eebSMatthias Springer AffineMap mlir::makePermutationMap( 18499ef9eebSMatthias Springer Block *insertPoint, ArrayRef<Value> indices, 18599ef9eebSMatthias Springer const DenseMap<Operation *, unsigned> &loopToVectorDim) { 18699ef9eebSMatthias Springer DenseMap<Operation *, unsigned> enclosingLoopToVectorDim; 18799ef9eebSMatthias Springer auto enclosingLoops = getEnclosingforOps(insertPoint); 18899ef9eebSMatthias Springer for (auto *forInst : enclosingLoops) { 18999ef9eebSMatthias Springer auto it = loopToVectorDim.find(forInst); 19099ef9eebSMatthias Springer if (it != loopToVectorDim.end()) { 19199ef9eebSMatthias Springer enclosingLoopToVectorDim.insert(*it); 19299ef9eebSMatthias Springer } 19399ef9eebSMatthias Springer } 19499ef9eebSMatthias Springer return ::makePermutationMap(indices, enclosingLoopToVectorDim); 19599ef9eebSMatthias Springer } 19699ef9eebSMatthias Springer 19799ef9eebSMatthias Springer AffineMap mlir::makePermutationMap( 19899ef9eebSMatthias Springer Operation *op, ArrayRef<Value> indices, 19999ef9eebSMatthias Springer const DenseMap<Operation *, unsigned> &loopToVectorDim) { 20099ef9eebSMatthias Springer return makePermutationMap(op->getBlock(), indices, loopToVectorDim); 20199ef9eebSMatthias Springer } 20299ef9eebSMatthias Springer 20399ef9eebSMatthias Springer bool matcher::operatesOnSuperVectorsOf(Operation &op, 20499ef9eebSMatthias Springer VectorType subVectorType) { 20599ef9eebSMatthias Springer // First, extract the vector type and distinguish between: 20699ef9eebSMatthias Springer // a. ops that *must* lower a super-vector (i.e. vector.transfer_read, 20799ef9eebSMatthias Springer // vector.transfer_write); and 20899ef9eebSMatthias Springer // b. ops that *may* lower a super-vector (all other ops). 20999ef9eebSMatthias Springer // The ops that *may* lower a super-vector only do so if the super-vector to 21099ef9eebSMatthias Springer // sub-vector ratio exists. The ops that *must* lower a super-vector are 21199ef9eebSMatthias Springer // explicitly checked for this property. 21299ef9eebSMatthias Springer /// TODO: there should be a single function for all ops to do this so we 21399ef9eebSMatthias Springer /// do not have to special case. Maybe a trait, or just a method, unclear atm. 21499ef9eebSMatthias Springer bool mustDivide = false; 21599ef9eebSMatthias Springer (void)mustDivide; 21699ef9eebSMatthias Springer VectorType superVectorType; 21799ef9eebSMatthias Springer if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) { 21899ef9eebSMatthias Springer superVectorType = transfer.getVectorType(); 21999ef9eebSMatthias Springer mustDivide = true; 22099ef9eebSMatthias Springer } else if (op.getNumResults() == 0) { 22123aa5a74SRiver Riddle if (!isa<func::ReturnOp>(op)) { 22299ef9eebSMatthias Springer op.emitError("NYI: assuming only return operations can have 0 " 22399ef9eebSMatthias Springer " results at this point"); 22499ef9eebSMatthias Springer } 22599ef9eebSMatthias Springer return false; 22699ef9eebSMatthias Springer } else if (op.getNumResults() == 1) { 2275550c821STres Popp if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) { 22899ef9eebSMatthias Springer superVectorType = v; 22999ef9eebSMatthias Springer } else { 23099ef9eebSMatthias Springer // Not a vector type. 23199ef9eebSMatthias Springer return false; 23299ef9eebSMatthias Springer } 23399ef9eebSMatthias Springer } else { 23499ef9eebSMatthias Springer // Not a vector.transfer and has more than 1 result, fail hard for now to 23599ef9eebSMatthias Springer // wake us up when something changes. 23699ef9eebSMatthias Springer op.emitError("NYI: operation has more than 1 result"); 23799ef9eebSMatthias Springer return false; 23899ef9eebSMatthias Springer } 23999ef9eebSMatthias Springer 24099ef9eebSMatthias Springer // Get the ratio. 2417a69a9d7SNicolas Vasilache auto ratio = 2427a69a9d7SNicolas Vasilache computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); 24399ef9eebSMatthias Springer 24499ef9eebSMatthias Springer // Sanity check. 2455413bf1bSKazu Hirata assert((ratio || !mustDivide) && 24699ef9eebSMatthias Springer "vector.transfer operation in which super-vector size is not an" 24799ef9eebSMatthias Springer " integer multiple of sub-vector size"); 24899ef9eebSMatthias Springer 24999ef9eebSMatthias Springer // This catches cases that are not strictly necessary to have multiplicity but 25099ef9eebSMatthias Springer // still aren't divisible by the sub-vector shape. 25199ef9eebSMatthias Springer // This could be useful information if we wanted to reshape at the level of 25299ef9eebSMatthias Springer // the vector type (but we would have to look at the compute and distinguish 25399ef9eebSMatthias Springer // between parallel, reduction and possibly other cases. 254064a08cdSKazu Hirata return ratio.has_value(); 25599ef9eebSMatthias Springer } 2568171eac2SAndrzej Warzyński 2578171eac2SAndrzej Warzyński bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { 2588171eac2SAndrzej Warzyński if (vectorType.isScalable()) 2598171eac2SAndrzej Warzyński return false; 2608171eac2SAndrzej Warzyński 2618171eac2SAndrzej Warzyński ArrayRef<int64_t> vectorShape = vectorType.getShape(); 2628171eac2SAndrzej Warzyński auto vecRank = vectorType.getRank(); 2638171eac2SAndrzej Warzyński 264*6aaa8f25SMatthias Springer if (!memrefType.areTrailingDimsContiguous(vecRank)) 2659478bf0cSAndrzej Warzyński return false; 2669478bf0cSAndrzej Warzyński 2678171eac2SAndrzej Warzyński // Extract the trailing dims and strides of the input memref 2688171eac2SAndrzej Warzyński auto memrefShape = memrefType.getShape().take_back(vecRank); 2698171eac2SAndrzej Warzyński 2709478bf0cSAndrzej Warzyński // Compare the dims of `vectorType` against `memrefType` (in reverse). 2718171eac2SAndrzej Warzyński // In the most basic case, all dims will match. 2728171eac2SAndrzej Warzyński auto firstNonMatchingDim = 2738171eac2SAndrzej Warzyński std::mismatch(vectorShape.rbegin(), vectorShape.rend(), 2748171eac2SAndrzej Warzyński memrefShape.rbegin(), memrefShape.rend()); 2758171eac2SAndrzej Warzyński if (firstNonMatchingDim.first == vectorShape.rend()) 2768171eac2SAndrzej Warzyński return true; 2778171eac2SAndrzej Warzyński 2788171eac2SAndrzej Warzyński // One non-matching dim is still fine, however the remaining leading dims of 2798171eac2SAndrzej Warzyński // `vectorType` need to be 1. 2808171eac2SAndrzej Warzyński SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first, 2818171eac2SAndrzej Warzyński vectorShape.rend()); 2828171eac2SAndrzej Warzyński 2838171eac2SAndrzej Warzyński return llvm::all_of(leadingDims, [](auto x) { return x == 1; }); 2848171eac2SAndrzej Warzyński } 285a1a68603SBenjamin Maxwell 286a1a68603SBenjamin Maxwell std::optional<StaticTileOffsetRange> 287a1a68603SBenjamin Maxwell vector::createUnrollIterator(VectorType vType, int64_t targetRank) { 288a1a68603SBenjamin Maxwell if (vType.getRank() <= targetRank) 289a1a68603SBenjamin Maxwell return {}; 290a1a68603SBenjamin Maxwell // Attempt to unroll until targetRank or the first scalable dimension (which 291a1a68603SBenjamin Maxwell // cannot be unrolled). 292a1a68603SBenjamin Maxwell auto shapeToUnroll = vType.getShape().drop_back(targetRank); 293a1a68603SBenjamin Maxwell auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank); 294a1a68603SBenjamin Maxwell auto it = 295a1a68603SBenjamin Maxwell std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true); 296a1a68603SBenjamin Maxwell auto firstScalableDim = it - scalableDimsToUnroll.begin(); 297a1a68603SBenjamin Maxwell if (firstScalableDim == 0) 298a1a68603SBenjamin Maxwell return {}; 299a1a68603SBenjamin Maxwell // All scalable dimensions should be removed now. 300a1a68603SBenjamin Maxwell scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim); 301a1a68603SBenjamin Maxwell assert(!llvm::is_contained(scalableDimsToUnroll, true) && 302a1a68603SBenjamin Maxwell "unexpected leading scalable dimension"); 303a1a68603SBenjamin Maxwell // Create an unroll iterator for leading dimensions. 304a1a68603SBenjamin Maxwell shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim); 305a1a68603SBenjamin Maxwell return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1); 306a1a68603SBenjamin Maxwell } 307c56bd7abSAndrzej Warzyński 308c56bd7abSAndrzej Warzyński SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics, 309c56bd7abSAndrzej Warzyński Operation *xfer, 310c56bd7abSAndrzej Warzyński RewriterBase &rewriter) { 311c56bd7abSAndrzej Warzyński auto loc = xfer->getLoc(); 312c56bd7abSAndrzej Warzyński 313c56bd7abSAndrzej Warzyński Value base = TypeSwitch<Operation *, Value>(xfer) 314c56bd7abSAndrzej Warzyński .Case<vector::TransferReadOp>( 315c56bd7abSAndrzej Warzyński [&](auto readOp) { return readOp.getSource(); }) 316c56bd7abSAndrzej Warzyński .Case<vector::TransferWriteOp>( 317c56bd7abSAndrzej Warzyński [&](auto writeOp) { return writeOp.getOperand(1); }); 318c56bd7abSAndrzej Warzyński 319c56bd7abSAndrzej Warzyński SmallVector<OpFoldResult> mixedSourceDims = 320c56bd7abSAndrzej Warzyński hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base) 321c56bd7abSAndrzej Warzyński : memref::getMixedSizes(rewriter, loc, base); 322c56bd7abSAndrzej Warzyński return mixedSourceDims; 323c56bd7abSAndrzej Warzyński } 324d3aa92edSAndrzej Warzyński 325d3aa92edSAndrzej Warzyński bool vector::isLinearizableVector(VectorType type) { 326fe07d9aaSAndrzej Warzyński return (type.getRank() > 1) && (type.getNumScalableDims() <= 1); 327d3aa92edSAndrzej Warzyński } 32830d4f6afSLubomir Litchev 32930d4f6afSLubomir Litchev Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, 33030d4f6afSLubomir Litchev Value source, ArrayRef<int64_t> readShape, 33130d4f6afSLubomir Litchev Value padValue, 33230d4f6afSLubomir Litchev bool useInBoundsInsteadOfMasking) { 33330d4f6afSLubomir Litchev assert(llvm::none_of(readShape, 33430d4f6afSLubomir Litchev [](int64_t s) { return s == ShapedType::kDynamic; }) && 33530d4f6afSLubomir Litchev "expected static shape"); 33630d4f6afSLubomir Litchev auto sourceShapedType = cast<ShapedType>(source.getType()); 33730d4f6afSLubomir Litchev auto sourceShape = sourceShapedType.getShape(); 33830d4f6afSLubomir Litchev assert(sourceShape.size() == readShape.size() && "expected same ranks."); 33930d4f6afSLubomir Litchev auto maskType = VectorType::get(readShape, builder.getI1Type()); 34030d4f6afSLubomir Litchev auto vectorType = VectorType::get(readShape, padValue.getType()); 34130d4f6afSLubomir Litchev assert(padValue.getType() == sourceShapedType.getElementType() && 34230d4f6afSLubomir Litchev "expected same pad element type to match source element type"); 34330d4f6afSLubomir Litchev int64_t readRank = readShape.size(); 34430d4f6afSLubomir Litchev auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); 34530d4f6afSLubomir Litchev SmallVector<bool> inBoundsVal(readRank, true); 3468feedd5eSPrashant Kumar if (useInBoundsInsteadOfMasking) { 34730d4f6afSLubomir Litchev // Update the inBounds attribute. 34830d4f6afSLubomir Litchev for (unsigned i = 0; i < readRank; i++) 34930d4f6afSLubomir Litchev inBoundsVal[i] = (sourceShape[i] == readShape[i]) && 35030d4f6afSLubomir Litchev !ShapedType::isDynamic(sourceShape[i]); 35130d4f6afSLubomir Litchev } 35230d4f6afSLubomir Litchev auto transferReadOp = builder.create<vector::TransferReadOp>( 35330d4f6afSLubomir Litchev loc, 35430d4f6afSLubomir Litchev /*vectorType=*/vectorType, 35530d4f6afSLubomir Litchev /*source=*/source, 35630d4f6afSLubomir Litchev /*indices=*/SmallVector<Value>(readRank, zero), 35730d4f6afSLubomir Litchev /*padding=*/padValue, 35830d4f6afSLubomir Litchev /*inBounds=*/inBoundsVal); 35930d4f6afSLubomir Litchev 3608feedd5eSPrashant Kumar if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking) 36130d4f6afSLubomir Litchev return transferReadOp; 36230d4f6afSLubomir Litchev SmallVector<OpFoldResult> mixedSourceDims = 36330d4f6afSLubomir Litchev tensor::getMixedSizes(builder, loc, source); 36430d4f6afSLubomir Litchev Value mask = 36530d4f6afSLubomir Litchev builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims); 36630d4f6afSLubomir Litchev return mlir::vector::maskOperation(builder, transferReadOp, mask) 36730d4f6afSLubomir Litchev ->getResult(0); 36830d4f6afSLubomir Litchev } 36930d4f6afSLubomir Litchev 37030d4f6afSLubomir Litchev LogicalResult 37130d4f6afSLubomir Litchev vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, 37230d4f6afSLubomir Litchev ArrayRef<int64_t> inputVectorSizes) { 37330d4f6afSLubomir Litchev LDBG("Iteration space static sizes:"); 37430d4f6afSLubomir Litchev LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs())); 37530d4f6afSLubomir Litchev LLVM_DEBUG(llvm::dbgs() << "\n"); 37630d4f6afSLubomir Litchev 37730d4f6afSLubomir Litchev if (inputVectorSizes.size() != shape.size()) { 37830d4f6afSLubomir Litchev LDBG("Input vector sizes don't match the number of loops"); 37930d4f6afSLubomir Litchev return failure(); 38030d4f6afSLubomir Litchev } 38130d4f6afSLubomir Litchev if (ShapedType::isDynamicShape(inputVectorSizes)) { 38230d4f6afSLubomir Litchev LDBG("Input vector sizes can't have dynamic dimensions"); 38330d4f6afSLubomir Litchev return failure(); 38430d4f6afSLubomir Litchev } 38530d4f6afSLubomir Litchev if (!llvm::all_of(llvm::zip(shape, inputVectorSizes), 38630d4f6afSLubomir Litchev [](std::tuple<int64_t, int64_t> sizePair) { 38730d4f6afSLubomir Litchev int64_t staticSize = std::get<0>(sizePair); 38830d4f6afSLubomir Litchev int64_t inputSize = std::get<1>(sizePair); 38930d4f6afSLubomir Litchev return ShapedType::isDynamic(staticSize) || 39030d4f6afSLubomir Litchev staticSize <= inputSize; 39130d4f6afSLubomir Litchev })) { 39230d4f6afSLubomir Litchev LDBG("Input vector sizes must be greater than or equal to iteration space " 39330d4f6afSLubomir Litchev "static sizes"); 39430d4f6afSLubomir Litchev return failure(); 39530d4f6afSLubomir Litchev } 39630d4f6afSLubomir Litchev return success(); 39730d4f6afSLubomir Litchev } 398