1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 #include <optional> 11 12 #include "IRModule.h" 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/Interfaces.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 namespace py = pybind11; 18 19 namespace mlir { 20 namespace python { 21 22 constexpr static const char *constructorDoc = 23 R"(Creates an interface from a given operation/opview object or from a 24 subclass of OpView. Raises ValueError if the operation does not implement the 25 interface.)"; 26 27 constexpr static const char *operationDoc = 28 R"(Returns an Operation for which the interface was constructed.)"; 29 30 constexpr static const char *opviewDoc = 31 R"(Returns an OpView subclass _instance_ for which the interface was 32 constructed)"; 33 34 constexpr static const char *inferReturnTypesDoc = 35 R"(Given the arguments required to build an operation, attempts to infer 36 its return types. Raises ValueError on failure.)"; 37 38 /// CRTP base class for Python classes representing MLIR Op interfaces. 39 /// Interface hierarchies are flat so no base class is expected here. The 40 /// derived class is expected to define the following static fields: 41 /// - `const char *pyClassName` - the name of the Python class to create; 42 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 43 /// of the interface. 44 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 45 /// interface-specific methods. 46 /// 47 /// An interface class may be constructed from either an Operation/OpView object 48 /// or from a subclass of OpView. In the latter case, only the static interface 49 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 50 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 51 /// method to check whether the interface object was constructed from a class or 52 /// an operation/opview instance. The `getOpName` always succeeds and returns a 53 /// canonical name of the operation suitable for lookups. 54 template <typename ConcreteIface> 55 class PyConcreteOpInterface { 56 protected: 57 using ClassTy = py::class_<ConcreteIface>; 58 using GetTypeIDFunctionTy = MlirTypeID (*)(); 59 60 public: 61 /// Constructs an interface instance from an object that is either an 62 /// operation or a subclass of OpView. In the latter case, only the static 63 /// methods of the interface are accessible to the caller. 64 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 65 : obj(std::move(object)) { 66 try { 67 operation = &py::cast<PyOperation &>(obj); 68 } catch (py::cast_error &) { 69 // Do nothing. 70 } 71 72 try { 73 operation = &py::cast<PyOpView &>(obj).getOperation(); 74 } catch (py::cast_error &) { 75 // Do nothing. 76 } 77 78 if (operation != nullptr) { 79 if (!mlirOperationImplementsInterface(*operation, 80 ConcreteIface::getInterfaceID())) { 81 std::string msg = "the operation does not implement "; 82 throw py::value_error(msg + ConcreteIface::pyClassName); 83 } 84 85 MlirIdentifier identifier = mlirOperationGetName(*operation); 86 MlirStringRef stringRef = mlirIdentifierStr(identifier); 87 opName = std::string(stringRef.data, stringRef.length); 88 } else { 89 try { 90 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 91 } catch (py::cast_error &) { 92 throw py::type_error( 93 "Op interface does not refer to an operation or OpView class"); 94 } 95 96 if (!mlirOperationImplementsInterfaceStatic( 97 mlirStringRefCreate(opName.data(), opName.length()), 98 context.resolve().get(), ConcreteIface::getInterfaceID())) { 99 std::string msg = "the operation does not implement "; 100 throw py::value_error(msg + ConcreteIface::pyClassName); 101 } 102 } 103 } 104 105 /// Creates the Python bindings for this class in the given module. 106 static void bind(py::module &m) { 107 py::class_<ConcreteIface> cls(m, "InferTypeOpInterface", 108 py::module_local()); 109 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 110 py::arg("context") = py::none(), constructorDoc) 111 .def_property_readonly("operation", 112 &PyConcreteOpInterface::getOperationObject, 113 operationDoc) 114 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 115 opviewDoc); 116 ConcreteIface::bindDerived(cls); 117 } 118 119 /// Hook for derived classes to add class-specific bindings. 120 static void bindDerived(ClassTy &cls) {} 121 122 /// Returns `true` if this object was constructed from a subclass of OpView 123 /// rather than from an operation instance. 124 bool isStatic() { return operation == nullptr; } 125 126 /// Returns the operation instance from which this object was constructed. 127 /// Throws a type error if this object was constructed from a subclass of 128 /// OpView. 129 py::object getOperationObject() { 130 if (operation == nullptr) { 131 throw py::type_error("Cannot get an operation from a static interface"); 132 } 133 134 return operation->getRef().releaseObject(); 135 } 136 137 /// Returns the opview of the operation instance from which this object was 138 /// constructed. Throws a type error if this object was constructed form a 139 /// subclass of OpView. 140 py::object getOpView() { 141 if (operation == nullptr) { 142 throw py::type_error("Cannot get an opview from a static interface"); 143 } 144 145 return operation->createOpView(); 146 } 147 148 /// Returns the canonical name of the operation this interface is constructed 149 /// from. 150 const std::string &getOpName() { return opName; } 151 152 private: 153 PyOperation *operation = nullptr; 154 std::string opName; 155 py::object obj; 156 }; 157 158 /// Python wrapper for InterTypeOpInterface. This interface has only static 159 /// methods. 160 class PyInferTypeOpInterface 161 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 162 public: 163 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 164 165 constexpr static const char *pyClassName = "InferTypeOpInterface"; 166 constexpr static GetTypeIDFunctionTy getInterfaceID = 167 &mlirInferTypeOpInterfaceTypeID; 168 169 /// C-style user-data structure for type appending callback. 170 struct AppendResultsCallbackData { 171 std::vector<PyType> &inferredTypes; 172 PyMlirContext &pyMlirContext; 173 }; 174 175 /// Appends the types provided as the two first arguments to the user-data 176 /// structure (expects AppendResultsCallbackData). 177 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 178 void *userData) { 179 auto *data = static_cast<AppendResultsCallbackData *>(userData); 180 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 181 for (intptr_t i = 0; i < nTypes; ++i) { 182 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 183 } 184 } 185 186 /// Given the arguments required to build an operation, attempts to infer its 187 /// return types. Throws value_error on failure. 188 std::vector<PyType> 189 inferReturnTypes(std::optional<py::list> operandList, 190 std::optional<PyAttribute> attributes, void *properties, 191 std::optional<std::vector<PyRegion>> regions, 192 DefaultingPyMlirContext context, 193 DefaultingPyLocation location) { 194 llvm::SmallVector<MlirValue> mlirOperands; 195 llvm::SmallVector<MlirRegion> mlirRegions; 196 197 if (operandList && !operandList->empty()) { 198 // Note: as the list may contain other lists this may not be final size. 199 mlirOperands.reserve(operandList->size()); 200 for (const auto& it : llvm::enumerate(*operandList)) { 201 PyValue* val; 202 try { 203 val = py::cast<PyValue *>(it.value()); 204 if (!val) 205 throw py::cast_error(); 206 mlirOperands.push_back(val->get()); 207 continue; 208 } catch (py::cast_error &err) { 209 // Intentionally unhandled to try sequence below first. 210 (void)err; 211 } 212 213 try { 214 auto vals = py::cast<py::sequence>(it.value()); 215 for (py::object v : vals) { 216 try { 217 val = py::cast<PyValue *>(v); 218 if (!val) 219 throw py::cast_error(); 220 mlirOperands.push_back(val->get()); 221 } catch (py::cast_error &err) { 222 throw py::value_error( 223 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 224 " must be a Value or Sequence of Values (" + err.what() + 225 ")") 226 .str()); 227 } 228 } 229 continue; 230 } catch (py::cast_error &err) { 231 throw py::value_error( 232 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 233 " must be a Value or Sequence of Values (" + err.what() + ")") 234 .str()); 235 } 236 237 throw py::cast_error(); 238 } 239 } 240 241 if (regions) { 242 mlirRegions.reserve(regions->size()); 243 for (PyRegion ®ion : *regions) { 244 mlirRegions.push_back(region); 245 } 246 } 247 248 std::vector<PyType> inferredTypes; 249 PyMlirContext &pyContext = context.resolve(); 250 AppendResultsCallbackData data{inferredTypes, pyContext}; 251 MlirStringRef opNameRef = 252 mlirStringRefCreate(getOpName().data(), getOpName().length()); 253 MlirAttribute attributeDict = 254 attributes ? attributes->get() : mlirAttributeGetNull(); 255 256 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 257 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 258 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 259 mlirRegions.data(), &appendResultsCallback, &data); 260 261 if (mlirLogicalResultIsFailure(result)) { 262 throw py::value_error("Failed to infer result types"); 263 } 264 265 return inferredTypes; 266 } 267 268 static void bindDerived(ClassTy &cls) { 269 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 270 py::arg("operands") = py::none(), 271 py::arg("attributes") = py::none(), 272 py::arg("properties") = py::none(), py::arg("regions") = py::none(), 273 py::arg("context") = py::none(), py::arg("loc") = py::none(), 274 inferReturnTypesDoc); 275 } 276 }; 277 278 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } 279 280 } // namespace python 281 } // namespace mlir 282