1 //===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the GPU kernel-related operations and puts them in the 10 // corresponding dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H 15 #define MLIR_DIALECT_GPU_IR_GPUDIALECT_H 16 17 #include "mlir/Bytecode/BytecodeOpInterface.h" 18 #include "mlir/Dialect/DLTI/Traits.h" 19 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/Dialect.h" 23 #include "mlir/IR/OpDefinition.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/RegionKindInterface.h" 26 #include "mlir/IR/SymbolTable.h" 27 #include "mlir/Interfaces/ControlFlowInterfaces.h" 28 #include "mlir/Interfaces/FunctionInterfaces.h" 29 #include "mlir/Interfaces/InferIntRangeInterface.h" 30 #include "mlir/Interfaces/InferTypeOpInterface.h" 31 #include "mlir/Interfaces/SideEffectInterfaces.h" 32 #include "llvm/ADT/STLExtras.h" 33 34 namespace mlir { 35 namespace gpu { 36 37 /// Utility class for the GPU dialect to represent triples of `Value`s 38 /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. 39 struct KernelDim3 { 40 Value x; 41 Value y; 42 Value z; 43 }; 44 45 class AsyncTokenType 46 : public Type::TypeBase<AsyncTokenType, Type, TypeStorage> { 47 public: 48 // Used for generic hooks in TypeBase. 49 using Base::Base; 50 51 static constexpr StringLiteral name = "gpu.async_token"; 52 }; 53 54 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape 55 /// and type. 56 struct MMAMatrixStorageType : public TypeStorage { 57 MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes, 58 Type elementType, StringRef operand) 59 : dimShapes(dimShapes), numDims(numDims), elementType(elementType), 60 operand(operand) {} 61 62 /// The hash key for uniquing. 63 using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>; 64 bool operator==(const KeyTy &key) const { 65 return key == KeyTy(getShape(), elementType, operand); 66 } 67 68 /// Construction. 69 static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator, 70 const KeyTy &key) { 71 ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key)); 72 StringRef operand = allocator.copyInto(std::get<2>(key)); 73 74 return new (allocator.allocate<MMAMatrixStorageType>()) 75 MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key), 76 operand); 77 } 78 79 ArrayRef<int64_t> getShape() const { 80 return ArrayRef<int64_t>(dimShapes, numDims); 81 } 82 83 StringRef getOperand() const { return operand; } 84 85 /// Reference to the shape of the MMA matrix. 86 const int64_t *dimShapes; 87 88 /// Number of dimensions in the MMA matrix. 89 unsigned numDims; 90 91 /// Element type of elements held in the MMA matrix. 92 Type elementType; 93 94 /// MMA operand that this MMAMatrix holds. The general form of operation this 95 /// type supports is given by the equation C += A*B. This field specifies 96 /// which operand in the given equation is held by this type. The valid values 97 /// are "AOp", "BOp" and "COp". 98 StringRef operand; 99 }; 100 101 /// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply 102 /// accumulate operations. MMAMatrices are taken as direct operands by these 103 /// operations and are also produced as results. These matrices are meant to 104 /// reside in the registers. A limited number of pointwise operations can be 105 /// performed on these matrices, i.e., operations which operate uniformly on 106 /// all the elements in the matrix and do not change the order of matrix 107 /// elements. The above conditions exist because the layout of matrix elements 108 /// inside the matrix is opaque i.e., the elements may be present in the 109 /// matrix in any order. The general usage of this type is shown as follows:- 110 /// 111 /// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 : 112 /// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> 113 /// 114 /// The MMAMatrixType describes the shape of the matrix being loaded and the 115 /// operand being loaded too. The operand needs to be specified to aid the 116 /// lowering of this type to dialects such as NVVM where each workitem may 117 /// hold different amount of elements depending on the elementType of the 118 /// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type 119 /// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage 120 /// are:- 121 /// 122 /// %3 = gpu.subgroup_mma_compute %0, %1, %2 : 123 /// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> 124 /// -> !gpu.mma_matrix<16x16xf32, "COp"> 125 /// 126 /// 127 /// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 128 /// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> 129 // TODO: consider moving this to ODS. 130 class MMAMatrixType 131 : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> { 132 public: 133 using Base::Base; 134 135 static constexpr StringLiteral name = "gpu.mma_matrix"; 136 137 /// Get MMAMatrixType and verify construction Invariants. 138 static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType, 139 StringRef operand); 140 141 /// Get MMAMatrixType at a particular location and verify construction 142 /// Invariants. 143 static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError, 144 ArrayRef<int64_t> shape, Type elementType, 145 StringRef operand); 146 147 /// Check if a type is valid a MMAMatrixType elementType. 148 static bool isValidElementType(Type elementType); 149 150 /// Verify that shape and elementType are actually allowed for the 151 /// MMAMatrixType. 152 static LogicalResult 153 verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 154 ArrayRef<int64_t> shape, Type elementType, 155 StringRef operand); 156 157 /// Get number of dims. 158 unsigned getNumDims() const; 159 160 /// Get shape of the matrix. 161 ArrayRef<int64_t> getShape() const; 162 163 /// Get elementType of a single element. 164 Type getElementType() const; 165 166 /// The general form of operation this type supports is given by the equation 167 /// C += A*B. This function returns which operand in the given equation is 168 /// held by this type. String returned can be one of"AOp", "BOp" and "COp". 169 StringRef getOperand() const; 170 }; 171 172 // Adds a `gpu.async.token` to the front of the argument list. 173 void addAsyncDependency(Operation *op, Value token); 174 175 // Handle types for sparse. 176 enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp }; 177 178 class SparseDnTensorHandleType 179 : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> { 180 public: 181 using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type, 182 TypeStorage>::Base; 183 using Base::Base; 184 185 static constexpr StringLiteral name = "gpu.sparse.dntensor_handle"; 186 }; 187 188 class SparseSpMatHandleType 189 : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> { 190 public: 191 using Base = 192 typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base; 193 using Base::Base; 194 195 static constexpr StringLiteral name = "gpu.sparse.spmat_handle"; 196 }; 197 198 class SparseSpGEMMOpHandleType 199 : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> { 200 public: 201 using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type, 202 TypeStorage>::Base; 203 using Base::Base; 204 205 static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle"; 206 }; 207 208 } // namespace gpu 209 } // namespace mlir 210 211 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.h.inc" 212 213 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.h.inc" 214 215 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc" 216 217 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 218 219 #define GET_ATTRDEF_CLASSES 220 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc" 221 222 #define GET_OP_CLASSES 223 #include "mlir/Dialect/GPU/IR/GPUOps.h.inc" 224 225 #endif // MLIR_DIALECT_GPU_IR_GPUDIALECT_H 226