1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements transforms to optimize accesses to shared memory. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/NVGPU/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 19 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 20 #include "mlir/Dialect/NVGPU/Transforms/Utils.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/Support/MathExtras.h" 25 26 namespace mlir { 27 namespace nvgpu { 28 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY 29 #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc" 30 } // namespace nvgpu 31 } // namespace mlir 32 33 using namespace mlir; 34 using namespace mlir::nvgpu; 35 36 /// The size of a shared memory line according to NV documentation. 37 constexpr int64_t kSharedMemoryLineSizeBytes = 128; 38 /// We optimize for 128bit accesses, but this can be made an argument in the 39 /// future. 40 constexpr int64_t kDefaultVectorSizeBits = 128; 41 42 /// Uses `srcIndexValue` to permute `tgtIndexValue` via 43 /// `result = xor(floordiv(srcIdxVal,permuteEveryN), 44 /// floordiv(tgtIdxVal,vectorSize))) 45 /// + tgtIdxVal % vectorSize` 46 /// This is done using an optimized sequence of `arith` operations. 47 static Value permuteVectorOffset(OpBuilder &b, Location loc, 48 ArrayRef<Value> indices, MemRefType memrefTy, 49 int64_t srcDim, int64_t tgtDim) { 50 // Adjust the src index to change how often the permutation changes 51 // if necessary. 52 Value src = indices[srcDim]; 53 54 // We only want to permute every N iterations of the target dim where N is 55 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). 56 const int64_t permuteEveryN = std::max<int64_t>( 57 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * 58 memrefTy.getElementTypeBitWidth()) / 59 8)); 60 61 // clang-format off 62 // Index bit representation (b0 = least significant bit) for dim(1) 63 // of a `memref<?x?xDT>` is as follows: 64 // N := log2(128/elementSizeBits) 65 // M := log2(dimSize(1)) 66 // then 67 // bits[0:N] = sub-vector element offset 68 // bits[N:M] = vector index 69 // clang-format on 70 int64_t n = 71 llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); 72 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); 73 74 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. 75 int64_t mask = (1LL << (m - n)) - 1; 76 if (permuteEveryN > 1) 77 mask = mask << llvm::Log2_64(permuteEveryN); 78 Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask); 79 srcBits = b.create<arith::AndIOp>(loc, src, srcBits); 80 81 // Use the src bits to permute the target bits b[N:M] containing the 82 // vector offset. 83 if (permuteEveryN > 1) { 84 int64_t shlBits = n - llvm::Log2_64(permuteEveryN); 85 if (shlBits > 0) { 86 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits); 87 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); 88 } else if (shlBits < 0) { 89 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits); 90 srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal); 91 } 92 } else { 93 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n); 94 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); 95 } 96 97 Value permutedVectorIdx = 98 b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits); 99 return permutedVectorIdx; 100 } 101 102 static void transformIndices(OpBuilder &builder, Location loc, 103 SmallVector<Value, 4> &indices, 104 MemRefType memrefTy, int64_t srcDim, 105 int64_t tgtDim) { 106 indices[tgtDim] = 107 permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); 108 } 109 110 /// Return all operations within `parentOp` that read from or write to 111 /// `shmMemRef`. 112 static LogicalResult 113 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, 114 SmallVector<Operation *, 16> &readOps, 115 SmallVector<Operation *, 16> &writeOps) { 116 parentOp->walk([&](Operation *op) { 117 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op); 118 if (!iface) 119 return; 120 std::optional<MemoryEffects::EffectInstance> effect = 121 iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef); 122 if (effect) { 123 readOps.push_back(op); 124 return; 125 } 126 effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef); 127 if (effect) 128 writeOps.push_back(op); 129 }); 130 131 // Restrict to a supported set of ops. We also require at least 2D access, 132 // although this could be relaxed. 133 if (llvm::any_of(readOps, [](Operation *op) { 134 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) || 135 getIndices(op).size() < 2; 136 })) 137 return failure(); 138 if (llvm::any_of(writeOps, [](Operation *op) { 139 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>( 140 op) || 141 getIndices(op).size() < 2; 142 })) 143 return failure(); 144 145 return success(); 146 } 147 148 llvm::LogicalResult 149 mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, 150 Value memrefValue) { 151 auto memRefType = dyn_cast<MemRefType>(memrefValue.getType()); 152 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) 153 return failure(); 154 155 // Not support 0D MemRefs. 156 if (memRefType.getRank() == 0) 157 return failure(); 158 159 // Abort if the given value has any sub-views; we do not do any alias 160 // analysis. 161 bool hasSubView = false; 162 parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); 163 if (hasSubView) 164 return failure(); 165 166 // Check if this is necessary given the assumption of 128b accesses: 167 // If dim[rank-1] is small enough to fit 8 rows in a 128B line. 168 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); 169 const int64_t rowsPerLine = 170 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / 171 rowSize; 172 const int64_t threadGroupSize = 173 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); 174 if (rowsPerLine >= threadGroupSize) 175 return failure(); 176 177 // Get sets of operations within the function that read/write to shared 178 // memory. 179 SmallVector<Operation *, 16> shmReadOps; 180 SmallVector<Operation *, 16> shmWriteOps; 181 if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, 182 shmWriteOps))) 183 return failure(); 184 185 if (shmReadOps.empty() || shmWriteOps.empty()) 186 return failure(); 187 188 OpBuilder builder(parentOp->getContext()); 189 190 int64_t tgtDim = memRefType.getRank() - 1; 191 int64_t srcDim = memRefType.getRank() - 2; 192 193 // Transform indices for the ops writing to shared memory. 194 while (!shmWriteOps.empty()) { 195 Operation *shmWriteOp = shmWriteOps.back(); 196 shmWriteOps.pop_back(); 197 builder.setInsertionPoint(shmWriteOp); 198 199 auto indices = getIndices(shmWriteOp); 200 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); 201 transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, 202 memRefType, srcDim, tgtDim); 203 setIndices(shmWriteOp, transformedIndices); 204 } 205 206 // Transform indices for the ops reading from shared memory. 207 while (!shmReadOps.empty()) { 208 Operation *shmReadOp = shmReadOps.back(); 209 shmReadOps.pop_back(); 210 builder.setInsertionPoint(shmReadOp); 211 212 auto indices = getIndices(shmReadOp); 213 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); 214 transformIndices(builder, shmReadOp->getLoc(), transformedIndices, 215 memRefType, srcDim, tgtDim); 216 setIndices(shmReadOp, transformedIndices); 217 } 218 219 return success(); 220 } 221 222 namespace { 223 class OptimizeSharedMemoryPass 224 : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> { 225 public: 226 OptimizeSharedMemoryPass() = default; 227 228 void runOnOperation() override { 229 Operation *op = getOperation(); 230 SmallVector<memref::AllocOp> shmAllocOps; 231 op->walk([&](memref::AllocOp allocOp) { 232 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) 233 return; 234 shmAllocOps.push_back(allocOp); 235 }); 236 for (auto allocOp : shmAllocOps) { 237 if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), 238 allocOp.getMemref()))) 239 return; 240 } 241 } 242 }; 243 } // namespace 244 245 std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() { 246 return std::make_unique<OptimizeSharedMemoryPass>(); 247 } 248