xref: /llvm-project/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h (revision cafb6284d18bbdb952ae6d5e4aa97912d57dbfb8)
1 //===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of other dialects
10 // (e.g. Vector) to `nvgpu.mma.*` dialect operations.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
14 #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
15 
16 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/Types.h"
20 
21 namespace mlir {
22 namespace nvgpu {
23 
24 /// Represents the role of an operand in an MMA instruction:
25 /// `result := matmul(A, B) + C`
26 enum class MatMulOperandRole : int32_t { A = 0, B, C };
27 
28 /// Returns the first user of the `op` that is vector.contract. If no
29 /// vector.contract user exists, return failure.
30 FailureOr<vector::ContractionOp> getUserContract(Operation *op);
31 
32 /// Collects information about a warp-level matrix operand represented by a
33 /// VectorType.
34 struct WarpMatrixInfo {
35   VectorType vectorType;
36   MatMulOperandRole operandRole;
37 };
38 
39 /// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the
40 /// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or
41 /// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result.
42 /// Otherwise, return failure.
43 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
44 
45 /// Returns the number of bits in a single tile row. It is either 128, 256, or
46 /// 512 bits depending on the data type and` whether the operand is an
47 /// accumulator/result operand
48 int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
49 
50 /// Specifies information about the registers which compose a matrix fragment
51 /// according to the PTX documentation.
52 struct FragmentElementInfo {
53   Type registerLLVMType;
54   int64_t elementsPerRegister;
55   int64_t registerWidthBits;
56   int64_t numRegistersPerFragment;
57 };
58 
59 /// Returns a FragmentElementInfo struct describing the register types for the
60 /// given matrix fragment type.
61 FailureOr<FragmentElementInfo>
62 getMmaSyncRegisterType(const WarpMatrixInfo &type);
63 
64 /// Returns an AffineMap which maps a two dimensions representing (laneId,
65 /// logicalValueId) and returns two results representing offsets within a
66 /// matrix operand. The offsets point to the values the thread is responsible
67 /// for (AKA the matrix fragment values) during a warp-collective matrix
68 /// operation. For a visual reference of this LaneId -> (row, col) mapping,
69 /// please see NVIDIA's PTX documentation:
70 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
71 FailureOr<AffineMap>
72 getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
73                                   const WarpMatrixInfo &fragmentType);
74 
75 /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
76 /// `nvvm.ldmatrix`.
77 struct LdMatrixParams {
78   VectorType fragmentType;
79   bool isAccum;
80   int64_t numTiles;
81   vector::IteratorType contiguousDimType;
82   NVVM::MMALayout targetLayout;
83 };
84 
85 /// Given `type` that contains info for a warp-matrix operand and whether or not
86 /// the load is a transposed load, return the LdMatrixParams.
87 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
88                                             bool transpose);
89 /// Returns an AffineMap which maps a single dimension representing the laneId
90 /// to two results representing offsets within the matrix operand that should
91 /// be the pointer locations a thread should pass to the ldmatrix instruction.
92 FailureOr<AffineMap>
93 getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
94                                const LdMatrixParams &params);
95 
96 /// Returns whether the `vector.transfer_read` instruction can be interpreted
97 /// as a warp-level cooperative matrix load operation. This function is meant to
98 /// be used to establish whether `op` is part of a chain of such warp-level
99 /// operations.
100 bool canLowerToWarpMatrixOperation(vector::TransferReadOp op);
101 
102 /// Returns whether the `vector.transfer_write` instruction can be interpreted
103 /// as a warp-level cooperative matrix store operation. This function is meant
104 /// to be used to establish whether `op` is part of a chain of such warp-level
105 /// operations.
106 bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op);
107 
108 } // namespace nvgpu
109 } // namespace mlir
110 
111 #endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
112