xref: /llvm-project/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h (revision 2e6cc79f816d942ab09d6a310cd925c1da148aa9)
1 //===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the NVVM IR dialect in MLIR, containing NVVM operations and
10 // NVVM specific extensions to the LLVM type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
15 #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
16 
17 #include "mlir/Bytecode/BytecodeOpInterface.h"
18 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Interfaces/InferIntRangeInterface.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
25 #include "llvm/IR/IntrinsicsNVPTX.h"
26 
27 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
28 
29 namespace mlir {
30 namespace NVVM {
31 
32 // Shared memory has 128-bit alignment
33 constexpr int kSharedMemoryAlignmentBit = 128;
34 
35 /// NVVM memory space identifiers.
36 enum NVVMMemorySpace {
37   /// Global memory space identifier.
38   kGlobalMemorySpace = 1,
39   /// Shared memory space identifier.
40   kSharedMemorySpace = 3,
41   /// Constant memory space identifier.
42   kConstantMemorySpace = 4
43 };
44 
45 /// Return the element type and number of elements associated with a wmma matrix
46 /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
47 /// WMMA_REGS structure.
48 std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
49                                              mlir::NVVM::MMAFrag frag, int nRow,
50                                              int nCol,
51                                              mlir::MLIRContext *context);
52 } // namespace NVVM
53 } // namespace mlir
54 
55 ///// Ops /////
56 #define GET_ATTRDEF_CLASSES
57 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.h.inc"
58 
59 #define GET_OP_CLASSES
60 #include "mlir/Dialect/LLVMIR/NVVMOps.h.inc"
61 
62 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc"
63 
64 #endif /* MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ */
65