xref: /llvm-project/mlir/lib/Bindings/Python/DialectNVGPU.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===--- DialectNVGPU.cpp - Pybind module for NVGPU dialect API support ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/NVGPU.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/Bindings/Python/NanobindAdaptors.h"
12 #include "mlir/Bindings/Python/Nanobind.h"
13 
14 namespace nb = nanobind;
15 using namespace llvm;
16 using namespace mlir;
17 using namespace mlir::python;
18 using namespace mlir::python::nanobind_adaptors;
19 
20 static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
21   auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
22       m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
23 
24   nvgpuTensorMapDescriptorType.def_classmethod(
25       "get",
26       [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
27          int oobFill, int interleave, MlirContext ctx) {
28         return cls(mlirNVGPUTensorMapDescriptorTypeGet(
29             ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
30       },
31       "Gets an instance of TensorMapDescriptorType in the same context",
32       nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
33       nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
34       nb::arg("ctx").none() = nb::none());
35 }
36 
37 NB_MODULE(_mlirDialectsNVGPU, m) {
38   m.doc() = "MLIR NVGPU dialect.";
39 
40   populateDialectNVGPUSubmodule(m);
41 }
42