xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/IndexingUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Interfaces/SideEffectInterfaces.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
38 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
39   for (; op != nullptr && op->getParentRegion() != region;
40        op = op->getParentOp())
41     ;
42   return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49   TransferOptimization(RewriterBase &rewriter, Operation *op)
50       : rewriter(rewriter), dominators(op), postDominators(op) {}
51   void deadStoreOp(vector::TransferWriteOp);
52   void storeToLoadForwarding(vector::TransferReadOp);
53   void removeDeadOp() {
54     for (Operation *op : opToErase)
55       rewriter.eraseOp(op);
56     opToErase.clear();
57   }
58 
59 private:
60   RewriterBase &rewriter;
61   bool isReachable(Operation *start, Operation *dest);
62   DominanceInfo dominators;
63   PostDominanceInfo postDominators;
64   std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71   assert(start->getParentRegion() == dest->getParentRegion() &&
72          "This function only works for ops i the same region");
73   // Simple case where the start op dominate the destination.
74   if (dominators.dominates(start, dest))
75     return true;
76   return start->getBlock()->isReachable(dest->getBlock());
77 }
78 
79 /// For transfer_write to overwrite fully another transfer_write must:
80 /// 1. Access the same memref with the same indices and vector type.
81 /// 2. Post-dominate the other transfer_write operation.
82 /// If several candidates are available, one must be post-dominated by all the
83 /// others since they are all post-dominating the same transfer_write. We only
84 /// consider the transfer_write post-dominated by all the other candidates as
85 /// this will be the first transfer_write executed after the potentially dead
86 /// transfer_write.
87 /// If we found such an overwriting transfer_write we know that the original
88 /// transfer_write is dead if all reads that can be reached from the potentially
89 /// dead transfer_write are dominated by the overwriting transfer_write.
90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
92                     << "\n");
93   llvm::SmallVector<Operation *, 8> blockingAccesses;
94   Operation *firstOverwriteCandidate = nullptr;
95   Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
96   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97                                            source.getUsers().end());
98   llvm::SmallDenseSet<Operation *, 32> processed;
99   while (!users.empty()) {
100     Operation *user = users.pop_back_val();
101     // If the user has already been processed skip.
102     if (!processed.insert(user).second)
103       continue;
104     if (isa<ViewLikeOpInterface>(user)) {
105       users.append(user->getUsers().begin(), user->getUsers().end());
106       continue;
107     }
108     if (isMemoryEffectFree(user))
109       continue;
110     if (user == write.getOperation())
111       continue;
112     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
113       // Check candidate that can override the store.
114       if (memref::isSameViewOrTrivialAlias(
115               cast<MemrefValue>(nextWrite.getSource()),
116               cast<MemrefValue>(write.getSource())) &&
117           checkSameValueWAW(nextWrite, write) &&
118           postDominators.postDominates(nextWrite, write)) {
119         if (firstOverwriteCandidate == nullptr ||
120             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121           firstOverwriteCandidate = nextWrite;
122         else
123           assert(
124               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
125         continue;
126       }
127     }
128     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
129       // Don't need to consider disjoint accesses.
130       if (vector::isDisjointTransferSet(
131               cast<VectorTransferOpInterface>(write.getOperation()),
132               cast<VectorTransferOpInterface>(transferOp.getOperation()),
133               /*testDynamicValueUsingBounds=*/true))
134         continue;
135     }
136     blockingAccesses.push_back(user);
137   }
138   if (firstOverwriteCandidate == nullptr)
139     return;
140   Region *topRegion = firstOverwriteCandidate->getParentRegion();
141   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
142   assert(writeAncestor &&
143          "write op should be recursively part of the top region");
144 
145   for (Operation *access : blockingAccesses) {
146     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
147     // TODO: if the access and write have the same ancestor we could recurse in
148     // the region to know if the access is reachable with more precision.
149     if (accessAncestor == nullptr ||
150         !isReachable(writeAncestor, accessAncestor))
151       continue;
152     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154                         << *accessAncestor << "\n");
155       return;
156     }
157   }
158   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
159                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
160   opToErase.push_back(write.getOperation());
161 }
162 
163 /// A transfer_write candidate to storeToLoad forwarding must:
164 /// 1. Access the same memref with the same indices and vector type as the
165 /// transfer_read.
166 /// 2. Dominate the transfer_read operation.
167 /// If several candidates are available, one must be dominated by all the others
168 /// since they are all dominating the same transfer_read. We only consider the
169 /// transfer_write dominated by all the other candidates as this will be the
170 /// last transfer_write executed before the transfer_read.
171 /// If we found such a candidate we can do the forwarding if all the other
172 /// potentially aliasing ops that may reach the transfer_read are post-dominated
173 /// by the transfer_write.
174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175   if (read.hasOutOfBoundsDim())
176     return;
177   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
178                     << "\n");
179   SmallVector<Operation *, 8> blockingWrites;
180   vector::TransferWriteOp lastwrite = nullptr;
181   Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
182   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183                                            source.getUsers().end());
184   llvm::SmallDenseSet<Operation *, 32> processed;
185   while (!users.empty()) {
186     Operation *user = users.pop_back_val();
187     // If the user has already been processed skip.
188     if (!processed.insert(user).second)
189       continue;
190     if (isa<ViewLikeOpInterface>(user)) {
191       users.append(user->getUsers().begin(), user->getUsers().end());
192       continue;
193     }
194     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
195       continue;
196     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
197       // If there is a write, but we can prove that it is disjoint we can ignore
198       // the write.
199       if (vector::isDisjointTransferSet(
200               cast<VectorTransferOpInterface>(write.getOperation()),
201               cast<VectorTransferOpInterface>(read.getOperation()),
202               /*testDynamicValueUsingBounds=*/true))
203         continue;
204       if (memref::isSameViewOrTrivialAlias(
205               cast<MemrefValue>(read.getSource()),
206               cast<MemrefValue>(write.getSource())) &&
207           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
208         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
209           lastwrite = write;
210         else
211           assert(dominators.dominates(write, lastwrite));
212         continue;
213       }
214     }
215     blockingWrites.push_back(user);
216   }
217 
218   if (lastwrite == nullptr)
219     return;
220 
221   Region *topRegion = lastwrite->getParentRegion();
222   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
223   assert(readAncestor &&
224          "read op should be recursively part of the top region");
225 
226   for (Operation *write : blockingWrites) {
227     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
228     // TODO: if the store and read have the same ancestor we could recurse in
229     // the region to know if the read is reachable with more precision.
230     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
231       continue;
232     if (!postDominators.postDominates(lastwrite, write)) {
233       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
234                         << *write << "\n");
235       return;
236     }
237   }
238 
239   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
240                     << " to: " << *read.getOperation() << "\n");
241   read.replaceAllUsesWith(lastwrite.getVector());
242   opToErase.push_back(read.getOperation());
243 }
244 
245 /// Converts OpFoldResults to int64_t shape without unit dims.
246 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
247   SmallVector<int64_t> reducedShape;
248   for (const auto size : mixedSizes) {
249     if (llvm::dyn_cast_if_present<Value>(size)) {
250       reducedShape.push_back(ShapedType::kDynamic);
251       continue;
252     }
253 
254     auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
255     if (value == 1)
256       continue;
257     reducedShape.push_back(value.getSExtValue());
258   }
259   return reducedShape;
260 }
261 
262 /// Drops unit dimensions from the input MemRefType.
263 static MemRefType dropUnitDims(MemRefType inputType,
264                                ArrayRef<OpFoldResult> offsets,
265                                ArrayRef<OpFoldResult> sizes,
266                                ArrayRef<OpFoldResult> strides) {
267   auto targetShape = getReducedShape(sizes);
268   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269       targetShape, inputType, offsets, sizes, strides);
270   return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
271 }
272 
273 /// Creates a rank-reducing memref.subview op that drops unit dims from its
274 /// input. Or just returns the input if it was already without unit dims.
275 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
276                                                  mlir::Location loc,
277                                                  Value input) {
278   MemRefType inputType = cast<MemRefType>(input.getType());
279   SmallVector<OpFoldResult> offsets(inputType.getRank(),
280                                     rewriter.getIndexAttr(0));
281   SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
282   SmallVector<OpFoldResult> strides(inputType.getRank(),
283                                     rewriter.getIndexAttr(1));
284   MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285 
286   if (resultType.canonicalizeStridedLayout() ==
287       inputType.canonicalizeStridedLayout())
288     return input;
289   return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
290                                             sizes, strides);
291 }
292 
293 /// Returns the number of dims that aren't unit dims.
294 static int getReducedRank(ArrayRef<int64_t> shape) {
295   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
296 }
297 
298 /// Trims non-scalable one dimensions from `oldType` and returns the result
299 /// type.
300 static VectorType trimNonScalableUnitDims(VectorType oldType) {
301   SmallVector<int64_t> newShape;
302   SmallVector<bool> newScalableDims;
303   for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
304     if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305       continue;
306     newShape.push_back(dimSize);
307     newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
308   }
309   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
310 }
311 
312 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313 static FailureOr<Value>
314 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
315                                   vector::CreateMaskOp op) {
316   auto type = op.getType();
317   VectorType reducedType = trimNonScalableUnitDims(type);
318   if (reducedType.getRank() == type.getRank())
319     return failure();
320 
321   SmallVector<Value> reducedOperands;
322   for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323            type.getShape(), type.getScalableDims(), op.getOperands())) {
324     if (dim == 1 && !dimIsScalable) {
325       // If the mask for the unit dim is not a constant of 1, do nothing.
326       auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327       if (!constant || (constant.value() != 1))
328         return failure();
329       continue;
330     }
331     reducedOperands.push_back(operand);
332   }
333   return rewriter
334       .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
335       .getResult();
336 }
337 
338 namespace {
339 
340 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341 /// inserting a memref.subview dropping those unit dims. The vector shapes are
342 /// also reduced accordingly.
343 class TransferReadDropUnitDimsPattern
344     : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
345   using MaskableOpRewritePattern::MaskableOpRewritePattern;
346 
347   FailureOr<Value>
348   matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349                             vector::MaskingOpInterface maskingOp,
350                             PatternRewriter &rewriter) const override {
351     auto loc = transferReadOp.getLoc();
352     Value vector = transferReadOp.getVector();
353     VectorType vectorType = cast<VectorType>(vector.getType());
354     Value source = transferReadOp.getSource();
355     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
356     // TODO: support tensor types.
357     if (!sourceType)
358       return failure();
359     // TODO: generalize this pattern, relax the requirements here.
360     if (transferReadOp.hasOutOfBoundsDim())
361       return failure();
362     if (!transferReadOp.getPermutationMap().isMinorIdentity())
363       return failure();
364     // Check if the source shape can be further reduced.
365     int reducedRank = getReducedRank(sourceType.getShape());
366     if (reducedRank == sourceType.getRank())
367       return failure();
368     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
369     // out.
370     if (reducedRank == 0 && maskingOp)
371       return failure();
372     // Check if the reduced vector shape matches the reduced source shape.
373     // Otherwise, this case is not supported yet.
374     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
375     if (reducedRank != reducedVectorType.getRank())
376       return failure();
377     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
378           return getConstantIntValue(v) != static_cast<int64_t>(0);
379         }))
380       return failure();
381 
382     Value maskOp = transferReadOp.getMask();
383     if (maskOp) {
384       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385       if (!createMaskOp)
386         return rewriter.notifyMatchFailure(
387             transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
388                             "currently supported");
389       FailureOr<Value> rankReducedCreateMask =
390           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
391       if (failed(rankReducedCreateMask))
392         return failure();
393       maskOp = *rankReducedCreateMask;
394     }
395 
396     Value reducedShapeSource =
397         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
398     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
399     SmallVector<Value> zeros(reducedRank, c0);
400     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
401     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
402     Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404         transferReadOp.getPadding(), maskOp,
405         rewriter.getBoolArrayAttr(inBounds));
406 
407     if (maskingOp) {
408       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
409           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
410           maskingOp.getMask());
411       newTransferReadOp = mlir::vector::maskOperation(
412           rewriter, newTransferReadOp, shapeCastMask);
413     }
414 
415     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416         loc, vectorType, newTransferReadOp->getResults()[0]);
417 
418     return shapeCast;
419   }
420 };
421 
422 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424 /// vector shapes are also reduced accordingly.
425 class TransferWriteDropUnitDimsPattern
426     : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
427   using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 
429   FailureOr<Value>
430   matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431                             vector::MaskingOpInterface maskingOp,
432                             PatternRewriter &rewriter) const override {
433     auto loc = transferWriteOp.getLoc();
434     Value vector = transferWriteOp.getVector();
435     VectorType vectorType = cast<VectorType>(vector.getType());
436     Value source = transferWriteOp.getSource();
437     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
438     // TODO: support tensor type.
439     if (!sourceType)
440       return failure();
441     // TODO: generalize this pattern, relax the requirements here.
442     if (transferWriteOp.hasOutOfBoundsDim())
443       return failure();
444     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
445       return failure();
446     // Check if the destination shape can be further reduced.
447     int reducedRank = getReducedRank(sourceType.getShape());
448     if (reducedRank == sourceType.getRank())
449       return failure();
450     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
451     // out.
452     if (reducedRank == 0 && maskingOp)
453       return failure();
454     // Check if the reduced vector shape matches the reduced destination shape.
455     // Otherwise, this case is not supported yet.
456     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
457     if (reducedRank != reducedVectorType.getRank())
458       return failure();
459     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
460           return getConstantIntValue(v) != static_cast<int64_t>(0);
461         }))
462       return failure();
463 
464     Value maskOp = transferWriteOp.getMask();
465     if (maskOp) {
466       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467       if (!createMaskOp)
468         return rewriter.notifyMatchFailure(
469             transferWriteOp,
470             "unsupported mask op, only 'vector.create_mask' is "
471             "currently supported");
472       FailureOr<Value> rankReducedCreateMask =
473           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
474       if (failed(rankReducedCreateMask))
475         return failure();
476       maskOp = *rankReducedCreateMask;
477     }
478     Value reducedShapeSource =
479         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
480     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
481     SmallVector<Value> zeros(reducedRank, c0);
482     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
483     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
484     auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485         loc, reducedVectorType, vector);
486     Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
487         loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
488         maskOp, rewriter.getBoolArrayAttr(inBounds));
489 
490     if (maskingOp) {
491       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
492           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
493           maskingOp.getMask());
494       newXferWrite =
495           mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
496     }
497 
498     if (transferWriteOp.hasPureTensorSemantics())
499       return newXferWrite->getResults()[0];
500 
501     // With Memref semantics, there's no return value. Use empty value to signal
502     // success.
503     return Value();
504   }
505 };
506 
507 } // namespace
508 
509 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
510 /// input starting at `firstDimToCollapse`.
511 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
512                                Value input, int64_t firstDimToCollapse) {
513   ShapedType inputType = cast<ShapedType>(input.getType());
514   if (inputType.getRank() == 1)
515     return input;
516   SmallVector<ReassociationIndices> reassociation;
517   for (int64_t i = 0; i < firstDimToCollapse; ++i)
518     reassociation.push_back(ReassociationIndices{i});
519   ReassociationIndices collapsedIndices;
520   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521     collapsedIndices.push_back(i);
522   reassociation.push_back(collapsedIndices);
523   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
524 }
525 
526 /// Returns the new indices that collapses the inner dimensions starting from
527 /// the `firstDimToCollapse` dimension.
528 static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
529                                               Location loc,
530                                               ArrayRef<int64_t> shape,
531                                               ValueRange indices,
532                                               int64_t firstDimToCollapse) {
533   assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
534 
535   // If all the collapsed indices are zero then no extra logic is needed.
536   // Otherwise, a new offset/index has to be computed.
537   SmallVector<Value> indicesAfterCollapsing(
538       indices.begin(), indices.begin() + firstDimToCollapse);
539   SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
540                                        indices.end());
541   if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
542     indicesAfterCollapsing.push_back(indicesToCollapse[0]);
543     return indicesAfterCollapsing;
544   }
545 
546   // Compute the remaining trailing index/offset required for reading from
547   // the collapsed memref:
548   //
549   //    offset = 0
550   //    for (i = firstDimToCollapse; i < outputRank; ++i)
551   //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
552   //
553   // For this example:
554   //   %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
555   //      memref<1x43x2xi32>, vector<1x2xi32>
556   // which would be collapsed to:
557   //   %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
558   //      memref<1x86xi32>, vector<2xi32>
559   // one would get the following offset:
560   //    %offset = %arg0 * 43
561   OpFoldResult collapsedOffset =
562       rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
563 
564   auto collapsedStrides = computeSuffixProduct(
565       ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
566 
567   // Compute the collapsed offset.
568   auto &&[collapsedExpr, collapsedVals] =
569       computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
570   collapsedOffset = affine::makeComposedFoldedAffineApply(
571       rewriter, loc, collapsedExpr, collapsedVals);
572 
573   if (auto value = dyn_cast<Value>(collapsedOffset)) {
574     indicesAfterCollapsing.push_back(value);
575   } else {
576     indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
577         loc, *getConstantIntValue(collapsedOffset)));
578   }
579 
580   return indicesAfterCollapsing;
581 }
582 
583 namespace {
584 
585 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
586 /// memref.collapse_shape on the source so that the resulting
587 /// vector.transfer_read has a 1D source. Requires the source shape to be
588 /// already reduced i.e. without unit dims.
589 ///
590 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
591 /// the trailing dimension of the vector read is smaller than the provided
592 /// bitwidth.
593 class FlattenContiguousRowMajorTransferReadPattern
594     : public OpRewritePattern<vector::TransferReadOp> {
595 public:
596   FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
597                                                unsigned vectorBitwidth,
598                                                PatternBenefit benefit)
599       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
600         targetVectorBitwidth(vectorBitwidth) {}
601 
602   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
603                                 PatternRewriter &rewriter) const override {
604     auto loc = transferReadOp.getLoc();
605     Value vector = transferReadOp.getVector();
606     VectorType vectorType = cast<VectorType>(vector.getType());
607     auto source = transferReadOp.getSource();
608     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
609 
610     // 0. Check pre-conditions
611     // Contiguity check is valid on tensors only.
612     if (!sourceType)
613       return failure();
614     // If this is already 0D/1D, there's nothing to do.
615     if (vectorType.getRank() <= 1)
616       return failure();
617     if (!vectorType.getElementType().isSignlessIntOrFloat())
618       return failure();
619     unsigned trailingVectorDimBitwidth =
620         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
621     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
622       return failure();
623     if (!vector::isContiguousSlice(sourceType, vectorType))
624       return failure();
625     // TODO: generalize this pattern, relax the requirements here.
626     if (transferReadOp.hasOutOfBoundsDim())
627       return failure();
628     if (!transferReadOp.getPermutationMap().isMinorIdentity())
629       return failure();
630     if (transferReadOp.getMask())
631       return failure();
632 
633     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
634 
635     // 1. Collapse the source memref
636     Value collapsedSource =
637         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
638     MemRefType collapsedSourceType =
639         cast<MemRefType>(collapsedSource.getType());
640     int64_t collapsedRank = collapsedSourceType.getRank();
641     assert(collapsedRank == firstDimToCollapse + 1);
642 
643     // 2. Generate input args for a new vector.transfer_read that will read
644     // from the collapsed memref.
645     // 2.1. New dim exprs + affine map
646     SmallVector<AffineExpr, 1> dimExprs{
647         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
648     auto collapsedMap =
649         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
650 
651     // 2.2 New indices
652     SmallVector<Value> collapsedIndices =
653         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
654                             transferReadOp.getIndices(), firstDimToCollapse);
655 
656     // 3. Create new vector.transfer_read that reads from the collapsed memref
657     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
658                                                 vectorType.getElementType());
659     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
660         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
661     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
662 
663     // 4. Replace the old transfer_read with the new one reading from the
664     // collapsed shape
665     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
666         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
667     return success();
668   }
669 
670 private:
671   // Minimum bitwidth that the trailing vector dimension should have after
672   // flattening.
673   unsigned targetVectorBitwidth;
674 };
675 
676 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
677 /// memref.collapse_shape on the source so that the resulting
678 /// vector.transfer_write has a 1D source. Requires the source shape to be
679 /// already reduced i.e. without unit dims.
680 ///
681 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
682 /// the trailing dimension of the vector read is smaller than the provided
683 /// bitwidth.
684 class FlattenContiguousRowMajorTransferWritePattern
685     : public OpRewritePattern<vector::TransferWriteOp> {
686 public:
687   FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
688                                                 unsigned vectorBitwidth,
689                                                 PatternBenefit benefit)
690       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
691         targetVectorBitwidth(vectorBitwidth) {}
692 
693   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
694                                 PatternRewriter &rewriter) const override {
695     auto loc = transferWriteOp.getLoc();
696     Value vector = transferWriteOp.getVector();
697     VectorType vectorType = cast<VectorType>(vector.getType());
698     Value source = transferWriteOp.getSource();
699     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
700 
701     // 0. Check pre-conditions
702     // Contiguity check is valid on tensors only.
703     if (!sourceType)
704       return failure();
705     // If this is already 0D/1D, there's nothing to do.
706     if (vectorType.getRank() <= 1)
707       // Already 0D/1D, nothing to do.
708       return failure();
709     if (!vectorType.getElementType().isSignlessIntOrFloat())
710       return failure();
711     unsigned trailingVectorDimBitwidth =
712         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
713     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
714       return failure();
715     if (!vector::isContiguousSlice(sourceType, vectorType))
716       return failure();
717     // TODO: generalize this pattern, relax the requirements here.
718     if (transferWriteOp.hasOutOfBoundsDim())
719       return failure();
720     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
721       return failure();
722     if (transferWriteOp.getMask())
723       return failure();
724 
725     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
726 
727     // 1. Collapse the source memref
728     Value collapsedSource =
729         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
730     MemRefType collapsedSourceType =
731         cast<MemRefType>(collapsedSource.getType());
732     int64_t collapsedRank = collapsedSourceType.getRank();
733     assert(collapsedRank == firstDimToCollapse + 1);
734 
735     // 2. Generate input args for a new vector.transfer_read that will read
736     // from the collapsed memref.
737     // 2.1. New dim exprs + affine map
738     SmallVector<AffineExpr, 1> dimExprs{
739         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
740     auto collapsedMap =
741         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
742 
743     // 2.2 New indices
744     SmallVector<Value> collapsedIndices =
745         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
746                             transferWriteOp.getIndices(), firstDimToCollapse);
747 
748     // 3. Create new vector.transfer_write that writes to the collapsed memref
749     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
750                                                 vectorType.getElementType());
751     Value flatVector =
752         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
753     vector::TransferWriteOp flatWrite =
754         rewriter.create<vector::TransferWriteOp>(
755             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
756     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
757 
758     // 4. Replace the old transfer_write with the new one writing the
759     // collapsed shape
760     rewriter.eraseOp(transferWriteOp);
761     return success();
762   }
763 
764 private:
765   // Minimum bitwidth that the trailing vector dimension should have after
766   // flattening.
767   unsigned targetVectorBitwidth;
768 };
769 
770 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
771 /// to `memref.load` patterns. The `match` method is shared for both
772 /// `vector.extract` and `vector.extract_element`.
773 template <class VectorExtractOp>
774 class RewriteScalarExtractOfTransferReadBase
775     : public OpRewritePattern<VectorExtractOp> {
776   using Base = OpRewritePattern<VectorExtractOp>;
777 
778 public:
779   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
780                                          PatternBenefit benefit,
781                                          bool allowMultipleUses)
782       : Base::OpRewritePattern(context, benefit),
783         allowMultipleUses(allowMultipleUses) {}
784 
785   LogicalResult match(VectorExtractOp extractOp) const override {
786     auto xferOp =
787         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
788     if (!xferOp)
789       return failure();
790     // Check that we are extracting a scalar and not a sub-vector.
791     if (isa<VectorType>(extractOp.getResult().getType()))
792       return failure();
793     // If multiple uses are not allowed, check if xfer has a single use.
794     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
795       return failure();
796     // If multiple uses are allowed, check if all the xfer uses are extract ops.
797     if (allowMultipleUses &&
798         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
799           return isa<vector::ExtractOp, vector::ExtractElementOp>(
800               use.getOwner());
801         }))
802       return failure();
803     // Mask not supported.
804     if (xferOp.getMask())
805       return failure();
806     // Map not supported.
807     if (!xferOp.getPermutationMap().isMinorIdentity())
808       return failure();
809     // Cannot rewrite if the indices may be out of bounds.
810     if (xferOp.hasOutOfBoundsDim())
811       return failure();
812     return success();
813   }
814 
815 private:
816   bool allowMultipleUses;
817 };
818 
819 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
820 ///
821 /// All the users of the transfer op must be either `vector.extractelement` or
822 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
823 /// transfer ops with any number of users. Otherwise, rewrite only if the
824 /// extract op is the single user of the transfer op. Rewriting a single
825 /// vector load with multiple scalar loads may negatively affect performance.
826 class RewriteScalarExtractElementOfTransferRead
827     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
828   using RewriteScalarExtractOfTransferReadBase::
829       RewriteScalarExtractOfTransferReadBase;
830 
831   void rewrite(vector::ExtractElementOp extractOp,
832                PatternRewriter &rewriter) const override {
833     // Construct scalar load.
834     auto loc = extractOp.getLoc();
835     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
836     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
837                                   xferOp.getIndices().end());
838     if (extractOp.getPosition()) {
839       AffineExpr sym0, sym1;
840       bindSymbols(extractOp.getContext(), sym0, sym1);
841       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
842           rewriter, loc, sym0 + sym1,
843           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
844       if (auto value = dyn_cast<Value>(ofr)) {
845         newIndices[newIndices.size() - 1] = value;
846       } else {
847         newIndices[newIndices.size() - 1] =
848             rewriter.create<arith::ConstantIndexOp>(loc,
849                                                     *getConstantIntValue(ofr));
850       }
851     }
852     if (isa<MemRefType>(xferOp.getSource().getType())) {
853       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
854                                                   newIndices);
855     } else {
856       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
857           extractOp, xferOp.getSource(), newIndices);
858     }
859   }
860 };
861 
862 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
863 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
864 ///
865 /// All the users of the transfer op must be either `vector.extractelement` or
866 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
867 /// transfer ops with any number of users. Otherwise, rewrite only if the
868 /// extract op is the single user of the transfer op. Rewriting a single
869 /// vector load with multiple scalar loads may negatively affect performance.
870 class RewriteScalarExtractOfTransferRead
871     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
872   using RewriteScalarExtractOfTransferReadBase::
873       RewriteScalarExtractOfTransferReadBase;
874 
875   void rewrite(vector::ExtractOp extractOp,
876                PatternRewriter &rewriter) const override {
877     // Construct scalar load.
878     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
879     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
880                                   xferOp.getIndices().end());
881     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
882       assert(isa<Attribute>(pos) && "Unexpected non-constant index");
883       int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
884       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
885       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
886           rewriter, extractOp.getLoc(),
887           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
888       if (auto value = dyn_cast<Value>(ofr)) {
889         newIndices[idx] = value;
890       } else {
891         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
892             extractOp.getLoc(), *getConstantIntValue(ofr));
893       }
894     }
895     if (isa<MemRefType>(xferOp.getSource().getType())) {
896       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
897                                                   newIndices);
898     } else {
899       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
900           extractOp, xferOp.getSource(), newIndices);
901     }
902   }
903 };
904 
905 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
906 /// to memref.store.
907 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
908   using OpRewritePattern::OpRewritePattern;
909 
910   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
911                                 PatternRewriter &rewriter) const override {
912     // Must be a scalar write.
913     auto vecType = xferOp.getVectorType();
914     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
915       return failure();
916     // Mask not supported.
917     if (xferOp.getMask())
918       return failure();
919     // Map not supported.
920     if (!xferOp.getPermutationMap().isMinorIdentity())
921       return failure();
922     // Only float and integer element types are supported.
923     Value scalar;
924     if (vecType.getRank() == 0) {
925       // vector.extract does not support vector<f32> etc., so use
926       // vector.extractelement instead.
927       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
928                                                          xferOp.getVector());
929     } else {
930       SmallVector<int64_t> pos(vecType.getRank(), 0);
931       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
932                                                   xferOp.getVector(), pos);
933     }
934     // Construct a scalar store.
935     if (isa<MemRefType>(xferOp.getSource().getType())) {
936       rewriter.replaceOpWithNewOp<memref::StoreOp>(
937           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
938     } else {
939       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
940           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
941     }
942     return success();
943   }
944 };
945 
946 } // namespace
947 
948 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
949                                      Operation *rootOp) {
950   TransferOptimization opt(rewriter, rootOp);
951   // Run store to load forwarding first since it can expose more dead store
952   // opportunity.
953   rootOp->walk([&](vector::TransferReadOp read) {
954     if (isa<MemRefType>(read.getShapedType()))
955       opt.storeToLoadForwarding(read);
956   });
957   opt.removeDeadOp();
958   rootOp->walk([&](vector::TransferWriteOp write) {
959     if (isa<MemRefType>(write.getShapedType()))
960       opt.deadStoreOp(write);
961   });
962   opt.removeDeadOp();
963 }
964 
965 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
966     RewritePatternSet &patterns, PatternBenefit benefit,
967     bool allowMultipleUses) {
968   patterns.add<RewriteScalarExtractElementOfTransferRead,
969                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
970                                                    benefit, allowMultipleUses);
971   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
972 }
973 
974 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
975     RewritePatternSet &patterns, PatternBenefit benefit) {
976   patterns
977       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
978           patterns.getContext(), benefit);
979   populateShapeCastFoldingPatterns(patterns);
980 }
981 
982 void mlir::vector::populateFlattenVectorTransferPatterns(
983     RewritePatternSet &patterns, unsigned targetVectorBitwidth,
984     PatternBenefit benefit) {
985   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
986                FlattenContiguousRowMajorTransferWritePattern>(
987       patterns.getContext(), targetVectorBitwidth, benefit);
988   populateShapeCastFoldingPatterns(patterns, benefit);
989   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
990 }
991