1 //===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Diagnostics.h" 10 #include "mlir-c/Dialect/LLVM.h" 11 #include "mlir-c/IR.h" 12 #include "mlir-c/Support.h" 13 #include "mlir/Bindings/Python/PybindAdaptors.h" 14 #include <string> 15 16 namespace py = pybind11; 17 using namespace llvm; 18 using namespace mlir; 19 using namespace mlir::python; 20 using namespace mlir::python::adaptors; 21 22 /// RAII scope intercepting all diagnostics into a string. The message must be 23 /// checked before this goes out of scope. 24 class CollectDiagnosticsToStringScope { 25 public: 26 explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { 27 handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, 28 /*deleteUserData=*/nullptr); 29 } 30 ~CollectDiagnosticsToStringScope() { 31 assert(errorMessage.empty() && "unchecked error message"); 32 mlirContextDetachDiagnosticHandler(context, handlerID); 33 } 34 35 [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } 36 37 private: 38 static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { 39 auto printer = +[](MlirStringRef message, void *data) { 40 *static_cast<std::string *>(data) += 41 StringRef(message.data, message.length); 42 }; 43 mlirDiagnosticPrint(diag, printer, data); 44 return mlirLogicalResultSuccess(); 45 } 46 47 MlirContext context; 48 MlirDiagnosticHandlerID handlerID; 49 std::string errorMessage = ""; 50 }; 51 52 void populateDialectLLVMSubmodule(const pybind11::module &m) { 53 auto llvmStructType = 54 mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); 55 56 llvmStructType.def_classmethod( 57 "get_literal", 58 [](py::object cls, const std::vector<MlirType> &elements, bool packed, 59 MlirLocation loc) { 60 CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); 61 62 MlirType type = mlirLLVMStructTypeLiteralGetChecked( 63 loc, elements.size(), elements.data(), packed); 64 if (mlirTypeIsNull(type)) { 65 throw py::value_error(scope.takeMessage()); 66 } 67 return cls(type); 68 }, 69 py::arg("cls"), py::arg("elements"), py::kw_only(), 70 py::arg("packed") = false, py::arg("loc") = py::none()); 71 72 llvmStructType.def_classmethod( 73 "get_identified", 74 [](py::object cls, const std::string &name, MlirContext context) { 75 return cls(mlirLLVMStructTypeIdentifiedGet( 76 context, mlirStringRefCreate(name.data(), name.size()))); 77 }, 78 py::arg("cls"), py::arg("name"), py::kw_only(), 79 py::arg("context") = py::none()); 80 81 llvmStructType.def_classmethod( 82 "get_opaque", 83 [](py::object cls, const std::string &name, MlirContext context) { 84 return cls(mlirLLVMStructTypeOpaqueGet( 85 context, mlirStringRefCreate(name.data(), name.size()))); 86 }, 87 py::arg("cls"), py::arg("name"), py::arg("context") = py::none()); 88 89 llvmStructType.def( 90 "set_body", 91 [](MlirType self, const std::vector<MlirType> &elements, bool packed) { 92 MlirLogicalResult result = mlirLLVMStructTypeSetBody( 93 self, elements.size(), elements.data(), packed); 94 if (!mlirLogicalResultIsSuccess(result)) { 95 throw py::value_error( 96 "Struct body already set to different content."); 97 } 98 }, 99 py::arg("elements"), py::kw_only(), py::arg("packed") = false); 100 101 llvmStructType.def_classmethod( 102 "new_identified", 103 [](py::object cls, const std::string &name, 104 const std::vector<MlirType> &elements, bool packed, MlirContext ctx) { 105 return cls(mlirLLVMStructTypeIdentifiedNewGet( 106 ctx, mlirStringRefCreate(name.data(), name.length()), 107 elements.size(), elements.data(), packed)); 108 }, 109 py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(), 110 py::arg("packed") = false, py::arg("context") = py::none()); 111 112 llvmStructType.def_property_readonly( 113 "name", [](MlirType type) -> std::optional<std::string> { 114 if (mlirLLVMStructTypeIsLiteral(type)) 115 return std::nullopt; 116 117 MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type); 118 return StringRef(stringRef.data, stringRef.length).str(); 119 }); 120 121 llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object { 122 // Don't crash in absence of a body. 123 if (mlirLLVMStructTypeIsOpaque(type)) 124 return py::none(); 125 126 py::list body; 127 for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e; 128 ++i) { 129 body.append(mlirLLVMStructTypeGetElementType(type, i)); 130 } 131 return body; 132 }); 133 134 llvmStructType.def_property_readonly( 135 "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); }); 136 137 llvmStructType.def_property_readonly( 138 "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); }); 139 } 140 141 PYBIND11_MODULE(_mlirDialectsLLVM, m) { 142 m.doc() = "MLIR LLVM Dialect"; 143 144 populateDialectLLVMSubmodule(m); 145 } 146