xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
11b002d27SArnab Dutta //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
21b002d27SArnab Dutta //
31b002d27SArnab Dutta // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41b002d27SArnab Dutta // See https://llvm.org/LICENSE.txt for license information.
51b002d27SArnab Dutta // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61b002d27SArnab Dutta //
71b002d27SArnab Dutta //===----------------------------------------------------------------------===//
81b002d27SArnab Dutta //
91b002d27SArnab Dutta // This transformation pass folds loading/storing from/to subview ops into
101b002d27SArnab Dutta // loading/storing from/to the original memref.
111b002d27SArnab Dutta //
121b002d27SArnab Dutta //===----------------------------------------------------------------------===//
131b002d27SArnab Dutta 
141b002d27SArnab Dutta #include "mlir/Dialect/Affine/IR/AffineOps.h"
154dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
17829446cbSNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
18829446cbSNicolas Vasilache #include "mlir/Dialect/GPU/IR/GPUDialect.h"
191b002d27SArnab Dutta #include "mlir/Dialect/MemRef/IR/MemRef.h"
20829446cbSNicolas Vasilache #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
226ed8434eSPrathamesh Tagore #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23fc5c1a76SManish Gupta #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
241b002d27SArnab Dutta #include "mlir/Dialect/Utils/IndexingUtils.h"
251b002d27SArnab Dutta #include "mlir/Dialect/Vector/IR/VectorOps.h"
264dc72d47SNicolas Vasilache #include "mlir/IR/AffineMap.h"
271b002d27SArnab Dutta #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
284dc72d47SNicolas Vasilache #include "llvm/ADT/STLExtras.h"
291b002d27SArnab Dutta #include "llvm/ADT/SmallBitVector.h"
301b002d27SArnab Dutta #include "llvm/ADT/TypeSwitch.h"
31fc5c1a76SManish Gupta #include "llvm/Support/Debug.h"
32fc5c1a76SManish Gupta 
33fc5c1a76SManish Gupta #define DEBUG_TYPE "fold-memref-alias-ops"
34fc5c1a76SManish Gupta #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
351b002d27SArnab Dutta 
3667d0d7acSMichele Scuttari namespace mlir {
3767d0d7acSMichele Scuttari namespace memref {
3867d0d7acSMichele Scuttari #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
3967d0d7acSMichele Scuttari #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
4067d0d7acSMichele Scuttari } // namespace memref
4167d0d7acSMichele Scuttari } // namespace mlir
4267d0d7acSMichele Scuttari 
431b002d27SArnab Dutta using namespace mlir;
441b002d27SArnab Dutta 
451b002d27SArnab Dutta //===----------------------------------------------------------------------===//
461b002d27SArnab Dutta // Utility functions
471b002d27SArnab Dutta //===----------------------------------------------------------------------===//
481b002d27SArnab Dutta 
491b002d27SArnab Dutta /// Given the 'indices' of a load/store operation where the memref is a result
501b002d27SArnab Dutta /// of a expand_shape op, returns the indices w.r.t to the source memref of the
511b002d27SArnab Dutta /// expand_shape op. For example
521b002d27SArnab Dutta ///
531b002d27SArnab Dutta /// %0 = ... : memref<12x42xf32>
541b002d27SArnab Dutta /// %1 = memref.expand_shape %0 [[0, 1], [2]]
551b002d27SArnab Dutta ///    : memref<12x42xf32> into memref<2x6x42xf32>
561b002d27SArnab Dutta /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
571b002d27SArnab Dutta ///
581b002d27SArnab Dutta /// could be folded into
591b002d27SArnab Dutta ///
601b002d27SArnab Dutta /// %2 = load %0[6 * i1 + i2, %i3] :
611b002d27SArnab Dutta ///          memref<12x42xf32>
621b002d27SArnab Dutta static LogicalResult
631b002d27SArnab Dutta resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
641b002d27SArnab Dutta                                 memref::ExpandShapeOp expandShapeOp,
651b002d27SArnab Dutta                                 ValueRange indices,
661b002d27SArnab Dutta                                 SmallVectorImpl<Value> &sourceIndices) {
676ed8434eSPrathamesh Tagore   // Record the rewriter context for constructing ops later.
686ed8434eSPrathamesh Tagore   MLIRContext *ctx = rewriter.getContext();
696ed8434eSPrathamesh Tagore 
706ed8434eSPrathamesh Tagore   // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
716ed8434eSPrathamesh Tagore   // This is done for the purpose of inferring the output shape via
726ed8434eSPrathamesh Tagore   // `inferExpandOutputShape` which will in turn be used for suffix product
736ed8434eSPrathamesh Tagore   // calculation later.
746ed8434eSPrathamesh Tagore   SmallVector<OpFoldResult> srcShape;
756ed8434eSPrathamesh Tagore   MemRefType srcType = expandShapeOp.getSrcType();
766ed8434eSPrathamesh Tagore 
776ed8434eSPrathamesh Tagore   for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
786ed8434eSPrathamesh Tagore     if (srcType.isDynamicDim(i)) {
796ed8434eSPrathamesh Tagore       srcShape.push_back(
806ed8434eSPrathamesh Tagore           rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
816ed8434eSPrathamesh Tagore               .getResult());
826ed8434eSPrathamesh Tagore     } else {
836ed8434eSPrathamesh Tagore       srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
846ed8434eSPrathamesh Tagore     }
856ed8434eSPrathamesh Tagore   }
866ed8434eSPrathamesh Tagore 
876ed8434eSPrathamesh Tagore   auto outputShape = inferExpandShapeOutputShape(
886ed8434eSPrathamesh Tagore       rewriter, loc, expandShapeOp.getResultType(),
896ed8434eSPrathamesh Tagore       expandShapeOp.getReassociationIndices(), srcShape);
906ed8434eSPrathamesh Tagore   if (!outputShape.has_value())
91f6897c37SHanhan Wang     return failure();
92f6897c37SHanhan Wang 
936ed8434eSPrathamesh Tagore   // Traverse all reassociation groups to determine the appropriate indices
946ed8434eSPrathamesh Tagore   // corresponding to each one of them post op folding.
95203fad47SNicolas Vasilache   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
961b002d27SArnab Dutta     assert(!groups.empty() && "association indices groups cannot be empty");
976ed8434eSPrathamesh Tagore     // Flag to indicate the presence of dynamic dimensions in current
986ed8434eSPrathamesh Tagore     // reassociation group.
99203fad47SNicolas Vasilache     int64_t groupSize = groups.size();
100203fad47SNicolas Vasilache 
1016ed8434eSPrathamesh Tagore     // Group output dimensions utilized in this reassociation group for suffix
1026ed8434eSPrathamesh Tagore     // product calculation.
1036ed8434eSPrathamesh Tagore     SmallVector<OpFoldResult> sizesVal(groupSize);
1046ed8434eSPrathamesh Tagore     for (int64_t i = 0; i < groupSize; ++i) {
1056ed8434eSPrathamesh Tagore       sizesVal[i] = (*outputShape)[groups[i]];
1066ed8434eSPrathamesh Tagore     }
107203fad47SNicolas Vasilache 
1086ed8434eSPrathamesh Tagore     // Calculate suffix product of relevant output dimension sizes.
1096ed8434eSPrathamesh Tagore     SmallVector<OpFoldResult> suffixProduct =
1106ed8434eSPrathamesh Tagore         memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
1116ed8434eSPrathamesh Tagore 
1126ed8434eSPrathamesh Tagore     // Create affine expression variables for dimensions and symbols in the
1136ed8434eSPrathamesh Tagore     // newly constructed affine map.
1146ed8434eSPrathamesh Tagore     SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
1156ed8434eSPrathamesh Tagore     bindDimsList<AffineExpr>(ctx, dims);
1166ed8434eSPrathamesh Tagore     bindSymbolsList<AffineExpr>(ctx, symbols);
1176ed8434eSPrathamesh Tagore 
1186ed8434eSPrathamesh Tagore     // Linearize binded dimensions and symbols to construct the resultant
1196ed8434eSPrathamesh Tagore     // affine expression for this indice.
1206ed8434eSPrathamesh Tagore     AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
1216ed8434eSPrathamesh Tagore 
1226ed8434eSPrathamesh Tagore     // Record the load index corresponding to each dimension in the
1236ed8434eSPrathamesh Tagore     // reassociation group. These are later supplied as operands to the affine
1246ed8434eSPrathamesh Tagore     // map used for calulating relevant index post op folding.
125829446cbSNicolas Vasilache     SmallVector<OpFoldResult> dynamicIndices(groupSize);
126203fad47SNicolas Vasilache     for (int64_t i = 0; i < groupSize; i++)
127203fad47SNicolas Vasilache       dynamicIndices[i] = indices[groups[i]];
128829446cbSNicolas Vasilache 
1296ed8434eSPrathamesh Tagore     // Supply suffix product results followed by load op indices as operands
1306ed8434eSPrathamesh Tagore     // to the map.
1316ed8434eSPrathamesh Tagore     SmallVector<OpFoldResult> mapOperands;
1326ed8434eSPrathamesh Tagore     llvm::append_range(mapOperands, suffixProduct);
1336ed8434eSPrathamesh Tagore     llvm::append_range(mapOperands, dynamicIndices);
1346ed8434eSPrathamesh Tagore 
1356ed8434eSPrathamesh Tagore     // Creating maximally folded and composed affine.apply composes better
1366ed8434eSPrathamesh Tagore     // with other transformations without interleaving canonicalization
1376ed8434eSPrathamesh Tagore     // passes.
1384c48f016SMatthias Springer     OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
139829446cbSNicolas Vasilache         rewriter, loc,
140829446cbSNicolas Vasilache         AffineMap::get(/*numDims=*/groupSize,
1416ed8434eSPrathamesh Tagore                        /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
1426ed8434eSPrathamesh Tagore         mapOperands);
1436ed8434eSPrathamesh Tagore 
1446ed8434eSPrathamesh Tagore     // Push index value in the op post folding corresponding to this
1456ed8434eSPrathamesh Tagore     // reassociation group.
146829446cbSNicolas Vasilache     sourceIndices.push_back(
147829446cbSNicolas Vasilache         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
1481b002d27SArnab Dutta   }
1491b002d27SArnab Dutta   return success();
1501b002d27SArnab Dutta }
1511b002d27SArnab Dutta 
1521b002d27SArnab Dutta /// Given the 'indices' of a load/store operation where the memref is a result
1531b002d27SArnab Dutta /// of a collapse_shape op, returns the indices w.r.t to the source memref of
1541b002d27SArnab Dutta /// the collapse_shape op. For example
1551b002d27SArnab Dutta ///
1561b002d27SArnab Dutta /// %0 = ... : memref<2x6x42xf32>
1571b002d27SArnab Dutta /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
1581b002d27SArnab Dutta ///    : memref<2x6x42xf32> into memref<12x42xf32>
1591b002d27SArnab Dutta /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
1601b002d27SArnab Dutta ///
1611b002d27SArnab Dutta /// could be folded into
1621b002d27SArnab Dutta ///
1631b002d27SArnab Dutta /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
1641b002d27SArnab Dutta ///          memref<2x6x42xf32>
1651b002d27SArnab Dutta static LogicalResult
1661b002d27SArnab Dutta resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
1671b002d27SArnab Dutta                                   memref::CollapseShapeOp collapseShapeOp,
1681b002d27SArnab Dutta                                   ValueRange indices,
1691b002d27SArnab Dutta                                   SmallVectorImpl<Value> &sourceIndices) {
170203fad47SNicolas Vasilache   int64_t cnt = 0;
1711b002d27SArnab Dutta   SmallVector<Value> tmp(indices.size());
172829446cbSNicolas Vasilache   SmallVector<OpFoldResult> dynamicIndices;
173203fad47SNicolas Vasilache   for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
1741b002d27SArnab Dutta     assert(!groups.empty() && "association indices groups cannot be empty");
1751b002d27SArnab Dutta     dynamicIndices.push_back(indices[cnt++]);
176203fad47SNicolas Vasilache     int64_t groupSize = groups.size();
177203fad47SNicolas Vasilache 
178f32b3e1cSFelix Schneider     // Calculate suffix product for all collapse op source dimension sizes
179f32b3e1cSFelix Schneider     // except the most major one of each group.
180f32b3e1cSFelix Schneider     // We allow the most major source dimension to be dynamic but enforce all
181f32b3e1cSFelix Schneider     // others to be known statically.
182f32b3e1cSFelix Schneider     SmallVector<int64_t> sizes(groupSize, 1);
183f32b3e1cSFelix Schneider     for (int64_t i = 1; i < groupSize; ++i) {
184203fad47SNicolas Vasilache       sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
185f32b3e1cSFelix Schneider       if (sizes[i] == ShapedType::kDynamic)
186f32b3e1cSFelix Schneider         return failure();
187f32b3e1cSFelix Schneider     }
188203fad47SNicolas Vasilache     SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
189203fad47SNicolas Vasilache 
1901b002d27SArnab Dutta     // Derive the index values along all dimensions of the source corresponding
1911b002d27SArnab Dutta     // to the index wrt to collapsed shape op output.
192203fad47SNicolas Vasilache     auto d0 = rewriter.getAffineDimExpr(0);
193203fad47SNicolas Vasilache     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
194203fad47SNicolas Vasilache 
195203fad47SNicolas Vasilache     // Construct the AffineApplyOp for each delinearizingExpr.
196829446cbSNicolas Vasilache     for (int64_t i = 0; i < groupSize; i++) {
1974c48f016SMatthias Springer       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
198829446cbSNicolas Vasilache           rewriter, loc,
199203fad47SNicolas Vasilache           AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
200203fad47SNicolas Vasilache                          delinearizingExprs[i]),
201829446cbSNicolas Vasilache           dynamicIndices);
202829446cbSNicolas Vasilache       sourceIndices.push_back(
203829446cbSNicolas Vasilache           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
204829446cbSNicolas Vasilache     }
2051b002d27SArnab Dutta     dynamicIndices.clear();
2061b002d27SArnab Dutta   }
2071b002d27SArnab Dutta   if (collapseShapeOp.getReassociationIndices().empty()) {
2081b002d27SArnab Dutta     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
209203fad47SNicolas Vasilache     int64_t srcRank =
2105550c821STres Popp         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
211829446cbSNicolas Vasilache     for (int64_t i = 0; i < srcRank; i++) {
2124c48f016SMatthias Springer       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
213829446cbSNicolas Vasilache           rewriter, loc, zeroAffineMap, dynamicIndices);
2141b002d27SArnab Dutta       sourceIndices.push_back(
215829446cbSNicolas Vasilache           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
216829446cbSNicolas Vasilache     }
2171b002d27SArnab Dutta   }
2181b002d27SArnab Dutta   return success();
2191b002d27SArnab Dutta }
2201b002d27SArnab Dutta 
2211b002d27SArnab Dutta /// Helpers to access the memref operand for each op.
2221b002d27SArnab Dutta template <typename LoadOrStoreOpTy>
2231b002d27SArnab Dutta static Value getMemRefOperand(LoadOrStoreOpTy op) {
2241b002d27SArnab Dutta   return op.getMemref();
2251b002d27SArnab Dutta }
2261b002d27SArnab Dutta 
2271b002d27SArnab Dutta static Value getMemRefOperand(vector::TransferReadOp op) {
2281b002d27SArnab Dutta   return op.getSource();
2291b002d27SArnab Dutta }
2301b002d27SArnab Dutta 
23146c32afbSGuray Ozen static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
23246c32afbSGuray Ozen   return op.getSrcMemref();
23346c32afbSGuray Ozen }
23446c32afbSGuray Ozen 
2355ec360c5SGuray Ozen static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
2365ec360c5SGuray Ozen 
237dae3c44cSMax191 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
238dae3c44cSMax191 
2395aa2c65aStyb0807 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
2405aa2c65aStyb0807 
241dae3c44cSMax191 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
242dae3c44cSMax191 
2431b002d27SArnab Dutta static Value getMemRefOperand(vector::TransferWriteOp op) {
2441b002d27SArnab Dutta   return op.getSource();
2451b002d27SArnab Dutta }
2461b002d27SArnab Dutta 
24759e4fbfcSLei Zhang static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
24859e4fbfcSLei Zhang   return op.getSrcMemref();
24959e4fbfcSLei Zhang }
25059e4fbfcSLei Zhang 
25159e4fbfcSLei Zhang static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
25259e4fbfcSLei Zhang   return op.getDstMemref();
25359e4fbfcSLei Zhang }
25459e4fbfcSLei Zhang 
2551b002d27SArnab Dutta //===----------------------------------------------------------------------===//
2561b002d27SArnab Dutta // Patterns
2571b002d27SArnab Dutta //===----------------------------------------------------------------------===//
2581b002d27SArnab Dutta 
2591b002d27SArnab Dutta namespace {
2601b002d27SArnab Dutta /// Merges subview operation with load/transferRead operation.
2611b002d27SArnab Dutta template <typename OpTy>
2621b002d27SArnab Dutta class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
2631b002d27SArnab Dutta public:
2641b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
2651b002d27SArnab Dutta 
2661b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy loadOp,
2671b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
2681b002d27SArnab Dutta };
2691b002d27SArnab Dutta 
2701b002d27SArnab Dutta /// Merges expand_shape operation with load/transferRead operation.
2711b002d27SArnab Dutta template <typename OpTy>
2721b002d27SArnab Dutta class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
2731b002d27SArnab Dutta public:
2741b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
2751b002d27SArnab Dutta 
2761b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy loadOp,
2771b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
2781b002d27SArnab Dutta };
2791b002d27SArnab Dutta 
2801b002d27SArnab Dutta /// Merges collapse_shape operation with load/transferRead operation.
2811b002d27SArnab Dutta template <typename OpTy>
2821b002d27SArnab Dutta class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
2831b002d27SArnab Dutta public:
2841b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
2851b002d27SArnab Dutta 
2861b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy loadOp,
2871b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
2881b002d27SArnab Dutta };
2891b002d27SArnab Dutta 
2901b002d27SArnab Dutta /// Merges subview operation with store/transferWriteOp operation.
2911b002d27SArnab Dutta template <typename OpTy>
2921b002d27SArnab Dutta class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
2931b002d27SArnab Dutta public:
2941b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
2951b002d27SArnab Dutta 
2961b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy storeOp,
2971b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
2981b002d27SArnab Dutta };
2991b002d27SArnab Dutta 
3001b002d27SArnab Dutta /// Merges expand_shape operation with store/transferWriteOp operation.
3011b002d27SArnab Dutta template <typename OpTy>
3021b002d27SArnab Dutta class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
3031b002d27SArnab Dutta public:
3041b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
3051b002d27SArnab Dutta 
3061b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy storeOp,
3071b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
3081b002d27SArnab Dutta };
3091b002d27SArnab Dutta 
3101b002d27SArnab Dutta /// Merges collapse_shape operation with store/transferWriteOp operation.
3111b002d27SArnab Dutta template <typename OpTy>
3121b002d27SArnab Dutta class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
3131b002d27SArnab Dutta public:
3141b002d27SArnab Dutta   using OpRewritePattern<OpTy>::OpRewritePattern;
3151b002d27SArnab Dutta 
3161b002d27SArnab Dutta   LogicalResult matchAndRewrite(OpTy storeOp,
3171b002d27SArnab Dutta                                 PatternRewriter &rewriter) const override;
3181b002d27SArnab Dutta };
3191b002d27SArnab Dutta 
320ccb8a4e3SMatthias Springer /// Folds subview(subview(x)) to a single subview(x).
321ccb8a4e3SMatthias Springer class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
322ccb8a4e3SMatthias Springer public:
323ccb8a4e3SMatthias Springer   using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
324ccb8a4e3SMatthias Springer 
325ccb8a4e3SMatthias Springer   LogicalResult matchAndRewrite(memref::SubViewOp subView,
326ccb8a4e3SMatthias Springer                                 PatternRewriter &rewriter) const override {
327ccb8a4e3SMatthias Springer     auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
328ccb8a4e3SMatthias Springer     if (!srcSubView)
329ccb8a4e3SMatthias Springer       return failure();
330ccb8a4e3SMatthias Springer 
33133468a51SNicolas Vasilache     // TODO: relax unit stride assumption.
33233468a51SNicolas Vasilache     if (!subView.hasUnitStride()) {
33333468a51SNicolas Vasilache       return rewriter.notifyMatchFailure(subView, "requires unit strides");
334ccb8a4e3SMatthias Springer     }
33533468a51SNicolas Vasilache     if (!srcSubView.hasUnitStride()) {
33633468a51SNicolas Vasilache       return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
337ccb8a4e3SMatthias Springer     }
338ccb8a4e3SMatthias Springer 
33933468a51SNicolas Vasilache     // Resolve sizes according to dropped dims.
34033468a51SNicolas Vasilache     SmallVector<OpFoldResult> resolvedSizes;
34133468a51SNicolas Vasilache     llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
3424c48f016SMatthias Springer     affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
34333468a51SNicolas Vasilache                                         subView.getMixedSizes(), srcDroppedDims,
34433468a51SNicolas Vasilache                                         resolvedSizes);
34533468a51SNicolas Vasilache 
34633468a51SNicolas Vasilache     // Resolve offsets according to source offsets and strides.
34733468a51SNicolas Vasilache     SmallVector<Value> resolvedOffsets;
3484c48f016SMatthias Springer     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
34933468a51SNicolas Vasilache         rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
35033468a51SNicolas Vasilache         srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
35133468a51SNicolas Vasilache         resolvedOffsets);
35233468a51SNicolas Vasilache 
353ccb8a4e3SMatthias Springer     // Replace original op.
354ccb8a4e3SMatthias Springer     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
35533468a51SNicolas Vasilache         subView, subView.getType(), srcSubView.getSource(),
35633468a51SNicolas Vasilache         getAsOpFoldResult(resolvedOffsets), resolvedSizes,
35733468a51SNicolas Vasilache         srcSubView.getMixedStrides());
35833468a51SNicolas Vasilache 
359ccb8a4e3SMatthias Springer     return success();
360ccb8a4e3SMatthias Springer   }
361ccb8a4e3SMatthias Springer };
362fc5c1a76SManish Gupta 
363fc5c1a76SManish Gupta /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
364fc5c1a76SManish Gupta /// is folds subview on src and dst memref of the copy.
365baa5beecStyb0807 class NVGPUAsyncCopyOpSubViewOpFolder final
366fc5c1a76SManish Gupta     : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
367fc5c1a76SManish Gupta public:
368fc5c1a76SManish Gupta   using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
369fc5c1a76SManish Gupta 
370fc5c1a76SManish Gupta   LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
371fc5c1a76SManish Gupta                                 PatternRewriter &rewriter) const override;
372fc5c1a76SManish Gupta };
3731b002d27SArnab Dutta } // namespace
3741b002d27SArnab Dutta 
3751b002d27SArnab Dutta static SmallVector<Value>
3762fe37d1cSMehdi Amini calculateExpandedAccessIndices(AffineMap affineMap,
3772fe37d1cSMehdi Amini                                const SmallVector<Value> &indices, Location loc,
3782fe37d1cSMehdi Amini                                PatternRewriter &rewriter) {
379829446cbSNicolas Vasilache   SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
380829446cbSNicolas Vasilache       llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
3811b002d27SArnab Dutta   SmallVector<Value> expandedIndices;
382829446cbSNicolas Vasilache   for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
3834c48f016SMatthias Springer     OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
384829446cbSNicolas Vasilache         rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
3851b002d27SArnab Dutta     expandedIndices.push_back(
386829446cbSNicolas Vasilache         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
387829446cbSNicolas Vasilache   }
3881b002d27SArnab Dutta   return expandedIndices;
3891b002d27SArnab Dutta }
3901b002d27SArnab Dutta 
3914dc72d47SNicolas Vasilache template <typename XferOp>
3924dc72d47SNicolas Vasilache static LogicalResult
3934dc72d47SNicolas Vasilache preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
3944dc72d47SNicolas Vasilache                                memref::SubViewOp subviewOp) {
3954dc72d47SNicolas Vasilache   static_assert(
3964dc72d47SNicolas Vasilache       !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
3974dc72d47SNicolas Vasilache       "must be a vector transfer op");
3984dc72d47SNicolas Vasilache   if (xferOp.hasOutOfBoundsDim())
3994dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
4004dc72d47SNicolas Vasilache   if (!subviewOp.hasUnitStride()) {
4014dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(
4024dc72d47SNicolas Vasilache         xferOp, "non-1 stride subview, need to track strides in folded memref");
4034dc72d47SNicolas Vasilache   }
4044dc72d47SNicolas Vasilache   return success();
4054dc72d47SNicolas Vasilache }
4064dc72d47SNicolas Vasilache 
4074dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
4084dc72d47SNicolas Vasilache                                                 Operation *op,
4094dc72d47SNicolas Vasilache                                                 memref::SubViewOp subviewOp) {
4104dc72d47SNicolas Vasilache   return success();
4114dc72d47SNicolas Vasilache }
4124dc72d47SNicolas Vasilache 
4134dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
4144dc72d47SNicolas Vasilache                                                 vector::TransferReadOp readOp,
4154dc72d47SNicolas Vasilache                                                 memref::SubViewOp subviewOp) {
4164dc72d47SNicolas Vasilache   return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
4174dc72d47SNicolas Vasilache }
4184dc72d47SNicolas Vasilache 
4194dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
4204dc72d47SNicolas Vasilache                                                 vector::TransferWriteOp writeOp,
4214dc72d47SNicolas Vasilache                                                 memref::SubViewOp subviewOp) {
4224dc72d47SNicolas Vasilache   return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
4234dc72d47SNicolas Vasilache }
4244dc72d47SNicolas Vasilache 
4251b002d27SArnab Dutta template <typename OpTy>
4261b002d27SArnab Dutta LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
4271b002d27SArnab Dutta     OpTy loadOp, PatternRewriter &rewriter) const {
4281b002d27SArnab Dutta   auto subViewOp =
4291b002d27SArnab Dutta       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
4301b002d27SArnab Dutta 
4311b002d27SArnab Dutta   if (!subViewOp)
4324dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
4334dc72d47SNicolas Vasilache 
4344dc72d47SNicolas Vasilache   LogicalResult preconditionResult =
4354dc72d47SNicolas Vasilache       preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
4364dc72d47SNicolas Vasilache   if (failed(preconditionResult))
4374dc72d47SNicolas Vasilache     return preconditionResult;
4381b002d27SArnab Dutta 
4391b002d27SArnab Dutta   SmallVector<Value> indices(loadOp.getIndices().begin(),
4401b002d27SArnab Dutta                              loadOp.getIndices().end());
4411b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
4421b002d27SArnab Dutta   // "actual" indices.
4434c48f016SMatthias Springer   if (auto affineLoadOp =
4444c48f016SMatthias Springer           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
4451b002d27SArnab Dutta     AffineMap affineMap = affineLoadOp.getAffineMap();
4461b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
4471b002d27SArnab Dutta         affineMap, indices, loadOp.getLoc(), rewriter);
4481b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
4491b002d27SArnab Dutta   }
450203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
4514c48f016SMatthias Springer   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
4524dc72d47SNicolas Vasilache       rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
4534dc72d47SNicolas Vasilache       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
4544dc72d47SNicolas Vasilache       sourceIndices);
455b7d47ed1SNicolas Vasilache 
4561b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(loadOp)
4574c48f016SMatthias Springer       .Case([&](affine::AffineLoadOp op) {
4584c48f016SMatthias Springer         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
4594c48f016SMatthias Springer             loadOp, subViewOp.getSource(), sourceIndices);
4601b002d27SArnab Dutta       })
4611cb91b42SGuray Ozen       .Case([&](memref::LoadOp op) {
4621cb91b42SGuray Ozen         rewriter.replaceOpWithNewOp<memref::LoadOp>(
4631cb91b42SGuray Ozen             loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
4641cb91b42SGuray Ozen       })
4655ec360c5SGuray Ozen       .Case([&](vector::LoadOp op) {
4665ec360c5SGuray Ozen         rewriter.replaceOpWithNewOp<vector::LoadOp>(
4675ec360c5SGuray Ozen             op, op.getType(), subViewOp.getSource(), sourceIndices);
4685ec360c5SGuray Ozen       })
4695aa2c65aStyb0807       .Case([&](vector::MaskedLoadOp op) {
4705aa2c65aStyb0807         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
4715aa2c65aStyb0807             op, op.getType(), subViewOp.getSource(), sourceIndices,
4725aa2c65aStyb0807             op.getMask(), op.getPassThru());
4735aa2c65aStyb0807       })
4744dc72d47SNicolas Vasilache       .Case([&](vector::TransferReadOp op) {
4751b002d27SArnab Dutta         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
4764dc72d47SNicolas Vasilache             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
4774dc72d47SNicolas Vasilache             AffineMapAttr::get(expandDimsToRank(
4784dc72d47SNicolas Vasilache                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
4794dc72d47SNicolas Vasilache                 subViewOp.getDroppedDims())),
48048f980c5SQuinn Dawkins             op.getPadding(), op.getMask(), op.getInBoundsAttr());
4811b002d27SArnab Dutta       })
48259e4fbfcSLei Zhang       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
48359e4fbfcSLei Zhang         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
48459e4fbfcSLei Zhang             op, op.getType(), subViewOp.getSource(), sourceIndices,
48559e4fbfcSLei Zhang             op.getLeadDimension(), op.getTransposeAttr());
48659e4fbfcSLei Zhang       })
48746c32afbSGuray Ozen       .Case([&](nvgpu::LdMatrixOp op) {
48846c32afbSGuray Ozen         rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
48946c32afbSGuray Ozen             op, op.getType(), subViewOp.getSource(), sourceIndices,
49046c32afbSGuray Ozen             op.getTranspose(), op.getNumTiles());
49146c32afbSGuray Ozen       })
4921b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
4931b002d27SArnab Dutta   return success();
4941b002d27SArnab Dutta }
4951b002d27SArnab Dutta 
4961b002d27SArnab Dutta template <typename OpTy>
4971b002d27SArnab Dutta LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
4981b002d27SArnab Dutta     OpTy loadOp, PatternRewriter &rewriter) const {
4991b002d27SArnab Dutta   auto expandShapeOp =
5001b002d27SArnab Dutta       getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
5011b002d27SArnab Dutta 
5021b002d27SArnab Dutta   if (!expandShapeOp)
5031b002d27SArnab Dutta     return failure();
5041b002d27SArnab Dutta 
5051b002d27SArnab Dutta   SmallVector<Value> indices(loadOp.getIndices().begin(),
5061b002d27SArnab Dutta                              loadOp.getIndices().end());
5071b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
5081b002d27SArnab Dutta   // "actual" indices.
5094c48f016SMatthias Springer   if (auto affineLoadOp =
5104c48f016SMatthias Springer           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
5111b002d27SArnab Dutta     AffineMap affineMap = affineLoadOp.getAffineMap();
5121b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
5131b002d27SArnab Dutta         affineMap, indices, loadOp.getLoc(), rewriter);
5141b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
5151b002d27SArnab Dutta   }
516203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
5171b002d27SArnab Dutta   if (failed(resolveSourceIndicesExpandShape(
5181b002d27SArnab Dutta           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
5191b002d27SArnab Dutta     return failure();
5201b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(loadOp)
52157e43608SKunwar Grover       .Case([&](affine::AffineLoadOp op) {
52257e43608SKunwar Grover         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
5231b002d27SArnab Dutta             loadOp, expandShapeOp.getViewSource(), sourceIndices);
5241b002d27SArnab Dutta       })
52557e43608SKunwar Grover       .Case([&](memref::LoadOp op) {
52657e43608SKunwar Grover         rewriter.replaceOpWithNewOp<memref::LoadOp>(
52757e43608SKunwar Grover             loadOp, expandShapeOp.getViewSource(), sourceIndices,
52857e43608SKunwar Grover             op.getNontemporal());
52957e43608SKunwar Grover       })
53057e43608SKunwar Grover       .Case([&](vector::LoadOp op) {
53157e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::LoadOp>(
53257e43608SKunwar Grover             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
53357e43608SKunwar Grover             op.getNontemporal());
53457e43608SKunwar Grover       })
53557e43608SKunwar Grover       .Case([&](vector::MaskedLoadOp op) {
53657e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
53757e43608SKunwar Grover             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
53857e43608SKunwar Grover             op.getMask(), op.getPassThru());
53957e43608SKunwar Grover       })
5401b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
5411b002d27SArnab Dutta   return success();
5421b002d27SArnab Dutta }
5431b002d27SArnab Dutta 
5441b002d27SArnab Dutta template <typename OpTy>
5451b002d27SArnab Dutta LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
5461b002d27SArnab Dutta     OpTy loadOp, PatternRewriter &rewriter) const {
5471b002d27SArnab Dutta   auto collapseShapeOp = getMemRefOperand(loadOp)
5481b002d27SArnab Dutta                              .template getDefiningOp<memref::CollapseShapeOp>();
5491b002d27SArnab Dutta 
5501b002d27SArnab Dutta   if (!collapseShapeOp)
5511b002d27SArnab Dutta     return failure();
5521b002d27SArnab Dutta 
5531b002d27SArnab Dutta   SmallVector<Value> indices(loadOp.getIndices().begin(),
5541b002d27SArnab Dutta                              loadOp.getIndices().end());
5551b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
5561b002d27SArnab Dutta   // "actual" indices.
5574c48f016SMatthias Springer   if (auto affineLoadOp =
5584c48f016SMatthias Springer           dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
5591b002d27SArnab Dutta     AffineMap affineMap = affineLoadOp.getAffineMap();
5601b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
5611b002d27SArnab Dutta         affineMap, indices, loadOp.getLoc(), rewriter);
5621b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
5631b002d27SArnab Dutta   }
564203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
5651b002d27SArnab Dutta   if (failed(resolveSourceIndicesCollapseShape(
5661b002d27SArnab Dutta           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
5671b002d27SArnab Dutta     return failure();
5681b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(loadOp)
56957e43608SKunwar Grover       .Case([&](affine::AffineLoadOp op) {
57057e43608SKunwar Grover         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
5711b002d27SArnab Dutta             loadOp, collapseShapeOp.getViewSource(), sourceIndices);
5721b002d27SArnab Dutta       })
57357e43608SKunwar Grover       .Case([&](memref::LoadOp op) {
57457e43608SKunwar Grover         rewriter.replaceOpWithNewOp<memref::LoadOp>(
57557e43608SKunwar Grover             loadOp, collapseShapeOp.getViewSource(), sourceIndices,
57657e43608SKunwar Grover             op.getNontemporal());
57757e43608SKunwar Grover       })
57857e43608SKunwar Grover       .Case([&](vector::LoadOp op) {
57957e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::LoadOp>(
58057e43608SKunwar Grover             op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
58157e43608SKunwar Grover             op.getNontemporal());
58257e43608SKunwar Grover       })
58357e43608SKunwar Grover       .Case([&](vector::MaskedLoadOp op) {
58457e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
58557e43608SKunwar Grover             op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
58657e43608SKunwar Grover             op.getMask(), op.getPassThru());
58757e43608SKunwar Grover       })
5881b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
5891b002d27SArnab Dutta   return success();
5901b002d27SArnab Dutta }
5911b002d27SArnab Dutta 
5921b002d27SArnab Dutta template <typename OpTy>
5931b002d27SArnab Dutta LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
5941b002d27SArnab Dutta     OpTy storeOp, PatternRewriter &rewriter) const {
5951b002d27SArnab Dutta   auto subViewOp =
5961b002d27SArnab Dutta       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
5971b002d27SArnab Dutta 
5981b002d27SArnab Dutta   if (!subViewOp)
5994dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
6004dc72d47SNicolas Vasilache 
6014dc72d47SNicolas Vasilache   LogicalResult preconditionResult =
6024dc72d47SNicolas Vasilache       preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
6034dc72d47SNicolas Vasilache   if (failed(preconditionResult))
6044dc72d47SNicolas Vasilache     return preconditionResult;
6051b002d27SArnab Dutta 
6061b002d27SArnab Dutta   SmallVector<Value> indices(storeOp.getIndices().begin(),
6071b002d27SArnab Dutta                              storeOp.getIndices().end());
6081b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
6091b002d27SArnab Dutta   // "actual" indices.
6104c48f016SMatthias Springer   if (auto affineStoreOp =
6114c48f016SMatthias Springer           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
6121b002d27SArnab Dutta     AffineMap affineMap = affineStoreOp.getAffineMap();
6131b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
6141b002d27SArnab Dutta         affineMap, indices, storeOp.getLoc(), rewriter);
6151b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
6161b002d27SArnab Dutta   }
617203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
6184c48f016SMatthias Springer   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
6194dc72d47SNicolas Vasilache       rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
6204dc72d47SNicolas Vasilache       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
6214dc72d47SNicolas Vasilache       sourceIndices);
622b7d47ed1SNicolas Vasilache 
6231b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(storeOp)
6244c48f016SMatthias Springer       .Case([&](affine::AffineStoreOp op) {
6254c48f016SMatthias Springer         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
62659e4fbfcSLei Zhang             op, op.getValue(), subViewOp.getSource(), sourceIndices);
6271b002d27SArnab Dutta       })
6281cb91b42SGuray Ozen       .Case([&](memref::StoreOp op) {
6291cb91b42SGuray Ozen         rewriter.replaceOpWithNewOp<memref::StoreOp>(
63059e4fbfcSLei Zhang             op, op.getValue(), subViewOp.getSource(), sourceIndices,
6311cb91b42SGuray Ozen             op.getNontemporal());
6321cb91b42SGuray Ozen       })
6331b002d27SArnab Dutta       .Case([&](vector::TransferWriteOp op) {
6341b002d27SArnab Dutta         rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
635c692a11eSRiver Riddle             op, op.getValue(), subViewOp.getSource(), sourceIndices,
6364dc72d47SNicolas Vasilache             AffineMapAttr::get(expandDimsToRank(
6374dc72d47SNicolas Vasilache                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
6384dc72d47SNicolas Vasilache                 subViewOp.getDroppedDims())),
63948f980c5SQuinn Dawkins             op.getMask(), op.getInBoundsAttr());
6401b002d27SArnab Dutta       })
641dae3c44cSMax191       .Case([&](vector::StoreOp op) {
642dae3c44cSMax191         rewriter.replaceOpWithNewOp<vector::StoreOp>(
643dae3c44cSMax191             op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
644dae3c44cSMax191       })
645dae3c44cSMax191       .Case([&](vector::MaskedStoreOp op) {
646dae3c44cSMax191         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
647dae3c44cSMax191             op, subViewOp.getSource(), sourceIndices, op.getMask(),
648dae3c44cSMax191             op.getValueToStore());
649dae3c44cSMax191       })
65059e4fbfcSLei Zhang       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
65159e4fbfcSLei Zhang         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
65259e4fbfcSLei Zhang             op, op.getSrc(), subViewOp.getSource(), sourceIndices,
65359e4fbfcSLei Zhang             op.getLeadDimension(), op.getTransposeAttr());
65459e4fbfcSLei Zhang       })
6551b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
6561b002d27SArnab Dutta   return success();
6571b002d27SArnab Dutta }
6581b002d27SArnab Dutta 
6591b002d27SArnab Dutta template <typename OpTy>
6601b002d27SArnab Dutta LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
6611b002d27SArnab Dutta     OpTy storeOp, PatternRewriter &rewriter) const {
6621b002d27SArnab Dutta   auto expandShapeOp =
6631b002d27SArnab Dutta       getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
6641b002d27SArnab Dutta 
6651b002d27SArnab Dutta   if (!expandShapeOp)
6661b002d27SArnab Dutta     return failure();
6671b002d27SArnab Dutta 
6681b002d27SArnab Dutta   SmallVector<Value> indices(storeOp.getIndices().begin(),
6691b002d27SArnab Dutta                              storeOp.getIndices().end());
6701b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
6711b002d27SArnab Dutta   // "actual" indices.
6724c48f016SMatthias Springer   if (auto affineStoreOp =
6734c48f016SMatthias Springer           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
6741b002d27SArnab Dutta     AffineMap affineMap = affineStoreOp.getAffineMap();
6751b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
6761b002d27SArnab Dutta         affineMap, indices, storeOp.getLoc(), rewriter);
6771b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
6781b002d27SArnab Dutta   }
679203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
6801b002d27SArnab Dutta   if (failed(resolveSourceIndicesExpandShape(
6811b002d27SArnab Dutta           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
6821b002d27SArnab Dutta     return failure();
6831b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(storeOp)
68457e43608SKunwar Grover       .Case([&](affine::AffineStoreOp op) {
68557e43608SKunwar Grover         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
68657e43608SKunwar Grover             storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
6871b002d27SArnab Dutta             sourceIndices);
6881b002d27SArnab Dutta       })
68957e43608SKunwar Grover       .Case([&](memref::StoreOp op) {
69057e43608SKunwar Grover         rewriter.replaceOpWithNewOp<memref::StoreOp>(
69157e43608SKunwar Grover             storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
69257e43608SKunwar Grover             sourceIndices, op.getNontemporal());
69357e43608SKunwar Grover       })
69457e43608SKunwar Grover       .Case([&](vector::StoreOp op) {
69557e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::StoreOp>(
69657e43608SKunwar Grover             op, op.getValueToStore(), expandShapeOp.getViewSource(),
69757e43608SKunwar Grover             sourceIndices, op.getNontemporal());
69857e43608SKunwar Grover       })
69957e43608SKunwar Grover       .Case([&](vector::MaskedStoreOp op) {
70057e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
70157e43608SKunwar Grover             op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
70257e43608SKunwar Grover             op.getValueToStore());
70357e43608SKunwar Grover       })
7041b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
7051b002d27SArnab Dutta   return success();
7061b002d27SArnab Dutta }
7071b002d27SArnab Dutta 
7081b002d27SArnab Dutta template <typename OpTy>
7091b002d27SArnab Dutta LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
7101b002d27SArnab Dutta     OpTy storeOp, PatternRewriter &rewriter) const {
7111b002d27SArnab Dutta   auto collapseShapeOp = getMemRefOperand(storeOp)
7121b002d27SArnab Dutta                              .template getDefiningOp<memref::CollapseShapeOp>();
7131b002d27SArnab Dutta 
7141b002d27SArnab Dutta   if (!collapseShapeOp)
7151b002d27SArnab Dutta     return failure();
7161b002d27SArnab Dutta 
7171b002d27SArnab Dutta   SmallVector<Value> indices(storeOp.getIndices().begin(),
7181b002d27SArnab Dutta                              storeOp.getIndices().end());
7191b002d27SArnab Dutta   // For affine ops, we need to apply the map to get the operands to get the
7201b002d27SArnab Dutta   // "actual" indices.
7214c48f016SMatthias Springer   if (auto affineStoreOp =
7224c48f016SMatthias Springer           dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
7231b002d27SArnab Dutta     AffineMap affineMap = affineStoreOp.getAffineMap();
7241b002d27SArnab Dutta     auto expandedIndices = calculateExpandedAccessIndices(
7251b002d27SArnab Dutta         affineMap, indices, storeOp.getLoc(), rewriter);
7261b002d27SArnab Dutta     indices.assign(expandedIndices.begin(), expandedIndices.end());
7271b002d27SArnab Dutta   }
728203fad47SNicolas Vasilache   SmallVector<Value> sourceIndices;
7291b002d27SArnab Dutta   if (failed(resolveSourceIndicesCollapseShape(
7301b002d27SArnab Dutta           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
7311b002d27SArnab Dutta     return failure();
7321b002d27SArnab Dutta   llvm::TypeSwitch<Operation *, void>(storeOp)
73357e43608SKunwar Grover       .Case([&](affine::AffineStoreOp op) {
73457e43608SKunwar Grover         rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
73557e43608SKunwar Grover             storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
7361b002d27SArnab Dutta             sourceIndices);
7371b002d27SArnab Dutta       })
73857e43608SKunwar Grover       .Case([&](memref::StoreOp op) {
73957e43608SKunwar Grover         rewriter.replaceOpWithNewOp<memref::StoreOp>(
74057e43608SKunwar Grover             storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
74157e43608SKunwar Grover             sourceIndices, op.getNontemporal());
74257e43608SKunwar Grover       })
74357e43608SKunwar Grover       .Case([&](vector::StoreOp op) {
74457e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::StoreOp>(
74557e43608SKunwar Grover             op, op.getValueToStore(), collapseShapeOp.getViewSource(),
74657e43608SKunwar Grover             sourceIndices, op.getNontemporal());
74757e43608SKunwar Grover       })
74857e43608SKunwar Grover       .Case([&](vector::MaskedStoreOp op) {
74957e43608SKunwar Grover         rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
75057e43608SKunwar Grover             op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
75157e43608SKunwar Grover             op.getValueToStore());
75257e43608SKunwar Grover       })
7531b002d27SArnab Dutta       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
7541b002d27SArnab Dutta   return success();
7551b002d27SArnab Dutta }
7561b002d27SArnab Dutta 
757baa5beecStyb0807 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
758fc5c1a76SManish Gupta     nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
759fc5c1a76SManish Gupta 
760fc5c1a76SManish Gupta   LLVM_DEBUG(DBGS() << "copyOp       : " << copyOp << "\n");
761fc5c1a76SManish Gupta 
762fc5c1a76SManish Gupta   auto srcSubViewOp =
763fc5c1a76SManish Gupta       copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
764fc5c1a76SManish Gupta   auto dstSubViewOp =
765fc5c1a76SManish Gupta       copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
766fc5c1a76SManish Gupta 
767fc5c1a76SManish Gupta   if (!(srcSubViewOp || dstSubViewOp))
768fc5c1a76SManish Gupta     return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
769fc5c1a76SManish Gupta                                                "source or destination");
770fc5c1a76SManish Gupta 
771fc5c1a76SManish Gupta   // If the source is a subview, we need to resolve the indices.
772fc5c1a76SManish Gupta   SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
773fc5c1a76SManish Gupta                                 copyOp.getSrcIndices().end());
774fc5c1a76SManish Gupta   SmallVector<Value> foldedSrcIndices(srcindices);
775fc5c1a76SManish Gupta 
776fc5c1a76SManish Gupta   if (srcSubViewOp) {
777fc5c1a76SManish Gupta     LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
7784c48f016SMatthias Springer     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
779fc5c1a76SManish Gupta         rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
780fc5c1a76SManish Gupta         srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
781fc5c1a76SManish Gupta         srcindices, foldedSrcIndices);
782fc5c1a76SManish Gupta   }
783fc5c1a76SManish Gupta 
784fc5c1a76SManish Gupta   // If the destination is a subview, we need to resolve the indices.
785fc5c1a76SManish Gupta   SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
786fc5c1a76SManish Gupta                                 copyOp.getDstIndices().end());
787fc5c1a76SManish Gupta   SmallVector<Value> foldedDstIndices(dstindices);
788fc5c1a76SManish Gupta 
789fc5c1a76SManish Gupta   if (dstSubViewOp) {
790fc5c1a76SManish Gupta     LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
7914c48f016SMatthias Springer     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
792fc5c1a76SManish Gupta         rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
793fc5c1a76SManish Gupta         dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
794fc5c1a76SManish Gupta         dstindices, foldedDstIndices);
795fc5c1a76SManish Gupta   }
796fc5c1a76SManish Gupta 
797fc5c1a76SManish Gupta   // Replace the copy op with a new copy op that uses the source and destination
798fc5c1a76SManish Gupta   // of the subview.
799fc5c1a76SManish Gupta   rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
800fc5c1a76SManish Gupta       copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
801fc5c1a76SManish Gupta       (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
802fc5c1a76SManish Gupta       foldedDstIndices,
803fc5c1a76SManish Gupta       (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
804fc5c1a76SManish Gupta       foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
805fc5c1a76SManish Gupta       copyOp.getBypassL1Attr());
806fc5c1a76SManish Gupta 
807fc5c1a76SManish Gupta   return success();
808fc5c1a76SManish Gupta }
809fc5c1a76SManish Gupta 
8101b002d27SArnab Dutta void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
8114c48f016SMatthias Springer   patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
8121b002d27SArnab Dutta                LoadOpOfSubViewOpFolder<memref::LoadOp>,
81346c32afbSGuray Ozen                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
8145ec360c5SGuray Ozen                LoadOpOfSubViewOpFolder<vector::LoadOp>,
8155aa2c65aStyb0807                LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
8161b002d27SArnab Dutta                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
81759e4fbfcSLei Zhang                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
8184c48f016SMatthias Springer                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
8191b002d27SArnab Dutta                StoreOpOfSubViewOpFolder<memref::StoreOp>,
8201b002d27SArnab Dutta                StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
821dae3c44cSMax191                StoreOpOfSubViewOpFolder<vector::StoreOp>,
822dae3c44cSMax191                StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
82359e4fbfcSLei Zhang                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
8244c48f016SMatthias Springer                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
8251b002d27SArnab Dutta                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
82657e43608SKunwar Grover                LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
82757e43608SKunwar Grover                LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
8284c48f016SMatthias Springer                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
8291b002d27SArnab Dutta                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
83057e43608SKunwar Grover                StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
83157e43608SKunwar Grover                StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
8324c48f016SMatthias Springer                LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
8331b002d27SArnab Dutta                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
83457e43608SKunwar Grover                LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
83557e43608SKunwar Grover                LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
8364c48f016SMatthias Springer                StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
837ccb8a4e3SMatthias Springer                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
83857e43608SKunwar Grover                StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
83957e43608SKunwar Grover                StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
840baa5beecStyb0807                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
841fc5c1a76SManish Gupta       patterns.getContext());
8421b002d27SArnab Dutta }
8431b002d27SArnab Dutta 
8441b002d27SArnab Dutta //===----------------------------------------------------------------------===//
8451b002d27SArnab Dutta // Pass registration
8461b002d27SArnab Dutta //===----------------------------------------------------------------------===//
8471b002d27SArnab Dutta 
8481b002d27SArnab Dutta namespace {
8491b002d27SArnab Dutta 
8501b002d27SArnab Dutta struct FoldMemRefAliasOpsPass final
85167d0d7acSMichele Scuttari     : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
8521b002d27SArnab Dutta   void runOnOperation() override;
8531b002d27SArnab Dutta };
8541b002d27SArnab Dutta 
8551b002d27SArnab Dutta } // namespace
8561b002d27SArnab Dutta 
8571b002d27SArnab Dutta void FoldMemRefAliasOpsPass::runOnOperation() {
8581b002d27SArnab Dutta   RewritePatternSet patterns(&getContext());
8591b002d27SArnab Dutta   memref::populateFoldMemRefAliasOpPatterns(patterns);
860*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
8611b002d27SArnab Dutta }
8621b002d27SArnab Dutta 
8631b002d27SArnab Dutta std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
8641b002d27SArnab Dutta   return std::make_unique<FoldMemRefAliasOpsPass>();
8651b002d27SArnab Dutta }
866