1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <optional> 10 #include <utility> 11 12 #include "IRModule.h" 13 #include "mlir-c/BuiltinAttributes.h" 14 #include "mlir-c/Interfaces.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 namespace py = pybind11; 18 19 namespace mlir { 20 namespace python { 21 22 constexpr static const char *constructorDoc = 23 R"(Creates an interface from a given operation/opview object or from a 24 subclass of OpView. Raises ValueError if the operation does not implement the 25 interface.)"; 26 27 constexpr static const char *operationDoc = 28 R"(Returns an Operation for which the interface was constructed.)"; 29 30 constexpr static const char *opviewDoc = 31 R"(Returns an OpView subclass _instance_ for which the interface was 32 constructed)"; 33 34 constexpr static const char *inferReturnTypesDoc = 35 R"(Given the arguments required to build an operation, attempts to infer 36 its return types. Raises ValueError on failure.)"; 37 38 constexpr static const char *inferReturnTypeComponentsDoc = 39 R"(Given the arguments required to build an operation, attempts to infer 40 its return shaped type components. Raises ValueError on failure.)"; 41 42 namespace { 43 44 /// Takes in an optional ist of operands and converts them into a SmallVector 45 /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 46 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) { 47 llvm::SmallVector<MlirValue> mlirOperands; 48 49 if (!operandList || operandList->empty()) { 50 return mlirOperands; 51 } 52 53 // Note: as the list may contain other lists this may not be final size. 54 mlirOperands.reserve(operandList->size()); 55 for (const auto &&it : llvm::enumerate(*operandList)) { 56 PyValue *val; 57 try { 58 val = py::cast<PyValue *>(it.value()); 59 if (!val) 60 throw py::cast_error(); 61 mlirOperands.push_back(val->get()); 62 continue; 63 } catch (py::cast_error &err) { 64 // Intentionally unhandled to try sequence below first. 65 (void)err; 66 } 67 68 try { 69 auto vals = py::cast<py::sequence>(it.value()); 70 for (py::object v : vals) { 71 try { 72 val = py::cast<PyValue *>(v); 73 if (!val) 74 throw py::cast_error(); 75 mlirOperands.push_back(val->get()); 76 } catch (py::cast_error &err) { 77 throw py::value_error( 78 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 79 " must be a Value or Sequence of Values (" + err.what() + ")") 80 .str()); 81 } 82 } 83 continue; 84 } catch (py::cast_error &err) { 85 throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 86 " must be a Value or Sequence of Values (" + 87 err.what() + ")") 88 .str()); 89 } 90 91 throw py::cast_error(); 92 } 93 94 return mlirOperands; 95 } 96 97 /// Takes in an optional vector of PyRegions and returns a SmallVector of 98 /// MlirRegion. Returns an empty SmallVector if the list is empty. 99 llvm::SmallVector<MlirRegion> 100 wrapRegions(std::optional<std::vector<PyRegion>> regions) { 101 llvm::SmallVector<MlirRegion> mlirRegions; 102 103 if (regions) { 104 mlirRegions.reserve(regions->size()); 105 for (PyRegion ®ion : *regions) { 106 mlirRegions.push_back(region); 107 } 108 } 109 110 return mlirRegions; 111 } 112 113 } // namespace 114 115 /// CRTP base class for Python classes representing MLIR Op interfaces. 116 /// Interface hierarchies are flat so no base class is expected here. The 117 /// derived class is expected to define the following static fields: 118 /// - `const char *pyClassName` - the name of the Python class to create; 119 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 120 /// of the interface. 121 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 122 /// interface-specific methods. 123 /// 124 /// An interface class may be constructed from either an Operation/OpView object 125 /// or from a subclass of OpView. In the latter case, only the static interface 126 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 127 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 128 /// method to check whether the interface object was constructed from a class or 129 /// an operation/opview instance. The `getOpName` always succeeds and returns a 130 /// canonical name of the operation suitable for lookups. 131 template <typename ConcreteIface> 132 class PyConcreteOpInterface { 133 protected: 134 using ClassTy = py::class_<ConcreteIface>; 135 using GetTypeIDFunctionTy = MlirTypeID (*)(); 136 137 public: 138 /// Constructs an interface instance from an object that is either an 139 /// operation or a subclass of OpView. In the latter case, only the static 140 /// methods of the interface are accessible to the caller. 141 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) 142 : obj(std::move(object)) { 143 try { 144 operation = &py::cast<PyOperation &>(obj); 145 } catch (py::cast_error &) { 146 // Do nothing. 147 } 148 149 try { 150 operation = &py::cast<PyOpView &>(obj).getOperation(); 151 } catch (py::cast_error &) { 152 // Do nothing. 153 } 154 155 if (operation != nullptr) { 156 if (!mlirOperationImplementsInterface(*operation, 157 ConcreteIface::getInterfaceID())) { 158 std::string msg = "the operation does not implement "; 159 throw py::value_error(msg + ConcreteIface::pyClassName); 160 } 161 162 MlirIdentifier identifier = mlirOperationGetName(*operation); 163 MlirStringRef stringRef = mlirIdentifierStr(identifier); 164 opName = std::string(stringRef.data, stringRef.length); 165 } else { 166 try { 167 opName = obj.attr("OPERATION_NAME").template cast<std::string>(); 168 } catch (py::cast_error &) { 169 throw py::type_error( 170 "Op interface does not refer to an operation or OpView class"); 171 } 172 173 if (!mlirOperationImplementsInterfaceStatic( 174 mlirStringRefCreate(opName.data(), opName.length()), 175 context.resolve().get(), ConcreteIface::getInterfaceID())) { 176 std::string msg = "the operation does not implement "; 177 throw py::value_error(msg + ConcreteIface::pyClassName); 178 } 179 } 180 } 181 182 /// Creates the Python bindings for this class in the given module. 183 static void bind(py::module &m) { 184 py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName, 185 py::module_local()); 186 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), 187 py::arg("context") = py::none(), constructorDoc) 188 .def_property_readonly("operation", 189 &PyConcreteOpInterface::getOperationObject, 190 operationDoc) 191 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, 192 opviewDoc); 193 ConcreteIface::bindDerived(cls); 194 } 195 196 /// Hook for derived classes to add class-specific bindings. 197 static void bindDerived(ClassTy &cls) {} 198 199 /// Returns `true` if this object was constructed from a subclass of OpView 200 /// rather than from an operation instance. 201 bool isStatic() { return operation == nullptr; } 202 203 /// Returns the operation instance from which this object was constructed. 204 /// Throws a type error if this object was constructed from a subclass of 205 /// OpView. 206 py::object getOperationObject() { 207 if (operation == nullptr) { 208 throw py::type_error("Cannot get an operation from a static interface"); 209 } 210 211 return operation->getRef().releaseObject(); 212 } 213 214 /// Returns the opview of the operation instance from which this object was 215 /// constructed. Throws a type error if this object was constructed form a 216 /// subclass of OpView. 217 py::object getOpView() { 218 if (operation == nullptr) { 219 throw py::type_error("Cannot get an opview from a static interface"); 220 } 221 222 return operation->createOpView(); 223 } 224 225 /// Returns the canonical name of the operation this interface is constructed 226 /// from. 227 const std::string &getOpName() { return opName; } 228 229 private: 230 PyOperation *operation = nullptr; 231 std::string opName; 232 py::object obj; 233 }; 234 235 /// Python wrapper for InferTypeOpInterface. This interface has only static 236 /// methods. 237 class PyInferTypeOpInterface 238 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 239 public: 240 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 241 242 constexpr static const char *pyClassName = "InferTypeOpInterface"; 243 constexpr static GetTypeIDFunctionTy getInterfaceID = 244 &mlirInferTypeOpInterfaceTypeID; 245 246 /// C-style user-data structure for type appending callback. 247 struct AppendResultsCallbackData { 248 std::vector<PyType> &inferredTypes; 249 PyMlirContext &pyMlirContext; 250 }; 251 252 /// Appends the types provided as the two first arguments to the user-data 253 /// structure (expects AppendResultsCallbackData). 254 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 255 void *userData) { 256 auto *data = static_cast<AppendResultsCallbackData *>(userData); 257 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 258 for (intptr_t i = 0; i < nTypes; ++i) { 259 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 260 } 261 } 262 263 /// Given the arguments required to build an operation, attempts to infer its 264 /// return types. Throws value_error on failure. 265 std::vector<PyType> 266 inferReturnTypes(std::optional<py::list> operandList, 267 std::optional<PyAttribute> attributes, void *properties, 268 std::optional<std::vector<PyRegion>> regions, 269 DefaultingPyMlirContext context, 270 DefaultingPyLocation location) { 271 llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList); 272 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions); 273 274 std::vector<PyType> inferredTypes; 275 PyMlirContext &pyContext = context.resolve(); 276 AppendResultsCallbackData data{inferredTypes, pyContext}; 277 MlirStringRef opNameRef = 278 mlirStringRefCreate(getOpName().data(), getOpName().length()); 279 MlirAttribute attributeDict = 280 attributes ? attributes->get() : mlirAttributeGetNull(); 281 282 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 283 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 284 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 285 mlirRegions.data(), &appendResultsCallback, &data); 286 287 if (mlirLogicalResultIsFailure(result)) { 288 throw py::value_error("Failed to infer result types"); 289 } 290 291 return inferredTypes; 292 } 293 294 static void bindDerived(ClassTy &cls) { 295 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 296 py::arg("operands") = py::none(), 297 py::arg("attributes") = py::none(), 298 py::arg("properties") = py::none(), py::arg("regions") = py::none(), 299 py::arg("context") = py::none(), py::arg("loc") = py::none(), 300 inferReturnTypesDoc); 301 } 302 }; 303 304 /// Wrapper around an shaped type components. 305 class PyShapedTypeComponents { 306 public: 307 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 308 PyShapedTypeComponents(py::list shape, MlirType elementType) 309 : shape(shape), elementType(elementType), ranked(true) {} 310 PyShapedTypeComponents(py::list shape, MlirType elementType, 311 MlirAttribute attribute) 312 : shape(shape), elementType(elementType), attribute(attribute), 313 ranked(true) {} 314 PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 315 PyShapedTypeComponents(PyShapedTypeComponents &&other) 316 : shape(other.shape), elementType(other.elementType), 317 attribute(other.attribute), ranked(other.ranked) {} 318 319 static void bind(py::module &m) { 320 py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents", 321 py::module_local()) 322 .def_property_readonly( 323 "element_type", 324 [](PyShapedTypeComponents &self) { 325 return PyType(PyMlirContext::forContext( 326 mlirTypeGetContext(self.elementType)), 327 self.elementType); 328 }, 329 "Returns the element type of the shaped type components.") 330 .def_static( 331 "get", 332 [](PyType &elementType) { 333 return PyShapedTypeComponents(elementType); 334 }, 335 py::arg("element_type"), 336 "Create an shaped type components object with only the element " 337 "type.") 338 .def_static( 339 "get", 340 [](py::list shape, PyType &elementType) { 341 return PyShapedTypeComponents(shape, elementType); 342 }, 343 py::arg("shape"), py::arg("element_type"), 344 "Create a ranked shaped type components object.") 345 .def_static( 346 "get", 347 [](py::list shape, PyType &elementType, PyAttribute &attribute) { 348 return PyShapedTypeComponents(shape, elementType, attribute); 349 }, 350 py::arg("shape"), py::arg("element_type"), py::arg("attribute"), 351 "Create a ranked shaped type components object with attribute.") 352 .def_property_readonly( 353 "has_rank", 354 [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 355 "Returns whether the given shaped type component is ranked.") 356 .def_property_readonly( 357 "rank", 358 [](PyShapedTypeComponents &self) -> py::object { 359 if (!self.ranked) { 360 return py::none(); 361 } 362 return py::int_(self.shape.size()); 363 }, 364 "Returns the rank of the given ranked shaped type components. If " 365 "the shaped type components does not have a rank, None is " 366 "returned.") 367 .def_property_readonly( 368 "shape", 369 [](PyShapedTypeComponents &self) -> py::object { 370 if (!self.ranked) { 371 return py::none(); 372 } 373 return py::list(self.shape); 374 }, 375 "Returns the shape of the ranked shaped type components as a list " 376 "of integers. Returns none if the shaped type component does not " 377 "have a rank."); 378 } 379 380 pybind11::object getCapsule(); 381 static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); 382 383 private: 384 py::list shape; 385 MlirType elementType; 386 MlirAttribute attribute; 387 bool ranked{false}; 388 }; 389 390 /// Python wrapper for InferShapedTypeOpInterface. This interface has only 391 /// static methods. 392 class PyInferShapedTypeOpInterface 393 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 394 public: 395 using PyConcreteOpInterface< 396 PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 397 398 constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 399 constexpr static GetTypeIDFunctionTy getInterfaceID = 400 &mlirInferShapedTypeOpInterfaceTypeID; 401 402 /// C-style user-data structure for type appending callback. 403 struct AppendResultsCallbackData { 404 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 405 }; 406 407 /// Appends the shaped type components provided as unpacked shape, element 408 /// type, attribute to the user-data. 409 static void appendResultsCallback(bool hasRank, intptr_t rank, 410 const int64_t *shape, MlirType elementType, 411 MlirAttribute attribute, void *userData) { 412 auto *data = static_cast<AppendResultsCallbackData *>(userData); 413 if (!hasRank) { 414 data->inferredShapedTypeComponents.emplace_back(elementType); 415 } else { 416 py::list shapeList; 417 for (intptr_t i = 0; i < rank; ++i) { 418 shapeList.append(shape[i]); 419 } 420 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 421 attribute); 422 } 423 } 424 425 /// Given the arguments required to build an operation, attempts to infer the 426 /// shaped type components. Throws value_error on failure. 427 std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 428 std::optional<py::list> operandList, 429 std::optional<PyAttribute> attributes, void *properties, 430 std::optional<std::vector<PyRegion>> regions, 431 DefaultingPyMlirContext context, DefaultingPyLocation location) { 432 llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList); 433 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions); 434 435 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 436 PyMlirContext &pyContext = context.resolve(); 437 AppendResultsCallbackData data{inferredShapedTypeComponents}; 438 MlirStringRef opNameRef = 439 mlirStringRefCreate(getOpName().data(), getOpName().length()); 440 MlirAttribute attributeDict = 441 attributes ? attributes->get() : mlirAttributeGetNull(); 442 443 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 444 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 445 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 446 mlirRegions.data(), &appendResultsCallback, &data); 447 448 if (mlirLogicalResultIsFailure(result)) { 449 throw py::value_error("Failed to infer result shape type components"); 450 } 451 452 return inferredShapedTypeComponents; 453 } 454 455 static void bindDerived(ClassTy &cls) { 456 cls.def("inferReturnTypeComponents", 457 &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 458 py::arg("operands") = py::none(), 459 py::arg("attributes") = py::none(), py::arg("regions") = py::none(), 460 py::arg("properties") = py::none(), py::arg("context") = py::none(), 461 py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); 462 } 463 }; 464 465 void populateIRInterfaces(py::module &m) { 466 PyInferTypeOpInterface::bind(m); 467 PyShapedTypeComponents::bind(m); 468 PyInferShapedTypeOpInterface::bind(m); 469 } 470 471 } // namespace python 472 } // namespace mlir 473