xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Dialect/GPU/Transforms/Passes.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
33   auto sourceType = cast<BaseMemRefType>(source.getType());
34   SmallVector<int64_t> staticOffsets;
35   SmallVector<Value> dynamicOffsets;
36   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
37   auto stridedLayout =
38       StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
39   return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
40                          sourceType.getMemorySpace());
41 }
42 
43 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
44   if (auto *parentOp = val.getDefiningOp()) {
45     builder.setInsertionPointAfter(parentOp);
46   } else {
47     builder.setInsertionPointToStart(val.getParentBlock());
48   }
49 }
50 
51 static bool isInsideLaunch(Operation *op) {
52   return op->getParentOfType<gpu::LaunchOp>();
53 }
54 
55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
56 getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
57                         ArrayRef<OpFoldResult> subOffsets,
58                         ArrayRef<OpFoldResult> subStrides = std::nullopt) {
59   auto sourceType = cast<MemRefType>(source.getType());
60   auto sourceRank = static_cast<unsigned>(sourceType.getRank());
61 
62   memref::ExtractStridedMetadataOp newExtractStridedMetadata;
63   {
64     OpBuilder::InsertionGuard g(rewriter);
65     setInsertionPointToStart(rewriter, source);
66     newExtractStridedMetadata =
67         rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
68   }
69 
70   auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
71 
72   auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
73     return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
74                                       : rewriter.getIndexAttr(dim);
75   };
76 
77   OpFoldResult origOffset =
78       getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79   ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80 
81   SmallVector<OpFoldResult> origStrides;
82   origStrides.reserve(sourceRank);
83 
84   SmallVector<OpFoldResult> strides;
85   strides.reserve(sourceRank);
86 
87   AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
88   AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
89   for (auto i : llvm::seq(0u, sourceRank)) {
90     OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
91 
92     if (!subStrides.empty()) {
93       strides.push_back(affine::makeComposedFoldedAffineApply(
94           rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95     }
96 
97     origStrides.emplace_back(origStride);
98   }
99 
100   auto &&[expr, values] =
101       computeLinearIndex(origOffset, origStrides, subOffsets);
102   OpFoldResult finalOffset =
103       affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
104   return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
105 }
106 
107 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
108                            ValueRange offsets) {
109   SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
110   auto &&[base, offset, ignore] =
111       getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
112   MemRefType retType = inferCastResultType(base, offset);
113   return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
114                                                     std::nullopt, std::nullopt);
115 }
116 
117 static bool needFlatten(Value val) {
118   auto type = cast<MemRefType>(val.getType());
119   return type.getRank() != 0;
120 }
121 
122 static bool checkLayout(Value val) {
123   auto type = cast<MemRefType>(val.getType());
124   return type.getLayout().isIdentity() ||
125          isa<StridedLayoutAttr>(type.getLayout());
126 }
127 
128 namespace {
129 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
130   using OpRewritePattern::OpRewritePattern;
131 
132   LogicalResult matchAndRewrite(memref::LoadOp op,
133                                 PatternRewriter &rewriter) const override {
134     if (!isInsideLaunch(op))
135       return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
136 
137     Value memref = op.getMemref();
138     if (!needFlatten(memref))
139       return rewriter.notifyMatchFailure(op, "nothing to do");
140 
141     if (!checkLayout(memref))
142       return rewriter.notifyMatchFailure(op, "unsupported layout");
143 
144     Location loc = op.getLoc();
145     Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
146     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
147     return success();
148   }
149 };
150 
151 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
152   using OpRewritePattern::OpRewritePattern;
153 
154   LogicalResult matchAndRewrite(memref::StoreOp op,
155                                 PatternRewriter &rewriter) const override {
156     if (!isInsideLaunch(op))
157       return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
158 
159     Value memref = op.getMemref();
160     if (!needFlatten(memref))
161       return rewriter.notifyMatchFailure(op, "nothing to do");
162 
163     if (!checkLayout(memref))
164       return rewriter.notifyMatchFailure(op, "unsupported layout");
165 
166     Location loc = op.getLoc();
167     Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
168     Value value = op.getValue();
169     rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
170     return success();
171   }
172 };
173 
174 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
175   using OpRewritePattern::OpRewritePattern;
176 
177   LogicalResult matchAndRewrite(memref::SubViewOp op,
178                                 PatternRewriter &rewriter) const override {
179     if (!isInsideLaunch(op))
180       return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
181 
182     Value memref = op.getSource();
183     if (!needFlatten(memref))
184       return rewriter.notifyMatchFailure(op, "nothing to do");
185 
186     if (!checkLayout(memref))
187       return rewriter.notifyMatchFailure(op, "unsupported layout");
188 
189     Location loc = op.getLoc();
190     SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
191     SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
192     SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
193     auto &&[base, finalOffset, strides] =
194         getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
195 
196     auto srcType = cast<MemRefType>(memref.getType());
197     auto resultType = cast<MemRefType>(op.getType());
198     unsigned subRank = static_cast<unsigned>(resultType.getRank());
199 
200     llvm::SmallBitVector droppedDims = op.getDroppedDims();
201 
202     SmallVector<OpFoldResult> finalSizes;
203     finalSizes.reserve(subRank);
204 
205     SmallVector<OpFoldResult> finalStrides;
206     finalStrides.reserve(subRank);
207 
208     for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
209       if (droppedDims.test(i))
210         continue;
211 
212       finalSizes.push_back(subSizes[i]);
213       finalStrides.push_back(strides[i]);
214     }
215 
216     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
217         op, resultType, base, finalOffset, finalSizes, finalStrides);
218     return success();
219   }
220 };
221 
222 struct GpuDecomposeMemrefsPass
223     : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
224 
225   void runOnOperation() override {
226     RewritePatternSet patterns(&getContext());
227 
228     populateGpuDecomposeMemrefsPatterns(patterns);
229 
230     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
231       return signalPassFailure();
232   }
233 };
234 
235 } // namespace
236 
237 void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
238   patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
239       patterns.getContext());
240 }
241 
242 std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
243   return std::make_unique<GpuDecomposeMemrefsPass>();
244 }
245