xref: /llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstddef>
10 #include <cstdint>
11 #include <stdexcept>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "IRModule.h"
17 #include "NanobindUtils.h"
18 #include "mlir-c/AffineExpr.h"
19 #include "mlir-c/AffineMap.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir/Bindings/Python/Nanobind.h"
22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Twine.h"
28 
29 namespace nb = nanobind;
30 using namespace mlir;
31 using namespace mlir::python;
32 
33 using llvm::SmallVector;
34 using llvm::StringRef;
35 using llvm::Twine;
36 
37 static const char kDumpDocstring[] =
38     R"(Dumps a debug representation of the object to stderr.)";
39 
40 /// Attempts to populate `result` with the content of `list` casted to the
41 /// appropriate type (Python and C types are provided as template arguments).
42 /// Throws errors in case of failure, using "action" to describe what the caller
43 /// was attempting to do.
44 template <typename PyType, typename CType>
45 static void pyListToVector(const nb::list &list,
46                            llvm::SmallVectorImpl<CType> &result,
47                            StringRef action) {
48   result.reserve(nb::len(list));
49   for (nb::handle item : list) {
50     try {
51       result.push_back(nb::cast<PyType>(item));
52     } catch (nb::cast_error &err) {
53       std::string msg = (llvm::Twine("Invalid expression when ") + action +
54                          " (" + err.what() + ")")
55                             .str();
56       throw std::runtime_error(msg.c_str());
57     } catch (std::runtime_error &err) {
58       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
59                          action + " (" + err.what() + ")")
60                             .str();
61       throw std::runtime_error(msg.c_str());
62     }
63   }
64 }
65 
66 template <typename PermutationTy>
67 static bool isPermutation(std::vector<PermutationTy> permutation) {
68   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
69   for (auto val : permutation) {
70     if (val < permutation.size()) {
71       if (seen[val])
72         return false;
73       seen[val] = true;
74       continue;
75     }
76     return false;
77   }
78   return true;
79 }
80 
81 namespace {
82 
83 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
84 /// and should be castable from it. Intermediate hierarchy classes can be
85 /// modeled by specifying BaseTy.
86 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
87 class PyConcreteAffineExpr : public BaseTy {
88 public:
89   // Derived classes must define statics for:
90   //   IsAFunctionTy isaFunction
91   //   const char *pyClassName
92   // and redefine bindDerived.
93   using ClassTy = nb::class_<DerivedTy, BaseTy>;
94   using IsAFunctionTy = bool (*)(MlirAffineExpr);
95 
96   PyConcreteAffineExpr() = default;
97   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
98       : BaseTy(std::move(contextRef), affineExpr) {}
99   PyConcreteAffineExpr(PyAffineExpr &orig)
100       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
101 
102   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
103     if (!DerivedTy::isaFunction(orig)) {
104       auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
105       throw nb::value_error((Twine("Cannot cast affine expression to ") +
106                              DerivedTy::pyClassName + " (from " + origRepr +
107                              ")")
108                                 .str()
109                                 .c_str());
110     }
111     return orig;
112   }
113 
114   static void bind(nb::module_ &m) {
115     auto cls = ClassTy(m, DerivedTy::pyClassName);
116     cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
117     cls.def_static(
118         "isinstance",
119         [](PyAffineExpr &otherAffineExpr) -> bool {
120           return DerivedTy::isaFunction(otherAffineExpr);
121         },
122         nb::arg("other"));
123     DerivedTy::bindDerived(cls);
124   }
125 
126   /// Implemented by derived classes to add methods to the Python subclass.
127   static void bindDerived(ClassTy &m) {}
128 };
129 
130 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
131 public:
132   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
133   static constexpr const char *pyClassName = "AffineConstantExpr";
134   using PyConcreteAffineExpr::PyConcreteAffineExpr;
135 
136   static PyAffineConstantExpr get(intptr_t value,
137                                   DefaultingPyMlirContext context) {
138     MlirAffineExpr affineExpr =
139         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
140     return PyAffineConstantExpr(context->getRef(), affineExpr);
141   }
142 
143   static void bindDerived(ClassTy &c) {
144     c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
145                  nb::arg("context").none() = nb::none());
146     c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
147       return mlirAffineConstantExprGetValue(self);
148     });
149   }
150 };
151 
152 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
153 public:
154   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
155   static constexpr const char *pyClassName = "AffineDimExpr";
156   using PyConcreteAffineExpr::PyConcreteAffineExpr;
157 
158   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
159     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
160     return PyAffineDimExpr(context->getRef(), affineExpr);
161   }
162 
163   static void bindDerived(ClassTy &c) {
164     c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
165                  nb::arg("context").none() = nb::none());
166     c.def_prop_ro("position", [](PyAffineDimExpr &self) {
167       return mlirAffineDimExprGetPosition(self);
168     });
169   }
170 };
171 
172 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
173 public:
174   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
175   static constexpr const char *pyClassName = "AffineSymbolExpr";
176   using PyConcreteAffineExpr::PyConcreteAffineExpr;
177 
178   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
179     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
180     return PyAffineSymbolExpr(context->getRef(), affineExpr);
181   }
182 
183   static void bindDerived(ClassTy &c) {
184     c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
185                  nb::arg("context").none() = nb::none());
186     c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
187       return mlirAffineSymbolExprGetPosition(self);
188     });
189   }
190 };
191 
192 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
193 public:
194   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
195   static constexpr const char *pyClassName = "AffineBinaryExpr";
196   using PyConcreteAffineExpr::PyConcreteAffineExpr;
197 
198   PyAffineExpr lhs() {
199     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
200     return PyAffineExpr(getContext(), lhsExpr);
201   }
202 
203   PyAffineExpr rhs() {
204     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
205     return PyAffineExpr(getContext(), rhsExpr);
206   }
207 
208   static void bindDerived(ClassTy &c) {
209     c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
210     c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
211   }
212 };
213 
214 class PyAffineAddExpr
215     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
216 public:
217   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
218   static constexpr const char *pyClassName = "AffineAddExpr";
219   using PyConcreteAffineExpr::PyConcreteAffineExpr;
220 
221   static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
222     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
223     return PyAffineAddExpr(lhs.getContext(), expr);
224   }
225 
226   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
227     MlirAffineExpr expr = mlirAffineAddExprGet(
228         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
229     return PyAffineAddExpr(lhs.getContext(), expr);
230   }
231 
232   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
233     MlirAffineExpr expr = mlirAffineAddExprGet(
234         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
235     return PyAffineAddExpr(rhs.getContext(), expr);
236   }
237 
238   static void bindDerived(ClassTy &c) {
239     c.def_static("get", &PyAffineAddExpr::get);
240   }
241 };
242 
243 class PyAffineMulExpr
244     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
245 public:
246   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
247   static constexpr const char *pyClassName = "AffineMulExpr";
248   using PyConcreteAffineExpr::PyConcreteAffineExpr;
249 
250   static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
251     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
252     return PyAffineMulExpr(lhs.getContext(), expr);
253   }
254 
255   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
256     MlirAffineExpr expr = mlirAffineMulExprGet(
257         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
258     return PyAffineMulExpr(lhs.getContext(), expr);
259   }
260 
261   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
262     MlirAffineExpr expr = mlirAffineMulExprGet(
263         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
264     return PyAffineMulExpr(rhs.getContext(), expr);
265   }
266 
267   static void bindDerived(ClassTy &c) {
268     c.def_static("get", &PyAffineMulExpr::get);
269   }
270 };
271 
272 class PyAffineModExpr
273     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
274 public:
275   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
276   static constexpr const char *pyClassName = "AffineModExpr";
277   using PyConcreteAffineExpr::PyConcreteAffineExpr;
278 
279   static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
280     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
281     return PyAffineModExpr(lhs.getContext(), expr);
282   }
283 
284   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
285     MlirAffineExpr expr = mlirAffineModExprGet(
286         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
287     return PyAffineModExpr(lhs.getContext(), expr);
288   }
289 
290   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
291     MlirAffineExpr expr = mlirAffineModExprGet(
292         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
293     return PyAffineModExpr(rhs.getContext(), expr);
294   }
295 
296   static void bindDerived(ClassTy &c) {
297     c.def_static("get", &PyAffineModExpr::get);
298   }
299 };
300 
301 class PyAffineFloorDivExpr
302     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
303 public:
304   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
305   static constexpr const char *pyClassName = "AffineFloorDivExpr";
306   using PyConcreteAffineExpr::PyConcreteAffineExpr;
307 
308   static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
309     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
310     return PyAffineFloorDivExpr(lhs.getContext(), expr);
311   }
312 
313   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
314     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
315         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
316     return PyAffineFloorDivExpr(lhs.getContext(), expr);
317   }
318 
319   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
320     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
321         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
322     return PyAffineFloorDivExpr(rhs.getContext(), expr);
323   }
324 
325   static void bindDerived(ClassTy &c) {
326     c.def_static("get", &PyAffineFloorDivExpr::get);
327   }
328 };
329 
330 class PyAffineCeilDivExpr
331     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
332 public:
333   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
334   static constexpr const char *pyClassName = "AffineCeilDivExpr";
335   using PyConcreteAffineExpr::PyConcreteAffineExpr;
336 
337   static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
338     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
339     return PyAffineCeilDivExpr(lhs.getContext(), expr);
340   }
341 
342   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
343     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
344         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
345     return PyAffineCeilDivExpr(lhs.getContext(), expr);
346   }
347 
348   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
349     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
350         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
351     return PyAffineCeilDivExpr(rhs.getContext(), expr);
352   }
353 
354   static void bindDerived(ClassTy &c) {
355     c.def_static("get", &PyAffineCeilDivExpr::get);
356   }
357 };
358 
359 } // namespace
360 
361 bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
362   return mlirAffineExprEqual(affineExpr, other.affineExpr);
363 }
364 
365 nb::object PyAffineExpr::getCapsule() {
366   return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
367 }
368 
369 PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
370   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
371   if (mlirAffineExprIsNull(rawAffineExpr))
372     throw nb::python_error();
373   return PyAffineExpr(
374       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
375       rawAffineExpr);
376 }
377 
378 //------------------------------------------------------------------------------
379 // PyAffineMap and utilities.
380 //------------------------------------------------------------------------------
381 namespace {
382 
383 /// A list of expressions contained in an affine map. Internally these are
384 /// stored as a consecutive array leading to inexpensive random access. Both
385 /// the map and the expression are owned by the context so we need not bother
386 /// with lifetime extension.
387 class PyAffineMapExprList
388     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
389 public:
390   static constexpr const char *pyClassName = "AffineExprList";
391 
392   PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
393                       intptr_t length = -1, intptr_t step = 1)
394       : Sliceable(startIndex,
395                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
396                   step),
397         affineMap(map) {}
398 
399 private:
400   /// Give the parent CRTP class access to hook implementations below.
401   friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
402 
403   intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
404 
405   PyAffineExpr getRawElement(intptr_t pos) {
406     return PyAffineExpr(affineMap.getContext(),
407                         mlirAffineMapGetResult(affineMap, pos));
408   }
409 
410   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
411                             intptr_t step) {
412     return PyAffineMapExprList(affineMap, startIndex, length, step);
413   }
414 
415   PyAffineMap affineMap;
416 };
417 } // namespace
418 
419 bool PyAffineMap::operator==(const PyAffineMap &other) const {
420   return mlirAffineMapEqual(affineMap, other.affineMap);
421 }
422 
423 nb::object PyAffineMap::getCapsule() {
424   return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
425 }
426 
427 PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
428   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
429   if (mlirAffineMapIsNull(rawAffineMap))
430     throw nb::python_error();
431   return PyAffineMap(
432       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
433       rawAffineMap);
434 }
435 
436 //------------------------------------------------------------------------------
437 // PyIntegerSet and utilities.
438 //------------------------------------------------------------------------------
439 namespace {
440 
441 class PyIntegerSetConstraint {
442 public:
443   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
444       : set(std::move(set)), pos(pos) {}
445 
446   PyAffineExpr getExpr() {
447     return PyAffineExpr(set.getContext(),
448                         mlirIntegerSetGetConstraint(set, pos));
449   }
450 
451   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
452 
453   static void bind(nb::module_ &m) {
454     nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
455         .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
456         .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
457   }
458 
459 private:
460   PyIntegerSet set;
461   intptr_t pos;
462 };
463 
464 class PyIntegerSetConstraintList
465     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
466 public:
467   static constexpr const char *pyClassName = "IntegerSetConstraintList";
468 
469   PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
470                              intptr_t length = -1, intptr_t step = 1)
471       : Sliceable(startIndex,
472                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
473                   step),
474         set(set) {}
475 
476 private:
477   /// Give the parent CRTP class access to hook implementations below.
478   friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
479 
480   intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
481 
482   PyIntegerSetConstraint getRawElement(intptr_t pos) {
483     return PyIntegerSetConstraint(set, pos);
484   }
485 
486   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
487                                    intptr_t step) {
488     return PyIntegerSetConstraintList(set, startIndex, length, step);
489   }
490 
491   PyIntegerSet set;
492 };
493 } // namespace
494 
495 bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
496   return mlirIntegerSetEqual(integerSet, other.integerSet);
497 }
498 
499 nb::object PyIntegerSet::getCapsule() {
500   return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
501 }
502 
503 PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
504   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
505   if (mlirIntegerSetIsNull(rawIntegerSet))
506     throw nb::python_error();
507   return PyIntegerSet(
508       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
509       rawIntegerSet);
510 }
511 
512 void mlir::python::populateIRAffine(nb::module_ &m) {
513   //----------------------------------------------------------------------------
514   // Mapping of PyAffineExpr and derived classes.
515   //----------------------------------------------------------------------------
516   nb::class_<PyAffineExpr>(m, "AffineExpr")
517       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
518       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
519       .def("__add__", &PyAffineAddExpr::get)
520       .def("__add__", &PyAffineAddExpr::getRHSConstant)
521       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
522       .def("__mul__", &PyAffineMulExpr::get)
523       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
524       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
525       .def("__mod__", &PyAffineModExpr::get)
526       .def("__mod__", &PyAffineModExpr::getRHSConstant)
527       .def("__rmod__",
528            [](PyAffineExpr &self, intptr_t other) {
529              return PyAffineModExpr::get(
530                  PyAffineConstantExpr::get(other, *self.getContext().get()),
531                  self);
532            })
533       .def("__sub__",
534            [](PyAffineExpr &self, PyAffineExpr &other) {
535              auto negOne =
536                  PyAffineConstantExpr::get(-1, *self.getContext().get());
537              return PyAffineAddExpr::get(self,
538                                          PyAffineMulExpr::get(negOne, other));
539            })
540       .def("__sub__",
541            [](PyAffineExpr &self, intptr_t other) {
542              return PyAffineAddExpr::get(
543                  self,
544                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
545            })
546       .def("__rsub__",
547            [](PyAffineExpr &self, intptr_t other) {
548              return PyAffineAddExpr::getLHSConstant(
549                  other, PyAffineMulExpr::getLHSConstant(-1, self));
550            })
551       .def("__eq__", [](PyAffineExpr &self,
552                         PyAffineExpr &other) { return self == other; })
553       .def("__eq__",
554            [](PyAffineExpr &self, nb::object &other) { return false; })
555       .def("__str__",
556            [](PyAffineExpr &self) {
557              PyPrintAccumulator printAccum;
558              mlirAffineExprPrint(self, printAccum.getCallback(),
559                                  printAccum.getUserData());
560              return printAccum.join();
561            })
562       .def("__repr__",
563            [](PyAffineExpr &self) {
564              PyPrintAccumulator printAccum;
565              printAccum.parts.append("AffineExpr(");
566              mlirAffineExprPrint(self, printAccum.getCallback(),
567                                  printAccum.getUserData());
568              printAccum.parts.append(")");
569              return printAccum.join();
570            })
571       .def("__hash__",
572            [](PyAffineExpr &self) {
573              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
574            })
575       .def_prop_ro(
576           "context",
577           [](PyAffineExpr &self) { return self.getContext().getObject(); })
578       .def("compose",
579            [](PyAffineExpr &self, PyAffineMap &other) {
580              return PyAffineExpr(self.getContext(),
581                                  mlirAffineExprCompose(self, other));
582            })
583       .def_static(
584           "get_add", &PyAffineAddExpr::get,
585           "Gets an affine expression containing a sum of two expressions.")
586       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
587                   "Gets an affine expression containing a sum of a constant "
588                   "and another expression.")
589       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
590                   "Gets an affine expression containing a sum of an expression "
591                   "and a constant.")
592       .def_static(
593           "get_mul", &PyAffineMulExpr::get,
594           "Gets an affine expression containing a product of two expressions.")
595       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
596                   "Gets an affine expression containing a product of a "
597                   "constant and another expression.")
598       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
599                   "Gets an affine expression containing a product of an "
600                   "expression and a constant.")
601       .def_static("get_mod", &PyAffineModExpr::get,
602                   "Gets an affine expression containing the modulo of dividing "
603                   "one expression by another.")
604       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
605                   "Gets a semi-affine expression containing the modulo of "
606                   "dividing a constant by an expression.")
607       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
608                   "Gets an affine expression containing the module of dividing"
609                   "an expression by a constant.")
610       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
611                   "Gets an affine expression containing the rounded-down "
612                   "result of dividing one expression by another.")
613       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
614                   "Gets a semi-affine expression containing the rounded-down "
615                   "result of dividing a constant by an expression.")
616       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
617                   "Gets an affine expression containing the rounded-down "
618                   "result of dividing an expression by a constant.")
619       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
620                   "Gets an affine expression containing the rounded-up result "
621                   "of dividing one expression by another.")
622       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
623                   "Gets a semi-affine expression containing the rounded-up "
624                   "result of dividing a constant by an expression.")
625       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
626                   "Gets an affine expression containing the rounded-up result "
627                   "of dividing an expression by a constant.")
628       .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
629                   nb::arg("context").none() = nb::none(),
630                   "Gets a constant affine expression with the given value.")
631       .def_static(
632           "get_dim", &PyAffineDimExpr::get, nb::arg("position"),
633           nb::arg("context").none() = nb::none(),
634           "Gets an affine expression of a dimension at the given position.")
635       .def_static(
636           "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
637           nb::arg("context").none() = nb::none(),
638           "Gets an affine expression of a symbol at the given position.")
639       .def(
640           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
641           kDumpDocstring);
642   PyAffineConstantExpr::bind(m);
643   PyAffineDimExpr::bind(m);
644   PyAffineSymbolExpr::bind(m);
645   PyAffineBinaryExpr::bind(m);
646   PyAffineAddExpr::bind(m);
647   PyAffineMulExpr::bind(m);
648   PyAffineModExpr::bind(m);
649   PyAffineFloorDivExpr::bind(m);
650   PyAffineCeilDivExpr::bind(m);
651 
652   //----------------------------------------------------------------------------
653   // Mapping of PyAffineMap.
654   //----------------------------------------------------------------------------
655   nb::class_<PyAffineMap>(m, "AffineMap")
656       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
657       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
658       .def("__eq__",
659            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
660       .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
661       .def("__str__",
662            [](PyAffineMap &self) {
663              PyPrintAccumulator printAccum;
664              mlirAffineMapPrint(self, printAccum.getCallback(),
665                                 printAccum.getUserData());
666              return printAccum.join();
667            })
668       .def("__repr__",
669            [](PyAffineMap &self) {
670              PyPrintAccumulator printAccum;
671              printAccum.parts.append("AffineMap(");
672              mlirAffineMapPrint(self, printAccum.getCallback(),
673                                 printAccum.getUserData());
674              printAccum.parts.append(")");
675              return printAccum.join();
676            })
677       .def("__hash__",
678            [](PyAffineMap &self) {
679              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
680            })
681       .def_static("compress_unused_symbols",
682                   [](nb::list affineMaps, DefaultingPyMlirContext context) {
683                     SmallVector<MlirAffineMap> maps;
684                     pyListToVector<PyAffineMap, MlirAffineMap>(
685                         affineMaps, maps, "attempting to create an AffineMap");
686                     std::vector<MlirAffineMap> compressed(affineMaps.size());
687                     auto populate = [](void *result, intptr_t idx,
688                                        MlirAffineMap m) {
689                       static_cast<MlirAffineMap *>(result)[idx] = (m);
690                     };
691                     mlirAffineMapCompressUnusedSymbols(
692                         maps.data(), maps.size(), compressed.data(), populate);
693                     std::vector<PyAffineMap> res;
694                     res.reserve(compressed.size());
695                     for (auto m : compressed)
696                       res.emplace_back(context->getRef(), m);
697                     return res;
698                   })
699       .def_prop_ro(
700           "context",
701           [](PyAffineMap &self) { return self.getContext().getObject(); },
702           "Context that owns the Affine Map")
703       .def(
704           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
705           kDumpDocstring)
706       .def_static(
707           "get",
708           [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
709              DefaultingPyMlirContext context) {
710             SmallVector<MlirAffineExpr> affineExprs;
711             pyListToVector<PyAffineExpr, MlirAffineExpr>(
712                 exprs, affineExprs, "attempting to create an AffineMap");
713             MlirAffineMap map =
714                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
715                                  affineExprs.size(), affineExprs.data());
716             return PyAffineMap(context->getRef(), map);
717           },
718           nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
719           nb::arg("context").none() = nb::none(),
720           "Gets a map with the given expressions as results.")
721       .def_static(
722           "get_constant",
723           [](intptr_t value, DefaultingPyMlirContext context) {
724             MlirAffineMap affineMap =
725                 mlirAffineMapConstantGet(context->get(), value);
726             return PyAffineMap(context->getRef(), affineMap);
727           },
728           nb::arg("value"), nb::arg("context").none() = nb::none(),
729           "Gets an affine map with a single constant result")
730       .def_static(
731           "get_empty",
732           [](DefaultingPyMlirContext context) {
733             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
734             return PyAffineMap(context->getRef(), affineMap);
735           },
736           nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
737       .def_static(
738           "get_identity",
739           [](intptr_t nDims, DefaultingPyMlirContext context) {
740             MlirAffineMap affineMap =
741                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
742             return PyAffineMap(context->getRef(), affineMap);
743           },
744           nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
745           "Gets an identity map with the given number of dimensions.")
746       .def_static(
747           "get_minor_identity",
748           [](intptr_t nDims, intptr_t nResults,
749              DefaultingPyMlirContext context) {
750             MlirAffineMap affineMap =
751                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
752             return PyAffineMap(context->getRef(), affineMap);
753           },
754           nb::arg("n_dims"), nb::arg("n_results"),
755           nb::arg("context").none() = nb::none(),
756           "Gets a minor identity map with the given number of dimensions and "
757           "results.")
758       .def_static(
759           "get_permutation",
760           [](std::vector<unsigned> permutation,
761              DefaultingPyMlirContext context) {
762             if (!isPermutation(permutation))
763               throw std::runtime_error("Invalid permutation when attempting to "
764                                        "create an AffineMap");
765             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
766                 context->get(), permutation.size(), permutation.data());
767             return PyAffineMap(context->getRef(), affineMap);
768           },
769           nb::arg("permutation"), nb::arg("context").none() = nb::none(),
770           "Gets an affine map that permutes its inputs.")
771       .def(
772           "get_submap",
773           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
774             intptr_t numResults = mlirAffineMapGetNumResults(self);
775             for (intptr_t pos : resultPos) {
776               if (pos < 0 || pos >= numResults)
777                 throw nb::value_error("result position out of bounds");
778             }
779             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
780                 self, resultPos.size(), resultPos.data());
781             return PyAffineMap(self.getContext(), affineMap);
782           },
783           nb::arg("result_positions"))
784       .def(
785           "get_major_submap",
786           [](PyAffineMap &self, intptr_t nResults) {
787             if (nResults >= mlirAffineMapGetNumResults(self))
788               throw nb::value_error("number of results out of bounds");
789             MlirAffineMap affineMap =
790                 mlirAffineMapGetMajorSubMap(self, nResults);
791             return PyAffineMap(self.getContext(), affineMap);
792           },
793           nb::arg("n_results"))
794       .def(
795           "get_minor_submap",
796           [](PyAffineMap &self, intptr_t nResults) {
797             if (nResults >= mlirAffineMapGetNumResults(self))
798               throw nb::value_error("number of results out of bounds");
799             MlirAffineMap affineMap =
800                 mlirAffineMapGetMinorSubMap(self, nResults);
801             return PyAffineMap(self.getContext(), affineMap);
802           },
803           nb::arg("n_results"))
804       .def(
805           "replace",
806           [](PyAffineMap &self, PyAffineExpr &expression,
807              PyAffineExpr &replacement, intptr_t numResultDims,
808              intptr_t numResultSyms) {
809             MlirAffineMap affineMap = mlirAffineMapReplace(
810                 self, expression, replacement, numResultDims, numResultSyms);
811             return PyAffineMap(self.getContext(), affineMap);
812           },
813           nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
814           nb::arg("n_result_syms"))
815       .def_prop_ro(
816           "is_permutation",
817           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
818       .def_prop_ro("is_projected_permutation",
819                    [](PyAffineMap &self) {
820                      return mlirAffineMapIsProjectedPermutation(self);
821                    })
822       .def_prop_ro(
823           "n_dims",
824           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
825       .def_prop_ro(
826           "n_inputs",
827           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
828       .def_prop_ro(
829           "n_symbols",
830           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
831       .def_prop_ro("results",
832                    [](PyAffineMap &self) { return PyAffineMapExprList(self); });
833   PyAffineMapExprList::bind(m);
834 
835   //----------------------------------------------------------------------------
836   // Mapping of PyIntegerSet.
837   //----------------------------------------------------------------------------
838   nb::class_<PyIntegerSet>(m, "IntegerSet")
839       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
840       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
841       .def("__eq__", [](PyIntegerSet &self,
842                         PyIntegerSet &other) { return self == other; })
843       .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
844       .def("__str__",
845            [](PyIntegerSet &self) {
846              PyPrintAccumulator printAccum;
847              mlirIntegerSetPrint(self, printAccum.getCallback(),
848                                  printAccum.getUserData());
849              return printAccum.join();
850            })
851       .def("__repr__",
852            [](PyIntegerSet &self) {
853              PyPrintAccumulator printAccum;
854              printAccum.parts.append("IntegerSet(");
855              mlirIntegerSetPrint(self, printAccum.getCallback(),
856                                  printAccum.getUserData());
857              printAccum.parts.append(")");
858              return printAccum.join();
859            })
860       .def("__hash__",
861            [](PyIntegerSet &self) {
862              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
863            })
864       .def_prop_ro(
865           "context",
866           [](PyIntegerSet &self) { return self.getContext().getObject(); })
867       .def(
868           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
869           kDumpDocstring)
870       .def_static(
871           "get",
872           [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
873              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
874             if (exprs.size() != eqFlags.size())
875               throw nb::value_error(
876                   "Expected the number of constraints to match "
877                   "that of equality flags");
878             if (exprs.size() == 0)
879               throw nb::value_error("Expected non-empty list of constraints");
880 
881             // Copy over to a SmallVector because std::vector has a
882             // specialization for booleans that packs data and does not
883             // expose a `bool *`.
884             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
885 
886             SmallVector<MlirAffineExpr> affineExprs;
887             pyListToVector<PyAffineExpr>(exprs, affineExprs,
888                                          "attempting to create an IntegerSet");
889             MlirIntegerSet set = mlirIntegerSetGet(
890                 context->get(), numDims, numSymbols, exprs.size(),
891                 affineExprs.data(), flags.data());
892             return PyIntegerSet(context->getRef(), set);
893           },
894           nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
895           nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
896       .def_static(
897           "get_empty",
898           [](intptr_t numDims, intptr_t numSymbols,
899              DefaultingPyMlirContext context) {
900             MlirIntegerSet set =
901                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
902             return PyIntegerSet(context->getRef(), set);
903           },
904           nb::arg("num_dims"), nb::arg("num_symbols"),
905           nb::arg("context").none() = nb::none())
906       .def(
907           "get_replaced",
908           [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
909              intptr_t numResultDims, intptr_t numResultSymbols) {
910             if (static_cast<intptr_t>(dimExprs.size()) !=
911                 mlirIntegerSetGetNumDims(self))
912               throw nb::value_error(
913                   "Expected the number of dimension replacement expressions "
914                   "to match that of dimensions");
915             if (static_cast<intptr_t>(symbolExprs.size()) !=
916                 mlirIntegerSetGetNumSymbols(self))
917               throw nb::value_error(
918                   "Expected the number of symbol replacement expressions "
919                   "to match that of symbols");
920 
921             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
922             pyListToVector<PyAffineExpr>(
923                 dimExprs, dimAffineExprs,
924                 "attempting to create an IntegerSet by replacing dimensions");
925             pyListToVector<PyAffineExpr>(
926                 symbolExprs, symbolAffineExprs,
927                 "attempting to create an IntegerSet by replacing symbols");
928             MlirIntegerSet set = mlirIntegerSetReplaceGet(
929                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
930                 numResultDims, numResultSymbols);
931             return PyIntegerSet(self.getContext(), set);
932           },
933           nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
934           nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
935       .def_prop_ro("is_canonical_empty",
936                    [](PyIntegerSet &self) {
937                      return mlirIntegerSetIsCanonicalEmpty(self);
938                    })
939       .def_prop_ro(
940           "n_dims",
941           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
942       .def_prop_ro(
943           "n_symbols",
944           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
945       .def_prop_ro(
946           "n_inputs",
947           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
948       .def_prop_ro("n_equalities",
949                    [](PyIntegerSet &self) {
950                      return mlirIntegerSetGetNumEqualities(self);
951                    })
952       .def_prop_ro("n_inequalities",
953                    [](PyIntegerSet &self) {
954                      return mlirIntegerSetGetNumInequalities(self);
955                    })
956       .def_prop_ro("constraints", [](PyIntegerSet &self) {
957         return PyIntegerSetConstraintList(self);
958       });
959   PyIntegerSetConstraint::bind(m);
960   PyIntegerSetConstraintList::bind(m);
961 }
962