1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <nanobind/nanobind.h> 10 #include <nanobind/stl/optional.h> 11 #include <nanobind/stl/vector.h> 12 13 #include <cstdint> 14 #include <optional> 15 #include <string> 16 #include <utility> 17 #include <vector> 18 19 #include "IRModule.h" 20 #include "mlir-c/BuiltinAttributes.h" 21 #include "mlir-c/IR.h" 22 #include "mlir-c/Interfaces.h" 23 #include "mlir-c/Support.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallVector.h" 26 27 namespace nb = nanobind; 28 29 namespace mlir { 30 namespace python { 31 32 constexpr static const char *constructorDoc = 33 R"(Creates an interface from a given operation/opview object or from a 34 subclass of OpView. Raises ValueError if the operation does not implement the 35 interface.)"; 36 37 constexpr static const char *operationDoc = 38 R"(Returns an Operation for which the interface was constructed.)"; 39 40 constexpr static const char *opviewDoc = 41 R"(Returns an OpView subclass _instance_ for which the interface was 42 constructed)"; 43 44 constexpr static const char *inferReturnTypesDoc = 45 R"(Given the arguments required to build an operation, attempts to infer 46 its return types. Raises ValueError on failure.)"; 47 48 constexpr static const char *inferReturnTypeComponentsDoc = 49 R"(Given the arguments required to build an operation, attempts to infer 50 its return shaped type components. Raises ValueError on failure.)"; 51 52 namespace { 53 54 /// Takes in an optional ist of operands and converts them into a SmallVector 55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 56 llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) { 57 llvm::SmallVector<MlirValue> mlirOperands; 58 59 if (!operandList || operandList->size() == 0) { 60 return mlirOperands; 61 } 62 63 // Note: as the list may contain other lists this may not be final size. 64 mlirOperands.reserve(operandList->size()); 65 for (const auto &&it : llvm::enumerate(*operandList)) { 66 if (it.value().is_none()) 67 continue; 68 69 PyValue *val; 70 try { 71 val = nb::cast<PyValue *>(it.value()); 72 if (!val) 73 throw nb::cast_error(); 74 mlirOperands.push_back(val->get()); 75 continue; 76 } catch (nb::cast_error &err) { 77 // Intentionally unhandled to try sequence below first. 78 (void)err; 79 } 80 81 try { 82 auto vals = nb::cast<nb::sequence>(it.value()); 83 for (nb::handle v : vals) { 84 try { 85 val = nb::cast<PyValue *>(v); 86 if (!val) 87 throw nb::cast_error(); 88 mlirOperands.push_back(val->get()); 89 } catch (nb::cast_error &err) { 90 throw nb::value_error( 91 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 92 " must be a Value or Sequence of Values (" + err.what() + ")") 93 .str() 94 .c_str()); 95 } 96 } 97 continue; 98 } catch (nb::cast_error &err) { 99 throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 100 " must be a Value or Sequence of Values (" + 101 err.what() + ")") 102 .str() 103 .c_str()); 104 } 105 106 throw nb::cast_error(); 107 } 108 109 return mlirOperands; 110 } 111 112 /// Takes in an optional vector of PyRegions and returns a SmallVector of 113 /// MlirRegion. Returns an empty SmallVector if the list is empty. 114 llvm::SmallVector<MlirRegion> 115 wrapRegions(std::optional<std::vector<PyRegion>> regions) { 116 llvm::SmallVector<MlirRegion> mlirRegions; 117 118 if (regions) { 119 mlirRegions.reserve(regions->size()); 120 for (PyRegion ®ion : *regions) { 121 mlirRegions.push_back(region); 122 } 123 } 124 125 return mlirRegions; 126 } 127 128 } // namespace 129 130 /// CRTP base class for Python classes representing MLIR Op interfaces. 131 /// Interface hierarchies are flat so no base class is expected here. The 132 /// derived class is expected to define the following static fields: 133 /// - `const char *pyClassName` - the name of the Python class to create; 134 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 135 /// of the interface. 136 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 137 /// interface-specific methods. 138 /// 139 /// An interface class may be constructed from either an Operation/OpView object 140 /// or from a subclass of OpView. In the latter case, only the static interface 141 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 142 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 143 /// method to check whether the interface object was constructed from a class or 144 /// an operation/opview instance. The `getOpName` always succeeds and returns a 145 /// canonical name of the operation suitable for lookups. 146 template <typename ConcreteIface> 147 class PyConcreteOpInterface { 148 protected: 149 using ClassTy = nb::class_<ConcreteIface>; 150 using GetTypeIDFunctionTy = MlirTypeID (*)(); 151 152 public: 153 /// Constructs an interface instance from an object that is either an 154 /// operation or a subclass of OpView. In the latter case, only the static 155 /// methods of the interface are accessible to the caller. 156 PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) 157 : obj(std::move(object)) { 158 try { 159 operation = &nb::cast<PyOperation &>(obj); 160 } catch (nb::cast_error &) { 161 // Do nothing. 162 } 163 164 try { 165 operation = &nb::cast<PyOpView &>(obj).getOperation(); 166 } catch (nb::cast_error &) { 167 // Do nothing. 168 } 169 170 if (operation != nullptr) { 171 if (!mlirOperationImplementsInterface(*operation, 172 ConcreteIface::getInterfaceID())) { 173 std::string msg = "the operation does not implement "; 174 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 175 } 176 177 MlirIdentifier identifier = mlirOperationGetName(*operation); 178 MlirStringRef stringRef = mlirIdentifierStr(identifier); 179 opName = std::string(stringRef.data, stringRef.length); 180 } else { 181 try { 182 opName = nb::cast<std::string>(obj.attr("OPERATION_NAME")); 183 } catch (nb::cast_error &) { 184 throw nb::type_error( 185 "Op interface does not refer to an operation or OpView class"); 186 } 187 188 if (!mlirOperationImplementsInterfaceStatic( 189 mlirStringRefCreate(opName.data(), opName.length()), 190 context.resolve().get(), ConcreteIface::getInterfaceID())) { 191 std::string msg = "the operation does not implement "; 192 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 193 } 194 } 195 } 196 197 /// Creates the Python bindings for this class in the given module. 198 static void bind(nb::module_ &m) { 199 nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName); 200 cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"), 201 nb::arg("context").none() = nb::none(), constructorDoc) 202 .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, 203 operationDoc) 204 .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); 205 ConcreteIface::bindDerived(cls); 206 } 207 208 /// Hook for derived classes to add class-specific bindings. 209 static void bindDerived(ClassTy &cls) {} 210 211 /// Returns `true` if this object was constructed from a subclass of OpView 212 /// rather than from an operation instance. 213 bool isStatic() { return operation == nullptr; } 214 215 /// Returns the operation instance from which this object was constructed. 216 /// Throws a type error if this object was constructed from a subclass of 217 /// OpView. 218 nb::object getOperationObject() { 219 if (operation == nullptr) { 220 throw nb::type_error("Cannot get an operation from a static interface"); 221 } 222 223 return operation->getRef().releaseObject(); 224 } 225 226 /// Returns the opview of the operation instance from which this object was 227 /// constructed. Throws a type error if this object was constructed form a 228 /// subclass of OpView. 229 nb::object getOpView() { 230 if (operation == nullptr) { 231 throw nb::type_error("Cannot get an opview from a static interface"); 232 } 233 234 return operation->createOpView(); 235 } 236 237 /// Returns the canonical name of the operation this interface is constructed 238 /// from. 239 const std::string &getOpName() { return opName; } 240 241 private: 242 PyOperation *operation = nullptr; 243 std::string opName; 244 nb::object obj; 245 }; 246 247 /// Python wrapper for InferTypeOpInterface. This interface has only static 248 /// methods. 249 class PyInferTypeOpInterface 250 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 251 public: 252 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 253 254 constexpr static const char *pyClassName = "InferTypeOpInterface"; 255 constexpr static GetTypeIDFunctionTy getInterfaceID = 256 &mlirInferTypeOpInterfaceTypeID; 257 258 /// C-style user-data structure for type appending callback. 259 struct AppendResultsCallbackData { 260 std::vector<PyType> &inferredTypes; 261 PyMlirContext &pyMlirContext; 262 }; 263 264 /// Appends the types provided as the two first arguments to the user-data 265 /// structure (expects AppendResultsCallbackData). 266 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 267 void *userData) { 268 auto *data = static_cast<AppendResultsCallbackData *>(userData); 269 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 270 for (intptr_t i = 0; i < nTypes; ++i) { 271 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 272 } 273 } 274 275 /// Given the arguments required to build an operation, attempts to infer its 276 /// return types. Throws value_error on failure. 277 std::vector<PyType> 278 inferReturnTypes(std::optional<nb::list> operandList, 279 std::optional<PyAttribute> attributes, void *properties, 280 std::optional<std::vector<PyRegion>> regions, 281 DefaultingPyMlirContext context, 282 DefaultingPyLocation location) { 283 llvm::SmallVector<MlirValue> mlirOperands = 284 wrapOperands(std::move(operandList)); 285 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 286 287 std::vector<PyType> inferredTypes; 288 PyMlirContext &pyContext = context.resolve(); 289 AppendResultsCallbackData data{inferredTypes, pyContext}; 290 MlirStringRef opNameRef = 291 mlirStringRefCreate(getOpName().data(), getOpName().length()); 292 MlirAttribute attributeDict = 293 attributes ? attributes->get() : mlirAttributeGetNull(); 294 295 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 296 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 297 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 298 mlirRegions.data(), &appendResultsCallback, &data); 299 300 if (mlirLogicalResultIsFailure(result)) { 301 throw nb::value_error("Failed to infer result types"); 302 } 303 304 return inferredTypes; 305 } 306 307 static void bindDerived(ClassTy &cls) { 308 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 309 nb::arg("operands").none() = nb::none(), 310 nb::arg("attributes").none() = nb::none(), 311 nb::arg("properties").none() = nb::none(), 312 nb::arg("regions").none() = nb::none(), 313 nb::arg("context").none() = nb::none(), 314 nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); 315 } 316 }; 317 318 /// Wrapper around an shaped type components. 319 class PyShapedTypeComponents { 320 public: 321 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 322 PyShapedTypeComponents(nb::list shape, MlirType elementType) 323 : shape(std::move(shape)), elementType(elementType), ranked(true) {} 324 PyShapedTypeComponents(nb::list shape, MlirType elementType, 325 MlirAttribute attribute) 326 : shape(std::move(shape)), elementType(elementType), attribute(attribute), 327 ranked(true) {} 328 PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 329 PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept 330 : shape(other.shape), elementType(other.elementType), 331 attribute(other.attribute), ranked(other.ranked) {} 332 333 static void bind(nb::module_ &m) { 334 nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents") 335 .def_prop_ro( 336 "element_type", 337 [](PyShapedTypeComponents &self) { return self.elementType; }, 338 "Returns the element type of the shaped type components.") 339 .def_static( 340 "get", 341 [](PyType &elementType) { 342 return PyShapedTypeComponents(elementType); 343 }, 344 nb::arg("element_type"), 345 "Create an shaped type components object with only the element " 346 "type.") 347 .def_static( 348 "get", 349 [](nb::list shape, PyType &elementType) { 350 return PyShapedTypeComponents(std::move(shape), elementType); 351 }, 352 nb::arg("shape"), nb::arg("element_type"), 353 "Create a ranked shaped type components object.") 354 .def_static( 355 "get", 356 [](nb::list shape, PyType &elementType, PyAttribute &attribute) { 357 return PyShapedTypeComponents(std::move(shape), elementType, 358 attribute); 359 }, 360 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), 361 "Create a ranked shaped type components object with attribute.") 362 .def_prop_ro( 363 "has_rank", 364 [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 365 "Returns whether the given shaped type component is ranked.") 366 .def_prop_ro( 367 "rank", 368 [](PyShapedTypeComponents &self) -> nb::object { 369 if (!self.ranked) { 370 return nb::none(); 371 } 372 return nb::int_(self.shape.size()); 373 }, 374 "Returns the rank of the given ranked shaped type components. If " 375 "the shaped type components does not have a rank, None is " 376 "returned.") 377 .def_prop_ro( 378 "shape", 379 [](PyShapedTypeComponents &self) -> nb::object { 380 if (!self.ranked) { 381 return nb::none(); 382 } 383 return nb::list(self.shape); 384 }, 385 "Returns the shape of the ranked shaped type components as a list " 386 "of integers. Returns none if the shaped type component does not " 387 "have a rank."); 388 } 389 390 nb::object getCapsule(); 391 static PyShapedTypeComponents createFromCapsule(nb::object capsule); 392 393 private: 394 nb::list shape; 395 MlirType elementType; 396 MlirAttribute attribute; 397 bool ranked{false}; 398 }; 399 400 /// Python wrapper for InferShapedTypeOpInterface. This interface has only 401 /// static methods. 402 class PyInferShapedTypeOpInterface 403 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 404 public: 405 using PyConcreteOpInterface< 406 PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 407 408 constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 409 constexpr static GetTypeIDFunctionTy getInterfaceID = 410 &mlirInferShapedTypeOpInterfaceTypeID; 411 412 /// C-style user-data structure for type appending callback. 413 struct AppendResultsCallbackData { 414 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 415 }; 416 417 /// Appends the shaped type components provided as unpacked shape, element 418 /// type, attribute to the user-data. 419 static void appendResultsCallback(bool hasRank, intptr_t rank, 420 const int64_t *shape, MlirType elementType, 421 MlirAttribute attribute, void *userData) { 422 auto *data = static_cast<AppendResultsCallbackData *>(userData); 423 if (!hasRank) { 424 data->inferredShapedTypeComponents.emplace_back(elementType); 425 } else { 426 nb::list shapeList; 427 for (intptr_t i = 0; i < rank; ++i) { 428 shapeList.append(shape[i]); 429 } 430 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 431 attribute); 432 } 433 } 434 435 /// Given the arguments required to build an operation, attempts to infer the 436 /// shaped type components. Throws value_error on failure. 437 std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 438 std::optional<nb::list> operandList, 439 std::optional<PyAttribute> attributes, void *properties, 440 std::optional<std::vector<PyRegion>> regions, 441 DefaultingPyMlirContext context, DefaultingPyLocation location) { 442 llvm::SmallVector<MlirValue> mlirOperands = 443 wrapOperands(std::move(operandList)); 444 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 445 446 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 447 PyMlirContext &pyContext = context.resolve(); 448 AppendResultsCallbackData data{inferredShapedTypeComponents}; 449 MlirStringRef opNameRef = 450 mlirStringRefCreate(getOpName().data(), getOpName().length()); 451 MlirAttribute attributeDict = 452 attributes ? attributes->get() : mlirAttributeGetNull(); 453 454 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 455 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 456 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 457 mlirRegions.data(), &appendResultsCallback, &data); 458 459 if (mlirLogicalResultIsFailure(result)) { 460 throw nb::value_error("Failed to infer result shape type components"); 461 } 462 463 return inferredShapedTypeComponents; 464 } 465 466 static void bindDerived(ClassTy &cls) { 467 cls.def("inferReturnTypeComponents", 468 &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 469 nb::arg("operands").none() = nb::none(), 470 nb::arg("attributes").none() = nb::none(), 471 nb::arg("regions").none() = nb::none(), 472 nb::arg("properties").none() = nb::none(), 473 nb::arg("context").none() = nb::none(), 474 nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); 475 } 476 }; 477 478 void populateIRInterfaces(nb::module_ &m) { 479 PyInferTypeOpInterface::bind(m); 480 PyShapedTypeComponents::bind(m); 481 PyInferShapedTypeOpInterface::bind(m); 482 } 483 484 } // namespace python 485 } // namespace mlir 486