1 //===- Transforms.h - NVGPU Dialect transformations --------------*- C++-*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares functions that assist transformations for the nvgpu 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #ifndef MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ 14 #define MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ 15 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 class RewriterBase; 20 21 namespace nvgpu { 22 23 /// 24 /// Passes 25 /// 26 27 /// Optimizes vectorized accesses to a shared memory buffer specified by 28 /// memrefValue. This transformation assumes the following: 29 /// 1) All relevant accesses to `memrefValue` are contained with `parentOp`. 30 /// 2) The function will fail precondition checks if any subviews are 31 /// taken of `memrefValue`. All reads/writes to `memrefValue` should occur 32 /// through `memrefValue` directly. 33 /// 34 /// Shared memory bank conflicts occur when multiple threads attempt to read or 35 /// write locations assigned to the same shared memory bank. For `2^N` byte 36 /// vectorized accesses, we need to be concerned with conflicts among threads 37 /// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation 38 /// changes any indexed memory access (vector.load, memref.load, nvgpu.ldmatrix, 39 /// etc) such that the final dimension's index value is permuted such that 40 /// `newColIndex = oldColIndex % vectorSize + 41 /// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the 42 /// index for the second-to last dimension and `perm[rowIndex]` is a permutation 43 /// function that depends on the row Index. The permutation function is chosen 44 /// to ensure that sequential distributed+vectorized reads/writes down a single 45 /// dimension of the memref have minimal conflicts. 46 llvm::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, 47 Value memrefValue); 48 49 /// 50 /// Rewrites patterns 51 /// 52 53 //===----------------------------------------------------------------------===// 54 // NVGPU transformation options exposed as auxiliary structs. 55 //===----------------------------------------------------------------------===// 56 /// Enum to control the lowering of `nvgpu.mmasync`. 57 enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 }; 58 59 /// Collect patterns to convert mma.sync on f32 input and rewrite 60 /// to use tensor cores with user provided level of accuracy: 61 /// (a) tf32 (1 mma.sync per warp-level matrix-multiply-accumulate) 62 /// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate) 63 /// Typically, tf32 tensor core acceleration comes at a cost 64 /// of accuracy from missing precision bits. While f32 has 23 precision 65 /// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the 66 /// precision bits by spliting each operand into two tf32 values 67 /// and issue three mma.sync tensor core operations. 68 void populateMmaSyncF32ToTF32Patterns( 69 RewritePatternSet &patterns, 70 nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32); 71 72 /// Convert global->shared vector transfers to async device copies. This 73 /// function looks for suitable vector transfers within the specified op and 74 /// converts them to "nvgpu.device_async_copy" ops. Consecutive copies are put 75 /// into the same sync group. If `bypassL1` is set, the "bypassL1" attribute is 76 /// set for suitable (i.e., transfer size 16 bytes) transfers. 77 void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1); 78 79 } // namespace nvgpu 80 } // namespace mlir 81 82 #endif // MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ 83