1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cstdint> 10 #include <optional> 11 #include <pybind11/cast.h> 12 #include <pybind11/detail/common.h> 13 #include <pybind11/pybind11.h> 14 #include <pybind11/pytypes.h> 15 #include <string> 16 #include <utility> 17 #include <vector> 18 19 #include "IRModule.h" 20 #include "mlir-c/BuiltinAttributes.h" 21 #include "mlir-c/IR.h" 22 #include "mlir-c/Interfaces.h" 23 #include "mlir-c/Support.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallVector.h" 26 27 namespace py = pybind11; 28 29 namespace mlir { 30 namespace python { 31 32 constexpr static const char *constructorDoc = 33 R"(Creates an interface from a given operation/opview object or from a 34 subclass of OpView. Raises ValueError if the operation does not implement the 35 interface.)"; 36 37 constexpr static const char *operationDoc = 38 R"(Returns an Operation for which the interface was constructed.)"; 39 40 constexpr static const char *opviewDoc = 41 R"(Returns an OpView subclass _instance_ for which the interface was 42 constructed)"; 43 44 constexpr static const char *inferReturnTypesDoc = 45 R"(Given the arguments required to build an operation, attempts to infer 46 its return types. Raises ValueError on failure.)"; 47 48 constexpr static const char *inferReturnTypeComponentsDoc = 49 R"(Given the arguments required to build an operation, attempts to infer 50 its return shaped type components. Raises ValueError on failure.)"; 51 52 namespace { 53 54 /// Takes in an optional ist of operands and converts them into a SmallVector 55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 56 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) { 57 llvm::SmallVector<MlirValue> mlirOperands; 58 59 if (!operandList || operandList->empty()) { 60 return mlirOperands; 61 } 62 63 // Note: as the list may contain other lists this may not be final size. 64 mlirOperands.reserve(operandList->size()); 65 for (const auto &&it : llvm::enumerate(*operandList)) { 66 if (it.value().is_none()) 67 continue; 68 69 PyValue *val; 70 try { 71 val = py::cast<PyValue *>(it.value()); 72 if (!val) 73 throw py::cast_error(); 74 mlirOperands.push_back(val->get()); 75 continue; 76 } catch (py::cast_error &err) { 77 // Intentionally unhandled to try sequence below first. 78 (void)err; 79 } 80 81 try { 82 auto vals = py::cast<py::sequence>(it.value()); 83 for (py::object v : vals) { 84 try { 85 val = py::cast<PyValue *>(v); 86 if (!val) 87 throw py::cast_error(); 88 mlirOperands.push_back(val->get()); 89 } catch (py::cast_error &err) { 90 throw py::value_error( 91 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 92 " must be a Value or Sequence of Values (" + err.what() + ")") 93 .str()); 94 } 95 } 96 continue; 97 } catch (py::cast_error &err) { 98 throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 99 " must be a Value or Sequence of Values (" + 100 err.what() + ")") 101 .str()); 102 } 103 104 throw py::cast_error(); 105 } 106 107 return mlirOperands; 108 } 109 110 /// Takes in an optional vector of PyRegions and returns a SmallVector of 111 /// MlirRegion. Returns an empty SmallVector if the list is empty. 112 llvm::SmallVector<MlirRegion> 113 wrapRegions(std::optional<std::vector<PyRegion>> regions) { 114 llvm::SmallVector<MlirRegion> mlirRegions; 115 116 if (regions) { 117 mlirRegions.reserve(regions->size()); 118 for (PyRegion ®ion : *regions) { 119 mlirRegions.push_back(region); 120 } 121 } 122 123 return mlirRegions; 124 } 125 126 } // namespace 127 128 /// CRTP base class for Python classes representing MLIR Op interfaces. 129 /// Interface hierarchies are flat so no base class is expected here. The 130 /// derived class is expected to define the following static fields: 131 /// - `const char *pyClassName` - the name of the Python class to create; 132 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 133 /// of the interface. 134 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 135 /// interface-specific methods. 136 /// 137 /// An interface class may be constructed from either an Operation/OpView object 138 /// or from a subclass of OpView. In the latter case, only the static interface 139 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 140 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 141 /// method to check whether the interface object was constructed from a class or 142 /// an operation/opview instance. The `getOpName` always succeeds and returns a 143 /// canonical name of the operation suitable for lookups. 144 template <typename ConcreteIface> 145 class PyConcreteOpInterface { 146 protected: 147 using ClassTy = py::class_<ConcreteIface>; 148 using GetTypeIDFunctionTy = MlirTypeID (*)(); 149 150 public: 151 /// Constructs an interface instance from an object that is either an 152 /// operation or a subclass of OpView. In the latter case, only the static 153 /// methods of the interface are accessible to the caller. 154 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 155 : obj(std::move(object)) { 156 try { 157 operation = &py::cast<PyOperation &>(obj); 158 } catch (py::cast_error &) { 159 // Do nothing. 160 } 161 162 try { 163 operation = &py::cast<PyOpView &>(obj).getOperation(); 164 } catch (py::cast_error &) { 165 // Do nothing. 166 } 167 168 if (operation != nullptr) { 169 if (!mlirOperationImplementsInterface(*operation, 170 ConcreteIface::getInterfaceID())) { 171 std::string msg = "the operation does not implement "; 172 throw py::value_error(msg + ConcreteIface::pyClassName); 173 } 174 175 MlirIdentifier identifier = mlirOperationGetName(*operation); 176 MlirStringRef stringRef = mlirIdentifierStr(identifier); 177 opName = std::string(stringRef.data, stringRef.length); 178 } else { 179 try { 180 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 181 } catch (py::cast_error &) { 182 throw py::type_error( 183 "Op interface does not refer to an operation or OpView class"); 184 } 185 186 if (!mlirOperationImplementsInterfaceStatic( 187 mlirStringRefCreate(opName.data(), opName.length()), 188 context.resolve().get(), ConcreteIface::getInterfaceID())) { 189 std::string msg = "the operation does not implement "; 190 throw py::value_error(msg + ConcreteIface::pyClassName); 191 } 192 } 193 } 194 195 /// Creates the Python bindings for this class in the given module. 196 static void bind(py::module &m) { 197 py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName, 198 py::module_local()); 199 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 200 py::arg("context") = py::none(), constructorDoc) 201 .def_property_readonly("operation", 202 &PyConcreteOpInterface::getOperationObject, 203 operationDoc) 204 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 205 opviewDoc); 206 ConcreteIface::bindDerived(cls); 207 } 208 209 /// Hook for derived classes to add class-specific bindings. 210 static void bindDerived(ClassTy &cls) {} 211 212 /// Returns `true` if this object was constructed from a subclass of OpView 213 /// rather than from an operation instance. 214 bool isStatic() { return operation == nullptr; } 215 216 /// Returns the operation instance from which this object was constructed. 217 /// Throws a type error if this object was constructed from a subclass of 218 /// OpView. 219 py::object getOperationObject() { 220 if (operation == nullptr) { 221 throw py::type_error("Cannot get an operation from a static interface"); 222 } 223 224 return operation->getRef().releaseObject(); 225 } 226 227 /// Returns the opview of the operation instance from which this object was 228 /// constructed. Throws a type error if this object was constructed form a 229 /// subclass of OpView. 230 py::object getOpView() { 231 if (operation == nullptr) { 232 throw py::type_error("Cannot get an opview from a static interface"); 233 } 234 235 return operation->createOpView(); 236 } 237 238 /// Returns the canonical name of the operation this interface is constructed 239 /// from. 240 const std::string &getOpName() { return opName; } 241 242 private: 243 PyOperation *operation = nullptr; 244 std::string opName; 245 py::object obj; 246 }; 247 248 /// Python wrapper for InferTypeOpInterface. This interface has only static 249 /// methods. 250 class PyInferTypeOpInterface 251 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 252 public: 253 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 254 255 constexpr static const char *pyClassName = "InferTypeOpInterface"; 256 constexpr static GetTypeIDFunctionTy getInterfaceID = 257 &mlirInferTypeOpInterfaceTypeID; 258 259 /// C-style user-data structure for type appending callback. 260 struct AppendResultsCallbackData { 261 std::vector<PyType> &inferredTypes; 262 PyMlirContext &pyMlirContext; 263 }; 264 265 /// Appends the types provided as the two first arguments to the user-data 266 /// structure (expects AppendResultsCallbackData). 267 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 268 void *userData) { 269 auto *data = static_cast<AppendResultsCallbackData *>(userData); 270 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 271 for (intptr_t i = 0; i < nTypes; ++i) { 272 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 273 } 274 } 275 276 /// Given the arguments required to build an operation, attempts to infer its 277 /// return types. Throws value_error on failure. 278 std::vector<PyType> 279 inferReturnTypes(std::optional<py::list> operandList, 280 std::optional<PyAttribute> attributes, void *properties, 281 std::optional<std::vector<PyRegion>> regions, 282 DefaultingPyMlirContext context, 283 DefaultingPyLocation location) { 284 llvm::SmallVector<MlirValue> mlirOperands = 285 wrapOperands(std::move(operandList)); 286 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 287 288 std::vector<PyType> inferredTypes; 289 PyMlirContext &pyContext = context.resolve(); 290 AppendResultsCallbackData data{inferredTypes, pyContext}; 291 MlirStringRef opNameRef = 292 mlirStringRefCreate(getOpName().data(), getOpName().length()); 293 MlirAttribute attributeDict = 294 attributes ? attributes->get() : mlirAttributeGetNull(); 295 296 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 297 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 298 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 299 mlirRegions.data(), &appendResultsCallback, &data); 300 301 if (mlirLogicalResultIsFailure(result)) { 302 throw py::value_error("Failed to infer result types"); 303 } 304 305 return inferredTypes; 306 } 307 308 static void bindDerived(ClassTy &cls) { 309 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 310 py::arg("operands") = py::none(), 311 py::arg("attributes") = py::none(), 312 py::arg("properties") = py::none(), py::arg("regions") = py::none(), 313 py::arg("context") = py::none(), py::arg("loc") = py::none(), 314 inferReturnTypesDoc); 315 } 316 }; 317 318 /// Wrapper around an shaped type components. 319 class PyShapedTypeComponents { 320 public: 321 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 322 PyShapedTypeComponents(py::list shape, MlirType elementType) 323 : shape(std::move(shape)), elementType(elementType), ranked(true) {} 324 PyShapedTypeComponents(py::list shape, MlirType elementType, 325 MlirAttribute attribute) 326 : shape(std::move(shape)), elementType(elementType), attribute(attribute), 327 ranked(true) {} 328 PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 329 PyShapedTypeComponents(PyShapedTypeComponents &&other) 330 : shape(other.shape), elementType(other.elementType), 331 attribute(other.attribute), ranked(other.ranked) {} 332 333 static void bind(py::module &m) { 334 py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents", 335 py::module_local()) 336 .def_property_readonly( 337 "element_type", 338 [](PyShapedTypeComponents &self) { return self.elementType; }, 339 "Returns the element type of the shaped type components.") 340 .def_static( 341 "get", 342 [](PyType &elementType) { 343 return PyShapedTypeComponents(elementType); 344 }, 345 py::arg("element_type"), 346 "Create an shaped type components object with only the element " 347 "type.") 348 .def_static( 349 "get", 350 [](py::list shape, PyType &elementType) { 351 return PyShapedTypeComponents(std::move(shape), elementType); 352 }, 353 py::arg("shape"), py::arg("element_type"), 354 "Create a ranked shaped type components object.") 355 .def_static( 356 "get", 357 [](py::list shape, PyType &elementType, PyAttribute &attribute) { 358 return PyShapedTypeComponents(std::move(shape), elementType, 359 attribute); 360 }, 361 py::arg("shape"), py::arg("element_type"), py::arg("attribute"), 362 "Create a ranked shaped type components object with attribute.") 363 .def_property_readonly( 364 "has_rank", 365 [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 366 "Returns whether the given shaped type component is ranked.") 367 .def_property_readonly( 368 "rank", 369 [](PyShapedTypeComponents &self) -> py::object { 370 if (!self.ranked) { 371 return py::none(); 372 } 373 return py::int_(self.shape.size()); 374 }, 375 "Returns the rank of the given ranked shaped type components. If " 376 "the shaped type components does not have a rank, None is " 377 "returned.") 378 .def_property_readonly( 379 "shape", 380 [](PyShapedTypeComponents &self) -> py::object { 381 if (!self.ranked) { 382 return py::none(); 383 } 384 return py::list(self.shape); 385 }, 386 "Returns the shape of the ranked shaped type components as a list " 387 "of integers. Returns none if the shaped type component does not " 388 "have a rank."); 389 } 390 391 pybind11::object getCapsule(); 392 static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); 393 394 private: 395 py::list shape; 396 MlirType elementType; 397 MlirAttribute attribute; 398 bool ranked{false}; 399 }; 400 401 /// Python wrapper for InferShapedTypeOpInterface. This interface has only 402 /// static methods. 403 class PyInferShapedTypeOpInterface 404 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 405 public: 406 using PyConcreteOpInterface< 407 PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 408 409 constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 410 constexpr static GetTypeIDFunctionTy getInterfaceID = 411 &mlirInferShapedTypeOpInterfaceTypeID; 412 413 /// C-style user-data structure for type appending callback. 414 struct AppendResultsCallbackData { 415 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 416 }; 417 418 /// Appends the shaped type components provided as unpacked shape, element 419 /// type, attribute to the user-data. 420 static void appendResultsCallback(bool hasRank, intptr_t rank, 421 const int64_t *shape, MlirType elementType, 422 MlirAttribute attribute, void *userData) { 423 auto *data = static_cast<AppendResultsCallbackData *>(userData); 424 if (!hasRank) { 425 data->inferredShapedTypeComponents.emplace_back(elementType); 426 } else { 427 py::list shapeList; 428 for (intptr_t i = 0; i < rank; ++i) { 429 shapeList.append(shape[i]); 430 } 431 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 432 attribute); 433 } 434 } 435 436 /// Given the arguments required to build an operation, attempts to infer the 437 /// shaped type components. Throws value_error on failure. 438 std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 439 std::optional<py::list> operandList, 440 std::optional<PyAttribute> attributes, void *properties, 441 std::optional<std::vector<PyRegion>> regions, 442 DefaultingPyMlirContext context, DefaultingPyLocation location) { 443 llvm::SmallVector<MlirValue> mlirOperands = 444 wrapOperands(std::move(operandList)); 445 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 446 447 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 448 PyMlirContext &pyContext = context.resolve(); 449 AppendResultsCallbackData data{inferredShapedTypeComponents}; 450 MlirStringRef opNameRef = 451 mlirStringRefCreate(getOpName().data(), getOpName().length()); 452 MlirAttribute attributeDict = 453 attributes ? attributes->get() : mlirAttributeGetNull(); 454 455 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 456 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 457 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 458 mlirRegions.data(), &appendResultsCallback, &data); 459 460 if (mlirLogicalResultIsFailure(result)) { 461 throw py::value_error("Failed to infer result shape type components"); 462 } 463 464 return inferredShapedTypeComponents; 465 } 466 467 static void bindDerived(ClassTy &cls) { 468 cls.def("inferReturnTypeComponents", 469 &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 470 py::arg("operands") = py::none(), 471 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 472 py::arg("properties") = py::none(), py::arg("context") = py::none(), 473 py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); 474 } 475 }; 476 477 void populateIRInterfaces(py::module &m) { 478 PyInferTypeOpInterface::bind(m); 479 PyShapedTypeComponents::bind(m); 480 PyInferShapedTypeOpInterface::bind(m); 481 } 482 483 } // namespace python 484 } // namespace mlir 485