xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (revision 22f96ab6fbf89dfa89faa2aa88cefb485fbd4e21)
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "vector-transfer-opt"
30 
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32 
33 using namespace mlir;
34 
35 /// Return the ancestor op in the region or nullptr if the region is not
36 /// an ancestor of the op.
37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
38   for (; op != nullptr && op->getParentRegion() != region;
39        op = op->getParentOp())
40     ;
41   return op;
42 }
43 
44 namespace {
45 
46 class TransferOptimization {
47 public:
48   TransferOptimization(RewriterBase &rewriter, Operation *op)
49       : rewriter(rewriter), dominators(op), postDominators(op) {}
50   void deadStoreOp(vector::TransferWriteOp);
51   void storeToLoadForwarding(vector::TransferReadOp);
52   void removeDeadOp() {
53     for (Operation *op : opToErase)
54       rewriter.eraseOp(op);
55     opToErase.clear();
56   }
57 
58 private:
59   RewriterBase &rewriter;
60   bool isReachable(Operation *start, Operation *dest);
61   DominanceInfo dominators;
62   PostDominanceInfo postDominators;
63   std::vector<Operation *> opToErase;
64 };
65 
66 } // namespace
67 /// Return true if there is a path from start operation to dest operation,
68 /// otherwise return false. The operations have to be in the same region.
69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
70   assert(start->getParentRegion() == dest->getParentRegion() &&
71          "This function only works for ops i the same region");
72   // Simple case where the start op dominate the destination.
73   if (dominators.dominates(start, dest))
74     return true;
75   Block *startBlock = start->getBlock();
76   Block *destBlock = dest->getBlock();
77   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78                                     startBlock->succ_end());
79   SmallPtrSet<Block *, 32> visited;
80   while (!worklist.empty()) {
81     Block *bb = worklist.pop_back_val();
82     if (!visited.insert(bb).second)
83       continue;
84     if (dominators.dominates(bb, destBlock))
85       return true;
86     worklist.append(bb->succ_begin(), bb->succ_end());
87   }
88   return false;
89 }
90 
91 /// For transfer_write to overwrite fully another transfer_write must:
92 /// 1. Access the same memref with the same indices and vector type.
93 /// 2. Post-dominate the other transfer_write operation.
94 /// If several candidates are available, one must be post-dominated by all the
95 /// others since they are all post-dominating the same transfer_write. We only
96 /// consider the transfer_write post-dominated by all the other candidates as
97 /// this will be the first transfer_write executed after the potentially dead
98 /// transfer_write.
99 /// If we found such an overwriting transfer_write we know that the original
100 /// transfer_write is dead if all reads that can be reached from the potentially
101 /// dead transfer_write are dominated by the overwriting transfer_write.
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104                     << "\n");
105   llvm::SmallVector<Operation *, 8> blockingAccesses;
106   Operation *firstOverwriteCandidate = nullptr;
107   Value source = write.getSource();
108   // Skip subview ops.
109   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110     source = subView.getSource();
111   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112                                            source.getUsers().end());
113   llvm::SmallDenseSet<Operation *, 32> processed;
114   while (!users.empty()) {
115     Operation *user = users.pop_back_val();
116     // If the user has already been processed skip.
117     if (!processed.insert(user).second)
118       continue;
119     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120       users.append(subView->getUsers().begin(), subView->getUsers().end());
121       continue;
122     }
123     if (isMemoryEffectFree(user))
124       continue;
125     if (user == write.getOperation())
126       continue;
127     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128       // Check candidate that can override the store.
129       if (write.getSource() == nextWrite.getSource() &&
130           checkSameValueWAW(nextWrite, write) &&
131           postDominators.postDominates(nextWrite, write)) {
132         if (firstOverwriteCandidate == nullptr ||
133             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134           firstOverwriteCandidate = nextWrite;
135         else
136           assert(
137               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138         continue;
139       }
140     }
141     if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142       // Don't need to consider disjoint accesses.
143       if (vector::isDisjointTransferSet(
144               cast<VectorTransferOpInterface>(write.getOperation()),
145               cast<VectorTransferOpInterface>(transferOp.getOperation())))
146         continue;
147     }
148     blockingAccesses.push_back(user);
149   }
150   if (firstOverwriteCandidate == nullptr)
151     return;
152   Region *topRegion = firstOverwriteCandidate->getParentRegion();
153   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
154   assert(writeAncestor &&
155          "write op should be recursively part of the top region");
156 
157   for (Operation *access : blockingAccesses) {
158     Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
159     // TODO: if the access and write have the same ancestor we could recurse in
160     // the region to know if the access is reachable with more precision.
161     if (accessAncestor == nullptr ||
162         !isReachable(writeAncestor, accessAncestor))
163       continue;
164     if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
165       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
166                         << *accessAncestor << "\n");
167       return;
168     }
169   }
170   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
171                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
172   opToErase.push_back(write.getOperation());
173 }
174 
175 /// A transfer_write candidate to storeToLoad forwarding must:
176 /// 1. Access the same memref with the same indices and vector type as the
177 /// transfer_read.
178 /// 2. Dominate the transfer_read operation.
179 /// If several candidates are available, one must be dominated by all the others
180 /// since they are all dominating the same transfer_read. We only consider the
181 /// transfer_write dominated by all the other candidates as this will be the
182 /// last transfer_write executed before the transfer_read.
183 /// If we found such a candidate we can do the forwarding if all the other
184 /// potentially aliasing ops that may reach the transfer_read are post-dominated
185 /// by the transfer_write.
186 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
187   if (read.hasOutOfBoundsDim())
188     return;
189   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
190                     << "\n");
191   SmallVector<Operation *, 8> blockingWrites;
192   vector::TransferWriteOp lastwrite = nullptr;
193   Value source = read.getSource();
194   // Skip subview ops.
195   while (auto subView = source.getDefiningOp<memref::SubViewOp>())
196     source = subView.getSource();
197   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
198                                            source.getUsers().end());
199   llvm::SmallDenseSet<Operation *, 32> processed;
200   while (!users.empty()) {
201     Operation *user = users.pop_back_val();
202     // If the user has already been processed skip.
203     if (!processed.insert(user).second)
204       continue;
205     if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
206       users.append(subView->getUsers().begin(), subView->getUsers().end());
207       continue;
208     }
209     if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
210       users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
211       continue;
212     }
213     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
214       continue;
215     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
216       // If there is a write, but we can prove that it is disjoint we can ignore
217       // the write.
218       if (vector::isDisjointTransferSet(
219               cast<VectorTransferOpInterface>(write.getOperation()),
220               cast<VectorTransferOpInterface>(read.getOperation())))
221         continue;
222       if (write.getSource() == read.getSource() &&
223           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
224         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
225           lastwrite = write;
226         else
227           assert(dominators.dominates(write, lastwrite));
228         continue;
229       }
230     }
231     blockingWrites.push_back(user);
232   }
233 
234   if (lastwrite == nullptr)
235     return;
236 
237   Region *topRegion = lastwrite->getParentRegion();
238   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
239   assert(readAncestor &&
240          "read op should be recursively part of the top region");
241 
242   for (Operation *write : blockingWrites) {
243     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
244     // TODO: if the store and read have the same ancestor we could recurse in
245     // the region to know if the read is reachable with more precision.
246     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
247       continue;
248     if (!postDominators.postDominates(lastwrite, write)) {
249       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
250                         << *write << "\n");
251       return;
252     }
253   }
254 
255   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
256                     << " to: " << *read.getOperation() << "\n");
257   read.replaceAllUsesWith(lastwrite.getVector());
258   opToErase.push_back(read.getOperation());
259 }
260 
261 /// Drops unit dimensions from the input MemRefType.
262 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
263                                ArrayRef<int64_t> sizes,
264                                ArrayRef<int64_t> strides) {
265   SmallVector<int64_t> targetShape = llvm::to_vector(
266       llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
267   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
268       targetShape, inputType, offsets, sizes, strides);
269   return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
270 }
271 
272 /// Creates a rank-reducing memref.subview op that drops unit dims from its
273 /// input. Or just returns the input if it was already without unit dims.
274 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
275                                                  mlir::Location loc,
276                                                  Value input) {
277   MemRefType inputType = cast<MemRefType>(input.getType());
278   assert(inputType.hasStaticShape());
279   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
280   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
281   ArrayRef<int64_t> subViewSizes = inputType.getShape();
282   MemRefType resultType =
283       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
284   if (canonicalizeStridedLayout(resultType) ==
285       canonicalizeStridedLayout(inputType))
286     return input;
287   return rewriter.create<memref::SubViewOp>(
288       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
289 }
290 
291 /// Returns the number of dims that aren't unit dims.
292 static int getReducedRank(ArrayRef<int64_t> shape) {
293   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
294 }
295 
296 /// Returns a copy of `shape` without unit dims.
297 static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
298   SmallVector<int64_t> reducedShape;
299   llvm::copy_if(shape, std::back_inserter(reducedShape),
300                 [](int64_t dimSize) { return dimSize != 1; });
301   return reducedShape;
302 }
303 
304 namespace {
305 
306 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
307 /// inserting a memref.subview dropping those unit dims. The vector shapes are
308 /// also reduced accordingly.
309 class TransferReadDropUnitDimsPattern
310     : public OpRewritePattern<vector::TransferReadOp> {
311   using OpRewritePattern::OpRewritePattern;
312 
313   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
314                                 PatternRewriter &rewriter) const override {
315     auto loc = transferReadOp.getLoc();
316     Value vector = transferReadOp.getVector();
317     VectorType vectorType = cast<VectorType>(vector.getType());
318     Value source = transferReadOp.getSource();
319     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
320     // TODO: support tensor types.
321     if (!sourceType || !sourceType.hasStaticShape())
322       return failure();
323     if (sourceType.getNumElements() != vectorType.getNumElements())
324       return failure();
325     // TODO: generalize this pattern, relax the requirements here.
326     if (transferReadOp.hasOutOfBoundsDim())
327       return failure();
328     if (!transferReadOp.getPermutationMap().isMinorIdentity())
329       return failure();
330     // Check if the source shape can be further reduced.
331     int reducedRank = getReducedRank(sourceType.getShape());
332     if (reducedRank == sourceType.getRank())
333       return failure();
334     // Check if the reduced vector shape matches the reduced source shape.
335     // Otherwise, this case is not supported yet.
336     int vectorReducedRank = getReducedRank(vectorType.getShape());
337     if (reducedRank != vectorReducedRank)
338       return failure();
339     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
340           return getConstantIntValue(v) != static_cast<int64_t>(0);
341         }))
342       return failure();
343     Value reducedShapeSource =
344         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
345     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
346     SmallVector<Value> zeros(reducedRank, c0);
347     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
348     auto reducedVectorType = VectorType::get(
349         getReducedShape(vectorType.getShape()), vectorType.getElementType());
350 
351     auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
352         loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
353     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
354         loc, vectorType, newTransferReadOp);
355     rewriter.replaceOp(transferReadOp, shapeCast);
356 
357     return success();
358   }
359 };
360 
361 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
362 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
363 /// vector shapes are also reduced accordingly.
364 class TransferWriteDropUnitDimsPattern
365     : public OpRewritePattern<vector::TransferWriteOp> {
366   using OpRewritePattern::OpRewritePattern;
367 
368   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
369                                 PatternRewriter &rewriter) const override {
370     auto loc = transferWriteOp.getLoc();
371     Value vector = transferWriteOp.getVector();
372     VectorType vectorType = cast<VectorType>(vector.getType());
373     Value source = transferWriteOp.getSource();
374     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
375     // TODO: support tensor type.
376     if (!sourceType || !sourceType.hasStaticShape())
377       return failure();
378     if (sourceType.getNumElements() != vectorType.getNumElements())
379       return failure();
380     // TODO: generalize this pattern, relax the requirements here.
381     if (transferWriteOp.hasOutOfBoundsDim())
382       return failure();
383     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
384       return failure();
385     // Check if the destination shape can be further reduced.
386     int reducedRank = getReducedRank(sourceType.getShape());
387     if (reducedRank == sourceType.getRank())
388       return failure();
389     // Check if the reduced vector shape matches the reduced destination shape.
390     // Otherwise, this case is not supported yet.
391     int vectorReducedRank = getReducedRank(vectorType.getShape());
392     if (reducedRank != vectorReducedRank)
393       return failure();
394     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
395           return getConstantIntValue(v) != static_cast<int64_t>(0);
396         }))
397       return failure();
398     Value reducedShapeSource =
399         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
400     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
401     SmallVector<Value> zeros(reducedRank, c0);
402     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
403     VectorType reducedVectorType = VectorType::get(
404         getReducedShape(vectorType.getShape()), vectorType.getElementType());
405 
406     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
407         loc, reducedVectorType, vector);
408     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
409         transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
410 
411     return success();
412   }
413 };
414 
415 } // namespace
416 
417 /// Return true if the memref type has its inner dimension matching the given
418 /// shape. Otherwise return false.
419 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
420                                               ArrayRef<int64_t> targetShape) {
421   auto shape = memrefType.getShape();
422   SmallVector<int64_t> strides;
423   int64_t offset;
424   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
425     return false;
426   if (strides.back() != 1)
427     return false;
428   strides.pop_back();
429   int64_t flatDim = 1;
430   for (auto [targetDim, memrefDim, memrefStride] :
431        llvm::reverse(llvm::zip(targetShape, shape, strides))) {
432     flatDim *= memrefDim;
433     if (flatDim != memrefStride || targetDim != memrefDim)
434       return false;
435   }
436   return true;
437 }
438 
439 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
440 /// input starting at `firstDimToCollapse`.
441 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
442                                Value input, int64_t firstDimToCollapse) {
443   ShapedType inputType = cast<ShapedType>(input.getType());
444   if (inputType.getRank() == 1)
445     return input;
446   SmallVector<ReassociationIndices> reassociation;
447   for (int64_t i = 0; i < firstDimToCollapse; ++i)
448     reassociation.push_back(ReassociationIndices{i});
449   ReassociationIndices collapsedIndices;
450   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
451     collapsedIndices.push_back(i);
452   reassociation.push_back(collapsedIndices);
453   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
454 }
455 
456 /// Checks that the indices corresponding to dimensions starting at
457 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
458 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
459 static LogicalResult
460 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
461                                  SmallVector<Value> &outIndices) {
462   int64_t rank = indices.size();
463   if (firstDimToCollapse >= rank)
464     return failure();
465   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
466     std::optional<int64_t> cst = getConstantIntValue(indices[i]);
467     if (!cst || cst.value() != 0)
468       return failure();
469   }
470   outIndices = indices;
471   outIndices.resize(firstDimToCollapse + 1);
472   return success();
473 }
474 
475 namespace {
476 
477 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
478 /// memref.collapse_shape on the source so that the resulting
479 /// vector.transfer_read has a 1D source. Requires the source shape to be
480 /// already reduced i.e. without unit dims.
481 class FlattenContiguousRowMajorTransferReadPattern
482     : public OpRewritePattern<vector::TransferReadOp> {
483   using OpRewritePattern::OpRewritePattern;
484 
485   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
486                                 PatternRewriter &rewriter) const override {
487     auto loc = transferReadOp.getLoc();
488     Value vector = transferReadOp.getVector();
489     VectorType vectorType = cast<VectorType>(vector.getType());
490     Value source = transferReadOp.getSource();
491     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
492     // Contiguity check is valid on tensors only.
493     if (!sourceType)
494       return failure();
495     if (vectorType.getRank() <= 1)
496       // Already 0D/1D, nothing to do.
497       return failure();
498     if (!hasMatchingInnerContigousShape(
499             sourceType,
500             vectorType.getShape().take_back(vectorType.getRank() - 1)))
501       return failure();
502     int64_t firstContiguousInnerDim =
503         sourceType.getRank() - vectorType.getRank();
504     // TODO: generalize this pattern, relax the requirements here.
505     if (transferReadOp.hasOutOfBoundsDim())
506       return failure();
507     if (!transferReadOp.getPermutationMap().isMinorIdentity())
508       return failure();
509     if (transferReadOp.getMask())
510       return failure();
511     SmallVector<Value> collapsedIndices;
512     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
513                                                 firstContiguousInnerDim,
514                                                 collapsedIndices)))
515       return failure();
516     Value collapsedSource =
517         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
518     MemRefType collapsedSourceType =
519         dyn_cast<MemRefType>(collapsedSource.getType());
520     int64_t collapsedRank = collapsedSourceType.getRank();
521     assert(collapsedRank == firstContiguousInnerDim + 1);
522     SmallVector<AffineExpr, 1> dimExprs{
523         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
524     auto collapsedMap =
525         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
526     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
527                                                 vectorType.getElementType());
528     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
529         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
530     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
531     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
532         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
533     return success();
534   }
535 };
536 
537 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
538 /// memref.collapse_shape on the source so that the resulting
539 /// vector.transfer_write has a 1D source. Requires the source shape to be
540 /// already reduced i.e. without unit dims.
541 class FlattenContiguousRowMajorTransferWritePattern
542     : public OpRewritePattern<vector::TransferWriteOp> {
543   using OpRewritePattern::OpRewritePattern;
544 
545   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
546                                 PatternRewriter &rewriter) const override {
547     auto loc = transferWriteOp.getLoc();
548     Value vector = transferWriteOp.getVector();
549     VectorType vectorType = cast<VectorType>(vector.getType());
550     Value source = transferWriteOp.getSource();
551     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
552     // Contiguity check is valid on tensors only.
553     if (!sourceType)
554       return failure();
555     if (vectorType.getRank() <= 1)
556       // Already 0D/1D, nothing to do.
557       return failure();
558     if (!hasMatchingInnerContigousShape(
559             sourceType,
560             vectorType.getShape().take_back(vectorType.getRank() - 1)))
561       return failure();
562     int64_t firstContiguousInnerDim =
563         sourceType.getRank() - vectorType.getRank();
564     // TODO: generalize this pattern, relax the requirements here.
565     if (transferWriteOp.hasOutOfBoundsDim())
566       return failure();
567     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
568       return failure();
569     if (transferWriteOp.getMask())
570       return failure();
571     SmallVector<Value> collapsedIndices;
572     if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
573                                                 firstContiguousInnerDim,
574                                                 collapsedIndices)))
575       return failure();
576     Value collapsedSource =
577         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
578     MemRefType collapsedSourceType =
579         cast<MemRefType>(collapsedSource.getType());
580     int64_t collapsedRank = collapsedSourceType.getRank();
581     assert(collapsedRank == firstContiguousInnerDim + 1);
582     SmallVector<AffineExpr, 1> dimExprs{
583         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
584     auto collapsedMap =
585         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
586     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
587                                                 vectorType.getElementType());
588     Value flatVector =
589         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
590     vector::TransferWriteOp flatWrite =
591         rewriter.create<vector::TransferWriteOp>(
592             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
593     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
594     rewriter.eraseOp(transferWriteOp);
595     return success();
596   }
597 };
598 
599 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
600 /// to `memref.load` patterns. The `match` method is shared for both
601 /// `vector.extract` and `vector.extract_element`.
602 template <class VectorExtractOp>
603 class RewriteScalarExtractOfTransferReadBase
604     : public OpRewritePattern<VectorExtractOp> {
605   using Base = OpRewritePattern<VectorExtractOp>;
606 
607 public:
608   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
609                                          PatternBenefit benefit,
610                                          bool allowMultipleUses)
611       : Base::OpRewritePattern(context, benefit),
612         allowMultipleUses(allowMultipleUses) {}
613 
614   LogicalResult match(VectorExtractOp extractOp) const override {
615     auto xferOp =
616         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
617     if (!xferOp)
618       return failure();
619     // Check that we are extracting a scalar and not a sub-vector.
620     if (isa<VectorType>(extractOp.getResult().getType()))
621       return failure();
622     // If multiple uses are not allowed, check if xfer has a single use.
623     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
624       return failure();
625     // If multiple uses are allowed, check if all the xfer uses are extract ops.
626     if (allowMultipleUses &&
627         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
628           return isa<vector::ExtractOp, vector::ExtractElementOp>(
629               use.getOwner());
630         }))
631       return failure();
632     // Mask not supported.
633     if (xferOp.getMask())
634       return failure();
635     // Map not supported.
636     if (!xferOp.getPermutationMap().isMinorIdentity())
637       return failure();
638     // Cannot rewrite if the indices may be out of bounds.
639     if (xferOp.hasOutOfBoundsDim())
640       return failure();
641     return success();
642   }
643 
644 private:
645   bool allowMultipleUses;
646 };
647 
648 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
649 ///
650 /// All the users of the transfer op must be either `vector.extractelement` or
651 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
652 /// transfer ops with any number of users. Otherwise, rewrite only if the
653 /// extract op is the single user of the transfer op. Rewriting a single
654 /// vector load with multiple scalar loads may negatively affect performance.
655 class RewriteScalarExtractElementOfTransferRead
656     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
657   using RewriteScalarExtractOfTransferReadBase::
658       RewriteScalarExtractOfTransferReadBase;
659 
660   void rewrite(vector::ExtractElementOp extractOp,
661                PatternRewriter &rewriter) const override {
662     // Construct scalar load.
663     auto loc = extractOp.getLoc();
664     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
665     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
666                                   xferOp.getIndices().end());
667     if (extractOp.getPosition()) {
668       AffineExpr sym0, sym1;
669       bindSymbols(extractOp.getContext(), sym0, sym1);
670       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
671           rewriter, loc, sym0 + sym1,
672           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
673       if (ofr.is<Value>()) {
674         newIndices[newIndices.size() - 1] = ofr.get<Value>();
675       } else {
676         newIndices[newIndices.size() - 1] =
677             rewriter.create<arith::ConstantIndexOp>(loc,
678                                                     *getConstantIntValue(ofr));
679       }
680     }
681     if (isa<MemRefType>(xferOp.getSource().getType())) {
682       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
683                                                   newIndices);
684     } else {
685       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
686           extractOp, xferOp.getSource(), newIndices);
687     }
688   }
689 };
690 
691 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
692 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
693 ///
694 /// All the users of the transfer op must be either `vector.extractelement` or
695 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
696 /// transfer ops with any number of users. Otherwise, rewrite only if the
697 /// extract op is the single user of the transfer op. Rewriting a single
698 /// vector load with multiple scalar loads may negatively affect performance.
699 class RewriteScalarExtractOfTransferRead
700     : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
701   using RewriteScalarExtractOfTransferReadBase::
702       RewriteScalarExtractOfTransferReadBase;
703 
704   void rewrite(vector::ExtractOp extractOp,
705                PatternRewriter &rewriter) const override {
706     // Construct scalar load.
707     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
708     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
709                                   xferOp.getIndices().end());
710     for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
711       int64_t offset = it.value();
712       int64_t idx =
713           newIndices.size() - extractOp.getPosition().size() + it.index();
714       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
715           rewriter, extractOp.getLoc(),
716           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
717       if (ofr.is<Value>()) {
718         newIndices[idx] = ofr.get<Value>();
719       } else {
720         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
721             extractOp.getLoc(), *getConstantIntValue(ofr));
722       }
723     }
724     if (isa<MemRefType>(xferOp.getSource().getType())) {
725       rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
726                                                   newIndices);
727     } else {
728       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
729           extractOp, xferOp.getSource(), newIndices);
730     }
731   }
732 };
733 
734 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
735 /// to memref.store.
736 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
737   using OpRewritePattern::OpRewritePattern;
738 
739   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
740                                 PatternRewriter &rewriter) const override {
741     // Must be a scalar write.
742     auto vecType = xferOp.getVectorType();
743     if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
744       return failure();
745     // Mask not supported.
746     if (xferOp.getMask())
747       return failure();
748     // Map not supported.
749     if (!xferOp.getPermutationMap().isMinorIdentity())
750       return failure();
751     // Only float and integer element types are supported.
752     Value scalar;
753     if (vecType.getRank() == 0) {
754       // vector.extract does not support vector<f32> etc., so use
755       // vector.extractelement instead.
756       scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
757                                                          xferOp.getVector());
758     } else {
759       SmallVector<int64_t> pos(vecType.getRank(), 0);
760       scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
761                                                   xferOp.getVector(), pos);
762     }
763     // Construct a scalar store.
764     if (isa<MemRefType>(xferOp.getSource().getType())) {
765       rewriter.replaceOpWithNewOp<memref::StoreOp>(
766           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
767     } else {
768       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
769           xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
770     }
771     return success();
772   }
773 };
774 
775 } // namespace
776 
777 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
778                                      Operation *rootOp) {
779   TransferOptimization opt(rewriter, rootOp);
780   // Run store to load forwarding first since it can expose more dead store
781   // opportunity.
782   rootOp->walk([&](vector::TransferReadOp read) {
783     if (isa<MemRefType>(read.getShapedType()))
784       opt.storeToLoadForwarding(read);
785   });
786   opt.removeDeadOp();
787   rootOp->walk([&](vector::TransferWriteOp write) {
788     if (isa<MemRefType>(write.getShapedType()))
789       opt.deadStoreOp(write);
790   });
791   opt.removeDeadOp();
792 }
793 
794 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
795     RewritePatternSet &patterns, PatternBenefit benefit,
796     bool allowMultipleUses) {
797   patterns.add<RewriteScalarExtractElementOfTransferRead,
798                RewriteScalarExtractOfTransferRead>(patterns.getContext(),
799                                                    benefit, allowMultipleUses);
800   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
801 }
802 
803 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
804     RewritePatternSet &patterns, PatternBenefit benefit) {
805   patterns
806       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
807           patterns.getContext(), benefit);
808   populateShapeCastFoldingPatterns(patterns);
809 }
810 
811 void mlir::vector::populateFlattenVectorTransferPatterns(
812     RewritePatternSet &patterns, PatternBenefit benefit) {
813   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
814                FlattenContiguousRowMajorTransferWritePattern>(
815       patterns.getContext(), benefit);
816   populateShapeCastFoldingPatterns(patterns, benefit);
817 }
818