xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===- ExtractAddressCmoputations.cpp - Extract address computations  -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This transformation pass rewrites loading/storing from/to a memref with
10 /// offsets into loading/storing from/to a subview and without any offset on
11 /// the instruction itself.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // Helper functions for the `load base[off0...]`
30 //  => `load (subview base[off0...])[0...]` pattern.
31 //===----------------------------------------------------------------------===//
32 
33 // Matches getFailureOrSrcMemRef specs for LoadOp.
34 // \see LoadStoreLikeOpRewriter.
getLoadOpSrcMemRef(memref::LoadOp loadOp)35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36   return loadOp.getMemRef();
37 }
38 
39 // Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40 // \see LoadStoreLikeOpRewriter.
rebuildLoadOp(RewriterBase & rewriter,memref::LoadOp loadOp,Value srcMemRef,ArrayRef<Value> indices)41 static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42                                     memref::LoadOp loadOp, Value srcMemRef,
43                                     ArrayRef<Value> indices) {
44   Location loc = loadOp.getLoc();
45   return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
46                                          loadOp.getNontemporal());
47 }
48 
49 // Matches getViewSizeForEachDim specs for LoadOp.
50 // \see LoadStoreLikeOpRewriter.
51 static SmallVector<OpFoldResult>
getLoadOpViewSizeForEachDim(RewriterBase & rewriter,memref::LoadOp loadOp)52 getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53   MemRefType ldTy = loadOp.getMemRefType();
54   unsigned loadRank = ldTy.getRank();
55   return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Helper functions for the `store val, base[off0...]`
60 //  => `store val, (subview base[off0...])[0...]` pattern.
61 //===----------------------------------------------------------------------===//
62 
63 // Matches getFailureOrSrcMemRef specs for StoreOp.
64 // \see LoadStoreLikeOpRewriter.
getStoreOpSrcMemRef(memref::StoreOp storeOp)65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66   return storeOp.getMemRef();
67 }
68 
69 // Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70 // \see LoadStoreLikeOpRewriter.
rebuildStoreOp(RewriterBase & rewriter,memref::StoreOp storeOp,Value srcMemRef,ArrayRef<Value> indices)71 static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72                                       memref::StoreOp storeOp, Value srcMemRef,
73                                       ArrayRef<Value> indices) {
74   Location loc = storeOp.getLoc();
75   return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
76                                           srcMemRef, indices,
77                                           storeOp.getNontemporal());
78 }
79 
80 // Matches getViewSizeForEachDim specs for StoreOp.
81 // \see LoadStoreLikeOpRewriter.
82 static SmallVector<OpFoldResult>
getStoreOpViewSizeForEachDim(RewriterBase & rewriter,memref::StoreOp storeOp)83 getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
84   MemRefType ldTy = storeOp.getMemRefType();
85   unsigned loadRank = ldTy.getRank();
86   return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Helper functions for the `ldmatrix base[off0...]`
91 //  => `ldmatrix (subview base[off0...])[0...]` pattern.
92 //===----------------------------------------------------------------------===//
93 
94 // Matches getFailureOrSrcMemRef specs for LdMatrixOp.
95 // \see LoadStoreLikeOpRewriter.
getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp)96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97   return ldMatrixOp.getSrcMemref();
98 }
99 
100 // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
101 // \see LoadStoreLikeOpRewriter.
rebuildLdMatrixOp(RewriterBase & rewriter,nvgpu::LdMatrixOp ldMatrixOp,Value srcMemRef,ArrayRef<Value> indices)102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
103                                            nvgpu::LdMatrixOp ldMatrixOp,
104                                            Value srcMemRef,
105                                            ArrayRef<Value> indices) {
106   Location loc = ldMatrixOp.getLoc();
107   return rewriter.create<nvgpu::LdMatrixOp>(
108       loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109       ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Helper functions for the `transfer_read base[off0...]`
114 //  => `transfer_read (subview base[off0...])[0...]` pattern.
115 //===----------------------------------------------------------------------===//
116 
117 // Matches getFailureOrSrcMemRef specs for TransferReadOp.
118 // \see LoadStoreLikeOpRewriter.
119 template <typename TransferLikeOp>
120 static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp)121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122   Value src = transferLikeOp.getSource();
123   if (isa<MemRefType>(src.getType()))
124     return src;
125   return failure();
126 }
127 
128 // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
129 // \see LoadStoreLikeOpRewriter.
130 static vector::TransferReadOp
rebuildTransferReadOp(RewriterBase & rewriter,vector::TransferReadOp transferReadOp,Value srcMemRef,ArrayRef<Value> indices)131 rebuildTransferReadOp(RewriterBase &rewriter,
132                       vector::TransferReadOp transferReadOp, Value srcMemRef,
133                       ArrayRef<Value> indices) {
134   Location loc = transferReadOp.getLoc();
135   return rewriter.create<vector::TransferReadOp>(
136       loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137       transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138       transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Helper functions for the `transfer_write base[off0...]`
143 //  => `transfer_write (subview base[off0...])[0...]` pattern.
144 //===----------------------------------------------------------------------===//
145 
146 // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
147 // \see LoadStoreLikeOpRewriter.
148 static vector::TransferWriteOp
rebuildTransferWriteOp(RewriterBase & rewriter,vector::TransferWriteOp transferWriteOp,Value srcMemRef,ArrayRef<Value> indices)149 rebuildTransferWriteOp(RewriterBase &rewriter,
150                        vector::TransferWriteOp transferWriteOp, Value srcMemRef,
151                        ArrayRef<Value> indices) {
152   Location loc = transferWriteOp.getLoc();
153   return rewriter.create<vector::TransferWriteOp>(
154       loc, transferWriteOp.getValue(), srcMemRef, indices,
155       transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156       transferWriteOp.getInBoundsAttr());
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Generic helper functions used as default implementation in
161 // LoadStoreLikeOpRewriter.
162 //===----------------------------------------------------------------------===//
163 
164 /// Helper function to get the src memref.
165 /// It uses the already defined getFailureOrSrcMemRef but asserts
166 /// that the source is a memref.
167 template <typename LoadStoreLikeOp,
168           FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp)169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170   FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171   assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
172   return *failureOrSrcMemRef;
173 }
174 
175 /// Helper function to get the sizes of the resulting view.
176 /// This function gets the sizes of the source memref then substracts the
177 /// offsets used within \p loadStoreLikeOp. This gives the maximal (for
178 /// inbound) sizes for the view.
179 /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
180 template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
181 static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase & rewriter,LoadStoreLikeOp loadStoreLikeOp)182 getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
183                                LoadStoreLikeOp loadStoreLikeOp) {
184   Location loc = loadStoreLikeOp.getLoc();
185   auto extractStridedMetadataOp =
186       rewriter.create<memref::ExtractStridedMetadataOp>(
187           loc, getSrcMemRef(loadStoreLikeOp));
188   SmallVector<OpFoldResult> srcSizes =
189       extractStridedMetadataOp.getConstifiedMixedSizes();
190   SmallVector<OpFoldResult> indices =
191       getAsOpFoldResult(loadStoreLikeOp.getIndices());
192   SmallVector<OpFoldResult> finalSizes;
193 
194   AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
195   AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
196 
197   for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
198     finalSizes.push_back(affine::makeComposedFoldedAffineApply(
199         rewriter, loc, s0 - s1, {srcSize, indice}));
200   }
201   return finalSizes;
202 }
203 
204 /// Rewrite a store/load-like op so that all its indices are zeros.
205 /// E.g., %ld = memref.load %base[%off0]...[%offN]
206 /// =>
207 /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
208 /// %ld = memref.load %new_base[0,..,0] :
209 ///    memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
210 ///
211 /// `getSrcMemRef` returns the source memref for the given load-like operation.
212 ///
213 /// `getViewSizeForEachDim` returns the sizes of view that is going to feed
214 /// new operation. This must return one size per dimension of the view.
215 /// The sizes of the view needs to be at least as big as what is actually
216 /// going to be accessed. Use the provided `loadStoreOp` to get the right
217 /// sizes.
218 ///
219 /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
220 /// LoadStoreLikeOp that reads from srcMemRef[indices].
221 /// The returned operation will be used to replace loadStoreOp.
222 template <typename LoadStoreLikeOp,
223           FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224           LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
225               RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
226               Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
227           SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
228               RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
229               getGenericOpViewSizeForEachDim<
230                   LoadStoreLikeOp,
231                   getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
232 struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
233   using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
234 
matchAndRewrite__anon69e44af10111::LoadStoreLikeOpRewriter235   LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
236                                 PatternRewriter &rewriter) const override {
237     FailureOr<Value> failureOrSrcMemRef =
238         getFailureOrSrcMemRef(loadStoreLikeOp);
239     if (failed(failureOrSrcMemRef))
240       return rewriter.notifyMatchFailure(loadStoreLikeOp,
241                                          "source is not a memref");
242     Value srcMemRef = *failureOrSrcMemRef;
243     auto ldStTy = cast<MemRefType>(srcMemRef.getType());
244     unsigned loadStoreRank = ldStTy.getRank();
245     // Don't waste compile time if there is nothing to rewrite.
246     if (loadStoreRank == 0)
247       return rewriter.notifyMatchFailure(loadStoreLikeOp,
248                                          "0-D accesses don't need rewriting");
249 
250     // If our load already has only zeros as indices there is nothing
251     // to do.
252     SmallVector<OpFoldResult> indices =
253         getAsOpFoldResult(loadStoreLikeOp.getIndices());
254     if (std::all_of(indices.begin(), indices.end(),
255                     [](const OpFoldResult &opFold) {
256                       return isConstantIntValue(opFold, 0);
257                     })) {
258       return rewriter.notifyMatchFailure(
259           loadStoreLikeOp, "no computation to extract: offsets are 0s");
260     }
261 
262     // Create the array of ones of the right size.
263     SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
264     SmallVector<OpFoldResult> sizes =
265         getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266     assert(sizes.size() == loadStoreRank &&
267            "Expected one size per load dimension");
268     Location loc = loadStoreLikeOp.getLoc();
269     // The subview inherits its strides from the original memref and will
270     // apply them properly to the input indices.
271     // Therefore the strides multipliers are simply ones.
272     auto subview =
273         rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
274                                            /*offsets=*/indices,
275                                            /*sizes=*/sizes, /*strides=*/ones);
276     // Rewrite the load/store with the subview as the base pointer.
277     SmallVector<Value> zeros(loadStoreRank,
278                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
279     LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280         rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281     rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
282     return success();
283   }
284 };
285 } // namespace
286 
populateExtractAddressComputationsPatterns(RewritePatternSet & patterns)287 void memref::populateExtractAddressComputationsPatterns(
288     RewritePatternSet &patterns) {
289   patterns.add<
290       LoadStoreLikeOpRewriter<
291           memref::LoadOp,
292           /*getSrcMemRef=*/getLoadOpSrcMemRef,
293           /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
294           /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
295       LoadStoreLikeOpRewriter<
296           memref::StoreOp,
297           /*getSrcMemRef=*/getStoreOpSrcMemRef,
298           /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
299           /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
300       LoadStoreLikeOpRewriter<
301           nvgpu::LdMatrixOp,
302           /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
303           /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
304       LoadStoreLikeOpRewriter<
305           vector::TransferReadOp,
306           /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307           /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
308       LoadStoreLikeOpRewriter<
309           vector::TransferWriteOp,
310           /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311           /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
312       patterns.getContext());
313 }
314