1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "mlir-c/BuiltinAttributes.h" 11 #include "mlir-c/Interfaces.h" 12 13 namespace py = pybind11; 14 15 namespace mlir { 16 namespace python { 17 18 constexpr static const char *constructorDoc = 19 R"(Creates an interface from a given operation/opview object or from a 20 subclass of OpView. Raises ValueError if the operation does not implement the 21 interface.)"; 22 23 constexpr static const char *operationDoc = 24 R"(Returns an Operation for which the interface was constructed.)"; 25 26 constexpr static const char *opviewDoc = 27 R"(Returns an OpView subclass _instance_ for which the interface was 28 constructed)"; 29 30 constexpr static const char *inferReturnTypesDoc = 31 R"(Given the arguments required to build an operation, attempts to infer 32 its return types. Raises ValueError on failure.)"; 33 34 /// CRTP base class for Python classes representing MLIR Op interfaces. 35 /// Interface hierarchies are flat so no base class is expected here. The 36 /// derived class is expected to define the following static fields: 37 /// - `const char *pyClassName` - the name of the Python class to create; 38 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 39 /// of the interface. 40 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 41 /// interface-specific methods. 42 /// 43 /// An interface class may be constructed from either an Operation/OpView object 44 /// or from a subclass of OpView. In the latter case, only the static interface 45 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 46 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 47 /// method to check whether the interface object was constructed from a class or 48 /// an operation/opview instance. The `getOpName` always succeeds and returns a 49 /// canonical name of the operation suitable for lookups. 50 template <typename ConcreteIface> 51 class PyConcreteOpInterface { 52 protected: 53 using ClassTy = py::class_<ConcreteIface>; 54 using GetTypeIDFunctionTy = MlirTypeID (*)(); 55 56 public: 57 /// Constructs an interface instance from an object that is either an 58 /// operation or a subclass of OpView. In the latter case, only the static 59 /// methods of the interface are accessible to the caller. 60 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 61 : obj(object) { 62 try { 63 operation = &py::cast<PyOperation &>(obj); 64 } catch (py::cast_error &err) { 65 // Do nothing. 66 } 67 68 try { 69 operation = &py::cast<PyOpView &>(obj).getOperation(); 70 } catch (py::cast_error &err) { 71 // Do nothing. 72 } 73 74 if (operation != nullptr) { 75 if (!mlirOperationImplementsInterface(*operation, 76 ConcreteIface::getInterfaceID())) { 77 std::string msg = "the operation does not implement "; 78 throw py::value_error(msg + ConcreteIface::pyClassName); 79 } 80 81 MlirIdentifier identifier = mlirOperationGetName(*operation); 82 MlirStringRef stringRef = mlirIdentifierStr(identifier); 83 opName = std::string(stringRef.data, stringRef.length); 84 } else { 85 try { 86 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 87 } catch (py::cast_error &err) { 88 throw py::type_error( 89 "Op interface does not refer to an operation or OpView class"); 90 } 91 92 if (!mlirOperationImplementsInterfaceStatic( 93 mlirStringRefCreate(opName.data(), opName.length()), 94 context.resolve().get(), ConcreteIface::getInterfaceID())) { 95 std::string msg = "the operation does not implement "; 96 throw py::value_error(msg + ConcreteIface::pyClassName); 97 } 98 } 99 } 100 101 /// Creates the Python bindings for this class in the given module. 102 static void bind(py::module &m) { 103 py::class_<ConcreteIface> cls(m, "InferTypeOpInterface", 104 py::module_local()); 105 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 106 py::arg("context") = py::none(), constructorDoc) 107 .def_property_readonly("operation", 108 &PyConcreteOpInterface::getOperationObject, 109 operationDoc) 110 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 111 opviewDoc); 112 ConcreteIface::bindDerived(cls); 113 } 114 115 /// Hook for derived classes to add class-specific bindings. 116 static void bindDerived(ClassTy &cls) {} 117 118 /// Returns `true` if this object was constructed from a subclass of OpView 119 /// rather than from an operation instance. 120 bool isStatic() { return operation == nullptr; } 121 122 /// Returns the operation instance from which this object was constructed. 123 /// Throws a type error if this object was constructed from a subclass of 124 /// OpView. 125 py::object getOperationObject() { 126 if (operation == nullptr) { 127 throw py::type_error("Cannot get an operation from a static interface"); 128 } 129 130 return operation->getRef().releaseObject(); 131 } 132 133 /// Returns the opview of the operation instance from which this object was 134 /// constructed. Throws a type error if this object was constructed form a 135 /// subclass of OpView. 136 py::object getOpView() { 137 if (operation == nullptr) { 138 throw py::type_error("Cannot get an opview from a static interface"); 139 } 140 141 return operation->createOpView(); 142 } 143 144 /// Returns the canonical name of the operation this interface is constructed 145 /// from. 146 const std::string &getOpName() { return opName; } 147 148 private: 149 PyOperation *operation = nullptr; 150 std::string opName; 151 py::object obj; 152 }; 153 154 /// Python wrapper for InterTypeOpInterface. This interface has only static 155 /// methods. 156 class PyInferTypeOpInterface 157 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 158 public: 159 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 160 161 constexpr static const char *pyClassName = "InferTypeOpInterface"; 162 constexpr static GetTypeIDFunctionTy getInterfaceID = 163 &mlirInferTypeOpInterfaceTypeID; 164 165 /// C-style user-data structure for type appending callback. 166 struct AppendResultsCallbackData { 167 std::vector<PyType> &inferredTypes; 168 PyMlirContext &pyMlirContext; 169 }; 170 171 /// Appends the types provided as the two first arguments to the user-data 172 /// structure (expects AppendResultsCallbackData). 173 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 174 void *userData) { 175 auto *data = static_cast<AppendResultsCallbackData *>(userData); 176 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 177 for (intptr_t i = 0; i < nTypes; ++i) { 178 data->inferredTypes.push_back( 179 PyType(data->pyMlirContext.getRef(), types[i])); 180 } 181 } 182 183 /// Given the arguments required to build an operation, attempts to infer its 184 /// return types. Throws value_error on faliure. 185 std::vector<PyType> 186 inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands, 187 llvm::Optional<PyAttribute> attributes, 188 llvm::Optional<std::vector<PyRegion>> regions, 189 DefaultingPyMlirContext context, 190 DefaultingPyLocation location) { 191 llvm::SmallVector<MlirValue> mlirOperands; 192 llvm::SmallVector<MlirRegion> mlirRegions; 193 194 if (operands) { 195 mlirOperands.reserve(operands->size()); 196 for (PyValue &value : *operands) { 197 mlirOperands.push_back(value); 198 } 199 } 200 201 if (regions) { 202 mlirRegions.reserve(regions->size()); 203 for (PyRegion ®ion : *regions) { 204 mlirRegions.push_back(region); 205 } 206 } 207 208 std::vector<PyType> inferredTypes; 209 PyMlirContext &pyContext = context.resolve(); 210 AppendResultsCallbackData data{inferredTypes, pyContext}; 211 MlirStringRef opNameRef = 212 mlirStringRefCreate(getOpName().data(), getOpName().length()); 213 MlirAttribute attributeDict = 214 attributes ? attributes->get() : mlirAttributeGetNull(); 215 216 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 217 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 218 mlirOperands.data(), attributeDict, mlirRegions.size(), 219 mlirRegions.data(), &appendResultsCallback, &data); 220 221 if (mlirLogicalResultIsFailure(result)) { 222 throw py::value_error("Failed to infer result types"); 223 } 224 225 return inferredTypes; 226 } 227 228 static void bindDerived(ClassTy &cls) { 229 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 230 py::arg("operands") = py::none(), 231 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 232 py::arg("context") = py::none(), py::arg("loc") = py::none(), 233 inferReturnTypesDoc); 234 } 235 }; 236 237 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } 238 239 } // namespace python 240 } // namespace mlir 241