1 //===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides utilities to assist in the lowering of other dialects 10 // (e.g. Vector) to `nvgpu.mma.*` dialect operations. 11 // 12 //===----------------------------------------------------------------------===// 13 #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H 14 #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H 15 16 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/Types.h" 20 21 namespace mlir { 22 namespace nvgpu { 23 24 /// Represents the role of an operand in an MMA instruction: 25 /// `result := matmul(A, B) + C` 26 enum class MatMulOperandRole : int32_t { A = 0, B, C }; 27 28 /// Returns the first user of the `op` that is vector.contract. If no 29 /// vector.contract user exists, return failure. 30 FailureOr<vector::ContractionOp> getUserContract(Operation *op); 31 32 /// Collects information about a warp-level matrix operand represented by a 33 /// VectorType. 34 struct WarpMatrixInfo { 35 VectorType vectorType; 36 MatMulOperandRole operandRole; 37 }; 38 39 /// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the 40 /// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or 41 /// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result. 42 /// Otherwise, return failure. 43 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op); 44 45 /// Returns the number of bits in a single tile row. It is either 128, 256, or 46 /// 512 bits depending on the data type and` whether the operand is an 47 /// accumulator/result operand 48 int64_t inferTileWidthInBits(const WarpMatrixInfo &type); 49 50 /// Specifies information about the registers which compose a matrix fragment 51 /// according to the PTX documentation. 52 struct FragmentElementInfo { 53 Type registerLLVMType; 54 int64_t elementsPerRegister; 55 int64_t registerWidthBits; 56 int64_t numRegistersPerFragment; 57 }; 58 59 /// Returns a FragmentElementInfo struct describing the register types for the 60 /// given matrix fragment type. 61 FailureOr<FragmentElementInfo> 62 getMmaSyncRegisterType(const WarpMatrixInfo &type); 63 64 /// Returns an AffineMap which maps a two dimensions representing (laneId, 65 /// logicalValueId) and returns two results representing offsets within a 66 /// matrix operand. The offsets point to the values the thread is responsible 67 /// for (AKA the matrix fragment values) during a warp-collective matrix 68 /// operation. For a visual reference of this LaneId -> (row, col) mapping, 69 /// please see NVIDIA's PTX documentation: 70 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma 71 FailureOr<AffineMap> 72 getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, 73 const WarpMatrixInfo &fragmentType); 74 75 /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to 76 /// `nvvm.ldmatrix`. 77 struct LdMatrixParams { 78 VectorType fragmentType; 79 bool isAccum; 80 int64_t numTiles; 81 vector::IteratorType contiguousDimType; 82 NVVM::MMALayout targetLayout; 83 }; 84 85 /// Given `type` that contains info for a warp-matrix operand and whether or not 86 /// the load is a transposed load, return the LdMatrixParams. 87 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type, 88 bool transpose); 89 /// Returns an AffineMap which maps a single dimension representing the laneId 90 /// to two results representing offsets within the matrix operand that should 91 /// be the pointer locations a thread should pass to the ldmatrix instruction. 92 FailureOr<AffineMap> 93 getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, 94 const LdMatrixParams ¶ms); 95 96 /// Returns whether the `vector.transfer_read` instruction can be interpreted 97 /// as a warp-level cooperative matrix load operation. This function is meant to 98 /// be used to establish whether `op` is part of a chain of such warp-level 99 /// operations. 100 bool canLowerToWarpMatrixOperation(vector::TransferReadOp op); 101 102 /// Returns whether the `vector.transfer_write` instruction can be interpreted 103 /// as a warp-level cooperative matrix store operation. This function is meant 104 /// to be used to establish whether `op` is part of a chain of such warp-level 105 /// operations. 106 bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op); 107 108 } // namespace nvgpu 109 } // namespace mlir 110 111 #endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H 112