1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cstddef> 10 #include <cstdint> 11 #include <stdexcept> 12 #include <string> 13 #include <utility> 14 #include <vector> 15 16 #include "IRModule.h" 17 #include "NanobindUtils.h" 18 #include "mlir-c/AffineExpr.h" 19 #include "mlir-c/AffineMap.h" 20 #include "mlir-c/IntegerSet.h" 21 #include "mlir/Bindings/Python/Nanobind.h" 22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 23 #include "mlir/Support/LLVM.h" 24 #include "llvm/ADT/Hashing.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/ADT/Twine.h" 28 29 namespace nb = nanobind; 30 using namespace mlir; 31 using namespace mlir::python; 32 33 using llvm::SmallVector; 34 using llvm::StringRef; 35 using llvm::Twine; 36 37 static const char kDumpDocstring[] = 38 R"(Dumps a debug representation of the object to stderr.)"; 39 40 /// Attempts to populate `result` with the content of `list` casted to the 41 /// appropriate type (Python and C types are provided as template arguments). 42 /// Throws errors in case of failure, using "action" to describe what the caller 43 /// was attempting to do. 44 template <typename PyType, typename CType> 45 static void pyListToVector(const nb::list &list, 46 llvm::SmallVectorImpl<CType> &result, 47 StringRef action) { 48 result.reserve(nb::len(list)); 49 for (nb::handle item : list) { 50 try { 51 result.push_back(nb::cast<PyType>(item)); 52 } catch (nb::cast_error &err) { 53 std::string msg = (llvm::Twine("Invalid expression when ") + action + 54 " (" + err.what() + ")") 55 .str(); 56 throw std::runtime_error(msg.c_str()); 57 } catch (std::runtime_error &err) { 58 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 59 action + " (" + err.what() + ")") 60 .str(); 61 throw std::runtime_error(msg.c_str()); 62 } 63 } 64 } 65 66 template <typename PermutationTy> 67 static bool isPermutation(std::vector<PermutationTy> permutation) { 68 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 69 for (auto val : permutation) { 70 if (val < permutation.size()) { 71 if (seen[val]) 72 return false; 73 seen[val] = true; 74 continue; 75 } 76 return false; 77 } 78 return true; 79 } 80 81 namespace { 82 83 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 84 /// and should be castable from it. Intermediate hierarchy classes can be 85 /// modeled by specifying BaseTy. 86 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 87 class PyConcreteAffineExpr : public BaseTy { 88 public: 89 // Derived classes must define statics for: 90 // IsAFunctionTy isaFunction 91 // const char *pyClassName 92 // and redefine bindDerived. 93 using ClassTy = nb::class_<DerivedTy, BaseTy>; 94 using IsAFunctionTy = bool (*)(MlirAffineExpr); 95 96 PyConcreteAffineExpr() = default; 97 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 98 : BaseTy(std::move(contextRef), affineExpr) {} 99 PyConcreteAffineExpr(PyAffineExpr &orig) 100 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 101 102 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 103 if (!DerivedTy::isaFunction(orig)) { 104 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig))); 105 throw nb::value_error((Twine("Cannot cast affine expression to ") + 106 DerivedTy::pyClassName + " (from " + origRepr + 107 ")") 108 .str() 109 .c_str()); 110 } 111 return orig; 112 } 113 114 static void bind(nb::module_ &m) { 115 auto cls = ClassTy(m, DerivedTy::pyClassName); 116 cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr")); 117 cls.def_static( 118 "isinstance", 119 [](PyAffineExpr &otherAffineExpr) -> bool { 120 return DerivedTy::isaFunction(otherAffineExpr); 121 }, 122 nb::arg("other")); 123 DerivedTy::bindDerived(cls); 124 } 125 126 /// Implemented by derived classes to add methods to the Python subclass. 127 static void bindDerived(ClassTy &m) {} 128 }; 129 130 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 131 public: 132 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 133 static constexpr const char *pyClassName = "AffineConstantExpr"; 134 using PyConcreteAffineExpr::PyConcreteAffineExpr; 135 136 static PyAffineConstantExpr get(intptr_t value, 137 DefaultingPyMlirContext context) { 138 MlirAffineExpr affineExpr = 139 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 140 return PyAffineConstantExpr(context->getRef(), affineExpr); 141 } 142 143 static void bindDerived(ClassTy &c) { 144 c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), 145 nb::arg("context").none() = nb::none()); 146 c.def_prop_ro("value", [](PyAffineConstantExpr &self) { 147 return mlirAffineConstantExprGetValue(self); 148 }); 149 } 150 }; 151 152 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 153 public: 154 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 155 static constexpr const char *pyClassName = "AffineDimExpr"; 156 using PyConcreteAffineExpr::PyConcreteAffineExpr; 157 158 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 159 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 160 return PyAffineDimExpr(context->getRef(), affineExpr); 161 } 162 163 static void bindDerived(ClassTy &c) { 164 c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), 165 nb::arg("context").none() = nb::none()); 166 c.def_prop_ro("position", [](PyAffineDimExpr &self) { 167 return mlirAffineDimExprGetPosition(self); 168 }); 169 } 170 }; 171 172 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 173 public: 174 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 175 static constexpr const char *pyClassName = "AffineSymbolExpr"; 176 using PyConcreteAffineExpr::PyConcreteAffineExpr; 177 178 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 179 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 180 return PyAffineSymbolExpr(context->getRef(), affineExpr); 181 } 182 183 static void bindDerived(ClassTy &c) { 184 c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), 185 nb::arg("context").none() = nb::none()); 186 c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { 187 return mlirAffineSymbolExprGetPosition(self); 188 }); 189 } 190 }; 191 192 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 193 public: 194 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 195 static constexpr const char *pyClassName = "AffineBinaryExpr"; 196 using PyConcreteAffineExpr::PyConcreteAffineExpr; 197 198 PyAffineExpr lhs() { 199 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 200 return PyAffineExpr(getContext(), lhsExpr); 201 } 202 203 PyAffineExpr rhs() { 204 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 205 return PyAffineExpr(getContext(), rhsExpr); 206 } 207 208 static void bindDerived(ClassTy &c) { 209 c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); 210 c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); 211 } 212 }; 213 214 class PyAffineAddExpr 215 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 216 public: 217 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 218 static constexpr const char *pyClassName = "AffineAddExpr"; 219 using PyConcreteAffineExpr::PyConcreteAffineExpr; 220 221 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 222 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 223 return PyAffineAddExpr(lhs.getContext(), expr); 224 } 225 226 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 227 MlirAffineExpr expr = mlirAffineAddExprGet( 228 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 229 return PyAffineAddExpr(lhs.getContext(), expr); 230 } 231 232 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 233 MlirAffineExpr expr = mlirAffineAddExprGet( 234 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 235 return PyAffineAddExpr(rhs.getContext(), expr); 236 } 237 238 static void bindDerived(ClassTy &c) { 239 c.def_static("get", &PyAffineAddExpr::get); 240 } 241 }; 242 243 class PyAffineMulExpr 244 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 245 public: 246 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 247 static constexpr const char *pyClassName = "AffineMulExpr"; 248 using PyConcreteAffineExpr::PyConcreteAffineExpr; 249 250 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 251 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 252 return PyAffineMulExpr(lhs.getContext(), expr); 253 } 254 255 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 256 MlirAffineExpr expr = mlirAffineMulExprGet( 257 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 258 return PyAffineMulExpr(lhs.getContext(), expr); 259 } 260 261 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 262 MlirAffineExpr expr = mlirAffineMulExprGet( 263 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 264 return PyAffineMulExpr(rhs.getContext(), expr); 265 } 266 267 static void bindDerived(ClassTy &c) { 268 c.def_static("get", &PyAffineMulExpr::get); 269 } 270 }; 271 272 class PyAffineModExpr 273 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 274 public: 275 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 276 static constexpr const char *pyClassName = "AffineModExpr"; 277 using PyConcreteAffineExpr::PyConcreteAffineExpr; 278 279 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 280 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 281 return PyAffineModExpr(lhs.getContext(), expr); 282 } 283 284 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 285 MlirAffineExpr expr = mlirAffineModExprGet( 286 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 287 return PyAffineModExpr(lhs.getContext(), expr); 288 } 289 290 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 291 MlirAffineExpr expr = mlirAffineModExprGet( 292 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 293 return PyAffineModExpr(rhs.getContext(), expr); 294 } 295 296 static void bindDerived(ClassTy &c) { 297 c.def_static("get", &PyAffineModExpr::get); 298 } 299 }; 300 301 class PyAffineFloorDivExpr 302 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 303 public: 304 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 305 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 306 using PyConcreteAffineExpr::PyConcreteAffineExpr; 307 308 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 309 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 310 return PyAffineFloorDivExpr(lhs.getContext(), expr); 311 } 312 313 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 314 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 315 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 316 return PyAffineFloorDivExpr(lhs.getContext(), expr); 317 } 318 319 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 320 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 321 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 322 return PyAffineFloorDivExpr(rhs.getContext(), expr); 323 } 324 325 static void bindDerived(ClassTy &c) { 326 c.def_static("get", &PyAffineFloorDivExpr::get); 327 } 328 }; 329 330 class PyAffineCeilDivExpr 331 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 332 public: 333 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 334 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 335 using PyConcreteAffineExpr::PyConcreteAffineExpr; 336 337 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 338 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 339 return PyAffineCeilDivExpr(lhs.getContext(), expr); 340 } 341 342 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 343 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 344 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 345 return PyAffineCeilDivExpr(lhs.getContext(), expr); 346 } 347 348 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 349 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 350 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 351 return PyAffineCeilDivExpr(rhs.getContext(), expr); 352 } 353 354 static void bindDerived(ClassTy &c) { 355 c.def_static("get", &PyAffineCeilDivExpr::get); 356 } 357 }; 358 359 } // namespace 360 361 bool PyAffineExpr::operator==(const PyAffineExpr &other) const { 362 return mlirAffineExprEqual(affineExpr, other.affineExpr); 363 } 364 365 nb::object PyAffineExpr::getCapsule() { 366 return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this)); 367 } 368 369 PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { 370 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 371 if (mlirAffineExprIsNull(rawAffineExpr)) 372 throw nb::python_error(); 373 return PyAffineExpr( 374 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 375 rawAffineExpr); 376 } 377 378 //------------------------------------------------------------------------------ 379 // PyAffineMap and utilities. 380 //------------------------------------------------------------------------------ 381 namespace { 382 383 /// A list of expressions contained in an affine map. Internally these are 384 /// stored as a consecutive array leading to inexpensive random access. Both 385 /// the map and the expression are owned by the context so we need not bother 386 /// with lifetime extension. 387 class PyAffineMapExprList 388 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 389 public: 390 static constexpr const char *pyClassName = "AffineExprList"; 391 392 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, 393 intptr_t length = -1, intptr_t step = 1) 394 : Sliceable(startIndex, 395 length == -1 ? mlirAffineMapGetNumResults(map) : length, 396 step), 397 affineMap(map) {} 398 399 private: 400 /// Give the parent CRTP class access to hook implementations below. 401 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>; 402 403 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } 404 405 PyAffineExpr getRawElement(intptr_t pos) { 406 return PyAffineExpr(affineMap.getContext(), 407 mlirAffineMapGetResult(affineMap, pos)); 408 } 409 410 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 411 intptr_t step) { 412 return PyAffineMapExprList(affineMap, startIndex, length, step); 413 } 414 415 PyAffineMap affineMap; 416 }; 417 } // namespace 418 419 bool PyAffineMap::operator==(const PyAffineMap &other) const { 420 return mlirAffineMapEqual(affineMap, other.affineMap); 421 } 422 423 nb::object PyAffineMap::getCapsule() { 424 return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this)); 425 } 426 427 PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { 428 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 429 if (mlirAffineMapIsNull(rawAffineMap)) 430 throw nb::python_error(); 431 return PyAffineMap( 432 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 433 rawAffineMap); 434 } 435 436 //------------------------------------------------------------------------------ 437 // PyIntegerSet and utilities. 438 //------------------------------------------------------------------------------ 439 namespace { 440 441 class PyIntegerSetConstraint { 442 public: 443 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) 444 : set(std::move(set)), pos(pos) {} 445 446 PyAffineExpr getExpr() { 447 return PyAffineExpr(set.getContext(), 448 mlirIntegerSetGetConstraint(set, pos)); 449 } 450 451 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 452 453 static void bind(nb::module_ &m) { 454 nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint") 455 .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) 456 .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); 457 } 458 459 private: 460 PyIntegerSet set; 461 intptr_t pos; 462 }; 463 464 class PyIntegerSetConstraintList 465 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 466 public: 467 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 468 469 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, 470 intptr_t length = -1, intptr_t step = 1) 471 : Sliceable(startIndex, 472 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 473 step), 474 set(set) {} 475 476 private: 477 /// Give the parent CRTP class access to hook implementations below. 478 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>; 479 480 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } 481 482 PyIntegerSetConstraint getRawElement(intptr_t pos) { 483 return PyIntegerSetConstraint(set, pos); 484 } 485 486 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 487 intptr_t step) { 488 return PyIntegerSetConstraintList(set, startIndex, length, step); 489 } 490 491 PyIntegerSet set; 492 }; 493 } // namespace 494 495 bool PyIntegerSet::operator==(const PyIntegerSet &other) const { 496 return mlirIntegerSetEqual(integerSet, other.integerSet); 497 } 498 499 nb::object PyIntegerSet::getCapsule() { 500 return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this)); 501 } 502 503 PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { 504 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 505 if (mlirIntegerSetIsNull(rawIntegerSet)) 506 throw nb::python_error(); 507 return PyIntegerSet( 508 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 509 rawIntegerSet); 510 } 511 512 void mlir::python::populateIRAffine(nb::module_ &m) { 513 //---------------------------------------------------------------------------- 514 // Mapping of PyAffineExpr and derived classes. 515 //---------------------------------------------------------------------------- 516 nb::class_<PyAffineExpr>(m, "AffineExpr") 517 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) 518 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 519 .def("__add__", &PyAffineAddExpr::get) 520 .def("__add__", &PyAffineAddExpr::getRHSConstant) 521 .def("__radd__", &PyAffineAddExpr::getRHSConstant) 522 .def("__mul__", &PyAffineMulExpr::get) 523 .def("__mul__", &PyAffineMulExpr::getRHSConstant) 524 .def("__rmul__", &PyAffineMulExpr::getRHSConstant) 525 .def("__mod__", &PyAffineModExpr::get) 526 .def("__mod__", &PyAffineModExpr::getRHSConstant) 527 .def("__rmod__", 528 [](PyAffineExpr &self, intptr_t other) { 529 return PyAffineModExpr::get( 530 PyAffineConstantExpr::get(other, *self.getContext().get()), 531 self); 532 }) 533 .def("__sub__", 534 [](PyAffineExpr &self, PyAffineExpr &other) { 535 auto negOne = 536 PyAffineConstantExpr::get(-1, *self.getContext().get()); 537 return PyAffineAddExpr::get(self, 538 PyAffineMulExpr::get(negOne, other)); 539 }) 540 .def("__sub__", 541 [](PyAffineExpr &self, intptr_t other) { 542 return PyAffineAddExpr::get( 543 self, 544 PyAffineConstantExpr::get(-other, *self.getContext().get())); 545 }) 546 .def("__rsub__", 547 [](PyAffineExpr &self, intptr_t other) { 548 return PyAffineAddExpr::getLHSConstant( 549 other, PyAffineMulExpr::getLHSConstant(-1, self)); 550 }) 551 .def("__eq__", [](PyAffineExpr &self, 552 PyAffineExpr &other) { return self == other; }) 553 .def("__eq__", 554 [](PyAffineExpr &self, nb::object &other) { return false; }) 555 .def("__str__", 556 [](PyAffineExpr &self) { 557 PyPrintAccumulator printAccum; 558 mlirAffineExprPrint(self, printAccum.getCallback(), 559 printAccum.getUserData()); 560 return printAccum.join(); 561 }) 562 .def("__repr__", 563 [](PyAffineExpr &self) { 564 PyPrintAccumulator printAccum; 565 printAccum.parts.append("AffineExpr("); 566 mlirAffineExprPrint(self, printAccum.getCallback(), 567 printAccum.getUserData()); 568 printAccum.parts.append(")"); 569 return printAccum.join(); 570 }) 571 .def("__hash__", 572 [](PyAffineExpr &self) { 573 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 574 }) 575 .def_prop_ro( 576 "context", 577 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 578 .def("compose", 579 [](PyAffineExpr &self, PyAffineMap &other) { 580 return PyAffineExpr(self.getContext(), 581 mlirAffineExprCompose(self, other)); 582 }) 583 .def_static( 584 "get_add", &PyAffineAddExpr::get, 585 "Gets an affine expression containing a sum of two expressions.") 586 .def_static("get_add", &PyAffineAddExpr::getLHSConstant, 587 "Gets an affine expression containing a sum of a constant " 588 "and another expression.") 589 .def_static("get_add", &PyAffineAddExpr::getRHSConstant, 590 "Gets an affine expression containing a sum of an expression " 591 "and a constant.") 592 .def_static( 593 "get_mul", &PyAffineMulExpr::get, 594 "Gets an affine expression containing a product of two expressions.") 595 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, 596 "Gets an affine expression containing a product of a " 597 "constant and another expression.") 598 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, 599 "Gets an affine expression containing a product of an " 600 "expression and a constant.") 601 .def_static("get_mod", &PyAffineModExpr::get, 602 "Gets an affine expression containing the modulo of dividing " 603 "one expression by another.") 604 .def_static("get_mod", &PyAffineModExpr::getLHSConstant, 605 "Gets a semi-affine expression containing the modulo of " 606 "dividing a constant by an expression.") 607 .def_static("get_mod", &PyAffineModExpr::getRHSConstant, 608 "Gets an affine expression containing the module of dividing" 609 "an expression by a constant.") 610 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 611 "Gets an affine expression containing the rounded-down " 612 "result of dividing one expression by another.") 613 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, 614 "Gets a semi-affine expression containing the rounded-down " 615 "result of dividing a constant by an expression.") 616 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, 617 "Gets an affine expression containing the rounded-down " 618 "result of dividing an expression by a constant.") 619 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 620 "Gets an affine expression containing the rounded-up result " 621 "of dividing one expression by another.") 622 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, 623 "Gets a semi-affine expression containing the rounded-up " 624 "result of dividing a constant by an expression.") 625 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, 626 "Gets an affine expression containing the rounded-up result " 627 "of dividing an expression by a constant.") 628 .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), 629 nb::arg("context").none() = nb::none(), 630 "Gets a constant affine expression with the given value.") 631 .def_static( 632 "get_dim", &PyAffineDimExpr::get, nb::arg("position"), 633 nb::arg("context").none() = nb::none(), 634 "Gets an affine expression of a dimension at the given position.") 635 .def_static( 636 "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), 637 nb::arg("context").none() = nb::none(), 638 "Gets an affine expression of a symbol at the given position.") 639 .def( 640 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 641 kDumpDocstring); 642 PyAffineConstantExpr::bind(m); 643 PyAffineDimExpr::bind(m); 644 PyAffineSymbolExpr::bind(m); 645 PyAffineBinaryExpr::bind(m); 646 PyAffineAddExpr::bind(m); 647 PyAffineMulExpr::bind(m); 648 PyAffineModExpr::bind(m); 649 PyAffineFloorDivExpr::bind(m); 650 PyAffineCeilDivExpr::bind(m); 651 652 //---------------------------------------------------------------------------- 653 // Mapping of PyAffineMap. 654 //---------------------------------------------------------------------------- 655 nb::class_<PyAffineMap>(m, "AffineMap") 656 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) 657 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 658 .def("__eq__", 659 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 660 .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) 661 .def("__str__", 662 [](PyAffineMap &self) { 663 PyPrintAccumulator printAccum; 664 mlirAffineMapPrint(self, printAccum.getCallback(), 665 printAccum.getUserData()); 666 return printAccum.join(); 667 }) 668 .def("__repr__", 669 [](PyAffineMap &self) { 670 PyPrintAccumulator printAccum; 671 printAccum.parts.append("AffineMap("); 672 mlirAffineMapPrint(self, printAccum.getCallback(), 673 printAccum.getUserData()); 674 printAccum.parts.append(")"); 675 return printAccum.join(); 676 }) 677 .def("__hash__", 678 [](PyAffineMap &self) { 679 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 680 }) 681 .def_static("compress_unused_symbols", 682 [](nb::list affineMaps, DefaultingPyMlirContext context) { 683 SmallVector<MlirAffineMap> maps; 684 pyListToVector<PyAffineMap, MlirAffineMap>( 685 affineMaps, maps, "attempting to create an AffineMap"); 686 std::vector<MlirAffineMap> compressed(affineMaps.size()); 687 auto populate = [](void *result, intptr_t idx, 688 MlirAffineMap m) { 689 static_cast<MlirAffineMap *>(result)[idx] = (m); 690 }; 691 mlirAffineMapCompressUnusedSymbols( 692 maps.data(), maps.size(), compressed.data(), populate); 693 std::vector<PyAffineMap> res; 694 res.reserve(compressed.size()); 695 for (auto m : compressed) 696 res.emplace_back(context->getRef(), m); 697 return res; 698 }) 699 .def_prop_ro( 700 "context", 701 [](PyAffineMap &self) { return self.getContext().getObject(); }, 702 "Context that owns the Affine Map") 703 .def( 704 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 705 kDumpDocstring) 706 .def_static( 707 "get", 708 [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, 709 DefaultingPyMlirContext context) { 710 SmallVector<MlirAffineExpr> affineExprs; 711 pyListToVector<PyAffineExpr, MlirAffineExpr>( 712 exprs, affineExprs, "attempting to create an AffineMap"); 713 MlirAffineMap map = 714 mlirAffineMapGet(context->get(), dimCount, symbolCount, 715 affineExprs.size(), affineExprs.data()); 716 return PyAffineMap(context->getRef(), map); 717 }, 718 nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), 719 nb::arg("context").none() = nb::none(), 720 "Gets a map with the given expressions as results.") 721 .def_static( 722 "get_constant", 723 [](intptr_t value, DefaultingPyMlirContext context) { 724 MlirAffineMap affineMap = 725 mlirAffineMapConstantGet(context->get(), value); 726 return PyAffineMap(context->getRef(), affineMap); 727 }, 728 nb::arg("value"), nb::arg("context").none() = nb::none(), 729 "Gets an affine map with a single constant result") 730 .def_static( 731 "get_empty", 732 [](DefaultingPyMlirContext context) { 733 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 734 return PyAffineMap(context->getRef(), affineMap); 735 }, 736 nb::arg("context").none() = nb::none(), "Gets an empty affine map.") 737 .def_static( 738 "get_identity", 739 [](intptr_t nDims, DefaultingPyMlirContext context) { 740 MlirAffineMap affineMap = 741 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 742 return PyAffineMap(context->getRef(), affineMap); 743 }, 744 nb::arg("n_dims"), nb::arg("context").none() = nb::none(), 745 "Gets an identity map with the given number of dimensions.") 746 .def_static( 747 "get_minor_identity", 748 [](intptr_t nDims, intptr_t nResults, 749 DefaultingPyMlirContext context) { 750 MlirAffineMap affineMap = 751 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 752 return PyAffineMap(context->getRef(), affineMap); 753 }, 754 nb::arg("n_dims"), nb::arg("n_results"), 755 nb::arg("context").none() = nb::none(), 756 "Gets a minor identity map with the given number of dimensions and " 757 "results.") 758 .def_static( 759 "get_permutation", 760 [](std::vector<unsigned> permutation, 761 DefaultingPyMlirContext context) { 762 if (!isPermutation(permutation)) 763 throw std::runtime_error("Invalid permutation when attempting to " 764 "create an AffineMap"); 765 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 766 context->get(), permutation.size(), permutation.data()); 767 return PyAffineMap(context->getRef(), affineMap); 768 }, 769 nb::arg("permutation"), nb::arg("context").none() = nb::none(), 770 "Gets an affine map that permutes its inputs.") 771 .def( 772 "get_submap", 773 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 774 intptr_t numResults = mlirAffineMapGetNumResults(self); 775 for (intptr_t pos : resultPos) { 776 if (pos < 0 || pos >= numResults) 777 throw nb::value_error("result position out of bounds"); 778 } 779 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 780 self, resultPos.size(), resultPos.data()); 781 return PyAffineMap(self.getContext(), affineMap); 782 }, 783 nb::arg("result_positions")) 784 .def( 785 "get_major_submap", 786 [](PyAffineMap &self, intptr_t nResults) { 787 if (nResults >= mlirAffineMapGetNumResults(self)) 788 throw nb::value_error("number of results out of bounds"); 789 MlirAffineMap affineMap = 790 mlirAffineMapGetMajorSubMap(self, nResults); 791 return PyAffineMap(self.getContext(), affineMap); 792 }, 793 nb::arg("n_results")) 794 .def( 795 "get_minor_submap", 796 [](PyAffineMap &self, intptr_t nResults) { 797 if (nResults >= mlirAffineMapGetNumResults(self)) 798 throw nb::value_error("number of results out of bounds"); 799 MlirAffineMap affineMap = 800 mlirAffineMapGetMinorSubMap(self, nResults); 801 return PyAffineMap(self.getContext(), affineMap); 802 }, 803 nb::arg("n_results")) 804 .def( 805 "replace", 806 [](PyAffineMap &self, PyAffineExpr &expression, 807 PyAffineExpr &replacement, intptr_t numResultDims, 808 intptr_t numResultSyms) { 809 MlirAffineMap affineMap = mlirAffineMapReplace( 810 self, expression, replacement, numResultDims, numResultSyms); 811 return PyAffineMap(self.getContext(), affineMap); 812 }, 813 nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), 814 nb::arg("n_result_syms")) 815 .def_prop_ro( 816 "is_permutation", 817 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 818 .def_prop_ro("is_projected_permutation", 819 [](PyAffineMap &self) { 820 return mlirAffineMapIsProjectedPermutation(self); 821 }) 822 .def_prop_ro( 823 "n_dims", 824 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 825 .def_prop_ro( 826 "n_inputs", 827 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 828 .def_prop_ro( 829 "n_symbols", 830 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 831 .def_prop_ro("results", 832 [](PyAffineMap &self) { return PyAffineMapExprList(self); }); 833 PyAffineMapExprList::bind(m); 834 835 //---------------------------------------------------------------------------- 836 // Mapping of PyIntegerSet. 837 //---------------------------------------------------------------------------- 838 nb::class_<PyIntegerSet>(m, "IntegerSet") 839 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) 840 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 841 .def("__eq__", [](PyIntegerSet &self, 842 PyIntegerSet &other) { return self == other; }) 843 .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) 844 .def("__str__", 845 [](PyIntegerSet &self) { 846 PyPrintAccumulator printAccum; 847 mlirIntegerSetPrint(self, printAccum.getCallback(), 848 printAccum.getUserData()); 849 return printAccum.join(); 850 }) 851 .def("__repr__", 852 [](PyIntegerSet &self) { 853 PyPrintAccumulator printAccum; 854 printAccum.parts.append("IntegerSet("); 855 mlirIntegerSetPrint(self, printAccum.getCallback(), 856 printAccum.getUserData()); 857 printAccum.parts.append(")"); 858 return printAccum.join(); 859 }) 860 .def("__hash__", 861 [](PyIntegerSet &self) { 862 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 863 }) 864 .def_prop_ro( 865 "context", 866 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 867 .def( 868 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 869 kDumpDocstring) 870 .def_static( 871 "get", 872 [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, 873 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 874 if (exprs.size() != eqFlags.size()) 875 throw nb::value_error( 876 "Expected the number of constraints to match " 877 "that of equality flags"); 878 if (exprs.size() == 0) 879 throw nb::value_error("Expected non-empty list of constraints"); 880 881 // Copy over to a SmallVector because std::vector has a 882 // specialization for booleans that packs data and does not 883 // expose a `bool *`. 884 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 885 886 SmallVector<MlirAffineExpr> affineExprs; 887 pyListToVector<PyAffineExpr>(exprs, affineExprs, 888 "attempting to create an IntegerSet"); 889 MlirIntegerSet set = mlirIntegerSetGet( 890 context->get(), numDims, numSymbols, exprs.size(), 891 affineExprs.data(), flags.data()); 892 return PyIntegerSet(context->getRef(), set); 893 }, 894 nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), 895 nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) 896 .def_static( 897 "get_empty", 898 [](intptr_t numDims, intptr_t numSymbols, 899 DefaultingPyMlirContext context) { 900 MlirIntegerSet set = 901 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 902 return PyIntegerSet(context->getRef(), set); 903 }, 904 nb::arg("num_dims"), nb::arg("num_symbols"), 905 nb::arg("context").none() = nb::none()) 906 .def( 907 "get_replaced", 908 [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, 909 intptr_t numResultDims, intptr_t numResultSymbols) { 910 if (static_cast<intptr_t>(dimExprs.size()) != 911 mlirIntegerSetGetNumDims(self)) 912 throw nb::value_error( 913 "Expected the number of dimension replacement expressions " 914 "to match that of dimensions"); 915 if (static_cast<intptr_t>(symbolExprs.size()) != 916 mlirIntegerSetGetNumSymbols(self)) 917 throw nb::value_error( 918 "Expected the number of symbol replacement expressions " 919 "to match that of symbols"); 920 921 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 922 pyListToVector<PyAffineExpr>( 923 dimExprs, dimAffineExprs, 924 "attempting to create an IntegerSet by replacing dimensions"); 925 pyListToVector<PyAffineExpr>( 926 symbolExprs, symbolAffineExprs, 927 "attempting to create an IntegerSet by replacing symbols"); 928 MlirIntegerSet set = mlirIntegerSetReplaceGet( 929 self, dimAffineExprs.data(), symbolAffineExprs.data(), 930 numResultDims, numResultSymbols); 931 return PyIntegerSet(self.getContext(), set); 932 }, 933 nb::arg("dim_exprs"), nb::arg("symbol_exprs"), 934 nb::arg("num_result_dims"), nb::arg("num_result_symbols")) 935 .def_prop_ro("is_canonical_empty", 936 [](PyIntegerSet &self) { 937 return mlirIntegerSetIsCanonicalEmpty(self); 938 }) 939 .def_prop_ro( 940 "n_dims", 941 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 942 .def_prop_ro( 943 "n_symbols", 944 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 945 .def_prop_ro( 946 "n_inputs", 947 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 948 .def_prop_ro("n_equalities", 949 [](PyIntegerSet &self) { 950 return mlirIntegerSetGetNumEqualities(self); 951 }) 952 .def_prop_ro("n_inequalities", 953 [](PyIntegerSet &self) { 954 return mlirIntegerSetGetNumInequalities(self); 955 }) 956 .def_prop_ro("constraints", [](PyIntegerSet &self) { 957 return PyIntegerSetConstraintList(self); 958 }); 959 PyIntegerSetConstraint::bind(m); 960 PyIntegerSetConstraintList::bind(m); 961 } 962