199ef9eebSMatthias Springer //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer // 999ef9eebSMatthias Springer // This file implements target-independent patterns to rewrite a vector.transfer 1099ef9eebSMatthias Springer // op into a fully in-bounds part and a partial part. 1199ef9eebSMatthias Springer // 1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 1399ef9eebSMatthias Springer 14a1fe1f5fSKazu Hirata #include <optional> 152bc4c3e9SNicolas Vasilache #include <type_traits> 1699ef9eebSMatthias Springer 1799ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1999ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 2099ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 2299ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 2399ef9eebSMatthias Springer 2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 2599ef9eebSMatthias Springer #include "mlir/IR/Matchers.h" 2699ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h" 2799ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h" 2899ef9eebSMatthias Springer 2999ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h" 3099ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h" 3199ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h" 3299ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h" 3399ef9eebSMatthias Springer #include "llvm/Support/Debug.h" 3499ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h" 3599ef9eebSMatthias Springer 3699ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-split" 3799ef9eebSMatthias Springer 3899ef9eebSMatthias Springer using namespace mlir; 3999ef9eebSMatthias Springer using namespace mlir::vector; 4099ef9eebSMatthias Springer 4199ef9eebSMatthias Springer /// Build the condition to ensure that a particular VectorTransferOpInterface 4299ef9eebSMatthias Springer /// is in-bounds. 4399ef9eebSMatthias Springer static Value createInBoundsCond(RewriterBase &b, 4499ef9eebSMatthias Springer VectorTransferOpInterface xferOp) { 45fd5cda33SMatthias Springer assert(xferOp.getPermutationMap().isMinorIdentity() && 4699ef9eebSMatthias Springer "Expected minor identity map"); 4799ef9eebSMatthias Springer Value inBoundsCond; 4899ef9eebSMatthias Springer xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 4999ef9eebSMatthias Springer // Zip over the resulting vector shape and memref indices. 5099ef9eebSMatthias Springer // If the dimension is known to be in-bounds, it does not participate in 5199ef9eebSMatthias Springer // the construction of `inBoundsCond`. 5299ef9eebSMatthias Springer if (xferOp.isDimInBounds(resultIdx)) 5399ef9eebSMatthias Springer return; 5499ef9eebSMatthias Springer // Fold or create the check that `index + vector_size` <= `memref_size`. 5599ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 5699ef9eebSMatthias Springer int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 57030b18feSMatthias Springer OpFoldResult sum = affine::makeComposedFoldedAffineApply( 58030b18feSMatthias Springer b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize), 598a40fcafSMatthias Springer {xferOp.getIndices()[indicesIdx]}); 60030b18feSMatthias Springer OpFoldResult dimSz = 618a40fcafSMatthias Springer memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx); 62030b18feSMatthias Springer auto maybeCstSum = getConstantIntValue(sum); 63030b18feSMatthias Springer auto maybeCstDimSz = getConstantIntValue(dimSz); 64030b18feSMatthias Springer if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz) 6599ef9eebSMatthias Springer return; 66030b18feSMatthias Springer Value cond = 67030b18feSMatthias Springer b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, 68030b18feSMatthias Springer getValueOrCreateConstantIndexOp(b, loc, sum), 69030b18feSMatthias Springer getValueOrCreateConstantIndexOp(b, loc, dimSz)); 7099ef9eebSMatthias Springer // Conjunction over all dims for which we are in-bounds. 7199ef9eebSMatthias Springer if (inBoundsCond) 7299ef9eebSMatthias Springer inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 7399ef9eebSMatthias Springer else 7499ef9eebSMatthias Springer inBoundsCond = cond; 7599ef9eebSMatthias Springer }); 7699ef9eebSMatthias Springer return inBoundsCond; 7799ef9eebSMatthias Springer } 7899ef9eebSMatthias Springer 7999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 8099ef9eebSMatthias Springer /// masking) fast path and a slow path. 8199ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 8299ef9eebSMatthias Springer /// newly created conditional upon function return. 832bc4c3e9SNicolas Vasilache /// To accommodate for the fact that the original vector.transfer indexing may 842bc4c3e9SNicolas Vasilache /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 8599ef9eebSMatthias Springer /// scf.if op returns a view and values of type index. 8699ef9eebSMatthias Springer /// At this time, only vector.transfer_read case is implemented. 8799ef9eebSMatthias Springer /// 8899ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read): 8999ef9eebSMatthias Springer /// ``` 9099ef9eebSMatthias Springer /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 9199ef9eebSMatthias Springer /// ``` 9299ef9eebSMatthias Springer /// is transformed into: 9399ef9eebSMatthias Springer /// ``` 9499ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 9599ef9eebSMatthias Springer /// // fast path, direct cast 9699ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType 9799ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index 9899ef9eebSMatthias Springer /// } else { 9999ef9eebSMatthias Springer /// // slow path, not in-bounds vector.transfer or linalg.copy. 10099ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType 10199ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index 10299ef9eebSMatthias Springer // } 10399ef9eebSMatthias Springer /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 10499ef9eebSMatthias Springer /// ``` 10599ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector. 10699ef9eebSMatthias Springer /// 10799ef9eebSMatthias Springer /// Preconditions: 108fd5cda33SMatthias Springer /// 1. `xferOp.getPermutationMap()` must be a minor identity map 1098a40fcafSMatthias Springer /// 2. the rank of the `xferOp.memref()` and the rank of the 1108a40fcafSMatthias Springer /// `xferOp.getVector()` must be equal. This will be relaxed in the future 1118a40fcafSMatthias Springer /// but requires rank-reducing subviews. 11299ef9eebSMatthias Springer static LogicalResult 11399ef9eebSMatthias Springer splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 11499ef9eebSMatthias Springer // TODO: support 0-d corner case. 11599ef9eebSMatthias Springer if (xferOp.getTransferRank() == 0) 11699ef9eebSMatthias Springer return failure(); 11799ef9eebSMatthias Springer 11899ef9eebSMatthias Springer // TODO: expand support to these 2 cases. 119fd5cda33SMatthias Springer if (!xferOp.getPermutationMap().isMinorIdentity()) 12099ef9eebSMatthias Springer return failure(); 12199ef9eebSMatthias Springer // Must have some out-of-bounds dimension to be a candidate for splitting. 12299ef9eebSMatthias Springer if (!xferOp.hasOutOfBoundsDim()) 12399ef9eebSMatthias Springer return failure(); 12499ef9eebSMatthias Springer // Don't split transfer operations directly under IfOp, this avoids applying 12599ef9eebSMatthias Springer // the pattern recursively. 12699ef9eebSMatthias Springer // TODO: improve the filtering condition to make it more applicable. 12799ef9eebSMatthias Springer if (isa<scf::IfOp>(xferOp->getParentOp())) 12899ef9eebSMatthias Springer return failure(); 12999ef9eebSMatthias Springer return success(); 13099ef9eebSMatthias Springer } 13199ef9eebSMatthias Springer 13299ef9eebSMatthias Springer /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 13399ef9eebSMatthias Springer /// be cast. If the MemRefTypes don't have the same rank or are not strided, 13499ef9eebSMatthias Springer /// return null; otherwise: 13599ef9eebSMatthias Springer /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 13699ef9eebSMatthias Springer /// 2. else return a new MemRefType obtained by iterating over the shape and 13799ef9eebSMatthias Springer /// strides and: 13899ef9eebSMatthias Springer /// a. keeping the ones that are static and equal across `aT` and `bT`. 13999ef9eebSMatthias Springer /// b. using a dynamic shape and/or stride for the dimensions that don't 14099ef9eebSMatthias Springer /// agree. 14199ef9eebSMatthias Springer static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 14299ef9eebSMatthias Springer if (memref::CastOp::areCastCompatible(aT, bT)) 14399ef9eebSMatthias Springer return aT; 14499ef9eebSMatthias Springer if (aT.getRank() != bT.getRank()) 14599ef9eebSMatthias Springer return MemRefType(); 14699ef9eebSMatthias Springer int64_t aOffset, bOffset; 14799ef9eebSMatthias Springer SmallVector<int64_t, 4> aStrides, bStrides; 148*6aaa8f25SMatthias Springer if (failed(aT.getStridesAndOffset(aStrides, aOffset)) || 149*6aaa8f25SMatthias Springer failed(bT.getStridesAndOffset(bStrides, bOffset)) || 15099ef9eebSMatthias Springer aStrides.size() != bStrides.size()) 15199ef9eebSMatthias Springer return MemRefType(); 15299ef9eebSMatthias Springer 15399ef9eebSMatthias Springer ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 15499ef9eebSMatthias Springer int64_t resOffset; 15599ef9eebSMatthias Springer SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 15699ef9eebSMatthias Springer resStrides(bT.getRank(), 0); 15799ef9eebSMatthias Springer for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 15899ef9eebSMatthias Springer resShape[idx] = 159399638f9SAliia Khasanova (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic; 1602bc4c3e9SNicolas Vasilache resStrides[idx] = 1612bc4c3e9SNicolas Vasilache (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic; 16299ef9eebSMatthias Springer } 1632bc4c3e9SNicolas Vasilache resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; 16499ef9eebSMatthias Springer return MemRefType::get( 16599ef9eebSMatthias Springer resShape, aT.getElementType(), 166f096e72cSAlex Zinenko StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); 16799ef9eebSMatthias Springer } 16899ef9eebSMatthias Springer 1695b6b2cafSQuinn Dawkins /// Casts the given memref to a compatible memref type. If the source memref has 1705b6b2cafSQuinn Dawkins /// a different address space than the target type, a `memref.memory_space_cast` 1715b6b2cafSQuinn Dawkins /// is first inserted, followed by a `memref.cast`. 1725b6b2cafSQuinn Dawkins static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, 1735b6b2cafSQuinn Dawkins MemRefType compatibleMemRefType) { 174a5757c5bSChristian Sigg MemRefType sourceType = cast<MemRefType>(memref.getType()); 1755b6b2cafSQuinn Dawkins Value res = memref; 1765b6b2cafSQuinn Dawkins if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) { 1775b6b2cafSQuinn Dawkins sourceType = MemRefType::get( 1785b6b2cafSQuinn Dawkins sourceType.getShape(), sourceType.getElementType(), 1795b6b2cafSQuinn Dawkins sourceType.getLayout(), compatibleMemRefType.getMemorySpace()); 1805b6b2cafSQuinn Dawkins res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res); 1815b6b2cafSQuinn Dawkins } 1825b6b2cafSQuinn Dawkins if (sourceType == compatibleMemRefType) 1835b6b2cafSQuinn Dawkins return res; 1845b6b2cafSQuinn Dawkins return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res); 1855b6b2cafSQuinn Dawkins } 1865b6b2cafSQuinn Dawkins 18799ef9eebSMatthias Springer /// Operates under a scoped context to build the intersection between the 1888a40fcafSMatthias Springer /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`. 18999ef9eebSMatthias Springer // TODO: view intersection/union/differences should be a proper std op. 19099ef9eebSMatthias Springer static std::pair<Value, Value> 19199ef9eebSMatthias Springer createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 19299ef9eebSMatthias Springer Value alloc) { 19399ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 19499ef9eebSMatthias Springer int64_t memrefRank = xferOp.getShapedType().getRank(); 19599ef9eebSMatthias Springer // TODO: relax this precondition, will require rank-reducing subviews. 1965550c821STres Popp assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() && 19799ef9eebSMatthias Springer "Expected memref rank to match the alloc rank"); 19899ef9eebSMatthias Springer ValueRange leadingIndices = 1998a40fcafSMatthias Springer xferOp.getIndices().take_front(xferOp.getLeadingShapedRank()); 20099ef9eebSMatthias Springer SmallVector<OpFoldResult, 4> sizes; 20199ef9eebSMatthias Springer sizes.append(leadingIndices.begin(), leadingIndices.end()); 20299ef9eebSMatthias Springer auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 20399ef9eebSMatthias Springer xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 20499ef9eebSMatthias Springer using MapList = ArrayRef<ArrayRef<AffineExpr>>; 2058a40fcafSMatthias Springer Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(), 2068a40fcafSMatthias Springer xferOp.getSource(), indicesIdx); 20799ef9eebSMatthias Springer Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 2088a40fcafSMatthias Springer Value index = xferOp.getIndices()[indicesIdx]; 20999ef9eebSMatthias Springer AffineExpr i, j, k; 21099ef9eebSMatthias Springer bindDims(xferOp.getContext(), i, j, k); 21199ef9eebSMatthias Springer SmallVector<AffineMap, 4> maps = 212fe8a62c4SUday Bondhugula AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext()); 21399ef9eebSMatthias Springer // affine_min(%dimMemRef - %index, %dimAlloc) 2144c48f016SMatthias Springer Value affineMin = b.create<affine::AffineMinOp>( 21599ef9eebSMatthias Springer loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 21699ef9eebSMatthias Springer sizes.push_back(affineMin); 21799ef9eebSMatthias Springer }); 21899ef9eebSMatthias Springer 21999ef9eebSMatthias Springer SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 2208a40fcafSMatthias Springer xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; })); 22199ef9eebSMatthias Springer SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 22299ef9eebSMatthias Springer SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 22399ef9eebSMatthias Springer auto copySrc = b.create<memref::SubViewOp>( 2248a40fcafSMatthias Springer loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides); 22599ef9eebSMatthias Springer auto copyDest = b.create<memref::SubViewOp>( 2268a40fcafSMatthias Springer loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides); 22799ef9eebSMatthias Springer return std::make_pair(copySrc, copyDest); 22899ef9eebSMatthias Springer } 22999ef9eebSMatthias Springer 23099ef9eebSMatthias Springer /// Given an `xferOp` for which: 23199ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 23299ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated. 23399ef9eebSMatthias Springer /// Produce IR resembling: 23499ef9eebSMatthias Springer /// ``` 23599ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 2365b6b2cafSQuinn Dawkins /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) 23799ef9eebSMatthias Springer /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 23899ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index 23999ef9eebSMatthias Springer /// } else { 24099ef9eebSMatthias Springer /// %2 = linalg.fill(%pad, %alloc) 24199ef9eebSMatthias Springer /// %3 = subview %view [...][...][...] 24299ef9eebSMatthias Springer /// %4 = subview %alloc [0, 0] [...] [...] 24399ef9eebSMatthias Springer /// linalg.copy(%3, %4) 24499ef9eebSMatthias Springer /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 24599ef9eebSMatthias Springer /// scf.yield %5, ... : compatibleMemRefType, index, index 24699ef9eebSMatthias Springer /// } 24799ef9eebSMatthias Springer /// ``` 24899ef9eebSMatthias Springer /// Return the produced scf::IfOp. 24999ef9eebSMatthias Springer static scf::IfOp 25099ef9eebSMatthias Springer createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 25199ef9eebSMatthias Springer TypeRange returnTypes, Value inBoundsCond, 25299ef9eebSMatthias Springer MemRefType compatibleMemRefType, Value alloc) { 25399ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 25499ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 2557c38fd60SJacques Pienaar Value memref = xferOp.getSource(); 25699ef9eebSMatthias Springer return b.create<scf::IfOp>( 2571125c5c0SFrederik Gossen loc, inBoundsCond, 25899ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 2595b6b2cafSQuinn Dawkins Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); 26099ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res}; 2617c38fd60SJacques Pienaar viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 2627c38fd60SJacques Pienaar xferOp.getIndices().end()); 26399ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 26499ef9eebSMatthias Springer }, 26599ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 2667c38fd60SJacques Pienaar b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()}, 2677294be2bSgysit ValueRange{alloc}); 26899ef9eebSMatthias Springer // Take partial subview of memref which guarantees no dimension 26999ef9eebSMatthias Springer // overflows. 27099ef9eebSMatthias Springer IRRewriter rewriter(b); 27199ef9eebSMatthias Springer std::pair<Value, Value> copyArgs = createSubViewIntersection( 27299ef9eebSMatthias Springer rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 27399ef9eebSMatthias Springer alloc); 274ebc81537SAlexander Belyaev b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 27599ef9eebSMatthias Springer Value casted = 2765b6b2cafSQuinn Dawkins castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 27799ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted}; 27899ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 27999ef9eebSMatthias Springer zero); 28099ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 28199ef9eebSMatthias Springer }); 28299ef9eebSMatthias Springer } 28399ef9eebSMatthias Springer 28499ef9eebSMatthias Springer /// Given an `xferOp` for which: 28599ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 28699ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated. 28799ef9eebSMatthias Springer /// Produce IR resembling: 28899ef9eebSMatthias Springer /// ``` 28999ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 2905b6b2cafSQuinn Dawkins /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) 29199ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType 29299ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index 29399ef9eebSMatthias Springer /// } else { 29499ef9eebSMatthias Springer /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 29599ef9eebSMatthias Springer /// %3 = vector.type_cast %extra_alloc : 29699ef9eebSMatthias Springer /// memref<...> to memref<vector<...>> 29799ef9eebSMatthias Springer /// store %2, %3[] : memref<vector<...>> 29899ef9eebSMatthias Springer /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 29999ef9eebSMatthias Springer /// scf.yield %4, ... : compatibleMemRefType, index, index 30099ef9eebSMatthias Springer /// } 30199ef9eebSMatthias Springer /// ``` 30299ef9eebSMatthias Springer /// Return the produced scf::IfOp. 30399ef9eebSMatthias Springer static scf::IfOp createFullPartialVectorTransferRead( 30499ef9eebSMatthias Springer RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 30599ef9eebSMatthias Springer Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 30699ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 30799ef9eebSMatthias Springer scf::IfOp fullPartialIfOp; 30899ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 3097c38fd60SJacques Pienaar Value memref = xferOp.getSource(); 31099ef9eebSMatthias Springer return b.create<scf::IfOp>( 3111125c5c0SFrederik Gossen loc, inBoundsCond, 31299ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 3135b6b2cafSQuinn Dawkins Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); 31499ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res}; 3157c38fd60SJacques Pienaar viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 3167c38fd60SJacques Pienaar xferOp.getIndices().end()); 31799ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 31899ef9eebSMatthias Springer }, 31999ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 32099ef9eebSMatthias Springer Operation *newXfer = b.clone(*xferOp.getOperation()); 3218a40fcafSMatthias Springer Value vector = cast<VectorTransferOpInterface>(newXfer).getVector(); 32299ef9eebSMatthias Springer b.create<memref::StoreOp>( 32399ef9eebSMatthias Springer loc, vector, 32499ef9eebSMatthias Springer b.create<vector::TypeCastOp>( 32599ef9eebSMatthias Springer loc, MemRefType::get({}, vector.getType()), alloc)); 32699ef9eebSMatthias Springer 32799ef9eebSMatthias Springer Value casted = 3285b6b2cafSQuinn Dawkins castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 32999ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted}; 33099ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 33199ef9eebSMatthias Springer zero); 33299ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 33399ef9eebSMatthias Springer }); 33499ef9eebSMatthias Springer } 33599ef9eebSMatthias Springer 33699ef9eebSMatthias Springer /// Given an `xferOp` for which: 33799ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 33899ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated. 33999ef9eebSMatthias Springer /// Produce IR resembling: 34099ef9eebSMatthias Springer /// ``` 34199ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 34299ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType 34399ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index 34499ef9eebSMatthias Springer /// } else { 34599ef9eebSMatthias Springer /// %3 = vector.type_cast %extra_alloc : 34699ef9eebSMatthias Springer /// memref<...> to memref<vector<...>> 34799ef9eebSMatthias Springer /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 34899ef9eebSMatthias Springer /// scf.yield %4, ... : compatibleMemRefType, index, index 34999ef9eebSMatthias Springer /// } 35099ef9eebSMatthias Springer /// ``` 35199ef9eebSMatthias Springer static ValueRange 35299ef9eebSMatthias Springer getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 35399ef9eebSMatthias Springer TypeRange returnTypes, Value inBoundsCond, 35499ef9eebSMatthias Springer MemRefType compatibleMemRefType, Value alloc) { 35599ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 35699ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 3577c38fd60SJacques Pienaar Value memref = xferOp.getSource(); 35899ef9eebSMatthias Springer return b 35999ef9eebSMatthias Springer .create<scf::IfOp>( 3601125c5c0SFrederik Gossen loc, inBoundsCond, 36199ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 3625b6b2cafSQuinn Dawkins Value res = 3635b6b2cafSQuinn Dawkins castToCompatibleMemRefType(b, memref, compatibleMemRefType); 36499ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res}; 36599ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), 3667c38fd60SJacques Pienaar xferOp.getIndices().begin(), 3677c38fd60SJacques Pienaar xferOp.getIndices().end()); 36899ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 36999ef9eebSMatthias Springer }, 37099ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) { 37199ef9eebSMatthias Springer Value casted = 3725b6b2cafSQuinn Dawkins castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 37399ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted}; 37499ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), 37599ef9eebSMatthias Springer xferOp.getTransferRank(), zero); 37699ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices); 37799ef9eebSMatthias Springer }) 37899ef9eebSMatthias Springer ->getResults(); 37999ef9eebSMatthias Springer } 38099ef9eebSMatthias Springer 38199ef9eebSMatthias Springer /// Given an `xferOp` for which: 38299ef9eebSMatthias Springer /// 1. `inBoundsCond` has been computed. 38399ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated. 38499ef9eebSMatthias Springer /// 3. it originally wrote to %view 38599ef9eebSMatthias Springer /// Produce IR resembling: 38699ef9eebSMatthias Springer /// ``` 38799ef9eebSMatthias Springer /// %notInBounds = arith.xori %inBounds, %true 38899ef9eebSMatthias Springer /// scf.if (%notInBounds) { 38999ef9eebSMatthias Springer /// %3 = subview %alloc [...][...][...] 39099ef9eebSMatthias Springer /// %4 = subview %view [0, 0][...][...] 39199ef9eebSMatthias Springer /// linalg.copy(%3, %4) 39299ef9eebSMatthias Springer /// } 39399ef9eebSMatthias Springer /// ``` 39499ef9eebSMatthias Springer static void createFullPartialLinalgCopy(RewriterBase &b, 39599ef9eebSMatthias Springer vector::TransferWriteOp xferOp, 39699ef9eebSMatthias Springer Value inBoundsCond, Value alloc) { 39799ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 39899ef9eebSMatthias Springer auto notInBounds = b.create<arith::XOrIOp>( 39999ef9eebSMatthias Springer loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 40099ef9eebSMatthias Springer b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 40199ef9eebSMatthias Springer IRRewriter rewriter(b); 40299ef9eebSMatthias Springer std::pair<Value, Value> copyArgs = createSubViewIntersection( 40399ef9eebSMatthias Springer rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 40499ef9eebSMatthias Springer alloc); 405ebc81537SAlexander Belyaev b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 40699ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, ValueRange{}); 40799ef9eebSMatthias Springer }); 40899ef9eebSMatthias Springer } 40999ef9eebSMatthias Springer 41099ef9eebSMatthias Springer /// Given an `xferOp` for which: 41199ef9eebSMatthias Springer /// 1. `inBoundsCond` has been computed. 41299ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated. 41399ef9eebSMatthias Springer /// 3. it originally wrote to %view 41499ef9eebSMatthias Springer /// Produce IR resembling: 41599ef9eebSMatthias Springer /// ``` 41699ef9eebSMatthias Springer /// %notInBounds = arith.xori %inBounds, %true 41799ef9eebSMatthias Springer /// scf.if (%notInBounds) { 41899ef9eebSMatthias Springer /// %2 = load %alloc : memref<vector<...>> 41999ef9eebSMatthias Springer /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 42099ef9eebSMatthias Springer /// } 42199ef9eebSMatthias Springer /// ``` 42299ef9eebSMatthias Springer static void createFullPartialVectorTransferWrite(RewriterBase &b, 42399ef9eebSMatthias Springer vector::TransferWriteOp xferOp, 42499ef9eebSMatthias Springer Value inBoundsCond, 42599ef9eebSMatthias Springer Value alloc) { 42699ef9eebSMatthias Springer Location loc = xferOp.getLoc(); 42799ef9eebSMatthias Springer auto notInBounds = b.create<arith::XOrIOp>( 42899ef9eebSMatthias Springer loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 42999ef9eebSMatthias Springer b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 4304d67b278SJeff Niu IRMapping mapping; 43199ef9eebSMatthias Springer Value load = b.create<memref::LoadOp>( 4327c38fd60SJacques Pienaar loc, 4337c38fd60SJacques Pienaar b.create<vector::TypeCastOp>( 4341b60f0d7SJeff Niu loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), 4351b60f0d7SJeff Niu ValueRange()); 4367c38fd60SJacques Pienaar mapping.map(xferOp.getVector(), load); 43799ef9eebSMatthias Springer b.clone(*xferOp.getOperation(), mapping); 43899ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, ValueRange{}); 43999ef9eebSMatthias Springer }); 44099ef9eebSMatthias Springer } 44199ef9eebSMatthias Springer 4423c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 4433c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) { 4444c807f2fSAlex Zinenko // Find the closest surrounding allocation scope that is not a known looping 4454c807f2fSAlex Zinenko // construct (putting alloca's in loops doesn't always lower to deallocation 4464c807f2fSAlex Zinenko // until the end of the loop). 4474c807f2fSAlex Zinenko Operation *scope = nullptr; 4484c807f2fSAlex Zinenko for (Operation *parent = op->getParentOp(); parent != nullptr; 4494c807f2fSAlex Zinenko parent = parent->getParentOp()) { 4504c807f2fSAlex Zinenko if (parent->hasTrait<OpTrait::AutomaticAllocationScope>()) 4514c807f2fSAlex Zinenko scope = parent; 4524c48f016SMatthias Springer if (!isa<scf::ForOp, affine::AffineForOp>(parent)) 4534c807f2fSAlex Zinenko break; 4544c807f2fSAlex Zinenko } 4553c3810e7SNicolas Vasilache assert(scope && "Expected op to be inside automatic allocation scope"); 4563c3810e7SNicolas Vasilache return scope; 4573c3810e7SNicolas Vasilache } 4583c3810e7SNicolas Vasilache 45999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 46099ef9eebSMatthias Springer /// masking) fastpath and a slowpath. 46199ef9eebSMatthias Springer /// 46299ef9eebSMatthias Springer /// For vector.transfer_read: 46399ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 46499ef9eebSMatthias Springer /// newly created conditional upon function return. 46599ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be 46699ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 46799ef9eebSMatthias Springer /// scf.if op returns a view and values of type index. 46899ef9eebSMatthias Springer /// 46999ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read): 47099ef9eebSMatthias Springer /// ``` 47199ef9eebSMatthias Springer /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 47299ef9eebSMatthias Springer /// ``` 47399ef9eebSMatthias Springer /// is transformed into: 47499ef9eebSMatthias Springer /// ``` 47599ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 47699ef9eebSMatthias Springer /// // fastpath, direct cast 47799ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType 47899ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index 47999ef9eebSMatthias Springer /// } else { 48099ef9eebSMatthias Springer /// // slowpath, not in-bounds vector.transfer or linalg.copy. 48199ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType 48299ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index 48399ef9eebSMatthias Springer // } 48499ef9eebSMatthias Springer /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 48599ef9eebSMatthias Springer /// ``` 48699ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector. 48799ef9eebSMatthias Springer /// 48899ef9eebSMatthias Springer /// For vector.transfer_write: 48999ef9eebSMatthias Springer /// There are 2 conditional blocks. First a block to decide which memref and 49099ef9eebSMatthias Springer /// indices to use for an unmasked, inbounds write. Then a conditional block to 49199ef9eebSMatthias Springer /// further copy a partial buffer into the final result in the slow path case. 49299ef9eebSMatthias Springer /// 49399ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_write): 49499ef9eebSMatthias Springer /// ``` 49599ef9eebSMatthias Springer /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 49699ef9eebSMatthias Springer /// ``` 49799ef9eebSMatthias Springer /// is transformed into: 49899ef9eebSMatthias Springer /// ``` 49999ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) { 50099ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType 50199ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index 50299ef9eebSMatthias Springer /// } else { 50399ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType 50499ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index 50599ef9eebSMatthias Springer /// } 50699ef9eebSMatthias Springer /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 50799ef9eebSMatthias Springer /// true]} 50899ef9eebSMatthias Springer /// scf.if (%notInBounds) { 50999ef9eebSMatthias Springer /// // slowpath: not in-bounds vector.transfer or linalg.copy. 51099ef9eebSMatthias Springer /// } 51199ef9eebSMatthias Springer /// ``` 51299ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector. 51399ef9eebSMatthias Springer /// 51499ef9eebSMatthias Springer /// Preconditions: 515fd5cda33SMatthias Springer /// 1. `xferOp.getPermutationMap()` must be a minor identity map 5168a40fcafSMatthias Springer /// 2. the rank of the `xferOp.getSource()` and the rank of the 5178a40fcafSMatthias Springer /// `xferOp.getVector()` must be equal. This will be relaxed in the future 5188a40fcafSMatthias Springer /// but requires rank-reducing subviews. 51999ef9eebSMatthias Springer LogicalResult mlir::vector::splitFullAndPartialTransfer( 52099ef9eebSMatthias Springer RewriterBase &b, VectorTransferOpInterface xferOp, 52199ef9eebSMatthias Springer VectorTransformsOptions options, scf::IfOp *ifOp) { 52299ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::None) 52399ef9eebSMatthias Springer return failure(); 52499ef9eebSMatthias Springer 52599ef9eebSMatthias Springer SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 52699ef9eebSMatthias Springer auto inBoundsAttr = b.getBoolArrayAttr(bools); 52799ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 5285fcf907bSMatthias Springer b.modifyOpInPlace(xferOp, [&]() { 5298a40fcafSMatthias Springer xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 5302fea658aSMatthias Springer }); 53199ef9eebSMatthias Springer return success(); 53299ef9eebSMatthias Springer } 53399ef9eebSMatthias Springer 53499ef9eebSMatthias Springer // Assert preconditions. Additionally, keep the variables in an inner scope to 53599ef9eebSMatthias Springer // ensure they aren't used in the wrong scopes further down. 53699ef9eebSMatthias Springer { 53799ef9eebSMatthias Springer assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 53899ef9eebSMatthias Springer "Expected splitFullAndPartialTransferPrecondition to hold"); 53999ef9eebSMatthias Springer 54099ef9eebSMatthias Springer auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 54199ef9eebSMatthias Springer auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 54299ef9eebSMatthias Springer 54399ef9eebSMatthias Springer if (!(xferReadOp || xferWriteOp)) 54499ef9eebSMatthias Springer return failure(); 5457c38fd60SJacques Pienaar if (xferWriteOp && xferWriteOp.getMask()) 54699ef9eebSMatthias Springer return failure(); 5477c38fd60SJacques Pienaar if (xferReadOp && xferReadOp.getMask()) 54899ef9eebSMatthias Springer return failure(); 54999ef9eebSMatthias Springer } 55099ef9eebSMatthias Springer 55199ef9eebSMatthias Springer RewriterBase::InsertionGuard guard(b); 55299ef9eebSMatthias Springer b.setInsertionPoint(xferOp); 55399ef9eebSMatthias Springer Value inBoundsCond = createInBoundsCond( 55499ef9eebSMatthias Springer b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 55599ef9eebSMatthias Springer if (!inBoundsCond) 55699ef9eebSMatthias Springer return failure(); 55799ef9eebSMatthias Springer 55899ef9eebSMatthias Springer // Top of the function `alloc` for transient storage. 55999ef9eebSMatthias Springer Value alloc; 56099ef9eebSMatthias Springer { 56199ef9eebSMatthias Springer RewriterBase::InsertionGuard guard(b); 5623c3810e7SNicolas Vasilache Operation *scope = getAutomaticAllocationScope(xferOp); 5633c3810e7SNicolas Vasilache assert(scope->getNumRegions() == 1 && 5643c3810e7SNicolas Vasilache "AutomaticAllocationScope with >1 regions"); 5653c3810e7SNicolas Vasilache b.setInsertionPointToStart(&scope->getRegion(0).front()); 56699ef9eebSMatthias Springer auto shape = xferOp.getVectorType().getShape(); 56799ef9eebSMatthias Springer Type elementType = xferOp.getVectorType().getElementType(); 5683c3810e7SNicolas Vasilache alloc = b.create<memref::AllocaOp>(scope->getLoc(), 56999ef9eebSMatthias Springer MemRefType::get(shape, elementType), 57099ef9eebSMatthias Springer ValueRange{}, b.getI64IntegerAttr(32)); 57199ef9eebSMatthias Springer } 57299ef9eebSMatthias Springer 57399ef9eebSMatthias Springer MemRefType compatibleMemRefType = 5745550c821STres Popp getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()), 5755550c821STres Popp cast<MemRefType>(alloc.getType())); 57699ef9eebSMatthias Springer if (!compatibleMemRefType) 57799ef9eebSMatthias Springer return failure(); 57899ef9eebSMatthias Springer 57999ef9eebSMatthias Springer SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 58099ef9eebSMatthias Springer b.getIndexType()); 58199ef9eebSMatthias Springer returnTypes[0] = compatibleMemRefType; 58299ef9eebSMatthias Springer 58399ef9eebSMatthias Springer if (auto xferReadOp = 58499ef9eebSMatthias Springer dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 58599ef9eebSMatthias Springer // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 58699ef9eebSMatthias Springer scf::IfOp fullPartialIfOp = 58799ef9eebSMatthias Springer options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 58899ef9eebSMatthias Springer ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 58999ef9eebSMatthias Springer inBoundsCond, 59099ef9eebSMatthias Springer compatibleMemRefType, alloc) 59199ef9eebSMatthias Springer : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 59299ef9eebSMatthias Springer inBoundsCond, compatibleMemRefType, 59399ef9eebSMatthias Springer alloc); 59499ef9eebSMatthias Springer if (ifOp) 59599ef9eebSMatthias Springer *ifOp = fullPartialIfOp; 59699ef9eebSMatthias Springer 59799ef9eebSMatthias Springer // Set existing read op to in-bounds, it always reads from a full buffer. 59899ef9eebSMatthias Springer for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 59999ef9eebSMatthias Springer xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 60099ef9eebSMatthias Springer 6015fcf907bSMatthias Springer b.modifyOpInPlace(xferOp, [&]() { 6028a40fcafSMatthias Springer xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 6032fea658aSMatthias Springer }); 60499ef9eebSMatthias Springer 60599ef9eebSMatthias Springer return success(); 60699ef9eebSMatthias Springer } 60799ef9eebSMatthias Springer 60899ef9eebSMatthias Springer auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 60999ef9eebSMatthias Springer 61099ef9eebSMatthias Springer // Decide which location to write the entire vector to. 61199ef9eebSMatthias Springer auto memrefAndIndices = getLocationToWriteFullVec( 61299ef9eebSMatthias Springer b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 61399ef9eebSMatthias Springer 61499ef9eebSMatthias Springer // Do an in bounds write to either the output or the extra allocated buffer. 61599ef9eebSMatthias Springer // The operation is cloned to prevent deleting information needed for the 61699ef9eebSMatthias Springer // later IR creation. 6174d67b278SJeff Niu IRMapping mapping; 6187c38fd60SJacques Pienaar mapping.map(xferWriteOp.getSource(), memrefAndIndices.front()); 6197c38fd60SJacques Pienaar mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front()); 62099ef9eebSMatthias Springer auto *clone = b.clone(*xferWriteOp, mapping); 62199ef9eebSMatthias Springer clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 62299ef9eebSMatthias Springer 62399ef9eebSMatthias Springer // Create a potential copy from the allocated buffer to the final output in 62499ef9eebSMatthias Springer // the slow path case. 62599ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 62699ef9eebSMatthias Springer createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 62799ef9eebSMatthias Springer else 62899ef9eebSMatthias Springer createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 62999ef9eebSMatthias Springer 6302fea658aSMatthias Springer b.eraseOp(xferOp); 63199ef9eebSMatthias Springer 63299ef9eebSMatthias Springer return success(); 63399ef9eebSMatthias Springer } 63499ef9eebSMatthias Springer 6352bc4c3e9SNicolas Vasilache namespace { 6362bc4c3e9SNicolas Vasilache /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern 6372bc4c3e9SNicolas Vasilache /// may take an extra filter to perform selection at a finer granularity. 6382bc4c3e9SNicolas Vasilache struct VectorTransferFullPartialRewriter : public RewritePattern { 6392bc4c3e9SNicolas Vasilache using FilterConstraintType = 6402bc4c3e9SNicolas Vasilache std::function<LogicalResult(VectorTransferOpInterface op)>; 6412bc4c3e9SNicolas Vasilache 6422bc4c3e9SNicolas Vasilache explicit VectorTransferFullPartialRewriter( 6432bc4c3e9SNicolas Vasilache MLIRContext *context, 6442bc4c3e9SNicolas Vasilache VectorTransformsOptions options = VectorTransformsOptions(), 6452bc4c3e9SNicolas Vasilache FilterConstraintType filter = 6462bc4c3e9SNicolas Vasilache [](VectorTransferOpInterface op) { return success(); }, 6472bc4c3e9SNicolas Vasilache PatternBenefit benefit = 1) 6482bc4c3e9SNicolas Vasilache : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), 6492bc4c3e9SNicolas Vasilache filter(std::move(filter)) {} 6502bc4c3e9SNicolas Vasilache 6512bc4c3e9SNicolas Vasilache /// Performs the rewrite. 6522bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(Operation *op, 6532bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override; 6542bc4c3e9SNicolas Vasilache 6552bc4c3e9SNicolas Vasilache private: 6562bc4c3e9SNicolas Vasilache VectorTransformsOptions options; 6572bc4c3e9SNicolas Vasilache FilterConstraintType filter; 6582bc4c3e9SNicolas Vasilache }; 6592bc4c3e9SNicolas Vasilache 6602bc4c3e9SNicolas Vasilache } // namespace 6612bc4c3e9SNicolas Vasilache 6622bc4c3e9SNicolas Vasilache LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite( 66399ef9eebSMatthias Springer Operation *op, PatternRewriter &rewriter) const { 66499ef9eebSMatthias Springer auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 66599ef9eebSMatthias Springer if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 66699ef9eebSMatthias Springer failed(filter(xferOp))) 66799ef9eebSMatthias Springer return failure(); 6682fea658aSMatthias Springer return splitFullAndPartialTransfer(rewriter, xferOp, options); 66999ef9eebSMatthias Springer } 6702bc4c3e9SNicolas Vasilache 6712bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorTransferFullPartialPatterns( 6722bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, const VectorTransformsOptions &options) { 6732bc4c3e9SNicolas Vasilache patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(), 6742bc4c3e9SNicolas Vasilache options); 6752bc4c3e9SNicolas Vasilache } 676