xref: /llvm-project/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (revision 7030280329c3a723a42304e92f9c207acb8ea731)
1 //===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the GPU kernel-related operations and puts them in the
10 // corresponding dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H
15 #define MLIR_DIALECT_GPU_IR_GPUDIALECT_H
16 
17 #include "mlir/Bytecode/BytecodeOpInterface.h"
18 #include "mlir/Dialect/DLTI/Traits.h"
19 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/RegionKindInterface.h"
26 #include "mlir/IR/SymbolTable.h"
27 #include "mlir/Interfaces/ControlFlowInterfaces.h"
28 #include "mlir/Interfaces/FunctionInterfaces.h"
29 #include "mlir/Interfaces/InferIntRangeInterface.h"
30 #include "mlir/Interfaces/InferTypeOpInterface.h"
31 #include "mlir/Interfaces/SideEffectInterfaces.h"
32 #include "llvm/ADT/STLExtras.h"
33 
34 namespace mlir {
35 namespace gpu {
36 
37 /// Utility class for the GPU dialect to represent triples of `Value`s
38 /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
39 struct KernelDim3 {
40   Value x;
41   Value y;
42   Value z;
43 };
44 
45 class AsyncTokenType
46     : public Type::TypeBase<AsyncTokenType, Type, TypeStorage> {
47 public:
48   // Used for generic hooks in TypeBase.
49   using Base::Base;
50 
51   static constexpr StringLiteral name = "gpu.async_token";
52 };
53 
54 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
55 /// and type.
56 struct MMAMatrixStorageType : public TypeStorage {
57   MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
58                        Type elementType, StringRef operand)
59       : dimShapes(dimShapes), numDims(numDims), elementType(elementType),
60         operand(operand) {}
61 
62   /// The hash key for uniquing.
63   using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
64   bool operator==(const KeyTy &key) const {
65     return key == KeyTy(getShape(), elementType, operand);
66   }
67 
68   /// Construction.
69   static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
70                                          const KeyTy &key) {
71     ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
72     StringRef operand = allocator.copyInto(std::get<2>(key));
73 
74     return new (allocator.allocate<MMAMatrixStorageType>())
75         MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
76                              operand);
77   }
78 
79   ArrayRef<int64_t> getShape() const {
80     return ArrayRef<int64_t>(dimShapes, numDims);
81   }
82 
83   StringRef getOperand() const { return operand; }
84 
85   /// Reference to the shape of the MMA matrix.
86   const int64_t *dimShapes;
87 
88   /// Number of dimensions in the MMA matrix.
89   unsigned numDims;
90 
91   /// Element type of elements held in the MMA matrix.
92   Type elementType;
93 
94   /// MMA operand that this MMAMatrix holds. The general form of operation this
95   /// type supports is given by the equation C += A*B. This field specifies
96   /// which operand in the given equation is held by this type. The valid values
97   /// are "AOp", "BOp" and "COp".
98   StringRef operand;
99 };
100 
101 /// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
102 /// accumulate operations. MMAMatrices are taken as direct operands by these
103 /// operations and are also produced as results. These matrices are meant to
104 /// reside in the registers. A limited number of pointwise operations can be
105 /// performed on these matrices, i.e., operations which operate uniformly on
106 /// all the elements in the matrix and do not change the order of matrix
107 /// elements. The above conditions exist because the layout of matrix elements
108 /// inside the matrix is opaque i.e., the elements may be present in the
109 /// matrix in any order. The general usage of this type is shown as follows:-
110 ///
111 ///   %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
112 ///           index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
113 ///
114 /// The MMAMatrixType describes the shape of the matrix being loaded and the
115 /// operand being loaded too. The operand needs to be specified to aid the
116 /// lowering of this type to dialects such as NVVM where each workitem may
117 /// hold different amount of elements depending on the elementType of the
118 /// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
119 /// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
120 /// are:-
121 ///
122 ///   %3 = gpu.subgroup_mma_compute %0, %1, %2 :
123 ///   !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
124 ///    -> !gpu.mma_matrix<16x16xf32, "COp">
125 ///
126 ///
127 ///   gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
128 ///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
129 // TODO: consider moving this to ODS.
130 class MMAMatrixType
131     : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
132 public:
133   using Base::Base;
134 
135   static constexpr StringLiteral name = "gpu.mma_matrix";
136 
137   /// Get MMAMatrixType and verify construction Invariants.
138   static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
139                            StringRef operand);
140 
141   /// Get MMAMatrixType at a particular location and verify construction
142   /// Invariants.
143   static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
144                                   ArrayRef<int64_t> shape, Type elementType,
145                                   StringRef operand);
146 
147   /// Check if a type is valid a MMAMatrixType elementType.
148   static bool isValidElementType(Type elementType);
149 
150   /// Verify that shape and elementType are actually allowed for the
151   /// MMAMatrixType.
152   static LogicalResult
153   verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
154                    ArrayRef<int64_t> shape, Type elementType,
155                    StringRef operand);
156 
157   /// Get number of dims.
158   unsigned getNumDims() const;
159 
160   /// Get shape of the matrix.
161   ArrayRef<int64_t> getShape() const;
162 
163   /// Get elementType of a single element.
164   Type getElementType() const;
165 
166   /// The general form of operation this type supports is given by the equation
167   /// C += A*B. This function returns which operand in the given equation is
168   /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
169   StringRef getOperand() const;
170 };
171 
172 // Adds a `gpu.async.token` to the front of the argument list.
173 void addAsyncDependency(Operation *op, Value token);
174 
175 // Handle types for sparse.
176 enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
177 
178 class SparseDnTensorHandleType
179     : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
180 public:
181   using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
182                                        TypeStorage>::Base;
183   using Base::Base;
184 
185   static constexpr StringLiteral name = "gpu.sparse.dntensor_handle";
186 };
187 
188 class SparseSpMatHandleType
189     : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
190 public:
191   using Base =
192       typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
193   using Base::Base;
194 
195   static constexpr StringLiteral name = "gpu.sparse.spmat_handle";
196 };
197 
198 class SparseSpGEMMOpHandleType
199     : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
200 public:
201   using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
202                                        TypeStorage>::Base;
203   using Base::Base;
204 
205   static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle";
206 };
207 
208 } // namespace gpu
209 } // namespace mlir
210 
211 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.h.inc"
212 
213 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.h.inc"
214 
215 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc"
216 
217 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
218 
219 #define GET_ATTRDEF_CLASSES
220 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc"
221 
222 #define GET_OP_CLASSES
223 #include "mlir/Dialect/GPU/IR/GPUOps.h.inc"
224 
225 #endif // MLIR_DIALECT_GPU_IR_GPUDIALECT_H
226