xref: /llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp (revision 285a229f205ae67dca48c8eac8206a115320c677)
1436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9*285a229fSMehdi Amini #include <cstddef>
10*285a229fSMehdi Amini #include <cstdint>
11*285a229fSMehdi Amini #include <pybind11/cast.h>
12*285a229fSMehdi Amini #include <pybind11/detail/common.h>
13*285a229fSMehdi Amini #include <pybind11/pybind11.h>
14*285a229fSMehdi Amini #include <pybind11/pytypes.h>
15*285a229fSMehdi Amini #include <string>
161fc096afSMehdi Amini #include <utility>
17*285a229fSMehdi Amini #include <vector>
181fc096afSMehdi Amini 
19436c6c9cSStella Laurenzo #include "IRModule.h"
20436c6c9cSStella Laurenzo 
21436c6c9cSStella Laurenzo #include "PybindUtils.h"
22436c6c9cSStella Laurenzo 
23*285a229fSMehdi Amini #include "mlir-c/AffineExpr.h"
24436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h"
25436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
26436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h"
27*285a229fSMehdi Amini #include "mlir/Support/LLVM.h"
282233c4dcSBenjamin Kramer #include "llvm/ADT/Hashing.h"
29*285a229fSMehdi Amini #include "llvm/ADT/SmallVector.h"
30*285a229fSMehdi Amini #include "llvm/ADT/StringRef.h"
31*285a229fSMehdi Amini #include "llvm/ADT/Twine.h"
32436c6c9cSStella Laurenzo 
33436c6c9cSStella Laurenzo namespace py = pybind11;
34436c6c9cSStella Laurenzo using namespace mlir;
35436c6c9cSStella Laurenzo using namespace mlir::python;
36436c6c9cSStella Laurenzo 
37436c6c9cSStella Laurenzo using llvm::SmallVector;
38436c6c9cSStella Laurenzo using llvm::StringRef;
39436c6c9cSStella Laurenzo using llvm::Twine;
40436c6c9cSStella Laurenzo 
41436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
42436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
43436c6c9cSStella Laurenzo 
44436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the
45436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments).
46436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller
47436c6c9cSStella Laurenzo /// was attempting to do.
48436c6c9cSStella Laurenzo template <typename PyType, typename CType>
491fc096afSMehdi Amini static void pyListToVector(const py::list &list,
501fc096afSMehdi Amini                            llvm::SmallVectorImpl<CType> &result,
51436c6c9cSStella Laurenzo                            StringRef action) {
52436c6c9cSStella Laurenzo   result.reserve(py::len(list));
53436c6c9cSStella Laurenzo   for (py::handle item : list) {
54436c6c9cSStella Laurenzo     try {
55436c6c9cSStella Laurenzo       result.push_back(item.cast<PyType>());
56436c6c9cSStella Laurenzo     } catch (py::cast_error &err) {
57436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression when ") + action +
58436c6c9cSStella Laurenzo                          " (" + err.what() + ")")
59436c6c9cSStella Laurenzo                             .str();
60436c6c9cSStella Laurenzo       throw py::cast_error(msg);
61436c6c9cSStella Laurenzo     } catch (py::reference_cast_error &err) {
62436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
63436c6c9cSStella Laurenzo                          action + " (" + err.what() + ")")
64436c6c9cSStella Laurenzo                             .str();
65436c6c9cSStella Laurenzo       throw py::cast_error(msg);
66436c6c9cSStella Laurenzo     }
67436c6c9cSStella Laurenzo   }
68436c6c9cSStella Laurenzo }
69436c6c9cSStella Laurenzo 
70436c6c9cSStella Laurenzo template <typename PermutationTy>
71436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) {
72436c6c9cSStella Laurenzo   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
73436c6c9cSStella Laurenzo   for (auto val : permutation) {
74436c6c9cSStella Laurenzo     if (val < permutation.size()) {
75436c6c9cSStella Laurenzo       if (seen[val])
76436c6c9cSStella Laurenzo         return false;
77436c6c9cSStella Laurenzo       seen[val] = true;
78436c6c9cSStella Laurenzo       continue;
79436c6c9cSStella Laurenzo     }
80436c6c9cSStella Laurenzo     return false;
81436c6c9cSStella Laurenzo   }
82436c6c9cSStella Laurenzo   return true;
83436c6c9cSStella Laurenzo }
84436c6c9cSStella Laurenzo 
85436c6c9cSStella Laurenzo namespace {
86436c6c9cSStella Laurenzo 
87436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
88436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be
89436c6c9cSStella Laurenzo /// modeled by specifying BaseTy.
90436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr>
91436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy {
92436c6c9cSStella Laurenzo public:
93436c6c9cSStella Laurenzo   // Derived classes must define statics for:
94436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
95436c6c9cSStella Laurenzo   //   const char *pyClassName
96436c6c9cSStella Laurenzo   // and redefine bindDerived.
97436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, BaseTy>;
98436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAffineExpr);
99436c6c9cSStella Laurenzo 
100436c6c9cSStella Laurenzo   PyConcreteAffineExpr() = default;
101436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
102436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), affineExpr) {}
103436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyAffineExpr &orig)
104436c6c9cSStella Laurenzo       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
105436c6c9cSStella Laurenzo 
106436c6c9cSStella Laurenzo   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
107436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
108436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1094811270bSmax       throw py::value_error((Twine("Cannot cast affine expression to ") +
1104811270bSmax                              DerivedTy::pyClassName + " (from " + origRepr +
1114811270bSmax                              ")")
1124811270bSmax                                 .str());
113436c6c9cSStella Laurenzo     }
114436c6c9cSStella Laurenzo     return orig;
115436c6c9cSStella Laurenzo   }
116436c6c9cSStella Laurenzo 
117436c6c9cSStella Laurenzo   static void bind(py::module &m) {
118f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
119a6e7d024SStella Laurenzo     cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
120a6e7d024SStella Laurenzo     cls.def_static(
121a6e7d024SStella Laurenzo         "isinstance",
122a6e7d024SStella Laurenzo         [](PyAffineExpr &otherAffineExpr) -> bool {
12378f2dae0SAlex Zinenko           return DerivedTy::isaFunction(otherAffineExpr);
124a6e7d024SStella Laurenzo         },
125a6e7d024SStella Laurenzo         py::arg("other"));
126436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
127436c6c9cSStella Laurenzo   }
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
130436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
131436c6c9cSStella Laurenzo };
132436c6c9cSStella Laurenzo 
133436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
134436c6c9cSStella Laurenzo public:
135436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
136436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineConstantExpr";
137436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo   static PyAffineConstantExpr get(intptr_t value,
140436c6c9cSStella Laurenzo                                   DefaultingPyMlirContext context) {
141436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr =
142436c6c9cSStella Laurenzo         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
143436c6c9cSStella Laurenzo     return PyAffineConstantExpr(context->getRef(), affineExpr);
144436c6c9cSStella Laurenzo   }
145436c6c9cSStella Laurenzo 
146436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
147436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
148436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
149436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
150436c6c9cSStella Laurenzo       return mlirAffineConstantExprGetValue(self);
151436c6c9cSStella Laurenzo     });
152436c6c9cSStella Laurenzo   }
153436c6c9cSStella Laurenzo };
154436c6c9cSStella Laurenzo 
155436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
156436c6c9cSStella Laurenzo public:
157436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
158436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineDimExpr";
159436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
160436c6c9cSStella Laurenzo 
161436c6c9cSStella Laurenzo   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
162436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
163436c6c9cSStella Laurenzo     return PyAffineDimExpr(context->getRef(), affineExpr);
164436c6c9cSStella Laurenzo   }
165436c6c9cSStella Laurenzo 
166436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
167436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
168436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
169436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
170436c6c9cSStella Laurenzo       return mlirAffineDimExprGetPosition(self);
171436c6c9cSStella Laurenzo     });
172436c6c9cSStella Laurenzo   }
173436c6c9cSStella Laurenzo };
174436c6c9cSStella Laurenzo 
175436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
176436c6c9cSStella Laurenzo public:
177436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
178436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineSymbolExpr";
179436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
180436c6c9cSStella Laurenzo 
181436c6c9cSStella Laurenzo   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
182436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
183436c6c9cSStella Laurenzo     return PyAffineSymbolExpr(context->getRef(), affineExpr);
184436c6c9cSStella Laurenzo   }
185436c6c9cSStella Laurenzo 
186436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
187436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
188436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
189436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
190436c6c9cSStella Laurenzo       return mlirAffineSymbolExprGetPosition(self);
191436c6c9cSStella Laurenzo     });
192436c6c9cSStella Laurenzo   }
193436c6c9cSStella Laurenzo };
194436c6c9cSStella Laurenzo 
195436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
196436c6c9cSStella Laurenzo public:
197436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
198436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineBinaryExpr";
199436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
200436c6c9cSStella Laurenzo 
201436c6c9cSStella Laurenzo   PyAffineExpr lhs() {
202436c6c9cSStella Laurenzo     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
203436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), lhsExpr);
204436c6c9cSStella Laurenzo   }
205436c6c9cSStella Laurenzo 
206436c6c9cSStella Laurenzo   PyAffineExpr rhs() {
207436c6c9cSStella Laurenzo     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
208436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), rhsExpr);
209436c6c9cSStella Laurenzo   }
210436c6c9cSStella Laurenzo 
211436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
212436c6c9cSStella Laurenzo     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
213436c6c9cSStella Laurenzo     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
214436c6c9cSStella Laurenzo   }
215436c6c9cSStella Laurenzo };
216436c6c9cSStella Laurenzo 
217436c6c9cSStella Laurenzo class PyAffineAddExpr
218436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
219436c6c9cSStella Laurenzo public:
220436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
221436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineAddExpr";
222436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
223436c6c9cSStella Laurenzo 
2241fc096afSMehdi Amini   static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
225436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
226436c6c9cSStella Laurenzo     return PyAffineAddExpr(lhs.getContext(), expr);
227436c6c9cSStella Laurenzo   }
228436c6c9cSStella Laurenzo 
229fc7594ccSAlex Zinenko   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
230fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineAddExprGet(
231fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
232fc7594ccSAlex Zinenko     return PyAffineAddExpr(lhs.getContext(), expr);
233fc7594ccSAlex Zinenko   }
234fc7594ccSAlex Zinenko 
235fc7594ccSAlex Zinenko   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
236fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineAddExprGet(
237fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
238fc7594ccSAlex Zinenko     return PyAffineAddExpr(rhs.getContext(), expr);
239fc7594ccSAlex Zinenko   }
240fc7594ccSAlex Zinenko 
241436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
242436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineAddExpr::get);
243436c6c9cSStella Laurenzo   }
244436c6c9cSStella Laurenzo };
245436c6c9cSStella Laurenzo 
246436c6c9cSStella Laurenzo class PyAffineMulExpr
247436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
248436c6c9cSStella Laurenzo public:
249436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
250436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMulExpr";
251436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
252436c6c9cSStella Laurenzo 
2531fc096afSMehdi Amini   static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
254436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
255436c6c9cSStella Laurenzo     return PyAffineMulExpr(lhs.getContext(), expr);
256436c6c9cSStella Laurenzo   }
257436c6c9cSStella Laurenzo 
258fc7594ccSAlex Zinenko   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
259fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineMulExprGet(
260fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
261fc7594ccSAlex Zinenko     return PyAffineMulExpr(lhs.getContext(), expr);
262fc7594ccSAlex Zinenko   }
263fc7594ccSAlex Zinenko 
264fc7594ccSAlex Zinenko   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
265fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineMulExprGet(
266fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
267fc7594ccSAlex Zinenko     return PyAffineMulExpr(rhs.getContext(), expr);
268fc7594ccSAlex Zinenko   }
269fc7594ccSAlex Zinenko 
270436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
271436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineMulExpr::get);
272436c6c9cSStella Laurenzo   }
273436c6c9cSStella Laurenzo };
274436c6c9cSStella Laurenzo 
275436c6c9cSStella Laurenzo class PyAffineModExpr
276436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
277436c6c9cSStella Laurenzo public:
278436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
279436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineModExpr";
280436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
281436c6c9cSStella Laurenzo 
2821fc096afSMehdi Amini   static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
283436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
284436c6c9cSStella Laurenzo     return PyAffineModExpr(lhs.getContext(), expr);
285436c6c9cSStella Laurenzo   }
286436c6c9cSStella Laurenzo 
287fc7594ccSAlex Zinenko   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
288fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineModExprGet(
289fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
290fc7594ccSAlex Zinenko     return PyAffineModExpr(lhs.getContext(), expr);
291fc7594ccSAlex Zinenko   }
292fc7594ccSAlex Zinenko 
293fc7594ccSAlex Zinenko   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
294fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineModExprGet(
295fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
296fc7594ccSAlex Zinenko     return PyAffineModExpr(rhs.getContext(), expr);
297fc7594ccSAlex Zinenko   }
298fc7594ccSAlex Zinenko 
299436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
300436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineModExpr::get);
301436c6c9cSStella Laurenzo   }
302436c6c9cSStella Laurenzo };
303436c6c9cSStella Laurenzo 
304436c6c9cSStella Laurenzo class PyAffineFloorDivExpr
305436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
306436c6c9cSStella Laurenzo public:
307436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
308436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineFloorDivExpr";
309436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
310436c6c9cSStella Laurenzo 
3111fc096afSMehdi Amini   static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
312436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
313436c6c9cSStella Laurenzo     return PyAffineFloorDivExpr(lhs.getContext(), expr);
314436c6c9cSStella Laurenzo   }
315436c6c9cSStella Laurenzo 
316fc7594ccSAlex Zinenko   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
317fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
318fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
319fc7594ccSAlex Zinenko     return PyAffineFloorDivExpr(lhs.getContext(), expr);
320fc7594ccSAlex Zinenko   }
321fc7594ccSAlex Zinenko 
322fc7594ccSAlex Zinenko   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
323fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
324fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
325fc7594ccSAlex Zinenko     return PyAffineFloorDivExpr(rhs.getContext(), expr);
326fc7594ccSAlex Zinenko   }
327fc7594ccSAlex Zinenko 
328436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
329436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineFloorDivExpr::get);
330436c6c9cSStella Laurenzo   }
331436c6c9cSStella Laurenzo };
332436c6c9cSStella Laurenzo 
333436c6c9cSStella Laurenzo class PyAffineCeilDivExpr
334436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
335436c6c9cSStella Laurenzo public:
336436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
337436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineCeilDivExpr";
338436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
339436c6c9cSStella Laurenzo 
3401fc096afSMehdi Amini   static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
341436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
342436c6c9cSStella Laurenzo     return PyAffineCeilDivExpr(lhs.getContext(), expr);
343436c6c9cSStella Laurenzo   }
344436c6c9cSStella Laurenzo 
345fc7594ccSAlex Zinenko   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
346fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
347fc7594ccSAlex Zinenko         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
348fc7594ccSAlex Zinenko     return PyAffineCeilDivExpr(lhs.getContext(), expr);
349fc7594ccSAlex Zinenko   }
350fc7594ccSAlex Zinenko 
351fc7594ccSAlex Zinenko   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
352fc7594ccSAlex Zinenko     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
353fc7594ccSAlex Zinenko         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
354fc7594ccSAlex Zinenko     return PyAffineCeilDivExpr(rhs.getContext(), expr);
355fc7594ccSAlex Zinenko   }
356fc7594ccSAlex Zinenko 
357436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
358436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineCeilDivExpr::get);
359436c6c9cSStella Laurenzo   }
360436c6c9cSStella Laurenzo };
361436c6c9cSStella Laurenzo 
362436c6c9cSStella Laurenzo } // namespace
363436c6c9cSStella Laurenzo 
364e6d738e0SRahul Kayaith bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
365436c6c9cSStella Laurenzo   return mlirAffineExprEqual(affineExpr, other.affineExpr);
366436c6c9cSStella Laurenzo }
367436c6c9cSStella Laurenzo 
368436c6c9cSStella Laurenzo py::object PyAffineExpr::getCapsule() {
369436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
370436c6c9cSStella Laurenzo       mlirPythonAffineExprToCapsule(*this));
371436c6c9cSStella Laurenzo }
372436c6c9cSStella Laurenzo 
373436c6c9cSStella Laurenzo PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
374436c6c9cSStella Laurenzo   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
375436c6c9cSStella Laurenzo   if (mlirAffineExprIsNull(rawAffineExpr))
376436c6c9cSStella Laurenzo     throw py::error_already_set();
377436c6c9cSStella Laurenzo   return PyAffineExpr(
378436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
379436c6c9cSStella Laurenzo       rawAffineExpr);
380436c6c9cSStella Laurenzo }
381436c6c9cSStella Laurenzo 
382436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
383436c6c9cSStella Laurenzo // PyAffineMap and utilities.
384436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
385436c6c9cSStella Laurenzo namespace {
386436c6c9cSStella Laurenzo 
387436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are
388436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both
389436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother
390436c6c9cSStella Laurenzo /// with lifetime extension.
391436c6c9cSStella Laurenzo class PyAffineMapExprList
392436c6c9cSStella Laurenzo     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
393436c6c9cSStella Laurenzo public:
394436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineExprList";
395436c6c9cSStella Laurenzo 
3961fc096afSMehdi Amini   PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
397436c6c9cSStella Laurenzo                       intptr_t length = -1, intptr_t step = 1)
398436c6c9cSStella Laurenzo       : Sliceable(startIndex,
399436c6c9cSStella Laurenzo                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
400436c6c9cSStella Laurenzo                   step),
401436c6c9cSStella Laurenzo         affineMap(map) {}
402436c6c9cSStella Laurenzo 
403ee168fb9SAlex Zinenko private:
404ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
405ee168fb9SAlex Zinenko   friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
406436c6c9cSStella Laurenzo 
407ee168fb9SAlex Zinenko   intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
408ee168fb9SAlex Zinenko 
409ee168fb9SAlex Zinenko   PyAffineExpr getRawElement(intptr_t pos) {
410436c6c9cSStella Laurenzo     return PyAffineExpr(affineMap.getContext(),
411436c6c9cSStella Laurenzo                         mlirAffineMapGetResult(affineMap, pos));
412436c6c9cSStella Laurenzo   }
413436c6c9cSStella Laurenzo 
414436c6c9cSStella Laurenzo   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
415436c6c9cSStella Laurenzo                             intptr_t step) {
416436c6c9cSStella Laurenzo     return PyAffineMapExprList(affineMap, startIndex, length, step);
417436c6c9cSStella Laurenzo   }
418436c6c9cSStella Laurenzo 
419436c6c9cSStella Laurenzo   PyAffineMap affineMap;
420436c6c9cSStella Laurenzo };
421be0a7e9fSMehdi Amini } // namespace
422436c6c9cSStella Laurenzo 
423e6d738e0SRahul Kayaith bool PyAffineMap::operator==(const PyAffineMap &other) const {
424436c6c9cSStella Laurenzo   return mlirAffineMapEqual(affineMap, other.affineMap);
425436c6c9cSStella Laurenzo }
426436c6c9cSStella Laurenzo 
427436c6c9cSStella Laurenzo py::object PyAffineMap::getCapsule() {
428436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
429436c6c9cSStella Laurenzo }
430436c6c9cSStella Laurenzo 
431436c6c9cSStella Laurenzo PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
432436c6c9cSStella Laurenzo   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
433436c6c9cSStella Laurenzo   if (mlirAffineMapIsNull(rawAffineMap))
434436c6c9cSStella Laurenzo     throw py::error_already_set();
435436c6c9cSStella Laurenzo   return PyAffineMap(
436436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
437436c6c9cSStella Laurenzo       rawAffineMap);
438436c6c9cSStella Laurenzo }
439436c6c9cSStella Laurenzo 
440436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
441436c6c9cSStella Laurenzo // PyIntegerSet and utilities.
442436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
443436c6c9cSStella Laurenzo namespace {
444436c6c9cSStella Laurenzo 
445436c6c9cSStella Laurenzo class PyIntegerSetConstraint {
446436c6c9cSStella Laurenzo public:
4471fc096afSMehdi Amini   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
4481fc096afSMehdi Amini       : set(std::move(set)), pos(pos) {}
449436c6c9cSStella Laurenzo 
450436c6c9cSStella Laurenzo   PyAffineExpr getExpr() {
451436c6c9cSStella Laurenzo     return PyAffineExpr(set.getContext(),
452436c6c9cSStella Laurenzo                         mlirIntegerSetGetConstraint(set, pos));
453436c6c9cSStella Laurenzo   }
454436c6c9cSStella Laurenzo 
455436c6c9cSStella Laurenzo   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
456436c6c9cSStella Laurenzo 
457436c6c9cSStella Laurenzo   static void bind(py::module &m) {
458f05ff4f7SStella Laurenzo     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
459f05ff4f7SStella Laurenzo                                        py::module_local())
460436c6c9cSStella Laurenzo         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
461436c6c9cSStella Laurenzo         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
462436c6c9cSStella Laurenzo   }
463436c6c9cSStella Laurenzo 
464436c6c9cSStella Laurenzo private:
465436c6c9cSStella Laurenzo   PyIntegerSet set;
466436c6c9cSStella Laurenzo   intptr_t pos;
467436c6c9cSStella Laurenzo };
468436c6c9cSStella Laurenzo 
469436c6c9cSStella Laurenzo class PyIntegerSetConstraintList
470436c6c9cSStella Laurenzo     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
471436c6c9cSStella Laurenzo public:
472436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerSetConstraintList";
473436c6c9cSStella Laurenzo 
4741fc096afSMehdi Amini   PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
475436c6c9cSStella Laurenzo                              intptr_t length = -1, intptr_t step = 1)
476436c6c9cSStella Laurenzo       : Sliceable(startIndex,
477436c6c9cSStella Laurenzo                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
478436c6c9cSStella Laurenzo                   step),
479436c6c9cSStella Laurenzo         set(set) {}
480436c6c9cSStella Laurenzo 
481ee168fb9SAlex Zinenko private:
482ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
483ee168fb9SAlex Zinenko   friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
484436c6c9cSStella Laurenzo 
485ee168fb9SAlex Zinenko   intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
486ee168fb9SAlex Zinenko 
487ee168fb9SAlex Zinenko   PyIntegerSetConstraint getRawElement(intptr_t pos) {
488436c6c9cSStella Laurenzo     return PyIntegerSetConstraint(set, pos);
489436c6c9cSStella Laurenzo   }
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
492436c6c9cSStella Laurenzo                                    intptr_t step) {
493436c6c9cSStella Laurenzo     return PyIntegerSetConstraintList(set, startIndex, length, step);
494436c6c9cSStella Laurenzo   }
495436c6c9cSStella Laurenzo 
496436c6c9cSStella Laurenzo   PyIntegerSet set;
497436c6c9cSStella Laurenzo };
498436c6c9cSStella Laurenzo } // namespace
499436c6c9cSStella Laurenzo 
500e6d738e0SRahul Kayaith bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
501436c6c9cSStella Laurenzo   return mlirIntegerSetEqual(integerSet, other.integerSet);
502436c6c9cSStella Laurenzo }
503436c6c9cSStella Laurenzo 
504436c6c9cSStella Laurenzo py::object PyIntegerSet::getCapsule() {
505436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
506436c6c9cSStella Laurenzo       mlirPythonIntegerSetToCapsule(*this));
507436c6c9cSStella Laurenzo }
508436c6c9cSStella Laurenzo 
509436c6c9cSStella Laurenzo PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
510436c6c9cSStella Laurenzo   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
511436c6c9cSStella Laurenzo   if (mlirIntegerSetIsNull(rawIntegerSet))
512436c6c9cSStella Laurenzo     throw py::error_already_set();
513436c6c9cSStella Laurenzo   return PyIntegerSet(
514436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
515436c6c9cSStella Laurenzo       rawIntegerSet);
516436c6c9cSStella Laurenzo }
517436c6c9cSStella Laurenzo 
518436c6c9cSStella Laurenzo void mlir::python::populateIRAffine(py::module &m) {
519436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
520436c6c9cSStella Laurenzo   // Mapping of PyAffineExpr and derived classes.
521436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
522f05ff4f7SStella Laurenzo   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
523436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524436c6c9cSStella Laurenzo                              &PyAffineExpr::getCapsule)
525436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
526fc7594ccSAlex Zinenko       .def("__add__", &PyAffineAddExpr::get)
527fc7594ccSAlex Zinenko       .def("__add__", &PyAffineAddExpr::getRHSConstant)
528fc7594ccSAlex Zinenko       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
529fc7594ccSAlex Zinenko       .def("__mul__", &PyAffineMulExpr::get)
530fc7594ccSAlex Zinenko       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
531fc7594ccSAlex Zinenko       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
532fc7594ccSAlex Zinenko       .def("__mod__", &PyAffineModExpr::get)
533fc7594ccSAlex Zinenko       .def("__mod__", &PyAffineModExpr::getRHSConstant)
534fc7594ccSAlex Zinenko       .def("__rmod__",
535fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
536fc7594ccSAlex Zinenko              return PyAffineModExpr::get(
537fc7594ccSAlex Zinenko                  PyAffineConstantExpr::get(other, *self.getContext().get()),
538fc7594ccSAlex Zinenko                  self);
539436c6c9cSStella Laurenzo            })
540436c6c9cSStella Laurenzo       .def("__sub__",
541436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
542436c6c9cSStella Laurenzo              auto negOne =
543436c6c9cSStella Laurenzo                  PyAffineConstantExpr::get(-1, *self.getContext().get());
544436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self,
545436c6c9cSStella Laurenzo                                          PyAffineMulExpr::get(negOne, other));
546436c6c9cSStella Laurenzo            })
547fc7594ccSAlex Zinenko       .def("__sub__",
548fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
549fc7594ccSAlex Zinenko              return PyAffineAddExpr::get(
550fc7594ccSAlex Zinenko                  self,
551fc7594ccSAlex Zinenko                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
552fc7594ccSAlex Zinenko            })
553fc7594ccSAlex Zinenko       .def("__rsub__",
554fc7594ccSAlex Zinenko            [](PyAffineExpr &self, intptr_t other) {
555fc7594ccSAlex Zinenko              return PyAffineAddExpr::getLHSConstant(
556fc7594ccSAlex Zinenko                  other, PyAffineMulExpr::getLHSConstant(-1, self));
557fc7594ccSAlex Zinenko            })
558436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineExpr &self,
559436c6c9cSStella Laurenzo                         PyAffineExpr &other) { return self == other; })
560436c6c9cSStella Laurenzo       .def("__eq__",
561436c6c9cSStella Laurenzo            [](PyAffineExpr &self, py::object &other) { return false; })
562436c6c9cSStella Laurenzo       .def("__str__",
563436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
564436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
565436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
566436c6c9cSStella Laurenzo                                  printAccum.getUserData());
567436c6c9cSStella Laurenzo              return printAccum.join();
568436c6c9cSStella Laurenzo            })
569436c6c9cSStella Laurenzo       .def("__repr__",
570436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
571436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
572436c6c9cSStella Laurenzo              printAccum.parts.append("AffineExpr(");
573436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
574436c6c9cSStella Laurenzo                                  printAccum.getUserData());
575436c6c9cSStella Laurenzo              printAccum.parts.append(")");
576436c6c9cSStella Laurenzo              return printAccum.join();
577436c6c9cSStella Laurenzo            })
578fc7594ccSAlex Zinenko       .def("__hash__",
579fc7594ccSAlex Zinenko            [](PyAffineExpr &self) {
580fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
581fc7594ccSAlex Zinenko            })
582436c6c9cSStella Laurenzo       .def_property_readonly(
583436c6c9cSStella Laurenzo           "context",
584436c6c9cSStella Laurenzo           [](PyAffineExpr &self) { return self.getContext().getObject(); })
585fc7594ccSAlex Zinenko       .def("compose",
586fc7594ccSAlex Zinenko            [](PyAffineExpr &self, PyAffineMap &other) {
587fc7594ccSAlex Zinenko              return PyAffineExpr(self.getContext(),
588fc7594ccSAlex Zinenko                                  mlirAffineExprCompose(self, other));
589fc7594ccSAlex Zinenko            })
590436c6c9cSStella Laurenzo       .def_static(
591436c6c9cSStella Laurenzo           "get_add", &PyAffineAddExpr::get,
592436c6c9cSStella Laurenzo           "Gets an affine expression containing a sum of two expressions.")
593fc7594ccSAlex Zinenko       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
594fc7594ccSAlex Zinenko                   "Gets an affine expression containing a sum of a constant "
595fc7594ccSAlex Zinenko                   "and another expression.")
596fc7594ccSAlex Zinenko       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
597fc7594ccSAlex Zinenko                   "Gets an affine expression containing a sum of an expression "
598fc7594ccSAlex Zinenko                   "and a constant.")
599436c6c9cSStella Laurenzo       .def_static(
600436c6c9cSStella Laurenzo           "get_mul", &PyAffineMulExpr::get,
601436c6c9cSStella Laurenzo           "Gets an affine expression containing a product of two expressions.")
602fc7594ccSAlex Zinenko       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
603fc7594ccSAlex Zinenko                   "Gets an affine expression containing a product of a "
604fc7594ccSAlex Zinenko                   "constant and another expression.")
605fc7594ccSAlex Zinenko       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
606fc7594ccSAlex Zinenko                   "Gets an affine expression containing a product of an "
607fc7594ccSAlex Zinenko                   "expression and a constant.")
608436c6c9cSStella Laurenzo       .def_static("get_mod", &PyAffineModExpr::get,
609436c6c9cSStella Laurenzo                   "Gets an affine expression containing the modulo of dividing "
610436c6c9cSStella Laurenzo                   "one expression by another.")
611fc7594ccSAlex Zinenko       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
612fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the modulo of "
613fc7594ccSAlex Zinenko                   "dividing a constant by an expression.")
614fc7594ccSAlex Zinenko       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
615fc7594ccSAlex Zinenko                   "Gets an affine expression containing the module of dividing"
616fc7594ccSAlex Zinenko                   "an expression by a constant.")
617436c6c9cSStella Laurenzo       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
618436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-down "
619436c6c9cSStella Laurenzo                   "result of dividing one expression by another.")
620fc7594ccSAlex Zinenko       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
621fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the rounded-down "
622fc7594ccSAlex Zinenko                   "result of dividing a constant by an expression.")
623fc7594ccSAlex Zinenko       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
624fc7594ccSAlex Zinenko                   "Gets an affine expression containing the rounded-down "
625fc7594ccSAlex Zinenko                   "result of dividing an expression by a constant.")
626436c6c9cSStella Laurenzo       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
627436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-up result "
628436c6c9cSStella Laurenzo                   "of dividing one expression by another.")
629fc7594ccSAlex Zinenko       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
630fc7594ccSAlex Zinenko                   "Gets a semi-affine expression containing the rounded-up "
631fc7594ccSAlex Zinenko                   "result of dividing a constant by an expression.")
632fc7594ccSAlex Zinenko       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
633fc7594ccSAlex Zinenko                   "Gets an affine expression containing the rounded-up result "
634fc7594ccSAlex Zinenko                   "of dividing an expression by a constant.")
635436c6c9cSStella Laurenzo       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
636436c6c9cSStella Laurenzo                   py::arg("context") = py::none(),
637436c6c9cSStella Laurenzo                   "Gets a constant affine expression with the given value.")
638436c6c9cSStella Laurenzo       .def_static(
639436c6c9cSStella Laurenzo           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
640436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
641436c6c9cSStella Laurenzo           "Gets an affine expression of a dimension at the given position.")
642436c6c9cSStella Laurenzo       .def_static(
643436c6c9cSStella Laurenzo           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
644436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
645436c6c9cSStella Laurenzo           "Gets an affine expression of a symbol at the given position.")
646436c6c9cSStella Laurenzo       .def(
647436c6c9cSStella Laurenzo           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
648436c6c9cSStella Laurenzo           kDumpDocstring);
649436c6c9cSStella Laurenzo   PyAffineConstantExpr::bind(m);
650436c6c9cSStella Laurenzo   PyAffineDimExpr::bind(m);
651436c6c9cSStella Laurenzo   PyAffineSymbolExpr::bind(m);
652436c6c9cSStella Laurenzo   PyAffineBinaryExpr::bind(m);
653436c6c9cSStella Laurenzo   PyAffineAddExpr::bind(m);
654436c6c9cSStella Laurenzo   PyAffineMulExpr::bind(m);
655436c6c9cSStella Laurenzo   PyAffineModExpr::bind(m);
656436c6c9cSStella Laurenzo   PyAffineFloorDivExpr::bind(m);
657436c6c9cSStella Laurenzo   PyAffineCeilDivExpr::bind(m);
658436c6c9cSStella Laurenzo 
659436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
660436c6c9cSStella Laurenzo   // Mapping of PyAffineMap.
661436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
662f05ff4f7SStella Laurenzo   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
663436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
664436c6c9cSStella Laurenzo                              &PyAffineMap::getCapsule)
665436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
666436c6c9cSStella Laurenzo       .def("__eq__",
667436c6c9cSStella Laurenzo            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
668436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
669436c6c9cSStella Laurenzo       .def("__str__",
670436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
671436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
672436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
673436c6c9cSStella Laurenzo                                 printAccum.getUserData());
674436c6c9cSStella Laurenzo              return printAccum.join();
675436c6c9cSStella Laurenzo            })
676436c6c9cSStella Laurenzo       .def("__repr__",
677436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
678436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
679436c6c9cSStella Laurenzo              printAccum.parts.append("AffineMap(");
680436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
681436c6c9cSStella Laurenzo                                 printAccum.getUserData());
682436c6c9cSStella Laurenzo              printAccum.parts.append(")");
683436c6c9cSStella Laurenzo              return printAccum.join();
684436c6c9cSStella Laurenzo            })
685fc7594ccSAlex Zinenko       .def("__hash__",
686fc7594ccSAlex Zinenko            [](PyAffineMap &self) {
687fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
688fc7594ccSAlex Zinenko            })
689335d2df5SNicolas Vasilache       .def_static("compress_unused_symbols",
690335d2df5SNicolas Vasilache                   [](py::list affineMaps, DefaultingPyMlirContext context) {
691335d2df5SNicolas Vasilache                     SmallVector<MlirAffineMap> maps;
692335d2df5SNicolas Vasilache                     pyListToVector<PyAffineMap, MlirAffineMap>(
693335d2df5SNicolas Vasilache                         affineMaps, maps, "attempting to create an AffineMap");
694335d2df5SNicolas Vasilache                     std::vector<MlirAffineMap> compressed(affineMaps.size());
695335d2df5SNicolas Vasilache                     auto populate = [](void *result, intptr_t idx,
696335d2df5SNicolas Vasilache                                        MlirAffineMap m) {
697335d2df5SNicolas Vasilache                       static_cast<MlirAffineMap *>(result)[idx] = (m);
698335d2df5SNicolas Vasilache                     };
699335d2df5SNicolas Vasilache                     mlirAffineMapCompressUnusedSymbols(
700335d2df5SNicolas Vasilache                         maps.data(), maps.size(), compressed.data(), populate);
701335d2df5SNicolas Vasilache                     std::vector<PyAffineMap> res;
702e2f16be5SMehdi Amini                     res.reserve(compressed.size());
703335d2df5SNicolas Vasilache                     for (auto m : compressed)
704e5639b3fSMehdi Amini                       res.emplace_back(context->getRef(), m);
705335d2df5SNicolas Vasilache                     return res;
706335d2df5SNicolas Vasilache                   })
707436c6c9cSStella Laurenzo       .def_property_readonly(
708436c6c9cSStella Laurenzo           "context",
709436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return self.getContext().getObject(); },
710436c6c9cSStella Laurenzo           "Context that owns the Affine Map")
711436c6c9cSStella Laurenzo       .def(
712436c6c9cSStella Laurenzo           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
713436c6c9cSStella Laurenzo           kDumpDocstring)
714436c6c9cSStella Laurenzo       .def_static(
715436c6c9cSStella Laurenzo           "get",
716436c6c9cSStella Laurenzo           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
717436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
718436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
719436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr, MlirAffineExpr>(
720337c937dSMehdi Amini                 exprs, affineExprs, "attempting to create an AffineMap");
721436c6c9cSStella Laurenzo             MlirAffineMap map =
722436c6c9cSStella Laurenzo                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
723436c6c9cSStella Laurenzo                                  affineExprs.size(), affineExprs.data());
724436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), map);
725436c6c9cSStella Laurenzo           },
726436c6c9cSStella Laurenzo           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
727436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
728436c6c9cSStella Laurenzo           "Gets a map with the given expressions as results.")
729436c6c9cSStella Laurenzo       .def_static(
730436c6c9cSStella Laurenzo           "get_constant",
731436c6c9cSStella Laurenzo           [](intptr_t value, DefaultingPyMlirContext context) {
732436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
733436c6c9cSStella Laurenzo                 mlirAffineMapConstantGet(context->get(), value);
734436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
735436c6c9cSStella Laurenzo           },
736436c6c9cSStella Laurenzo           py::arg("value"), py::arg("context") = py::none(),
737436c6c9cSStella Laurenzo           "Gets an affine map with a single constant result")
738436c6c9cSStella Laurenzo       .def_static(
739436c6c9cSStella Laurenzo           "get_empty",
740436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
741436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
742436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
743436c6c9cSStella Laurenzo           },
744436c6c9cSStella Laurenzo           py::arg("context") = py::none(), "Gets an empty affine map.")
745436c6c9cSStella Laurenzo       .def_static(
746436c6c9cSStella Laurenzo           "get_identity",
747436c6c9cSStella Laurenzo           [](intptr_t nDims, DefaultingPyMlirContext context) {
748436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
749436c6c9cSStella Laurenzo                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
750436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
751436c6c9cSStella Laurenzo           },
752436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("context") = py::none(),
753436c6c9cSStella Laurenzo           "Gets an identity map with the given number of dimensions.")
754436c6c9cSStella Laurenzo       .def_static(
755436c6c9cSStella Laurenzo           "get_minor_identity",
756436c6c9cSStella Laurenzo           [](intptr_t nDims, intptr_t nResults,
757436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
758436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
759436c6c9cSStella Laurenzo                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
760436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
761436c6c9cSStella Laurenzo           },
762436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("n_results"),
763436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
764436c6c9cSStella Laurenzo           "Gets a minor identity map with the given number of dimensions and "
765436c6c9cSStella Laurenzo           "results.")
766436c6c9cSStella Laurenzo       .def_static(
767436c6c9cSStella Laurenzo           "get_permutation",
768436c6c9cSStella Laurenzo           [](std::vector<unsigned> permutation,
769436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
770436c6c9cSStella Laurenzo             if (!isPermutation(permutation))
771436c6c9cSStella Laurenzo               throw py::cast_error("Invalid permutation when attempting to "
772436c6c9cSStella Laurenzo                                    "create an AffineMap");
773436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
774436c6c9cSStella Laurenzo                 context->get(), permutation.size(), permutation.data());
775436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
776436c6c9cSStella Laurenzo           },
777436c6c9cSStella Laurenzo           py::arg("permutation"), py::arg("context") = py::none(),
778436c6c9cSStella Laurenzo           "Gets an affine map that permutes its inputs.")
779a6e7d024SStella Laurenzo       .def(
780a6e7d024SStella Laurenzo           "get_submap",
781436c6c9cSStella Laurenzo           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
782436c6c9cSStella Laurenzo             intptr_t numResults = mlirAffineMapGetNumResults(self);
783436c6c9cSStella Laurenzo             for (intptr_t pos : resultPos) {
784436c6c9cSStella Laurenzo               if (pos < 0 || pos >= numResults)
785436c6c9cSStella Laurenzo                 throw py::value_error("result position out of bounds");
786436c6c9cSStella Laurenzo             }
787436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
788436c6c9cSStella Laurenzo                 self, resultPos.size(), resultPos.data());
789436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
790a6e7d024SStella Laurenzo           },
791a6e7d024SStella Laurenzo           py::arg("result_positions"))
792a6e7d024SStella Laurenzo       .def(
793a6e7d024SStella Laurenzo           "get_major_submap",
794436c6c9cSStella Laurenzo           [](PyAffineMap &self, intptr_t nResults) {
795436c6c9cSStella Laurenzo             if (nResults >= mlirAffineMapGetNumResults(self))
796436c6c9cSStella Laurenzo               throw py::value_error("number of results out of bounds");
797436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
798436c6c9cSStella Laurenzo                 mlirAffineMapGetMajorSubMap(self, nResults);
799436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
800a6e7d024SStella Laurenzo           },
801a6e7d024SStella Laurenzo           py::arg("n_results"))
802a6e7d024SStella Laurenzo       .def(
803a6e7d024SStella Laurenzo           "get_minor_submap",
804436c6c9cSStella Laurenzo           [](PyAffineMap &self, intptr_t nResults) {
805436c6c9cSStella Laurenzo             if (nResults >= mlirAffineMapGetNumResults(self))
806436c6c9cSStella Laurenzo               throw py::value_error("number of results out of bounds");
807436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
808436c6c9cSStella Laurenzo                 mlirAffineMapGetMinorSubMap(self, nResults);
809436c6c9cSStella Laurenzo             return PyAffineMap(self.getContext(), affineMap);
810a6e7d024SStella Laurenzo           },
811a6e7d024SStella Laurenzo           py::arg("n_results"))
812a6e7d024SStella Laurenzo       .def(
813a6e7d024SStella Laurenzo           "replace",
81431f888eaSTobias Gysi           [](PyAffineMap &self, PyAffineExpr &expression,
81531f888eaSTobias Gysi              PyAffineExpr &replacement, intptr_t numResultDims,
81631f888eaSTobias Gysi              intptr_t numResultSyms) {
81731f888eaSTobias Gysi             MlirAffineMap affineMap = mlirAffineMapReplace(
81831f888eaSTobias Gysi                 self, expression, replacement, numResultDims, numResultSyms);
81931f888eaSTobias Gysi             return PyAffineMap(self.getContext(), affineMap);
820a6e7d024SStella Laurenzo           },
821a6e7d024SStella Laurenzo           py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
822a6e7d024SStella Laurenzo           py::arg("n_result_syms"))
823436c6c9cSStella Laurenzo       .def_property_readonly(
824436c6c9cSStella Laurenzo           "is_permutation",
825436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
826436c6c9cSStella Laurenzo       .def_property_readonly("is_projected_permutation",
827436c6c9cSStella Laurenzo                              [](PyAffineMap &self) {
828436c6c9cSStella Laurenzo                                return mlirAffineMapIsProjectedPermutation(self);
829436c6c9cSStella Laurenzo                              })
830436c6c9cSStella Laurenzo       .def_property_readonly(
831436c6c9cSStella Laurenzo           "n_dims",
832436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
833436c6c9cSStella Laurenzo       .def_property_readonly(
834436c6c9cSStella Laurenzo           "n_inputs",
835436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
836436c6c9cSStella Laurenzo       .def_property_readonly(
837436c6c9cSStella Laurenzo           "n_symbols",
838436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
839436c6c9cSStella Laurenzo       .def_property_readonly("results", [](PyAffineMap &self) {
840436c6c9cSStella Laurenzo         return PyAffineMapExprList(self);
841436c6c9cSStella Laurenzo       });
842436c6c9cSStella Laurenzo   PyAffineMapExprList::bind(m);
843436c6c9cSStella Laurenzo 
844436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
845436c6c9cSStella Laurenzo   // Mapping of PyIntegerSet.
846436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
847f05ff4f7SStella Laurenzo   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
848436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
849436c6c9cSStella Laurenzo                              &PyIntegerSet::getCapsule)
850436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
851436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self,
852436c6c9cSStella Laurenzo                         PyIntegerSet &other) { return self == other; })
853436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
854436c6c9cSStella Laurenzo       .def("__str__",
855436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
856436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
857436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
858436c6c9cSStella Laurenzo                                  printAccum.getUserData());
859436c6c9cSStella Laurenzo              return printAccum.join();
860436c6c9cSStella Laurenzo            })
861436c6c9cSStella Laurenzo       .def("__repr__",
862436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
863436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
864436c6c9cSStella Laurenzo              printAccum.parts.append("IntegerSet(");
865436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
866436c6c9cSStella Laurenzo                                  printAccum.getUserData());
867436c6c9cSStella Laurenzo              printAccum.parts.append(")");
868436c6c9cSStella Laurenzo              return printAccum.join();
869436c6c9cSStella Laurenzo            })
870fc7594ccSAlex Zinenko       .def("__hash__",
871fc7594ccSAlex Zinenko            [](PyIntegerSet &self) {
872fc7594ccSAlex Zinenko              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
873fc7594ccSAlex Zinenko            })
874436c6c9cSStella Laurenzo       .def_property_readonly(
875436c6c9cSStella Laurenzo           "context",
876436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return self.getContext().getObject(); })
877436c6c9cSStella Laurenzo       .def(
878436c6c9cSStella Laurenzo           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
879436c6c9cSStella Laurenzo           kDumpDocstring)
880436c6c9cSStella Laurenzo       .def_static(
881436c6c9cSStella Laurenzo           "get",
882436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
883436c6c9cSStella Laurenzo              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
884436c6c9cSStella Laurenzo             if (exprs.size() != eqFlags.size())
885436c6c9cSStella Laurenzo               throw py::value_error(
886436c6c9cSStella Laurenzo                   "Expected the number of constraints to match "
887436c6c9cSStella Laurenzo                   "that of equality flags");
888436c6c9cSStella Laurenzo             if (exprs.empty())
889436c6c9cSStella Laurenzo               throw py::value_error("Expected non-empty list of constraints");
890436c6c9cSStella Laurenzo 
891436c6c9cSStella Laurenzo             // Copy over to a SmallVector because std::vector has a
892436c6c9cSStella Laurenzo             // specialization for booleans that packs data and does not
893436c6c9cSStella Laurenzo             // expose a `bool *`.
894436c6c9cSStella Laurenzo             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
895436c6c9cSStella Laurenzo 
896436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
897436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(exprs, affineExprs,
898436c6c9cSStella Laurenzo                                          "attempting to create an IntegerSet");
899436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetGet(
900436c6c9cSStella Laurenzo                 context->get(), numDims, numSymbols, exprs.size(),
901436c6c9cSStella Laurenzo                 affineExprs.data(), flags.data());
902436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
903436c6c9cSStella Laurenzo           },
904436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
905436c6c9cSStella Laurenzo           py::arg("eq_flags"), py::arg("context") = py::none())
906436c6c9cSStella Laurenzo       .def_static(
907436c6c9cSStella Laurenzo           "get_empty",
908436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols,
909436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
910436c6c9cSStella Laurenzo             MlirIntegerSet set =
911436c6c9cSStella Laurenzo                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
912436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
913436c6c9cSStella Laurenzo           },
914436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"),
915436c6c9cSStella Laurenzo           py::arg("context") = py::none())
916a6e7d024SStella Laurenzo       .def(
917a6e7d024SStella Laurenzo           "get_replaced",
918436c6c9cSStella Laurenzo           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
919436c6c9cSStella Laurenzo              intptr_t numResultDims, intptr_t numResultSymbols) {
920436c6c9cSStella Laurenzo             if (static_cast<intptr_t>(dimExprs.size()) !=
921436c6c9cSStella Laurenzo                 mlirIntegerSetGetNumDims(self))
922436c6c9cSStella Laurenzo               throw py::value_error(
923436c6c9cSStella Laurenzo                   "Expected the number of dimension replacement expressions "
924436c6c9cSStella Laurenzo                   "to match that of dimensions");
925436c6c9cSStella Laurenzo             if (static_cast<intptr_t>(symbolExprs.size()) !=
926436c6c9cSStella Laurenzo                 mlirIntegerSetGetNumSymbols(self))
927436c6c9cSStella Laurenzo               throw py::value_error(
928436c6c9cSStella Laurenzo                   "Expected the number of symbol replacement expressions "
929436c6c9cSStella Laurenzo                   "to match that of symbols");
930436c6c9cSStella Laurenzo 
931436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
932436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(
933436c6c9cSStella Laurenzo                 dimExprs, dimAffineExprs,
934436c6c9cSStella Laurenzo                 "attempting to create an IntegerSet by replacing dimensions");
935436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(
936436c6c9cSStella Laurenzo                 symbolExprs, symbolAffineExprs,
937436c6c9cSStella Laurenzo                 "attempting to create an IntegerSet by replacing symbols");
938436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetReplaceGet(
939436c6c9cSStella Laurenzo                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
940436c6c9cSStella Laurenzo                 numResultDims, numResultSymbols);
941436c6c9cSStella Laurenzo             return PyIntegerSet(self.getContext(), set);
942a6e7d024SStella Laurenzo           },
943a6e7d024SStella Laurenzo           py::arg("dim_exprs"), py::arg("symbol_exprs"),
944a6e7d024SStella Laurenzo           py::arg("num_result_dims"), py::arg("num_result_symbols"))
945436c6c9cSStella Laurenzo       .def_property_readonly("is_canonical_empty",
946436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
947436c6c9cSStella Laurenzo                                return mlirIntegerSetIsCanonicalEmpty(self);
948436c6c9cSStella Laurenzo                              })
949436c6c9cSStella Laurenzo       .def_property_readonly(
950436c6c9cSStella Laurenzo           "n_dims",
951436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
952436c6c9cSStella Laurenzo       .def_property_readonly(
953436c6c9cSStella Laurenzo           "n_symbols",
954436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
955436c6c9cSStella Laurenzo       .def_property_readonly(
956436c6c9cSStella Laurenzo           "n_inputs",
957436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
958436c6c9cSStella Laurenzo       .def_property_readonly("n_equalities",
959436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
960436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumEqualities(self);
961436c6c9cSStella Laurenzo                              })
962436c6c9cSStella Laurenzo       .def_property_readonly("n_inequalities",
963436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
964436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumInequalities(self);
965436c6c9cSStella Laurenzo                              })
966436c6c9cSStella Laurenzo       .def_property_readonly("constraints", [](PyIntegerSet &self) {
967436c6c9cSStella Laurenzo         return PyIntegerSetConstraintList(self);
968436c6c9cSStella Laurenzo       });
969436c6c9cSStella Laurenzo   PyIntegerSetConstraint::bind(m);
970436c6c9cSStella Laurenzo   PyIntegerSetConstraintList::bind(m);
971436c6c9cSStella Laurenzo }
972