//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements transforms to optimize accesses to shared memory. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/Transforms/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" namespace mlir { namespace nvgpu { #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc" } // namespace nvgpu } // namespace mlir using namespace mlir; using namespace mlir::nvgpu; /// The size of a shared memory line according to NV documentation. constexpr int64_t kSharedMemoryLineSizeBytes = 128; /// We optimize for 128bit accesses, but this can be made an argument in the /// future. constexpr int64_t kDefaultVectorSizeBits = 128; /// Uses `srcIndexValue` to permute `tgtIndexValue` via /// `result = xor(floordiv(srcIdxVal,permuteEveryN), /// floordiv(tgtIdxVal,vectorSize))) /// + tgtIdxVal % vectorSize` /// This is done using an optimized sequence of `arith` operations. static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim) { // Adjust the src index to change how often the permutation changes // if necessary. Value src = indices[srcDim]; // We only want to permute every N iterations of the target dim where N is // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). const int64_t permuteEveryN = std::max( 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * memrefTy.getElementTypeBitWidth()) / 8)); // clang-format off // Index bit representation (b0 = least significant bit) for dim(1) // of a `memref` is as follows: // N := log2(128/elementSizeBits) // M := log2(dimSize(1)) // then // bits[0:N] = sub-vector element offset // bits[N:M] = vector index // clang-format on int64_t n = llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. int64_t mask = (1LL << (m - n)) - 1; if (permuteEveryN > 1) mask = mask << llvm::Log2_64(permuteEveryN); Value srcBits = b.create(loc, mask); srcBits = b.create(loc, src, srcBits); // Use the src bits to permute the target bits b[N:M] containing the // vector offset. if (permuteEveryN > 1) { int64_t shlBits = n - llvm::Log2_64(permuteEveryN); if (shlBits > 0) { Value finalShiftVal = b.create(loc, shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } else if (shlBits < 0) { Value finalShiftVal = b.create(loc, -1 * shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } } else { Value finalShiftVal = b.create(loc, n); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } Value permutedVectorIdx = b.create(loc, indices[tgtDim], srcBits); return permutedVectorIdx; } static void transformIndices(OpBuilder &builder, Location loc, SmallVector &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim) { indices[tgtDim] = permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); } /// Return all operations within `parentOp` that read from or write to /// `shmMemRef`. static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector &readOps, SmallVector &writeOps) { parentOp->walk([&](Operation *op) { MemoryEffectOpInterface iface = dyn_cast(op); if (!iface) return; std::optional effect = iface.getEffectOnValue(shmMemRef); if (effect) { readOps.push_back(op); return; } effect = iface.getEffectOnValue(shmMemRef); if (effect) writeOps.push_back(op); }); // Restrict to a supported set of ops. We also require at least 2D access, // although this could be relaxed. if (llvm::any_of(readOps, [](Operation *op) { return !isa(op) || getIndices(op).size() < 2; })) return failure(); if (llvm::any_of(writeOps, [](Operation *op) { return !isa( op) || getIndices(op).size() < 2; })) return failure(); return success(); } llvm::LogicalResult mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue) { auto memRefType = dyn_cast(memrefValue.getType()); if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) return failure(); // Not support 0D MemRefs. if (memRefType.getRank() == 0) return failure(); // Abort if the given value has any sub-views; we do not do any alias // analysis. bool hasSubView = false; parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); if (hasSubView) return failure(); // Check if this is necessary given the assumption of 128b accesses: // If dim[rank-1] is small enough to fit 8 rows in a 128B line. const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); const int64_t rowsPerLine = (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / rowSize; const int64_t threadGroupSize = 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); if (rowsPerLine >= threadGroupSize) return failure(); // Get sets of operations within the function that read/write to shared // memory. SmallVector shmReadOps; SmallVector shmWriteOps; if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, shmWriteOps))) return failure(); if (shmReadOps.empty() || shmWriteOps.empty()) return failure(); OpBuilder builder(parentOp->getContext()); int64_t tgtDim = memRefType.getRank() - 1; int64_t srcDim = memRefType.getRank() - 2; // Transform indices for the ops writing to shared memory. while (!shmWriteOps.empty()) { Operation *shmWriteOp = shmWriteOps.back(); shmWriteOps.pop_back(); builder.setInsertionPoint(shmWriteOp); auto indices = getIndices(shmWriteOp); SmallVector transformedIndices(indices.begin(), indices.end()); transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, memRefType, srcDim, tgtDim); setIndices(shmWriteOp, transformedIndices); } // Transform indices for the ops reading from shared memory. while (!shmReadOps.empty()) { Operation *shmReadOp = shmReadOps.back(); shmReadOps.pop_back(); builder.setInsertionPoint(shmReadOp); auto indices = getIndices(shmReadOp); SmallVector transformedIndices(indices.begin(), indices.end()); transformIndices(builder, shmReadOp->getLoc(), transformedIndices, memRefType, srcDim, tgtDim); setIndices(shmReadOp, transformedIndices); } return success(); } namespace { class OptimizeSharedMemoryPass : public nvgpu::impl::OptimizeSharedMemoryBase { public: OptimizeSharedMemoryPass() = default; void runOnOperation() override { Operation *op = getOperation(); SmallVector shmAllocOps; op->walk([&](memref::AllocOp allocOp) { if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) return; shmAllocOps.push_back(allocOp); }); for (auto allocOp : shmAllocOps) { if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), allocOp.getMemref()))) return; } } }; } // namespace std::unique_ptr mlir::nvgpu::createOptimizeSharedMemoryPass() { return std::make_unique(); }