xref: /llvm-project/mlir/lib/Bindings/Python/DialectLLVM.cpp (revision afe75b4d5fcebd6fdd292ca1797de1b35cb687b0)
1 //===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <string>
10 
11 #include "mlir-c/Dialect/LLVM.h"
12 #include "mlir-c/IR.h"
13 #include "mlir-c/Support.h"
14 #include "mlir/Bindings/Python/Diagnostics.h"
15 #include "mlir/Bindings/Python/PybindAdaptors.h"
16 
17 namespace py = pybind11;
18 using namespace llvm;
19 using namespace mlir;
20 using namespace mlir::python;
21 using namespace mlir::python::adaptors;
22 
23 void populateDialectLLVMSubmodule(const pybind11::module &m) {
24 
25   //===--------------------------------------------------------------------===//
26   // StructType
27   //===--------------------------------------------------------------------===//
28 
29   auto llvmStructType =
30       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
31 
32   llvmStructType.def_classmethod(
33       "get_literal",
34       [](py::object cls, const std::vector<MlirType> &elements, bool packed,
35          MlirLocation loc) {
36         CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
37 
38         MlirType type = mlirLLVMStructTypeLiteralGetChecked(
39             loc, elements.size(), elements.data(), packed);
40         if (mlirTypeIsNull(type)) {
41           throw py::value_error(scope.takeMessage());
42         }
43         return cls(type);
44       },
45       "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
46       "loc"_a = py::none());
47 
48   llvmStructType.def_classmethod(
49       "get_identified",
50       [](py::object cls, const std::string &name, MlirContext context) {
51         return cls(mlirLLVMStructTypeIdentifiedGet(
52             context, mlirStringRefCreate(name.data(), name.size())));
53       },
54       "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
55 
56   llvmStructType.def_classmethod(
57       "get_opaque",
58       [](py::object cls, const std::string &name, MlirContext context) {
59         return cls(mlirLLVMStructTypeOpaqueGet(
60             context, mlirStringRefCreate(name.data(), name.size())));
61       },
62       "cls"_a, "name"_a, "context"_a = py::none());
63 
64   llvmStructType.def(
65       "set_body",
66       [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
67         MlirLogicalResult result = mlirLLVMStructTypeSetBody(
68             self, elements.size(), elements.data(), packed);
69         if (!mlirLogicalResultIsSuccess(result)) {
70           throw py::value_error(
71               "Struct body already set to different content.");
72         }
73       },
74       "elements"_a, py::kw_only(), "packed"_a = false);
75 
76   llvmStructType.def_classmethod(
77       "new_identified",
78       [](py::object cls, const std::string &name,
79          const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
80         return cls(mlirLLVMStructTypeIdentifiedNewGet(
81             ctx, mlirStringRefCreate(name.data(), name.length()),
82             elements.size(), elements.data(), packed));
83       },
84       "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
85       "context"_a = py::none());
86 
87   llvmStructType.def_property_readonly(
88       "name", [](MlirType type) -> std::optional<std::string> {
89         if (mlirLLVMStructTypeIsLiteral(type))
90           return std::nullopt;
91 
92         MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
93         return StringRef(stringRef.data, stringRef.length).str();
94       });
95 
96   llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
97     // Don't crash in absence of a body.
98     if (mlirLLVMStructTypeIsOpaque(type))
99       return py::none();
100 
101     py::list body;
102     for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
103          ++i) {
104       body.append(mlirLLVMStructTypeGetElementType(type, i));
105     }
106     return body;
107   });
108 
109   llvmStructType.def_property_readonly(
110       "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
111 
112   llvmStructType.def_property_readonly(
113       "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
114 
115   //===--------------------------------------------------------------------===//
116   // PointerType
117   //===--------------------------------------------------------------------===//
118 
119   mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
120       .def_classmethod(
121           "get",
122           [](py::object cls, std::optional<unsigned> addressSpace,
123              MlirContext context) {
124             CollectDiagnosticsToStringScope scope(context);
125             MlirType type = mlirLLVMPointerTypeGet(
126                 context, addressSpace.has_value() ? *addressSpace : 0);
127             if (mlirTypeIsNull(type)) {
128               throw py::value_error(scope.takeMessage());
129             }
130             return cls(type);
131           },
132           "cls"_a, "address_space"_a = py::none(), py::kw_only(),
133           "context"_a = py::none())
134       .def_property_readonly("address_space", [](MlirType type) {
135         return mlirLLVMPointerTypeGetAddressSpace(type);
136       });
137 }
138 
139 PYBIND11_MODULE(_mlirDialectsLLVM, m) {
140   m.doc() = "MLIR LLVM Dialect";
141 
142   populateDialectLLVMSubmodule(m);
143 }
144