xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 16b75cd2bb439633d29c99a7663f2586e4068ecf)
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
38   for (; op != nullptr && op->getParentRegion() != region;
39        op = op->getParentOp())
40     ;
41   return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48   TransferOptimization(RewriterBase &rewriter, Operation *op)
49       : rewriter(rewriter), dominators(op), postDominators(op) {}
50   void deadStoreOp(vector::TransferWriteOp);
51   void storeToLoadForwarding(vector::TransferReadOp);
52   void removeDeadOp() {
53     for (Operation *op : opToErase)
54       rewriter.eraseOp(op);
55     opToErase.clear();
56   }
57 
58 private:
59   RewriterBase &rewriter;
60   bool isReachable(Operation *start, Operation *dest);
61   DominanceInfo dominators;
62   PostDominanceInfo postDominators;
63   std::vector<Operation *> opToErase;
64 };
65 
66 } // namespace
67 /// Return true if there is a path from start operation to dest operation,
68 /// otherwise return false. The operations have to be in the same region.
69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
70   assert(start->getParentRegion() == dest->getParentRegion() &&
71          "This function only works for ops i the same region");
72   // Simple case where the start op dominate the destination.
73   if (dominators.dominates(start, dest))
74     return true;
75   Block *startBlock = start->getBlock();
76   Block *destBlock = dest->getBlock();
77   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78                                     startBlock->succ_end());
79   SmallPtrSet<Block *, 32> visited;
80   while (!worklist.empty()) {
81     Block *bb = worklist.pop_back_val();
82     if (!visited.insert(bb).second)
83       continue;
84     if (dominators.dominates(bb, destBlock))
85       return true;
86     worklist.append(bb->succ_begin(), bb->succ_end());
87   }
88   return false;
89 }
90 
91 /// For transfer_write to overwrite fully another transfer_write must:
92 /// 1. Access the same memref with the same indices and vector type.
93 /// 2. Post-dominate the other transfer_write operation.
94 /// If several candidates are available, one must be post-dominated by all the
95 /// others since they are all post-dominating the same transfer_write. We only
96 /// consider the transfer_write post-dominated by all the other candidates as
97 /// this will be the first transfer_write executed after the potentially dead
98 /// transfer_write.
99 /// If we found such an overwriting transfer_write we know that the original
100 /// transfer_write is dead if all reads that can be reached from the potentially
101 /// dead transfer_write are dominated by the overwriting transfer_write.
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104                     << "\n");
105   llvm::SmallVector<Operation *, 8> blockingAccesses;
106   Operation *firstOverwriteCandidate = nullptr;
107   Value source = write.getSource();
108   // Skip subview ops.
109   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110     source = subView.getSource();
111   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112                                            source.getUsers().end());
113   llvm::SmallDenseSet<Operation *, 32> processed;
114   while (!users.empty()) {
115     Operation *user = users.pop_back_val();
116     // If the user has already been processed skip.
117     if (!processed.insert(user).second)
118       continue;
119     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120       users.append(subView->getUsers().begin(), subView->getUsers().end());
121       continue;
122     }
123     if (isMemoryEffectFree(user))
124       continue;
125     if (user == write.getOperation())
126       continue;
127     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128       // Check candidate that can override the store.
129       if (write.getSource() == nextWrite.getSource() &&
130           checkSameValueWAW(nextWrite, write) &&
131           postDominators.postDominates(nextWrite, write)) {
132         if (firstOverwriteCandidate == nullptr ||
133             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134           firstOverwriteCandidate = nextWrite;
135         else
136           assert(
137               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138         continue;
139       }
140     }
141     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142       // Don't need to consider disjoint accesses.
143       if (vector::isDisjointTransferSet(
144               cast<VectorTransferOpInterface>(write.getOperation()),
145               cast<VectorTransferOpInterface>(transferOp.getOperation())))
146         continue;
147     }
148     blockingAccesses.push_back(user);
149   }
150   if (firstOverwriteCandidate == nullptr)
151     return;
152   Region *topRegion = firstOverwriteCandidate->getParentRegion();
153   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
154   assert(writeAncestor &&
155          "write op should be recursively part of the top region");
156 
157   for (Operation *access : blockingAccesses) {
158     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
159     // TODO: if the access and write have the same ancestor we could recurse in
160     // the region to know if the access is reachable with more precision.
161     if (accessAncestor == nullptr ||
162         !isReachable(writeAncestor, accessAncestor))
163       continue;
164     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
165       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
166                         << *accessAncestor << "\n");
167       return;
168     }
169   }
170   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
171                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
172   opToErase.push_back(write.getOperation());
173 }
174 
175 /// A transfer_write candidate to storeToLoad forwarding must:
176 /// 1. Access the same memref with the same indices and vector type as the
177 /// transfer_read.
178 /// 2. Dominate the transfer_read operation.
179 /// If several candidates are available, one must be dominated by all the others
180 /// since they are all dominating the same transfer_read. We only consider the
181 /// transfer_write dominated by all the other candidates as this will be the
182 /// last transfer_write executed before the transfer_read.
183 /// If we found such a candidate we can do the forwarding if all the other
184 /// potentially aliasing ops that may reach the transfer_read are post-dominated
185 /// by the transfer_write.
186 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
187   if (read.hasOutOfBoundsDim())
188     return;
189   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
190                     << "\n");
191   SmallVector<Operation *, 8> blockingWrites;
192   vector::TransferWriteOp lastwrite = nullptr;
193   Value source = read.getSource();
194   // Skip subview ops.
195   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
196     source = subView.getSource();
197   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
198                                            source.getUsers().end());
199   llvm::SmallDenseSet<Operation *, 32> processed;
200   while (!users.empty()) {
201     Operation *user = users.pop_back_val();
202     // If the user has already been processed skip.
203     if (!processed.insert(user).second)
204       continue;
205     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
206       users.append(subView->getUsers().begin(), subView->getUsers().end());
207       continue;
208     }
209     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
210       continue;
211     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
212       // If there is a write, but we can prove that it is disjoint we can ignore
213       // the write.
214       if (vector::isDisjointTransferSet(
215               cast<VectorTransferOpInterface>(write.getOperation()),
216               cast<VectorTransferOpInterface>(read.getOperation())))
217         continue;
218       if (write.getSource() == read.getSource() &&
219           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
220         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
221           lastwrite = write;
222         else
223           assert(dominators.dominates(write, lastwrite));
224         continue;
225       }
226     }
227     blockingWrites.push_back(user);
228   }
229 
230   if (lastwrite == nullptr)
231     return;
232 
233   Region *topRegion = lastwrite->getParentRegion();
234   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
235   assert(readAncestor &&
236          "read op should be recursively part of the top region");
237 
238   for (Operation *write : blockingWrites) {
239     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
240     // TODO: if the store and read have the same ancestor we could recurse in
241     // the region to know if the read is reachable with more precision.
242     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
243       continue;
244     if (!postDominators.postDominates(lastwrite, write)) {
245       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
246                         << *write << "\n");
247       return;
248     }
249   }
250 
251   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
252                     << " to: " << *read.getOperation() << "\n");
253   read.replaceAllUsesWith(lastwrite.getVector());
254   opToErase.push_back(read.getOperation());
255 }
256 
257 /// Drops unit dimensions from the input MemRefType.
258 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
259                                ArrayRef<int64_t> sizes,
260                                ArrayRef<int64_t> strides) {
261   SmallVector<int64_t> targetShape = llvm::to_vector(
262       llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
263   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
264       targetShape, inputType, offsets, sizes, strides);
265   return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
266 }
267 
268 /// Creates a rank-reducing memref.subview op that drops unit dims from its
269 /// input. Or just returns the input if it was already without unit dims.
270 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
271                                                  mlir::Location loc,
272                                                  Value input) {
273   MemRefType inputType = cast<MemRefType>(input.getType());
274   assert(inputType.hasStaticShape());
275   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
276   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
277   ArrayRef<int64_t> subViewSizes = inputType.getShape();
278   MemRefType resultType =
279       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
280   if (canonicalizeStridedLayout(resultType) ==
281       canonicalizeStridedLayout(inputType))
282     return input;
283   return rewriter.create<memref::SubViewOp>(
284       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
285 }
286 
287 /// Returns the number of dims that aren't unit dims.
288 static int getReducedRank(ArrayRef<int64_t> shape) {
289   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
290 }
291 
292 /// Returns a copy of `shape` without unit dims.
293 static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
294   SmallVector<int64_t> reducedShape;
295   llvm::copy_if(shape, std::back_inserter(reducedShape),
296                 [](int64_t dimSize) { return dimSize != 1; });
297   return reducedShape;
298 }
299 
300 namespace {
301 
302 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
303 /// inserting a memref.subview dropping those unit dims. The vector shapes are
304 /// also reduced accordingly.
305 class TransferReadDropUnitDimsPattern
306     : public OpRewritePattern<vector::TransferReadOp> {
307   using OpRewritePattern::OpRewritePattern;
308 
309   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
310                                 PatternRewriter &rewriter) const override {
311     auto loc = transferReadOp.getLoc();
312     Value vector = transferReadOp.getVector();
313     VectorType vectorType = cast<VectorType>(vector.getType());
314     Value source = transferReadOp.getSource();
315     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
316     // TODO: support tensor types.
317     if (!sourceType || !sourceType.hasStaticShape())
318       return failure();
319     if (sourceType.getNumElements() != vectorType.getNumElements())
320       return failure();
321     // TODO: generalize this pattern, relax the requirements here.
322     if (transferReadOp.hasOutOfBoundsDim())
323       return failure();
324     if (!transferReadOp.getPermutationMap().isMinorIdentity())
325       return failure();
326     // Check if the source shape can be further reduced.
327     int reducedRank = getReducedRank(sourceType.getShape());
328     if (reducedRank == sourceType.getRank())
329       return failure();
330     // Check if the reduced vector shape matches the reduced source shape.
331     // Otherwise, this case is not supported yet.
332     int vectorReducedRank = getReducedRank(vectorType.getShape());
333     if (reducedRank != vectorReducedRank)
334       return failure();
335     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
336           return getConstantIntValue(v) != static_cast<int64_t>(0);
337         }))
338       return failure();
339     Value reducedShapeSource =
340         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
341     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
342     SmallVector<Value> zeros(reducedRank, c0);
343     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
344     auto reducedVectorType = VectorType::get(
345         getReducedShape(vectorType.getShape()), vectorType.getElementType());
346 
347     auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
348         loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
349     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
350         loc, vectorType, newTransferReadOp);
351     rewriter.replaceOp(transferReadOp, shapeCast);
352 
353     return success();
354   }
355 };
356 
357 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
358 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
359 /// vector shapes are also reduced accordingly.
360 class TransferWriteDropUnitDimsPattern
361     : public OpRewritePattern<vector::TransferWriteOp> {
362   using OpRewritePattern::OpRewritePattern;
363 
364   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
365                                 PatternRewriter &rewriter) const override {
366     auto loc = transferWriteOp.getLoc();
367     Value vector = transferWriteOp.getVector();
368     VectorType vectorType = cast<VectorType>(vector.getType());
369     Value source = transferWriteOp.getSource();
370     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
371     // TODO: support tensor type.
372     if (!sourceType || !sourceType.hasStaticShape())
373       return failure();
374     if (sourceType.getNumElements() != vectorType.getNumElements())
375       return failure();
376     // TODO: generalize this pattern, relax the requirements here.
377     if (transferWriteOp.hasOutOfBoundsDim())
378       return failure();
379     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
380       return failure();
381     // Check if the destination shape can be further reduced.
382     int reducedRank = getReducedRank(sourceType.getShape());
383     if (reducedRank == sourceType.getRank())
384       return failure();
385     // Check if the reduced vector shape matches the reduced destination shape.
386     // Otherwise, this case is not supported yet.
387     int vectorReducedRank = getReducedRank(vectorType.getShape());
388     if (reducedRank != vectorReducedRank)
389       return failure();
390     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
391           return getConstantIntValue(v) != static_cast<int64_t>(0);
392         }))
393       return failure();
394     Value reducedShapeSource =
395         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
396     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
397     SmallVector<Value> zeros(reducedRank, c0);
398     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
399     VectorType reducedVectorType = VectorType::get(
400         getReducedShape(vectorType.getShape()), vectorType.getElementType());
401 
402     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
403         loc, reducedVectorType, vector);
404     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
405         transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
406 
407     return success();
408   }
409 };
410 
411 } // namespace
412 
413 /// Return true if the memref type has its inner dimension matching the given
414 /// shape. Otherwise return false.
415 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
416                                               ArrayRef<int64_t> targetShape) {
417   auto shape = memrefType.getShape();
418   SmallVector<int64_t> strides;
419   int64_t offset;
420   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
421     return false;
422   if (strides.back() != 1)
423     return false;
424   strides.pop_back();
425   int64_t flatDim = 1;
426   for (auto [targetDim, memrefDim, memrefStride] :
427        llvm::reverse(llvm::zip(targetShape, shape, strides))) {
428     flatDim *= memrefDim;
429     if (flatDim != memrefStride || targetDim != memrefDim)
430       return false;
431   }
432   return true;
433 }
434 
435 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
436 /// input starting at `firstDimToCollapse`.
437 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
438                                Value input, int64_t firstDimToCollapse) {
439   ShapedType inputType = cast<ShapedType>(input.getType());
440   if (inputType.getRank() == 1)
441     return input;
442   SmallVector<ReassociationIndices> reassociation;
443   for (int64_t i = 0; i < firstDimToCollapse; ++i)
444     reassociation.push_back(ReassociationIndices{i});
445   ReassociationIndices collapsedIndices;
446   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
447     collapsedIndices.push_back(i);
448   reassociation.push_back(collapsedIndices);
449   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
450 }
451 
452 /// Checks that the indices corresponding to dimensions starting at
453 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
454 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
455 static LogicalResult
456 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
457                                  SmallVector<Value> &outIndices) {
458   int64_t rank = indices.size();
459   if (firstDimToCollapse >= rank)
460     return failure();
461   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
462     std::optional<int64_t> cst = getConstantIntValue(indices[i]);
463     if (!cst || cst.value() != 0)
464       return failure();
465   }
466   outIndices = indices;
467   outIndices.resize(firstDimToCollapse + 1);
468   return success();
469 }
470 
471 namespace {
472 
473 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
474 /// memref.collapse_shape on the source so that the resulting
475 /// vector.transfer_read has a 1D source. Requires the source shape to be
476 /// already reduced i.e. without unit dims.
477 class FlattenContiguousRowMajorTransferReadPattern
478     : public OpRewritePattern<vector::TransferReadOp> {
479   using OpRewritePattern::OpRewritePattern;
480 
481   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
482                                 PatternRewriter &rewriter) const override {
483     auto loc = transferReadOp.getLoc();
484     Value vector = transferReadOp.getVector();
485     VectorType vectorType = cast<VectorType>(vector.getType());
486     Value source = transferReadOp.getSource();
487     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
488     // Contiguity check is valid on tensors only.
489     if (!sourceType)
490       return failure();
491     if (vectorType.getRank() <= 1)
492       // Already 0D/1D, nothing to do.
493       return failure();
494     if (!hasMatchingInnerContigousShape(
495             sourceType,
496             vectorType.getShape().take_back(vectorType.getRank() - 1)))
497       return failure();
498     int64_t firstContiguousInnerDim =
499         sourceType.getRank() - vectorType.getRank();
500     // TODO: generalize this pattern, relax the requirements here.
501     if (transferReadOp.hasOutOfBoundsDim())
502       return failure();
503     if (!transferReadOp.getPermutationMap().isMinorIdentity())
504       return failure();
505     if (transferReadOp.getMask())
506       return failure();
507     SmallVector<Value> collapsedIndices;
508     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
509                                                 firstContiguousInnerDim,
510                                                 collapsedIndices)))
511       return failure();
512     Value collapsedSource =
513         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
514     MemRefType collapsedSourceType =
515         dyn_cast<MemRefType>(collapsedSource.getType());
516     int64_t collapsedRank = collapsedSourceType.getRank();
517     assert(collapsedRank == firstContiguousInnerDim + 1);
518     SmallVector<AffineExpr, 1> dimExprs{
519         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
520     auto collapsedMap =
521         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
522     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
523                                                 vectorType.getElementType());
524     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
525         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
526     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
527     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
528         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
529     return success();
530   }
531 };
532 
533 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
534 /// memref.collapse_shape on the source so that the resulting
535 /// vector.transfer_write has a 1D source. Requires the source shape to be
536 /// already reduced i.e. without unit dims.
537 class FlattenContiguousRowMajorTransferWritePattern
538     : public OpRewritePattern<vector::TransferWriteOp> {
539   using OpRewritePattern::OpRewritePattern;
540 
541   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
542                                 PatternRewriter &rewriter) const override {
543     auto loc = transferWriteOp.getLoc();
544     Value vector = transferWriteOp.getVector();
545     VectorType vectorType = cast<VectorType>(vector.getType());
546     Value source = transferWriteOp.getSource();
547     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
548     // Contiguity check is valid on tensors only.
549     if (!sourceType)
550       return failure();
551     if (vectorType.getRank() <= 1)
552       // Already 0D/1D, nothing to do.
553       return failure();
554     if (!hasMatchingInnerContigousShape(
555             sourceType,
556             vectorType.getShape().take_back(vectorType.getRank() - 1)))
557       return failure();
558     int64_t firstContiguousInnerDim =
559         sourceType.getRank() - vectorType.getRank();
560     // TODO: generalize this pattern, relax the requirements here.
561     if (transferWriteOp.hasOutOfBoundsDim())
562       return failure();
563     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
564       return failure();
565     if (transferWriteOp.getMask())
566       return failure();
567     SmallVector<Value> collapsedIndices;
568     if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
569                                                 firstContiguousInnerDim,
570                                                 collapsedIndices)))
571       return failure();
572     Value collapsedSource =
573         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
574     MemRefType collapsedSourceType =
575         cast<MemRefType>(collapsedSource.getType());
576     int64_t collapsedRank = collapsedSourceType.getRank();
577     assert(collapsedRank == firstContiguousInnerDim + 1);
578     SmallVector<AffineExpr, 1> dimExprs{
579         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
580     auto collapsedMap =
581         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
582     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
583                                                 vectorType.getElementType());
584     Value flatVector =
585         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
586     vector::TransferWriteOp flatWrite =
587         rewriter.create<vector::TransferWriteOp>(
588             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
589     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
590     rewriter.eraseOp(transferWriteOp);
591     return success();
592   }
593 };
594 
595 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
596 /// to `memref.load` patterns. The `match` method is shared for both
597 /// `vector.extract` and `vector.extract_element`.
598 template <class VectorExtractOp>
599 class RewriteScalarExtractOfTransferReadBase
600     : public OpRewritePattern<VectorExtractOp> {
601   using Base = OpRewritePattern<VectorExtractOp>;
602 
603 public:
604   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
605                                          PatternBenefit benefit,
606                                          bool allowMultipleUses)
607       : Base::OpRewritePattern(context, benefit),
608         allowMultipleUses(allowMultipleUses) {}
609 
610   LogicalResult match(VectorExtractOp extractOp) const override {
611     auto xferOp =
612         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
613     if (!xferOp)
614       return failure();
615     // Check that we are extracting a scalar and not a sub-vector.
616     if (isa<VectorType>(extractOp.getResult().getType()))
617       return failure();
618     // If multiple uses are not allowed, check if xfer has a single use.
619     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
620       return failure();
621     // If multiple uses are allowed, check if all the xfer uses are extract ops.
622     if (allowMultipleUses &&
623         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
624           return isa<vector::ExtractOp, vector::ExtractElementOp>(
625               use.getOwner());
626         }))
627       return failure();
628     // Mask not supported.
629     if (xferOp.getMask())
630       return failure();
631     // Map not supported.
632     if (!xferOp.getPermutationMap().isMinorIdentity())
633       return failure();
634     // Cannot rewrite if the indices may be out of bounds.
635     if (xferOp.hasOutOfBoundsDim())
636       return failure();
637     return success();
638   }
639 
640 private:
641   bool allowMultipleUses;
642 };
643 
644 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
645 ///
646 /// All the users of the transfer op must be either `vector.extractelement` or
647 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
648 /// transfer ops with any number of users. Otherwise, rewrite only if the
649 /// extract op is the single user of the transfer op. Rewriting a single
650 /// vector load with multiple scalar loads may negatively affect performance.
651 class RewriteScalarExtractElementOfTransferRead
652     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
653   using RewriteScalarExtractOfTransferReadBase::
654       RewriteScalarExtractOfTransferReadBase;
655 
656   void rewrite(vector::ExtractElementOp extractOp,
657                PatternRewriter &rewriter) const override {
658     // Construct scalar load.
659     auto loc = extractOp.getLoc();
660     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
661     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
662                                   xferOp.getIndices().end());
663     if (extractOp.getPosition()) {
664       AffineExpr sym0, sym1;
665       bindSymbols(extractOp.getContext(), sym0, sym1);
666       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
667           rewriter, loc, sym0 + sym1,
668           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
669       if (ofr.is<Value>()) {
670         newIndices[newIndices.size() - 1] = ofr.get<Value>();
671       } else {
672         newIndices[newIndices.size() - 1] =
673             rewriter.create<arith::ConstantIndexOp>(loc,
674                                                     *getConstantIntValue(ofr));
675       }
676     }
677     if (isa<MemRefType>(xferOp.getSource().getType())) {
678       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
679                                                   newIndices);
680     } else {
681       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
682           extractOp, xferOp.getSource(), newIndices);
683     }
684   }
685 };
686 
687 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
688 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
689 ///
690 /// All the users of the transfer op must be either `vector.extractelement` or
691 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
692 /// transfer ops with any number of users. Otherwise, rewrite only if the
693 /// extract op is the single user of the transfer op. Rewriting a single
694 /// vector load with multiple scalar loads may negatively affect performance.
695 class RewriteScalarExtractOfTransferRead
696     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
697   using RewriteScalarExtractOfTransferReadBase::
698       RewriteScalarExtractOfTransferReadBase;
699 
700   void rewrite(vector::ExtractOp extractOp,
701                PatternRewriter &rewriter) const override {
702     // Construct scalar load.
703     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
704     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
705                                   xferOp.getIndices().end());
706     for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
707       int64_t offset = it.value();
708       int64_t idx =
709           newIndices.size() - extractOp.getPosition().size() + it.index();
710       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
711           rewriter, extractOp.getLoc(),
712           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
713       if (ofr.is<Value>()) {
714         newIndices[idx] = ofr.get<Value>();
715       } else {
716         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
717             extractOp.getLoc(), *getConstantIntValue(ofr));
718       }
719     }
720     if (isa<MemRefType>(xferOp.getSource().getType())) {
721       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
722                                                   newIndices);
723     } else {
724       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
725           extractOp, xferOp.getSource(), newIndices);
726     }
727   }
728 };
729 
730 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
731 /// to memref.store.
732 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
733   using OpRewritePattern::OpRewritePattern;
734 
735   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
736                                 PatternRewriter &rewriter) const override {
737     // Must be a scalar write.
738     auto vecType = xferOp.getVectorType();
739     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
740       return failure();
741     // Mask not supported.
742     if (xferOp.getMask())
743       return failure();
744     // Map not supported.
745     if (!xferOp.getPermutationMap().isMinorIdentity())
746       return failure();
747     // Only float and integer element types are supported.
748     Value scalar;
749     if (vecType.getRank() == 0) {
750       // vector.extract does not support vector<f32> etc., so use
751       // vector.extractelement instead.
752       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
753                                                          xferOp.getVector());
754     } else {
755       SmallVector<int64_t> pos(vecType.getRank(), 0);
756       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
757                                                   xferOp.getVector(), pos);
758     }
759     // Construct a scalar store.
760     if (isa<MemRefType>(xferOp.getSource().getType())) {
761       rewriter.replaceOpWithNewOp<memref::StoreOp>(
762           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
763     } else {
764       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
765           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
766     }
767     return success();
768   }
769 };
770 
771 } // namespace
772 
773 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
774                                      Operation *rootOp) {
775   TransferOptimization opt(rewriter, rootOp);
776   // Run store to load forwarding first since it can expose more dead store
777   // opportunity.
778   rootOp->walk([&](vector::TransferReadOp read) {
779     if (isa<MemRefType>(read.getShapedType()))
780       opt.storeToLoadForwarding(read);
781   });
782   opt.removeDeadOp();
783   rootOp->walk([&](vector::TransferWriteOp write) {
784     if (isa<MemRefType>(write.getShapedType()))
785       opt.deadStoreOp(write);
786   });
787   opt.removeDeadOp();
788 }
789 
790 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
791     RewritePatternSet &patterns, PatternBenefit benefit,
792     bool allowMultipleUses) {
793   patterns.add<RewriteScalarExtractElementOfTransferRead,
794                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
795                                                    benefit, allowMultipleUses);
796   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
797 }
798 
799 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
800     RewritePatternSet &patterns, PatternBenefit benefit) {
801   patterns
802       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
803           patterns.getContext(), benefit);
804   populateShapeCastFoldingPatterns(patterns);
805 }
806 
807 void mlir::vector::populateFlattenVectorTransferPatterns(
808     RewritePatternSet &patterns, PatternBenefit benefit) {
809   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
810                FlattenContiguousRowMajorTransferWritePattern>(
811       patterns.getContext(), benefit);
812   populateShapeCastFoldingPatterns(patterns, benefit);
813 }
814