xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24 #include "mlir/Dialect/Utils/IndexingUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 } // namespace memref
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 /// Given the 'indices' of a load/store operation where the memref is a result
50 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
51 /// expand_shape op. For example
52 ///
53 /// %0 = ... : memref<12x42xf32>
54 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
55 ///    : memref<12x42xf32> into memref<2x6x42xf32>
56 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
57 ///
58 /// could be folded into
59 ///
60 /// %2 = load %0[6 * i1 + i2, %i3] :
61 ///          memref<12x42xf32>
62 static LogicalResult
63 resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
64                                 memref::ExpandShapeOp expandShapeOp,
65                                 ValueRange indices,
66                                 SmallVectorImpl<Value> &sourceIndices) {
67   // Record the rewriter context for constructing ops later.
68   MLIRContext *ctx = rewriter.getContext();
69 
70   // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71   // This is done for the purpose of inferring the output shape via
72   // `inferExpandOutputShape` which will in turn be used for suffix product
73   // calculation later.
74   SmallVector<OpFoldResult> srcShape;
75   MemRefType srcType = expandShapeOp.getSrcType();
76 
77   for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78     if (srcType.isDynamicDim(i)) {
79       srcShape.push_back(
80           rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81               .getResult());
82     } else {
83       srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84     }
85   }
86 
87   auto outputShape = inferExpandShapeOutputShape(
88       rewriter, loc, expandShapeOp.getResultType(),
89       expandShapeOp.getReassociationIndices(), srcShape);
90   if (!outputShape.has_value())
91     return failure();
92 
93   // Traverse all reassociation groups to determine the appropriate indices
94   // corresponding to each one of them post op folding.
95   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96     assert(!groups.empty() && "association indices groups cannot be empty");
97     // Flag to indicate the presence of dynamic dimensions in current
98     // reassociation group.
99     int64_t groupSize = groups.size();
100 
101     // Group output dimensions utilized in this reassociation group for suffix
102     // product calculation.
103     SmallVector<OpFoldResult> sizesVal(groupSize);
104     for (int64_t i = 0; i < groupSize; ++i) {
105       sizesVal[i] = (*outputShape)[groups[i]];
106     }
107 
108     // Calculate suffix product of relevant output dimension sizes.
109     SmallVector<OpFoldResult> suffixProduct =
110         memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111 
112     // Create affine expression variables for dimensions and symbols in the
113     // newly constructed affine map.
114     SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115     bindDimsList<AffineExpr>(ctx, dims);
116     bindSymbolsList<AffineExpr>(ctx, symbols);
117 
118     // Linearize binded dimensions and symbols to construct the resultant
119     // affine expression for this indice.
120     AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121 
122     // Record the load index corresponding to each dimension in the
123     // reassociation group. These are later supplied as operands to the affine
124     // map used for calulating relevant index post op folding.
125     SmallVector<OpFoldResult> dynamicIndices(groupSize);
126     for (int64_t i = 0; i < groupSize; i++)
127       dynamicIndices[i] = indices[groups[i]];
128 
129     // Supply suffix product results followed by load op indices as operands
130     // to the map.
131     SmallVector<OpFoldResult> mapOperands;
132     llvm::append_range(mapOperands, suffixProduct);
133     llvm::append_range(mapOperands, dynamicIndices);
134 
135     // Creating maximally folded and composed affine.apply composes better
136     // with other transformations without interleaving canonicalization
137     // passes.
138     OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
139         rewriter, loc,
140         AffineMap::get(/*numDims=*/groupSize,
141                        /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142         mapOperands);
143 
144     // Push index value in the op post folding corresponding to this
145     // reassociation group.
146     sourceIndices.push_back(
147         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
148   }
149   return success();
150 }
151 
152 /// Given the 'indices' of a load/store operation where the memref is a result
153 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
154 /// the collapse_shape op. For example
155 ///
156 /// %0 = ... : memref<2x6x42xf32>
157 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
158 ///    : memref<2x6x42xf32> into memref<12x42xf32>
159 /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
160 ///
161 /// could be folded into
162 ///
163 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
164 ///          memref<2x6x42xf32>
165 static LogicalResult
166 resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
167                                   memref::CollapseShapeOp collapseShapeOp,
168                                   ValueRange indices,
169                                   SmallVectorImpl<Value> &sourceIndices) {
170   int64_t cnt = 0;
171   SmallVector<Value> tmp(indices.size());
172   SmallVector<OpFoldResult> dynamicIndices;
173   for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
174     assert(!groups.empty() && "association indices groups cannot be empty");
175     dynamicIndices.push_back(indices[cnt++]);
176     int64_t groupSize = groups.size();
177 
178     // Calculate suffix product for all collapse op source dimension sizes
179     // except the most major one of each group.
180     // We allow the most major source dimension to be dynamic but enforce all
181     // others to be known statically.
182     SmallVector<int64_t> sizes(groupSize, 1);
183     for (int64_t i = 1; i < groupSize; ++i) {
184       sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
185       if (sizes[i] == ShapedType::kDynamic)
186         return failure();
187     }
188     SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
189 
190     // Derive the index values along all dimensions of the source corresponding
191     // to the index wrt to collapsed shape op output.
192     auto d0 = rewriter.getAffineDimExpr(0);
193     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
194 
195     // Construct the AffineApplyOp for each delinearizingExpr.
196     for (int64_t i = 0; i < groupSize; i++) {
197       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
198           rewriter, loc,
199           AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
200                          delinearizingExprs[i]),
201           dynamicIndices);
202       sourceIndices.push_back(
203           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
204     }
205     dynamicIndices.clear();
206   }
207   if (collapseShapeOp.getReassociationIndices().empty()) {
208     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
209     int64_t srcRank =
210         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
211     for (int64_t i = 0; i < srcRank; i++) {
212       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
213           rewriter, loc, zeroAffineMap, dynamicIndices);
214       sourceIndices.push_back(
215           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
216     }
217   }
218   return success();
219 }
220 
221 /// Helpers to access the memref operand for each op.
222 template <typename LoadOrStoreOpTy>
223 static Value getMemRefOperand(LoadOrStoreOpTy op) {
224   return op.getMemref();
225 }
226 
227 static Value getMemRefOperand(vector::TransferReadOp op) {
228   return op.getSource();
229 }
230 
231 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
232   return op.getSrcMemref();
233 }
234 
235 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
236 
237 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
238 
239 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
240 
241 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
242 
243 static Value getMemRefOperand(vector::TransferWriteOp op) {
244   return op.getSource();
245 }
246 
247 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
248   return op.getSrcMemref();
249 }
250 
251 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
252   return op.getDstMemref();
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Patterns
257 //===----------------------------------------------------------------------===//
258 
259 namespace {
260 /// Merges subview operation with load/transferRead operation.
261 template <typename OpTy>
262 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
263 public:
264   using OpRewritePattern<OpTy>::OpRewritePattern;
265 
266   LogicalResult matchAndRewrite(OpTy loadOp,
267                                 PatternRewriter &rewriter) const override;
268 };
269 
270 /// Merges expand_shape operation with load/transferRead operation.
271 template <typename OpTy>
272 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
273 public:
274   using OpRewritePattern<OpTy>::OpRewritePattern;
275 
276   LogicalResult matchAndRewrite(OpTy loadOp,
277                                 PatternRewriter &rewriter) const override;
278 };
279 
280 /// Merges collapse_shape operation with load/transferRead operation.
281 template <typename OpTy>
282 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
283 public:
284   using OpRewritePattern<OpTy>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(OpTy loadOp,
287                                 PatternRewriter &rewriter) const override;
288 };
289 
290 /// Merges subview operation with store/transferWriteOp operation.
291 template <typename OpTy>
292 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
293 public:
294   using OpRewritePattern<OpTy>::OpRewritePattern;
295 
296   LogicalResult matchAndRewrite(OpTy storeOp,
297                                 PatternRewriter &rewriter) const override;
298 };
299 
300 /// Merges expand_shape operation with store/transferWriteOp operation.
301 template <typename OpTy>
302 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
303 public:
304   using OpRewritePattern<OpTy>::OpRewritePattern;
305 
306   LogicalResult matchAndRewrite(OpTy storeOp,
307                                 PatternRewriter &rewriter) const override;
308 };
309 
310 /// Merges collapse_shape operation with store/transferWriteOp operation.
311 template <typename OpTy>
312 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
313 public:
314   using OpRewritePattern<OpTy>::OpRewritePattern;
315 
316   LogicalResult matchAndRewrite(OpTy storeOp,
317                                 PatternRewriter &rewriter) const override;
318 };
319 
320 /// Folds subview(subview(x)) to a single subview(x).
321 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
322 public:
323   using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
324 
325   LogicalResult matchAndRewrite(memref::SubViewOp subView,
326                                 PatternRewriter &rewriter) const override {
327     auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
328     if (!srcSubView)
329       return failure();
330 
331     // TODO: relax unit stride assumption.
332     if (!subView.hasUnitStride()) {
333       return rewriter.notifyMatchFailure(subView, "requires unit strides");
334     }
335     if (!srcSubView.hasUnitStride()) {
336       return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
337     }
338 
339     // Resolve sizes according to dropped dims.
340     SmallVector<OpFoldResult> resolvedSizes;
341     llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
342     affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
343                                         subView.getMixedSizes(), srcDroppedDims,
344                                         resolvedSizes);
345 
346     // Resolve offsets according to source offsets and strides.
347     SmallVector<Value> resolvedOffsets;
348     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
349         rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
350         srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
351         resolvedOffsets);
352 
353     // Replace original op.
354     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
355         subView, subView.getType(), srcSubView.getSource(),
356         getAsOpFoldResult(resolvedOffsets), resolvedSizes,
357         srcSubView.getMixedStrides());
358 
359     return success();
360   }
361 };
362 
363 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
364 /// is folds subview on src and dst memref of the copy.
365 class NVGPUAsyncCopyOpSubViewOpFolder final
366     : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
367 public:
368   using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
369 
370   LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
371                                 PatternRewriter &rewriter) const override;
372 };
373 } // namespace
374 
375 static SmallVector<Value>
376 calculateExpandedAccessIndices(AffineMap affineMap,
377                                const SmallVector<Value> &indices, Location loc,
378                                PatternRewriter &rewriter) {
379   SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
380       llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
381   SmallVector<Value> expandedIndices;
382   for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
383     OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
384         rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
385     expandedIndices.push_back(
386         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
387   }
388   return expandedIndices;
389 }
390 
391 template <typename XferOp>
392 static LogicalResult
393 preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
394                                memref::SubViewOp subviewOp) {
395   static_assert(
396       !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
397       "must be a vector transfer op");
398   if (xferOp.hasOutOfBoundsDim())
399     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
400   if (!subviewOp.hasUnitStride()) {
401     return rewriter.notifyMatchFailure(
402         xferOp, "non-1 stride subview, need to track strides in folded memref");
403   }
404   return success();
405 }
406 
407 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
408                                                 Operation *op,
409                                                 memref::SubViewOp subviewOp) {
410   return success();
411 }
412 
413 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
414                                                 vector::TransferReadOp readOp,
415                                                 memref::SubViewOp subviewOp) {
416   return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
417 }
418 
419 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
420                                                 vector::TransferWriteOp writeOp,
421                                                 memref::SubViewOp subviewOp) {
422   return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
423 }
424 
425 template <typename OpTy>
426 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
427     OpTy loadOp, PatternRewriter &rewriter) const {
428   auto subViewOp =
429       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
430 
431   if (!subViewOp)
432     return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
433 
434   LogicalResult preconditionResult =
435       preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
436   if (failed(preconditionResult))
437     return preconditionResult;
438 
439   SmallVector<Value> indices(loadOp.getIndices().begin(),
440                              loadOp.getIndices().end());
441   // For affine ops, we need to apply the map to get the operands to get the
442   // "actual" indices.
443   if (auto affineLoadOp =
444           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
445     AffineMap affineMap = affineLoadOp.getAffineMap();
446     auto expandedIndices = calculateExpandedAccessIndices(
447         affineMap, indices, loadOp.getLoc(), rewriter);
448     indices.assign(expandedIndices.begin(), expandedIndices.end());
449   }
450   SmallVector<Value> sourceIndices;
451   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
452       rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
453       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
454       sourceIndices);
455 
456   llvm::TypeSwitch<Operation *, void>(loadOp)
457       .Case([&](affine::AffineLoadOp op) {
458         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
459             loadOp, subViewOp.getSource(), sourceIndices);
460       })
461       .Case([&](memref::LoadOp op) {
462         rewriter.replaceOpWithNewOp<memref::LoadOp>(
463             loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
464       })
465       .Case([&](vector::LoadOp op) {
466         rewriter.replaceOpWithNewOp<vector::LoadOp>(
467             op, op.getType(), subViewOp.getSource(), sourceIndices);
468       })
469       .Case([&](vector::MaskedLoadOp op) {
470         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
471             op, op.getType(), subViewOp.getSource(), sourceIndices,
472             op.getMask(), op.getPassThru());
473       })
474       .Case([&](vector::TransferReadOp op) {
475         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
476             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
477             AffineMapAttr::get(expandDimsToRank(
478                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
479                 subViewOp.getDroppedDims())),
480             op.getPadding(), op.getMask(), op.getInBoundsAttr());
481       })
482       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
483         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
484             op, op.getType(), subViewOp.getSource(), sourceIndices,
485             op.getLeadDimension(), op.getTransposeAttr());
486       })
487       .Case([&](nvgpu::LdMatrixOp op) {
488         rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
489             op, op.getType(), subViewOp.getSource(), sourceIndices,
490             op.getTranspose(), op.getNumTiles());
491       })
492       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
493   return success();
494 }
495 
496 template <typename OpTy>
497 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
498     OpTy loadOp, PatternRewriter &rewriter) const {
499   auto expandShapeOp =
500       getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
501 
502   if (!expandShapeOp)
503     return failure();
504 
505   SmallVector<Value> indices(loadOp.getIndices().begin(),
506                              loadOp.getIndices().end());
507   // For affine ops, we need to apply the map to get the operands to get the
508   // "actual" indices.
509   if (auto affineLoadOp =
510           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
511     AffineMap affineMap = affineLoadOp.getAffineMap();
512     auto expandedIndices = calculateExpandedAccessIndices(
513         affineMap, indices, loadOp.getLoc(), rewriter);
514     indices.assign(expandedIndices.begin(), expandedIndices.end());
515   }
516   SmallVector<Value> sourceIndices;
517   if (failed(resolveSourceIndicesExpandShape(
518           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
519     return failure();
520   llvm::TypeSwitch<Operation *, void>(loadOp)
521       .Case([&](affine::AffineLoadOp op) {
522         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
523             loadOp, expandShapeOp.getViewSource(), sourceIndices);
524       })
525       .Case([&](memref::LoadOp op) {
526         rewriter.replaceOpWithNewOp<memref::LoadOp>(
527             loadOp, expandShapeOp.getViewSource(), sourceIndices,
528             op.getNontemporal());
529       })
530       .Case([&](vector::LoadOp op) {
531         rewriter.replaceOpWithNewOp<vector::LoadOp>(
532             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
533             op.getNontemporal());
534       })
535       .Case([&](vector::MaskedLoadOp op) {
536         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
537             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
538             op.getMask(), op.getPassThru());
539       })
540       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
541   return success();
542 }
543 
544 template <typename OpTy>
545 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
546     OpTy loadOp, PatternRewriter &rewriter) const {
547   auto collapseShapeOp = getMemRefOperand(loadOp)
548                              .template getDefiningOp<memref::CollapseShapeOp>();
549 
550   if (!collapseShapeOp)
551     return failure();
552 
553   SmallVector<Value> indices(loadOp.getIndices().begin(),
554                              loadOp.getIndices().end());
555   // For affine ops, we need to apply the map to get the operands to get the
556   // "actual" indices.
557   if (auto affineLoadOp =
558           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
559     AffineMap affineMap = affineLoadOp.getAffineMap();
560     auto expandedIndices = calculateExpandedAccessIndices(
561         affineMap, indices, loadOp.getLoc(), rewriter);
562     indices.assign(expandedIndices.begin(), expandedIndices.end());
563   }
564   SmallVector<Value> sourceIndices;
565   if (failed(resolveSourceIndicesCollapseShape(
566           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
567     return failure();
568   llvm::TypeSwitch<Operation *, void>(loadOp)
569       .Case([&](affine::AffineLoadOp op) {
570         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
571             loadOp, collapseShapeOp.getViewSource(), sourceIndices);
572       })
573       .Case([&](memref::LoadOp op) {
574         rewriter.replaceOpWithNewOp<memref::LoadOp>(
575             loadOp, collapseShapeOp.getViewSource(), sourceIndices,
576             op.getNontemporal());
577       })
578       .Case([&](vector::LoadOp op) {
579         rewriter.replaceOpWithNewOp<vector::LoadOp>(
580             op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
581             op.getNontemporal());
582       })
583       .Case([&](vector::MaskedLoadOp op) {
584         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
585             op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
586             op.getMask(), op.getPassThru());
587       })
588       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
589   return success();
590 }
591 
592 template <typename OpTy>
593 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
594     OpTy storeOp, PatternRewriter &rewriter) const {
595   auto subViewOp =
596       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
597 
598   if (!subViewOp)
599     return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
600 
601   LogicalResult preconditionResult =
602       preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
603   if (failed(preconditionResult))
604     return preconditionResult;
605 
606   SmallVector<Value> indices(storeOp.getIndices().begin(),
607                              storeOp.getIndices().end());
608   // For affine ops, we need to apply the map to get the operands to get the
609   // "actual" indices.
610   if (auto affineStoreOp =
611           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
612     AffineMap affineMap = affineStoreOp.getAffineMap();
613     auto expandedIndices = calculateExpandedAccessIndices(
614         affineMap, indices, storeOp.getLoc(), rewriter);
615     indices.assign(expandedIndices.begin(), expandedIndices.end());
616   }
617   SmallVector<Value> sourceIndices;
618   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
619       rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
620       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
621       sourceIndices);
622 
623   llvm::TypeSwitch<Operation *, void>(storeOp)
624       .Case([&](affine::AffineStoreOp op) {
625         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
626             op, op.getValue(), subViewOp.getSource(), sourceIndices);
627       })
628       .Case([&](memref::StoreOp op) {
629         rewriter.replaceOpWithNewOp<memref::StoreOp>(
630             op, op.getValue(), subViewOp.getSource(), sourceIndices,
631             op.getNontemporal());
632       })
633       .Case([&](vector::TransferWriteOp op) {
634         rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
635             op, op.getValue(), subViewOp.getSource(), sourceIndices,
636             AffineMapAttr::get(expandDimsToRank(
637                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
638                 subViewOp.getDroppedDims())),
639             op.getMask(), op.getInBoundsAttr());
640       })
641       .Case([&](vector::StoreOp op) {
642         rewriter.replaceOpWithNewOp<vector::StoreOp>(
643             op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
644       })
645       .Case([&](vector::MaskedStoreOp op) {
646         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
647             op, subViewOp.getSource(), sourceIndices, op.getMask(),
648             op.getValueToStore());
649       })
650       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
651         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
652             op, op.getSrc(), subViewOp.getSource(), sourceIndices,
653             op.getLeadDimension(), op.getTransposeAttr());
654       })
655       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
656   return success();
657 }
658 
659 template <typename OpTy>
660 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
661     OpTy storeOp, PatternRewriter &rewriter) const {
662   auto expandShapeOp =
663       getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
664 
665   if (!expandShapeOp)
666     return failure();
667 
668   SmallVector<Value> indices(storeOp.getIndices().begin(),
669                              storeOp.getIndices().end());
670   // For affine ops, we need to apply the map to get the operands to get the
671   // "actual" indices.
672   if (auto affineStoreOp =
673           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
674     AffineMap affineMap = affineStoreOp.getAffineMap();
675     auto expandedIndices = calculateExpandedAccessIndices(
676         affineMap, indices, storeOp.getLoc(), rewriter);
677     indices.assign(expandedIndices.begin(), expandedIndices.end());
678   }
679   SmallVector<Value> sourceIndices;
680   if (failed(resolveSourceIndicesExpandShape(
681           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
682     return failure();
683   llvm::TypeSwitch<Operation *, void>(storeOp)
684       .Case([&](affine::AffineStoreOp op) {
685         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
686             storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
687             sourceIndices);
688       })
689       .Case([&](memref::StoreOp op) {
690         rewriter.replaceOpWithNewOp<memref::StoreOp>(
691             storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
692             sourceIndices, op.getNontemporal());
693       })
694       .Case([&](vector::StoreOp op) {
695         rewriter.replaceOpWithNewOp<vector::StoreOp>(
696             op, op.getValueToStore(), expandShapeOp.getViewSource(),
697             sourceIndices, op.getNontemporal());
698       })
699       .Case([&](vector::MaskedStoreOp op) {
700         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
701             op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
702             op.getValueToStore());
703       })
704       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
705   return success();
706 }
707 
708 template <typename OpTy>
709 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
710     OpTy storeOp, PatternRewriter &rewriter) const {
711   auto collapseShapeOp = getMemRefOperand(storeOp)
712                              .template getDefiningOp<memref::CollapseShapeOp>();
713 
714   if (!collapseShapeOp)
715     return failure();
716 
717   SmallVector<Value> indices(storeOp.getIndices().begin(),
718                              storeOp.getIndices().end());
719   // For affine ops, we need to apply the map to get the operands to get the
720   // "actual" indices.
721   if (auto affineStoreOp =
722           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
723     AffineMap affineMap = affineStoreOp.getAffineMap();
724     auto expandedIndices = calculateExpandedAccessIndices(
725         affineMap, indices, storeOp.getLoc(), rewriter);
726     indices.assign(expandedIndices.begin(), expandedIndices.end());
727   }
728   SmallVector<Value> sourceIndices;
729   if (failed(resolveSourceIndicesCollapseShape(
730           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
731     return failure();
732   llvm::TypeSwitch<Operation *, void>(storeOp)
733       .Case([&](affine::AffineStoreOp op) {
734         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
735             storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
736             sourceIndices);
737       })
738       .Case([&](memref::StoreOp op) {
739         rewriter.replaceOpWithNewOp<memref::StoreOp>(
740             storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
741             sourceIndices, op.getNontemporal());
742       })
743       .Case([&](vector::StoreOp op) {
744         rewriter.replaceOpWithNewOp<vector::StoreOp>(
745             op, op.getValueToStore(), collapseShapeOp.getViewSource(),
746             sourceIndices, op.getNontemporal());
747       })
748       .Case([&](vector::MaskedStoreOp op) {
749         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
750             op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
751             op.getValueToStore());
752       })
753       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
754   return success();
755 }
756 
757 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
758     nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
759 
760   LLVM_DEBUG(DBGS() << "copyOp       : " << copyOp << "\n");
761 
762   auto srcSubViewOp =
763       copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
764   auto dstSubViewOp =
765       copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
766 
767   if (!(srcSubViewOp || dstSubViewOp))
768     return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
769                                                "source or destination");
770 
771   // If the source is a subview, we need to resolve the indices.
772   SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
773                                 copyOp.getSrcIndices().end());
774   SmallVector<Value> foldedSrcIndices(srcindices);
775 
776   if (srcSubViewOp) {
777     LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
778     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
779         rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
780         srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
781         srcindices, foldedSrcIndices);
782   }
783 
784   // If the destination is a subview, we need to resolve the indices.
785   SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
786                                 copyOp.getDstIndices().end());
787   SmallVector<Value> foldedDstIndices(dstindices);
788 
789   if (dstSubViewOp) {
790     LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
791     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
792         rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
793         dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
794         dstindices, foldedDstIndices);
795   }
796 
797   // Replace the copy op with a new copy op that uses the source and destination
798   // of the subview.
799   rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
800       copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
801       (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
802       foldedDstIndices,
803       (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
804       foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
805       copyOp.getBypassL1Attr());
806 
807   return success();
808 }
809 
810 void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
811   patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
812                LoadOpOfSubViewOpFolder<memref::LoadOp>,
813                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
814                LoadOpOfSubViewOpFolder<vector::LoadOp>,
815                LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
816                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
817                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
818                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
819                StoreOpOfSubViewOpFolder<memref::StoreOp>,
820                StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
821                StoreOpOfSubViewOpFolder<vector::StoreOp>,
822                StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
823                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
824                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
825                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
826                LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
827                LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
828                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
829                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
830                StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
831                StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
832                LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
833                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
834                LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
835                LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
836                StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
837                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
838                StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
839                StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
840                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
841       patterns.getContext());
842 }
843 
844 //===----------------------------------------------------------------------===//
845 // Pass registration
846 //===----------------------------------------------------------------------===//
847 
848 namespace {
849 
850 struct FoldMemRefAliasOpsPass final
851     : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
852   void runOnOperation() override;
853 };
854 
855 } // namespace
856 
857 void FoldMemRefAliasOpsPass::runOnOperation() {
858   RewritePatternSet patterns(&getContext());
859   memref::populateFoldMemRefAliasOpPatterns(patterns);
860   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
861 }
862 
863 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
864   return std::make_unique<FoldMemRefAliasOpsPass>();
865 }
866