xref: /llvm-project/mlir/lib/Bindings/Python/DialectLLVM.cpp (revision bd8fcf75df11406527de423daa63e21c3ec8609b)
1 //===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Diagnostics.h"
10 #include "mlir-c/Dialect/LLVM.h"
11 #include "mlir-c/IR.h"
12 #include "mlir-c/Support.h"
13 #include "mlir/Bindings/Python/PybindAdaptors.h"
14 #include <string>
15 
16 namespace py = pybind11;
17 using namespace llvm;
18 using namespace mlir;
19 using namespace mlir::python;
20 using namespace mlir::python::adaptors;
21 
22 /// RAII scope intercepting all diagnostics into a string. The message must be
23 /// checked before this goes out of scope.
24 class CollectDiagnosticsToStringScope {
25 public:
26   explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27     handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28                                                    /*deleteUserData=*/nullptr);
29   }
30   ~CollectDiagnosticsToStringScope() {
31     assert(errorMessage.empty() && "unchecked error message");
32     mlirContextDetachDiagnosticHandler(context, handlerID);
33   }
34 
35   [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
36 
37 private:
38   static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
39     auto printer = +[](MlirStringRef message, void *data) {
40       *static_cast<std::string *>(data) +=
41           StringRef(message.data, message.length);
42     };
43     mlirDiagnosticPrint(diag, printer, data);
44     return mlirLogicalResultSuccess();
45   }
46 
47   MlirContext context;
48   MlirDiagnosticHandlerID handlerID;
49   std::string errorMessage = "";
50 };
51 
52 void populateDialectLLVMSubmodule(const pybind11::module &m) {
53   auto llvmStructType =
54       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
55 
56   llvmStructType.def_classmethod(
57       "get_literal",
58       [](py::object cls, const std::vector<MlirType> &elements, bool packed,
59          MlirLocation loc) {
60         CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
61 
62         MlirType type = mlirLLVMStructTypeLiteralGetChecked(
63             loc, elements.size(), elements.data(), packed);
64         if (mlirTypeIsNull(type)) {
65           throw py::value_error(scope.takeMessage());
66         }
67         return cls(type);
68       },
69       py::arg("cls"), py::arg("elements"), py::kw_only(),
70       py::arg("packed") = false, py::arg("loc") = py::none());
71 
72   llvmStructType.def_classmethod(
73       "get_identified",
74       [](py::object cls, const std::string &name, MlirContext context) {
75         return cls(mlirLLVMStructTypeIdentifiedGet(
76             context, mlirStringRefCreate(name.data(), name.size())));
77       },
78       py::arg("cls"), py::arg("name"), py::kw_only(),
79       py::arg("context") = py::none());
80 
81   llvmStructType.def_classmethod(
82       "get_opaque",
83       [](py::object cls, const std::string &name, MlirContext context) {
84         return cls(mlirLLVMStructTypeOpaqueGet(
85             context, mlirStringRefCreate(name.data(), name.size())));
86       },
87       py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
88 
89   llvmStructType.def(
90       "set_body",
91       [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
92         MlirLogicalResult result = mlirLLVMStructTypeSetBody(
93             self, elements.size(), elements.data(), packed);
94         if (!mlirLogicalResultIsSuccess(result)) {
95           throw py::value_error(
96               "Struct body already set to different content.");
97         }
98       },
99       py::arg("elements"), py::kw_only(), py::arg("packed") = false);
100 
101   llvmStructType.def_classmethod(
102       "new_identified",
103       [](py::object cls, const std::string &name,
104          const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
105         return cls(mlirLLVMStructTypeIdentifiedNewGet(
106             ctx, mlirStringRefCreate(name.data(), name.length()),
107             elements.size(), elements.data(), packed));
108       },
109       py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
110       py::arg("packed") = false, py::arg("context") = py::none());
111 
112   llvmStructType.def_property_readonly(
113       "name", [](MlirType type) -> std::optional<std::string> {
114         if (mlirLLVMStructTypeIsLiteral(type))
115           return std::nullopt;
116 
117         MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
118         return StringRef(stringRef.data, stringRef.length).str();
119       });
120 
121   llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
122     // Don't crash in absence of a body.
123     if (mlirLLVMStructTypeIsOpaque(type))
124       return py::none();
125 
126     py::list body;
127     for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
128          ++i) {
129       body.append(mlirLLVMStructTypeGetElementType(type, i));
130     }
131     return body;
132   });
133 
134   llvmStructType.def_property_readonly(
135       "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
136 
137   llvmStructType.def_property_readonly(
138       "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
139 }
140 
141 PYBIND11_MODULE(_mlirDialectsLLVM, m) {
142   m.doc() = "MLIR LLVM Dialect";
143 
144   populateDialectLLVMSubmodule(m);
145 }
146