199ef9eebSMatthias Springer //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer // 999ef9eebSMatthias Springer // This file implements functions concerned with optimizing transfer_read and 1099ef9eebSMatthias Springer // transfer_write ops. 1199ef9eebSMatthias Springer // 1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 136a8ba318SRiver Riddle 142ec98ffbSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1699ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1790d2f8c6SBenjamin Maxwell #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 182ec98ffbSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 19847048f4SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h" 2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 2499ef9eebSMatthias Springer #include "mlir/IR/Dominance.h" 25fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 2699ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h" 2799ef9eebSMatthias Springer #include "llvm/ADT/StringRef.h" 2899ef9eebSMatthias Springer #include "llvm/Support/Debug.h" 2999ef9eebSMatthias Springer 3099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-opt" 3199ef9eebSMatthias Springer 3299ef9eebSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 3399ef9eebSMatthias Springer 3499ef9eebSMatthias Springer using namespace mlir; 3599ef9eebSMatthias Springer 3699ef9eebSMatthias Springer /// Return the ancestor op in the region or nullptr if the region is not 3799ef9eebSMatthias Springer /// an ancestor of the op. 3899ef9eebSMatthias Springer static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 3999ef9eebSMatthias Springer for (; op != nullptr && op->getParentRegion() != region; 4099ef9eebSMatthias Springer op = op->getParentOp()) 4199ef9eebSMatthias Springer ; 4299ef9eebSMatthias Springer return op; 4399ef9eebSMatthias Springer } 4499ef9eebSMatthias Springer 4599ef9eebSMatthias Springer namespace { 4699ef9eebSMatthias Springer 4799ef9eebSMatthias Springer class TransferOptimization { 4899ef9eebSMatthias Springer public: 49553cebdeSNicolas Vasilache TransferOptimization(RewriterBase &rewriter, Operation *op) 50553cebdeSNicolas Vasilache : rewriter(rewriter), dominators(op), postDominators(op) {} 5199ef9eebSMatthias Springer void deadStoreOp(vector::TransferWriteOp); 5299ef9eebSMatthias Springer void storeToLoadForwarding(vector::TransferReadOp); 5399ef9eebSMatthias Springer void removeDeadOp() { 5499ef9eebSMatthias Springer for (Operation *op : opToErase) 55553cebdeSNicolas Vasilache rewriter.eraseOp(op); 5699ef9eebSMatthias Springer opToErase.clear(); 5799ef9eebSMatthias Springer } 5899ef9eebSMatthias Springer 5999ef9eebSMatthias Springer private: 60553cebdeSNicolas Vasilache RewriterBase &rewriter; 6199ef9eebSMatthias Springer bool isReachable(Operation *start, Operation *dest); 6299ef9eebSMatthias Springer DominanceInfo dominators; 6399ef9eebSMatthias Springer PostDominanceInfo postDominators; 6499ef9eebSMatthias Springer std::vector<Operation *> opToErase; 6599ef9eebSMatthias Springer }; 6699ef9eebSMatthias Springer 67834fcfedSDiego Caballero } // namespace 6899ef9eebSMatthias Springer /// Return true if there is a path from start operation to dest operation, 6999ef9eebSMatthias Springer /// otherwise return false. The operations have to be in the same region. 7099ef9eebSMatthias Springer bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 7199ef9eebSMatthias Springer assert(start->getParentRegion() == dest->getParentRegion() && 7299ef9eebSMatthias Springer "This function only works for ops i the same region"); 7399ef9eebSMatthias Springer // Simple case where the start op dominate the destination. 7499ef9eebSMatthias Springer if (dominators.dominates(start, dest)) 7599ef9eebSMatthias Springer return true; 76804d3c4cSMatthias Springer return start->getBlock()->isReachable(dest->getBlock()); 7799ef9eebSMatthias Springer } 7899ef9eebSMatthias Springer 7999ef9eebSMatthias Springer /// For transfer_write to overwrite fully another transfer_write must: 8099ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type. 8199ef9eebSMatthias Springer /// 2. Post-dominate the other transfer_write operation. 8299ef9eebSMatthias Springer /// If several candidates are available, one must be post-dominated by all the 8399ef9eebSMatthias Springer /// others since they are all post-dominating the same transfer_write. We only 8499ef9eebSMatthias Springer /// consider the transfer_write post-dominated by all the other candidates as 8599ef9eebSMatthias Springer /// this will be the first transfer_write executed after the potentially dead 8699ef9eebSMatthias Springer /// transfer_write. 8799ef9eebSMatthias Springer /// If we found such an overwriting transfer_write we know that the original 8899ef9eebSMatthias Springer /// transfer_write is dead if all reads that can be reached from the potentially 8999ef9eebSMatthias Springer /// dead transfer_write are dominated by the overwriting transfer_write. 9099ef9eebSMatthias Springer void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 9199ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 9299ef9eebSMatthias Springer << "\n"); 93bc13437bSThomas Raoux llvm::SmallVector<Operation *, 8> blockingAccesses; 9499ef9eebSMatthias Springer Operation *firstOverwriteCandidate = nullptr; 954e2efea5SQuinn Dawkins Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource())); 96bc13437bSThomas Raoux llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 97bc13437bSThomas Raoux source.getUsers().end()); 98bc13437bSThomas Raoux llvm::SmallDenseSet<Operation *, 32> processed; 99bc13437bSThomas Raoux while (!users.empty()) { 100bc13437bSThomas Raoux Operation *user = users.pop_back_val(); 101bc13437bSThomas Raoux // If the user has already been processed skip. 102bc13437bSThomas Raoux if (!processed.insert(user).second) 103bc13437bSThomas Raoux continue; 1044e2efea5SQuinn Dawkins if (isa<ViewLikeOpInterface>(user)) { 10590d2f8c6SBenjamin Maxwell users.append(user->getUsers().begin(), user->getUsers().end()); 106bc13437bSThomas Raoux continue; 107bc13437bSThomas Raoux } 10886771d0bSSanjoy Das if (isMemoryEffectFree(user)) 109bc13437bSThomas Raoux continue; 11099ef9eebSMatthias Springer if (user == write.getOperation()) 11199ef9eebSMatthias Springer continue; 11299ef9eebSMatthias Springer if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 11399ef9eebSMatthias Springer // Check candidate that can override the store. 11490d2f8c6SBenjamin Maxwell if (memref::isSameViewOrTrivialAlias( 11590d2f8c6SBenjamin Maxwell cast<MemrefValue>(nextWrite.getSource()), 11690d2f8c6SBenjamin Maxwell cast<MemrefValue>(write.getSource())) && 117bc13437bSThomas Raoux checkSameValueWAW(nextWrite, write) && 11899ef9eebSMatthias Springer postDominators.postDominates(nextWrite, write)) { 11999ef9eebSMatthias Springer if (firstOverwriteCandidate == nullptr || 12099ef9eebSMatthias Springer postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 12199ef9eebSMatthias Springer firstOverwriteCandidate = nextWrite; 12299ef9eebSMatthias Springer else 12399ef9eebSMatthias Springer assert( 12499ef9eebSMatthias Springer postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 12599ef9eebSMatthias Springer continue; 12699ef9eebSMatthias Springer } 12799ef9eebSMatthias Springer } 128bc13437bSThomas Raoux if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 129bc13437bSThomas Raoux // Don't need to consider disjoint accesses. 130bc13437bSThomas Raoux if (vector::isDisjointTransferSet( 131bc13437bSThomas Raoux cast<VectorTransferOpInterface>(write.getOperation()), 1323049ac44SLei Zhang cast<VectorTransferOpInterface>(transferOp.getOperation()), 1333049ac44SLei Zhang /*testDynamicValueUsingBounds=*/true)) 134bc13437bSThomas Raoux continue; 135bc13437bSThomas Raoux } 136bc13437bSThomas Raoux blockingAccesses.push_back(user); 13799ef9eebSMatthias Springer } 13899ef9eebSMatthias Springer if (firstOverwriteCandidate == nullptr) 13999ef9eebSMatthias Springer return; 14099ef9eebSMatthias Springer Region *topRegion = firstOverwriteCandidate->getParentRegion(); 14199ef9eebSMatthias Springer Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 14299ef9eebSMatthias Springer assert(writeAncestor && 14399ef9eebSMatthias Springer "write op should be recursively part of the top region"); 14499ef9eebSMatthias Springer 145bc13437bSThomas Raoux for (Operation *access : blockingAccesses) { 146bc13437bSThomas Raoux Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 147bc13437bSThomas Raoux // TODO: if the access and write have the same ancestor we could recurse in 148bc13437bSThomas Raoux // the region to know if the access is reachable with more precision. 149bc13437bSThomas Raoux if (accessAncestor == nullptr || 150bc13437bSThomas Raoux !isReachable(writeAncestor, accessAncestor)) 15199ef9eebSMatthias Springer continue; 152bc13437bSThomas Raoux if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 153bc13437bSThomas Raoux LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 154bc13437bSThomas Raoux << *accessAncestor << "\n"); 15599ef9eebSMatthias Springer return; 15699ef9eebSMatthias Springer } 15799ef9eebSMatthias Springer } 15899ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 15999ef9eebSMatthias Springer << " overwritten by: " << *firstOverwriteCandidate << "\n"); 16099ef9eebSMatthias Springer opToErase.push_back(write.getOperation()); 16199ef9eebSMatthias Springer } 16299ef9eebSMatthias Springer 16399ef9eebSMatthias Springer /// A transfer_write candidate to storeToLoad forwarding must: 16499ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type as the 16599ef9eebSMatthias Springer /// transfer_read. 16699ef9eebSMatthias Springer /// 2. Dominate the transfer_read operation. 16799ef9eebSMatthias Springer /// If several candidates are available, one must be dominated by all the others 16899ef9eebSMatthias Springer /// since they are all dominating the same transfer_read. We only consider the 16999ef9eebSMatthias Springer /// transfer_write dominated by all the other candidates as this will be the 17099ef9eebSMatthias Springer /// last transfer_write executed before the transfer_read. 17199ef9eebSMatthias Springer /// If we found such a candidate we can do the forwarding if all the other 17299ef9eebSMatthias Springer /// potentially aliasing ops that may reach the transfer_read are post-dominated 17399ef9eebSMatthias Springer /// by the transfer_write. 17499ef9eebSMatthias Springer void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 17599ef9eebSMatthias Springer if (read.hasOutOfBoundsDim()) 17699ef9eebSMatthias Springer return; 17799ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 17899ef9eebSMatthias Springer << "\n"); 17999ef9eebSMatthias Springer SmallVector<Operation *, 8> blockingWrites; 18099ef9eebSMatthias Springer vector::TransferWriteOp lastwrite = nullptr; 1814e2efea5SQuinn Dawkins Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource())); 182bc13437bSThomas Raoux llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 183bc13437bSThomas Raoux source.getUsers().end()); 184bc13437bSThomas Raoux llvm::SmallDenseSet<Operation *, 32> processed; 185bc13437bSThomas Raoux while (!users.empty()) { 186bc13437bSThomas Raoux Operation *user = users.pop_back_val(); 187bc13437bSThomas Raoux // If the user has already been processed skip. 188bc13437bSThomas Raoux if (!processed.insert(user).second) 189bc13437bSThomas Raoux continue; 1904e2efea5SQuinn Dawkins if (isa<ViewLikeOpInterface>(user)) { 19190d2f8c6SBenjamin Maxwell users.append(user->getUsers().begin(), user->getUsers().end()); 19222f96ab6SAndrzej Warzyński continue; 19322f96ab6SAndrzej Warzyński } 19486771d0bSSanjoy Das if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 19599ef9eebSMatthias Springer continue; 19699ef9eebSMatthias Springer if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 19799ef9eebSMatthias Springer // If there is a write, but we can prove that it is disjoint we can ignore 19899ef9eebSMatthias Springer // the write. 19999ef9eebSMatthias Springer if (vector::isDisjointTransferSet( 20099ef9eebSMatthias Springer cast<VectorTransferOpInterface>(write.getOperation()), 2013049ac44SLei Zhang cast<VectorTransferOpInterface>(read.getOperation()), 2023049ac44SLei Zhang /*testDynamicValueUsingBounds=*/true)) 20399ef9eebSMatthias Springer continue; 20490d2f8c6SBenjamin Maxwell if (memref::isSameViewOrTrivialAlias( 20590d2f8c6SBenjamin Maxwell cast<MemrefValue>(read.getSource()), 20690d2f8c6SBenjamin Maxwell cast<MemrefValue>(write.getSource())) && 207bc13437bSThomas Raoux dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 20899ef9eebSMatthias Springer if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 20999ef9eebSMatthias Springer lastwrite = write; 21099ef9eebSMatthias Springer else 21199ef9eebSMatthias Springer assert(dominators.dominates(write, lastwrite)); 21299ef9eebSMatthias Springer continue; 21399ef9eebSMatthias Springer } 21499ef9eebSMatthias Springer } 21599ef9eebSMatthias Springer blockingWrites.push_back(user); 21699ef9eebSMatthias Springer } 21799ef9eebSMatthias Springer 21899ef9eebSMatthias Springer if (lastwrite == nullptr) 21999ef9eebSMatthias Springer return; 22099ef9eebSMatthias Springer 22199ef9eebSMatthias Springer Region *topRegion = lastwrite->getParentRegion(); 22299ef9eebSMatthias Springer Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 22399ef9eebSMatthias Springer assert(readAncestor && 22499ef9eebSMatthias Springer "read op should be recursively part of the top region"); 22599ef9eebSMatthias Springer 22699ef9eebSMatthias Springer for (Operation *write : blockingWrites) { 22799ef9eebSMatthias Springer Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 22899ef9eebSMatthias Springer // TODO: if the store and read have the same ancestor we could recurse in 22999ef9eebSMatthias Springer // the region to know if the read is reachable with more precision. 23099ef9eebSMatthias Springer if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 23199ef9eebSMatthias Springer continue; 23299ef9eebSMatthias Springer if (!postDominators.postDominates(lastwrite, write)) { 23399ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 23499ef9eebSMatthias Springer << *write << "\n"); 23599ef9eebSMatthias Springer return; 23699ef9eebSMatthias Springer } 23799ef9eebSMatthias Springer } 23899ef9eebSMatthias Springer 23999ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 24099ef9eebSMatthias Springer << " to: " << *read.getOperation() << "\n"); 2417c38fd60SJacques Pienaar read.replaceAllUsesWith(lastwrite.getVector()); 24299ef9eebSMatthias Springer opToErase.push_back(read.getOperation()); 24399ef9eebSMatthias Springer } 24499ef9eebSMatthias Springer 245bf897d5dSCullen Rhodes /// Converts OpFoldResults to int64_t shape without unit dims. 246bf897d5dSCullen Rhodes static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 247bf897d5dSCullen Rhodes SmallVector<int64_t> reducedShape; 248bf897d5dSCullen Rhodes for (const auto size : mixedSizes) { 249bf897d5dSCullen Rhodes if (llvm::dyn_cast_if_present<Value>(size)) { 250bf897d5dSCullen Rhodes reducedShape.push_back(ShapedType::kDynamic); 251bf897d5dSCullen Rhodes continue; 252bf897d5dSCullen Rhodes } 253bf897d5dSCullen Rhodes 2546e41483bSKazu Hirata auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue(); 255bf897d5dSCullen Rhodes if (value == 1) 256bf897d5dSCullen Rhodes continue; 257bf897d5dSCullen Rhodes reducedShape.push_back(value.getSExtValue()); 258bf897d5dSCullen Rhodes } 259bf897d5dSCullen Rhodes return reducedShape; 260bf897d5dSCullen Rhodes } 261bf897d5dSCullen Rhodes 26299ef9eebSMatthias Springer /// Drops unit dimensions from the input MemRefType. 263bf897d5dSCullen Rhodes static MemRefType dropUnitDims(MemRefType inputType, 264bf897d5dSCullen Rhodes ArrayRef<OpFoldResult> offsets, 265bf897d5dSCullen Rhodes ArrayRef<OpFoldResult> sizes, 266bf897d5dSCullen Rhodes ArrayRef<OpFoldResult> strides) { 267bf897d5dSCullen Rhodes auto targetShape = getReducedShape(sizes); 26899ef9eebSMatthias Springer Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 2696c3c5f80SMatthias Springer targetShape, inputType, offsets, sizes, strides); 270*6aaa8f25SMatthias Springer return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout(); 27199ef9eebSMatthias Springer } 27299ef9eebSMatthias Springer 27399ef9eebSMatthias Springer /// Creates a rank-reducing memref.subview op that drops unit dims from its 27499ef9eebSMatthias Springer /// input. Or just returns the input if it was already without unit dims. 27599ef9eebSMatthias Springer static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 27699ef9eebSMatthias Springer mlir::Location loc, 27799ef9eebSMatthias Springer Value input) { 2785550c821STres Popp MemRefType inputType = cast<MemRefType>(input.getType()); 279bf897d5dSCullen Rhodes SmallVector<OpFoldResult> offsets(inputType.getRank(), 280bf897d5dSCullen Rhodes rewriter.getIndexAttr(0)); 281bf897d5dSCullen Rhodes SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 282bf897d5dSCullen Rhodes SmallVector<OpFoldResult> strides(inputType.getRank(), 283bf897d5dSCullen Rhodes rewriter.getIndexAttr(1)); 284bf897d5dSCullen Rhodes MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 285bf897d5dSCullen Rhodes 286*6aaa8f25SMatthias Springer if (resultType.canonicalizeStridedLayout() == 287*6aaa8f25SMatthias Springer inputType.canonicalizeStridedLayout()) 28899ef9eebSMatthias Springer return input; 289bf897d5dSCullen Rhodes return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 290bf897d5dSCullen Rhodes sizes, strides); 29199ef9eebSMatthias Springer } 29299ef9eebSMatthias Springer 29399ef9eebSMatthias Springer /// Returns the number of dims that aren't unit dims. 29499ef9eebSMatthias Springer static int getReducedRank(ArrayRef<int64_t> shape) { 29599ef9eebSMatthias Springer return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 29699ef9eebSMatthias Springer } 29799ef9eebSMatthias Springer 298bf897d5dSCullen Rhodes /// Trims non-scalable one dimensions from `oldType` and returns the result 299bf897d5dSCullen Rhodes /// type. 300bf897d5dSCullen Rhodes static VectorType trimNonScalableUnitDims(VectorType oldType) { 301bf897d5dSCullen Rhodes SmallVector<int64_t> newShape; 302bf897d5dSCullen Rhodes SmallVector<bool> newScalableDims; 303bf897d5dSCullen Rhodes for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 304bf897d5dSCullen Rhodes if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 305bf897d5dSCullen Rhodes continue; 306bf897d5dSCullen Rhodes newShape.push_back(dimSize); 307bf897d5dSCullen Rhodes newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 308bf897d5dSCullen Rhodes } 309bf897d5dSCullen Rhodes return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 310bf897d5dSCullen Rhodes } 311bf897d5dSCullen Rhodes 312bf897d5dSCullen Rhodes // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 313bf897d5dSCullen Rhodes static FailureOr<Value> 314bf897d5dSCullen Rhodes createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 315bf897d5dSCullen Rhodes vector::CreateMaskOp op) { 316bf897d5dSCullen Rhodes auto type = op.getType(); 317fdf84cbfSQuinn Dawkins VectorType reducedType = trimNonScalableUnitDims(type); 318bf897d5dSCullen Rhodes if (reducedType.getRank() == type.getRank()) 319bf897d5dSCullen Rhodes return failure(); 320bf897d5dSCullen Rhodes 321bf897d5dSCullen Rhodes SmallVector<Value> reducedOperands; 322bf897d5dSCullen Rhodes for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 323bf897d5dSCullen Rhodes type.getShape(), type.getScalableDims(), op.getOperands())) { 324bf897d5dSCullen Rhodes if (dim == 1 && !dimIsScalable) { 325bf897d5dSCullen Rhodes // If the mask for the unit dim is not a constant of 1, do nothing. 326bf897d5dSCullen Rhodes auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 327bf897d5dSCullen Rhodes if (!constant || (constant.value() != 1)) 328bf897d5dSCullen Rhodes return failure(); 329bf897d5dSCullen Rhodes continue; 330bf897d5dSCullen Rhodes } 331bf897d5dSCullen Rhodes reducedOperands.push_back(operand); 332bf897d5dSCullen Rhodes } 333bf897d5dSCullen Rhodes return rewriter 334bf897d5dSCullen Rhodes .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 335bf897d5dSCullen Rhodes .getResult(); 336834fcfedSDiego Caballero } 337834fcfedSDiego Caballero 338834fcfedSDiego Caballero namespace { 339834fcfedSDiego Caballero 340834fcfedSDiego Caballero /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 341834fcfedSDiego Caballero /// inserting a memref.subview dropping those unit dims. The vector shapes are 342834fcfedSDiego Caballero /// also reduced accordingly. 34399ef9eebSMatthias Springer class TransferReadDropUnitDimsPattern 3440cf7aaf3SAndrzej Warzyński : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { 3450cf7aaf3SAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 34699ef9eebSMatthias Springer 3470cf7aaf3SAndrzej Warzyński FailureOr<Value> 3480cf7aaf3SAndrzej Warzyński matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, 3490cf7aaf3SAndrzej Warzyński vector::MaskingOpInterface maskingOp, 35099ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 35199ef9eebSMatthias Springer auto loc = transferReadOp.getLoc(); 3527c38fd60SJacques Pienaar Value vector = transferReadOp.getVector(); 3535550c821STres Popp VectorType vectorType = cast<VectorType>(vector.getType()); 3547c38fd60SJacques Pienaar Value source = transferReadOp.getSource(); 3555550c821STres Popp MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 35699ef9eebSMatthias Springer // TODO: support tensor types. 357bf897d5dSCullen Rhodes if (!sourceType) 35899ef9eebSMatthias Springer return failure(); 35999ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here. 36099ef9eebSMatthias Springer if (transferReadOp.hasOutOfBoundsDim()) 36199ef9eebSMatthias Springer return failure(); 3627c38fd60SJacques Pienaar if (!transferReadOp.getPermutationMap().isMinorIdentity()) 36399ef9eebSMatthias Springer return failure(); 364834fcfedSDiego Caballero // Check if the source shape can be further reduced. 36599ef9eebSMatthias Springer int reducedRank = getReducedRank(sourceType.getShape()); 36699ef9eebSMatthias Springer if (reducedRank == sourceType.getRank()) 367834fcfedSDiego Caballero return failure(); 3680cf7aaf3SAndrzej Warzyński // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 3690cf7aaf3SAndrzej Warzyński // out. 3700cf7aaf3SAndrzej Warzyński if (reducedRank == 0 && maskingOp) 3710cf7aaf3SAndrzej Warzyński return failure(); 372834fcfedSDiego Caballero // Check if the reduced vector shape matches the reduced source shape. 373834fcfedSDiego Caballero // Otherwise, this case is not supported yet. 374fdf84cbfSQuinn Dawkins VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 375bf897d5dSCullen Rhodes if (reducedRank != reducedVectorType.getRank()) 376834fcfedSDiego Caballero return failure(); 377cb7bda2aSMatthias Springer if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 378cb7bda2aSMatthias Springer return getConstantIntValue(v) != static_cast<int64_t>(0); 379cb7bda2aSMatthias Springer })) 38099ef9eebSMatthias Springer return failure(); 381bf897d5dSCullen Rhodes 382bf897d5dSCullen Rhodes Value maskOp = transferReadOp.getMask(); 383bf897d5dSCullen Rhodes if (maskOp) { 384bf897d5dSCullen Rhodes auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 385bf897d5dSCullen Rhodes if (!createMaskOp) 386bf897d5dSCullen Rhodes return rewriter.notifyMatchFailure( 387bf897d5dSCullen Rhodes transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 388bf897d5dSCullen Rhodes "currently supported"); 389bf897d5dSCullen Rhodes FailureOr<Value> rankReducedCreateMask = 390bf897d5dSCullen Rhodes createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 391bf897d5dSCullen Rhodes if (failed(rankReducedCreateMask)) 392bf897d5dSCullen Rhodes return failure(); 393bf897d5dSCullen Rhodes maskOp = *rankReducedCreateMask; 394bf897d5dSCullen Rhodes } 395bf897d5dSCullen Rhodes 39699ef9eebSMatthias Springer Value reducedShapeSource = 39799ef9eebSMatthias Springer rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 39899ef9eebSMatthias Springer Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 39999ef9eebSMatthias Springer SmallVector<Value> zeros(reducedRank, c0); 40099ef9eebSMatthias Springer auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 401bf897d5dSCullen Rhodes SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 4020cf7aaf3SAndrzej Warzyński Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( 403bf897d5dSCullen Rhodes loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 404bf897d5dSCullen Rhodes transferReadOp.getPadding(), maskOp, 405bf897d5dSCullen Rhodes rewriter.getBoolArrayAttr(inBounds)); 406834fcfedSDiego Caballero 4070cf7aaf3SAndrzej Warzyński if (maskingOp) { 4080cf7aaf3SAndrzej Warzyński auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 4090cf7aaf3SAndrzej Warzyński loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 4100cf7aaf3SAndrzej Warzyński maskingOp.getMask()); 4110cf7aaf3SAndrzej Warzyński newTransferReadOp = mlir::vector::maskOperation( 4120cf7aaf3SAndrzej Warzyński rewriter, newTransferReadOp, shapeCastMask); 4130cf7aaf3SAndrzej Warzyński } 4140cf7aaf3SAndrzej Warzyński 4150cf7aaf3SAndrzej Warzyński auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 4160cf7aaf3SAndrzej Warzyński loc, vectorType, newTransferReadOp->getResults()[0]); 4170cf7aaf3SAndrzej Warzyński 4180cf7aaf3SAndrzej Warzyński return shapeCast; 41999ef9eebSMatthias Springer } 42099ef9eebSMatthias Springer }; 42199ef9eebSMatthias Springer 422834fcfedSDiego Caballero /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 423834fcfedSDiego Caballero /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 424834fcfedSDiego Caballero /// vector shapes are also reduced accordingly. 42599ef9eebSMatthias Springer class TransferWriteDropUnitDimsPattern 4260cf7aaf3SAndrzej Warzyński : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { 4270cf7aaf3SAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 42899ef9eebSMatthias Springer 4290cf7aaf3SAndrzej Warzyński FailureOr<Value> 4300cf7aaf3SAndrzej Warzyński matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, 4310cf7aaf3SAndrzej Warzyński vector::MaskingOpInterface maskingOp, 43299ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 43399ef9eebSMatthias Springer auto loc = transferWriteOp.getLoc(); 4347c38fd60SJacques Pienaar Value vector = transferWriteOp.getVector(); 4355550c821STres Popp VectorType vectorType = cast<VectorType>(vector.getType()); 4367c38fd60SJacques Pienaar Value source = transferWriteOp.getSource(); 4375550c821STres Popp MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 43899ef9eebSMatthias Springer // TODO: support tensor type. 439fdf84cbfSQuinn Dawkins if (!sourceType) 44099ef9eebSMatthias Springer return failure(); 44199ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here. 44299ef9eebSMatthias Springer if (transferWriteOp.hasOutOfBoundsDim()) 44399ef9eebSMatthias Springer return failure(); 4447c38fd60SJacques Pienaar if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 44599ef9eebSMatthias Springer return failure(); 446834fcfedSDiego Caballero // Check if the destination shape can be further reduced. 44799ef9eebSMatthias Springer int reducedRank = getReducedRank(sourceType.getShape()); 44899ef9eebSMatthias Springer if (reducedRank == sourceType.getRank()) 449834fcfedSDiego Caballero return failure(); 4500cf7aaf3SAndrzej Warzyński // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 4510cf7aaf3SAndrzej Warzyński // out. 4520cf7aaf3SAndrzej Warzyński if (reducedRank == 0 && maskingOp) 4530cf7aaf3SAndrzej Warzyński return failure(); 454834fcfedSDiego Caballero // Check if the reduced vector shape matches the reduced destination shape. 455834fcfedSDiego Caballero // Otherwise, this case is not supported yet. 456fdf84cbfSQuinn Dawkins VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 457fdf84cbfSQuinn Dawkins if (reducedRank != reducedVectorType.getRank()) 458834fcfedSDiego Caballero return failure(); 459cb7bda2aSMatthias Springer if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 460cb7bda2aSMatthias Springer return getConstantIntValue(v) != static_cast<int64_t>(0); 461cb7bda2aSMatthias Springer })) 46299ef9eebSMatthias Springer return failure(); 463fdf84cbfSQuinn Dawkins 464fdf84cbfSQuinn Dawkins Value maskOp = transferWriteOp.getMask(); 465fdf84cbfSQuinn Dawkins if (maskOp) { 466fdf84cbfSQuinn Dawkins auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 467fdf84cbfSQuinn Dawkins if (!createMaskOp) 468fdf84cbfSQuinn Dawkins return rewriter.notifyMatchFailure( 469fdf84cbfSQuinn Dawkins transferWriteOp, 470fdf84cbfSQuinn Dawkins "unsupported mask op, only 'vector.create_mask' is " 471fdf84cbfSQuinn Dawkins "currently supported"); 472fdf84cbfSQuinn Dawkins FailureOr<Value> rankReducedCreateMask = 473fdf84cbfSQuinn Dawkins createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 474fdf84cbfSQuinn Dawkins if (failed(rankReducedCreateMask)) 475fdf84cbfSQuinn Dawkins return failure(); 476fdf84cbfSQuinn Dawkins maskOp = *rankReducedCreateMask; 477fdf84cbfSQuinn Dawkins } 47899ef9eebSMatthias Springer Value reducedShapeSource = 47999ef9eebSMatthias Springer rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 48099ef9eebSMatthias Springer Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 48199ef9eebSMatthias Springer SmallVector<Value> zeros(reducedRank, c0); 48299ef9eebSMatthias Springer auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 483fdf84cbfSQuinn Dawkins SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 4840cf7aaf3SAndrzej Warzyński auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( 485834fcfedSDiego Caballero loc, reducedVectorType, vector); 4860cf7aaf3SAndrzej Warzyński Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( 4870cf7aaf3SAndrzej Warzyński loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, 4880cf7aaf3SAndrzej Warzyński maskOp, rewriter.getBoolArrayAttr(inBounds)); 489834fcfedSDiego Caballero 4900cf7aaf3SAndrzej Warzyński if (maskingOp) { 4910cf7aaf3SAndrzej Warzyński auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 4920cf7aaf3SAndrzej Warzyński loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 4930cf7aaf3SAndrzej Warzyński maskingOp.getMask()); 4940cf7aaf3SAndrzej Warzyński newXferWrite = 4950cf7aaf3SAndrzej Warzyński mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask); 4960cf7aaf3SAndrzej Warzyński } 4970cf7aaf3SAndrzej Warzyński 4980cf7aaf3SAndrzej Warzyński if (transferWriteOp.hasPureTensorSemantics()) 4990cf7aaf3SAndrzej Warzyński return newXferWrite->getResults()[0]; 5000cf7aaf3SAndrzej Warzyński 5010cf7aaf3SAndrzej Warzyński // With Memref semantics, there's no return value. Use empty value to signal 5020cf7aaf3SAndrzej Warzyński // success. 5030cf7aaf3SAndrzej Warzyński return Value(); 50499ef9eebSMatthias Springer } 50599ef9eebSMatthias Springer }; 50699ef9eebSMatthias Springer 507834fcfedSDiego Caballero } // namespace 508834fcfedSDiego Caballero 509f4ac9509SBenoit Jacob /// Creates a memref.collapse_shape collapsing all inner dimensions of the 510f4ac9509SBenoit Jacob /// input starting at `firstDimToCollapse`. 511f4ac9509SBenoit Jacob static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 512f4ac9509SBenoit Jacob Value input, int64_t firstDimToCollapse) { 5135550c821STres Popp ShapedType inputType = cast<ShapedType>(input.getType()); 514f4ac9509SBenoit Jacob if (inputType.getRank() == 1) 515f4ac9509SBenoit Jacob return input; 516f4ac9509SBenoit Jacob SmallVector<ReassociationIndices> reassociation; 517f4ac9509SBenoit Jacob for (int64_t i = 0; i < firstDimToCollapse; ++i) 518f4ac9509SBenoit Jacob reassociation.push_back(ReassociationIndices{i}); 519f4ac9509SBenoit Jacob ReassociationIndices collapsedIndices; 520f4ac9509SBenoit Jacob for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 521f4ac9509SBenoit Jacob collapsedIndices.push_back(i); 522f4ac9509SBenoit Jacob reassociation.push_back(collapsedIndices); 523f4ac9509SBenoit Jacob return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 524f4ac9509SBenoit Jacob } 525f4ac9509SBenoit Jacob 52653ddc874SHan-Chung Wang /// Returns the new indices that collapses the inner dimensions starting from 52753ddc874SHan-Chung Wang /// the `firstDimToCollapse` dimension. 52853ddc874SHan-Chung Wang static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, 52953ddc874SHan-Chung Wang Location loc, 53053ddc874SHan-Chung Wang ArrayRef<int64_t> shape, 53153ddc874SHan-Chung Wang ValueRange indices, 53253ddc874SHan-Chung Wang int64_t firstDimToCollapse) { 53353ddc874SHan-Chung Wang assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); 53453ddc874SHan-Chung Wang 53553ddc874SHan-Chung Wang // If all the collapsed indices are zero then no extra logic is needed. 53653ddc874SHan-Chung Wang // Otherwise, a new offset/index has to be computed. 53753ddc874SHan-Chung Wang SmallVector<Value> indicesAfterCollapsing( 53853ddc874SHan-Chung Wang indices.begin(), indices.begin() + firstDimToCollapse); 53953ddc874SHan-Chung Wang SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, 54053ddc874SHan-Chung Wang indices.end()); 54153ddc874SHan-Chung Wang if (llvm::all_of(indicesToCollapse, isZeroIndex)) { 54253ddc874SHan-Chung Wang indicesAfterCollapsing.push_back(indicesToCollapse[0]); 54353ddc874SHan-Chung Wang return indicesAfterCollapsing; 544f4ac9509SBenoit Jacob } 54553ddc874SHan-Chung Wang 54653ddc874SHan-Chung Wang // Compute the remaining trailing index/offset required for reading from 54753ddc874SHan-Chung Wang // the collapsed memref: 54853ddc874SHan-Chung Wang // 54953ddc874SHan-Chung Wang // offset = 0 55053ddc874SHan-Chung Wang // for (i = firstDimToCollapse; i < outputRank; ++i) 55153ddc874SHan-Chung Wang // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] 55253ddc874SHan-Chung Wang // 55353ddc874SHan-Chung Wang // For this example: 55453ddc874SHan-Chung Wang // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : 55553ddc874SHan-Chung Wang // memref<1x43x2xi32>, vector<1x2xi32> 55653ddc874SHan-Chung Wang // which would be collapsed to: 55753ddc874SHan-Chung Wang // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : 55853ddc874SHan-Chung Wang // memref<1x86xi32>, vector<2xi32> 55953ddc874SHan-Chung Wang // one would get the following offset: 56053ddc874SHan-Chung Wang // %offset = %arg0 * 43 56153ddc874SHan-Chung Wang OpFoldResult collapsedOffset = 56253ddc874SHan-Chung Wang rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); 56353ddc874SHan-Chung Wang 56453ddc874SHan-Chung Wang auto collapsedStrides = computeSuffixProduct( 56553ddc874SHan-Chung Wang ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); 56653ddc874SHan-Chung Wang 56753ddc874SHan-Chung Wang // Compute the collapsed offset. 56853ddc874SHan-Chung Wang auto &&[collapsedExpr, collapsedVals] = 56953ddc874SHan-Chung Wang computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse); 57053ddc874SHan-Chung Wang collapsedOffset = affine::makeComposedFoldedAffineApply( 57153ddc874SHan-Chung Wang rewriter, loc, collapsedExpr, collapsedVals); 57253ddc874SHan-Chung Wang 5736e41483bSKazu Hirata if (auto value = dyn_cast<Value>(collapsedOffset)) { 5746e41483bSKazu Hirata indicesAfterCollapsing.push_back(value); 57553ddc874SHan-Chung Wang } else { 57653ddc874SHan-Chung Wang indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>( 57753ddc874SHan-Chung Wang loc, *getConstantIntValue(collapsedOffset))); 57853ddc874SHan-Chung Wang } 57953ddc874SHan-Chung Wang 58053ddc874SHan-Chung Wang return indicesAfterCollapsing; 58199ef9eebSMatthias Springer } 58299ef9eebSMatthias Springer 583834fcfedSDiego Caballero namespace { 584834fcfedSDiego Caballero 58599ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_read ops by inserting 58699ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting 58799ef9eebSMatthias Springer /// vector.transfer_read has a 1D source. Requires the source shape to be 58899ef9eebSMatthias Springer /// already reduced i.e. without unit dims. 58934de7fd4SAndrzej Warzyński /// 59071441ed1SDiego Caballero /// If `targetVectorBitwidth` is provided, the flattening will only happen if 59171441ed1SDiego Caballero /// the trailing dimension of the vector read is smaller than the provided 59271441ed1SDiego Caballero /// bitwidth. 59399ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferReadPattern 59499ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> { 59571441ed1SDiego Caballero public: 59671441ed1SDiego Caballero FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, 59771441ed1SDiego Caballero unsigned vectorBitwidth, 59871441ed1SDiego Caballero PatternBenefit benefit) 59971441ed1SDiego Caballero : OpRewritePattern<vector::TransferReadOp>(context, benefit), 60071441ed1SDiego Caballero targetVectorBitwidth(vectorBitwidth) {} 60199ef9eebSMatthias Springer 60299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 60399ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 60499ef9eebSMatthias Springer auto loc = transferReadOp.getLoc(); 6057c38fd60SJacques Pienaar Value vector = transferReadOp.getVector(); 6065550c821STres Popp VectorType vectorType = cast<VectorType>(vector.getType()); 6072eb9e33cSAndrzej Warzyński auto source = transferReadOp.getSource(); 6085550c821STres Popp MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 6092eb9e33cSAndrzej Warzyński 6102eb9e33cSAndrzej Warzyński // 0. Check pre-conditions 61199ef9eebSMatthias Springer // Contiguity check is valid on tensors only. 61299ef9eebSMatthias Springer if (!sourceType) 61399ef9eebSMatthias Springer return failure(); 6142eb9e33cSAndrzej Warzyński // If this is already 0D/1D, there's nothing to do. 6154a876b13Sharsh if (vectorType.getRank() <= 1) 61699ef9eebSMatthias Springer return failure(); 61771441ed1SDiego Caballero if (!vectorType.getElementType().isSignlessIntOrFloat()) 61871441ed1SDiego Caballero return failure(); 61971441ed1SDiego Caballero unsigned trailingVectorDimBitwidth = 62071441ed1SDiego Caballero vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 62171441ed1SDiego Caballero if (trailingVectorDimBitwidth >= targetVectorBitwidth) 62271441ed1SDiego Caballero return failure(); 6238171eac2SAndrzej Warzyński if (!vector::isContiguousSlice(sourceType, vectorType)) 62499ef9eebSMatthias Springer return failure(); 62599ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here. 62699ef9eebSMatthias Springer if (transferReadOp.hasOutOfBoundsDim()) 62799ef9eebSMatthias Springer return failure(); 6287c38fd60SJacques Pienaar if (!transferReadOp.getPermutationMap().isMinorIdentity()) 62999ef9eebSMatthias Springer return failure(); 6307c38fd60SJacques Pienaar if (transferReadOp.getMask()) 63199ef9eebSMatthias Springer return failure(); 6322eb9e33cSAndrzej Warzyński 6332eb9e33cSAndrzej Warzyński int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 6342eb9e33cSAndrzej Warzyński 6352eb9e33cSAndrzej Warzyński // 1. Collapse the source memref 636f4ac9509SBenoit Jacob Value collapsedSource = 6372eb9e33cSAndrzej Warzyński collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 638f4ac9509SBenoit Jacob MemRefType collapsedSourceType = 63934de7fd4SAndrzej Warzyński cast<MemRefType>(collapsedSource.getType()); 640f4ac9509SBenoit Jacob int64_t collapsedRank = collapsedSourceType.getRank(); 6412eb9e33cSAndrzej Warzyński assert(collapsedRank == firstDimToCollapse + 1); 6422eb9e33cSAndrzej Warzyński 6432eb9e33cSAndrzej Warzyński // 2. Generate input args for a new vector.transfer_read that will read 6442eb9e33cSAndrzej Warzyński // from the collapsed memref. 6452eb9e33cSAndrzej Warzyński // 2.1. New dim exprs + affine map 646f4ac9509SBenoit Jacob SmallVector<AffineExpr, 1> dimExprs{ 6472eb9e33cSAndrzej Warzyński getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 648f4ac9509SBenoit Jacob auto collapsedMap = 649f4ac9509SBenoit Jacob AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 6502eb9e33cSAndrzej Warzyński 6512eb9e33cSAndrzej Warzyński // 2.2 New indices 65253ddc874SHan-Chung Wang SmallVector<Value> collapsedIndices = 65353ddc874SHan-Chung Wang getCollapsedIndices(rewriter, loc, sourceType.getShape(), 65453ddc874SHan-Chung Wang transferReadOp.getIndices(), firstDimToCollapse); 6552eb9e33cSAndrzej Warzyński 6562eb9e33cSAndrzej Warzyński // 3. Create new vector.transfer_read that reads from the collapsed memref 657f4ac9509SBenoit Jacob VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 658f4ac9509SBenoit Jacob vectorType.getElementType()); 659f4ac9509SBenoit Jacob vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 660f4ac9509SBenoit Jacob loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 661f4ac9509SBenoit Jacob flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 6622eb9e33cSAndrzej Warzyński 6632eb9e33cSAndrzej Warzyński // 4. Replace the old transfer_read with the new one reading from the 6642eb9e33cSAndrzej Warzyński // collapsed shape 66599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 6665550c821STres Popp transferReadOp, cast<VectorType>(vector.getType()), flatRead); 66799ef9eebSMatthias Springer return success(); 66899ef9eebSMatthias Springer } 66971441ed1SDiego Caballero 67071441ed1SDiego Caballero private: 67171441ed1SDiego Caballero // Minimum bitwidth that the trailing vector dimension should have after 67271441ed1SDiego Caballero // flattening. 67371441ed1SDiego Caballero unsigned targetVectorBitwidth; 67499ef9eebSMatthias Springer }; 67599ef9eebSMatthias Springer 67699ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_write ops by inserting 67799ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting 67899ef9eebSMatthias Springer /// vector.transfer_write has a 1D source. Requires the source shape to be 67999ef9eebSMatthias Springer /// already reduced i.e. without unit dims. 68034de7fd4SAndrzej Warzyński /// 68134de7fd4SAndrzej Warzyński /// If `targetVectorBitwidth` is provided, the flattening will only happen if 68234de7fd4SAndrzej Warzyński /// the trailing dimension of the vector read is smaller than the provided 68334de7fd4SAndrzej Warzyński /// bitwidth. 68499ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferWritePattern 68599ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> { 68671441ed1SDiego Caballero public: 68771441ed1SDiego Caballero FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, 68871441ed1SDiego Caballero unsigned vectorBitwidth, 68971441ed1SDiego Caballero PatternBenefit benefit) 69071441ed1SDiego Caballero : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 69171441ed1SDiego Caballero targetVectorBitwidth(vectorBitwidth) {} 69299ef9eebSMatthias Springer 69399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 69499ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 69599ef9eebSMatthias Springer auto loc = transferWriteOp.getLoc(); 6967c38fd60SJacques Pienaar Value vector = transferWriteOp.getVector(); 6975550c821STres Popp VectorType vectorType = cast<VectorType>(vector.getType()); 6987c38fd60SJacques Pienaar Value source = transferWriteOp.getSource(); 6995550c821STres Popp MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 70034de7fd4SAndrzej Warzyński 70134de7fd4SAndrzej Warzyński // 0. Check pre-conditions 70299ef9eebSMatthias Springer // Contiguity check is valid on tensors only. 70399ef9eebSMatthias Springer if (!sourceType) 70499ef9eebSMatthias Springer return failure(); 70534de7fd4SAndrzej Warzyński // If this is already 0D/1D, there's nothing to do. 7064a876b13Sharsh if (vectorType.getRank() <= 1) 7074a876b13Sharsh // Already 0D/1D, nothing to do. 70899ef9eebSMatthias Springer return failure(); 70971441ed1SDiego Caballero if (!vectorType.getElementType().isSignlessIntOrFloat()) 71071441ed1SDiego Caballero return failure(); 71171441ed1SDiego Caballero unsigned trailingVectorDimBitwidth = 71271441ed1SDiego Caballero vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 71371441ed1SDiego Caballero if (trailingVectorDimBitwidth >= targetVectorBitwidth) 71471441ed1SDiego Caballero return failure(); 7158171eac2SAndrzej Warzyński if (!vector::isContiguousSlice(sourceType, vectorType)) 71699ef9eebSMatthias Springer return failure(); 71799ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here. 71899ef9eebSMatthias Springer if (transferWriteOp.hasOutOfBoundsDim()) 71999ef9eebSMatthias Springer return failure(); 7207c38fd60SJacques Pienaar if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 72199ef9eebSMatthias Springer return failure(); 7227c38fd60SJacques Pienaar if (transferWriteOp.getMask()) 72399ef9eebSMatthias Springer return failure(); 72453ddc874SHan-Chung Wang 72534de7fd4SAndrzej Warzyński int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 726847048f4SDiego Caballero 72734de7fd4SAndrzej Warzyński // 1. Collapse the source memref 728f4ac9509SBenoit Jacob Value collapsedSource = 72953ddc874SHan-Chung Wang collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 730f4ac9509SBenoit Jacob MemRefType collapsedSourceType = 7315550c821STres Popp cast<MemRefType>(collapsedSource.getType()); 732f4ac9509SBenoit Jacob int64_t collapsedRank = collapsedSourceType.getRank(); 73353ddc874SHan-Chung Wang assert(collapsedRank == firstDimToCollapse + 1); 73453ddc874SHan-Chung Wang 73534de7fd4SAndrzej Warzyński // 2. Generate input args for a new vector.transfer_read that will read 73634de7fd4SAndrzej Warzyński // from the collapsed memref. 73734de7fd4SAndrzej Warzyński // 2.1. New dim exprs + affine map 738f4ac9509SBenoit Jacob SmallVector<AffineExpr, 1> dimExprs{ 73953ddc874SHan-Chung Wang getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 740f4ac9509SBenoit Jacob auto collapsedMap = 741f4ac9509SBenoit Jacob AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 74253ddc874SHan-Chung Wang 74334de7fd4SAndrzej Warzyński // 2.2 New indices 74434de7fd4SAndrzej Warzyński SmallVector<Value> collapsedIndices = 74534de7fd4SAndrzej Warzyński getCollapsedIndices(rewriter, loc, sourceType.getShape(), 74634de7fd4SAndrzej Warzyński transferWriteOp.getIndices(), firstDimToCollapse); 74734de7fd4SAndrzej Warzyński 74834de7fd4SAndrzej Warzyński // 3. Create new vector.transfer_write that writes to the collapsed memref 749f4ac9509SBenoit Jacob VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 750f4ac9509SBenoit Jacob vectorType.getElementType()); 751f4ac9509SBenoit Jacob Value flatVector = 752f4ac9509SBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 753f4ac9509SBenoit Jacob vector::TransferWriteOp flatWrite = 754f4ac9509SBenoit Jacob rewriter.create<vector::TransferWriteOp>( 755f4ac9509SBenoit Jacob loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 756f4ac9509SBenoit Jacob flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 75734de7fd4SAndrzej Warzyński 75834de7fd4SAndrzej Warzyński // 4. Replace the old transfer_write with the new one writing the 75934de7fd4SAndrzej Warzyński // collapsed shape 76099ef9eebSMatthias Springer rewriter.eraseOp(transferWriteOp); 76199ef9eebSMatthias Springer return success(); 76299ef9eebSMatthias Springer } 76371441ed1SDiego Caballero 76471441ed1SDiego Caballero private: 76571441ed1SDiego Caballero // Minimum bitwidth that the trailing vector dimension should have after 76671441ed1SDiego Caballero // flattening. 76771441ed1SDiego Caballero unsigned targetVectorBitwidth; 76899ef9eebSMatthias Springer }; 76999ef9eebSMatthias Springer 77014726cd6SDiego Caballero /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 77114726cd6SDiego Caballero /// to `memref.load` patterns. The `match` method is shared for both 77214726cd6SDiego Caballero /// `vector.extract` and `vector.extract_element`. 77314726cd6SDiego Caballero template <class VectorExtractOp> 77414726cd6SDiego Caballero class RewriteScalarExtractOfTransferReadBase 77514726cd6SDiego Caballero : public OpRewritePattern<VectorExtractOp> { 77614726cd6SDiego Caballero using Base = OpRewritePattern<VectorExtractOp>; 7772ec98ffbSMatthias Springer 77814726cd6SDiego Caballero public: 77914726cd6SDiego Caballero RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 78014726cd6SDiego Caballero PatternBenefit benefit, 78114726cd6SDiego Caballero bool allowMultipleUses) 78214726cd6SDiego Caballero : Base::OpRewritePattern(context, benefit), 78314726cd6SDiego Caballero allowMultipleUses(allowMultipleUses) {} 78414726cd6SDiego Caballero 78514726cd6SDiego Caballero LogicalResult match(VectorExtractOp extractOp) const override { 78614726cd6SDiego Caballero auto xferOp = 78714726cd6SDiego Caballero extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 7882ec98ffbSMatthias Springer if (!xferOp) 7892ec98ffbSMatthias Springer return failure(); 790d3e1398bSDiego Caballero // Check that we are extracting a scalar and not a sub-vector. 791d3e1398bSDiego Caballero if (isa<VectorType>(extractOp.getResult().getType())) 792d3e1398bSDiego Caballero return failure(); 79314726cd6SDiego Caballero // If multiple uses are not allowed, check if xfer has a single use. 79414726cd6SDiego Caballero if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 79514726cd6SDiego Caballero return failure(); 79614726cd6SDiego Caballero // If multiple uses are allowed, check if all the xfer uses are extract ops. 79714726cd6SDiego Caballero if (allowMultipleUses && 79814726cd6SDiego Caballero !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 79914726cd6SDiego Caballero return isa<vector::ExtractOp, vector::ExtractElementOp>( 80014726cd6SDiego Caballero use.getOwner()); 80114726cd6SDiego Caballero })) 8022ec98ffbSMatthias Springer return failure(); 8032ec98ffbSMatthias Springer // Mask not supported. 8042ec98ffbSMatthias Springer if (xferOp.getMask()) 8052ec98ffbSMatthias Springer return failure(); 8062ec98ffbSMatthias Springer // Map not supported. 8072ec98ffbSMatthias Springer if (!xferOp.getPermutationMap().isMinorIdentity()) 8082ec98ffbSMatthias Springer return failure(); 80914726cd6SDiego Caballero // Cannot rewrite if the indices may be out of bounds. 81014726cd6SDiego Caballero if (xferOp.hasOutOfBoundsDim()) 8112ec98ffbSMatthias Springer return failure(); 81214726cd6SDiego Caballero return success(); 81314726cd6SDiego Caballero } 81414726cd6SDiego Caballero 81514726cd6SDiego Caballero private: 81614726cd6SDiego Caballero bool allowMultipleUses; 81714726cd6SDiego Caballero }; 81814726cd6SDiego Caballero 81914726cd6SDiego Caballero /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 82014726cd6SDiego Caballero /// 82114726cd6SDiego Caballero /// All the users of the transfer op must be either `vector.extractelement` or 82214726cd6SDiego Caballero /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 82314726cd6SDiego Caballero /// transfer ops with any number of users. Otherwise, rewrite only if the 82414726cd6SDiego Caballero /// extract op is the single user of the transfer op. Rewriting a single 82514726cd6SDiego Caballero /// vector load with multiple scalar loads may negatively affect performance. 82614726cd6SDiego Caballero class RewriteScalarExtractElementOfTransferRead 82714726cd6SDiego Caballero : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 82814726cd6SDiego Caballero using RewriteScalarExtractOfTransferReadBase:: 82914726cd6SDiego Caballero RewriteScalarExtractOfTransferReadBase; 83014726cd6SDiego Caballero 83114726cd6SDiego Caballero void rewrite(vector::ExtractElementOp extractOp, 83214726cd6SDiego Caballero PatternRewriter &rewriter) const override { 8332ec98ffbSMatthias Springer // Construct scalar load. 834d3e1398bSDiego Caballero auto loc = extractOp.getLoc(); 83514726cd6SDiego Caballero auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 8362ec98ffbSMatthias Springer SmallVector<Value> newIndices(xferOp.getIndices().begin(), 8372ec98ffbSMatthias Springer xferOp.getIndices().end()); 8382ec98ffbSMatthias Springer if (extractOp.getPosition()) { 8392ec98ffbSMatthias Springer AffineExpr sym0, sym1; 8402ec98ffbSMatthias Springer bindSymbols(extractOp.getContext(), sym0, sym1); 8414c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 842d3e1398bSDiego Caballero rewriter, loc, sym0 + sym1, 8432ec98ffbSMatthias Springer {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 8446e41483bSKazu Hirata if (auto value = dyn_cast<Value>(ofr)) { 8456e41483bSKazu Hirata newIndices[newIndices.size() - 1] = value; 8462ec98ffbSMatthias Springer } else { 8472ec98ffbSMatthias Springer newIndices[newIndices.size() - 1] = 848d3e1398bSDiego Caballero rewriter.create<arith::ConstantIndexOp>(loc, 8492ec98ffbSMatthias Springer *getConstantIntValue(ofr)); 8502ec98ffbSMatthias Springer } 8512ec98ffbSMatthias Springer } 8525550c821STres Popp if (isa<MemRefType>(xferOp.getSource().getType())) { 8532ec98ffbSMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 8542ec98ffbSMatthias Springer newIndices); 8552ec98ffbSMatthias Springer } else { 8562ec98ffbSMatthias Springer rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 8572ec98ffbSMatthias Springer extractOp, xferOp.getSource(), newIndices); 8582ec98ffbSMatthias Springer } 8592ec98ffbSMatthias Springer } 8602ec98ffbSMatthias Springer }; 8612ec98ffbSMatthias Springer 86214726cd6SDiego Caballero /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 86314726cd6SDiego Caballero /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 86441e731f2SMatthias Springer /// 86514726cd6SDiego Caballero /// All the users of the transfer op must be either `vector.extractelement` or 86614726cd6SDiego Caballero /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 86714726cd6SDiego Caballero /// transfer ops with any number of users. Otherwise, rewrite only if the 86814726cd6SDiego Caballero /// extract op is the single user of the transfer op. Rewriting a single 86914726cd6SDiego Caballero /// vector load with multiple scalar loads may negatively affect performance. 87041e731f2SMatthias Springer class RewriteScalarExtractOfTransferRead 87114726cd6SDiego Caballero : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 87214726cd6SDiego Caballero using RewriteScalarExtractOfTransferReadBase:: 87314726cd6SDiego Caballero RewriteScalarExtractOfTransferReadBase; 8742ec98ffbSMatthias Springer 87514726cd6SDiego Caballero void rewrite(vector::ExtractOp extractOp, 8762ec98ffbSMatthias Springer PatternRewriter &rewriter) const override { 87741e731f2SMatthias Springer // Construct scalar load. 87814726cd6SDiego Caballero auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 87941e731f2SMatthias Springer SmallVector<Value> newIndices(xferOp.getIndices().begin(), 88041e731f2SMatthias Springer xferOp.getIndices().end()); 88198f6289aSDiego Caballero for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 8826e41483bSKazu Hirata assert(isa<Attribute>(pos) && "Unexpected non-constant index"); 8836e41483bSKazu Hirata int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt(); 88498f6289aSDiego Caballero int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 8854c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 88641e731f2SMatthias Springer rewriter, extractOp.getLoc(), 88741e731f2SMatthias Springer rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 8886e41483bSKazu Hirata if (auto value = dyn_cast<Value>(ofr)) { 8896e41483bSKazu Hirata newIndices[idx] = value; 89041e731f2SMatthias Springer } else { 89141e731f2SMatthias Springer newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 89241e731f2SMatthias Springer extractOp.getLoc(), *getConstantIntValue(ofr)); 89341e731f2SMatthias Springer } 89441e731f2SMatthias Springer } 8955550c821STres Popp if (isa<MemRefType>(xferOp.getSource().getType())) { 89641e731f2SMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 89741e731f2SMatthias Springer newIndices); 89841e731f2SMatthias Springer } else { 89941e731f2SMatthias Springer rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 90041e731f2SMatthias Springer extractOp, xferOp.getSource(), newIndices); 90141e731f2SMatthias Springer } 90241e731f2SMatthias Springer } 90341e731f2SMatthias Springer }; 90441e731f2SMatthias Springer 90541e731f2SMatthias Springer /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 90641e731f2SMatthias Springer /// to memref.store. 90741e731f2SMatthias Springer class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 90841e731f2SMatthias Springer using OpRewritePattern::OpRewritePattern; 90941e731f2SMatthias Springer 91041e731f2SMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 91141e731f2SMatthias Springer PatternRewriter &rewriter) const override { 91241e731f2SMatthias Springer // Must be a scalar write. 91341e731f2SMatthias Springer auto vecType = xferOp.getVectorType(); 91441e731f2SMatthias Springer if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 91541e731f2SMatthias Springer return failure(); 91641e731f2SMatthias Springer // Mask not supported. 91741e731f2SMatthias Springer if (xferOp.getMask()) 91841e731f2SMatthias Springer return failure(); 91941e731f2SMatthias Springer // Map not supported. 92041e731f2SMatthias Springer if (!xferOp.getPermutationMap().isMinorIdentity()) 92141e731f2SMatthias Springer return failure(); 92241e731f2SMatthias Springer // Only float and integer element types are supported. 92341e731f2SMatthias Springer Value scalar; 92441e731f2SMatthias Springer if (vecType.getRank() == 0) { 92541e731f2SMatthias Springer // vector.extract does not support vector<f32> etc., so use 92641e731f2SMatthias Springer // vector.extractelement instead. 92741e731f2SMatthias Springer scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 92841e731f2SMatthias Springer xferOp.getVector()); 92941e731f2SMatthias Springer } else { 93041e731f2SMatthias Springer SmallVector<int64_t> pos(vecType.getRank(), 0); 93141e731f2SMatthias Springer scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 93241e731f2SMatthias Springer xferOp.getVector(), pos); 93341e731f2SMatthias Springer } 9342ec98ffbSMatthias Springer // Construct a scalar store. 9355550c821STres Popp if (isa<MemRefType>(xferOp.getSource().getType())) { 9362ec98ffbSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 93741e731f2SMatthias Springer xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 9382ec98ffbSMatthias Springer } else { 9392ec98ffbSMatthias Springer rewriter.replaceOpWithNewOp<tensor::InsertOp>( 94041e731f2SMatthias Springer xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 9412ec98ffbSMatthias Springer } 9422ec98ffbSMatthias Springer return success(); 9432ec98ffbSMatthias Springer } 9442ec98ffbSMatthias Springer }; 945834fcfedSDiego Caballero 94699ef9eebSMatthias Springer } // namespace 94799ef9eebSMatthias Springer 948553cebdeSNicolas Vasilache void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 949553cebdeSNicolas Vasilache Operation *rootOp) { 950553cebdeSNicolas Vasilache TransferOptimization opt(rewriter, rootOp); 95199ef9eebSMatthias Springer // Run store to load forwarding first since it can expose more dead store 95299ef9eebSMatthias Springer // opportunity. 953171850c5SRiver Riddle rootOp->walk([&](vector::TransferReadOp read) { 9545550c821STres Popp if (isa<MemRefType>(read.getShapedType())) 95599ef9eebSMatthias Springer opt.storeToLoadForwarding(read); 95699ef9eebSMatthias Springer }); 95799ef9eebSMatthias Springer opt.removeDeadOp(); 958171850c5SRiver Riddle rootOp->walk([&](vector::TransferWriteOp write) { 9595550c821STres Popp if (isa<MemRefType>(write.getShapedType())) 96099ef9eebSMatthias Springer opt.deadStoreOp(write); 96199ef9eebSMatthias Springer }); 96299ef9eebSMatthias Springer opt.removeDeadOp(); 96399ef9eebSMatthias Springer } 96499ef9eebSMatthias Springer 9652ec98ffbSMatthias Springer void mlir::vector::populateScalarVectorTransferLoweringPatterns( 96614726cd6SDiego Caballero RewritePatternSet &patterns, PatternBenefit benefit, 96714726cd6SDiego Caballero bool allowMultipleUses) { 96841e731f2SMatthias Springer patterns.add<RewriteScalarExtractElementOfTransferRead, 96914726cd6SDiego Caballero RewriteScalarExtractOfTransferRead>(patterns.getContext(), 97014726cd6SDiego Caballero benefit, allowMultipleUses); 97114726cd6SDiego Caballero patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 9722ec98ffbSMatthias Springer } 9732ec98ffbSMatthias Springer 97499ef9eebSMatthias Springer void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 97527cc31b6SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 97699ef9eebSMatthias Springer patterns 97799ef9eebSMatthias Springer .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 97827cc31b6SNicolas Vasilache patterns.getContext(), benefit); 97999ef9eebSMatthias Springer populateShapeCastFoldingPatterns(patterns); 98099ef9eebSMatthias Springer } 98199ef9eebSMatthias Springer 98299ef9eebSMatthias Springer void mlir::vector::populateFlattenVectorTransferPatterns( 98371441ed1SDiego Caballero RewritePatternSet &patterns, unsigned targetVectorBitwidth, 98471441ed1SDiego Caballero PatternBenefit benefit) { 98599ef9eebSMatthias Springer patterns.add<FlattenContiguousRowMajorTransferReadPattern, 98699ef9eebSMatthias Springer FlattenContiguousRowMajorTransferWritePattern>( 98771441ed1SDiego Caballero patterns.getContext(), targetVectorBitwidth, benefit); 98827cc31b6SNicolas Vasilache populateShapeCastFoldingPatterns(patterns, benefit); 989c02d07fdSAndrzej Warzyński populateDropUnitDimWithShapeCastPatterns(patterns, benefit); 99099ef9eebSMatthias Springer } 991