xref: /llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9285a229fSMehdi Amini #include <cstddef>
10285a229fSMehdi Amini #include <cstdint>
11b56d1ec6SPeter Hawkins #include <stdexcept>
12285a229fSMehdi Amini #include <string>
131fc096afSMehdi Amini #include <utility>
14285a229fSMehdi Amini #include <vector>
151fc096afSMehdi Amini 
16436c6c9cSStella Laurenzo #include "IRModule.h"
17b56d1ec6SPeter Hawkins #include "NanobindUtils.h"
18285a229fSMehdi Amini #include "mlir-c/AffineExpr.h"
19436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h"
20436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h"
21*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
22*5cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23285a229fSMehdi Amini #include "mlir/Support/LLVM.h"
242233c4dcSBenjamin Kramer #include "llvm/ADT/Hashing.h"
25285a229fSMehdi Amini #include "llvm/ADT/SmallVector.h"
26285a229fSMehdi Amini #include "llvm/ADT/StringRef.h"
27285a229fSMehdi Amini #include "llvm/ADT/Twine.h"
28436c6c9cSStella Laurenzo 
29b56d1ec6SPeter Hawkins namespace nb = nanobind;
30436c6c9cSStella Laurenzo using namespace mlir;
31436c6c9cSStella Laurenzo using namespace mlir::python;
32436c6c9cSStella Laurenzo 
33436c6c9cSStella Laurenzo using llvm::SmallVector;
34436c6c9cSStella Laurenzo using llvm::StringRef;
35436c6c9cSStella Laurenzo using llvm::Twine;
36436c6c9cSStella Laurenzo 
37436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
38436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
39436c6c9cSStella Laurenzo 
40436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the
41436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments).
42436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller
43436c6c9cSStella Laurenzo /// was attempting to do.
44436c6c9cSStella Laurenzo template <typename PyType, typename CType>
45b56d1ec6SPeter Hawkins static void pyListToVector(const nb::list &list,
461fc096afSMehdi Amini                            llvm::SmallVectorImpl<CType> &result,
47436c6c9cSStella Laurenzo                            StringRef action) {
48b56d1ec6SPeter Hawkins   result.reserve(nb::len(list));
49b56d1ec6SPeter Hawkins   for (nb::handle item : list) {
50436c6c9cSStella Laurenzo     try {
51b56d1ec6SPeter Hawkins       result.push_back(nb::cast<PyType>(item));
52b56d1ec6SPeter Hawkins     } catch (nb::cast_error &err) {
53436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression when ") + action +
54436c6c9cSStella Laurenzo                          " (" + err.what() + ")")
55436c6c9cSStella Laurenzo                             .str();
56b56d1ec6SPeter Hawkins       throw std::runtime_error(msg.c_str());
57b56d1ec6SPeter Hawkins     } catch (std::runtime_error &err) {
58436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
59436c6c9cSStella Laurenzo                          action + " (" + err.what() + ")")
60436c6c9cSStella Laurenzo                             .str();
61b56d1ec6SPeter Hawkins       throw std::runtime_error(msg.c_str());
62436c6c9cSStella Laurenzo     }
63436c6c9cSStella Laurenzo   }
64436c6c9cSStella Laurenzo }
65436c6c9cSStella Laurenzo 
66436c6c9cSStella Laurenzo template <typename PermutationTy>
67436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) {
68436c6c9cSStella Laurenzo   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
69436c6c9cSStella Laurenzo   for (auto val : permutation) {
70436c6c9cSStella Laurenzo     if (val < permutation.size()) {
71436c6c9cSStella Laurenzo       if (seen[val])
72436c6c9cSStella Laurenzo         return false;
73436c6c9cSStella Laurenzo       seen[val] = true;
74436c6c9cSStella Laurenzo       continue;
75436c6c9cSStella Laurenzo     }
76436c6c9cSStella Laurenzo     return false;
77436c6c9cSStella Laurenzo   }
78436c6c9cSStella Laurenzo   return true;
79436c6c9cSStella Laurenzo }
80436c6c9cSStella Laurenzo 
81436c6c9cSStella Laurenzo namespace {
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
84436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be
85436c6c9cSStella Laurenzo /// modeled by specifying BaseTy.
86436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr>
87436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy {
88436c6c9cSStella Laurenzo public:
89436c6c9cSStella Laurenzo   // Derived classes must define statics for:
90436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
91436c6c9cSStella Laurenzo   //   const char *pyClassName
92436c6c9cSStella Laurenzo   // and redefine bindDerived.
93b56d1ec6SPeter Hawkins   using ClassTy = nb::class_<DerivedTy, BaseTy>;
94436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAffineExpr);
95436c6c9cSStella Laurenzo 
96436c6c9cSStella Laurenzo   PyConcreteAffineExpr() = default;
97436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
98436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), affineExpr) {}
99436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyAffineExpr &orig)
100436c6c9cSStella Laurenzo       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
101436c6c9cSStella Laurenzo 
102436c6c9cSStella Laurenzo   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
103436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
104b56d1ec6SPeter Hawkins       auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
105b56d1ec6SPeter Hawkins       throw nb::value_error((Twine("Cannot cast affine expression to ") +
1064811270bSmax                              DerivedTy::pyClassName + " (from " + origRepr +
1074811270bSmax                              ")")
108b56d1ec6SPeter Hawkins                                 .str()
109b56d1ec6SPeter Hawkins                                 .c_str());
110436c6c9cSStella Laurenzo     }
111436c6c9cSStella Laurenzo     return orig;
112436c6c9cSStella Laurenzo   }
113436c6c9cSStella Laurenzo 
114b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
115b56d1ec6SPeter Hawkins     auto cls = ClassTy(m, DerivedTy::pyClassName);
116b56d1ec6SPeter Hawkins     cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
117a6e7d024SStella Laurenzo     cls.def_static(
118a6e7d024SStella Laurenzo         "isinstance",
119a6e7d024SStella Laurenzo         [](PyAffineExpr &otherAffineExpr) -> bool {
12078f2dae0SAlex Zinenko           return DerivedTy::isaFunction(otherAffineExpr);
121a6e7d024SStella Laurenzo         },
122b56d1ec6SPeter Hawkins         nb::arg("other"));
123436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
124436c6c9cSStella Laurenzo   }
125436c6c9cSStella Laurenzo 
126436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
127436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
128436c6c9cSStella Laurenzo };
129436c6c9cSStella Laurenzo 
130436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
131436c6c9cSStella Laurenzo public:
132436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
133436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineConstantExpr";
134436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
135436c6c9cSStella Laurenzo 
136436c6c9cSStella Laurenzo   static PyAffineConstantExpr get(intptr_t value,
137436c6c9cSStella Laurenzo                                   DefaultingPyMlirContext context) {
138436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr =
139436c6c9cSStella Laurenzo         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
140436c6c9cSStella Laurenzo     return PyAffineConstantExpr(context->getRef(), affineExpr);
141436c6c9cSStella Laurenzo   }
142436c6c9cSStella Laurenzo 
143436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
144b56d1ec6SPeter Hawkins     c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
145b56d1ec6SPeter Hawkins                  nb::arg("context").none() = nb::none());
146b56d1ec6SPeter Hawkins     c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
147436c6c9cSStella Laurenzo       return mlirAffineConstantExprGetValue(self);
148436c6c9cSStella Laurenzo     });
149436c6c9cSStella Laurenzo   }
150436c6c9cSStella Laurenzo };
151436c6c9cSStella Laurenzo 
152436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
153436c6c9cSStella Laurenzo public:
154436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
155436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineDimExpr";
156436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
157436c6c9cSStella Laurenzo 
158436c6c9cSStella Laurenzo   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
159436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
160436c6c9cSStella Laurenzo     return PyAffineDimExpr(context->getRef(), affineExpr);
161436c6c9cSStella Laurenzo   }
162436c6c9cSStella Laurenzo 
163436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
164b56d1ec6SPeter Hawkins     c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
165b56d1ec6SPeter Hawkins                  nb::arg("context").none() = nb::none());
166b56d1ec6SPeter Hawkins     c.def_prop_ro("position", [](PyAffineDimExpr &self) {
167436c6c9cSStella Laurenzo       return mlirAffineDimExprGetPosition(self);
168436c6c9cSStella Laurenzo     });
169436c6c9cSStella Laurenzo   }
170436c6c9cSStella Laurenzo };
171436c6c9cSStella Laurenzo 
172436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
173436c6c9cSStella Laurenzo public:
174436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
175436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineSymbolExpr";
176436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
177436c6c9cSStella Laurenzo 
178436c6c9cSStella Laurenzo   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
179436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
180436c6c9cSStella Laurenzo     return PyAffineSymbolExpr(context->getRef(), affineExpr);
181436c6c9cSStella Laurenzo   }
182436c6c9cSStella Laurenzo 
183436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
184b56d1ec6SPeter Hawkins     c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
185b56d1ec6SPeter Hawkins                  nb::arg("context").none() = nb::none());
186b56d1ec6SPeter Hawkins     c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
187436c6c9cSStella Laurenzo       return mlirAffineSymbolExprGetPosition(self);
188436c6c9cSStella Laurenzo     });
189436c6c9cSStella Laurenzo   }
190436c6c9cSStella Laurenzo };
191436c6c9cSStella Laurenzo 
192436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
193436c6c9cSStella Laurenzo public:
194436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
195436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineBinaryExpr";
196436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
197436c6c9cSStella Laurenzo 
198436c6c9cSStella Laurenzo   PyAffineExpr lhs() {
199436c6c9cSStella Laurenzo     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
200436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), lhsExpr);
201436c6c9cSStella Laurenzo   }
202436c6c9cSStella Laurenzo 
203436c6c9cSStella Laurenzo   PyAffineExpr rhs() {
204436c6c9cSStella Laurenzo     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
205436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), rhsExpr);
206436c6c9cSStella Laurenzo   }
207436c6c9cSStella Laurenzo 
208436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
209b56d1ec6SPeter Hawkins     c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
210b56d1ec6SPeter Hawkins     c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
211436c6c9cSStella Laurenzo   }
212436c6c9cSStella Laurenzo };
213436c6c9cSStella Laurenzo 
214436c6c9cSStella Laurenzo class PyAffineAddExpr
215436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
216436c6c9cSStella Laurenzo public:
217436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
218436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineAddExpr";
219436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
220436c6c9cSStella Laurenzo 
2211fc096afSMehdi Amini   static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
222436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
223436c6c9cSStella Laurenzo     return PyAffineAddExpr(lhs.getContext(), expr);
224436c6c9cSStella Laurenzo   }
225436c6c9cSStella Laurenzo 
226fc7594ccSAlex Zinenko   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
227fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineAddExprGet(
228fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
229fc7594ccSAlex Zinenko     return PyAffineAddExpr(lhs.getContext(), expr);
230fc7594ccSAlex Zinenko   }
231fc7594ccSAlex Zinenko 
232fc7594ccSAlex Zinenko   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
233fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineAddExprGet(
234fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
235fc7594ccSAlex Zinenko     return PyAffineAddExpr(rhs.getContext(), expr);
236fc7594ccSAlex Zinenko   }
237fc7594ccSAlex Zinenko 
238436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
239436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineAddExpr::get);
240436c6c9cSStella Laurenzo   }
241436c6c9cSStella Laurenzo };
242436c6c9cSStella Laurenzo 
243436c6c9cSStella Laurenzo class PyAffineMulExpr
244436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
245436c6c9cSStella Laurenzo public:
246436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
247436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMulExpr";
248436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
249436c6c9cSStella Laurenzo 
2501fc096afSMehdi Amini   static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
251436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
252436c6c9cSStella Laurenzo     return PyAffineMulExpr(lhs.getContext(), expr);
253436c6c9cSStella Laurenzo   }
254436c6c9cSStella Laurenzo 
255fc7594ccSAlex Zinenko   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
256fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineMulExprGet(
257fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
258fc7594ccSAlex Zinenko     return PyAffineMulExpr(lhs.getContext(), expr);
259fc7594ccSAlex Zinenko   }
260fc7594ccSAlex Zinenko 
261fc7594ccSAlex Zinenko   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
262fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineMulExprGet(
263fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
264fc7594ccSAlex Zinenko     return PyAffineMulExpr(rhs.getContext(), expr);
265fc7594ccSAlex Zinenko   }
266fc7594ccSAlex Zinenko 
267436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
268436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineMulExpr::get);
269436c6c9cSStella Laurenzo   }
270436c6c9cSStella Laurenzo };
271436c6c9cSStella Laurenzo 
272436c6c9cSStella Laurenzo class PyAffineModExpr
273436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
274436c6c9cSStella Laurenzo public:
275436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
276436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineModExpr";
277436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
278436c6c9cSStella Laurenzo 
2791fc096afSMehdi Amini   static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
280436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
281436c6c9cSStella Laurenzo     return PyAffineModExpr(lhs.getContext(), expr);
282436c6c9cSStella Laurenzo   }
283436c6c9cSStella Laurenzo 
284fc7594ccSAlex Zinenko   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
285fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineModExprGet(
286fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
287fc7594ccSAlex Zinenko     return PyAffineModExpr(lhs.getContext(), expr);
288fc7594ccSAlex Zinenko   }
289fc7594ccSAlex Zinenko 
290fc7594ccSAlex Zinenko   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
291fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineModExprGet(
292fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
293fc7594ccSAlex Zinenko     return PyAffineModExpr(rhs.getContext(), expr);
294fc7594ccSAlex Zinenko   }
295fc7594ccSAlex Zinenko 
296436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
297436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineModExpr::get);
298436c6c9cSStella Laurenzo   }
299436c6c9cSStella Laurenzo };
300436c6c9cSStella Laurenzo 
301436c6c9cSStella Laurenzo class PyAffineFloorDivExpr
302436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
303436c6c9cSStella Laurenzo public:
304436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
305436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineFloorDivExpr";
306436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
307436c6c9cSStella Laurenzo 
3081fc096afSMehdi Amini   static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
309436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
310436c6c9cSStella Laurenzo     return PyAffineFloorDivExpr(lhs.getContext(), expr);
311436c6c9cSStella Laurenzo   }
312436c6c9cSStella Laurenzo 
313fc7594ccSAlex Zinenko   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
314fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
315fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
316fc7594ccSAlex Zinenko     return PyAffineFloorDivExpr(lhs.getContext(), expr);
317fc7594ccSAlex Zinenko   }
318fc7594ccSAlex Zinenko 
319fc7594ccSAlex Zinenko   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
320fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
321fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
322fc7594ccSAlex Zinenko     return PyAffineFloorDivExpr(rhs.getContext(), expr);
323fc7594ccSAlex Zinenko   }
324fc7594ccSAlex Zinenko 
325436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
326436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineFloorDivExpr::get);
327436c6c9cSStella Laurenzo   }
328436c6c9cSStella Laurenzo };
329436c6c9cSStella Laurenzo 
330436c6c9cSStella Laurenzo class PyAffineCeilDivExpr
331436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
332436c6c9cSStella Laurenzo public:
333436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
334436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineCeilDivExpr";
335436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
336436c6c9cSStella Laurenzo 
3371fc096afSMehdi Amini   static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
338436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
339436c6c9cSStella Laurenzo     return PyAffineCeilDivExpr(lhs.getContext(), expr);
340436c6c9cSStella Laurenzo   }
341436c6c9cSStella Laurenzo 
342fc7594ccSAlex Zinenko   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
343fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
344fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
345fc7594ccSAlex Zinenko     return PyAffineCeilDivExpr(lhs.getContext(), expr);
346fc7594ccSAlex Zinenko   }
347fc7594ccSAlex Zinenko 
348fc7594ccSAlex Zinenko   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
349fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
350fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
351fc7594ccSAlex Zinenko     return PyAffineCeilDivExpr(rhs.getContext(), expr);
352fc7594ccSAlex Zinenko   }
353fc7594ccSAlex Zinenko 
354436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
355436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineCeilDivExpr::get);
356436c6c9cSStella Laurenzo   }
357436c6c9cSStella Laurenzo };
358436c6c9cSStella Laurenzo 
359436c6c9cSStella Laurenzo } // namespace
360436c6c9cSStella Laurenzo 
361e6d738e0SRahul Kayaith bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
362436c6c9cSStella Laurenzo   return mlirAffineExprEqual(affineExpr, other.affineExpr);
363436c6c9cSStella Laurenzo }
364436c6c9cSStella Laurenzo 
365b56d1ec6SPeter Hawkins nb::object PyAffineExpr::getCapsule() {
366b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
367436c6c9cSStella Laurenzo }
368436c6c9cSStella Laurenzo 
369b56d1ec6SPeter Hawkins PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
370436c6c9cSStella Laurenzo   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
371436c6c9cSStella Laurenzo   if (mlirAffineExprIsNull(rawAffineExpr))
372b56d1ec6SPeter Hawkins     throw nb::python_error();
373436c6c9cSStella Laurenzo   return PyAffineExpr(
374436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
375436c6c9cSStella Laurenzo       rawAffineExpr);
376436c6c9cSStella Laurenzo }
377436c6c9cSStella Laurenzo 
378436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
379436c6c9cSStella Laurenzo // PyAffineMap and utilities.
380436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
381436c6c9cSStella Laurenzo namespace {
382436c6c9cSStella Laurenzo 
383436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are
384436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both
385436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother
386436c6c9cSStella Laurenzo /// with lifetime extension.
387436c6c9cSStella Laurenzo class PyAffineMapExprList
388436c6c9cSStella Laurenzo     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
389436c6c9cSStella Laurenzo public:
390436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineExprList";
391436c6c9cSStella Laurenzo 
3921fc096afSMehdi Amini   PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
393436c6c9cSStella Laurenzo                       intptr_t length = -1, intptr_t step = 1)
394436c6c9cSStella Laurenzo       : Sliceable(startIndex,
395436c6c9cSStella Laurenzo                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
396436c6c9cSStella Laurenzo                   step),
397436c6c9cSStella Laurenzo         affineMap(map) {}
398436c6c9cSStella Laurenzo 
399ee168fb9SAlex Zinenko private:
400ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
401ee168fb9SAlex Zinenko   friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
402436c6c9cSStella Laurenzo 
403ee168fb9SAlex Zinenko   intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
404ee168fb9SAlex Zinenko 
405ee168fb9SAlex Zinenko   PyAffineExpr getRawElement(intptr_t pos) {
406436c6c9cSStella Laurenzo     return PyAffineExpr(affineMap.getContext(),
407436c6c9cSStella Laurenzo                         mlirAffineMapGetResult(affineMap, pos));
408436c6c9cSStella Laurenzo   }
409436c6c9cSStella Laurenzo 
410436c6c9cSStella Laurenzo   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
411436c6c9cSStella Laurenzo                             intptr_t step) {
412436c6c9cSStella Laurenzo     return PyAffineMapExprList(affineMap, startIndex, length, step);
413436c6c9cSStella Laurenzo   }
414436c6c9cSStella Laurenzo 
415436c6c9cSStella Laurenzo   PyAffineMap affineMap;
416436c6c9cSStella Laurenzo };
417be0a7e9fSMehdi Amini } // namespace
418436c6c9cSStella Laurenzo 
419e6d738e0SRahul Kayaith bool PyAffineMap::operator==(const PyAffineMap &other) const {
420436c6c9cSStella Laurenzo   return mlirAffineMapEqual(affineMap, other.affineMap);
421436c6c9cSStella Laurenzo }
422436c6c9cSStella Laurenzo 
423b56d1ec6SPeter Hawkins nb::object PyAffineMap::getCapsule() {
424b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
425436c6c9cSStella Laurenzo }
426436c6c9cSStella Laurenzo 
427b56d1ec6SPeter Hawkins PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
428436c6c9cSStella Laurenzo   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
429436c6c9cSStella Laurenzo   if (mlirAffineMapIsNull(rawAffineMap))
430b56d1ec6SPeter Hawkins     throw nb::python_error();
431436c6c9cSStella Laurenzo   return PyAffineMap(
432436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
433436c6c9cSStella Laurenzo       rawAffineMap);
434436c6c9cSStella Laurenzo }
435436c6c9cSStella Laurenzo 
436436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
437436c6c9cSStella Laurenzo // PyIntegerSet and utilities.
438436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
439436c6c9cSStella Laurenzo namespace {
440436c6c9cSStella Laurenzo 
441436c6c9cSStella Laurenzo class PyIntegerSetConstraint {
442436c6c9cSStella Laurenzo public:
4431fc096afSMehdi Amini   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
4441fc096afSMehdi Amini       : set(std::move(set)), pos(pos) {}
445436c6c9cSStella Laurenzo 
446436c6c9cSStella Laurenzo   PyAffineExpr getExpr() {
447436c6c9cSStella Laurenzo     return PyAffineExpr(set.getContext(),
448436c6c9cSStella Laurenzo                         mlirIntegerSetGetConstraint(set, pos));
449436c6c9cSStella Laurenzo   }
450436c6c9cSStella Laurenzo 
451436c6c9cSStella Laurenzo   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
452436c6c9cSStella Laurenzo 
453b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
454b56d1ec6SPeter Hawkins     nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
455b56d1ec6SPeter Hawkins         .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
456b56d1ec6SPeter Hawkins         .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
457436c6c9cSStella Laurenzo   }
458436c6c9cSStella Laurenzo 
459436c6c9cSStella Laurenzo private:
460436c6c9cSStella Laurenzo   PyIntegerSet set;
461436c6c9cSStella Laurenzo   intptr_t pos;
462436c6c9cSStella Laurenzo };
463436c6c9cSStella Laurenzo 
464436c6c9cSStella Laurenzo class PyIntegerSetConstraintList
465436c6c9cSStella Laurenzo     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
466436c6c9cSStella Laurenzo public:
467436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerSetConstraintList";
468436c6c9cSStella Laurenzo 
4691fc096afSMehdi Amini   PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
470436c6c9cSStella Laurenzo                              intptr_t length = -1, intptr_t step = 1)
471436c6c9cSStella Laurenzo       : Sliceable(startIndex,
472436c6c9cSStella Laurenzo                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
473436c6c9cSStella Laurenzo                   step),
474436c6c9cSStella Laurenzo         set(set) {}
475436c6c9cSStella Laurenzo 
476ee168fb9SAlex Zinenko private:
477ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
478ee168fb9SAlex Zinenko   friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
479436c6c9cSStella Laurenzo 
480ee168fb9SAlex Zinenko   intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
481ee168fb9SAlex Zinenko 
482ee168fb9SAlex Zinenko   PyIntegerSetConstraint getRawElement(intptr_t pos) {
483436c6c9cSStella Laurenzo     return PyIntegerSetConstraint(set, pos);
484436c6c9cSStella Laurenzo   }
485436c6c9cSStella Laurenzo 
486436c6c9cSStella Laurenzo   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
487436c6c9cSStella Laurenzo                                    intptr_t step) {
488436c6c9cSStella Laurenzo     return PyIntegerSetConstraintList(set, startIndex, length, step);
489436c6c9cSStella Laurenzo   }
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo   PyIntegerSet set;
492436c6c9cSStella Laurenzo };
493436c6c9cSStella Laurenzo } // namespace
494436c6c9cSStella Laurenzo 
495e6d738e0SRahul Kayaith bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
496436c6c9cSStella Laurenzo   return mlirIntegerSetEqual(integerSet, other.integerSet);
497436c6c9cSStella Laurenzo }
498436c6c9cSStella Laurenzo 
499b56d1ec6SPeter Hawkins nb::object PyIntegerSet::getCapsule() {
500b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
501436c6c9cSStella Laurenzo }
502436c6c9cSStella Laurenzo 
503b56d1ec6SPeter Hawkins PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
504436c6c9cSStella Laurenzo   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
505436c6c9cSStella Laurenzo   if (mlirIntegerSetIsNull(rawIntegerSet))
506b56d1ec6SPeter Hawkins     throw nb::python_error();
507436c6c9cSStella Laurenzo   return PyIntegerSet(
508436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
509436c6c9cSStella Laurenzo       rawIntegerSet);
510436c6c9cSStella Laurenzo }
511436c6c9cSStella Laurenzo 
512b56d1ec6SPeter Hawkins void mlir::python::populateIRAffine(nb::module_ &m) {
513436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
514436c6c9cSStella Laurenzo   // Mapping of PyAffineExpr and derived classes.
515436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
516b56d1ec6SPeter Hawkins   nb::class_<PyAffineExpr>(m, "AffineExpr")
517b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
518436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
519fc7594ccSAlex Zinenko       .def("__add__", &PyAffineAddExpr::get)
520fc7594ccSAlex Zinenko       .def("__add__", &PyAffineAddExpr::getRHSConstant)
521fc7594ccSAlex Zinenko       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
522fc7594ccSAlex Zinenko       .def("__mul__", &PyAffineMulExpr::get)
523fc7594ccSAlex Zinenko       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
524fc7594ccSAlex Zinenko       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
525fc7594ccSAlex Zinenko       .def("__mod__", &PyAffineModExpr::get)
526fc7594ccSAlex Zinenko       .def("__mod__", &PyAffineModExpr::getRHSConstant)
527fc7594ccSAlex Zinenko       .def("__rmod__",
528fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
529fc7594ccSAlex Zinenko              return PyAffineModExpr::get(
530fc7594ccSAlex Zinenko                  PyAffineConstantExpr::get(other, *self.getContext().get()),
531fc7594ccSAlex Zinenko                  self);
532436c6c9cSStella Laurenzo            })
533436c6c9cSStella Laurenzo       .def("__sub__",
534436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
535436c6c9cSStella Laurenzo              auto negOne =
536436c6c9cSStella Laurenzo                  PyAffineConstantExpr::get(-1, *self.getContext().get());
537436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self,
538436c6c9cSStella Laurenzo                                          PyAffineMulExpr::get(negOne, other));
539436c6c9cSStella Laurenzo            })
540fc7594ccSAlex Zinenko       .def("__sub__",
541fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
542fc7594ccSAlex Zinenko              return PyAffineAddExpr::get(
543fc7594ccSAlex Zinenko                  self,
544fc7594ccSAlex Zinenko                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
545fc7594ccSAlex Zinenko            })
546fc7594ccSAlex Zinenko       .def("__rsub__",
547fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
548fc7594ccSAlex Zinenko              return PyAffineAddExpr::getLHSConstant(
549fc7594ccSAlex Zinenko                  other, PyAffineMulExpr::getLHSConstant(-1, self));
550fc7594ccSAlex Zinenko            })
551436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineExpr &self,
552436c6c9cSStella Laurenzo                         PyAffineExpr &other) { return self == other; })
553436c6c9cSStella Laurenzo       .def("__eq__",
554b56d1ec6SPeter Hawkins            [](PyAffineExpr &self, nb::object &other) { return false; })
555436c6c9cSStella Laurenzo       .def("__str__",
556436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
557436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
558436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
559436c6c9cSStella Laurenzo                                  printAccum.getUserData());
560436c6c9cSStella Laurenzo              return printAccum.join();
561436c6c9cSStella Laurenzo            })
562436c6c9cSStella Laurenzo       .def("__repr__",
563436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
564436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
565436c6c9cSStella Laurenzo              printAccum.parts.append("AffineExpr(");
566436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
567436c6c9cSStella Laurenzo                                  printAccum.getUserData());
568436c6c9cSStella Laurenzo              printAccum.parts.append(")");
569436c6c9cSStella Laurenzo              return printAccum.join();
570436c6c9cSStella Laurenzo            })
571fc7594ccSAlex Zinenko       .def("__hash__",
572fc7594ccSAlex Zinenko            [](PyAffineExpr &self) {
573fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
574fc7594ccSAlex Zinenko            })
575b56d1ec6SPeter Hawkins       .def_prop_ro(
576436c6c9cSStella Laurenzo           "context",
577436c6c9cSStella Laurenzo           [](PyAffineExpr &self) { return self.getContext().getObject(); })
578fc7594ccSAlex Zinenko       .def("compose",
579fc7594ccSAlex Zinenko            [](PyAffineExpr &self, PyAffineMap &other) {
580fc7594ccSAlex Zinenko              return PyAffineExpr(self.getContext(),
581fc7594ccSAlex Zinenko                                  mlirAffineExprCompose(self, other));
582fc7594ccSAlex Zinenko            })
583436c6c9cSStella Laurenzo       .def_static(
584436c6c9cSStella Laurenzo           "get_add", &PyAffineAddExpr::get,
585436c6c9cSStella Laurenzo           "Gets an affine expression containing a sum of two expressions.")
586fc7594ccSAlex Zinenko       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
587fc7594ccSAlex Zinenko                   "Gets an affine expression containing a sum of a constant "
588fc7594ccSAlex Zinenko                   "and another expression.")
589fc7594ccSAlex Zinenko       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
590fc7594ccSAlex Zinenko                   "Gets an affine expression containing a sum of an expression "
591fc7594ccSAlex Zinenko                   "and a constant.")
592436c6c9cSStella Laurenzo       .def_static(
593436c6c9cSStella Laurenzo           "get_mul", &PyAffineMulExpr::get,
594436c6c9cSStella Laurenzo           "Gets an affine expression containing a product of two expressions.")
595fc7594ccSAlex Zinenko       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
596fc7594ccSAlex Zinenko                   "Gets an affine expression containing a product of a "
597fc7594ccSAlex Zinenko                   "constant and another expression.")
598fc7594ccSAlex Zinenko       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
599fc7594ccSAlex Zinenko                   "Gets an affine expression containing a product of an "
600fc7594ccSAlex Zinenko                   "expression and a constant.")
601436c6c9cSStella Laurenzo       .def_static("get_mod", &PyAffineModExpr::get,
602436c6c9cSStella Laurenzo                   "Gets an affine expression containing the modulo of dividing "
603436c6c9cSStella Laurenzo                   "one expression by another.")
604fc7594ccSAlex Zinenko       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
605fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the modulo of "
606fc7594ccSAlex Zinenko                   "dividing a constant by an expression.")
607fc7594ccSAlex Zinenko       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
608fc7594ccSAlex Zinenko                   "Gets an affine expression containing the module of dividing"
609fc7594ccSAlex Zinenko                   "an expression by a constant.")
610436c6c9cSStella Laurenzo       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
611436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-down "
612436c6c9cSStella Laurenzo                   "result of dividing one expression by another.")
613fc7594ccSAlex Zinenko       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
614fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the rounded-down "
615fc7594ccSAlex Zinenko                   "result of dividing a constant by an expression.")
616fc7594ccSAlex Zinenko       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
617fc7594ccSAlex Zinenko                   "Gets an affine expression containing the rounded-down "
618fc7594ccSAlex Zinenko                   "result of dividing an expression by a constant.")
619436c6c9cSStella Laurenzo       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
620436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-up result "
621436c6c9cSStella Laurenzo                   "of dividing one expression by another.")
622fc7594ccSAlex Zinenko       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
623fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the rounded-up "
624fc7594ccSAlex Zinenko                   "result of dividing a constant by an expression.")
625fc7594ccSAlex Zinenko       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
626fc7594ccSAlex Zinenko                   "Gets an affine expression containing the rounded-up result "
627fc7594ccSAlex Zinenko                   "of dividing an expression by a constant.")
628b56d1ec6SPeter Hawkins       .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
629b56d1ec6SPeter Hawkins                   nb::arg("context").none() = nb::none(),
630436c6c9cSStella Laurenzo                   "Gets a constant affine expression with the given value.")
631436c6c9cSStella Laurenzo       .def_static(
632b56d1ec6SPeter Hawkins           "get_dim", &PyAffineDimExpr::get, nb::arg("position"),
633b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
634436c6c9cSStella Laurenzo           "Gets an affine expression of a dimension at the given position.")
635436c6c9cSStella Laurenzo       .def_static(
636b56d1ec6SPeter Hawkins           "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
637b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
638436c6c9cSStella Laurenzo           "Gets an affine expression of a symbol at the given position.")
639436c6c9cSStella Laurenzo       .def(
640436c6c9cSStella Laurenzo           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
641436c6c9cSStella Laurenzo           kDumpDocstring);
642436c6c9cSStella Laurenzo   PyAffineConstantExpr::bind(m);
643436c6c9cSStella Laurenzo   PyAffineDimExpr::bind(m);
644436c6c9cSStella Laurenzo   PyAffineSymbolExpr::bind(m);
645436c6c9cSStella Laurenzo   PyAffineBinaryExpr::bind(m);
646436c6c9cSStella Laurenzo   PyAffineAddExpr::bind(m);
647436c6c9cSStella Laurenzo   PyAffineMulExpr::bind(m);
648436c6c9cSStella Laurenzo   PyAffineModExpr::bind(m);
649436c6c9cSStella Laurenzo   PyAffineFloorDivExpr::bind(m);
650436c6c9cSStella Laurenzo   PyAffineCeilDivExpr::bind(m);
651436c6c9cSStella Laurenzo 
652436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
653436c6c9cSStella Laurenzo   // Mapping of PyAffineMap.
654436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
655b56d1ec6SPeter Hawkins   nb::class_<PyAffineMap>(m, "AffineMap")
656b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
657436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
658436c6c9cSStella Laurenzo       .def("__eq__",
659436c6c9cSStella Laurenzo            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
660b56d1ec6SPeter Hawkins       .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
661436c6c9cSStella Laurenzo       .def("__str__",
662436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
663436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
664436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
665436c6c9cSStella Laurenzo                                 printAccum.getUserData());
666436c6c9cSStella Laurenzo              return printAccum.join();
667436c6c9cSStella Laurenzo            })
668436c6c9cSStella Laurenzo       .def("__repr__",
669436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
670436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
671436c6c9cSStella Laurenzo              printAccum.parts.append("AffineMap(");
672436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
673436c6c9cSStella Laurenzo                                 printAccum.getUserData());
674436c6c9cSStella Laurenzo              printAccum.parts.append(")");
675436c6c9cSStella Laurenzo              return printAccum.join();
676436c6c9cSStella Laurenzo            })
677fc7594ccSAlex Zinenko       .def("__hash__",
678fc7594ccSAlex Zinenko            [](PyAffineMap &self) {
679fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
680fc7594ccSAlex Zinenko            })
681335d2df5SNicolas Vasilache       .def_static("compress_unused_symbols",
682b56d1ec6SPeter Hawkins                   [](nb::list affineMaps, DefaultingPyMlirContext context) {
683335d2df5SNicolas Vasilache                     SmallVector<MlirAffineMap> maps;
684335d2df5SNicolas Vasilache                     pyListToVector<PyAffineMap, MlirAffineMap>(
685335d2df5SNicolas Vasilache                         affineMaps, maps, "attempting to create an AffineMap");
686335d2df5SNicolas Vasilache                     std::vector<MlirAffineMap> compressed(affineMaps.size());
687335d2df5SNicolas Vasilache                     auto populate = [](void *result, intptr_t idx,
688335d2df5SNicolas Vasilache                                        MlirAffineMap m) {
689335d2df5SNicolas Vasilache                       static_cast<MlirAffineMap *>(result)[idx] = (m);
690335d2df5SNicolas Vasilache                     };
691335d2df5SNicolas Vasilache                     mlirAffineMapCompressUnusedSymbols(
692335d2df5SNicolas Vasilache                         maps.data(), maps.size(), compressed.data(), populate);
693335d2df5SNicolas Vasilache                     std::vector<PyAffineMap> res;
694e2f16be5SMehdi Amini                     res.reserve(compressed.size());
695335d2df5SNicolas Vasilache                     for (auto m : compressed)
696e5639b3fSMehdi Amini                       res.emplace_back(context->getRef(), m);
697335d2df5SNicolas Vasilache                     return res;
698335d2df5SNicolas Vasilache                   })
699b56d1ec6SPeter Hawkins       .def_prop_ro(
700436c6c9cSStella Laurenzo           "context",
701436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return self.getContext().getObject(); },
702436c6c9cSStella Laurenzo           "Context that owns the Affine Map")
703436c6c9cSStella Laurenzo       .def(
704436c6c9cSStella Laurenzo           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
705436c6c9cSStella Laurenzo           kDumpDocstring)
706436c6c9cSStella Laurenzo       .def_static(
707436c6c9cSStella Laurenzo           "get",
708b56d1ec6SPeter Hawkins           [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
709436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
710436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
711436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr, MlirAffineExpr>(
712337c937dSMehdi Amini                 exprs, affineExprs, "attempting to create an AffineMap");
713436c6c9cSStella Laurenzo             MlirAffineMap map =
714436c6c9cSStella Laurenzo                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
715436c6c9cSStella Laurenzo                                  affineExprs.size(), affineExprs.data());
716436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), map);
717436c6c9cSStella Laurenzo           },
718b56d1ec6SPeter Hawkins           nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
719b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
720436c6c9cSStella Laurenzo           "Gets a map with the given expressions as results.")
721436c6c9cSStella Laurenzo       .def_static(
722436c6c9cSStella Laurenzo           "get_constant",
723436c6c9cSStella Laurenzo           [](intptr_t value, DefaultingPyMlirContext context) {
724436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
725436c6c9cSStella Laurenzo                 mlirAffineMapConstantGet(context->get(), value);
726436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
727436c6c9cSStella Laurenzo           },
728b56d1ec6SPeter Hawkins           nb::arg("value"), nb::arg("context").none() = nb::none(),
729436c6c9cSStella Laurenzo           "Gets an affine map with a single constant result")
730436c6c9cSStella Laurenzo       .def_static(
731436c6c9cSStella Laurenzo           "get_empty",
732436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
733436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
734436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
735436c6c9cSStella Laurenzo           },
736b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
737436c6c9cSStella Laurenzo       .def_static(
738436c6c9cSStella Laurenzo           "get_identity",
739436c6c9cSStella Laurenzo           [](intptr_t nDims, DefaultingPyMlirContext context) {
740436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
741436c6c9cSStella Laurenzo                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
742436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
743436c6c9cSStella Laurenzo           },
744b56d1ec6SPeter Hawkins           nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
745436c6c9cSStella Laurenzo           "Gets an identity map with the given number of dimensions.")
746436c6c9cSStella Laurenzo       .def_static(
747436c6c9cSStella Laurenzo           "get_minor_identity",
748436c6c9cSStella Laurenzo           [](intptr_t nDims, intptr_t nResults,
749436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
750436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
751436c6c9cSStella Laurenzo                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
752436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
753436c6c9cSStella Laurenzo           },
754b56d1ec6SPeter Hawkins           nb::arg("n_dims"), nb::arg("n_results"),
755b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
756436c6c9cSStella Laurenzo           "Gets a minor identity map with the given number of dimensions and "
757436c6c9cSStella Laurenzo           "results.")
758436c6c9cSStella Laurenzo       .def_static(
759436c6c9cSStella Laurenzo           "get_permutation",
760436c6c9cSStella Laurenzo           [](std::vector<unsigned> permutation,
761436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
762436c6c9cSStella Laurenzo             if (!isPermutation(permutation))
763b56d1ec6SPeter Hawkins               throw std::runtime_error("Invalid permutation when attempting to "
764436c6c9cSStella Laurenzo                                        "create an AffineMap");
765436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
766436c6c9cSStella Laurenzo                 context->get(), permutation.size(), permutation.data());
767436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
768436c6c9cSStella Laurenzo           },
769b56d1ec6SPeter Hawkins           nb::arg("permutation"), nb::arg("context").none() = nb::none(),
770436c6c9cSStella Laurenzo           "Gets an affine map that permutes its inputs.")
771a6e7d024SStella Laurenzo       .def(
772a6e7d024SStella Laurenzo           "get_submap",
773436c6c9cSStella Laurenzo           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
774436c6c9cSStella Laurenzo             intptr_t numResults = mlirAffineMapGetNumResults(self);
775436c6c9cSStella Laurenzo             for (intptr_t pos : resultPos) {
776436c6c9cSStella Laurenzo               if (pos < 0 || pos >= numResults)
777b56d1ec6SPeter Hawkins                 throw nb::value_error("result position out of bounds");
778436c6c9cSStella Laurenzo             }
779436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
780436c6c9cSStella Laurenzo                 self, resultPos.size(), resultPos.data());
781436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
782a6e7d024SStella Laurenzo           },
783b56d1ec6SPeter Hawkins           nb::arg("result_positions"))
784a6e7d024SStella Laurenzo       .def(
785a6e7d024SStella Laurenzo           "get_major_submap",
786436c6c9cSStella Laurenzo           [](PyAffineMap &self, intptr_t nResults) {
787436c6c9cSStella Laurenzo             if (nResults >= mlirAffineMapGetNumResults(self))
788b56d1ec6SPeter Hawkins               throw nb::value_error("number of results out of bounds");
789436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
790436c6c9cSStella Laurenzo                 mlirAffineMapGetMajorSubMap(self, nResults);
791436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
792a6e7d024SStella Laurenzo           },
793b56d1ec6SPeter Hawkins           nb::arg("n_results"))
794a6e7d024SStella Laurenzo       .def(
795a6e7d024SStella Laurenzo           "get_minor_submap",
796436c6c9cSStella Laurenzo           [](PyAffineMap &self, intptr_t nResults) {
797436c6c9cSStella Laurenzo             if (nResults >= mlirAffineMapGetNumResults(self))
798b56d1ec6SPeter Hawkins               throw nb::value_error("number of results out of bounds");
799436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
800436c6c9cSStella Laurenzo                 mlirAffineMapGetMinorSubMap(self, nResults);
801436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
802a6e7d024SStella Laurenzo           },
803b56d1ec6SPeter Hawkins           nb::arg("n_results"))
804a6e7d024SStella Laurenzo       .def(
805a6e7d024SStella Laurenzo           "replace",
80631f888eaSTobias Gysi           [](PyAffineMap &self, PyAffineExpr &expression,
80731f888eaSTobias Gysi              PyAffineExpr &replacement, intptr_t numResultDims,
80831f888eaSTobias Gysi              intptr_t numResultSyms) {
80931f888eaSTobias Gysi             MlirAffineMap affineMap = mlirAffineMapReplace(
81031f888eaSTobias Gysi                 self, expression, replacement, numResultDims, numResultSyms);
81131f888eaSTobias Gysi             return PyAffineMap(self.getContext(), affineMap);
812a6e7d024SStella Laurenzo           },
813b56d1ec6SPeter Hawkins           nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
814b56d1ec6SPeter Hawkins           nb::arg("n_result_syms"))
815b56d1ec6SPeter Hawkins       .def_prop_ro(
816436c6c9cSStella Laurenzo           "is_permutation",
817436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
818b56d1ec6SPeter Hawkins       .def_prop_ro("is_projected_permutation",
819436c6c9cSStella Laurenzo                    [](PyAffineMap &self) {
820436c6c9cSStella Laurenzo                      return mlirAffineMapIsProjectedPermutation(self);
821436c6c9cSStella Laurenzo                    })
822b56d1ec6SPeter Hawkins       .def_prop_ro(
823436c6c9cSStella Laurenzo           "n_dims",
824436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
825b56d1ec6SPeter Hawkins       .def_prop_ro(
826436c6c9cSStella Laurenzo           "n_inputs",
827436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
828b56d1ec6SPeter Hawkins       .def_prop_ro(
829436c6c9cSStella Laurenzo           "n_symbols",
830436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
831b56d1ec6SPeter Hawkins       .def_prop_ro("results",
832b56d1ec6SPeter Hawkins                    [](PyAffineMap &self) { return PyAffineMapExprList(self); });
833436c6c9cSStella Laurenzo   PyAffineMapExprList::bind(m);
834436c6c9cSStella Laurenzo 
835436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
836436c6c9cSStella Laurenzo   // Mapping of PyIntegerSet.
837436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
838b56d1ec6SPeter Hawkins   nb::class_<PyIntegerSet>(m, "IntegerSet")
839b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
840436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
841436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self,
842436c6c9cSStella Laurenzo                         PyIntegerSet &other) { return self == other; })
843b56d1ec6SPeter Hawkins       .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
844436c6c9cSStella Laurenzo       .def("__str__",
845436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
846436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
847436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
848436c6c9cSStella Laurenzo                                  printAccum.getUserData());
849436c6c9cSStella Laurenzo              return printAccum.join();
850436c6c9cSStella Laurenzo            })
851436c6c9cSStella Laurenzo       .def("__repr__",
852436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
853436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
854436c6c9cSStella Laurenzo              printAccum.parts.append("IntegerSet(");
855436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
856436c6c9cSStella Laurenzo                                  printAccum.getUserData());
857436c6c9cSStella Laurenzo              printAccum.parts.append(")");
858436c6c9cSStella Laurenzo              return printAccum.join();
859436c6c9cSStella Laurenzo            })
860fc7594ccSAlex Zinenko       .def("__hash__",
861fc7594ccSAlex Zinenko            [](PyIntegerSet &self) {
862fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
863fc7594ccSAlex Zinenko            })
864b56d1ec6SPeter Hawkins       .def_prop_ro(
865436c6c9cSStella Laurenzo           "context",
866436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return self.getContext().getObject(); })
867436c6c9cSStella Laurenzo       .def(
868436c6c9cSStella Laurenzo           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
869436c6c9cSStella Laurenzo           kDumpDocstring)
870436c6c9cSStella Laurenzo       .def_static(
871436c6c9cSStella Laurenzo           "get",
872b56d1ec6SPeter Hawkins           [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
873436c6c9cSStella Laurenzo              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
874436c6c9cSStella Laurenzo             if (exprs.size() != eqFlags.size())
875b56d1ec6SPeter Hawkins               throw nb::value_error(
876436c6c9cSStella Laurenzo                   "Expected the number of constraints to match "
877436c6c9cSStella Laurenzo                   "that of equality flags");
878b56d1ec6SPeter Hawkins             if (exprs.size() == 0)
879b56d1ec6SPeter Hawkins               throw nb::value_error("Expected non-empty list of constraints");
880436c6c9cSStella Laurenzo 
881436c6c9cSStella Laurenzo             // Copy over to a SmallVector because std::vector has a
882436c6c9cSStella Laurenzo             // specialization for booleans that packs data and does not
883436c6c9cSStella Laurenzo             // expose a `bool *`.
884436c6c9cSStella Laurenzo             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
885436c6c9cSStella Laurenzo 
886436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
887436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(exprs, affineExprs,
888436c6c9cSStella Laurenzo                                          "attempting to create an IntegerSet");
889436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetGet(
890436c6c9cSStella Laurenzo                 context->get(), numDims, numSymbols, exprs.size(),
891436c6c9cSStella Laurenzo                 affineExprs.data(), flags.data());
892436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
893436c6c9cSStella Laurenzo           },
894b56d1ec6SPeter Hawkins           nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
895b56d1ec6SPeter Hawkins           nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
896436c6c9cSStella Laurenzo       .def_static(
897436c6c9cSStella Laurenzo           "get_empty",
898436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols,
899436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
900436c6c9cSStella Laurenzo             MlirIntegerSet set =
901436c6c9cSStella Laurenzo                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
902436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
903436c6c9cSStella Laurenzo           },
904b56d1ec6SPeter Hawkins           nb::arg("num_dims"), nb::arg("num_symbols"),
905b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none())
906a6e7d024SStella Laurenzo       .def(
907a6e7d024SStella Laurenzo           "get_replaced",
908b56d1ec6SPeter Hawkins           [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
909436c6c9cSStella Laurenzo              intptr_t numResultDims, intptr_t numResultSymbols) {
910436c6c9cSStella Laurenzo             if (static_cast<intptr_t>(dimExprs.size()) !=
911436c6c9cSStella Laurenzo                 mlirIntegerSetGetNumDims(self))
912b56d1ec6SPeter Hawkins               throw nb::value_error(
913436c6c9cSStella Laurenzo                   "Expected the number of dimension replacement expressions "
914436c6c9cSStella Laurenzo                   "to match that of dimensions");
915436c6c9cSStella Laurenzo             if (static_cast<intptr_t>(symbolExprs.size()) !=
916436c6c9cSStella Laurenzo                 mlirIntegerSetGetNumSymbols(self))
917b56d1ec6SPeter Hawkins               throw nb::value_error(
918436c6c9cSStella Laurenzo                   "Expected the number of symbol replacement expressions "
919436c6c9cSStella Laurenzo                   "to match that of symbols");
920436c6c9cSStella Laurenzo 
921436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
922436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(
923436c6c9cSStella Laurenzo                 dimExprs, dimAffineExprs,
924436c6c9cSStella Laurenzo                 "attempting to create an IntegerSet by replacing dimensions");
925436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(
926436c6c9cSStella Laurenzo                 symbolExprs, symbolAffineExprs,
927436c6c9cSStella Laurenzo                 "attempting to create an IntegerSet by replacing symbols");
928436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetReplaceGet(
929436c6c9cSStella Laurenzo                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
930436c6c9cSStella Laurenzo                 numResultDims, numResultSymbols);
931436c6c9cSStella Laurenzo             return PyIntegerSet(self.getContext(), set);
932a6e7d024SStella Laurenzo           },
933b56d1ec6SPeter Hawkins           nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
934b56d1ec6SPeter Hawkins           nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
935b56d1ec6SPeter Hawkins       .def_prop_ro("is_canonical_empty",
936436c6c9cSStella Laurenzo                    [](PyIntegerSet &self) {
937436c6c9cSStella Laurenzo                      return mlirIntegerSetIsCanonicalEmpty(self);
938436c6c9cSStella Laurenzo                    })
939b56d1ec6SPeter Hawkins       .def_prop_ro(
940436c6c9cSStella Laurenzo           "n_dims",
941436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
942b56d1ec6SPeter Hawkins       .def_prop_ro(
943436c6c9cSStella Laurenzo           "n_symbols",
944436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
945b56d1ec6SPeter Hawkins       .def_prop_ro(
946436c6c9cSStella Laurenzo           "n_inputs",
947436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
948b56d1ec6SPeter Hawkins       .def_prop_ro("n_equalities",
949436c6c9cSStella Laurenzo                    [](PyIntegerSet &self) {
950436c6c9cSStella Laurenzo                      return mlirIntegerSetGetNumEqualities(self);
951436c6c9cSStella Laurenzo                    })
952b56d1ec6SPeter Hawkins       .def_prop_ro("n_inequalities",
953436c6c9cSStella Laurenzo                    [](PyIntegerSet &self) {
954436c6c9cSStella Laurenzo                      return mlirIntegerSetGetNumInequalities(self);
955436c6c9cSStella Laurenzo                    })
956b56d1ec6SPeter Hawkins       .def_prop_ro("constraints", [](PyIntegerSet &self) {
957436c6c9cSStella Laurenzo         return PyIntegerSetConstraintList(self);
958436c6c9cSStella Laurenzo       });
959436c6c9cSStella Laurenzo   PyIntegerSetConstraint::bind(m);
960436c6c9cSStella Laurenzo   PyIntegerSetConstraintList::bind(m);
961436c6c9cSStella Laurenzo }
962