1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cstdint> 10 #include <optional> 11 #include <string> 12 #include <utility> 13 #include <vector> 14 15 #include "IRModule.h" 16 #include "mlir-c/BuiltinAttributes.h" 17 #include "mlir-c/IR.h" 18 #include "mlir-c/Interfaces.h" 19 #include "mlir-c/Support.h" 20 #include "mlir/Bindings/Python/Nanobind.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/ADT/SmallVector.h" 23 24 namespace nb = nanobind; 25 26 namespace mlir { 27 namespace python { 28 29 constexpr static const char *constructorDoc = 30 R"(Creates an interface from a given operation/opview object or from a 31 subclass of OpView. Raises ValueError if the operation does not implement the 32 interface.)"; 33 34 constexpr static const char *operationDoc = 35 R"(Returns an Operation for which the interface was constructed.)"; 36 37 constexpr static const char *opviewDoc = 38 R"(Returns an OpView subclass _instance_ for which the interface was 39 constructed)"; 40 41 constexpr static const char *inferReturnTypesDoc = 42 R"(Given the arguments required to build an operation, attempts to infer 43 its return types. Raises ValueError on failure.)"; 44 45 constexpr static const char *inferReturnTypeComponentsDoc = 46 R"(Given the arguments required to build an operation, attempts to infer 47 its return shaped type components. Raises ValueError on failure.)"; 48 49 namespace { 50 51 /// Takes in an optional ist of operands and converts them into a SmallVector 52 /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 53 llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) { 54 llvm::SmallVector<MlirValue> mlirOperands; 55 56 if (!operandList || operandList->size() == 0) { 57 return mlirOperands; 58 } 59 60 // Note: as the list may contain other lists this may not be final size. 61 mlirOperands.reserve(operandList->size()); 62 for (const auto &&it : llvm::enumerate(*operandList)) { 63 if (it.value().is_none()) 64 continue; 65 66 PyValue *val; 67 try { 68 val = nb::cast<PyValue *>(it.value()); 69 if (!val) 70 throw nb::cast_error(); 71 mlirOperands.push_back(val->get()); 72 continue; 73 } catch (nb::cast_error &err) { 74 // Intentionally unhandled to try sequence below first. 75 (void)err; 76 } 77 78 try { 79 auto vals = nb::cast<nb::sequence>(it.value()); 80 for (nb::handle v : vals) { 81 try { 82 val = nb::cast<PyValue *>(v); 83 if (!val) 84 throw nb::cast_error(); 85 mlirOperands.push_back(val->get()); 86 } catch (nb::cast_error &err) { 87 throw nb::value_error( 88 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 89 " must be a Value or Sequence of Values (" + err.what() + ")") 90 .str() 91 .c_str()); 92 } 93 } 94 continue; 95 } catch (nb::cast_error &err) { 96 throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 97 " must be a Value or Sequence of Values (" + 98 err.what() + ")") 99 .str() 100 .c_str()); 101 } 102 103 throw nb::cast_error(); 104 } 105 106 return mlirOperands; 107 } 108 109 /// Takes in an optional vector of PyRegions and returns a SmallVector of 110 /// MlirRegion. Returns an empty SmallVector if the list is empty. 111 llvm::SmallVector<MlirRegion> 112 wrapRegions(std::optional<std::vector<PyRegion>> regions) { 113 llvm::SmallVector<MlirRegion> mlirRegions; 114 115 if (regions) { 116 mlirRegions.reserve(regions->size()); 117 for (PyRegion ®ion : *regions) { 118 mlirRegions.push_back(region); 119 } 120 } 121 122 return mlirRegions; 123 } 124 125 } // namespace 126 127 /// CRTP base class for Python classes representing MLIR Op interfaces. 128 /// Interface hierarchies are flat so no base class is expected here. The 129 /// derived class is expected to define the following static fields: 130 /// - `const char *pyClassName` - the name of the Python class to create; 131 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 132 /// of the interface. 133 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 134 /// interface-specific methods. 135 /// 136 /// An interface class may be constructed from either an Operation/OpView object 137 /// or from a subclass of OpView. In the latter case, only the static interface 138 /// methods are available, similarly to calling ConcereteOp::staticMethod on the 139 /// C++ side. Implementations of concrete interfaces can use the `isStatic` 140 /// method to check whether the interface object was constructed from a class or 141 /// an operation/opview instance. The `getOpName` always succeeds and returns a 142 /// canonical name of the operation suitable for lookups. 143 template <typename ConcreteIface> 144 class PyConcreteOpInterface { 145 protected: 146 using ClassTy = nb::class_<ConcreteIface>; 147 using GetTypeIDFunctionTy = MlirTypeID (*)(); 148 149 public: 150 /// Constructs an interface instance from an object that is either an 151 /// operation or a subclass of OpView. In the latter case, only the static 152 /// methods of the interface are accessible to the caller. 153 PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) 154 : obj(std::move(object)) { 155 try { 156 operation = &nb::cast<PyOperation &>(obj); 157 } catch (nb::cast_error &) { 158 // Do nothing. 159 } 160 161 try { 162 operation = &nb::cast<PyOpView &>(obj).getOperation(); 163 } catch (nb::cast_error &) { 164 // Do nothing. 165 } 166 167 if (operation != nullptr) { 168 if (!mlirOperationImplementsInterface(*operation, 169 ConcreteIface::getInterfaceID())) { 170 std::string msg = "the operation does not implement "; 171 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 172 } 173 174 MlirIdentifier identifier = mlirOperationGetName(*operation); 175 MlirStringRef stringRef = mlirIdentifierStr(identifier); 176 opName = std::string(stringRef.data, stringRef.length); 177 } else { 178 try { 179 opName = nb::cast<std::string>(obj.attr("OPERATION_NAME")); 180 } catch (nb::cast_error &) { 181 throw nb::type_error( 182 "Op interface does not refer to an operation or OpView class"); 183 } 184 185 if (!mlirOperationImplementsInterfaceStatic( 186 mlirStringRefCreate(opName.data(), opName.length()), 187 context.resolve().get(), ConcreteIface::getInterfaceID())) { 188 std::string msg = "the operation does not implement "; 189 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 190 } 191 } 192 } 193 194 /// Creates the Python bindings for this class in the given module. 195 static void bind(nb::module_ &m) { 196 nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName); 197 cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"), 198 nb::arg("context").none() = nb::none(), constructorDoc) 199 .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, 200 operationDoc) 201 .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); 202 ConcreteIface::bindDerived(cls); 203 } 204 205 /// Hook for derived classes to add class-specific bindings. 206 static void bindDerived(ClassTy &cls) {} 207 208 /// Returns `true` if this object was constructed from a subclass of OpView 209 /// rather than from an operation instance. 210 bool isStatic() { return operation == nullptr; } 211 212 /// Returns the operation instance from which this object was constructed. 213 /// Throws a type error if this object was constructed from a subclass of 214 /// OpView. 215 nb::object getOperationObject() { 216 if (operation == nullptr) { 217 throw nb::type_error("Cannot get an operation from a static interface"); 218 } 219 220 return operation->getRef().releaseObject(); 221 } 222 223 /// Returns the opview of the operation instance from which this object was 224 /// constructed. Throws a type error if this object was constructed form a 225 /// subclass of OpView. 226 nb::object getOpView() { 227 if (operation == nullptr) { 228 throw nb::type_error("Cannot get an opview from a static interface"); 229 } 230 231 return operation->createOpView(); 232 } 233 234 /// Returns the canonical name of the operation this interface is constructed 235 /// from. 236 const std::string &getOpName() { return opName; } 237 238 private: 239 PyOperation *operation = nullptr; 240 std::string opName; 241 nb::object obj; 242 }; 243 244 /// Python wrapper for InferTypeOpInterface. This interface has only static 245 /// methods. 246 class PyInferTypeOpInterface 247 : public PyConcreteOpInterface<PyInferTypeOpInterface> { 248 public: 249 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 250 251 constexpr static const char *pyClassName = "InferTypeOpInterface"; 252 constexpr static GetTypeIDFunctionTy getInterfaceID = 253 &mlirInferTypeOpInterfaceTypeID; 254 255 /// C-style user-data structure for type appending callback. 256 struct AppendResultsCallbackData { 257 std::vector<PyType> &inferredTypes; 258 PyMlirContext &pyMlirContext; 259 }; 260 261 /// Appends the types provided as the two first arguments to the user-data 262 /// structure (expects AppendResultsCallbackData). 263 static void appendResultsCallback(intptr_t nTypes, MlirType *types, 264 void *userData) { 265 auto *data = static_cast<AppendResultsCallbackData *>(userData); 266 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 267 for (intptr_t i = 0; i < nTypes; ++i) { 268 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 269 } 270 } 271 272 /// Given the arguments required to build an operation, attempts to infer its 273 /// return types. Throws value_error on failure. 274 std::vector<PyType> 275 inferReturnTypes(std::optional<nb::list> operandList, 276 std::optional<PyAttribute> attributes, void *properties, 277 std::optional<std::vector<PyRegion>> regions, 278 DefaultingPyMlirContext context, 279 DefaultingPyLocation location) { 280 llvm::SmallVector<MlirValue> mlirOperands = 281 wrapOperands(std::move(operandList)); 282 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 283 284 std::vector<PyType> inferredTypes; 285 PyMlirContext &pyContext = context.resolve(); 286 AppendResultsCallbackData data{inferredTypes, pyContext}; 287 MlirStringRef opNameRef = 288 mlirStringRefCreate(getOpName().data(), getOpName().length()); 289 MlirAttribute attributeDict = 290 attributes ? attributes->get() : mlirAttributeGetNull(); 291 292 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 293 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 294 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 295 mlirRegions.data(), &appendResultsCallback, &data); 296 297 if (mlirLogicalResultIsFailure(result)) { 298 throw nb::value_error("Failed to infer result types"); 299 } 300 301 return inferredTypes; 302 } 303 304 static void bindDerived(ClassTy &cls) { 305 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 306 nb::arg("operands").none() = nb::none(), 307 nb::arg("attributes").none() = nb::none(), 308 nb::arg("properties").none() = nb::none(), 309 nb::arg("regions").none() = nb::none(), 310 nb::arg("context").none() = nb::none(), 311 nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); 312 } 313 }; 314 315 /// Wrapper around an shaped type components. 316 class PyShapedTypeComponents { 317 public: 318 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 319 PyShapedTypeComponents(nb::list shape, MlirType elementType) 320 : shape(std::move(shape)), elementType(elementType), ranked(true) {} 321 PyShapedTypeComponents(nb::list shape, MlirType elementType, 322 MlirAttribute attribute) 323 : shape(std::move(shape)), elementType(elementType), attribute(attribute), 324 ranked(true) {} 325 PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 326 PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept 327 : shape(other.shape), elementType(other.elementType), 328 attribute(other.attribute), ranked(other.ranked) {} 329 330 static void bind(nb::module_ &m) { 331 nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents") 332 .def_prop_ro( 333 "element_type", 334 [](PyShapedTypeComponents &self) { return self.elementType; }, 335 "Returns the element type of the shaped type components.") 336 .def_static( 337 "get", 338 [](PyType &elementType) { 339 return PyShapedTypeComponents(elementType); 340 }, 341 nb::arg("element_type"), 342 "Create an shaped type components object with only the element " 343 "type.") 344 .def_static( 345 "get", 346 [](nb::list shape, PyType &elementType) { 347 return PyShapedTypeComponents(std::move(shape), elementType); 348 }, 349 nb::arg("shape"), nb::arg("element_type"), 350 "Create a ranked shaped type components object.") 351 .def_static( 352 "get", 353 [](nb::list shape, PyType &elementType, PyAttribute &attribute) { 354 return PyShapedTypeComponents(std::move(shape), elementType, 355 attribute); 356 }, 357 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), 358 "Create a ranked shaped type components object with attribute.") 359 .def_prop_ro( 360 "has_rank", 361 [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 362 "Returns whether the given shaped type component is ranked.") 363 .def_prop_ro( 364 "rank", 365 [](PyShapedTypeComponents &self) -> nb::object { 366 if (!self.ranked) { 367 return nb::none(); 368 } 369 return nb::int_(self.shape.size()); 370 }, 371 "Returns the rank of the given ranked shaped type components. If " 372 "the shaped type components does not have a rank, None is " 373 "returned.") 374 .def_prop_ro( 375 "shape", 376 [](PyShapedTypeComponents &self) -> nb::object { 377 if (!self.ranked) { 378 return nb::none(); 379 } 380 return nb::list(self.shape); 381 }, 382 "Returns the shape of the ranked shaped type components as a list " 383 "of integers. Returns none if the shaped type component does not " 384 "have a rank."); 385 } 386 387 nb::object getCapsule(); 388 static PyShapedTypeComponents createFromCapsule(nb::object capsule); 389 390 private: 391 nb::list shape; 392 MlirType elementType; 393 MlirAttribute attribute; 394 bool ranked{false}; 395 }; 396 397 /// Python wrapper for InferShapedTypeOpInterface. This interface has only 398 /// static methods. 399 class PyInferShapedTypeOpInterface 400 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 401 public: 402 using PyConcreteOpInterface< 403 PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 404 405 constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 406 constexpr static GetTypeIDFunctionTy getInterfaceID = 407 &mlirInferShapedTypeOpInterfaceTypeID; 408 409 /// C-style user-data structure for type appending callback. 410 struct AppendResultsCallbackData { 411 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 412 }; 413 414 /// Appends the shaped type components provided as unpacked shape, element 415 /// type, attribute to the user-data. 416 static void appendResultsCallback(bool hasRank, intptr_t rank, 417 const int64_t *shape, MlirType elementType, 418 MlirAttribute attribute, void *userData) { 419 auto *data = static_cast<AppendResultsCallbackData *>(userData); 420 if (!hasRank) { 421 data->inferredShapedTypeComponents.emplace_back(elementType); 422 } else { 423 nb::list shapeList; 424 for (intptr_t i = 0; i < rank; ++i) { 425 shapeList.append(shape[i]); 426 } 427 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 428 attribute); 429 } 430 } 431 432 /// Given the arguments required to build an operation, attempts to infer the 433 /// shaped type components. Throws value_error on failure. 434 std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 435 std::optional<nb::list> operandList, 436 std::optional<PyAttribute> attributes, void *properties, 437 std::optional<std::vector<PyRegion>> regions, 438 DefaultingPyMlirContext context, DefaultingPyLocation location) { 439 llvm::SmallVector<MlirValue> mlirOperands = 440 wrapOperands(std::move(operandList)); 441 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 442 443 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 444 PyMlirContext &pyContext = context.resolve(); 445 AppendResultsCallbackData data{inferredShapedTypeComponents}; 446 MlirStringRef opNameRef = 447 mlirStringRefCreate(getOpName().data(), getOpName().length()); 448 MlirAttribute attributeDict = 449 attributes ? attributes->get() : mlirAttributeGetNull(); 450 451 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 452 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 453 mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 454 mlirRegions.data(), &appendResultsCallback, &data); 455 456 if (mlirLogicalResultIsFailure(result)) { 457 throw nb::value_error("Failed to infer result shape type components"); 458 } 459 460 return inferredShapedTypeComponents; 461 } 462 463 static void bindDerived(ClassTy &cls) { 464 cls.def("inferReturnTypeComponents", 465 &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 466 nb::arg("operands").none() = nb::none(), 467 nb::arg("attributes").none() = nb::none(), 468 nb::arg("regions").none() = nb::none(), 469 nb::arg("properties").none() = nb::none(), 470 nb::arg("context").none() = nb::none(), 471 nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); 472 } 473 }; 474 475 void populateIRInterfaces(nb::module_ &m) { 476 PyInferTypeOpInterface::bind(m); 477 PyShapedTypeComponents::bind(m); 478 PyInferShapedTypeOpInterface::bind(m); 479 } 480 481 } // namespace python 482 } // namespace mlir 483