xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
154cda2ecSQuentin Colombet //===- ExtractAddressCmoputations.cpp - Extract address computations  -----===//
254cda2ecSQuentin Colombet //
354cda2ecSQuentin Colombet // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
454cda2ecSQuentin Colombet // See https://llvm.org/LICENSE.txt for license information.
554cda2ecSQuentin Colombet // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
654cda2ecSQuentin Colombet //
754cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
854cda2ecSQuentin Colombet //
954cda2ecSQuentin Colombet /// This transformation pass rewrites loading/storing from/to a memref with
1054cda2ecSQuentin Colombet /// offsets into loading/storing from/to a subview and without any offset on
1154cda2ecSQuentin Colombet /// the instruction itself.
1254cda2ecSQuentin Colombet //
1354cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
1454cda2ecSQuentin Colombet 
1554cda2ecSQuentin Colombet #include "mlir/Dialect/Affine/IR/AffineOps.h"
1654cda2ecSQuentin Colombet #include "mlir/Dialect/Arith/IR/Arith.h"
1754cda2ecSQuentin Colombet #include "mlir/Dialect/MemRef/IR/MemRef.h"
1854cda2ecSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1954cda2ecSQuentin Colombet #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2054cda2ecSQuentin Colombet #include "mlir/Dialect/Utils/StaticValueUtils.h"
2154cda2ecSQuentin Colombet #include "mlir/Dialect/Vector/IR/VectorOps.h"
2254cda2ecSQuentin Colombet #include "mlir/IR/PatternMatch.h"
2354cda2ecSQuentin Colombet 
2454cda2ecSQuentin Colombet using namespace mlir;
2554cda2ecSQuentin Colombet 
2654cda2ecSQuentin Colombet namespace {
2754cda2ecSQuentin Colombet 
2854cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
2954cda2ecSQuentin Colombet // Helper functions for the `load base[off0...]`
3054cda2ecSQuentin Colombet //  => `load (subview base[off0...])[0...]` pattern.
3154cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
3254cda2ecSQuentin Colombet 
3354cda2ecSQuentin Colombet // Matches getFailureOrSrcMemRef specs for LoadOp.
3454cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
getLoadOpSrcMemRef(memref::LoadOp loadOp)3554cda2ecSQuentin Colombet static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
3654cda2ecSQuentin Colombet   return loadOp.getMemRef();
3754cda2ecSQuentin Colombet }
3854cda2ecSQuentin Colombet 
3954cda2ecSQuentin Colombet // Matches rebuildOpFromAddressAndIndices specs for LoadOp.
4054cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
rebuildLoadOp(RewriterBase & rewriter,memref::LoadOp loadOp,Value srcMemRef,ArrayRef<Value> indices)4154cda2ecSQuentin Colombet static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
4254cda2ecSQuentin Colombet                                     memref::LoadOp loadOp, Value srcMemRef,
4354cda2ecSQuentin Colombet                                     ArrayRef<Value> indices) {
4454cda2ecSQuentin Colombet   Location loc = loadOp.getLoc();
4554cda2ecSQuentin Colombet   return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
4654cda2ecSQuentin Colombet                                          loadOp.getNontemporal());
4754cda2ecSQuentin Colombet }
4854cda2ecSQuentin Colombet 
4954cda2ecSQuentin Colombet // Matches getViewSizeForEachDim specs for LoadOp.
5054cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
5154cda2ecSQuentin Colombet static SmallVector<OpFoldResult>
getLoadOpViewSizeForEachDim(RewriterBase & rewriter,memref::LoadOp loadOp)5254cda2ecSQuentin Colombet getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
5354cda2ecSQuentin Colombet   MemRefType ldTy = loadOp.getMemRefType();
5454cda2ecSQuentin Colombet   unsigned loadRank = ldTy.getRank();
5554cda2ecSQuentin Colombet   return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
5654cda2ecSQuentin Colombet }
5754cda2ecSQuentin Colombet 
5854cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
5954cda2ecSQuentin Colombet // Helper functions for the `store val, base[off0...]`
6054cda2ecSQuentin Colombet //  => `store val, (subview base[off0...])[0...]` pattern.
6154cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
6254cda2ecSQuentin Colombet 
6354cda2ecSQuentin Colombet // Matches getFailureOrSrcMemRef specs for StoreOp.
6454cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
getStoreOpSrcMemRef(memref::StoreOp storeOp)6554cda2ecSQuentin Colombet static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
6654cda2ecSQuentin Colombet   return storeOp.getMemRef();
6754cda2ecSQuentin Colombet }
6854cda2ecSQuentin Colombet 
6954cda2ecSQuentin Colombet // Matches rebuildOpFromAddressAndIndices specs for StoreOp.
7054cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
rebuildStoreOp(RewriterBase & rewriter,memref::StoreOp storeOp,Value srcMemRef,ArrayRef<Value> indices)7154cda2ecSQuentin Colombet static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
7254cda2ecSQuentin Colombet                                       memref::StoreOp storeOp, Value srcMemRef,
7354cda2ecSQuentin Colombet                                       ArrayRef<Value> indices) {
7454cda2ecSQuentin Colombet   Location loc = storeOp.getLoc();
7554cda2ecSQuentin Colombet   return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
7654cda2ecSQuentin Colombet                                           srcMemRef, indices,
7754cda2ecSQuentin Colombet                                           storeOp.getNontemporal());
7854cda2ecSQuentin Colombet }
7954cda2ecSQuentin Colombet 
8054cda2ecSQuentin Colombet // Matches getViewSizeForEachDim specs for StoreOp.
8154cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
8254cda2ecSQuentin Colombet static SmallVector<OpFoldResult>
getStoreOpViewSizeForEachDim(RewriterBase & rewriter,memref::StoreOp storeOp)8354cda2ecSQuentin Colombet getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
8454cda2ecSQuentin Colombet   MemRefType ldTy = storeOp.getMemRefType();
8554cda2ecSQuentin Colombet   unsigned loadRank = ldTy.getRank();
8654cda2ecSQuentin Colombet   return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
8754cda2ecSQuentin Colombet }
8854cda2ecSQuentin Colombet 
8954cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
9054cda2ecSQuentin Colombet // Helper functions for the `ldmatrix base[off0...]`
9154cda2ecSQuentin Colombet //  => `ldmatrix (subview base[off0...])[0...]` pattern.
9254cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
9354cda2ecSQuentin Colombet 
9454cda2ecSQuentin Colombet // Matches getFailureOrSrcMemRef specs for LdMatrixOp.
9554cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp)9654cda2ecSQuentin Colombet static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
9754cda2ecSQuentin Colombet   return ldMatrixOp.getSrcMemref();
9854cda2ecSQuentin Colombet }
9954cda2ecSQuentin Colombet 
10054cda2ecSQuentin Colombet // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
10154cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
rebuildLdMatrixOp(RewriterBase & rewriter,nvgpu::LdMatrixOp ldMatrixOp,Value srcMemRef,ArrayRef<Value> indices)10254cda2ecSQuentin Colombet static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
10354cda2ecSQuentin Colombet                                            nvgpu::LdMatrixOp ldMatrixOp,
10454cda2ecSQuentin Colombet                                            Value srcMemRef,
10554cda2ecSQuentin Colombet                                            ArrayRef<Value> indices) {
10654cda2ecSQuentin Colombet   Location loc = ldMatrixOp.getLoc();
10754cda2ecSQuentin Colombet   return rewriter.create<nvgpu::LdMatrixOp>(
10854cda2ecSQuentin Colombet       loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
10954cda2ecSQuentin Colombet       ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
11054cda2ecSQuentin Colombet }
11154cda2ecSQuentin Colombet 
11254cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
11354cda2ecSQuentin Colombet // Helper functions for the `transfer_read base[off0...]`
11454cda2ecSQuentin Colombet //  => `transfer_read (subview base[off0...])[0...]` pattern.
11554cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
11654cda2ecSQuentin Colombet 
11754cda2ecSQuentin Colombet // Matches getFailureOrSrcMemRef specs for TransferReadOp.
11854cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
11954cda2ecSQuentin Colombet template <typename TransferLikeOp>
12054cda2ecSQuentin Colombet static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp)12154cda2ecSQuentin Colombet getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
12254cda2ecSQuentin Colombet   Value src = transferLikeOp.getSource();
123*5550c821STres Popp   if (isa<MemRefType>(src.getType()))
12454cda2ecSQuentin Colombet     return src;
12554cda2ecSQuentin Colombet   return failure();
12654cda2ecSQuentin Colombet }
12754cda2ecSQuentin Colombet 
12854cda2ecSQuentin Colombet // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
12954cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
13054cda2ecSQuentin Colombet static vector::TransferReadOp
rebuildTransferReadOp(RewriterBase & rewriter,vector::TransferReadOp transferReadOp,Value srcMemRef,ArrayRef<Value> indices)13154cda2ecSQuentin Colombet rebuildTransferReadOp(RewriterBase &rewriter,
13254cda2ecSQuentin Colombet                       vector::TransferReadOp transferReadOp, Value srcMemRef,
13354cda2ecSQuentin Colombet                       ArrayRef<Value> indices) {
13454cda2ecSQuentin Colombet   Location loc = transferReadOp.getLoc();
13554cda2ecSQuentin Colombet   return rewriter.create<vector::TransferReadOp>(
13654cda2ecSQuentin Colombet       loc, transferReadOp.getResult().getType(), srcMemRef, indices,
13754cda2ecSQuentin Colombet       transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
13854cda2ecSQuentin Colombet       transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
13954cda2ecSQuentin Colombet }
14054cda2ecSQuentin Colombet 
14154cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
14254cda2ecSQuentin Colombet // Helper functions for the `transfer_write base[off0...]`
14354cda2ecSQuentin Colombet //  => `transfer_write (subview base[off0...])[0...]` pattern.
14454cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
14554cda2ecSQuentin Colombet 
14654cda2ecSQuentin Colombet // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
14754cda2ecSQuentin Colombet // \see LoadStoreLikeOpRewriter.
14854cda2ecSQuentin Colombet static vector::TransferWriteOp
rebuildTransferWriteOp(RewriterBase & rewriter,vector::TransferWriteOp transferWriteOp,Value srcMemRef,ArrayRef<Value> indices)14954cda2ecSQuentin Colombet rebuildTransferWriteOp(RewriterBase &rewriter,
15054cda2ecSQuentin Colombet                        vector::TransferWriteOp transferWriteOp, Value srcMemRef,
15154cda2ecSQuentin Colombet                        ArrayRef<Value> indices) {
15254cda2ecSQuentin Colombet   Location loc = transferWriteOp.getLoc();
15354cda2ecSQuentin Colombet   return rewriter.create<vector::TransferWriteOp>(
15454cda2ecSQuentin Colombet       loc, transferWriteOp.getValue(), srcMemRef, indices,
15554cda2ecSQuentin Colombet       transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
15654cda2ecSQuentin Colombet       transferWriteOp.getInBoundsAttr());
15754cda2ecSQuentin Colombet }
15854cda2ecSQuentin Colombet 
15954cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
16054cda2ecSQuentin Colombet // Generic helper functions used as default implementation in
16154cda2ecSQuentin Colombet // LoadStoreLikeOpRewriter.
16254cda2ecSQuentin Colombet //===----------------------------------------------------------------------===//
16354cda2ecSQuentin Colombet 
16454cda2ecSQuentin Colombet /// Helper function to get the src memref.
16554cda2ecSQuentin Colombet /// It uses the already defined getFailureOrSrcMemRef but asserts
16654cda2ecSQuentin Colombet /// that the source is a memref.
16754cda2ecSQuentin Colombet template <typename LoadStoreLikeOp,
16854cda2ecSQuentin Colombet           FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp)16954cda2ecSQuentin Colombet static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
17054cda2ecSQuentin Colombet   FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
17154cda2ecSQuentin Colombet   assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
17254cda2ecSQuentin Colombet   return *failureOrSrcMemRef;
17354cda2ecSQuentin Colombet }
17454cda2ecSQuentin Colombet 
17554cda2ecSQuentin Colombet /// Helper function to get the sizes of the resulting view.
17654cda2ecSQuentin Colombet /// This function gets the sizes of the source memref then substracts the
17754cda2ecSQuentin Colombet /// offsets used within \p loadStoreLikeOp. This gives the maximal (for
17854cda2ecSQuentin Colombet /// inbound) sizes for the view.
17954cda2ecSQuentin Colombet /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
18054cda2ecSQuentin Colombet template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
18154cda2ecSQuentin Colombet static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase & rewriter,LoadStoreLikeOp loadStoreLikeOp)18254cda2ecSQuentin Colombet getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
18354cda2ecSQuentin Colombet                                LoadStoreLikeOp loadStoreLikeOp) {
18454cda2ecSQuentin Colombet   Location loc = loadStoreLikeOp.getLoc();
18554cda2ecSQuentin Colombet   auto extractStridedMetadataOp =
18654cda2ecSQuentin Colombet       rewriter.create<memref::ExtractStridedMetadataOp>(
18754cda2ecSQuentin Colombet           loc, getSrcMemRef(loadStoreLikeOp));
18854cda2ecSQuentin Colombet   SmallVector<OpFoldResult> srcSizes =
18954cda2ecSQuentin Colombet       extractStridedMetadataOp.getConstifiedMixedSizes();
19054cda2ecSQuentin Colombet   SmallVector<OpFoldResult> indices =
19154cda2ecSQuentin Colombet       getAsOpFoldResult(loadStoreLikeOp.getIndices());
19254cda2ecSQuentin Colombet   SmallVector<OpFoldResult> finalSizes;
19354cda2ecSQuentin Colombet 
19454cda2ecSQuentin Colombet   AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
19554cda2ecSQuentin Colombet   AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
19654cda2ecSQuentin Colombet 
19754cda2ecSQuentin Colombet   for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
1984c48f016SMatthias Springer     finalSizes.push_back(affine::makeComposedFoldedAffineApply(
1994c48f016SMatthias Springer         rewriter, loc, s0 - s1, {srcSize, indice}));
20054cda2ecSQuentin Colombet   }
20154cda2ecSQuentin Colombet   return finalSizes;
20254cda2ecSQuentin Colombet }
20354cda2ecSQuentin Colombet 
20454cda2ecSQuentin Colombet /// Rewrite a store/load-like op so that all its indices are zeros.
20554cda2ecSQuentin Colombet /// E.g., %ld = memref.load %base[%off0]...[%offN]
20654cda2ecSQuentin Colombet /// =>
20754cda2ecSQuentin Colombet /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
20854cda2ecSQuentin Colombet /// %ld = memref.load %new_base[0,..,0] :
20954cda2ecSQuentin Colombet ///    memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
21054cda2ecSQuentin Colombet ///
21154cda2ecSQuentin Colombet /// `getSrcMemRef` returns the source memref for the given load-like operation.
21254cda2ecSQuentin Colombet ///
21354cda2ecSQuentin Colombet /// `getViewSizeForEachDim` returns the sizes of view that is going to feed
21454cda2ecSQuentin Colombet /// new operation. This must return one size per dimension of the view.
21554cda2ecSQuentin Colombet /// The sizes of the view needs to be at least as big as what is actually
21654cda2ecSQuentin Colombet /// going to be accessed. Use the provided `loadStoreOp` to get the right
21754cda2ecSQuentin Colombet /// sizes.
21854cda2ecSQuentin Colombet ///
21954cda2ecSQuentin Colombet /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
22054cda2ecSQuentin Colombet /// LoadStoreLikeOp that reads from srcMemRef[indices].
22154cda2ecSQuentin Colombet /// The returned operation will be used to replace loadStoreOp.
22254cda2ecSQuentin Colombet template <typename LoadStoreLikeOp,
22354cda2ecSQuentin Colombet           FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
22454cda2ecSQuentin Colombet           LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
22554cda2ecSQuentin Colombet               RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
22654cda2ecSQuentin Colombet               Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
22754cda2ecSQuentin Colombet           SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
22854cda2ecSQuentin Colombet               RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
22954cda2ecSQuentin Colombet               getGenericOpViewSizeForEachDim<
23054cda2ecSQuentin Colombet                   LoadStoreLikeOp,
23154cda2ecSQuentin Colombet                   getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
23254cda2ecSQuentin Colombet struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
23354cda2ecSQuentin Colombet   using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
23454cda2ecSQuentin Colombet 
matchAndRewrite__anon69e44af10111::LoadStoreLikeOpRewriter23554cda2ecSQuentin Colombet   LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
23654cda2ecSQuentin Colombet                                 PatternRewriter &rewriter) const override {
23754cda2ecSQuentin Colombet     FailureOr<Value> failureOrSrcMemRef =
23854cda2ecSQuentin Colombet         getFailureOrSrcMemRef(loadStoreLikeOp);
23954cda2ecSQuentin Colombet     if (failed(failureOrSrcMemRef))
24054cda2ecSQuentin Colombet       return rewriter.notifyMatchFailure(loadStoreLikeOp,
24154cda2ecSQuentin Colombet                                          "source is not a memref");
24254cda2ecSQuentin Colombet     Value srcMemRef = *failureOrSrcMemRef;
243*5550c821STres Popp     auto ldStTy = cast<MemRefType>(srcMemRef.getType());
24454cda2ecSQuentin Colombet     unsigned loadStoreRank = ldStTy.getRank();
24554cda2ecSQuentin Colombet     // Don't waste compile time if there is nothing to rewrite.
24654cda2ecSQuentin Colombet     if (loadStoreRank == 0)
24754cda2ecSQuentin Colombet       return rewriter.notifyMatchFailure(loadStoreLikeOp,
24854cda2ecSQuentin Colombet                                          "0-D accesses don't need rewriting");
24954cda2ecSQuentin Colombet 
25054cda2ecSQuentin Colombet     // If our load already has only zeros as indices there is nothing
25154cda2ecSQuentin Colombet     // to do.
25254cda2ecSQuentin Colombet     SmallVector<OpFoldResult> indices =
25354cda2ecSQuentin Colombet         getAsOpFoldResult(loadStoreLikeOp.getIndices());
25454cda2ecSQuentin Colombet     if (std::all_of(indices.begin(), indices.end(),
25554cda2ecSQuentin Colombet                     [](const OpFoldResult &opFold) {
25654cda2ecSQuentin Colombet                       return isConstantIntValue(opFold, 0);
25754cda2ecSQuentin Colombet                     })) {
25854cda2ecSQuentin Colombet       return rewriter.notifyMatchFailure(
25954cda2ecSQuentin Colombet           loadStoreLikeOp, "no computation to extract: offsets are 0s");
26054cda2ecSQuentin Colombet     }
26154cda2ecSQuentin Colombet 
26254cda2ecSQuentin Colombet     // Create the array of ones of the right size.
26354cda2ecSQuentin Colombet     SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
26454cda2ecSQuentin Colombet     SmallVector<OpFoldResult> sizes =
26554cda2ecSQuentin Colombet         getViewSizeForEachDim(rewriter, loadStoreLikeOp);
26654cda2ecSQuentin Colombet     assert(sizes.size() == loadStoreRank &&
26754cda2ecSQuentin Colombet            "Expected one size per load dimension");
26854cda2ecSQuentin Colombet     Location loc = loadStoreLikeOp.getLoc();
26954cda2ecSQuentin Colombet     // The subview inherits its strides from the original memref and will
27054cda2ecSQuentin Colombet     // apply them properly to the input indices.
27154cda2ecSQuentin Colombet     // Therefore the strides multipliers are simply ones.
27254cda2ecSQuentin Colombet     auto subview =
27354cda2ecSQuentin Colombet         rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
27454cda2ecSQuentin Colombet                                            /*offsets=*/indices,
27554cda2ecSQuentin Colombet                                            /*sizes=*/sizes, /*strides=*/ones);
27654cda2ecSQuentin Colombet     // Rewrite the load/store with the subview as the base pointer.
27754cda2ecSQuentin Colombet     SmallVector<Value> zeros(loadStoreRank,
27854cda2ecSQuentin Colombet                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
27954cda2ecSQuentin Colombet     LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
28054cda2ecSQuentin Colombet         rewriter, loadStoreLikeOp, subview.getResult(), zeros);
28154cda2ecSQuentin Colombet     rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
28254cda2ecSQuentin Colombet     return success();
28354cda2ecSQuentin Colombet   }
28454cda2ecSQuentin Colombet };
28554cda2ecSQuentin Colombet } // namespace
28654cda2ecSQuentin Colombet 
populateExtractAddressComputationsPatterns(RewritePatternSet & patterns)28754cda2ecSQuentin Colombet void memref::populateExtractAddressComputationsPatterns(
28854cda2ecSQuentin Colombet     RewritePatternSet &patterns) {
28954cda2ecSQuentin Colombet   patterns.add<
29054cda2ecSQuentin Colombet       LoadStoreLikeOpRewriter<
29154cda2ecSQuentin Colombet           memref::LoadOp,
29254cda2ecSQuentin Colombet           /*getSrcMemRef=*/getLoadOpSrcMemRef,
29354cda2ecSQuentin Colombet           /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
29454cda2ecSQuentin Colombet           /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
29554cda2ecSQuentin Colombet       LoadStoreLikeOpRewriter<
29654cda2ecSQuentin Colombet           memref::StoreOp,
29754cda2ecSQuentin Colombet           /*getSrcMemRef=*/getStoreOpSrcMemRef,
29854cda2ecSQuentin Colombet           /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
29954cda2ecSQuentin Colombet           /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
30054cda2ecSQuentin Colombet       LoadStoreLikeOpRewriter<
30154cda2ecSQuentin Colombet           nvgpu::LdMatrixOp,
30254cda2ecSQuentin Colombet           /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
30354cda2ecSQuentin Colombet           /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
30454cda2ecSQuentin Colombet       LoadStoreLikeOpRewriter<
30554cda2ecSQuentin Colombet           vector::TransferReadOp,
30654cda2ecSQuentin Colombet           /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
30754cda2ecSQuentin Colombet           /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
30854cda2ecSQuentin Colombet       LoadStoreLikeOpRewriter<
30954cda2ecSQuentin Colombet           vector::TransferWriteOp,
31054cda2ecSQuentin Colombet           /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
31154cda2ecSQuentin Colombet           /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
31254cda2ecSQuentin Colombet       patterns.getContext());
31354cda2ecSQuentin Colombet }
314