1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 #include <optional> 11 12 #include "IRModule.h" 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/Interfaces.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 namespace py = pybind11; 18 19 namespace mlir { 20 namespace python { 21 22 constexpr static const char *constructorDoc = 23 R"(Creates an interface from a given operation/opview object or from a 24 subclass of OpView. Raises ValueError if the operation does not implement the 25 interface.)"; 26 27 constexpr static const char *operationDoc = 28 R"(Returns an Operation for which the interface was constructed.)"; 29 30 constexpr static const char *opviewDoc = 31 R"(Returns an OpView subclass _instance_ for which the interface was 32 constructed)"; 33 34 constexpr static const char *inferReturnTypesDoc = 35 R"(Given the arguments required to build an operation, attempts to infer 36 its return types. Raises ValueError on failure.)"; 37 38 /// CRTP base class for Python classes representing MLIR Op interfaces. 39 /// Interface hierarchies are flat so no base class is expected here. The 40 /// derived class is expected to define the following static fields: 41 /// - `const char *pyClassName` - the name of the Python class to create; 42 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 43 /// of the interface. 44 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 45 /// interface-specific methods. 46 /// 47 /// An interface class may be constructed from either an Operation/OpView object 48 /// or from a subclass of OpView. In the latter case, only the static interface 49 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 50 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 51 /// method to check whether the interface object was constructed from a class or 52 /// an operation/opview instance. The `getOpName` always succeeds and returns a 53 /// canonical name of the operation suitable for lookups. 54 template <typename ConcreteIface> 55 class PyConcreteOpInterface { 56 protected: 57 using ClassTy = py::class_<ConcreteIface>; 58 using GetTypeIDFunctionTy = MlirTypeID (*)(); 59 60 public: 61 /// Constructs an interface instance from an object that is either an 62 /// operation or a subclass of OpView. In the latter case, only the static 63 /// methods of the interface are accessible to the caller. 64 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 65 : obj(std::move(object)) { 66 try { 67 operation = &py::cast<PyOperation &>(obj); 68 } catch (py::cast_error &) { 69 // Do nothing. 70 } 71 72 try { 73 operation = &py::cast<PyOpView &>(obj).getOperation(); 74 } catch (py::cast_error &) { 75 // Do nothing. 76 } 77 78 if (operation != nullptr) { 79 if (!mlirOperationImplementsInterface(*operation, 80 ConcreteIface::getInterfaceID())) { 81 std::string msg = "the operation does not implement "; 82 throw py::value_error(msg + ConcreteIface::pyClassName); 83 } 84 85 MlirIdentifier identifier = mlirOperationGetName(*operation); 86 MlirStringRef stringRef = mlirIdentifierStr(identifier); 87 opName = std::string(stringRef.data, stringRef.length); 88 } else { 89 try { 90 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 91 } catch (py::cast_error &) { 92 throw py::type_error( 93 "Op interface does not refer to an operation or OpView class"); 94 } 95 96 if (!mlirOperationImplementsInterfaceStatic( 97 mlirStringRefCreate(opName.data(), opName.length()), 98 context.resolve().get(), ConcreteIface::getInterfaceID())) { 99 std::string msg = "the operation does not implement "; 100 throw py::value_error(msg + ConcreteIface::pyClassName); 101 } 102 } 103 } 104 105 /// Creates the Python bindings for this class in the given module. 106 static void bind(py::module &m) { 107 py::class_<ConcreteIface> cls(m, "InferTypeOpInterface", 108 py::module_local()); 109 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 110 py::arg("context") = py::none(), constructorDoc) 111 .def_property_readonly("operation", 112 &PyConcreteOpInterface::getOperationObject, 113 operationDoc) 114 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 115 opviewDoc); 116 ConcreteIface::bindDerived(cls); 117 } 118 119 /// Hook for derived classes to add class-specific bindings. 120 static void bindDerived(ClassTy &cls) {} 121 122 /// Returns `true` if this object was constructed from a subclass of OpView 123 /// rather than from an operation instance. 124 bool isStatic() { return operation == nullptr; } 125 126 /// Returns the operation instance from which this object was constructed. 127 /// Throws a type error if this object was constructed from a subclass of 128 /// OpView. 129 py::object getOperationObject() { 130 if (operation == nullptr) { 131 throw py::type_error("Cannot get an operation from a static interface"); 132 } 133 134 return operation->getRef().releaseObject(); 135 } 136 137 /// Returns the opview of the operation instance from which this object was 138 /// constructed. Throws a type error if this object was constructed form a 139 /// subclass of OpView. 140 py::object getOpView() { 141 if (operation == nullptr) { 142 throw py::type_error("Cannot get an opview from a static interface"); 143 } 144 145 return operation->createOpView(); 146 } 147 148 /// Returns the canonical name of the operation this interface is constructed 149 /// from. 150 const std::string &getOpName() { return opName; } 151 152 private: 153 PyOperation *operation = nullptr; 154 std::string opName; 155 py::object obj; 156 }; 157 158 /// Python wrapper for InterTypeOpInterface. This interface has only static 159 /// methods. 160 class PyInferTypeOpInterface 161 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 162 public: 163 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 164 165 constexpr static const char *pyClassName = "InferTypeOpInterface"; 166 constexpr static GetTypeIDFunctionTy getInterfaceID = 167 &mlirInferTypeOpInterfaceTypeID; 168 169 /// C-style user-data structure for type appending callback. 170 struct AppendResultsCallbackData { 171 std::vector<PyType> &inferredTypes; 172 PyMlirContext &pyMlirContext; 173 }; 174 175 /// Appends the types provided as the two first arguments to the user-data 176 /// structure (expects AppendResultsCallbackData). 177 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 178 void *userData) { 179 auto *data = static_cast<AppendResultsCallbackData *>(userData); 180 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 181 for (intptr_t i = 0; i < nTypes; ++i) { 182 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 183 } 184 } 185 186 /// Given the arguments required to build an operation, attempts to infer its 187 /// return types. Throws value_error on failure. 188 std::vector<PyType> 189 inferReturnTypes(std::optional<py::list> operandList, 190 std::optional<PyAttribute> attributes, 191 std::optional<std::vector<PyRegion>> regions, 192 DefaultingPyMlirContext context, 193 DefaultingPyLocation location) { 194 llvm::SmallVector<MlirValue> mlirOperands; 195 llvm::SmallVector<MlirRegion> mlirRegions; 196 197 if (operandList && !operandList->empty()) { 198 // Note: as the list may contain other lists this may not be final size. 199 mlirOperands.reserve(operandList->size()); 200 for (const auto& it : llvm::enumerate(*operandList)) { 201 PyValue* val; 202 try { 203 val = py::cast<PyValue *>(it.value()); 204 if (!val) 205 throw py::cast_error(); 206 mlirOperands.push_back(val->get()); 207 continue; 208 } catch (py::cast_error &err) { 209 } 210 211 try { 212 auto vals = py::cast<py::sequence>(it.value()); 213 for (py::object v : vals) { 214 try { 215 val = py::cast<PyValue *>(v); 216 if (!val) 217 throw py::cast_error(); 218 mlirOperands.push_back(val->get()); 219 } catch (py::cast_error &err) { 220 throw py::value_error( 221 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 222 " must be a Value or Sequence of Values (" + err.what() + 223 ")") 224 .str()); 225 } 226 } 227 continue; 228 } catch (py::cast_error &err) { 229 throw py::value_error( 230 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 231 " must be a Value or Sequence of Values (" + err.what() + ")") 232 .str()); 233 } 234 235 throw py::cast_error(); 236 } 237 } 238 239 if (regions) { 240 mlirRegions.reserve(regions->size()); 241 for (PyRegion ®ion : *regions) { 242 mlirRegions.push_back(region); 243 } 244 } 245 246 std::vector<PyType> inferredTypes; 247 PyMlirContext &pyContext = context.resolve(); 248 AppendResultsCallbackData data{inferredTypes, pyContext}; 249 MlirStringRef opNameRef = 250 mlirStringRefCreate(getOpName().data(), getOpName().length()); 251 MlirAttribute attributeDict = 252 attributes ? attributes->get() : mlirAttributeGetNull(); 253 254 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 255 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 256 mlirOperands.data(), attributeDict, mlirRegions.size(), 257 mlirRegions.data(), &appendResultsCallback, &data); 258 259 if (mlirLogicalResultIsFailure(result)) { 260 throw py::value_error("Failed to infer result types"); 261 } 262 263 return inferredTypes; 264 } 265 266 static void bindDerived(ClassTy &cls) { 267 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 268 py::arg("operands") = py::none(), 269 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 270 py::arg("context") = py::none(), py::arg("loc") = py::none(), 271 inferReturnTypesDoc); 272 } 273 }; 274 275 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } 276 277 } // namespace python 278 } // namespace mlir 279