xref: /llvm-project/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (revision 8900c09ebfd782bfd41bac63ac5266f80fe29602)
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to optimize accesses to shared memory.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
19 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
20 #include "mlir/Dialect/NVGPU/Transforms/Utils.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 namespace mlir {
27 namespace nvgpu {
28 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
29 #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
30 } // namespace nvgpu
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::nvgpu;
35 
36 /// The size of a shared memory line according to NV documentation.
37 constexpr int64_t kSharedMemoryLineSizeBytes = 128;
38 /// We optimize for 128bit accesses, but this can be made an argument in the
39 /// future.
40 constexpr int64_t kDefaultVectorSizeBits = 128;
41 
42 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
43 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
44 ///               floordiv(tgtIdxVal,vectorSize)))
45 ///            + tgtIdxVal % vectorSize`
46 /// This is done using an optimized sequence of `arith` operations.
47 static Value permuteVectorOffset(OpBuilder &b, Location loc,
48                                  ArrayRef<Value> indices, MemRefType memrefTy,
49                                  int64_t srcDim, int64_t tgtDim) {
50   // Adjust the src index to change how often the permutation changes
51   // if necessary.
52   Value src = indices[srcDim];
53 
54   // We only want to permute every N iterations of the target dim where N is
55   // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
56   const int64_t permuteEveryN = std::max<int64_t>(
57       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
58                                         memrefTy.getElementTypeBitWidth()) /
59                                        8));
60 
61   // clang-format off
62   // Index bit representation (b0 = least significant bit) for dim(1)
63   // of a `memref<?x?xDT>` is as follows:
64   // N := log2(128/elementSizeBits)
65   // M := log2(dimSize(1))
66   // then
67   // bits[0:N] = sub-vector element offset
68   // bits[N:M] = vector index
69   // clang-format on
70   int64_t n =
71       llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
72   int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
73 
74   // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
75   int64_t mask = (1LL << (m - n)) - 1;
76   if (permuteEveryN > 1)
77     mask = mask << llvm::Log2_64(permuteEveryN);
78   Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
79   srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
80 
81   // Use the src bits to permute the target bits b[N:M] containing the
82   // vector offset.
83   if (permuteEveryN > 1) {
84     int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
85     if (shlBits > 0) {
86       Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
87       srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
88     } else if (shlBits < 0) {
89       Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
90       srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
91     }
92   } else {
93     Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
94     srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
95   }
96 
97   Value permutedVectorIdx =
98       b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
99   return permutedVectorIdx;
100 }
101 
102 static void transformIndices(OpBuilder &builder, Location loc,
103                              SmallVector<Value, 4> &indices,
104                              MemRefType memrefTy, int64_t srcDim,
105                              int64_t tgtDim) {
106   indices[tgtDim] =
107       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
108 }
109 
110 /// Return all operations within `parentOp` that read from or write to
111 /// `shmMemRef`.
112 static LogicalResult
113 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
114                       SmallVector<Operation *, 16> &readOps,
115                       SmallVector<Operation *, 16> &writeOps) {
116   parentOp->walk([&](Operation *op) {
117     MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
118     if (!iface)
119       return;
120     std::optional<MemoryEffects::EffectInstance> effect =
121         iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
122     if (effect) {
123       readOps.push_back(op);
124       return;
125     }
126     effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
127     if (effect)
128       writeOps.push_back(op);
129   });
130 
131   // Restrict to a supported set of ops. We also require at least 2D access,
132   // although this could be relaxed.
133   if (llvm::any_of(readOps, [](Operation *op) {
134         return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
135                getIndices(op).size() < 2;
136       }))
137     return failure();
138   if (llvm::any_of(writeOps, [](Operation *op) {
139         return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
140                    op) ||
141                getIndices(op).size() < 2;
142       }))
143     return failure();
144 
145   return success();
146 }
147 
148 llvm::LogicalResult
149 mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
150                                                 Value memrefValue) {
151   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
152   if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
153     return failure();
154 
155   // Not support 0D MemRefs.
156   if (memRefType.getRank() == 0)
157     return failure();
158 
159   // Abort if the given value has any sub-views; we do not do any alias
160   // analysis.
161   bool hasSubView = false;
162   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
163   if (hasSubView)
164     return failure();
165 
166   // Check if this is necessary given the assumption of 128b accesses:
167   // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
168   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
169   const int64_t rowsPerLine =
170       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
171       rowSize;
172   const int64_t threadGroupSize =
173       1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
174   if (rowsPerLine >= threadGroupSize)
175     return failure();
176 
177   // Get sets of operations within the function that read/write to shared
178   // memory.
179   SmallVector<Operation *, 16> shmReadOps;
180   SmallVector<Operation *, 16> shmWriteOps;
181   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
182                                    shmWriteOps)))
183     return failure();
184 
185   if (shmReadOps.empty() || shmWriteOps.empty())
186     return failure();
187 
188   OpBuilder builder(parentOp->getContext());
189 
190   int64_t tgtDim = memRefType.getRank() - 1;
191   int64_t srcDim = memRefType.getRank() - 2;
192 
193   // Transform indices for the ops writing to shared memory.
194   while (!shmWriteOps.empty()) {
195     Operation *shmWriteOp = shmWriteOps.back();
196     shmWriteOps.pop_back();
197     builder.setInsertionPoint(shmWriteOp);
198 
199     auto indices = getIndices(shmWriteOp);
200     SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
201     transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
202                      memRefType, srcDim, tgtDim);
203     setIndices(shmWriteOp, transformedIndices);
204   }
205 
206   // Transform indices for the ops reading from shared memory.
207   while (!shmReadOps.empty()) {
208     Operation *shmReadOp = shmReadOps.back();
209     shmReadOps.pop_back();
210     builder.setInsertionPoint(shmReadOp);
211 
212     auto indices = getIndices(shmReadOp);
213     SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
214     transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
215                      memRefType, srcDim, tgtDim);
216     setIndices(shmReadOp, transformedIndices);
217   }
218 
219   return success();
220 }
221 
222 namespace {
223 class OptimizeSharedMemoryPass
224     : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
225 public:
226   OptimizeSharedMemoryPass() = default;
227 
228   void runOnOperation() override {
229     Operation *op = getOperation();
230     SmallVector<memref::AllocOp> shmAllocOps;
231     op->walk([&](memref::AllocOp allocOp) {
232       if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
233         return;
234       shmAllocOps.push_back(allocOp);
235     });
236     for (auto allocOp : shmAllocOps) {
237       if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
238                                                     allocOp.getMemref())))
239         return;
240     }
241   }
242 };
243 } // namespace
244 
245 std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
246   return std::make_unique<OptimizeSharedMemoryPass>();
247 }
248