1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <optional> 10 #include <utility> 11 12 #include "IRModule.h" 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/Interfaces.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 namespace py = pybind11; 18 19 namespace mlir { 20 namespace python { 21 22 constexpr static const char *constructorDoc = 23 R"(Creates an interface from a given operation/opview object or from a 24 subclass of OpView. Raises ValueError if the operation does not implement the 25 interface.)"; 26 27 constexpr static const char *operationDoc = 28 R"(Returns an Operation for which the interface was constructed.)"; 29 30 constexpr static const char *opviewDoc = 31 R"(Returns an OpView subclass _instance_ for which the interface was 32 constructed)"; 33 34 constexpr static const char *inferReturnTypesDoc = 35 R"(Given the arguments required to build an operation, attempts to infer 36 its return types. Raises ValueError on failure.)"; 37 38 constexpr static const char *inferReturnTypeComponentsDoc = 39 R"(Given the arguments required to build an operation, attempts to infer 40 its return shaped type components. Raises ValueError on failure.)"; 41 42 namespace { 43 44 /// Takes in an optional ist of operands and converts them into a SmallVector 45 /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 46 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) { 47 llvm::SmallVector<MlirValue> mlirOperands; 48 49 if (!operandList || operandList->empty()) { 50 return mlirOperands; 51 } 52 53 // Note: as the list may contain other lists this may not be final size. 54 mlirOperands.reserve(operandList->size()); 55 for (const auto &&it : llvm::enumerate(*operandList)) { 56 PyValue *val; 57 try { 58 val = py::cast<PyValue *>(it.value()); 59 if (!val) 60 throw py::cast_error(); 61 mlirOperands.push_back(val->get()); 62 continue; 63 } catch (py::cast_error &err) { 64 // Intentionally unhandled to try sequence below first. 65 (void)err; 66 } 67 68 try { 69 auto vals = py::cast<py::sequence>(it.value()); 70 for (py::object v : vals) { 71 try { 72 val = py::cast<PyValue *>(v); 73 if (!val) 74 throw py::cast_error(); 75 mlirOperands.push_back(val->get()); 76 } catch (py::cast_error &err) { 77 throw py::value_error( 78 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 79 " must be a Value or Sequence of Values (" + err.what() + ")") 80 .str()); 81 } 82 } 83 continue; 84 } catch (py::cast_error &err) { 85 throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 86 " must be a Value or Sequence of Values (" + 87 err.what() + ")") 88 .str()); 89 } 90 91 throw py::cast_error(); 92 } 93 94 return mlirOperands; 95 } 96 97 /// Takes in an optional vector of PyRegions and returns a SmallVector of 98 /// MlirRegion. Returns an empty SmallVector if the list is empty. 99 llvm::SmallVector<MlirRegion> 100 wrapRegions(std::optional<std::vector<PyRegion>> regions) { 101 llvm::SmallVector<MlirRegion> mlirRegions; 102 103 if (regions) { 104 mlirRegions.reserve(regions->size()); 105 for (PyRegion ®ion : *regions) { 106 mlirRegions.push_back(region); 107 } 108 } 109 110 return mlirRegions; 111 } 112 113 } // namespace 114 115 /// CRTP base class for Python classes representing MLIR Op interfaces. 116 /// Interface hierarchies are flat so no base class is expected here. The 117 /// derived class is expected to define the following static fields: 118 /// - `const char *pyClassName` - the name of the Python class to create; 119 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 120 /// of the interface. 121 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 122 /// interface-specific methods. 123 /// 124 /// An interface class may be constructed from either an Operation/OpView object 125 /// or from a subclass of OpView. In the latter case, only the static interface 126 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 127 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 128 /// method to check whether the interface object was constructed from a class or 129 /// an operation/opview instance. The `getOpName` always succeeds and returns a 130 /// canonical name of the operation suitable for lookups. 131 template <typename ConcreteIface> 132 class PyConcreteOpInterface { 133 protected: 134 using ClassTy = py::class_<ConcreteIface>; 135 using GetTypeIDFunctionTy = MlirTypeID (*)(); 136 137 public: 138 /// Constructs an interface instance from an object that is either an 139 /// operation or a subclass of OpView. In the latter case, only the static 140 /// methods of the interface are accessible to the caller. 141 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 142 : obj(std::move(object)) { 143 try { 144 operation = &py::cast<PyOperation &>(obj); 145 } catch (py::cast_error &) { 146 // Do nothing. 147 } 148 149 try { 150 operation = &py::cast<PyOpView &>(obj).getOperation(); 151 } catch (py::cast_error &) { 152 // Do nothing. 153 } 154 155 if (operation != nullptr) { 156 if (!mlirOperationImplementsInterface(*operation, 157 ConcreteIface::getInterfaceID())) { 158 std::string msg = "the operation does not implement "; 159 throw py::value_error(msg + ConcreteIface::pyClassName); 160 } 161 162 MlirIdentifier identifier = mlirOperationGetName(*operation); 163 MlirStringRef stringRef = mlirIdentifierStr(identifier); 164 opName = std::string(stringRef.data, stringRef.length); 165 } else { 166 try { 167 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 168 } catch (py::cast_error &) { 169 throw py::type_error( 170 "Op interface does not refer to an operation or OpView class"); 171 } 172 173 if (!mlirOperationImplementsInterfaceStatic( 174 mlirStringRefCreate(opName.data(), opName.length()), 175 context.resolve().get(), ConcreteIface::getInterfaceID())) { 176 std::string msg = "the operation does not implement "; 177 throw py::value_error(msg + ConcreteIface::pyClassName); 178 } 179 } 180 } 181 182 /// Creates the Python bindings for this class in the given module. 183 static void bind(py::module &m) { 184 py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName, 185 py::module_local()); 186 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 187 py::arg("context") = py::none(), constructorDoc) 188 .def_property_readonly("operation", 189 &PyConcreteOpInterface::getOperationObject, 190 operationDoc) 191 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 192 opviewDoc); 193 ConcreteIface::bindDerived(cls); 194 } 195 196 /// Hook for derived classes to add class-specific bindings. 197 static void bindDerived(ClassTy &cls) {} 198 199 /// Returns `true` if this object was constructed from a subclass of OpView 200 /// rather than from an operation instance. 201 bool isStatic() { return operation == nullptr; } 202 203 /// Returns the operation instance from which this object was constructed. 204 /// Throws a type error if this object was constructed from a subclass of 205 /// OpView. 206 py::object getOperationObject() { 207 if (operation == nullptr) { 208 throw py::type_error("Cannot get an operation from a static interface"); 209 } 210 211 return operation->getRef().releaseObject(); 212 } 213 214 /// Returns the opview of the operation instance from which this object was 215 /// constructed. Throws a type error if this object was constructed form a 216 /// subclass of OpView. 217 py::object getOpView() { 218 if (operation == nullptr) { 219 throw py::type_error("Cannot get an opview from a static interface"); 220 } 221 222 return operation->createOpView(); 223 } 224 225 /// Returns the canonical name of the operation this interface is constructed 226 /// from. 227 const std::string &getOpName() { return opName; } 228 229 private: 230 PyOperation *operation = nullptr; 231 std::string opName; 232 py::object obj; 233 }; 234 235 /// Python wrapper for InferTypeOpInterface. This interface has only static 236 /// methods. 237 class PyInferTypeOpInterface 238 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 239 public: 240 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 241 242 constexpr static const char *pyClassName = "InferTypeOpInterface"; 243 constexpr static GetTypeIDFunctionTy getInterfaceID = 244 &mlirInferTypeOpInterfaceTypeID; 245 246 /// C-style user-data structure for type appending callback. 247 struct AppendResultsCallbackData { 248 std::vector<PyType> &inferredTypes; 249 PyMlirContext &pyMlirContext; 250 }; 251 252 /// Appends the types provided as the two first arguments to the user-data 253 /// structure (expects AppendResultsCallbackData). 254 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 255 void *userData) { 256 auto *data = static_cast<AppendResultsCallbackData *>(userData); 257 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 258 for (intptr_t i = 0; i < nTypes; ++i) { 259 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 260 } 261 } 262 263 /// Given the arguments required to build an operation, attempts to infer its 264 /// return types. Throws value_error on failure. 265 std::vector<PyType> 266 inferReturnTypes(std::optional<py::list> operandList, 267 std::optional<PyAttribute> attributes, void *properties, 268 std::optional<std::vector<PyRegion>> regions, 269 DefaultingPyMlirContext context, 270 DefaultingPyLocation location) { 271 llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList); 272 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions); 273 274 std::vector<PyType> inferredTypes; 275 PyMlirContext &pyContext = context.resolve(); 276 AppendResultsCallbackData data{inferredTypes, pyContext}; 277 MlirStringRef opNameRef = 278 mlirStringRefCreate(getOpName().data(), getOpName().length()); 279 MlirAttribute attributeDict = 280 attributes ? attributes->get() : mlirAttributeGetNull(); 281 282 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 283 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 284 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 285 mlirRegions.data(), &appendResultsCallback, &data); 286 287 if (mlirLogicalResultIsFailure(result)) { 288 throw py::value_error("Failed to infer result types"); 289 } 290 291 return inferredTypes; 292 } 293 294 static void bindDerived(ClassTy &cls) { 295 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 296 py::arg("operands") = py::none(), 297 py::arg("attributes") = py::none(), 298 py::arg("properties") = py::none(), py::arg("regions") = py::none(), 299 py::arg("context") = py::none(), py::arg("loc") = py::none(), 300 inferReturnTypesDoc); 301 } 302 }; 303 304 /// Wrapper around an shaped type components. 305 class PyShapedTypeComponents { 306 public: 307 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 308 PyShapedTypeComponents(py::list shape, MlirType elementType) 309 : shape(shape), elementType(elementType), ranked(true) {} 310 PyShapedTypeComponents(py::list shape, MlirType elementType, 311 MlirAttribute attribute) 312 : shape(shape), elementType(elementType), attribute(attribute), 313 ranked(true) {} 314 PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 315 PyShapedTypeComponents(PyShapedTypeComponents &&other) 316 : shape(other.shape), elementType(other.elementType), 317 attribute(other.attribute), ranked(other.ranked) {} 318 319 static void bind(py::module &m) { 320 py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents", 321 py::module_local()) 322 .def_property_readonly( 323 "element_type", 324 [](PyShapedTypeComponents &self) { return self.elementType; }, 325 "Returns the element type of the shaped type components.") 326 .def_static( 327 "get", 328 [](PyType &elementType) { 329 return PyShapedTypeComponents(elementType); 330 }, 331 py::arg("element_type"), 332 "Create an shaped type components object with only the element " 333 "type.") 334 .def_static( 335 "get", 336 [](py::list shape, PyType &elementType) { 337 return PyShapedTypeComponents(shape, elementType); 338 }, 339 py::arg("shape"), py::arg("element_type"), 340 "Create a ranked shaped type components object.") 341 .def_static( 342 "get", 343 [](py::list shape, PyType &elementType, PyAttribute &attribute) { 344 return PyShapedTypeComponents(shape, elementType, attribute); 345 }, 346 py::arg("shape"), py::arg("element_type"), py::arg("attribute"), 347 "Create a ranked shaped type components object with attribute.") 348 .def_property_readonly( 349 "has_rank", 350 [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 351 "Returns whether the given shaped type component is ranked.") 352 .def_property_readonly( 353 "rank", 354 [](PyShapedTypeComponents &self) -> py::object { 355 if (!self.ranked) { 356 return py::none(); 357 } 358 return py::int_(self.shape.size()); 359 }, 360 "Returns the rank of the given ranked shaped type components. If " 361 "the shaped type components does not have a rank, None is " 362 "returned.") 363 .def_property_readonly( 364 "shape", 365 [](PyShapedTypeComponents &self) -> py::object { 366 if (!self.ranked) { 367 return py::none(); 368 } 369 return py::list(self.shape); 370 }, 371 "Returns the shape of the ranked shaped type components as a list " 372 "of integers. Returns none if the shaped type component does not " 373 "have a rank."); 374 } 375 376 pybind11::object getCapsule(); 377 static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); 378 379 private: 380 py::list shape; 381 MlirType elementType; 382 MlirAttribute attribute; 383 bool ranked{false}; 384 }; 385 386 /// Python wrapper for InferShapedTypeOpInterface. This interface has only 387 /// static methods. 388 class PyInferShapedTypeOpInterface 389 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 390 public: 391 using PyConcreteOpInterface< 392 PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 393 394 constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 395 constexpr static GetTypeIDFunctionTy getInterfaceID = 396 &mlirInferShapedTypeOpInterfaceTypeID; 397 398 /// C-style user-data structure for type appending callback. 399 struct AppendResultsCallbackData { 400 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 401 }; 402 403 /// Appends the shaped type components provided as unpacked shape, element 404 /// type, attribute to the user-data. 405 static void appendResultsCallback(bool hasRank, intptr_t rank, 406 const int64_t *shape, MlirType elementType, 407 MlirAttribute attribute, void *userData) { 408 auto *data = static_cast<AppendResultsCallbackData *>(userData); 409 if (!hasRank) { 410 data->inferredShapedTypeComponents.emplace_back(elementType); 411 } else { 412 py::list shapeList; 413 for (intptr_t i = 0; i < rank; ++i) { 414 shapeList.append(shape[i]); 415 } 416 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 417 attribute); 418 } 419 } 420 421 /// Given the arguments required to build an operation, attempts to infer the 422 /// shaped type components. Throws value_error on failure. 423 std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 424 std::optional<py::list> operandList, 425 std::optional<PyAttribute> attributes, void *properties, 426 std::optional<std::vector<PyRegion>> regions, 427 DefaultingPyMlirContext context, DefaultingPyLocation location) { 428 llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList); 429 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions); 430 431 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 432 PyMlirContext &pyContext = context.resolve(); 433 AppendResultsCallbackData data{inferredShapedTypeComponents}; 434 MlirStringRef opNameRef = 435 mlirStringRefCreate(getOpName().data(), getOpName().length()); 436 MlirAttribute attributeDict = 437 attributes ? attributes->get() : mlirAttributeGetNull(); 438 439 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 440 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 441 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 442 mlirRegions.data(), &appendResultsCallback, &data); 443 444 if (mlirLogicalResultIsFailure(result)) { 445 throw py::value_error("Failed to infer result shape type components"); 446 } 447 448 return inferredShapedTypeComponents; 449 } 450 451 static void bindDerived(ClassTy &cls) { 452 cls.def("inferReturnTypeComponents", 453 &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 454 py::arg("operands") = py::none(), 455 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 456 py::arg("properties") = py::none(), py::arg("context") = py::none(), 457 py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); 458 } 459 }; 460 461 void populateIRInterfaces(py::module &m) { 462 PyInferTypeOpInterface::bind(m); 463 PyShapedTypeComponents::bind(m); 464 PyInferShapedTypeOpInterface::bind(m); 465 } 466 467 } // namespace python 468 } // namespace mlir 469