xref: /llvm-project/mlir/lib/Bindings/Python/DialectLLVM.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <string>
10 
11 #include "mlir-c/Dialect/LLVM.h"
12 #include "mlir-c/IR.h"
13 #include "mlir-c/Support.h"
14 #include "mlir/Bindings/Python/Diagnostics.h"
15 #include "mlir/Bindings/Python/NanobindAdaptors.h"
16 #include "mlir/Bindings/Python/Nanobind.h"
17 
18 namespace nb = nanobind;
19 
20 using namespace nanobind::literals;
21 
22 using namespace llvm;
23 using namespace mlir;
24 using namespace mlir::python;
25 using namespace mlir::python::nanobind_adaptors;
26 
27 void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
28 
29   //===--------------------------------------------------------------------===//
30   // StructType
31   //===--------------------------------------------------------------------===//
32 
33   auto llvmStructType =
34       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
35 
36   llvmStructType.def_classmethod(
37       "get_literal",
38       [](nb::object cls, const std::vector<MlirType> &elements, bool packed,
39          MlirLocation loc) {
40         CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
41 
42         MlirType type = mlirLLVMStructTypeLiteralGetChecked(
43             loc, elements.size(), elements.data(), packed);
44         if (mlirTypeIsNull(type)) {
45           throw nb::value_error(scope.takeMessage().c_str());
46         }
47         return cls(type);
48       },
49       "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
50       "loc"_a.none() = nb::none());
51 
52   llvmStructType.def_classmethod(
53       "get_identified",
54       [](nb::object cls, const std::string &name, MlirContext context) {
55         return cls(mlirLLVMStructTypeIdentifiedGet(
56             context, mlirStringRefCreate(name.data(), name.size())));
57       },
58       "cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());
59 
60   llvmStructType.def_classmethod(
61       "get_opaque",
62       [](nb::object cls, const std::string &name, MlirContext context) {
63         return cls(mlirLLVMStructTypeOpaqueGet(
64             context, mlirStringRefCreate(name.data(), name.size())));
65       },
66       "cls"_a, "name"_a, "context"_a.none() = nb::none());
67 
68   llvmStructType.def(
69       "set_body",
70       [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
71         MlirLogicalResult result = mlirLLVMStructTypeSetBody(
72             self, elements.size(), elements.data(), packed);
73         if (!mlirLogicalResultIsSuccess(result)) {
74           throw nb::value_error(
75               "Struct body already set to different content.");
76         }
77       },
78       "elements"_a, nb::kw_only(), "packed"_a = false);
79 
80   llvmStructType.def_classmethod(
81       "new_identified",
82       [](nb::object cls, const std::string &name,
83          const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
84         return cls(mlirLLVMStructTypeIdentifiedNewGet(
85             ctx, mlirStringRefCreate(name.data(), name.length()),
86             elements.size(), elements.data(), packed));
87       },
88       "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
89       "context"_a.none() = nb::none());
90 
91   llvmStructType.def_property_readonly(
92       "name", [](MlirType type) -> std::optional<std::string> {
93         if (mlirLLVMStructTypeIsLiteral(type))
94           return std::nullopt;
95 
96         MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
97         return StringRef(stringRef.data, stringRef.length).str();
98       });
99 
100   llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
101     // Don't crash in absence of a body.
102     if (mlirLLVMStructTypeIsOpaque(type))
103       return nb::none();
104 
105     nb::list body;
106     for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
107          ++i) {
108       body.append(mlirLLVMStructTypeGetElementType(type, i));
109     }
110     return body;
111   });
112 
113   llvmStructType.def_property_readonly(
114       "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
115 
116   llvmStructType.def_property_readonly(
117       "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
118 
119   //===--------------------------------------------------------------------===//
120   // PointerType
121   //===--------------------------------------------------------------------===//
122 
123   mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
124       .def_classmethod(
125           "get",
126           [](nb::object cls, std::optional<unsigned> addressSpace,
127              MlirContext context) {
128             CollectDiagnosticsToStringScope scope(context);
129             MlirType type = mlirLLVMPointerTypeGet(
130                 context, addressSpace.has_value() ? *addressSpace : 0);
131             if (mlirTypeIsNull(type)) {
132               throw nb::value_error(scope.takeMessage().c_str());
133             }
134             return cls(type);
135           },
136           "cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
137           "context"_a.none() = nb::none())
138       .def_property_readonly("address_space", [](MlirType type) {
139         return mlirLLVMPointerTypeGetAddressSpace(type);
140       });
141 }
142 
143 NB_MODULE(_mlirDialectsLLVM, m) {
144   m.doc() = "MLIR LLVM Dialect";
145 
146   populateDialectLLVMSubmodule(m);
147 }
148