1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "IRModule.h" 12 #include "mlir-c/BuiltinAttributes.h" 13 #include "mlir-c/Interfaces.h" 14 15 namespace py = pybind11; 16 17 namespace mlir { 18 namespace python { 19 20 constexpr static const char *constructorDoc = 21 R"(Creates an interface from a given operation/opview object or from a 22 subclass of OpView. Raises ValueError if the operation does not implement the 23 interface.)"; 24 25 constexpr static const char *operationDoc = 26 R"(Returns an Operation for which the interface was constructed.)"; 27 28 constexpr static const char *opviewDoc = 29 R"(Returns an OpView subclass _instance_ for which the interface was 30 constructed)"; 31 32 constexpr static const char *inferReturnTypesDoc = 33 R"(Given the arguments required to build an operation, attempts to infer 34 its return types. Raises ValueError on failure.)"; 35 36 /// CRTP base class for Python classes representing MLIR Op interfaces. 37 /// Interface hierarchies are flat so no base class is expected here. The 38 /// derived class is expected to define the following static fields: 39 /// - `const char *pyClassName` - the name of the Python class to create; 40 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 41 /// of the interface. 42 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 43 /// interface-specific methods. 44 /// 45 /// An interface class may be constructed from either an Operation/OpView object 46 /// or from a subclass of OpView. In the latter case, only the static interface 47 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 48 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 49 /// method to check whether the interface object was constructed from a class or 50 /// an operation/opview instance. The `getOpName` always succeeds and returns a 51 /// canonical name of the operation suitable for lookups. 52 template <typename ConcreteIface> 53 class PyConcreteOpInterface { 54 protected: 55 using ClassTy = py::class_<ConcreteIface>; 56 using GetTypeIDFunctionTy = MlirTypeID (*)(); 57 58 public: 59 /// Constructs an interface instance from an object that is either an 60 /// operation or a subclass of OpView. In the latter case, only the static 61 /// methods of the interface are accessible to the caller. 62 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 63 : obj(std::move(object)) { 64 try { 65 operation = &py::cast<PyOperation &>(obj); 66 } catch (py::cast_error &) { 67 // Do nothing. 68 } 69 70 try { 71 operation = &py::cast<PyOpView &>(obj).getOperation(); 72 } catch (py::cast_error &) { 73 // Do nothing. 74 } 75 76 if (operation != nullptr) { 77 if (!mlirOperationImplementsInterface(*operation, 78 ConcreteIface::getInterfaceID())) { 79 std::string msg = "the operation does not implement "; 80 throw py::value_error(msg + ConcreteIface::pyClassName); 81 } 82 83 MlirIdentifier identifier = mlirOperationGetName(*operation); 84 MlirStringRef stringRef = mlirIdentifierStr(identifier); 85 opName = std::string(stringRef.data, stringRef.length); 86 } else { 87 try { 88 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 89 } catch (py::cast_error &) { 90 throw py::type_error( 91 "Op interface does not refer to an operation or OpView class"); 92 } 93 94 if (!mlirOperationImplementsInterfaceStatic( 95 mlirStringRefCreate(opName.data(), opName.length()), 96 context.resolve().get(), ConcreteIface::getInterfaceID())) { 97 std::string msg = "the operation does not implement "; 98 throw py::value_error(msg + ConcreteIface::pyClassName); 99 } 100 } 101 } 102 103 /// Creates the Python bindings for this class in the given module. 104 static void bind(py::module &m) { 105 py::class_<ConcreteIface> cls(m, "InferTypeOpInterface", 106 py::module_local()); 107 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 108 py::arg("context") = py::none(), constructorDoc) 109 .def_property_readonly("operation", 110 &PyConcreteOpInterface::getOperationObject, 111 operationDoc) 112 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 113 opviewDoc); 114 ConcreteIface::bindDerived(cls); 115 } 116 117 /// Hook for derived classes to add class-specific bindings. 118 static void bindDerived(ClassTy &cls) {} 119 120 /// Returns `true` if this object was constructed from a subclass of OpView 121 /// rather than from an operation instance. 122 bool isStatic() { return operation == nullptr; } 123 124 /// Returns the operation instance from which this object was constructed. 125 /// Throws a type error if this object was constructed from a subclass of 126 /// OpView. 127 py::object getOperationObject() { 128 if (operation == nullptr) { 129 throw py::type_error("Cannot get an operation from a static interface"); 130 } 131 132 return operation->getRef().releaseObject(); 133 } 134 135 /// Returns the opview of the operation instance from which this object was 136 /// constructed. Throws a type error if this object was constructed form a 137 /// subclass of OpView. 138 py::object getOpView() { 139 if (operation == nullptr) { 140 throw py::type_error("Cannot get an opview from a static interface"); 141 } 142 143 return operation->createOpView(); 144 } 145 146 /// Returns the canonical name of the operation this interface is constructed 147 /// from. 148 const std::string &getOpName() { return opName; } 149 150 private: 151 PyOperation *operation = nullptr; 152 std::string opName; 153 py::object obj; 154 }; 155 156 /// Python wrapper for InterTypeOpInterface. This interface has only static 157 /// methods. 158 class PyInferTypeOpInterface 159 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 160 public: 161 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 162 163 constexpr static const char *pyClassName = "InferTypeOpInterface"; 164 constexpr static GetTypeIDFunctionTy getInterfaceID = 165 &mlirInferTypeOpInterfaceTypeID; 166 167 /// C-style user-data structure for type appending callback. 168 struct AppendResultsCallbackData { 169 std::vector<PyType> &inferredTypes; 170 PyMlirContext &pyMlirContext; 171 }; 172 173 /// Appends the types provided as the two first arguments to the user-data 174 /// structure (expects AppendResultsCallbackData). 175 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 176 void *userData) { 177 auto *data = static_cast<AppendResultsCallbackData *>(userData); 178 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 179 for (intptr_t i = 0; i < nTypes; ++i) { 180 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 181 } 182 } 183 184 /// Given the arguments required to build an operation, attempts to infer its 185 /// return types. Throws value_error on faliure. 186 std::vector<PyType> 187 inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands, 188 llvm::Optional<PyAttribute> attributes, 189 llvm::Optional<std::vector<PyRegion>> regions, 190 DefaultingPyMlirContext context, 191 DefaultingPyLocation location) { 192 llvm::SmallVector<MlirValue> mlirOperands; 193 llvm::SmallVector<MlirRegion> mlirRegions; 194 195 if (operands) { 196 mlirOperands.reserve(operands->size()); 197 for (PyValue &value : *operands) { 198 mlirOperands.push_back(value); 199 } 200 } 201 202 if (regions) { 203 mlirRegions.reserve(regions->size()); 204 for (PyRegion ®ion : *regions) { 205 mlirRegions.push_back(region); 206 } 207 } 208 209 std::vector<PyType> inferredTypes; 210 PyMlirContext &pyContext = context.resolve(); 211 AppendResultsCallbackData data{inferredTypes, pyContext}; 212 MlirStringRef opNameRef = 213 mlirStringRefCreate(getOpName().data(), getOpName().length()); 214 MlirAttribute attributeDict = 215 attributes ? attributes->get() : mlirAttributeGetNull(); 216 217 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 218 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 219 mlirOperands.data(), attributeDict, mlirRegions.size(), 220 mlirRegions.data(), &appendResultsCallback, &data); 221 222 if (mlirLogicalResultIsFailure(result)) { 223 throw py::value_error("Failed to infer result types"); 224 } 225 226 return inferredTypes; 227 } 228 229 static void bindDerived(ClassTy &cls) { 230 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 231 py::arg("operands") = py::none(), 232 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 233 py::arg("context") = py::none(), py::arg("loc") = py::none(), 234 inferReturnTypesDoc); 235 } 236 }; 237 238 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } 239 240 } // namespace python 241 } // namespace mlir 242