xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 0cf7aaf30067c4be2886a8c9127a27dcbfd63b92)
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/IndexingUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Interfaces/SideEffectInterfaces.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "vector-transfer-opt"
31 
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 
36 /// Return the ancestor op in the region or nullptr if the region is not
37 /// an ancestor of the op.
38 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
39   for (; op != nullptr && op->getParentRegion() != region;
40        op = op->getParentOp())
41     ;
42   return op;
43 }
44 
45 namespace {
46 
47 class TransferOptimization {
48 public:
49   TransferOptimization(RewriterBase &rewriter, Operation *op)
50       : rewriter(rewriter), dominators(op), postDominators(op) {}
51   void deadStoreOp(vector::TransferWriteOp);
52   void storeToLoadForwarding(vector::TransferReadOp);
53   void removeDeadOp() {
54     for (Operation *op : opToErase)
55       rewriter.eraseOp(op);
56     opToErase.clear();
57   }
58 
59 private:
60   RewriterBase &rewriter;
61   bool isReachable(Operation *start, Operation *dest);
62   DominanceInfo dominators;
63   PostDominanceInfo postDominators;
64   std::vector<Operation *> opToErase;
65 };
66 
67 } // namespace
68 /// Return true if there is a path from start operation to dest operation,
69 /// otherwise return false. The operations have to be in the same region.
70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71   assert(start->getParentRegion() == dest->getParentRegion() &&
72          "This function only works for ops i the same region");
73   // Simple case where the start op dominate the destination.
74   if (dominators.dominates(start, dest))
75     return true;
76   Block *startBlock = start->getBlock();
77   Block *destBlock = dest->getBlock();
78   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
79                                     startBlock->succ_end());
80   SmallPtrSet<Block *, 32> visited;
81   while (!worklist.empty()) {
82     Block *bb = worklist.pop_back_val();
83     if (!visited.insert(bb).second)
84       continue;
85     if (dominators.dominates(bb, destBlock))
86       return true;
87     worklist.append(bb->succ_begin(), bb->succ_end());
88   }
89   return false;
90 }
91 
92 /// For transfer_write to overwrite fully another transfer_write must:
93 /// 1. Access the same memref with the same indices and vector type.
94 /// 2. Post-dominate the other transfer_write operation.
95 /// If several candidates are available, one must be post-dominated by all the
96 /// others since they are all post-dominating the same transfer_write. We only
97 /// consider the transfer_write post-dominated by all the other candidates as
98 /// this will be the first transfer_write executed after the potentially dead
99 /// transfer_write.
100 /// If we found such an overwriting transfer_write we know that the original
101 /// transfer_write is dead if all reads that can be reached from the potentially
102 /// dead transfer_write are dominated by the overwriting transfer_write.
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
105                     << "\n");
106   llvm::SmallVector<Operation *, 8> blockingAccesses;
107   Operation *firstOverwriteCandidate = nullptr;
108   Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
109   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
110                                            source.getUsers().end());
111   llvm::SmallDenseSet<Operation *, 32> processed;
112   while (!users.empty()) {
113     Operation *user = users.pop_back_val();
114     // If the user has already been processed skip.
115     if (!processed.insert(user).second)
116       continue;
117     if (isa<ViewLikeOpInterface>(user)) {
118       users.append(user->getUsers().begin(), user->getUsers().end());
119       continue;
120     }
121     if (isMemoryEffectFree(user))
122       continue;
123     if (user == write.getOperation())
124       continue;
125     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
126       // Check candidate that can override the store.
127       if (memref::isSameViewOrTrivialAlias(
128               cast<MemrefValue>(nextWrite.getSource()),
129               cast<MemrefValue>(write.getSource())) &&
130           checkSameValueWAW(nextWrite, write) &&
131           postDominators.postDominates(nextWrite, write)) {
132         if (firstOverwriteCandidate == nullptr ||
133             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134           firstOverwriteCandidate = nextWrite;
135         else
136           assert(
137               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138         continue;
139       }
140     }
141     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142       // Don't need to consider disjoint accesses.
143       if (vector::isDisjointTransferSet(
144               cast<VectorTransferOpInterface>(write.getOperation()),
145               cast<VectorTransferOpInterface>(transferOp.getOperation()),
146               /*testDynamicValueUsingBounds=*/true))
147         continue;
148     }
149     blockingAccesses.push_back(user);
150   }
151   if (firstOverwriteCandidate == nullptr)
152     return;
153   Region *topRegion = firstOverwriteCandidate->getParentRegion();
154   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
155   assert(writeAncestor &&
156          "write op should be recursively part of the top region");
157 
158   for (Operation *access : blockingAccesses) {
159     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
160     // TODO: if the access and write have the same ancestor we could recurse in
161     // the region to know if the access is reachable with more precision.
162     if (accessAncestor == nullptr ||
163         !isReachable(writeAncestor, accessAncestor))
164       continue;
165     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
167                         << *accessAncestor << "\n");
168       return;
169     }
170   }
171   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
172                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
173   opToErase.push_back(write.getOperation());
174 }
175 
176 /// A transfer_write candidate to storeToLoad forwarding must:
177 /// 1. Access the same memref with the same indices and vector type as the
178 /// transfer_read.
179 /// 2. Dominate the transfer_read operation.
180 /// If several candidates are available, one must be dominated by all the others
181 /// since they are all dominating the same transfer_read. We only consider the
182 /// transfer_write dominated by all the other candidates as this will be the
183 /// last transfer_write executed before the transfer_read.
184 /// If we found such a candidate we can do the forwarding if all the other
185 /// potentially aliasing ops that may reach the transfer_read are post-dominated
186 /// by the transfer_write.
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188   if (read.hasOutOfBoundsDim())
189     return;
190   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
191                     << "\n");
192   SmallVector<Operation *, 8> blockingWrites;
193   vector::TransferWriteOp lastwrite = nullptr;
194   Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
195   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
196                                            source.getUsers().end());
197   llvm::SmallDenseSet<Operation *, 32> processed;
198   while (!users.empty()) {
199     Operation *user = users.pop_back_val();
200     // If the user has already been processed skip.
201     if (!processed.insert(user).second)
202       continue;
203     if (isa<ViewLikeOpInterface>(user)) {
204       users.append(user->getUsers().begin(), user->getUsers().end());
205       continue;
206     }
207     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
208       continue;
209     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
210       // If there is a write, but we can prove that it is disjoint we can ignore
211       // the write.
212       if (vector::isDisjointTransferSet(
213               cast<VectorTransferOpInterface>(write.getOperation()),
214               cast<VectorTransferOpInterface>(read.getOperation()),
215               /*testDynamicValueUsingBounds=*/true))
216         continue;
217       if (memref::isSameViewOrTrivialAlias(
218               cast<MemrefValue>(read.getSource()),
219               cast<MemrefValue>(write.getSource())) &&
220           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
221         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
222           lastwrite = write;
223         else
224           assert(dominators.dominates(write, lastwrite));
225         continue;
226       }
227     }
228     blockingWrites.push_back(user);
229   }
230 
231   if (lastwrite == nullptr)
232     return;
233 
234   Region *topRegion = lastwrite->getParentRegion();
235   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
236   assert(readAncestor &&
237          "read op should be recursively part of the top region");
238 
239   for (Operation *write : blockingWrites) {
240     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
241     // TODO: if the store and read have the same ancestor we could recurse in
242     // the region to know if the read is reachable with more precision.
243     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
244       continue;
245     if (!postDominators.postDominates(lastwrite, write)) {
246       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
247                         << *write << "\n");
248       return;
249     }
250   }
251 
252   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
253                     << " to: " << *read.getOperation() << "\n");
254   read.replaceAllUsesWith(lastwrite.getVector());
255   opToErase.push_back(read.getOperation());
256 }
257 
258 /// Converts OpFoldResults to int64_t shape without unit dims.
259 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
260   SmallVector<int64_t> reducedShape;
261   for (const auto size : mixedSizes) {
262     if (llvm::dyn_cast_if_present<Value>(size)) {
263       reducedShape.push_back(ShapedType::kDynamic);
264       continue;
265     }
266 
267     auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
268     if (value == 1)
269       continue;
270     reducedShape.push_back(value.getSExtValue());
271   }
272   return reducedShape;
273 }
274 
275 /// Drops unit dimensions from the input MemRefType.
276 static MemRefType dropUnitDims(MemRefType inputType,
277                                ArrayRef<OpFoldResult> offsets,
278                                ArrayRef<OpFoldResult> sizes,
279                                ArrayRef<OpFoldResult> strides) {
280   auto targetShape = getReducedShape(sizes);
281   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
282       targetShape, inputType, offsets, sizes, strides);
283   return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
284 }
285 
286 /// Creates a rank-reducing memref.subview op that drops unit dims from its
287 /// input. Or just returns the input if it was already without unit dims.
288 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
289                                                  mlir::Location loc,
290                                                  Value input) {
291   MemRefType inputType = cast<MemRefType>(input.getType());
292   SmallVector<OpFoldResult> offsets(inputType.getRank(),
293                                     rewriter.getIndexAttr(0));
294   SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
295   SmallVector<OpFoldResult> strides(inputType.getRank(),
296                                     rewriter.getIndexAttr(1));
297   MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
298 
299   if (canonicalizeStridedLayout(resultType) ==
300       canonicalizeStridedLayout(inputType))
301     return input;
302   return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
303                                             sizes, strides);
304 }
305 
306 /// Returns the number of dims that aren't unit dims.
307 static int getReducedRank(ArrayRef<int64_t> shape) {
308   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
309 }
310 
311 /// Trims non-scalable one dimensions from `oldType` and returns the result
312 /// type.
313 static VectorType trimNonScalableUnitDims(VectorType oldType) {
314   SmallVector<int64_t> newShape;
315   SmallVector<bool> newScalableDims;
316   for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
317     if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
318       continue;
319     newShape.push_back(dimSize);
320     newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
321   }
322   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
323 }
324 
325 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
326 static FailureOr<Value>
327 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
328                                   vector::CreateMaskOp op) {
329   auto type = op.getType();
330   VectorType reducedType = trimNonScalableUnitDims(type);
331   if (reducedType.getRank() == type.getRank())
332     return failure();
333 
334   SmallVector<Value> reducedOperands;
335   for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
336            type.getShape(), type.getScalableDims(), op.getOperands())) {
337     if (dim == 1 && !dimIsScalable) {
338       // If the mask for the unit dim is not a constant of 1, do nothing.
339       auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
340       if (!constant || (constant.value() != 1))
341         return failure();
342       continue;
343     }
344     reducedOperands.push_back(operand);
345   }
346   return rewriter
347       .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
348       .getResult();
349 }
350 
351 namespace {
352 
353 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
354 /// inserting a memref.subview dropping those unit dims. The vector shapes are
355 /// also reduced accordingly.
356 class TransferReadDropUnitDimsPattern
357     : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
358   using MaskableOpRewritePattern::MaskableOpRewritePattern;
359 
360   FailureOr<Value>
361   matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
362                             vector::MaskingOpInterface maskingOp,
363                             PatternRewriter &rewriter) const override {
364     auto loc = transferReadOp.getLoc();
365     Value vector = transferReadOp.getVector();
366     VectorType vectorType = cast<VectorType>(vector.getType());
367     Value source = transferReadOp.getSource();
368     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
369     // TODO: support tensor types.
370     if (!sourceType)
371       return failure();
372     // TODO: generalize this pattern, relax the requirements here.
373     if (transferReadOp.hasOutOfBoundsDim())
374       return failure();
375     if (!transferReadOp.getPermutationMap().isMinorIdentity())
376       return failure();
377     // Check if the source shape can be further reduced.
378     int reducedRank = getReducedRank(sourceType.getShape());
379     if (reducedRank == sourceType.getRank())
380       return failure();
381     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
382     // out.
383     if (reducedRank == 0 && maskingOp)
384       return failure();
385     // Check if the reduced vector shape matches the reduced source shape.
386     // Otherwise, this case is not supported yet.
387     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
388     if (reducedRank != reducedVectorType.getRank())
389       return failure();
390     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
391           return getConstantIntValue(v) != static_cast<int64_t>(0);
392         }))
393       return failure();
394 
395     Value maskOp = transferReadOp.getMask();
396     if (maskOp) {
397       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
398       if (!createMaskOp)
399         return rewriter.notifyMatchFailure(
400             transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
401                             "currently supported");
402       FailureOr<Value> rankReducedCreateMask =
403           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
404       if (failed(rankReducedCreateMask))
405         return failure();
406       maskOp = *rankReducedCreateMask;
407     }
408 
409     Value reducedShapeSource =
410         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
411     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
412     SmallVector<Value> zeros(reducedRank, c0);
413     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
414     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
415     Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
416         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
417         transferReadOp.getPadding(), maskOp,
418         rewriter.getBoolArrayAttr(inBounds));
419 
420     if (maskingOp) {
421       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
422           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
423           maskingOp.getMask());
424       newTransferReadOp = mlir::vector::maskOperation(
425           rewriter, newTransferReadOp, shapeCastMask);
426     }
427 
428     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
429         loc, vectorType, newTransferReadOp->getResults()[0]);
430 
431     return shapeCast;
432   }
433 };
434 
435 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
436 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
437 /// vector shapes are also reduced accordingly.
438 class TransferWriteDropUnitDimsPattern
439     : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
440   using MaskableOpRewritePattern::MaskableOpRewritePattern;
441 
442   FailureOr<Value>
443   matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
444                             vector::MaskingOpInterface maskingOp,
445                             PatternRewriter &rewriter) const override {
446     auto loc = transferWriteOp.getLoc();
447     Value vector = transferWriteOp.getVector();
448     VectorType vectorType = cast<VectorType>(vector.getType());
449     Value source = transferWriteOp.getSource();
450     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
451     // TODO: support tensor type.
452     if (!sourceType)
453       return failure();
454     // TODO: generalize this pattern, relax the requirements here.
455     if (transferWriteOp.hasOutOfBoundsDim())
456       return failure();
457     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
458       return failure();
459     // Check if the destination shape can be further reduced.
460     int reducedRank = getReducedRank(sourceType.getShape());
461     if (reducedRank == sourceType.getRank())
462       return failure();
463     // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
464     // out.
465     if (reducedRank == 0 && maskingOp)
466       return failure();
467     // Check if the reduced vector shape matches the reduced destination shape.
468     // Otherwise, this case is not supported yet.
469     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
470     if (reducedRank != reducedVectorType.getRank())
471       return failure();
472     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
473           return getConstantIntValue(v) != static_cast<int64_t>(0);
474         }))
475       return failure();
476 
477     Value maskOp = transferWriteOp.getMask();
478     if (maskOp) {
479       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
480       if (!createMaskOp)
481         return rewriter.notifyMatchFailure(
482             transferWriteOp,
483             "unsupported mask op, only 'vector.create_mask' is "
484             "currently supported");
485       FailureOr<Value> rankReducedCreateMask =
486           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
487       if (failed(rankReducedCreateMask))
488         return failure();
489       maskOp = *rankReducedCreateMask;
490     }
491     Value reducedShapeSource =
492         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
493     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
494     SmallVector<Value> zeros(reducedRank, c0);
495     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
496     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
497     auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
498         loc, reducedVectorType, vector);
499     Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
500         loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
501         maskOp, rewriter.getBoolArrayAttr(inBounds));
502 
503     if (maskingOp) {
504       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
505           loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
506           maskingOp.getMask());
507       newXferWrite =
508           mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
509     }
510 
511     if (transferWriteOp.hasPureTensorSemantics())
512       return newXferWrite->getResults()[0];
513 
514     // With Memref semantics, there's no return value. Use empty value to signal
515     // success.
516     return Value();
517   }
518 };
519 
520 } // namespace
521 
522 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
523 /// input starting at `firstDimToCollapse`.
524 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
525                                Value input, int64_t firstDimToCollapse) {
526   ShapedType inputType = cast<ShapedType>(input.getType());
527   if (inputType.getRank() == 1)
528     return input;
529   SmallVector<ReassociationIndices> reassociation;
530   for (int64_t i = 0; i < firstDimToCollapse; ++i)
531     reassociation.push_back(ReassociationIndices{i});
532   ReassociationIndices collapsedIndices;
533   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
534     collapsedIndices.push_back(i);
535   reassociation.push_back(collapsedIndices);
536   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
537 }
538 
539 /// Returns the new indices that collapses the inner dimensions starting from
540 /// the `firstDimToCollapse` dimension.
541 static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
542                                               Location loc,
543                                               ArrayRef<int64_t> shape,
544                                               ValueRange indices,
545                                               int64_t firstDimToCollapse) {
546   assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
547 
548   // If all the collapsed indices are zero then no extra logic is needed.
549   // Otherwise, a new offset/index has to be computed.
550   SmallVector<Value> indicesAfterCollapsing(
551       indices.begin(), indices.begin() + firstDimToCollapse);
552   SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
553                                        indices.end());
554   if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
555     indicesAfterCollapsing.push_back(indicesToCollapse[0]);
556     return indicesAfterCollapsing;
557   }
558 
559   // Compute the remaining trailing index/offset required for reading from
560   // the collapsed memref:
561   //
562   //    offset = 0
563   //    for (i = firstDimToCollapse; i < outputRank; ++i)
564   //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
565   //
566   // For this example:
567   //   %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
568   //      memref<1x43x2xi32>, vector<1x2xi32>
569   // which would be collapsed to:
570   //   %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
571   //      memref<1x86xi32>, vector<2xi32>
572   // one would get the following offset:
573   //    %offset = %arg0 * 43
574   OpFoldResult collapsedOffset =
575       rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
576 
577   auto collapsedStrides = computeSuffixProduct(
578       ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
579 
580   // Compute the collapsed offset.
581   auto &&[collapsedExpr, collapsedVals] =
582       computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
583   collapsedOffset = affine::makeComposedFoldedAffineApply(
584       rewriter, loc, collapsedExpr, collapsedVals);
585 
586   if (collapsedOffset.is<Value>()) {
587     indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
588   } else {
589     indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
590         loc, *getConstantIntValue(collapsedOffset)));
591   }
592 
593   return indicesAfterCollapsing;
594 }
595 
596 namespace {
597 
598 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
599 /// memref.collapse_shape on the source so that the resulting
600 /// vector.transfer_read has a 1D source. Requires the source shape to be
601 /// already reduced i.e. without unit dims.
602 ///
603 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
604 /// the trailing dimension of the vector read is smaller than the provided
605 /// bitwidth.
606 class FlattenContiguousRowMajorTransferReadPattern
607     : public OpRewritePattern<vector::TransferReadOp> {
608 public:
609   FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
610                                                unsigned vectorBitwidth,
611                                                PatternBenefit benefit)
612       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
613         targetVectorBitwidth(vectorBitwidth) {}
614 
615   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
616                                 PatternRewriter &rewriter) const override {
617     auto loc = transferReadOp.getLoc();
618     Value vector = transferReadOp.getVector();
619     VectorType vectorType = cast<VectorType>(vector.getType());
620     auto source = transferReadOp.getSource();
621     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
622 
623     // 0. Check pre-conditions
624     // Contiguity check is valid on tensors only.
625     if (!sourceType)
626       return failure();
627     // If this is already 0D/1D, there's nothing to do.
628     if (vectorType.getRank() <= 1)
629       return failure();
630     if (!vectorType.getElementType().isSignlessIntOrFloat())
631       return failure();
632     unsigned trailingVectorDimBitwidth =
633         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
634     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
635       return failure();
636     if (!vector::isContiguousSlice(sourceType, vectorType))
637       return failure();
638     // TODO: generalize this pattern, relax the requirements here.
639     if (transferReadOp.hasOutOfBoundsDim())
640       return failure();
641     if (!transferReadOp.getPermutationMap().isMinorIdentity())
642       return failure();
643     if (transferReadOp.getMask())
644       return failure();
645 
646     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
647 
648     // 1. Collapse the source memref
649     Value collapsedSource =
650         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
651     MemRefType collapsedSourceType =
652         cast<MemRefType>(collapsedSource.getType());
653     int64_t collapsedRank = collapsedSourceType.getRank();
654     assert(collapsedRank == firstDimToCollapse + 1);
655 
656     // 2. Generate input args for a new vector.transfer_read that will read
657     // from the collapsed memref.
658     // 2.1. New dim exprs + affine map
659     SmallVector<AffineExpr, 1> dimExprs{
660         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
661     auto collapsedMap =
662         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
663 
664     // 2.2 New indices
665     SmallVector<Value> collapsedIndices =
666         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
667                             transferReadOp.getIndices(), firstDimToCollapse);
668 
669     // 3. Create new vector.transfer_read that reads from the collapsed memref
670     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
671                                                 vectorType.getElementType());
672     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
673         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
674     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
675 
676     // 4. Replace the old transfer_read with the new one reading from the
677     // collapsed shape
678     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
679         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
680     return success();
681   }
682 
683 private:
684   // Minimum bitwidth that the trailing vector dimension should have after
685   // flattening.
686   unsigned targetVectorBitwidth;
687 };
688 
689 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
690 /// memref.collapse_shape on the source so that the resulting
691 /// vector.transfer_write has a 1D source. Requires the source shape to be
692 /// already reduced i.e. without unit dims.
693 ///
694 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
695 /// the trailing dimension of the vector read is smaller than the provided
696 /// bitwidth.
697 class FlattenContiguousRowMajorTransferWritePattern
698     : public OpRewritePattern<vector::TransferWriteOp> {
699 public:
700   FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
701                                                 unsigned vectorBitwidth,
702                                                 PatternBenefit benefit)
703       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
704         targetVectorBitwidth(vectorBitwidth) {}
705 
706   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
707                                 PatternRewriter &rewriter) const override {
708     auto loc = transferWriteOp.getLoc();
709     Value vector = transferWriteOp.getVector();
710     VectorType vectorType = cast<VectorType>(vector.getType());
711     Value source = transferWriteOp.getSource();
712     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
713 
714     // 0. Check pre-conditions
715     // Contiguity check is valid on tensors only.
716     if (!sourceType)
717       return failure();
718     // If this is already 0D/1D, there's nothing to do.
719     if (vectorType.getRank() <= 1)
720       // Already 0D/1D, nothing to do.
721       return failure();
722     if (!vectorType.getElementType().isSignlessIntOrFloat())
723       return failure();
724     unsigned trailingVectorDimBitwidth =
725         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
726     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
727       return failure();
728     if (!vector::isContiguousSlice(sourceType, vectorType))
729       return failure();
730     // TODO: generalize this pattern, relax the requirements here.
731     if (transferWriteOp.hasOutOfBoundsDim())
732       return failure();
733     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
734       return failure();
735     if (transferWriteOp.getMask())
736       return failure();
737 
738     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
739 
740     // 1. Collapse the source memref
741     Value collapsedSource =
742         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
743     MemRefType collapsedSourceType =
744         cast<MemRefType>(collapsedSource.getType());
745     int64_t collapsedRank = collapsedSourceType.getRank();
746     assert(collapsedRank == firstDimToCollapse + 1);
747 
748     // 2. Generate input args for a new vector.transfer_read that will read
749     // from the collapsed memref.
750     // 2.1. New dim exprs + affine map
751     SmallVector<AffineExpr, 1> dimExprs{
752         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
753     auto collapsedMap =
754         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
755 
756     // 2.2 New indices
757     SmallVector<Value> collapsedIndices =
758         getCollapsedIndices(rewriter, loc, sourceType.getShape(),
759                             transferWriteOp.getIndices(), firstDimToCollapse);
760 
761     // 3. Create new vector.transfer_write that writes to the collapsed memref
762     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
763                                                 vectorType.getElementType());
764     Value flatVector =
765         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
766     vector::TransferWriteOp flatWrite =
767         rewriter.create<vector::TransferWriteOp>(
768             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
769     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
770 
771     // 4. Replace the old transfer_write with the new one writing the
772     // collapsed shape
773     rewriter.eraseOp(transferWriteOp);
774     return success();
775   }
776 
777 private:
778   // Minimum bitwidth that the trailing vector dimension should have after
779   // flattening.
780   unsigned targetVectorBitwidth;
781 };
782 
783 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
784 /// to `memref.load` patterns. The `match` method is shared for both
785 /// `vector.extract` and `vector.extract_element`.
786 template <class VectorExtractOp>
787 class RewriteScalarExtractOfTransferReadBase
788     : public OpRewritePattern<VectorExtractOp> {
789   using Base = OpRewritePattern<VectorExtractOp>;
790 
791 public:
792   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
793                                          PatternBenefit benefit,
794                                          bool allowMultipleUses)
795       : Base::OpRewritePattern(context, benefit),
796         allowMultipleUses(allowMultipleUses) {}
797 
798   LogicalResult match(VectorExtractOp extractOp) const override {
799     auto xferOp =
800         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
801     if (!xferOp)
802       return failure();
803     // Check that we are extracting a scalar and not a sub-vector.
804     if (isa<VectorType>(extractOp.getResult().getType()))
805       return failure();
806     // If multiple uses are not allowed, check if xfer has a single use.
807     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
808       return failure();
809     // If multiple uses are allowed, check if all the xfer uses are extract ops.
810     if (allowMultipleUses &&
811         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
812           return isa<vector::ExtractOp, vector::ExtractElementOp>(
813               use.getOwner());
814         }))
815       return failure();
816     // Mask not supported.
817     if (xferOp.getMask())
818       return failure();
819     // Map not supported.
820     if (!xferOp.getPermutationMap().isMinorIdentity())
821       return failure();
822     // Cannot rewrite if the indices may be out of bounds.
823     if (xferOp.hasOutOfBoundsDim())
824       return failure();
825     return success();
826   }
827 
828 private:
829   bool allowMultipleUses;
830 };
831 
832 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
833 ///
834 /// All the users of the transfer op must be either `vector.extractelement` or
835 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
836 /// transfer ops with any number of users. Otherwise, rewrite only if the
837 /// extract op is the single user of the transfer op. Rewriting a single
838 /// vector load with multiple scalar loads may negatively affect performance.
839 class RewriteScalarExtractElementOfTransferRead
840     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
841   using RewriteScalarExtractOfTransferReadBase::
842       RewriteScalarExtractOfTransferReadBase;
843 
844   void rewrite(vector::ExtractElementOp extractOp,
845                PatternRewriter &rewriter) const override {
846     // Construct scalar load.
847     auto loc = extractOp.getLoc();
848     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
849     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
850                                   xferOp.getIndices().end());
851     if (extractOp.getPosition()) {
852       AffineExpr sym0, sym1;
853       bindSymbols(extractOp.getContext(), sym0, sym1);
854       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
855           rewriter, loc, sym0 + sym1,
856           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
857       if (ofr.is<Value>()) {
858         newIndices[newIndices.size() - 1] = ofr.get<Value>();
859       } else {
860         newIndices[newIndices.size() - 1] =
861             rewriter.create<arith::ConstantIndexOp>(loc,
862                                                     *getConstantIntValue(ofr));
863       }
864     }
865     if (isa<MemRefType>(xferOp.getSource().getType())) {
866       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
867                                                   newIndices);
868     } else {
869       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
870           extractOp, xferOp.getSource(), newIndices);
871     }
872   }
873 };
874 
875 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
876 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
877 ///
878 /// All the users of the transfer op must be either `vector.extractelement` or
879 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
880 /// transfer ops with any number of users. Otherwise, rewrite only if the
881 /// extract op is the single user of the transfer op. Rewriting a single
882 /// vector load with multiple scalar loads may negatively affect performance.
883 class RewriteScalarExtractOfTransferRead
884     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
885   using RewriteScalarExtractOfTransferReadBase::
886       RewriteScalarExtractOfTransferReadBase;
887 
888   void rewrite(vector::ExtractOp extractOp,
889                PatternRewriter &rewriter) const override {
890     // Construct scalar load.
891     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
892     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
893                                   xferOp.getIndices().end());
894     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
895       assert(pos.is<Attribute>() && "Unexpected non-constant index");
896       int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
897       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
898       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
899           rewriter, extractOp.getLoc(),
900           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
901       if (ofr.is<Value>()) {
902         newIndices[idx] = ofr.get<Value>();
903       } else {
904         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
905             extractOp.getLoc(), *getConstantIntValue(ofr));
906       }
907     }
908     if (isa<MemRefType>(xferOp.getSource().getType())) {
909       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
910                                                   newIndices);
911     } else {
912       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
913           extractOp, xferOp.getSource(), newIndices);
914     }
915   }
916 };
917 
918 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
919 /// to memref.store.
920 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
921   using OpRewritePattern::OpRewritePattern;
922 
923   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
924                                 PatternRewriter &rewriter) const override {
925     // Must be a scalar write.
926     auto vecType = xferOp.getVectorType();
927     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
928       return failure();
929     // Mask not supported.
930     if (xferOp.getMask())
931       return failure();
932     // Map not supported.
933     if (!xferOp.getPermutationMap().isMinorIdentity())
934       return failure();
935     // Only float and integer element types are supported.
936     Value scalar;
937     if (vecType.getRank() == 0) {
938       // vector.extract does not support vector<f32> etc., so use
939       // vector.extractelement instead.
940       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
941                                                          xferOp.getVector());
942     } else {
943       SmallVector<int64_t> pos(vecType.getRank(), 0);
944       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
945                                                   xferOp.getVector(), pos);
946     }
947     // Construct a scalar store.
948     if (isa<MemRefType>(xferOp.getSource().getType())) {
949       rewriter.replaceOpWithNewOp<memref::StoreOp>(
950           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
951     } else {
952       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
953           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
954     }
955     return success();
956   }
957 };
958 
959 } // namespace
960 
961 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
962                                      Operation *rootOp) {
963   TransferOptimization opt(rewriter, rootOp);
964   // Run store to load forwarding first since it can expose more dead store
965   // opportunity.
966   rootOp->walk([&](vector::TransferReadOp read) {
967     if (isa<MemRefType>(read.getShapedType()))
968       opt.storeToLoadForwarding(read);
969   });
970   opt.removeDeadOp();
971   rootOp->walk([&](vector::TransferWriteOp write) {
972     if (isa<MemRefType>(write.getShapedType()))
973       opt.deadStoreOp(write);
974   });
975   opt.removeDeadOp();
976 }
977 
978 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
979     RewritePatternSet &patterns, PatternBenefit benefit,
980     bool allowMultipleUses) {
981   patterns.add<RewriteScalarExtractElementOfTransferRead,
982                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
983                                                    benefit, allowMultipleUses);
984   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
985 }
986 
987 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
988     RewritePatternSet &patterns, PatternBenefit benefit) {
989   patterns
990       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
991           patterns.getContext(), benefit);
992   populateShapeCastFoldingPatterns(patterns);
993 }
994 
995 void mlir::vector::populateFlattenVectorTransferPatterns(
996     RewritePatternSet &patterns, unsigned targetVectorBitwidth,
997     PatternBenefit benefit) {
998   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
999                FlattenContiguousRowMajorTransferWritePattern>(
1000       patterns.getContext(), targetVectorBitwidth, benefit);
1001   populateShapeCastFoldingPatterns(patterns, benefit);
1002   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1003 }
1004