xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
199ef9eebSMatthias Springer //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements target-independent patterns to rewrite a vector.transfer
1099ef9eebSMatthias Springer // op into a fully in-bounds part and a partial part.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer 
14a1fe1f5fSKazu Hirata #include <optional>
152bc4c3e9SNicolas Vasilache #include <type_traits>
1699ef9eebSMatthias Springer 
1799ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2399ef9eebSMatthias Springer 
2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2599ef9eebSMatthias Springer #include "mlir/IR/Matchers.h"
2699ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
2799ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
2899ef9eebSMatthias Springer 
2999ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
3099ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
3199ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
3299ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h"
3399ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
3499ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h"
3599ef9eebSMatthias Springer 
3699ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-split"
3799ef9eebSMatthias Springer 
3899ef9eebSMatthias Springer using namespace mlir;
3999ef9eebSMatthias Springer using namespace mlir::vector;
4099ef9eebSMatthias Springer 
4199ef9eebSMatthias Springer /// Build the condition to ensure that a particular VectorTransferOpInterface
4299ef9eebSMatthias Springer /// is in-bounds.
4399ef9eebSMatthias Springer static Value createInBoundsCond(RewriterBase &b,
4499ef9eebSMatthias Springer                                 VectorTransferOpInterface xferOp) {
45fd5cda33SMatthias Springer   assert(xferOp.getPermutationMap().isMinorIdentity() &&
4699ef9eebSMatthias Springer          "Expected minor identity map");
4799ef9eebSMatthias Springer   Value inBoundsCond;
4899ef9eebSMatthias Springer   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
4999ef9eebSMatthias Springer     // Zip over the resulting vector shape and memref indices.
5099ef9eebSMatthias Springer     // If the dimension is known to be in-bounds, it does not participate in
5199ef9eebSMatthias Springer     // the construction of `inBoundsCond`.
5299ef9eebSMatthias Springer     if (xferOp.isDimInBounds(resultIdx))
5399ef9eebSMatthias Springer       return;
5499ef9eebSMatthias Springer     // Fold or create the check that `index + vector_size` <= `memref_size`.
5599ef9eebSMatthias Springer     Location loc = xferOp.getLoc();
5699ef9eebSMatthias Springer     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
57030b18feSMatthias Springer     OpFoldResult sum = affine::makeComposedFoldedAffineApply(
58030b18feSMatthias Springer         b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
598a40fcafSMatthias Springer         {xferOp.getIndices()[indicesIdx]});
60030b18feSMatthias Springer     OpFoldResult dimSz =
618a40fcafSMatthias Springer         memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
62030b18feSMatthias Springer     auto maybeCstSum = getConstantIntValue(sum);
63030b18feSMatthias Springer     auto maybeCstDimSz = getConstantIntValue(dimSz);
64030b18feSMatthias Springer     if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
6599ef9eebSMatthias Springer       return;
66030b18feSMatthias Springer     Value cond =
67030b18feSMatthias Springer         b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
68030b18feSMatthias Springer                                 getValueOrCreateConstantIndexOp(b, loc, sum),
69030b18feSMatthias Springer                                 getValueOrCreateConstantIndexOp(b, loc, dimSz));
7099ef9eebSMatthias Springer     // Conjunction over all dims for which we are in-bounds.
7199ef9eebSMatthias Springer     if (inBoundsCond)
7299ef9eebSMatthias Springer       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
7399ef9eebSMatthias Springer     else
7499ef9eebSMatthias Springer       inBoundsCond = cond;
7599ef9eebSMatthias Springer   });
7699ef9eebSMatthias Springer   return inBoundsCond;
7799ef9eebSMatthias Springer }
7899ef9eebSMatthias Springer 
7999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
8099ef9eebSMatthias Springer /// masking) fast path and a slow path.
8199ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
8299ef9eebSMatthias Springer /// newly created conditional upon function return.
832bc4c3e9SNicolas Vasilache /// To accommodate for the fact that the original vector.transfer indexing may
842bc4c3e9SNicolas Vasilache /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
8599ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
8699ef9eebSMatthias Springer /// At this time, only vector.transfer_read case is implemented.
8799ef9eebSMatthias Springer ///
8899ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
8999ef9eebSMatthias Springer /// ```
9099ef9eebSMatthias Springer ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
9199ef9eebSMatthias Springer /// ```
9299ef9eebSMatthias Springer /// is transformed into:
9399ef9eebSMatthias Springer /// ```
9499ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
9599ef9eebSMatthias Springer ///      // fast path, direct cast
9699ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
9799ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
9899ef9eebSMatthias Springer ///    } else {
9999ef9eebSMatthias Springer ///      // slow path, not in-bounds vector.transfer or linalg.copy.
10099ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
10199ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
10299ef9eebSMatthias Springer //     }
10399ef9eebSMatthias Springer ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
10499ef9eebSMatthias Springer /// ```
10599ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
10699ef9eebSMatthias Springer ///
10799ef9eebSMatthias Springer /// Preconditions:
108fd5cda33SMatthias Springer ///  1. `xferOp.getPermutationMap()` must be a minor identity map
1098a40fcafSMatthias Springer ///  2. the rank of the `xferOp.memref()` and the rank of the
1108a40fcafSMatthias Springer ///     `xferOp.getVector()` must be equal. This will be relaxed in the future
1118a40fcafSMatthias Springer ///     but requires rank-reducing subviews.
11299ef9eebSMatthias Springer static LogicalResult
11399ef9eebSMatthias Springer splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
11499ef9eebSMatthias Springer   // TODO: support 0-d corner case.
11599ef9eebSMatthias Springer   if (xferOp.getTransferRank() == 0)
11699ef9eebSMatthias Springer     return failure();
11799ef9eebSMatthias Springer 
11899ef9eebSMatthias Springer   // TODO: expand support to these 2 cases.
119fd5cda33SMatthias Springer   if (!xferOp.getPermutationMap().isMinorIdentity())
12099ef9eebSMatthias Springer     return failure();
12199ef9eebSMatthias Springer   // Must have some out-of-bounds dimension to be a candidate for splitting.
12299ef9eebSMatthias Springer   if (!xferOp.hasOutOfBoundsDim())
12399ef9eebSMatthias Springer     return failure();
12499ef9eebSMatthias Springer   // Don't split transfer operations directly under IfOp, this avoids applying
12599ef9eebSMatthias Springer   // the pattern recursively.
12699ef9eebSMatthias Springer   // TODO: improve the filtering condition to make it more applicable.
12799ef9eebSMatthias Springer   if (isa<scf::IfOp>(xferOp->getParentOp()))
12899ef9eebSMatthias Springer     return failure();
12999ef9eebSMatthias Springer   return success();
13099ef9eebSMatthias Springer }
13199ef9eebSMatthias Springer 
13299ef9eebSMatthias Springer /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
13399ef9eebSMatthias Springer /// be cast. If the MemRefTypes don't have the same rank or are not strided,
13499ef9eebSMatthias Springer /// return null; otherwise:
13599ef9eebSMatthias Springer ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
13699ef9eebSMatthias Springer ///   2. else return a new MemRefType obtained by iterating over the shape and
13799ef9eebSMatthias Springer ///   strides and:
13899ef9eebSMatthias Springer ///     a. keeping the ones that are static and equal across `aT` and `bT`.
13999ef9eebSMatthias Springer ///     b. using a dynamic shape and/or stride for the dimensions that don't
14099ef9eebSMatthias Springer ///        agree.
14199ef9eebSMatthias Springer static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
14299ef9eebSMatthias Springer   if (memref::CastOp::areCastCompatible(aT, bT))
14399ef9eebSMatthias Springer     return aT;
14499ef9eebSMatthias Springer   if (aT.getRank() != bT.getRank())
14599ef9eebSMatthias Springer     return MemRefType();
14699ef9eebSMatthias Springer   int64_t aOffset, bOffset;
14799ef9eebSMatthias Springer   SmallVector<int64_t, 4> aStrides, bStrides;
148*6aaa8f25SMatthias Springer   if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149*6aaa8f25SMatthias Springer       failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
15099ef9eebSMatthias Springer       aStrides.size() != bStrides.size())
15199ef9eebSMatthias Springer     return MemRefType();
15299ef9eebSMatthias Springer 
15399ef9eebSMatthias Springer   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
15499ef9eebSMatthias Springer   int64_t resOffset;
15599ef9eebSMatthias Springer   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
15699ef9eebSMatthias Springer       resStrides(bT.getRank(), 0);
15799ef9eebSMatthias Springer   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
15899ef9eebSMatthias Springer     resShape[idx] =
159399638f9SAliia Khasanova         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
1602bc4c3e9SNicolas Vasilache     resStrides[idx] =
1612bc4c3e9SNicolas Vasilache         (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
16299ef9eebSMatthias Springer   }
1632bc4c3e9SNicolas Vasilache   resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
16499ef9eebSMatthias Springer   return MemRefType::get(
16599ef9eebSMatthias Springer       resShape, aT.getElementType(),
166f096e72cSAlex Zinenko       StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
16799ef9eebSMatthias Springer }
16899ef9eebSMatthias Springer 
1695b6b2cafSQuinn Dawkins /// Casts the given memref to a compatible memref type. If the source memref has
1705b6b2cafSQuinn Dawkins /// a different address space than the target type, a `memref.memory_space_cast`
1715b6b2cafSQuinn Dawkins /// is first inserted, followed by a `memref.cast`.
1725b6b2cafSQuinn Dawkins static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
1735b6b2cafSQuinn Dawkins                                         MemRefType compatibleMemRefType) {
174a5757c5bSChristian Sigg   MemRefType sourceType = cast<MemRefType>(memref.getType());
1755b6b2cafSQuinn Dawkins   Value res = memref;
1765b6b2cafSQuinn Dawkins   if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
1775b6b2cafSQuinn Dawkins     sourceType = MemRefType::get(
1785b6b2cafSQuinn Dawkins         sourceType.getShape(), sourceType.getElementType(),
1795b6b2cafSQuinn Dawkins         sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
1805b6b2cafSQuinn Dawkins     res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
1815b6b2cafSQuinn Dawkins   }
1825b6b2cafSQuinn Dawkins   if (sourceType == compatibleMemRefType)
1835b6b2cafSQuinn Dawkins     return res;
1845b6b2cafSQuinn Dawkins   return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
1855b6b2cafSQuinn Dawkins }
1865b6b2cafSQuinn Dawkins 
18799ef9eebSMatthias Springer /// Operates under a scoped context to build the intersection between the
1888a40fcafSMatthias Springer /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
18999ef9eebSMatthias Springer // TODO: view intersection/union/differences should be a proper std op.
19099ef9eebSMatthias Springer static std::pair<Value, Value>
19199ef9eebSMatthias Springer createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
19299ef9eebSMatthias Springer                           Value alloc) {
19399ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
19499ef9eebSMatthias Springer   int64_t memrefRank = xferOp.getShapedType().getRank();
19599ef9eebSMatthias Springer   // TODO: relax this precondition, will require rank-reducing subviews.
1965550c821STres Popp   assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
19799ef9eebSMatthias Springer          "Expected memref rank to match the alloc rank");
19899ef9eebSMatthias Springer   ValueRange leadingIndices =
1998a40fcafSMatthias Springer       xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
20099ef9eebSMatthias Springer   SmallVector<OpFoldResult, 4> sizes;
20199ef9eebSMatthias Springer   sizes.append(leadingIndices.begin(), leadingIndices.end());
20299ef9eebSMatthias Springer   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
20399ef9eebSMatthias Springer   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
20499ef9eebSMatthias Springer     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2058a40fcafSMatthias Springer     Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
2068a40fcafSMatthias Springer                                               xferOp.getSource(), indicesIdx);
20799ef9eebSMatthias Springer     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
2088a40fcafSMatthias Springer     Value index = xferOp.getIndices()[indicesIdx];
20999ef9eebSMatthias Springer     AffineExpr i, j, k;
21099ef9eebSMatthias Springer     bindDims(xferOp.getContext(), i, j, k);
21199ef9eebSMatthias Springer     SmallVector<AffineMap, 4> maps =
212fe8a62c4SUday Bondhugula         AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
21399ef9eebSMatthias Springer     // affine_min(%dimMemRef - %index, %dimAlloc)
2144c48f016SMatthias Springer     Value affineMin = b.create<affine::AffineMinOp>(
21599ef9eebSMatthias Springer         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
21699ef9eebSMatthias Springer     sizes.push_back(affineMin);
21799ef9eebSMatthias Springer   });
21899ef9eebSMatthias Springer 
21999ef9eebSMatthias Springer   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
2208a40fcafSMatthias Springer       xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
22199ef9eebSMatthias Springer   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
22299ef9eebSMatthias Springer   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
22399ef9eebSMatthias Springer   auto copySrc = b.create<memref::SubViewOp>(
2248a40fcafSMatthias Springer       loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
22599ef9eebSMatthias Springer   auto copyDest = b.create<memref::SubViewOp>(
2268a40fcafSMatthias Springer       loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
22799ef9eebSMatthias Springer   return std::make_pair(copySrc, copyDest);
22899ef9eebSMatthias Springer }
22999ef9eebSMatthias Springer 
23099ef9eebSMatthias Springer /// Given an `xferOp` for which:
23199ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
23299ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
23399ef9eebSMatthias Springer /// Produce IR resembling:
23499ef9eebSMatthias Springer /// ```
23599ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
2365b6b2cafSQuinn Dawkins ///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
23799ef9eebSMatthias Springer ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
23899ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
23999ef9eebSMatthias Springer ///    } else {
24099ef9eebSMatthias Springer ///      %2 = linalg.fill(%pad, %alloc)
24199ef9eebSMatthias Springer ///      %3 = subview %view [...][...][...]
24299ef9eebSMatthias Springer ///      %4 = subview %alloc [0, 0] [...] [...]
24399ef9eebSMatthias Springer ///      linalg.copy(%3, %4)
24499ef9eebSMatthias Springer ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
24599ef9eebSMatthias Springer ///      scf.yield %5, ... : compatibleMemRefType, index, index
24699ef9eebSMatthias Springer ///   }
24799ef9eebSMatthias Springer /// ```
24899ef9eebSMatthias Springer /// Return the produced scf::IfOp.
24999ef9eebSMatthias Springer static scf::IfOp
25099ef9eebSMatthias Springer createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
25199ef9eebSMatthias Springer                             TypeRange returnTypes, Value inBoundsCond,
25299ef9eebSMatthias Springer                             MemRefType compatibleMemRefType, Value alloc) {
25399ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
25499ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
2557c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
25699ef9eebSMatthias Springer   return b.create<scf::IfOp>(
2571125c5c0SFrederik Gossen       loc, inBoundsCond,
25899ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
2595b6b2cafSQuinn Dawkins         Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
26099ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{res};
2617c38fd60SJacques Pienaar         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
2627c38fd60SJacques Pienaar                               xferOp.getIndices().end());
26399ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
26499ef9eebSMatthias Springer       },
26599ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
2667c38fd60SJacques Pienaar         b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
2677294be2bSgysit                                  ValueRange{alloc});
26899ef9eebSMatthias Springer         // Take partial subview of memref which guarantees no dimension
26999ef9eebSMatthias Springer         // overflows.
27099ef9eebSMatthias Springer         IRRewriter rewriter(b);
27199ef9eebSMatthias Springer         std::pair<Value, Value> copyArgs = createSubViewIntersection(
27299ef9eebSMatthias Springer             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
27399ef9eebSMatthias Springer             alloc);
274ebc81537SAlexander Belyaev         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
27599ef9eebSMatthias Springer         Value casted =
2765b6b2cafSQuinn Dawkins             castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
27799ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{casted};
27899ef9eebSMatthias Springer         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
27999ef9eebSMatthias Springer                               zero);
28099ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
28199ef9eebSMatthias Springer       });
28299ef9eebSMatthias Springer }
28399ef9eebSMatthias Springer 
28499ef9eebSMatthias Springer /// Given an `xferOp` for which:
28599ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
28699ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
28799ef9eebSMatthias Springer /// Produce IR resembling:
28899ef9eebSMatthias Springer /// ```
28999ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
2905b6b2cafSQuinn Dawkins ///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
29199ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
29299ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
29399ef9eebSMatthias Springer ///    } else {
29499ef9eebSMatthias Springer ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
29599ef9eebSMatthias Springer ///      %3 = vector.type_cast %extra_alloc :
29699ef9eebSMatthias Springer ///        memref<...> to memref<vector<...>>
29799ef9eebSMatthias Springer ///      store %2, %3[] : memref<vector<...>>
29899ef9eebSMatthias Springer ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
29999ef9eebSMatthias Springer ///      scf.yield %4, ... : compatibleMemRefType, index, index
30099ef9eebSMatthias Springer ///   }
30199ef9eebSMatthias Springer /// ```
30299ef9eebSMatthias Springer /// Return the produced scf::IfOp.
30399ef9eebSMatthias Springer static scf::IfOp createFullPartialVectorTransferRead(
30499ef9eebSMatthias Springer     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
30599ef9eebSMatthias Springer     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
30699ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
30799ef9eebSMatthias Springer   scf::IfOp fullPartialIfOp;
30899ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3097c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
31099ef9eebSMatthias Springer   return b.create<scf::IfOp>(
3111125c5c0SFrederik Gossen       loc, inBoundsCond,
31299ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
3135b6b2cafSQuinn Dawkins         Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
31499ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{res};
3157c38fd60SJacques Pienaar         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
3167c38fd60SJacques Pienaar                               xferOp.getIndices().end());
31799ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
31899ef9eebSMatthias Springer       },
31999ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
32099ef9eebSMatthias Springer         Operation *newXfer = b.clone(*xferOp.getOperation());
3218a40fcafSMatthias Springer         Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
32299ef9eebSMatthias Springer         b.create<memref::StoreOp>(
32399ef9eebSMatthias Springer             loc, vector,
32499ef9eebSMatthias Springer             b.create<vector::TypeCastOp>(
32599ef9eebSMatthias Springer                 loc, MemRefType::get({}, vector.getType()), alloc));
32699ef9eebSMatthias Springer 
32799ef9eebSMatthias Springer         Value casted =
3285b6b2cafSQuinn Dawkins             castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
32999ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{casted};
33099ef9eebSMatthias Springer         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
33199ef9eebSMatthias Springer                               zero);
33299ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
33399ef9eebSMatthias Springer       });
33499ef9eebSMatthias Springer }
33599ef9eebSMatthias Springer 
33699ef9eebSMatthias Springer /// Given an `xferOp` for which:
33799ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
33899ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
33999ef9eebSMatthias Springer /// Produce IR resembling:
34099ef9eebSMatthias Springer /// ```
34199ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
34299ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
34399ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
34499ef9eebSMatthias Springer ///    } else {
34599ef9eebSMatthias Springer ///      %3 = vector.type_cast %extra_alloc :
34699ef9eebSMatthias Springer ///        memref<...> to memref<vector<...>>
34799ef9eebSMatthias Springer ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
34899ef9eebSMatthias Springer ///      scf.yield %4, ... : compatibleMemRefType, index, index
34999ef9eebSMatthias Springer ///   }
35099ef9eebSMatthias Springer /// ```
35199ef9eebSMatthias Springer static ValueRange
35299ef9eebSMatthias Springer getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
35399ef9eebSMatthias Springer                           TypeRange returnTypes, Value inBoundsCond,
35499ef9eebSMatthias Springer                           MemRefType compatibleMemRefType, Value alloc) {
35599ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
35699ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3577c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
35899ef9eebSMatthias Springer   return b
35999ef9eebSMatthias Springer       .create<scf::IfOp>(
3601125c5c0SFrederik Gossen           loc, inBoundsCond,
36199ef9eebSMatthias Springer           [&](OpBuilder &b, Location loc) {
3625b6b2cafSQuinn Dawkins             Value res =
3635b6b2cafSQuinn Dawkins                 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
36499ef9eebSMatthias Springer             scf::ValueVector viewAndIndices{res};
36599ef9eebSMatthias Springer             viewAndIndices.insert(viewAndIndices.end(),
3667c38fd60SJacques Pienaar                                   xferOp.getIndices().begin(),
3677c38fd60SJacques Pienaar                                   xferOp.getIndices().end());
36899ef9eebSMatthias Springer             b.create<scf::YieldOp>(loc, viewAndIndices);
36999ef9eebSMatthias Springer           },
37099ef9eebSMatthias Springer           [&](OpBuilder &b, Location loc) {
37199ef9eebSMatthias Springer             Value casted =
3725b6b2cafSQuinn Dawkins                 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
37399ef9eebSMatthias Springer             scf::ValueVector viewAndIndices{casted};
37499ef9eebSMatthias Springer             viewAndIndices.insert(viewAndIndices.end(),
37599ef9eebSMatthias Springer                                   xferOp.getTransferRank(), zero);
37699ef9eebSMatthias Springer             b.create<scf::YieldOp>(loc, viewAndIndices);
37799ef9eebSMatthias Springer           })
37899ef9eebSMatthias Springer       ->getResults();
37999ef9eebSMatthias Springer }
38099ef9eebSMatthias Springer 
38199ef9eebSMatthias Springer /// Given an `xferOp` for which:
38299ef9eebSMatthias Springer ///   1. `inBoundsCond` has been computed.
38399ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
38499ef9eebSMatthias Springer ///   3. it originally wrote to %view
38599ef9eebSMatthias Springer /// Produce IR resembling:
38699ef9eebSMatthias Springer /// ```
38799ef9eebSMatthias Springer ///    %notInBounds = arith.xori %inBounds, %true
38899ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
38999ef9eebSMatthias Springer ///      %3 = subview %alloc [...][...][...]
39099ef9eebSMatthias Springer ///      %4 = subview %view [0, 0][...][...]
39199ef9eebSMatthias Springer ///      linalg.copy(%3, %4)
39299ef9eebSMatthias Springer ///   }
39399ef9eebSMatthias Springer /// ```
39499ef9eebSMatthias Springer static void createFullPartialLinalgCopy(RewriterBase &b,
39599ef9eebSMatthias Springer                                         vector::TransferWriteOp xferOp,
39699ef9eebSMatthias Springer                                         Value inBoundsCond, Value alloc) {
39799ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
39899ef9eebSMatthias Springer   auto notInBounds = b.create<arith::XOrIOp>(
39999ef9eebSMatthias Springer       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
40099ef9eebSMatthias Springer   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
40199ef9eebSMatthias Springer     IRRewriter rewriter(b);
40299ef9eebSMatthias Springer     std::pair<Value, Value> copyArgs = createSubViewIntersection(
40399ef9eebSMatthias Springer         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
40499ef9eebSMatthias Springer         alloc);
405ebc81537SAlexander Belyaev     b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
40699ef9eebSMatthias Springer     b.create<scf::YieldOp>(loc, ValueRange{});
40799ef9eebSMatthias Springer   });
40899ef9eebSMatthias Springer }
40999ef9eebSMatthias Springer 
41099ef9eebSMatthias Springer /// Given an `xferOp` for which:
41199ef9eebSMatthias Springer ///   1. `inBoundsCond` has been computed.
41299ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
41399ef9eebSMatthias Springer ///   3. it originally wrote to %view
41499ef9eebSMatthias Springer /// Produce IR resembling:
41599ef9eebSMatthias Springer /// ```
41699ef9eebSMatthias Springer ///    %notInBounds = arith.xori %inBounds, %true
41799ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
41899ef9eebSMatthias Springer ///      %2 = load %alloc : memref<vector<...>>
41999ef9eebSMatthias Springer ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
42099ef9eebSMatthias Springer ///   }
42199ef9eebSMatthias Springer /// ```
42299ef9eebSMatthias Springer static void createFullPartialVectorTransferWrite(RewriterBase &b,
42399ef9eebSMatthias Springer                                                  vector::TransferWriteOp xferOp,
42499ef9eebSMatthias Springer                                                  Value inBoundsCond,
42599ef9eebSMatthias Springer                                                  Value alloc) {
42699ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
42799ef9eebSMatthias Springer   auto notInBounds = b.create<arith::XOrIOp>(
42899ef9eebSMatthias Springer       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
42999ef9eebSMatthias Springer   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
4304d67b278SJeff Niu     IRMapping mapping;
43199ef9eebSMatthias Springer     Value load = b.create<memref::LoadOp>(
4327c38fd60SJacques Pienaar         loc,
4337c38fd60SJacques Pienaar         b.create<vector::TypeCastOp>(
4341b60f0d7SJeff Niu             loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
4351b60f0d7SJeff Niu         ValueRange());
4367c38fd60SJacques Pienaar     mapping.map(xferOp.getVector(), load);
43799ef9eebSMatthias Springer     b.clone(*xferOp.getOperation(), mapping);
43899ef9eebSMatthias Springer     b.create<scf::YieldOp>(loc, ValueRange{});
43999ef9eebSMatthias Springer   });
44099ef9eebSMatthias Springer }
44199ef9eebSMatthias Springer 
4423c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
4433c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
4444c807f2fSAlex Zinenko   // Find the closest surrounding allocation scope that is not a known looping
4454c807f2fSAlex Zinenko   // construct (putting alloca's in loops doesn't always lower to deallocation
4464c807f2fSAlex Zinenko   // until the end of the loop).
4474c807f2fSAlex Zinenko   Operation *scope = nullptr;
4484c807f2fSAlex Zinenko   for (Operation *parent = op->getParentOp(); parent != nullptr;
4494c807f2fSAlex Zinenko        parent = parent->getParentOp()) {
4504c807f2fSAlex Zinenko     if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
4514c807f2fSAlex Zinenko       scope = parent;
4524c48f016SMatthias Springer     if (!isa<scf::ForOp, affine::AffineForOp>(parent))
4534c807f2fSAlex Zinenko       break;
4544c807f2fSAlex Zinenko   }
4553c3810e7SNicolas Vasilache   assert(scope && "Expected op to be inside automatic allocation scope");
4563c3810e7SNicolas Vasilache   return scope;
4573c3810e7SNicolas Vasilache }
4583c3810e7SNicolas Vasilache 
45999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
46099ef9eebSMatthias Springer /// masking) fastpath and a slowpath.
46199ef9eebSMatthias Springer ///
46299ef9eebSMatthias Springer /// For vector.transfer_read:
46399ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
46499ef9eebSMatthias Springer /// newly created conditional upon function return.
46599ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be
46699ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
46799ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
46899ef9eebSMatthias Springer ///
46999ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
47099ef9eebSMatthias Springer /// ```
47199ef9eebSMatthias Springer ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
47299ef9eebSMatthias Springer /// ```
47399ef9eebSMatthias Springer /// is transformed into:
47499ef9eebSMatthias Springer /// ```
47599ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
47699ef9eebSMatthias Springer ///      // fastpath, direct cast
47799ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
47899ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
47999ef9eebSMatthias Springer ///    } else {
48099ef9eebSMatthias Springer ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
48199ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
48299ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
48399ef9eebSMatthias Springer //     }
48499ef9eebSMatthias Springer ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
48599ef9eebSMatthias Springer /// ```
48699ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
48799ef9eebSMatthias Springer ///
48899ef9eebSMatthias Springer /// For vector.transfer_write:
48999ef9eebSMatthias Springer /// There are 2 conditional blocks. First a block to decide which memref and
49099ef9eebSMatthias Springer /// indices to use for an unmasked, inbounds write. Then a conditional block to
49199ef9eebSMatthias Springer /// further copy a partial buffer into the final result in the slow path case.
49299ef9eebSMatthias Springer ///
49399ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_write):
49499ef9eebSMatthias Springer /// ```
49599ef9eebSMatthias Springer ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
49699ef9eebSMatthias Springer /// ```
49799ef9eebSMatthias Springer /// is transformed into:
49899ef9eebSMatthias Springer /// ```
49999ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
50099ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
50199ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
50299ef9eebSMatthias Springer ///    } else {
50399ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
50499ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
50599ef9eebSMatthias Springer ///     }
50699ef9eebSMatthias Springer ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
50799ef9eebSMatthias Springer ///                                                                    true]}
50899ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
50999ef9eebSMatthias Springer ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
51099ef9eebSMatthias Springer ///    }
51199ef9eebSMatthias Springer /// ```
51299ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
51399ef9eebSMatthias Springer ///
51499ef9eebSMatthias Springer /// Preconditions:
515fd5cda33SMatthias Springer ///  1. `xferOp.getPermutationMap()` must be a minor identity map
5168a40fcafSMatthias Springer ///  2. the rank of the `xferOp.getSource()` and the rank of the
5178a40fcafSMatthias Springer ///     `xferOp.getVector()` must be equal. This will be relaxed in the future
5188a40fcafSMatthias Springer ///     but requires rank-reducing subviews.
51999ef9eebSMatthias Springer LogicalResult mlir::vector::splitFullAndPartialTransfer(
52099ef9eebSMatthias Springer     RewriterBase &b, VectorTransferOpInterface xferOp,
52199ef9eebSMatthias Springer     VectorTransformsOptions options, scf::IfOp *ifOp) {
52299ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::None)
52399ef9eebSMatthias Springer     return failure();
52499ef9eebSMatthias Springer 
52599ef9eebSMatthias Springer   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
52699ef9eebSMatthias Springer   auto inBoundsAttr = b.getBoolArrayAttr(bools);
52799ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
5285fcf907bSMatthias Springer     b.modifyOpInPlace(xferOp, [&]() {
5298a40fcafSMatthias Springer       xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
5302fea658aSMatthias Springer     });
53199ef9eebSMatthias Springer     return success();
53299ef9eebSMatthias Springer   }
53399ef9eebSMatthias Springer 
53499ef9eebSMatthias Springer   // Assert preconditions. Additionally, keep the variables in an inner scope to
53599ef9eebSMatthias Springer   // ensure they aren't used in the wrong scopes further down.
53699ef9eebSMatthias Springer   {
53799ef9eebSMatthias Springer     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
53899ef9eebSMatthias Springer            "Expected splitFullAndPartialTransferPrecondition to hold");
53999ef9eebSMatthias Springer 
54099ef9eebSMatthias Springer     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
54199ef9eebSMatthias Springer     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
54299ef9eebSMatthias Springer 
54399ef9eebSMatthias Springer     if (!(xferReadOp || xferWriteOp))
54499ef9eebSMatthias Springer       return failure();
5457c38fd60SJacques Pienaar     if (xferWriteOp && xferWriteOp.getMask())
54699ef9eebSMatthias Springer       return failure();
5477c38fd60SJacques Pienaar     if (xferReadOp && xferReadOp.getMask())
54899ef9eebSMatthias Springer       return failure();
54999ef9eebSMatthias Springer   }
55099ef9eebSMatthias Springer 
55199ef9eebSMatthias Springer   RewriterBase::InsertionGuard guard(b);
55299ef9eebSMatthias Springer   b.setInsertionPoint(xferOp);
55399ef9eebSMatthias Springer   Value inBoundsCond = createInBoundsCond(
55499ef9eebSMatthias Springer       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
55599ef9eebSMatthias Springer   if (!inBoundsCond)
55699ef9eebSMatthias Springer     return failure();
55799ef9eebSMatthias Springer 
55899ef9eebSMatthias Springer   // Top of the function `alloc` for transient storage.
55999ef9eebSMatthias Springer   Value alloc;
56099ef9eebSMatthias Springer   {
56199ef9eebSMatthias Springer     RewriterBase::InsertionGuard guard(b);
5623c3810e7SNicolas Vasilache     Operation *scope = getAutomaticAllocationScope(xferOp);
5633c3810e7SNicolas Vasilache     assert(scope->getNumRegions() == 1 &&
5643c3810e7SNicolas Vasilache            "AutomaticAllocationScope with >1 regions");
5653c3810e7SNicolas Vasilache     b.setInsertionPointToStart(&scope->getRegion(0).front());
56699ef9eebSMatthias Springer     auto shape = xferOp.getVectorType().getShape();
56799ef9eebSMatthias Springer     Type elementType = xferOp.getVectorType().getElementType();
5683c3810e7SNicolas Vasilache     alloc = b.create<memref::AllocaOp>(scope->getLoc(),
56999ef9eebSMatthias Springer                                        MemRefType::get(shape, elementType),
57099ef9eebSMatthias Springer                                        ValueRange{}, b.getI64IntegerAttr(32));
57199ef9eebSMatthias Springer   }
57299ef9eebSMatthias Springer 
57399ef9eebSMatthias Springer   MemRefType compatibleMemRefType =
5745550c821STres Popp       getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
5755550c821STres Popp                                   cast<MemRefType>(alloc.getType()));
57699ef9eebSMatthias Springer   if (!compatibleMemRefType)
57799ef9eebSMatthias Springer     return failure();
57899ef9eebSMatthias Springer 
57999ef9eebSMatthias Springer   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
58099ef9eebSMatthias Springer                                    b.getIndexType());
58199ef9eebSMatthias Springer   returnTypes[0] = compatibleMemRefType;
58299ef9eebSMatthias Springer 
58399ef9eebSMatthias Springer   if (auto xferReadOp =
58499ef9eebSMatthias Springer           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
58599ef9eebSMatthias Springer     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
58699ef9eebSMatthias Springer     scf::IfOp fullPartialIfOp =
58799ef9eebSMatthias Springer         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
58899ef9eebSMatthias Springer             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
58999ef9eebSMatthias Springer                                                   inBoundsCond,
59099ef9eebSMatthias Springer                                                   compatibleMemRefType, alloc)
59199ef9eebSMatthias Springer             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
59299ef9eebSMatthias Springer                                           inBoundsCond, compatibleMemRefType,
59399ef9eebSMatthias Springer                                           alloc);
59499ef9eebSMatthias Springer     if (ifOp)
59599ef9eebSMatthias Springer       *ifOp = fullPartialIfOp;
59699ef9eebSMatthias Springer 
59799ef9eebSMatthias Springer     // Set existing read op to in-bounds, it always reads from a full buffer.
59899ef9eebSMatthias Springer     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
59999ef9eebSMatthias Springer       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
60099ef9eebSMatthias Springer 
6015fcf907bSMatthias Springer     b.modifyOpInPlace(xferOp, [&]() {
6028a40fcafSMatthias Springer       xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
6032fea658aSMatthias Springer     });
60499ef9eebSMatthias Springer 
60599ef9eebSMatthias Springer     return success();
60699ef9eebSMatthias Springer   }
60799ef9eebSMatthias Springer 
60899ef9eebSMatthias Springer   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
60999ef9eebSMatthias Springer 
61099ef9eebSMatthias Springer   // Decide which location to write the entire vector to.
61199ef9eebSMatthias Springer   auto memrefAndIndices = getLocationToWriteFullVec(
61299ef9eebSMatthias Springer       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
61399ef9eebSMatthias Springer 
61499ef9eebSMatthias Springer   // Do an in bounds write to either the output or the extra allocated buffer.
61599ef9eebSMatthias Springer   // The operation is cloned to prevent deleting information needed for the
61699ef9eebSMatthias Springer   // later IR creation.
6174d67b278SJeff Niu   IRMapping mapping;
6187c38fd60SJacques Pienaar   mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
6197c38fd60SJacques Pienaar   mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
62099ef9eebSMatthias Springer   auto *clone = b.clone(*xferWriteOp, mapping);
62199ef9eebSMatthias Springer   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
62299ef9eebSMatthias Springer 
62399ef9eebSMatthias Springer   // Create a potential copy from the allocated buffer to the final output in
62499ef9eebSMatthias Springer   // the slow path case.
62599ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
62699ef9eebSMatthias Springer     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
62799ef9eebSMatthias Springer   else
62899ef9eebSMatthias Springer     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
62999ef9eebSMatthias Springer 
6302fea658aSMatthias Springer   b.eraseOp(xferOp);
63199ef9eebSMatthias Springer 
63299ef9eebSMatthias Springer   return success();
63399ef9eebSMatthias Springer }
63499ef9eebSMatthias Springer 
6352bc4c3e9SNicolas Vasilache namespace {
6362bc4c3e9SNicolas Vasilache /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
6372bc4c3e9SNicolas Vasilache /// may take an extra filter to perform selection at a finer granularity.
6382bc4c3e9SNicolas Vasilache struct VectorTransferFullPartialRewriter : public RewritePattern {
6392bc4c3e9SNicolas Vasilache   using FilterConstraintType =
6402bc4c3e9SNicolas Vasilache       std::function<LogicalResult(VectorTransferOpInterface op)>;
6412bc4c3e9SNicolas Vasilache 
6422bc4c3e9SNicolas Vasilache   explicit VectorTransferFullPartialRewriter(
6432bc4c3e9SNicolas Vasilache       MLIRContext *context,
6442bc4c3e9SNicolas Vasilache       VectorTransformsOptions options = VectorTransformsOptions(),
6452bc4c3e9SNicolas Vasilache       FilterConstraintType filter =
6462bc4c3e9SNicolas Vasilache           [](VectorTransferOpInterface op) { return success(); },
6472bc4c3e9SNicolas Vasilache       PatternBenefit benefit = 1)
6482bc4c3e9SNicolas Vasilache       : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
6492bc4c3e9SNicolas Vasilache         filter(std::move(filter)) {}
6502bc4c3e9SNicolas Vasilache 
6512bc4c3e9SNicolas Vasilache   /// Performs the rewrite.
6522bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(Operation *op,
6532bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override;
6542bc4c3e9SNicolas Vasilache 
6552bc4c3e9SNicolas Vasilache private:
6562bc4c3e9SNicolas Vasilache   VectorTransformsOptions options;
6572bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
6582bc4c3e9SNicolas Vasilache };
6592bc4c3e9SNicolas Vasilache 
6602bc4c3e9SNicolas Vasilache } // namespace
6612bc4c3e9SNicolas Vasilache 
6622bc4c3e9SNicolas Vasilache LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
66399ef9eebSMatthias Springer     Operation *op, PatternRewriter &rewriter) const {
66499ef9eebSMatthias Springer   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
66599ef9eebSMatthias Springer   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
66699ef9eebSMatthias Springer       failed(filter(xferOp)))
66799ef9eebSMatthias Springer     return failure();
6682fea658aSMatthias Springer   return splitFullAndPartialTransfer(rewriter, xferOp, options);
66999ef9eebSMatthias Springer }
6702bc4c3e9SNicolas Vasilache 
6712bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorTransferFullPartialPatterns(
6722bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, const VectorTransformsOptions &options) {
6732bc4c3e9SNicolas Vasilache   patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
6742bc4c3e9SNicolas Vasilache                                                   options);
6752bc4c3e9SNicolas Vasilache }
676