xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 847048f497bcdfcfe52f36cba49f07bdbd63cd24)
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
38   for (; op != nullptr && op->getParentRegion() != region;
39        op = op->getParentOp())
40     ;
41   return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48   TransferOptimization(RewriterBase &rewriter, Operation *op)
49       : rewriter(rewriter), dominators(op), postDominators(op) {}
50   void deadStoreOp(vector::TransferWriteOp);
51   void storeToLoadForwarding(vector::TransferReadOp);
52   void removeDeadOp() {
53     for (Operation *op : opToErase)
54       rewriter.eraseOp(op);
55     opToErase.clear();
56   }
57 
58 private:
59   RewriterBase &rewriter;
60   bool isReachable(Operation *start, Operation *dest);
61   DominanceInfo dominators;
62   PostDominanceInfo postDominators;
63   std::vector<Operation *> opToErase;
64 };
65 
66 } // namespace
67 /// Return true if there is a path from start operation to dest operation,
68 /// otherwise return false. The operations have to be in the same region.
69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
70   assert(start->getParentRegion() == dest->getParentRegion() &&
71          "This function only works for ops i the same region");
72   // Simple case where the start op dominate the destination.
73   if (dominators.dominates(start, dest))
74     return true;
75   Block *startBlock = start->getBlock();
76   Block *destBlock = dest->getBlock();
77   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78                                     startBlock->succ_end());
79   SmallPtrSet<Block *, 32> visited;
80   while (!worklist.empty()) {
81     Block *bb = worklist.pop_back_val();
82     if (!visited.insert(bb).second)
83       continue;
84     if (dominators.dominates(bb, destBlock))
85       return true;
86     worklist.append(bb->succ_begin(), bb->succ_end());
87   }
88   return false;
89 }
90 
91 /// For transfer_write to overwrite fully another transfer_write must:
92 /// 1. Access the same memref with the same indices and vector type.
93 /// 2. Post-dominate the other transfer_write operation.
94 /// If several candidates are available, one must be post-dominated by all the
95 /// others since they are all post-dominating the same transfer_write. We only
96 /// consider the transfer_write post-dominated by all the other candidates as
97 /// this will be the first transfer_write executed after the potentially dead
98 /// transfer_write.
99 /// If we found such an overwriting transfer_write we know that the original
100 /// transfer_write is dead if all reads that can be reached from the potentially
101 /// dead transfer_write are dominated by the overwriting transfer_write.
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104                     << "\n");
105   llvm::SmallVector<Operation *, 8> blockingAccesses;
106   Operation *firstOverwriteCandidate = nullptr;
107   Value source = write.getSource();
108   // Skip subview ops.
109   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110     source = subView.getSource();
111   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112                                            source.getUsers().end());
113   llvm::SmallDenseSet<Operation *, 32> processed;
114   while (!users.empty()) {
115     Operation *user = users.pop_back_val();
116     // If the user has already been processed skip.
117     if (!processed.insert(user).second)
118       continue;
119     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120       users.append(subView->getUsers().begin(), subView->getUsers().end());
121       continue;
122     }
123     if (isMemoryEffectFree(user))
124       continue;
125     if (user == write.getOperation())
126       continue;
127     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128       // Check candidate that can override the store.
129       if (write.getSource() == nextWrite.getSource() &&
130           checkSameValueWAW(nextWrite, write) &&
131           postDominators.postDominates(nextWrite, write)) {
132         if (firstOverwriteCandidate == nullptr ||
133             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134           firstOverwriteCandidate = nextWrite;
135         else
136           assert(
137               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138         continue;
139       }
140     }
141     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142       // Don't need to consider disjoint accesses.
143       if (vector::isDisjointTransferSet(
144               cast<VectorTransferOpInterface>(write.getOperation()),
145               cast<VectorTransferOpInterface>(transferOp.getOperation()),
146               /*testDynamicValueUsingBounds=*/true))
147         continue;
148     }
149     blockingAccesses.push_back(user);
150   }
151   if (firstOverwriteCandidate == nullptr)
152     return;
153   Region *topRegion = firstOverwriteCandidate->getParentRegion();
154   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
155   assert(writeAncestor &&
156          "write op should be recursively part of the top region");
157 
158   for (Operation *access : blockingAccesses) {
159     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
160     // TODO: if the access and write have the same ancestor we could recurse in
161     // the region to know if the access is reachable with more precision.
162     if (accessAncestor == nullptr ||
163         !isReachable(writeAncestor, accessAncestor))
164       continue;
165     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
167                         << *accessAncestor << "\n");
168       return;
169     }
170   }
171   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
172                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
173   opToErase.push_back(write.getOperation());
174 }
175 
176 /// A transfer_write candidate to storeToLoad forwarding must:
177 /// 1. Access the same memref with the same indices and vector type as the
178 /// transfer_read.
179 /// 2. Dominate the transfer_read operation.
180 /// If several candidates are available, one must be dominated by all the others
181 /// since they are all dominating the same transfer_read. We only consider the
182 /// transfer_write dominated by all the other candidates as this will be the
183 /// last transfer_write executed before the transfer_read.
184 /// If we found such a candidate we can do the forwarding if all the other
185 /// potentially aliasing ops that may reach the transfer_read are post-dominated
186 /// by the transfer_write.
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188   if (read.hasOutOfBoundsDim())
189     return;
190   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
191                     << "\n");
192   SmallVector<Operation *, 8> blockingWrites;
193   vector::TransferWriteOp lastwrite = nullptr;
194   Value source = read.getSource();
195   // Skip subview ops.
196   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
197     source = subView.getSource();
198   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
199                                            source.getUsers().end());
200   llvm::SmallDenseSet<Operation *, 32> processed;
201   while (!users.empty()) {
202     Operation *user = users.pop_back_val();
203     // If the user has already been processed skip.
204     if (!processed.insert(user).second)
205       continue;
206     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207       users.append(subView->getUsers().begin(), subView->getUsers().end());
208       continue;
209     }
210     if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211       users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
212       continue;
213     }
214     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
215       continue;
216     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
217       // If there is a write, but we can prove that it is disjoint we can ignore
218       // the write.
219       if (vector::isDisjointTransferSet(
220               cast<VectorTransferOpInterface>(write.getOperation()),
221               cast<VectorTransferOpInterface>(read.getOperation()),
222               /*testDynamicValueUsingBounds=*/true))
223         continue;
224       if (write.getSource() == read.getSource() &&
225           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
226         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
227           lastwrite = write;
228         else
229           assert(dominators.dominates(write, lastwrite));
230         continue;
231       }
232     }
233     blockingWrites.push_back(user);
234   }
235 
236   if (lastwrite == nullptr)
237     return;
238 
239   Region *topRegion = lastwrite->getParentRegion();
240   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
241   assert(readAncestor &&
242          "read op should be recursively part of the top region");
243 
244   for (Operation *write : blockingWrites) {
245     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
246     // TODO: if the store and read have the same ancestor we could recurse in
247     // the region to know if the read is reachable with more precision.
248     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
249       continue;
250     if (!postDominators.postDominates(lastwrite, write)) {
251       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
252                         << *write << "\n");
253       return;
254     }
255   }
256 
257   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
258                     << " to: " << *read.getOperation() << "\n");
259   read.replaceAllUsesWith(lastwrite.getVector());
260   opToErase.push_back(read.getOperation());
261 }
262 
263 /// Converts OpFoldResults to int64_t shape without unit dims.
264 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
265   SmallVector<int64_t> reducedShape;
266   for (const auto size : mixedSizes) {
267     if (llvm::dyn_cast_if_present<Value>(size)) {
268       reducedShape.push_back(ShapedType::kDynamic);
269       continue;
270     }
271 
272     auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
273     if (value == 1)
274       continue;
275     reducedShape.push_back(value.getSExtValue());
276   }
277   return reducedShape;
278 }
279 
280 /// Drops unit dimensions from the input MemRefType.
281 static MemRefType dropUnitDims(MemRefType inputType,
282                                ArrayRef<OpFoldResult> offsets,
283                                ArrayRef<OpFoldResult> sizes,
284                                ArrayRef<OpFoldResult> strides) {
285   auto targetShape = getReducedShape(sizes);
286   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
287       targetShape, inputType, offsets, sizes, strides);
288   return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
289 }
290 
291 /// Creates a rank-reducing memref.subview op that drops unit dims from its
292 /// input. Or just returns the input if it was already without unit dims.
293 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
294                                                  mlir::Location loc,
295                                                  Value input) {
296   MemRefType inputType = cast<MemRefType>(input.getType());
297   SmallVector<OpFoldResult> offsets(inputType.getRank(),
298                                     rewriter.getIndexAttr(0));
299   SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
300   SmallVector<OpFoldResult> strides(inputType.getRank(),
301                                     rewriter.getIndexAttr(1));
302   MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
303 
304   if (canonicalizeStridedLayout(resultType) ==
305       canonicalizeStridedLayout(inputType))
306     return input;
307   return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
308                                             sizes, strides);
309 }
310 
311 /// Returns the number of dims that aren't unit dims.
312 static int getReducedRank(ArrayRef<int64_t> shape) {
313   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
314 }
315 
316 /// Trims non-scalable one dimensions from `oldType` and returns the result
317 /// type.
318 static VectorType trimNonScalableUnitDims(VectorType oldType) {
319   SmallVector<int64_t> newShape;
320   SmallVector<bool> newScalableDims;
321   for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
322     if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
323       continue;
324     newShape.push_back(dimSize);
325     newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
326   }
327   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
328 }
329 
330 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
331 static FailureOr<Value>
332 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
333                                   vector::CreateMaskOp op) {
334   auto type = op.getType();
335   VectorType reducedType = trimNonScalableUnitDims(type);
336   if (reducedType.getRank() == type.getRank())
337     return failure();
338 
339   SmallVector<Value> reducedOperands;
340   for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
341            type.getShape(), type.getScalableDims(), op.getOperands())) {
342     if (dim == 1 && !dimIsScalable) {
343       // If the mask for the unit dim is not a constant of 1, do nothing.
344       auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
345       if (!constant || (constant.value() != 1))
346         return failure();
347       continue;
348     }
349     reducedOperands.push_back(operand);
350   }
351   return rewriter
352       .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
353       .getResult();
354 }
355 
356 namespace {
357 
358 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
359 /// inserting a memref.subview dropping those unit dims. The vector shapes are
360 /// also reduced accordingly.
361 class TransferReadDropUnitDimsPattern
362     : public OpRewritePattern<vector::TransferReadOp> {
363   using OpRewritePattern::OpRewritePattern;
364 
365   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
366                                 PatternRewriter &rewriter) const override {
367     auto loc = transferReadOp.getLoc();
368     Value vector = transferReadOp.getVector();
369     VectorType vectorType = cast<VectorType>(vector.getType());
370     Value source = transferReadOp.getSource();
371     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
372     // TODO: support tensor types.
373     if (!sourceType)
374       return failure();
375     // TODO: generalize this pattern, relax the requirements here.
376     if (transferReadOp.hasOutOfBoundsDim())
377       return failure();
378     if (!transferReadOp.getPermutationMap().isMinorIdentity())
379       return failure();
380     // Check if the source shape can be further reduced.
381     int reducedRank = getReducedRank(sourceType.getShape());
382     if (reducedRank == sourceType.getRank())
383       return failure();
384     // Check if the reduced vector shape matches the reduced source shape.
385     // Otherwise, this case is not supported yet.
386     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
387     if (reducedRank != reducedVectorType.getRank())
388       return failure();
389     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
390           return getConstantIntValue(v) != static_cast<int64_t>(0);
391         }))
392       return failure();
393 
394     Value maskOp = transferReadOp.getMask();
395     if (maskOp) {
396       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
397       if (!createMaskOp)
398         return rewriter.notifyMatchFailure(
399             transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
400                             "currently supported");
401       FailureOr<Value> rankReducedCreateMask =
402           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
403       if (failed(rankReducedCreateMask))
404         return failure();
405       maskOp = *rankReducedCreateMask;
406     }
407 
408     Value reducedShapeSource =
409         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
410     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
411     SmallVector<Value> zeros(reducedRank, c0);
412     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
413     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
414     auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
415         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
416         transferReadOp.getPadding(), maskOp,
417         rewriter.getBoolArrayAttr(inBounds));
418     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
419         loc, vectorType, newTransferReadOp);
420     rewriter.replaceOp(transferReadOp, shapeCast);
421 
422     return success();
423   }
424 };
425 
426 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
427 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
428 /// vector shapes are also reduced accordingly.
429 class TransferWriteDropUnitDimsPattern
430     : public OpRewritePattern<vector::TransferWriteOp> {
431   using OpRewritePattern::OpRewritePattern;
432 
433   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
434                                 PatternRewriter &rewriter) const override {
435     auto loc = transferWriteOp.getLoc();
436     Value vector = transferWriteOp.getVector();
437     VectorType vectorType = cast<VectorType>(vector.getType());
438     Value source = transferWriteOp.getSource();
439     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
440     // TODO: support tensor type.
441     if (!sourceType)
442       return failure();
443     // TODO: generalize this pattern, relax the requirements here.
444     if (transferWriteOp.hasOutOfBoundsDim())
445       return failure();
446     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
447       return failure();
448     // Check if the destination shape can be further reduced.
449     int reducedRank = getReducedRank(sourceType.getShape());
450     if (reducedRank == sourceType.getRank())
451       return failure();
452     // Check if the reduced vector shape matches the reduced destination shape.
453     // Otherwise, this case is not supported yet.
454     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
455     if (reducedRank != reducedVectorType.getRank())
456       return failure();
457     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
458           return getConstantIntValue(v) != static_cast<int64_t>(0);
459         }))
460       return failure();
461 
462     Value maskOp = transferWriteOp.getMask();
463     if (maskOp) {
464       auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
465       if (!createMaskOp)
466         return rewriter.notifyMatchFailure(
467             transferWriteOp,
468             "unsupported mask op, only 'vector.create_mask' is "
469             "currently supported");
470       FailureOr<Value> rankReducedCreateMask =
471           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
472       if (failed(rankReducedCreateMask))
473         return failure();
474       maskOp = *rankReducedCreateMask;
475     }
476     Value reducedShapeSource =
477         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
478     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
479     SmallVector<Value> zeros(reducedRank, c0);
480     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
481     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
482     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
483         loc, reducedVectorType, vector);
484     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
485         transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
486         identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
487 
488     return success();
489   }
490 };
491 
492 } // namespace
493 
494 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
495 /// input starting at `firstDimToCollapse`.
496 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
497                                Value input, int64_t firstDimToCollapse) {
498   ShapedType inputType = cast<ShapedType>(input.getType());
499   if (inputType.getRank() == 1)
500     return input;
501   SmallVector<ReassociationIndices> reassociation;
502   for (int64_t i = 0; i < firstDimToCollapse; ++i)
503     reassociation.push_back(ReassociationIndices{i});
504   ReassociationIndices collapsedIndices;
505   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
506     collapsedIndices.push_back(i);
507   reassociation.push_back(collapsedIndices);
508   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
509 }
510 
511 /// Checks that the indices corresponding to dimensions starting at
512 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
513 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
514 /// TODO: Extract the logic that writes to outIndices so that this method
515 /// simply checks one pre-condition.
516 static LogicalResult
517 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
518                                  SmallVector<Value> &outIndices) {
519   int64_t rank = indices.size();
520   if (firstDimToCollapse >= rank)
521     return failure();
522   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
523     std::optional<int64_t> cst = getConstantIntValue(indices[i]);
524     if (!cst || cst.value() != 0)
525       return failure();
526   }
527   outIndices = indices;
528   outIndices.resize(firstDimToCollapse + 1);
529   return success();
530 }
531 
532 namespace {
533 
534 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
535 /// memref.collapse_shape on the source so that the resulting
536 /// vector.transfer_read has a 1D source. Requires the source shape to be
537 /// already reduced i.e. without unit dims.
538 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
539 /// the trailing dimension of the vector read is smaller than the provided
540 /// bitwidth.
541 class FlattenContiguousRowMajorTransferReadPattern
542     : public OpRewritePattern<vector::TransferReadOp> {
543 public:
544   FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
545                                                unsigned vectorBitwidth,
546                                                PatternBenefit benefit)
547       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
548         targetVectorBitwidth(vectorBitwidth) {}
549 
550   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
551                                 PatternRewriter &rewriter) const override {
552     auto loc = transferReadOp.getLoc();
553     Value vector = transferReadOp.getVector();
554     VectorType vectorType = cast<VectorType>(vector.getType());
555     auto source = transferReadOp.getSource();
556     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
557 
558     // 0. Check pre-conditions
559     // Contiguity check is valid on tensors only.
560     if (!sourceType)
561       return failure();
562     // If this is already 0D/1D, there's nothing to do.
563     if (vectorType.getRank() <= 1)
564       return failure();
565     if (!vectorType.getElementType().isSignlessIntOrFloat())
566       return failure();
567     unsigned trailingVectorDimBitwidth =
568         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
569     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
570       return failure();
571     if (!vector::isContiguousSlice(sourceType, vectorType))
572       return failure();
573     // TODO: generalize this pattern, relax the requirements here.
574     if (transferReadOp.hasOutOfBoundsDim())
575       return failure();
576     if (!transferReadOp.getPermutationMap().isMinorIdentity())
577       return failure();
578     if (transferReadOp.getMask())
579       return failure();
580 
581     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
582 
583     // 1. Collapse the source memref
584     Value collapsedSource =
585         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
586     MemRefType collapsedSourceType =
587         dyn_cast<MemRefType>(collapsedSource.getType());
588     int64_t collapsedRank = collapsedSourceType.getRank();
589     assert(collapsedRank == firstDimToCollapse + 1);
590 
591     // 2. Generate input args for a new vector.transfer_read that will read
592     // from the collapsed memref.
593     // 2.1. New dim exprs + affine map
594     SmallVector<AffineExpr, 1> dimExprs{
595         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
596     auto collapsedMap =
597         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
598 
599     // 2.2 New indices
600     // If all the collapsed indices are zero then no extra logic is needed.
601     // Otherwise, a new offset/index has to be computed.
602     SmallVector<Value> collapsedIndices;
603     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
604                                                 firstDimToCollapse,
605                                                 collapsedIndices))) {
606       // Copy all the leading indices.
607       SmallVector<Value> indices = transferReadOp.getIndices();
608       collapsedIndices.append(indices.begin(),
609                               indices.begin() + firstDimToCollapse);
610 
611       // Compute the remaining trailing index/offset required for reading from
612       // the collapsed memref:
613       //
614       //    offset = 0
615       //    for (i = firstDimToCollapse; i < outputRank; ++i)
616       //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
617       //
618       // For this example:
619       //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
620       //      memref<1x43x2xi32>, vector<1x2xi32>
621       // which would be collapsed to:
622       //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
623       //      memref<1x86xi32>, vector<2xi32>
624       // one would get the following offset:
625       //    %offset = %arg0 * 43
626       OpFoldResult collapsedOffset =
627           rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
628 
629       auto sourceShape = sourceType.getShape();
630       auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
631           sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
632 
633       // Compute the collapsed offset.
634       ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
635                                         indices.end());
636       auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
637           collapsedOffset, collapsedStrides, indicesToCollapse);
638       collapsedOffset = affine::makeComposedFoldedAffineApply(
639           rewriter, loc, collapsedExpr, collapsedVals);
640 
641       if (collapsedOffset.is<Value>()) {
642         collapsedIndices.push_back(collapsedOffset.get<Value>());
643       } else {
644         collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
645             loc, *getConstantIntValue(collapsedOffset)));
646       }
647     }
648 
649     // 3. Create new vector.transfer_read that reads from the collapsed memref
650     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
651                                                 vectorType.getElementType());
652     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
653         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
654     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
655 
656     // 4. Replace the old transfer_read with the new one reading from the
657     // collapsed shape
658     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
659         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
660     return success();
661   }
662 
663 private:
664   // Minimum bitwidth that the trailing vector dimension should have after
665   // flattening.
666   unsigned targetVectorBitwidth;
667 };
668 
669 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
670 /// memref.collapse_shape on the source so that the resulting
671 /// vector.transfer_write has a 1D source. Requires the source shape to be
672 /// already reduced i.e. without unit dims.
673 class FlattenContiguousRowMajorTransferWritePattern
674     : public OpRewritePattern<vector::TransferWriteOp> {
675 public:
676   FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
677                                                 unsigned vectorBitwidth,
678                                                 PatternBenefit benefit)
679       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
680         targetVectorBitwidth(vectorBitwidth) {}
681 
682   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
683                                 PatternRewriter &rewriter) const override {
684     auto loc = transferWriteOp.getLoc();
685     Value vector = transferWriteOp.getVector();
686     VectorType vectorType = cast<VectorType>(vector.getType());
687     Value source = transferWriteOp.getSource();
688     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
689     // Contiguity check is valid on tensors only.
690     if (!sourceType)
691       return failure();
692     if (vectorType.getRank() <= 1)
693       // Already 0D/1D, nothing to do.
694       return failure();
695     if (!vectorType.getElementType().isSignlessIntOrFloat())
696       return failure();
697     unsigned trailingVectorDimBitwidth =
698         vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
699     if (trailingVectorDimBitwidth >= targetVectorBitwidth)
700       return failure();
701     if (!vector::isContiguousSlice(sourceType, vectorType))
702       return failure();
703     int64_t firstContiguousInnerDim =
704         sourceType.getRank() - vectorType.getRank();
705     // TODO: generalize this pattern, relax the requirements here.
706     if (transferWriteOp.hasOutOfBoundsDim())
707       return failure();
708     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
709       return failure();
710     if (transferWriteOp.getMask())
711       return failure();
712     SmallVector<Value> collapsedIndices;
713     if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
714                                                 firstContiguousInnerDim,
715                                                 collapsedIndices)))
716       return failure();
717 
718     Value collapsedSource =
719         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
720     MemRefType collapsedSourceType =
721         cast<MemRefType>(collapsedSource.getType());
722     int64_t collapsedRank = collapsedSourceType.getRank();
723     assert(collapsedRank == firstContiguousInnerDim + 1);
724     SmallVector<AffineExpr, 1> dimExprs{
725         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
726     auto collapsedMap =
727         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
728     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
729                                                 vectorType.getElementType());
730     Value flatVector =
731         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
732     vector::TransferWriteOp flatWrite =
733         rewriter.create<vector::TransferWriteOp>(
734             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
735     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
736     rewriter.eraseOp(transferWriteOp);
737     return success();
738   }
739 
740 private:
741   // Minimum bitwidth that the trailing vector dimension should have after
742   // flattening.
743   unsigned targetVectorBitwidth;
744 };
745 
746 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
747 /// to `memref.load` patterns. The `match` method is shared for both
748 /// `vector.extract` and `vector.extract_element`.
749 template <class VectorExtractOp>
750 class RewriteScalarExtractOfTransferReadBase
751     : public OpRewritePattern<VectorExtractOp> {
752   using Base = OpRewritePattern<VectorExtractOp>;
753 
754 public:
755   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
756                                          PatternBenefit benefit,
757                                          bool allowMultipleUses)
758       : Base::OpRewritePattern(context, benefit),
759         allowMultipleUses(allowMultipleUses) {}
760 
761   LogicalResult match(VectorExtractOp extractOp) const override {
762     auto xferOp =
763         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
764     if (!xferOp)
765       return failure();
766     // Check that we are extracting a scalar and not a sub-vector.
767     if (isa<VectorType>(extractOp.getResult().getType()))
768       return failure();
769     // If multiple uses are not allowed, check if xfer has a single use.
770     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
771       return failure();
772     // If multiple uses are allowed, check if all the xfer uses are extract ops.
773     if (allowMultipleUses &&
774         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
775           return isa<vector::ExtractOp, vector::ExtractElementOp>(
776               use.getOwner());
777         }))
778       return failure();
779     // Mask not supported.
780     if (xferOp.getMask())
781       return failure();
782     // Map not supported.
783     if (!xferOp.getPermutationMap().isMinorIdentity())
784       return failure();
785     // Cannot rewrite if the indices may be out of bounds.
786     if (xferOp.hasOutOfBoundsDim())
787       return failure();
788     return success();
789   }
790 
791 private:
792   bool allowMultipleUses;
793 };
794 
795 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
796 ///
797 /// All the users of the transfer op must be either `vector.extractelement` or
798 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
799 /// transfer ops with any number of users. Otherwise, rewrite only if the
800 /// extract op is the single user of the transfer op. Rewriting a single
801 /// vector load with multiple scalar loads may negatively affect performance.
802 class RewriteScalarExtractElementOfTransferRead
803     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
804   using RewriteScalarExtractOfTransferReadBase::
805       RewriteScalarExtractOfTransferReadBase;
806 
807   void rewrite(vector::ExtractElementOp extractOp,
808                PatternRewriter &rewriter) const override {
809     // Construct scalar load.
810     auto loc = extractOp.getLoc();
811     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
812     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
813                                   xferOp.getIndices().end());
814     if (extractOp.getPosition()) {
815       AffineExpr sym0, sym1;
816       bindSymbols(extractOp.getContext(), sym0, sym1);
817       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
818           rewriter, loc, sym0 + sym1,
819           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
820       if (ofr.is<Value>()) {
821         newIndices[newIndices.size() - 1] = ofr.get<Value>();
822       } else {
823         newIndices[newIndices.size() - 1] =
824             rewriter.create<arith::ConstantIndexOp>(loc,
825                                                     *getConstantIntValue(ofr));
826       }
827     }
828     if (isa<MemRefType>(xferOp.getSource().getType())) {
829       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
830                                                   newIndices);
831     } else {
832       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
833           extractOp, xferOp.getSource(), newIndices);
834     }
835   }
836 };
837 
838 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
839 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
840 ///
841 /// All the users of the transfer op must be either `vector.extractelement` or
842 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
843 /// transfer ops with any number of users. Otherwise, rewrite only if the
844 /// extract op is the single user of the transfer op. Rewriting a single
845 /// vector load with multiple scalar loads may negatively affect performance.
846 class RewriteScalarExtractOfTransferRead
847     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
848   using RewriteScalarExtractOfTransferReadBase::
849       RewriteScalarExtractOfTransferReadBase;
850 
851   void rewrite(vector::ExtractOp extractOp,
852                PatternRewriter &rewriter) const override {
853     // Construct scalar load.
854     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
855     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
856                                   xferOp.getIndices().end());
857     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
858       assert(pos.is<Attribute>() && "Unexpected non-constant index");
859       int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
860       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
861       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
862           rewriter, extractOp.getLoc(),
863           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
864       if (ofr.is<Value>()) {
865         newIndices[idx] = ofr.get<Value>();
866       } else {
867         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
868             extractOp.getLoc(), *getConstantIntValue(ofr));
869       }
870     }
871     if (isa<MemRefType>(xferOp.getSource().getType())) {
872       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
873                                                   newIndices);
874     } else {
875       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
876           extractOp, xferOp.getSource(), newIndices);
877     }
878   }
879 };
880 
881 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
882 /// to memref.store.
883 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
884   using OpRewritePattern::OpRewritePattern;
885 
886   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
887                                 PatternRewriter &rewriter) const override {
888     // Must be a scalar write.
889     auto vecType = xferOp.getVectorType();
890     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
891       return failure();
892     // Mask not supported.
893     if (xferOp.getMask())
894       return failure();
895     // Map not supported.
896     if (!xferOp.getPermutationMap().isMinorIdentity())
897       return failure();
898     // Only float and integer element types are supported.
899     Value scalar;
900     if (vecType.getRank() == 0) {
901       // vector.extract does not support vector<f32> etc., so use
902       // vector.extractelement instead.
903       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
904                                                          xferOp.getVector());
905     } else {
906       SmallVector<int64_t> pos(vecType.getRank(), 0);
907       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
908                                                   xferOp.getVector(), pos);
909     }
910     // Construct a scalar store.
911     if (isa<MemRefType>(xferOp.getSource().getType())) {
912       rewriter.replaceOpWithNewOp<memref::StoreOp>(
913           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
914     } else {
915       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
916           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
917     }
918     return success();
919   }
920 };
921 
922 } // namespace
923 
924 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
925                                      Operation *rootOp) {
926   TransferOptimization opt(rewriter, rootOp);
927   // Run store to load forwarding first since it can expose more dead store
928   // opportunity.
929   rootOp->walk([&](vector::TransferReadOp read) {
930     if (isa<MemRefType>(read.getShapedType()))
931       opt.storeToLoadForwarding(read);
932   });
933   opt.removeDeadOp();
934   rootOp->walk([&](vector::TransferWriteOp write) {
935     if (isa<MemRefType>(write.getShapedType()))
936       opt.deadStoreOp(write);
937   });
938   opt.removeDeadOp();
939 }
940 
941 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
942     RewritePatternSet &patterns, PatternBenefit benefit,
943     bool allowMultipleUses) {
944   patterns.add<RewriteScalarExtractElementOfTransferRead,
945                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
946                                                    benefit, allowMultipleUses);
947   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
948 }
949 
950 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
951     RewritePatternSet &patterns, PatternBenefit benefit) {
952   patterns
953       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
954           patterns.getContext(), benefit);
955   populateShapeCastFoldingPatterns(patterns);
956 }
957 
958 void mlir::vector::populateFlattenVectorTransferPatterns(
959     RewritePatternSet &patterns, unsigned targetVectorBitwidth,
960     PatternBenefit benefit) {
961   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
962                FlattenContiguousRowMajorTransferWritePattern>(
963       patterns.getContext(), targetVectorBitwidth, benefit);
964   populateShapeCastFoldingPatterns(patterns, benefit);
965   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
966 }
967