xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
199ef9eebSMatthias Springer //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements functions concerned with optimizing transfer_read and
1099ef9eebSMatthias Springer // transfer_write ops.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
136a8ba318SRiver Riddle 
142ec98ffbSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1790d2f8c6SBenjamin Maxwell #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
182ec98ffbSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
19847048f4SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2499ef9eebSMatthias Springer #include "mlir/IR/Dominance.h"
25fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
2699ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
2799ef9eebSMatthias Springer #include "llvm/ADT/StringRef.h"
2899ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
2999ef9eebSMatthias Springer 
3099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-opt"
3199ef9eebSMatthias Springer 
3299ef9eebSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
3399ef9eebSMatthias Springer 
3499ef9eebSMatthias Springer using namespace mlir;
3599ef9eebSMatthias Springer 
3699ef9eebSMatthias Springer /// Return the ancestor op in the region or nullptr if the region is not
3799ef9eebSMatthias Springer /// an ancestor of the op.
3899ef9eebSMatthias Springer static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
3999ef9eebSMatthias Springer   for (; op != nullptr && op->getParentRegion() != region;
4099ef9eebSMatthias Springer        op = op->getParentOp())
4199ef9eebSMatthias Springer     ;
4299ef9eebSMatthias Springer   return op;
4399ef9eebSMatthias Springer }
4499ef9eebSMatthias Springer 
4599ef9eebSMatthias Springer namespace {
4699ef9eebSMatthias Springer 
4799ef9eebSMatthias Springer class TransferOptimization {
4899ef9eebSMatthias Springer public:
49553cebdeSNicolas Vasilache   TransferOptimization(RewriterBase &rewriter, Operation *op)
50553cebdeSNicolas Vasilache       : rewriter(rewriter), dominators(op), postDominators(op) {}
5199ef9eebSMatthias Springer   void deadStoreOp(vector::TransferWriteOp);
5299ef9eebSMatthias Springer   void storeToLoadForwarding(vector::TransferReadOp);
5399ef9eebSMatthias Springer   void removeDeadOp() {
5499ef9eebSMatthias Springer     for (Operation *op : opToErase)
55553cebdeSNicolas Vasilache       rewriter.eraseOp(op);
5699ef9eebSMatthias Springer     opToErase.clear();
5799ef9eebSMatthias Springer   }
5899ef9eebSMatthias Springer 
5999ef9eebSMatthias Springer private:
60553cebdeSNicolas Vasilache   RewriterBase &rewriter;
6199ef9eebSMatthias Springer   bool isReachable(Operation *start, Operation *dest);
6299ef9eebSMatthias Springer   DominanceInfo dominators;
6399ef9eebSMatthias Springer   PostDominanceInfo postDominators;
6499ef9eebSMatthias Springer   std::vector<Operation *> opToErase;
6599ef9eebSMatthias Springer };
6699ef9eebSMatthias Springer 
67834fcfedSDiego Caballero } // namespace
6899ef9eebSMatthias Springer /// Return true if there is a path from start operation to dest operation,
6999ef9eebSMatthias Springer /// otherwise return false. The operations have to be in the same region.
7099ef9eebSMatthias Springer bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
7199ef9eebSMatthias Springer   assert(start->getParentRegion() == dest->getParentRegion() &&
7299ef9eebSMatthias Springer          "This function only works for ops i the same region");
7399ef9eebSMatthias Springer   // Simple case where the start op dominate the destination.
7499ef9eebSMatthias Springer   if (dominators.dominates(start, dest))
7599ef9eebSMatthias Springer     return true;
76804d3c4cSMatthias Springer   return start->getBlock()->isReachable(dest->getBlock());
7799ef9eebSMatthias Springer }
7899ef9eebSMatthias Springer 
7999ef9eebSMatthias Springer /// For transfer_write to overwrite fully another transfer_write must:
8099ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type.
8199ef9eebSMatthias Springer /// 2. Post-dominate the other transfer_write operation.
8299ef9eebSMatthias Springer /// If several candidates are available, one must be post-dominated by all the
8399ef9eebSMatthias Springer /// others since they are all post-dominating the same transfer_write. We only
8499ef9eebSMatthias Springer /// consider the transfer_write post-dominated by all the other candidates as
8599ef9eebSMatthias Springer /// this will be the first transfer_write executed after the potentially dead
8699ef9eebSMatthias Springer /// transfer_write.
8799ef9eebSMatthias Springer /// If we found such an overwriting transfer_write we know that the original
8899ef9eebSMatthias Springer /// transfer_write is dead if all reads that can be reached from the potentially
8999ef9eebSMatthias Springer /// dead transfer_write are dominated by the overwriting transfer_write.
9099ef9eebSMatthias Springer void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
9199ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
9299ef9eebSMatthias Springer                     << "\n");
93bc13437bSThomas Raoux   llvm::SmallVector<Operation *, 8> blockingAccesses;
9499ef9eebSMatthias Springer   Operation *firstOverwriteCandidate = nullptr;
954e2efea5SQuinn Dawkins   Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
96bc13437bSThomas Raoux   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97bc13437bSThomas Raoux                                            source.getUsers().end());
98bc13437bSThomas Raoux   llvm::SmallDenseSet<Operation *, 32> processed;
99bc13437bSThomas Raoux   while (!users.empty()) {
100bc13437bSThomas Raoux     Operation *user = users.pop_back_val();
101bc13437bSThomas Raoux     // If the user has already been processed skip.
102bc13437bSThomas Raoux     if (!processed.insert(user).second)
103bc13437bSThomas Raoux       continue;
1044e2efea5SQuinn Dawkins     if (isa<ViewLikeOpInterface>(user)) {
10590d2f8c6SBenjamin Maxwell       users.append(user->getUsers().begin(), user->getUsers().end());
106bc13437bSThomas Raoux       continue;
107bc13437bSThomas Raoux     }
10886771d0bSSanjoy Das     if (isMemoryEffectFree(user))
109bc13437bSThomas Raoux       continue;
11099ef9eebSMatthias Springer     if (user == write.getOperation())
11199ef9eebSMatthias Springer       continue;
11299ef9eebSMatthias Springer     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
11399ef9eebSMatthias Springer       // Check candidate that can override the store.
11490d2f8c6SBenjamin Maxwell       if (memref::isSameViewOrTrivialAlias(
11590d2f8c6SBenjamin Maxwell               cast<MemrefValue>(nextWrite.getSource()),
11690d2f8c6SBenjamin Maxwell               cast<MemrefValue>(write.getSource())) &&
117bc13437bSThomas Raoux           checkSameValueWAW(nextWrite, write) &&
11899ef9eebSMatthias Springer           postDominators.postDominates(nextWrite, write)) {
11999ef9eebSMatthias Springer         if (firstOverwriteCandidate == nullptr ||
12099ef9eebSMatthias Springer             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
12199ef9eebSMatthias Springer           firstOverwriteCandidate = nextWrite;
12299ef9eebSMatthias Springer         else
12399ef9eebSMatthias Springer           assert(
12499ef9eebSMatthias Springer               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
12599ef9eebSMatthias Springer         continue;
12699ef9eebSMatthias Springer       }
12799ef9eebSMatthias Springer     }
128bc13437bSThomas Raoux     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
129bc13437bSThomas Raoux       // Don't need to consider disjoint accesses.
130bc13437bSThomas Raoux       if (vector::isDisjointTransferSet(
131bc13437bSThomas Raoux               cast<VectorTransferOpInterface>(write.getOperation()),
1323049ac44SLei Zhang               cast<VectorTransferOpInterface>(transferOp.getOperation()),
1333049ac44SLei Zhang               /*testDynamicValueUsingBounds=*/true))
134bc13437bSThomas Raoux         continue;
135bc13437bSThomas Raoux     }
136bc13437bSThomas Raoux     blockingAccesses.push_back(user);
13799ef9eebSMatthias Springer   }
13899ef9eebSMatthias Springer   if (firstOverwriteCandidate == nullptr)
13999ef9eebSMatthias Springer     return;
14099ef9eebSMatthias Springer   Region *topRegion = firstOverwriteCandidate->getParentRegion();
14199ef9eebSMatthias Springer   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
14299ef9eebSMatthias Springer   assert(writeAncestor &&
14399ef9eebSMatthias Springer          "write op should be recursively part of the top region");
14499ef9eebSMatthias Springer 
145bc13437bSThomas Raoux   for (Operation *access : blockingAccesses) {
146bc13437bSThomas Raoux     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
147bc13437bSThomas Raoux     // TODO: if the access and write have the same ancestor we could recurse in
148bc13437bSThomas Raoux     // the region to know if the access is reachable with more precision.
149bc13437bSThomas Raoux     if (accessAncestor == nullptr ||
150bc13437bSThomas Raoux         !isReachable(writeAncestor, accessAncestor))
15199ef9eebSMatthias Springer       continue;
152bc13437bSThomas Raoux     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153bc13437bSThomas Raoux       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154bc13437bSThomas Raoux                         << *accessAncestor << "\n");
15599ef9eebSMatthias Springer       return;
15699ef9eebSMatthias Springer     }
15799ef9eebSMatthias Springer   }
15899ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
15999ef9eebSMatthias Springer                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
16099ef9eebSMatthias Springer   opToErase.push_back(write.getOperation());
16199ef9eebSMatthias Springer }
16299ef9eebSMatthias Springer 
16399ef9eebSMatthias Springer /// A transfer_write candidate to storeToLoad forwarding must:
16499ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type as the
16599ef9eebSMatthias Springer /// transfer_read.
16699ef9eebSMatthias Springer /// 2. Dominate the transfer_read operation.
16799ef9eebSMatthias Springer /// If several candidates are available, one must be dominated by all the others
16899ef9eebSMatthias Springer /// since they are all dominating the same transfer_read. We only consider the
16999ef9eebSMatthias Springer /// transfer_write dominated by all the other candidates as this will be the
17099ef9eebSMatthias Springer /// last transfer_write executed before the transfer_read.
17199ef9eebSMatthias Springer /// If we found such a candidate we can do the forwarding if all the other
17299ef9eebSMatthias Springer /// potentially aliasing ops that may reach the transfer_read are post-dominated
17399ef9eebSMatthias Springer /// by the transfer_write.
17499ef9eebSMatthias Springer void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
17599ef9eebSMatthias Springer   if (read.hasOutOfBoundsDim())
17699ef9eebSMatthias Springer     return;
17799ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
17899ef9eebSMatthias Springer                     << "\n");
17999ef9eebSMatthias Springer   SmallVector<Operation *, 8> blockingWrites;
18099ef9eebSMatthias Springer   vector::TransferWriteOp lastwrite = nullptr;
1814e2efea5SQuinn Dawkins   Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
182bc13437bSThomas Raoux   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183bc13437bSThomas Raoux                                            source.getUsers().end());
184bc13437bSThomas Raoux   llvm::SmallDenseSet<Operation *, 32> processed;
185bc13437bSThomas Raoux   while (!users.empty()) {
186bc13437bSThomas Raoux     Operation *user = users.pop_back_val();
187bc13437bSThomas Raoux     // If the user has already been processed skip.
188bc13437bSThomas Raoux     if (!processed.insert(user).second)
189bc13437bSThomas Raoux       continue;
1904e2efea5SQuinn Dawkins     if (isa<ViewLikeOpInterface>(user)) {
19190d2f8c6SBenjamin Maxwell       users.append(user->getUsers().begin(), user->getUsers().end());
19222f96ab6SAndrzej Warzyński       continue;
19322f96ab6SAndrzej Warzyński     }
19486771d0bSSanjoy Das     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
19599ef9eebSMatthias Springer       continue;
19699ef9eebSMatthias Springer     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
19799ef9eebSMatthias Springer       // If there is a write, but we can prove that it is disjoint we can ignore
19899ef9eebSMatthias Springer       // the write.
19999ef9eebSMatthias Springer       if (vector::isDisjointTransferSet(
20099ef9eebSMatthias Springer               cast<VectorTransferOpInterface>(write.getOperation()),
2013049ac44SLei Zhang               cast<VectorTransferOpInterface>(read.getOperation()),
2023049ac44SLei Zhang               /*testDynamicValueUsingBounds=*/true))
20399ef9eebSMatthias Springer         continue;
20490d2f8c6SBenjamin Maxwell       if (memref::isSameViewOrTrivialAlias(
20590d2f8c6SBenjamin Maxwell               cast<MemrefValue>(read.getSource()),
20690d2f8c6SBenjamin Maxwell               cast<MemrefValue>(write.getSource())) &&
207bc13437bSThomas Raoux           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
20899ef9eebSMatthias Springer         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
20999ef9eebSMatthias Springer           lastwrite = write;
21099ef9eebSMatthias Springer         else
21199ef9eebSMatthias Springer           assert(dominators.dominates(write, lastwrite));
21299ef9eebSMatthias Springer         continue;
21399ef9eebSMatthias Springer       }
21499ef9eebSMatthias Springer     }
21599ef9eebSMatthias Springer     blockingWrites.push_back(user);
21699ef9eebSMatthias Springer   }
21799ef9eebSMatthias Springer 
21899ef9eebSMatthias Springer   if (lastwrite == nullptr)
21999ef9eebSMatthias Springer     return;
22099ef9eebSMatthias Springer 
22199ef9eebSMatthias Springer   Region *topRegion = lastwrite->getParentRegion();
22299ef9eebSMatthias Springer   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
22399ef9eebSMatthias Springer   assert(readAncestor &&
22499ef9eebSMatthias Springer          "read op should be recursively part of the top region");
22599ef9eebSMatthias Springer 
22699ef9eebSMatthias Springer   for (Operation *write : blockingWrites) {
22799ef9eebSMatthias Springer     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
22899ef9eebSMatthias Springer     // TODO: if the store and read have the same ancestor we could recurse in
22999ef9eebSMatthias Springer     // the region to know if the read is reachable with more precision.
23099ef9eebSMatthias Springer     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
23199ef9eebSMatthias Springer       continue;
23299ef9eebSMatthias Springer     if (!postDominators.postDominates(lastwrite, write)) {
23399ef9eebSMatthias Springer       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
23499ef9eebSMatthias Springer                         << *write << "\n");
23599ef9eebSMatthias Springer       return;
23699ef9eebSMatthias Springer     }
23799ef9eebSMatthias Springer   }
23899ef9eebSMatthias Springer 
23999ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
24099ef9eebSMatthias Springer                     << " to: " << *read.getOperation() << "\n");
2417c38fd60SJacques Pienaar   read.replaceAllUsesWith(lastwrite.getVector());
24299ef9eebSMatthias Springer   opToErase.push_back(read.getOperation());
24399ef9eebSMatthias Springer }
24499ef9eebSMatthias Springer 
245bf897d5dSCullen Rhodes /// Converts OpFoldResults to int64_t shape without unit dims.
246bf897d5dSCullen Rhodes static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
247bf897d5dSCullen Rhodes   SmallVector<int64_t> reducedShape;
248bf897d5dSCullen Rhodes   for (const auto size : mixedSizes) {
249bf897d5dSCullen Rhodes     if (llvm::dyn_cast_if_present<Value>(size)) {
250bf897d5dSCullen Rhodes       reducedShape.push_back(ShapedType::kDynamic);
251bf897d5dSCullen Rhodes       continue;
252bf897d5dSCullen Rhodes     }
253bf897d5dSCullen Rhodes 
2546e41483bSKazu Hirata     auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
255bf897d5dSCullen Rhodes     if (value == 1)
256bf897d5dSCullen Rhodes       continue;
257bf897d5dSCullen Rhodes     reducedShape.push_back(value.getSExtValue());
258bf897d5dSCullen Rhodes   }
259bf897d5dSCullen Rhodes   return reducedShape;
260bf897d5dSCullen Rhodes }
261bf897d5dSCullen Rhodes 
26299ef9eebSMatthias Springer /// Drops unit dimensions from the input MemRefType.
263bf897d5dSCullen Rhodes static MemRefType dropUnitDims(MemRefType inputType,
264bf897d5dSCullen Rhodes                                ArrayRef<OpFoldResult> offsets,
265bf897d5dSCullen Rhodes                                ArrayRef<OpFoldResult> sizes,
266bf897d5dSCullen Rhodes                                ArrayRef<OpFoldResult> strides) {
267bf897d5dSCullen Rhodes   auto targetShape = getReducedShape(sizes);
26899ef9eebSMatthias Springer   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
2696c3c5f80SMatthias Springer       targetShape, inputType, offsets, sizes, strides);
270*6aaa8f25SMatthias Springer   return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
27199ef9eebSMatthias Springer }
27299ef9eebSMatthias Springer 
27399ef9eebSMatthias Springer /// Creates a rank-reducing memref.subview op that drops unit dims from its
27499ef9eebSMatthias Springer /// input. Or just returns the input if it was already without unit dims.
27599ef9eebSMatthias Springer static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
27699ef9eebSMatthias Springer                                                  mlir::Location loc,
27799ef9eebSMatthias Springer                                                  Value input) {
2785550c821STres Popp   MemRefType inputType = cast<MemRefType>(input.getType());
279bf897d5dSCullen Rhodes   SmallVector<OpFoldResult> offsets(inputType.getRank(),
280bf897d5dSCullen Rhodes                                     rewriter.getIndexAttr(0));
281bf897d5dSCullen Rhodes   SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
282bf897d5dSCullen Rhodes   SmallVector<OpFoldResult> strides(inputType.getRank(),
283bf897d5dSCullen Rhodes                                     rewriter.getIndexAttr(1));
284bf897d5dSCullen Rhodes   MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285bf897d5dSCullen Rhodes 
286*6aaa8f25SMatthias Springer   if (resultType.canonicalizeStridedLayout() ==
287*6aaa8f25SMatthias Springer       inputType.canonicalizeStridedLayout())
28899ef9eebSMatthias Springer     return input;
289bf897d5dSCullen Rhodes   return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
290bf897d5dSCullen Rhodes                                             sizes, strides);
29199ef9eebSMatthias Springer }
29299ef9eebSMatthias Springer 
29399ef9eebSMatthias Springer /// Returns the number of dims that aren't unit dims.
29499ef9eebSMatthias Springer static int getReducedRank(ArrayRef<int64_t> shape) {
29599ef9eebSMatthias Springer   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
29699ef9eebSMatthias Springer }
29799ef9eebSMatthias Springer 
298bf897d5dSCullen Rhodes /// Trims non-scalable one dimensions from `oldType` and returns the result
299bf897d5dSCullen Rhodes /// type.
300bf897d5dSCullen Rhodes static VectorType trimNonScalableUnitDims(VectorType oldType) {
301bf897d5dSCullen Rhodes   SmallVector<int64_t> newShape;
302bf897d5dSCullen Rhodes   SmallVector<bool> newScalableDims;
303bf897d5dSCullen Rhodes   for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
304bf897d5dSCullen Rhodes     if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305bf897d5dSCullen Rhodes       continue;
306bf897d5dSCullen Rhodes     newShape.push_back(dimSize);
307bf897d5dSCullen Rhodes     newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
308bf897d5dSCullen Rhodes   }
309bf897d5dSCullen Rhodes   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
310bf897d5dSCullen Rhodes }
311bf897d5dSCullen Rhodes 
312bf897d5dSCullen Rhodes // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313bf897d5dSCullen Rhodes static FailureOr<Value>
314bf897d5dSCullen Rhodes createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
315bf897d5dSCullen Rhodes                                   vector::CreateMaskOp op) {
316bf897d5dSCullen Rhodes   auto type = op.getType();
317fdf84cbfSQuinn Dawkins   VectorType reducedType = trimNonScalableUnitDims(type);
318bf897d5dSCullen Rhodes   if (reducedType.getRank() == type.getRank())
319bf897d5dSCullen Rhodes     return failure();
320bf897d5dSCullen Rhodes 
321bf897d5dSCullen Rhodes   SmallVector<Value> reducedOperands;
322bf897d5dSCullen Rhodes   for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323bf897d5dSCullen Rhodes            type.getShape(), type.getScalableDims(), op.getOperands())) {
324bf897d5dSCullen Rhodes     if (dim == 1 && !dimIsScalable) {
325bf897d5dSCullen Rhodes       // If the mask for the unit dim is not a constant of 1, do nothing.
326bf897d5dSCullen Rhodes       auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327bf897d5dSCullen Rhodes       if (!constant || (constant.value() != 1))
328bf897d5dSCullen Rhodes         return failure();
329bf897d5dSCullen Rhodes       continue;
330bf897d5dSCullen Rhodes     }
331bf897d5dSCullen Rhodes     reducedOperands.push_back(operand);
332bf897d5dSCullen Rhodes   }
333bf897d5dSCullen Rhodes   return rewriter
334bf897d5dSCullen Rhodes       .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
335bf897d5dSCullen Rhodes       .getResult();
336834fcfedSDiego Caballero }
337834fcfedSDiego Caballero 
338834fcfedSDiego Caballero namespace {
339834fcfedSDiego Caballero 
340834fcfedSDiego Caballero /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341834fcfedSDiego Caballero /// inserting a memref.subview dropping those unit dims. The vector shapes are
342834fcfedSDiego Caballero /// also reduced accordingly.
34399ef9eebSMatthias Springer class TransferReadDropUnitDimsPattern
3440cf7aaf3SAndrzej Warzyński     : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
3450cf7aaf3SAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
34699ef9eebSMatthias Springer 
3470cf7aaf3SAndrzej Warzyński   FailureOr<Value>
3480cf7aaf3SAndrzej Warzyński   matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
3490cf7aaf3SAndrzej Warzyński                             vector::MaskingOpInterface maskingOp,
35099ef9eebSMatthias Springer                             PatternRewriter &rewriter) const override {
35199ef9eebSMatthias Springer     auto loc = transferReadOp.getLoc();
3527c38fd60SJacques Pienaar     Value vector = transferReadOp.getVector();
3535550c821STres Popp     VectorType vectorType = cast<VectorType>(vector.getType());
3547c38fd60SJacques Pienaar     Value source = transferReadOp.getSource();
3555550c821STres Popp     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
35699ef9eebSMatthias Springer     // TODO: support tensor types.
357bf897d5dSCullen Rhodes     if (!sourceType)
35899ef9eebSMatthias Springer       return failure();
35999ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
36099ef9eebSMatthias Springer     if (transferReadOp.hasOutOfBoundsDim())
36199ef9eebSMatthias Springer       return failure();
3627c38fd60SJacques Pienaar     if (!transferReadOp.getPermutationMap().isMinorIdentity())
36399ef9eebSMatthias Springer       return failure();
364834fcfedSDiego Caballero     // Check if the source shape can be further reduced.
36599ef9eebSMatthias Springer     int reducedRank = getReducedRank(sourceType.getShape());
36699ef9eebSMatthias Springer     if (reducedRank == sourceType.getRank())
367834fcfedSDiego Caballero       return failure();
3680cf7aaf3SAndrzej Warzyński     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
3690cf7aaf3SAndrzej Warzyński     // out.
3700cf7aaf3SAndrzej Warzyński     if (reducedRank == 0 && maskingOp)
3710cf7aaf3SAndrzej Warzyński       return failure();
372834fcfedSDiego Caballero     // Check if the reduced vector shape matches the reduced source shape.
373834fcfedSDiego Caballero     // Otherwise, this case is not supported yet.
374fdf84cbfSQuinn Dawkins     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
375bf897d5dSCullen Rhodes     if (reducedRank != reducedVectorType.getRank())
376834fcfedSDiego Caballero       return failure();
377cb7bda2aSMatthias Springer     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
378cb7bda2aSMatthias Springer           return getConstantIntValue(v) != static_cast<int64_t>(0);
379cb7bda2aSMatthias Springer         }))
38099ef9eebSMatthias Springer       return failure();
381bf897d5dSCullen Rhodes 
382bf897d5dSCullen Rhodes     Value maskOp = transferReadOp.getMask();
383bf897d5dSCullen Rhodes     if (maskOp) {
384bf897d5dSCullen Rhodes       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385bf897d5dSCullen Rhodes       if (!createMaskOp)
386bf897d5dSCullen Rhodes         return rewriter.notifyMatchFailure(
387bf897d5dSCullen Rhodes             transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
388bf897d5dSCullen Rhodes                             "currently supported");
389bf897d5dSCullen Rhodes       FailureOr<Value> rankReducedCreateMask =
390bf897d5dSCullen Rhodes           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
391bf897d5dSCullen Rhodes       if (failed(rankReducedCreateMask))
392bf897d5dSCullen Rhodes         return failure();
393bf897d5dSCullen Rhodes       maskOp = *rankReducedCreateMask;
394bf897d5dSCullen Rhodes     }
395bf897d5dSCullen Rhodes 
39699ef9eebSMatthias Springer     Value reducedShapeSource =
39799ef9eebSMatthias Springer         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
39899ef9eebSMatthias Springer     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
39999ef9eebSMatthias Springer     SmallVector<Value> zeros(reducedRank, c0);
40099ef9eebSMatthias Springer     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
401bf897d5dSCullen Rhodes     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
4020cf7aaf3SAndrzej Warzyński     Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403bf897d5dSCullen Rhodes         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404bf897d5dSCullen Rhodes         transferReadOp.getPadding(), maskOp,
405bf897d5dSCullen Rhodes         rewriter.getBoolArrayAttr(inBounds));
406834fcfedSDiego Caballero 
4070cf7aaf3SAndrzej Warzyński     if (maskingOp) {
4080cf7aaf3SAndrzej Warzyński       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
4090cf7aaf3SAndrzej Warzyński           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
4100cf7aaf3SAndrzej Warzyński           maskingOp.getMask());
4110cf7aaf3SAndrzej Warzyński       newTransferReadOp = mlir::vector::maskOperation(
4120cf7aaf3SAndrzej Warzyński           rewriter, newTransferReadOp, shapeCastMask);
4130cf7aaf3SAndrzej Warzyński     }
4140cf7aaf3SAndrzej Warzyński 
4150cf7aaf3SAndrzej Warzyński     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
4160cf7aaf3SAndrzej Warzyński         loc, vectorType, newTransferReadOp->getResults()[0]);
4170cf7aaf3SAndrzej Warzyński 
4180cf7aaf3SAndrzej Warzyński     return shapeCast;
41999ef9eebSMatthias Springer   }
42099ef9eebSMatthias Springer };
42199ef9eebSMatthias Springer 
422834fcfedSDiego Caballero /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423834fcfedSDiego Caballero /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424834fcfedSDiego Caballero /// vector shapes are also reduced accordingly.
42599ef9eebSMatthias Springer class TransferWriteDropUnitDimsPattern
4260cf7aaf3SAndrzej Warzyński     : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
4270cf7aaf3SAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
42899ef9eebSMatthias Springer 
4290cf7aaf3SAndrzej Warzyński   FailureOr<Value>
4300cf7aaf3SAndrzej Warzyński   matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
4310cf7aaf3SAndrzej Warzyński                             vector::MaskingOpInterface maskingOp,
43299ef9eebSMatthias Springer                             PatternRewriter &rewriter) const override {
43399ef9eebSMatthias Springer     auto loc = transferWriteOp.getLoc();
4347c38fd60SJacques Pienaar     Value vector = transferWriteOp.getVector();
4355550c821STres Popp     VectorType vectorType = cast<VectorType>(vector.getType());
4367c38fd60SJacques Pienaar     Value source = transferWriteOp.getSource();
4375550c821STres Popp     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
43899ef9eebSMatthias Springer     // TODO: support tensor type.
439fdf84cbfSQuinn Dawkins     if (!sourceType)
44099ef9eebSMatthias Springer       return failure();
44199ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
44299ef9eebSMatthias Springer     if (transferWriteOp.hasOutOfBoundsDim())
44399ef9eebSMatthias Springer       return failure();
4447c38fd60SJacques Pienaar     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
44599ef9eebSMatthias Springer       return failure();
446834fcfedSDiego Caballero     // Check if the destination shape can be further reduced.
44799ef9eebSMatthias Springer     int reducedRank = getReducedRank(sourceType.getShape());
44899ef9eebSMatthias Springer     if (reducedRank == sourceType.getRank())
449834fcfedSDiego Caballero       return failure();
4500cf7aaf3SAndrzej Warzyński     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
4510cf7aaf3SAndrzej Warzyński     // out.
4520cf7aaf3SAndrzej Warzyński     if (reducedRank == 0 && maskingOp)
4530cf7aaf3SAndrzej Warzyński       return failure();
454834fcfedSDiego Caballero     // Check if the reduced vector shape matches the reduced destination shape.
455834fcfedSDiego Caballero     // Otherwise, this case is not supported yet.
456fdf84cbfSQuinn Dawkins     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
457fdf84cbfSQuinn Dawkins     if (reducedRank != reducedVectorType.getRank())
458834fcfedSDiego Caballero       return failure();
459cb7bda2aSMatthias Springer     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
460cb7bda2aSMatthias Springer           return getConstantIntValue(v) != static_cast<int64_t>(0);
461cb7bda2aSMatthias Springer         }))
46299ef9eebSMatthias Springer       return failure();
463fdf84cbfSQuinn Dawkins 
464fdf84cbfSQuinn Dawkins     Value maskOp = transferWriteOp.getMask();
465fdf84cbfSQuinn Dawkins     if (maskOp) {
466fdf84cbfSQuinn Dawkins       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467fdf84cbfSQuinn Dawkins       if (!createMaskOp)
468fdf84cbfSQuinn Dawkins         return rewriter.notifyMatchFailure(
469fdf84cbfSQuinn Dawkins             transferWriteOp,
470fdf84cbfSQuinn Dawkins             "unsupported mask op, only 'vector.create_mask' is "
471fdf84cbfSQuinn Dawkins             "currently supported");
472fdf84cbfSQuinn Dawkins       FailureOr<Value> rankReducedCreateMask =
473fdf84cbfSQuinn Dawkins           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
474fdf84cbfSQuinn Dawkins       if (failed(rankReducedCreateMask))
475fdf84cbfSQuinn Dawkins         return failure();
476fdf84cbfSQuinn Dawkins       maskOp = *rankReducedCreateMask;
477fdf84cbfSQuinn Dawkins     }
47899ef9eebSMatthias Springer     Value reducedShapeSource =
47999ef9eebSMatthias Springer         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
48099ef9eebSMatthias Springer     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
48199ef9eebSMatthias Springer     SmallVector<Value> zeros(reducedRank, c0);
48299ef9eebSMatthias Springer     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
483fdf84cbfSQuinn Dawkins     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
4840cf7aaf3SAndrzej Warzyński     auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485834fcfedSDiego Caballero         loc, reducedVectorType, vector);
4860cf7aaf3SAndrzej Warzyński     Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
4870cf7aaf3SAndrzej Warzyński         loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
4880cf7aaf3SAndrzej Warzyński         maskOp, rewriter.getBoolArrayAttr(inBounds));
489834fcfedSDiego Caballero 
4900cf7aaf3SAndrzej Warzyński     if (maskingOp) {
4910cf7aaf3SAndrzej Warzyński       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
4920cf7aaf3SAndrzej Warzyński           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
4930cf7aaf3SAndrzej Warzyński           maskingOp.getMask());
4940cf7aaf3SAndrzej Warzyński       newXferWrite =
4950cf7aaf3SAndrzej Warzyński           mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
4960cf7aaf3SAndrzej Warzyński     }
4970cf7aaf3SAndrzej Warzyński 
4980cf7aaf3SAndrzej Warzyński     if (transferWriteOp.hasPureTensorSemantics())
4990cf7aaf3SAndrzej Warzyński       return newXferWrite->getResults()[0];
5000cf7aaf3SAndrzej Warzyński 
5010cf7aaf3SAndrzej Warzyński     // With Memref semantics, there's no return value. Use empty value to signal
5020cf7aaf3SAndrzej Warzyński     // success.
5030cf7aaf3SAndrzej Warzyński     return Value();
50499ef9eebSMatthias Springer   }
50599ef9eebSMatthias Springer };
50699ef9eebSMatthias Springer 
507834fcfedSDiego Caballero } // namespace
508834fcfedSDiego Caballero 
509f4ac9509SBenoit Jacob /// Creates a memref.collapse_shape collapsing all inner dimensions of the
510f4ac9509SBenoit Jacob /// input starting at `firstDimToCollapse`.
511f4ac9509SBenoit Jacob static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
512f4ac9509SBenoit Jacob                                Value input, int64_t firstDimToCollapse) {
5135550c821STres Popp   ShapedType inputType = cast<ShapedType>(input.getType());
514f4ac9509SBenoit Jacob   if (inputType.getRank() == 1)
515f4ac9509SBenoit Jacob     return input;
516f4ac9509SBenoit Jacob   SmallVector<ReassociationIndices> reassociation;
517f4ac9509SBenoit Jacob   for (int64_t i = 0; i < firstDimToCollapse; ++i)
518f4ac9509SBenoit Jacob     reassociation.push_back(ReassociationIndices{i});
519f4ac9509SBenoit Jacob   ReassociationIndices collapsedIndices;
520f4ac9509SBenoit Jacob   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521f4ac9509SBenoit Jacob     collapsedIndices.push_back(i);
522f4ac9509SBenoit Jacob   reassociation.push_back(collapsedIndices);
523f4ac9509SBenoit Jacob   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
524f4ac9509SBenoit Jacob }
525f4ac9509SBenoit Jacob 
52653ddc874SHan-Chung Wang /// Returns the new indices that collapses the inner dimensions starting from
52753ddc874SHan-Chung Wang /// the `firstDimToCollapse` dimension.
52853ddc874SHan-Chung Wang static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
52953ddc874SHan-Chung Wang                                               Location loc,
53053ddc874SHan-Chung Wang                                               ArrayRef<int64_t> shape,
53153ddc874SHan-Chung Wang                                               ValueRange indices,
53253ddc874SHan-Chung Wang                                               int64_t firstDimToCollapse) {
53353ddc874SHan-Chung Wang   assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
53453ddc874SHan-Chung Wang 
53553ddc874SHan-Chung Wang   // If all the collapsed indices are zero then no extra logic is needed.
53653ddc874SHan-Chung Wang   // Otherwise, a new offset/index has to be computed.
53753ddc874SHan-Chung Wang   SmallVector<Value> indicesAfterCollapsing(
53853ddc874SHan-Chung Wang       indices.begin(), indices.begin() + firstDimToCollapse);
53953ddc874SHan-Chung Wang   SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
54053ddc874SHan-Chung Wang                                        indices.end());
54153ddc874SHan-Chung Wang   if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
54253ddc874SHan-Chung Wang     indicesAfterCollapsing.push_back(indicesToCollapse[0]);
54353ddc874SHan-Chung Wang     return indicesAfterCollapsing;
544f4ac9509SBenoit Jacob   }
54553ddc874SHan-Chung Wang 
54653ddc874SHan-Chung Wang   // Compute the remaining trailing index/offset required for reading from
54753ddc874SHan-Chung Wang   // the collapsed memref:
54853ddc874SHan-Chung Wang   //
54953ddc874SHan-Chung Wang   //    offset = 0
55053ddc874SHan-Chung Wang   //    for (i = firstDimToCollapse; i < outputRank; ++i)
55153ddc874SHan-Chung Wang   //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
55253ddc874SHan-Chung Wang   //
55353ddc874SHan-Chung Wang   // For this example:
55453ddc874SHan-Chung Wang   //   %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
55553ddc874SHan-Chung Wang   //      memref<1x43x2xi32>, vector<1x2xi32>
55653ddc874SHan-Chung Wang   // which would be collapsed to:
55753ddc874SHan-Chung Wang   //   %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
55853ddc874SHan-Chung Wang   //      memref<1x86xi32>, vector<2xi32>
55953ddc874SHan-Chung Wang   // one would get the following offset:
56053ddc874SHan-Chung Wang   //    %offset = %arg0 * 43
56153ddc874SHan-Chung Wang   OpFoldResult collapsedOffset =
56253ddc874SHan-Chung Wang       rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
56353ddc874SHan-Chung Wang 
56453ddc874SHan-Chung Wang   auto collapsedStrides = computeSuffixProduct(
56553ddc874SHan-Chung Wang       ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
56653ddc874SHan-Chung Wang 
56753ddc874SHan-Chung Wang   // Compute the collapsed offset.
56853ddc874SHan-Chung Wang   auto &&[collapsedExpr, collapsedVals] =
56953ddc874SHan-Chung Wang       computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
57053ddc874SHan-Chung Wang   collapsedOffset = affine::makeComposedFoldedAffineApply(
57153ddc874SHan-Chung Wang       rewriter, loc, collapsedExpr, collapsedVals);
57253ddc874SHan-Chung Wang 
5736e41483bSKazu Hirata   if (auto value = dyn_cast<Value>(collapsedOffset)) {
5746e41483bSKazu Hirata     indicesAfterCollapsing.push_back(value);
57553ddc874SHan-Chung Wang   } else {
57653ddc874SHan-Chung Wang     indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
57753ddc874SHan-Chung Wang         loc, *getConstantIntValue(collapsedOffset)));
57853ddc874SHan-Chung Wang   }
57953ddc874SHan-Chung Wang 
58053ddc874SHan-Chung Wang   return indicesAfterCollapsing;
58199ef9eebSMatthias Springer }
58299ef9eebSMatthias Springer 
583834fcfedSDiego Caballero namespace {
584834fcfedSDiego Caballero 
58599ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_read ops by inserting
58699ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
58799ef9eebSMatthias Springer /// vector.transfer_read has a 1D source. Requires the source shape to be
58899ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
58934de7fd4SAndrzej Warzyński ///
59071441ed1SDiego Caballero /// If `targetVectorBitwidth` is provided, the flattening will only happen if
59171441ed1SDiego Caballero /// the trailing dimension of the vector read is smaller than the provided
59271441ed1SDiego Caballero /// bitwidth.
59399ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferReadPattern
59499ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
59571441ed1SDiego Caballero public:
59671441ed1SDiego Caballero   FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
59771441ed1SDiego Caballero                                                unsigned vectorBitwidth,
59871441ed1SDiego Caballero                                                PatternBenefit benefit)
59971441ed1SDiego Caballero       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
60071441ed1SDiego Caballero         targetVectorBitwidth(vectorBitwidth) {}
60199ef9eebSMatthias Springer 
60299ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
60399ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
60499ef9eebSMatthias Springer     auto loc = transferReadOp.getLoc();
6057c38fd60SJacques Pienaar     Value vector = transferReadOp.getVector();
6065550c821STres Popp     VectorType vectorType = cast<VectorType>(vector.getType());
6072eb9e33cSAndrzej Warzyński     auto source = transferReadOp.getSource();
6085550c821STres Popp     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
6092eb9e33cSAndrzej Warzyński 
6102eb9e33cSAndrzej Warzyński     // 0. Check pre-conditions
61199ef9eebSMatthias Springer     // Contiguity check is valid on tensors only.
61299ef9eebSMatthias Springer     if (!sourceType)
61399ef9eebSMatthias Springer       return failure();
6142eb9e33cSAndrzej Warzyński     // If this is already 0D/1D, there's nothing to do.
6154a876b13Sharsh     if (vectorType.getRank() <= 1)
61699ef9eebSMatthias Springer       return failure();
61771441ed1SDiego Caballero     if (!vectorType.getElementType().isSignlessIntOrFloat())
61871441ed1SDiego Caballero       return failure();
61971441ed1SDiego Caballero     unsigned trailingVectorDimBitwidth =
62071441ed1SDiego Caballero         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
62171441ed1SDiego Caballero     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
62271441ed1SDiego Caballero       return failure();
6238171eac2SAndrzej Warzyński     if (!vector::isContiguousSlice(sourceType, vectorType))
62499ef9eebSMatthias Springer       return failure();
62599ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
62699ef9eebSMatthias Springer     if (transferReadOp.hasOutOfBoundsDim())
62799ef9eebSMatthias Springer       return failure();
6287c38fd60SJacques Pienaar     if (!transferReadOp.getPermutationMap().isMinorIdentity())
62999ef9eebSMatthias Springer       return failure();
6307c38fd60SJacques Pienaar     if (transferReadOp.getMask())
63199ef9eebSMatthias Springer       return failure();
6322eb9e33cSAndrzej Warzyński 
6332eb9e33cSAndrzej Warzyński     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
6342eb9e33cSAndrzej Warzyński 
6352eb9e33cSAndrzej Warzyński     // 1. Collapse the source memref
636f4ac9509SBenoit Jacob     Value collapsedSource =
6372eb9e33cSAndrzej Warzyński         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
638f4ac9509SBenoit Jacob     MemRefType collapsedSourceType =
63934de7fd4SAndrzej Warzyński         cast<MemRefType>(collapsedSource.getType());
640f4ac9509SBenoit Jacob     int64_t collapsedRank = collapsedSourceType.getRank();
6412eb9e33cSAndrzej Warzyński     assert(collapsedRank == firstDimToCollapse + 1);
6422eb9e33cSAndrzej Warzyński 
6432eb9e33cSAndrzej Warzyński     // 2. Generate input args for a new vector.transfer_read that will read
6442eb9e33cSAndrzej Warzyński     // from the collapsed memref.
6452eb9e33cSAndrzej Warzyński     // 2.1. New dim exprs + affine map
646f4ac9509SBenoit Jacob     SmallVector<AffineExpr, 1> dimExprs{
6472eb9e33cSAndrzej Warzyński         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
648f4ac9509SBenoit Jacob     auto collapsedMap =
649f4ac9509SBenoit Jacob         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
6502eb9e33cSAndrzej Warzyński 
6512eb9e33cSAndrzej Warzyński     // 2.2 New indices
65253ddc874SHan-Chung Wang     SmallVector<Value> collapsedIndices =
65353ddc874SHan-Chung Wang         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
65453ddc874SHan-Chung Wang                             transferReadOp.getIndices(), firstDimToCollapse);
6552eb9e33cSAndrzej Warzyński 
6562eb9e33cSAndrzej Warzyński     // 3. Create new vector.transfer_read that reads from the collapsed memref
657f4ac9509SBenoit Jacob     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
658f4ac9509SBenoit Jacob                                                 vectorType.getElementType());
659f4ac9509SBenoit Jacob     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
660f4ac9509SBenoit Jacob         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
661f4ac9509SBenoit Jacob     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
6622eb9e33cSAndrzej Warzyński 
6632eb9e33cSAndrzej Warzyński     // 4. Replace the old transfer_read with the new one reading from the
6642eb9e33cSAndrzej Warzyński     // collapsed shape
66599ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6665550c821STres Popp         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
66799ef9eebSMatthias Springer     return success();
66899ef9eebSMatthias Springer   }
66971441ed1SDiego Caballero 
67071441ed1SDiego Caballero private:
67171441ed1SDiego Caballero   // Minimum bitwidth that the trailing vector dimension should have after
67271441ed1SDiego Caballero   // flattening.
67371441ed1SDiego Caballero   unsigned targetVectorBitwidth;
67499ef9eebSMatthias Springer };
67599ef9eebSMatthias Springer 
67699ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_write ops by inserting
67799ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
67899ef9eebSMatthias Springer /// vector.transfer_write has a 1D source. Requires the source shape to be
67999ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
68034de7fd4SAndrzej Warzyński ///
68134de7fd4SAndrzej Warzyński /// If `targetVectorBitwidth` is provided, the flattening will only happen if
68234de7fd4SAndrzej Warzyński /// the trailing dimension of the vector read is smaller than the provided
68334de7fd4SAndrzej Warzyński /// bitwidth.
68499ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferWritePattern
68599ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
68671441ed1SDiego Caballero public:
68771441ed1SDiego Caballero   FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
68871441ed1SDiego Caballero                                                 unsigned vectorBitwidth,
68971441ed1SDiego Caballero                                                 PatternBenefit benefit)
69071441ed1SDiego Caballero       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
69171441ed1SDiego Caballero         targetVectorBitwidth(vectorBitwidth) {}
69299ef9eebSMatthias Springer 
69399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
69499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
69599ef9eebSMatthias Springer     auto loc = transferWriteOp.getLoc();
6967c38fd60SJacques Pienaar     Value vector = transferWriteOp.getVector();
6975550c821STres Popp     VectorType vectorType = cast<VectorType>(vector.getType());
6987c38fd60SJacques Pienaar     Value source = transferWriteOp.getSource();
6995550c821STres Popp     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
70034de7fd4SAndrzej Warzyński 
70134de7fd4SAndrzej Warzyński     // 0. Check pre-conditions
70299ef9eebSMatthias Springer     // Contiguity check is valid on tensors only.
70399ef9eebSMatthias Springer     if (!sourceType)
70499ef9eebSMatthias Springer       return failure();
70534de7fd4SAndrzej Warzyński     // If this is already 0D/1D, there's nothing to do.
7064a876b13Sharsh     if (vectorType.getRank() <= 1)
7074a876b13Sharsh       // Already 0D/1D, nothing to do.
70899ef9eebSMatthias Springer       return failure();
70971441ed1SDiego Caballero     if (!vectorType.getElementType().isSignlessIntOrFloat())
71071441ed1SDiego Caballero       return failure();
71171441ed1SDiego Caballero     unsigned trailingVectorDimBitwidth =
71271441ed1SDiego Caballero         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
71371441ed1SDiego Caballero     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
71471441ed1SDiego Caballero       return failure();
7158171eac2SAndrzej Warzyński     if (!vector::isContiguousSlice(sourceType, vectorType))
71699ef9eebSMatthias Springer       return failure();
71799ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
71899ef9eebSMatthias Springer     if (transferWriteOp.hasOutOfBoundsDim())
71999ef9eebSMatthias Springer       return failure();
7207c38fd60SJacques Pienaar     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
72199ef9eebSMatthias Springer       return failure();
7227c38fd60SJacques Pienaar     if (transferWriteOp.getMask())
72399ef9eebSMatthias Springer       return failure();
72453ddc874SHan-Chung Wang 
72534de7fd4SAndrzej Warzyński     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
726847048f4SDiego Caballero 
72734de7fd4SAndrzej Warzyński     // 1. Collapse the source memref
728f4ac9509SBenoit Jacob     Value collapsedSource =
72953ddc874SHan-Chung Wang         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
730f4ac9509SBenoit Jacob     MemRefType collapsedSourceType =
7315550c821STres Popp         cast<MemRefType>(collapsedSource.getType());
732f4ac9509SBenoit Jacob     int64_t collapsedRank = collapsedSourceType.getRank();
73353ddc874SHan-Chung Wang     assert(collapsedRank == firstDimToCollapse + 1);
73453ddc874SHan-Chung Wang 
73534de7fd4SAndrzej Warzyński     // 2. Generate input args for a new vector.transfer_read that will read
73634de7fd4SAndrzej Warzyński     // from the collapsed memref.
73734de7fd4SAndrzej Warzyński     // 2.1. New dim exprs + affine map
738f4ac9509SBenoit Jacob     SmallVector<AffineExpr, 1> dimExprs{
73953ddc874SHan-Chung Wang         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
740f4ac9509SBenoit Jacob     auto collapsedMap =
741f4ac9509SBenoit Jacob         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
74253ddc874SHan-Chung Wang 
74334de7fd4SAndrzej Warzyński     // 2.2 New indices
74434de7fd4SAndrzej Warzyński     SmallVector<Value> collapsedIndices =
74534de7fd4SAndrzej Warzyński         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
74634de7fd4SAndrzej Warzyński                             transferWriteOp.getIndices(), firstDimToCollapse);
74734de7fd4SAndrzej Warzyński 
74834de7fd4SAndrzej Warzyński     // 3. Create new vector.transfer_write that writes to the collapsed memref
749f4ac9509SBenoit Jacob     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
750f4ac9509SBenoit Jacob                                                 vectorType.getElementType());
751f4ac9509SBenoit Jacob     Value flatVector =
752f4ac9509SBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
753f4ac9509SBenoit Jacob     vector::TransferWriteOp flatWrite =
754f4ac9509SBenoit Jacob         rewriter.create<vector::TransferWriteOp>(
755f4ac9509SBenoit Jacob             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
756f4ac9509SBenoit Jacob     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
75734de7fd4SAndrzej Warzyński 
75834de7fd4SAndrzej Warzyński     // 4. Replace the old transfer_write with the new one writing the
75934de7fd4SAndrzej Warzyński     // collapsed shape
76099ef9eebSMatthias Springer     rewriter.eraseOp(transferWriteOp);
76199ef9eebSMatthias Springer     return success();
76299ef9eebSMatthias Springer   }
76371441ed1SDiego Caballero 
76471441ed1SDiego Caballero private:
76571441ed1SDiego Caballero   // Minimum bitwidth that the trailing vector dimension should have after
76671441ed1SDiego Caballero   // flattening.
76771441ed1SDiego Caballero   unsigned targetVectorBitwidth;
76899ef9eebSMatthias Springer };
76999ef9eebSMatthias Springer 
77014726cd6SDiego Caballero /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
77114726cd6SDiego Caballero /// to `memref.load` patterns. The `match` method is shared for both
77214726cd6SDiego Caballero /// `vector.extract` and `vector.extract_element`.
77314726cd6SDiego Caballero template <class VectorExtractOp>
77414726cd6SDiego Caballero class RewriteScalarExtractOfTransferReadBase
77514726cd6SDiego Caballero     : public OpRewritePattern<VectorExtractOp> {
77614726cd6SDiego Caballero   using Base = OpRewritePattern<VectorExtractOp>;
7772ec98ffbSMatthias Springer 
77814726cd6SDiego Caballero public:
77914726cd6SDiego Caballero   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
78014726cd6SDiego Caballero                                          PatternBenefit benefit,
78114726cd6SDiego Caballero                                          bool allowMultipleUses)
78214726cd6SDiego Caballero       : Base::OpRewritePattern(context, benefit),
78314726cd6SDiego Caballero         allowMultipleUses(allowMultipleUses) {}
78414726cd6SDiego Caballero 
78514726cd6SDiego Caballero   LogicalResult match(VectorExtractOp extractOp) const override {
78614726cd6SDiego Caballero     auto xferOp =
78714726cd6SDiego Caballero         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
7882ec98ffbSMatthias Springer     if (!xferOp)
7892ec98ffbSMatthias Springer       return failure();
790d3e1398bSDiego Caballero     // Check that we are extracting a scalar and not a sub-vector.
791d3e1398bSDiego Caballero     if (isa<VectorType>(extractOp.getResult().getType()))
792d3e1398bSDiego Caballero       return failure();
79314726cd6SDiego Caballero     // If multiple uses are not allowed, check if xfer has a single use.
79414726cd6SDiego Caballero     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
79514726cd6SDiego Caballero       return failure();
79614726cd6SDiego Caballero     // If multiple uses are allowed, check if all the xfer uses are extract ops.
79714726cd6SDiego Caballero     if (allowMultipleUses &&
79814726cd6SDiego Caballero         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
79914726cd6SDiego Caballero           return isa<vector::ExtractOp, vector::ExtractElementOp>(
80014726cd6SDiego Caballero               use.getOwner());
80114726cd6SDiego Caballero         }))
8022ec98ffbSMatthias Springer       return failure();
8032ec98ffbSMatthias Springer     // Mask not supported.
8042ec98ffbSMatthias Springer     if (xferOp.getMask())
8052ec98ffbSMatthias Springer       return failure();
8062ec98ffbSMatthias Springer     // Map not supported.
8072ec98ffbSMatthias Springer     if (!xferOp.getPermutationMap().isMinorIdentity())
8082ec98ffbSMatthias Springer       return failure();
80914726cd6SDiego Caballero     // Cannot rewrite if the indices may be out of bounds.
81014726cd6SDiego Caballero     if (xferOp.hasOutOfBoundsDim())
8112ec98ffbSMatthias Springer       return failure();
81214726cd6SDiego Caballero     return success();
81314726cd6SDiego Caballero   }
81414726cd6SDiego Caballero 
81514726cd6SDiego Caballero private:
81614726cd6SDiego Caballero   bool allowMultipleUses;
81714726cd6SDiego Caballero };
81814726cd6SDiego Caballero 
81914726cd6SDiego Caballero /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
82014726cd6SDiego Caballero ///
82114726cd6SDiego Caballero /// All the users of the transfer op must be either `vector.extractelement` or
82214726cd6SDiego Caballero /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
82314726cd6SDiego Caballero /// transfer ops with any number of users. Otherwise, rewrite only if the
82414726cd6SDiego Caballero /// extract op is the single user of the transfer op. Rewriting a single
82514726cd6SDiego Caballero /// vector load with multiple scalar loads may negatively affect performance.
82614726cd6SDiego Caballero class RewriteScalarExtractElementOfTransferRead
82714726cd6SDiego Caballero     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
82814726cd6SDiego Caballero   using RewriteScalarExtractOfTransferReadBase::
82914726cd6SDiego Caballero       RewriteScalarExtractOfTransferReadBase;
83014726cd6SDiego Caballero 
83114726cd6SDiego Caballero   void rewrite(vector::ExtractElementOp extractOp,
83214726cd6SDiego Caballero                PatternRewriter &rewriter) const override {
8332ec98ffbSMatthias Springer     // Construct scalar load.
834d3e1398bSDiego Caballero     auto loc = extractOp.getLoc();
83514726cd6SDiego Caballero     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
8362ec98ffbSMatthias Springer     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
8372ec98ffbSMatthias Springer                                   xferOp.getIndices().end());
8382ec98ffbSMatthias Springer     if (extractOp.getPosition()) {
8392ec98ffbSMatthias Springer       AffineExpr sym0, sym1;
8402ec98ffbSMatthias Springer       bindSymbols(extractOp.getContext(), sym0, sym1);
8414c48f016SMatthias Springer       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
842d3e1398bSDiego Caballero           rewriter, loc, sym0 + sym1,
8432ec98ffbSMatthias Springer           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
8446e41483bSKazu Hirata       if (auto value = dyn_cast<Value>(ofr)) {
8456e41483bSKazu Hirata         newIndices[newIndices.size() - 1] = value;
8462ec98ffbSMatthias Springer       } else {
8472ec98ffbSMatthias Springer         newIndices[newIndices.size() - 1] =
848d3e1398bSDiego Caballero             rewriter.create<arith::ConstantIndexOp>(loc,
8492ec98ffbSMatthias Springer                                                     *getConstantIntValue(ofr));
8502ec98ffbSMatthias Springer       }
8512ec98ffbSMatthias Springer     }
8525550c821STres Popp     if (isa<MemRefType>(xferOp.getSource().getType())) {
8532ec98ffbSMatthias Springer       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
8542ec98ffbSMatthias Springer                                                   newIndices);
8552ec98ffbSMatthias Springer     } else {
8562ec98ffbSMatthias Springer       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
8572ec98ffbSMatthias Springer           extractOp, xferOp.getSource(), newIndices);
8582ec98ffbSMatthias Springer     }
8592ec98ffbSMatthias Springer   }
8602ec98ffbSMatthias Springer };
8612ec98ffbSMatthias Springer 
86214726cd6SDiego Caballero /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
86314726cd6SDiego Caballero /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
86441e731f2SMatthias Springer ///
86514726cd6SDiego Caballero /// All the users of the transfer op must be either `vector.extractelement` or
86614726cd6SDiego Caballero /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
86714726cd6SDiego Caballero /// transfer ops with any number of users. Otherwise, rewrite only if the
86814726cd6SDiego Caballero /// extract op is the single user of the transfer op. Rewriting a single
86914726cd6SDiego Caballero /// vector load with multiple scalar loads may negatively affect performance.
87041e731f2SMatthias Springer class RewriteScalarExtractOfTransferRead
87114726cd6SDiego Caballero     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
87214726cd6SDiego Caballero   using RewriteScalarExtractOfTransferReadBase::
87314726cd6SDiego Caballero       RewriteScalarExtractOfTransferReadBase;
8742ec98ffbSMatthias Springer 
87514726cd6SDiego Caballero   void rewrite(vector::ExtractOp extractOp,
8762ec98ffbSMatthias Springer                PatternRewriter &rewriter) const override {
87741e731f2SMatthias Springer     // Construct scalar load.
87814726cd6SDiego Caballero     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
87941e731f2SMatthias Springer     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
88041e731f2SMatthias Springer                                   xferOp.getIndices().end());
88198f6289aSDiego Caballero     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
8826e41483bSKazu Hirata       assert(isa<Attribute>(pos) && "Unexpected non-constant index");
8836e41483bSKazu Hirata       int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
88498f6289aSDiego Caballero       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
8854c48f016SMatthias Springer       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
88641e731f2SMatthias Springer           rewriter, extractOp.getLoc(),
88741e731f2SMatthias Springer           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
8886e41483bSKazu Hirata       if (auto value = dyn_cast<Value>(ofr)) {
8896e41483bSKazu Hirata         newIndices[idx] = value;
89041e731f2SMatthias Springer       } else {
89141e731f2SMatthias Springer         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
89241e731f2SMatthias Springer             extractOp.getLoc(), *getConstantIntValue(ofr));
89341e731f2SMatthias Springer       }
89441e731f2SMatthias Springer     }
8955550c821STres Popp     if (isa<MemRefType>(xferOp.getSource().getType())) {
89641e731f2SMatthias Springer       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
89741e731f2SMatthias Springer                                                   newIndices);
89841e731f2SMatthias Springer     } else {
89941e731f2SMatthias Springer       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
90041e731f2SMatthias Springer           extractOp, xferOp.getSource(), newIndices);
90141e731f2SMatthias Springer     }
90241e731f2SMatthias Springer   }
90341e731f2SMatthias Springer };
90441e731f2SMatthias Springer 
90541e731f2SMatthias Springer /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
90641e731f2SMatthias Springer /// to memref.store.
90741e731f2SMatthias Springer class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
90841e731f2SMatthias Springer   using OpRewritePattern::OpRewritePattern;
90941e731f2SMatthias Springer 
91041e731f2SMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
91141e731f2SMatthias Springer                                 PatternRewriter &rewriter) const override {
91241e731f2SMatthias Springer     // Must be a scalar write.
91341e731f2SMatthias Springer     auto vecType = xferOp.getVectorType();
91441e731f2SMatthias Springer     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
91541e731f2SMatthias Springer       return failure();
91641e731f2SMatthias Springer     // Mask not supported.
91741e731f2SMatthias Springer     if (xferOp.getMask())
91841e731f2SMatthias Springer       return failure();
91941e731f2SMatthias Springer     // Map not supported.
92041e731f2SMatthias Springer     if (!xferOp.getPermutationMap().isMinorIdentity())
92141e731f2SMatthias Springer       return failure();
92241e731f2SMatthias Springer     // Only float and integer element types are supported.
92341e731f2SMatthias Springer     Value scalar;
92441e731f2SMatthias Springer     if (vecType.getRank() == 0) {
92541e731f2SMatthias Springer       // vector.extract does not support vector<f32> etc., so use
92641e731f2SMatthias Springer       // vector.extractelement instead.
92741e731f2SMatthias Springer       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
92841e731f2SMatthias Springer                                                          xferOp.getVector());
92941e731f2SMatthias Springer     } else {
93041e731f2SMatthias Springer       SmallVector<int64_t> pos(vecType.getRank(), 0);
93141e731f2SMatthias Springer       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
93241e731f2SMatthias Springer                                                   xferOp.getVector(), pos);
93341e731f2SMatthias Springer     }
9342ec98ffbSMatthias Springer     // Construct a scalar store.
9355550c821STres Popp     if (isa<MemRefType>(xferOp.getSource().getType())) {
9362ec98ffbSMatthias Springer       rewriter.replaceOpWithNewOp<memref::StoreOp>(
93741e731f2SMatthias Springer           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
9382ec98ffbSMatthias Springer     } else {
9392ec98ffbSMatthias Springer       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
94041e731f2SMatthias Springer           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
9412ec98ffbSMatthias Springer     }
9422ec98ffbSMatthias Springer     return success();
9432ec98ffbSMatthias Springer   }
9442ec98ffbSMatthias Springer };
945834fcfedSDiego Caballero 
94699ef9eebSMatthias Springer } // namespace
94799ef9eebSMatthias Springer 
948553cebdeSNicolas Vasilache void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
949553cebdeSNicolas Vasilache                                      Operation *rootOp) {
950553cebdeSNicolas Vasilache   TransferOptimization opt(rewriter, rootOp);
95199ef9eebSMatthias Springer   // Run store to load forwarding first since it can expose more dead store
95299ef9eebSMatthias Springer   // opportunity.
953171850c5SRiver Riddle   rootOp->walk([&](vector::TransferReadOp read) {
9545550c821STres Popp     if (isa<MemRefType>(read.getShapedType()))
95599ef9eebSMatthias Springer       opt.storeToLoadForwarding(read);
95699ef9eebSMatthias Springer   });
95799ef9eebSMatthias Springer   opt.removeDeadOp();
958171850c5SRiver Riddle   rootOp->walk([&](vector::TransferWriteOp write) {
9595550c821STres Popp     if (isa<MemRefType>(write.getShapedType()))
96099ef9eebSMatthias Springer       opt.deadStoreOp(write);
96199ef9eebSMatthias Springer   });
96299ef9eebSMatthias Springer   opt.removeDeadOp();
96399ef9eebSMatthias Springer }
96499ef9eebSMatthias Springer 
9652ec98ffbSMatthias Springer void mlir::vector::populateScalarVectorTransferLoweringPatterns(
96614726cd6SDiego Caballero     RewritePatternSet &patterns, PatternBenefit benefit,
96714726cd6SDiego Caballero     bool allowMultipleUses) {
96841e731f2SMatthias Springer   patterns.add<RewriteScalarExtractElementOfTransferRead,
96914726cd6SDiego Caballero                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
97014726cd6SDiego Caballero                                                    benefit, allowMultipleUses);
97114726cd6SDiego Caballero   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
9722ec98ffbSMatthias Springer }
9732ec98ffbSMatthias Springer 
97499ef9eebSMatthias Springer void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
97527cc31b6SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
97699ef9eebSMatthias Springer   patterns
97799ef9eebSMatthias Springer       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
97827cc31b6SNicolas Vasilache           patterns.getContext(), benefit);
97999ef9eebSMatthias Springer   populateShapeCastFoldingPatterns(patterns);
98099ef9eebSMatthias Springer }
98199ef9eebSMatthias Springer 
98299ef9eebSMatthias Springer void mlir::vector::populateFlattenVectorTransferPatterns(
98371441ed1SDiego Caballero     RewritePatternSet &patterns, unsigned targetVectorBitwidth,
98471441ed1SDiego Caballero     PatternBenefit benefit) {
98599ef9eebSMatthias Springer   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
98699ef9eebSMatthias Springer                FlattenContiguousRowMajorTransferWritePattern>(
98771441ed1SDiego Caballero       patterns.getContext(), targetVectorBitwidth, benefit);
98827cc31b6SNicolas Vasilache   populateShapeCastFoldingPatterns(patterns, benefit);
989c02d07fdSAndrzej Warzyński   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
99099ef9eebSMatthias Springer }
991