xref: /llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp (revision 78f2dae00d32504d1f645f74c67bf4340ebcda82)
1436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "PybindUtils.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h"
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h"
16436c6c9cSStella Laurenzo 
17436c6c9cSStella Laurenzo namespace py = pybind11;
18436c6c9cSStella Laurenzo using namespace mlir;
19436c6c9cSStella Laurenzo using namespace mlir::python;
20436c6c9cSStella Laurenzo 
21436c6c9cSStella Laurenzo using llvm::SmallVector;
22436c6c9cSStella Laurenzo using llvm::StringRef;
23436c6c9cSStella Laurenzo using llvm::Twine;
24436c6c9cSStella Laurenzo 
25436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
26436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the
29436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments).
30436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller
31436c6c9cSStella Laurenzo /// was attempting to do.
32436c6c9cSStella Laurenzo template <typename PyType, typename CType>
33436c6c9cSStella Laurenzo static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34436c6c9cSStella Laurenzo                            StringRef action) {
35436c6c9cSStella Laurenzo   result.reserve(py::len(list));
36436c6c9cSStella Laurenzo   for (py::handle item : list) {
37436c6c9cSStella Laurenzo     try {
38436c6c9cSStella Laurenzo       result.push_back(item.cast<PyType>());
39436c6c9cSStella Laurenzo     } catch (py::cast_error &err) {
40436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41436c6c9cSStella Laurenzo                          " (" + err.what() + ")")
42436c6c9cSStella Laurenzo                             .str();
43436c6c9cSStella Laurenzo       throw py::cast_error(msg);
44436c6c9cSStella Laurenzo     } catch (py::reference_cast_error &err) {
45436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46436c6c9cSStella Laurenzo                          action + " (" + err.what() + ")")
47436c6c9cSStella Laurenzo                             .str();
48436c6c9cSStella Laurenzo       throw py::cast_error(msg);
49436c6c9cSStella Laurenzo     }
50436c6c9cSStella Laurenzo   }
51436c6c9cSStella Laurenzo }
52436c6c9cSStella Laurenzo 
53436c6c9cSStella Laurenzo template <typename PermutationTy>
54436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) {
55436c6c9cSStella Laurenzo   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56436c6c9cSStella Laurenzo   for (auto val : permutation) {
57436c6c9cSStella Laurenzo     if (val < permutation.size()) {
58436c6c9cSStella Laurenzo       if (seen[val])
59436c6c9cSStella Laurenzo         return false;
60436c6c9cSStella Laurenzo       seen[val] = true;
61436c6c9cSStella Laurenzo       continue;
62436c6c9cSStella Laurenzo     }
63436c6c9cSStella Laurenzo     return false;
64436c6c9cSStella Laurenzo   }
65436c6c9cSStella Laurenzo   return true;
66436c6c9cSStella Laurenzo }
67436c6c9cSStella Laurenzo 
68436c6c9cSStella Laurenzo namespace {
69436c6c9cSStella Laurenzo 
70436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be
72436c6c9cSStella Laurenzo /// modeled by specifying BaseTy.
73436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy {
75436c6c9cSStella Laurenzo public:
76436c6c9cSStella Laurenzo   // Derived classes must define statics for:
77436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
78436c6c9cSStella Laurenzo   //   const char *pyClassName
79436c6c9cSStella Laurenzo   // and redefine bindDerived.
80436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, BaseTy>;
81436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo   PyConcreteAffineExpr() = default;
84436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), affineExpr) {}
86436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyAffineExpr &orig)
87436c6c9cSStella Laurenzo       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
91436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError,
93436c6c9cSStella Laurenzo                        Twine("Cannot cast affine expression to ") +
94436c6c9cSStella Laurenzo                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95436c6c9cSStella Laurenzo     }
96436c6c9cSStella Laurenzo     return orig;
97436c6c9cSStella Laurenzo   }
98436c6c9cSStella Laurenzo 
99436c6c9cSStella Laurenzo   static void bind(py::module &m) {
100f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101436c6c9cSStella Laurenzo     cls.def(py::init<PyAffineExpr &>());
102*78f2dae0SAlex Zinenko     cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
103*78f2dae0SAlex Zinenko       return DerivedTy::isaFunction(otherAffineExpr);
104*78f2dae0SAlex Zinenko     });
105436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
106436c6c9cSStella Laurenzo   }
107436c6c9cSStella Laurenzo 
108436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
109436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
110436c6c9cSStella Laurenzo };
111436c6c9cSStella Laurenzo 
112436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
113436c6c9cSStella Laurenzo public:
114436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
115436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineConstantExpr";
116436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo   static PyAffineConstantExpr get(intptr_t value,
119436c6c9cSStella Laurenzo                                   DefaultingPyMlirContext context) {
120436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr =
121436c6c9cSStella Laurenzo         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
122436c6c9cSStella Laurenzo     return PyAffineConstantExpr(context->getRef(), affineExpr);
123436c6c9cSStella Laurenzo   }
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
126436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
127436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
128436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
129436c6c9cSStella Laurenzo       return mlirAffineConstantExprGetValue(self);
130436c6c9cSStella Laurenzo     });
131436c6c9cSStella Laurenzo   }
132436c6c9cSStella Laurenzo };
133436c6c9cSStella Laurenzo 
134436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
135436c6c9cSStella Laurenzo public:
136436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
137436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineDimExpr";
138436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
139436c6c9cSStella Laurenzo 
140436c6c9cSStella Laurenzo   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
141436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
142436c6c9cSStella Laurenzo     return PyAffineDimExpr(context->getRef(), affineExpr);
143436c6c9cSStella Laurenzo   }
144436c6c9cSStella Laurenzo 
145436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
146436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
147436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
148436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
149436c6c9cSStella Laurenzo       return mlirAffineDimExprGetPosition(self);
150436c6c9cSStella Laurenzo     });
151436c6c9cSStella Laurenzo   }
152436c6c9cSStella Laurenzo };
153436c6c9cSStella Laurenzo 
154436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
155436c6c9cSStella Laurenzo public:
156436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
157436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineSymbolExpr";
158436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
159436c6c9cSStella Laurenzo 
160436c6c9cSStella Laurenzo   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
161436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
162436c6c9cSStella Laurenzo     return PyAffineSymbolExpr(context->getRef(), affineExpr);
163436c6c9cSStella Laurenzo   }
164436c6c9cSStella Laurenzo 
165436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
166436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
167436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
168436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
169436c6c9cSStella Laurenzo       return mlirAffineSymbolExprGetPosition(self);
170436c6c9cSStella Laurenzo     });
171436c6c9cSStella Laurenzo   }
172436c6c9cSStella Laurenzo };
173436c6c9cSStella Laurenzo 
174436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
175436c6c9cSStella Laurenzo public:
176436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
177436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineBinaryExpr";
178436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
179436c6c9cSStella Laurenzo 
180436c6c9cSStella Laurenzo   PyAffineExpr lhs() {
181436c6c9cSStella Laurenzo     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
182436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), lhsExpr);
183436c6c9cSStella Laurenzo   }
184436c6c9cSStella Laurenzo 
185436c6c9cSStella Laurenzo   PyAffineExpr rhs() {
186436c6c9cSStella Laurenzo     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
187436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), rhsExpr);
188436c6c9cSStella Laurenzo   }
189436c6c9cSStella Laurenzo 
190436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
191436c6c9cSStella Laurenzo     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
192436c6c9cSStella Laurenzo     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
193436c6c9cSStella Laurenzo   }
194436c6c9cSStella Laurenzo };
195436c6c9cSStella Laurenzo 
196436c6c9cSStella Laurenzo class PyAffineAddExpr
197436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
198436c6c9cSStella Laurenzo public:
199436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
200436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineAddExpr";
201436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
202436c6c9cSStella Laurenzo 
203436c6c9cSStella Laurenzo   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
204436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
205436c6c9cSStella Laurenzo     return PyAffineAddExpr(lhs.getContext(), expr);
206436c6c9cSStella Laurenzo   }
207436c6c9cSStella Laurenzo 
208436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
209436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineAddExpr::get);
210436c6c9cSStella Laurenzo   }
211436c6c9cSStella Laurenzo };
212436c6c9cSStella Laurenzo 
213436c6c9cSStella Laurenzo class PyAffineMulExpr
214436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
215436c6c9cSStella Laurenzo public:
216436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
217436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMulExpr";
218436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
219436c6c9cSStella Laurenzo 
220436c6c9cSStella Laurenzo   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
221436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
222436c6c9cSStella Laurenzo     return PyAffineMulExpr(lhs.getContext(), expr);
223436c6c9cSStella Laurenzo   }
224436c6c9cSStella Laurenzo 
225436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
226436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineMulExpr::get);
227436c6c9cSStella Laurenzo   }
228436c6c9cSStella Laurenzo };
229436c6c9cSStella Laurenzo 
230436c6c9cSStella Laurenzo class PyAffineModExpr
231436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
232436c6c9cSStella Laurenzo public:
233436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
234436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineModExpr";
235436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
236436c6c9cSStella Laurenzo 
237436c6c9cSStella Laurenzo   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
238436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
239436c6c9cSStella Laurenzo     return PyAffineModExpr(lhs.getContext(), expr);
240436c6c9cSStella Laurenzo   }
241436c6c9cSStella Laurenzo 
242436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
243436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineModExpr::get);
244436c6c9cSStella Laurenzo   }
245436c6c9cSStella Laurenzo };
246436c6c9cSStella Laurenzo 
247436c6c9cSStella Laurenzo class PyAffineFloorDivExpr
248436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
249436c6c9cSStella Laurenzo public:
250436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
251436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineFloorDivExpr";
252436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
253436c6c9cSStella Laurenzo 
254436c6c9cSStella Laurenzo   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
255436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
256436c6c9cSStella Laurenzo     return PyAffineFloorDivExpr(lhs.getContext(), expr);
257436c6c9cSStella Laurenzo   }
258436c6c9cSStella Laurenzo 
259436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
260436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineFloorDivExpr::get);
261436c6c9cSStella Laurenzo   }
262436c6c9cSStella Laurenzo };
263436c6c9cSStella Laurenzo 
264436c6c9cSStella Laurenzo class PyAffineCeilDivExpr
265436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
266436c6c9cSStella Laurenzo public:
267436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
268436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineCeilDivExpr";
269436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
270436c6c9cSStella Laurenzo 
271436c6c9cSStella Laurenzo   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
272436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
273436c6c9cSStella Laurenzo     return PyAffineCeilDivExpr(lhs.getContext(), expr);
274436c6c9cSStella Laurenzo   }
275436c6c9cSStella Laurenzo 
276436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
277436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineCeilDivExpr::get);
278436c6c9cSStella Laurenzo   }
279436c6c9cSStella Laurenzo };
280436c6c9cSStella Laurenzo 
281436c6c9cSStella Laurenzo } // namespace
282436c6c9cSStella Laurenzo 
283436c6c9cSStella Laurenzo bool PyAffineExpr::operator==(const PyAffineExpr &other) {
284436c6c9cSStella Laurenzo   return mlirAffineExprEqual(affineExpr, other.affineExpr);
285436c6c9cSStella Laurenzo }
286436c6c9cSStella Laurenzo 
287436c6c9cSStella Laurenzo py::object PyAffineExpr::getCapsule() {
288436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
289436c6c9cSStella Laurenzo       mlirPythonAffineExprToCapsule(*this));
290436c6c9cSStella Laurenzo }
291436c6c9cSStella Laurenzo 
292436c6c9cSStella Laurenzo PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
293436c6c9cSStella Laurenzo   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
294436c6c9cSStella Laurenzo   if (mlirAffineExprIsNull(rawAffineExpr))
295436c6c9cSStella Laurenzo     throw py::error_already_set();
296436c6c9cSStella Laurenzo   return PyAffineExpr(
297436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
298436c6c9cSStella Laurenzo       rawAffineExpr);
299436c6c9cSStella Laurenzo }
300436c6c9cSStella Laurenzo 
301436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
302436c6c9cSStella Laurenzo // PyAffineMap and utilities.
303436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
304436c6c9cSStella Laurenzo namespace {
305436c6c9cSStella Laurenzo 
306436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are
307436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both
308436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother
309436c6c9cSStella Laurenzo /// with lifetime extension.
310436c6c9cSStella Laurenzo class PyAffineMapExprList
311436c6c9cSStella Laurenzo     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
312436c6c9cSStella Laurenzo public:
313436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineExprList";
314436c6c9cSStella Laurenzo 
315436c6c9cSStella Laurenzo   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
316436c6c9cSStella Laurenzo                       intptr_t length = -1, intptr_t step = 1)
317436c6c9cSStella Laurenzo       : Sliceable(startIndex,
318436c6c9cSStella Laurenzo                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
319436c6c9cSStella Laurenzo                   step),
320436c6c9cSStella Laurenzo         affineMap(map) {}
321436c6c9cSStella Laurenzo 
322436c6c9cSStella Laurenzo   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
323436c6c9cSStella Laurenzo 
324436c6c9cSStella Laurenzo   PyAffineExpr getElement(intptr_t pos) {
325436c6c9cSStella Laurenzo     return PyAffineExpr(affineMap.getContext(),
326436c6c9cSStella Laurenzo                         mlirAffineMapGetResult(affineMap, pos));
327436c6c9cSStella Laurenzo   }
328436c6c9cSStella Laurenzo 
329436c6c9cSStella Laurenzo   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
330436c6c9cSStella Laurenzo                             intptr_t step) {
331436c6c9cSStella Laurenzo     return PyAffineMapExprList(affineMap, startIndex, length, step);
332436c6c9cSStella Laurenzo   }
333436c6c9cSStella Laurenzo 
334436c6c9cSStella Laurenzo private:
335436c6c9cSStella Laurenzo   PyAffineMap affineMap;
336436c6c9cSStella Laurenzo };
337436c6c9cSStella Laurenzo } // end namespace
338436c6c9cSStella Laurenzo 
339436c6c9cSStella Laurenzo bool PyAffineMap::operator==(const PyAffineMap &other) {
340436c6c9cSStella Laurenzo   return mlirAffineMapEqual(affineMap, other.affineMap);
341436c6c9cSStella Laurenzo }
342436c6c9cSStella Laurenzo 
343436c6c9cSStella Laurenzo py::object PyAffineMap::getCapsule() {
344436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
345436c6c9cSStella Laurenzo }
346436c6c9cSStella Laurenzo 
347436c6c9cSStella Laurenzo PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
348436c6c9cSStella Laurenzo   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
349436c6c9cSStella Laurenzo   if (mlirAffineMapIsNull(rawAffineMap))
350436c6c9cSStella Laurenzo     throw py::error_already_set();
351436c6c9cSStella Laurenzo   return PyAffineMap(
352436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
353436c6c9cSStella Laurenzo       rawAffineMap);
354436c6c9cSStella Laurenzo }
355436c6c9cSStella Laurenzo 
356436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
357436c6c9cSStella Laurenzo // PyIntegerSet and utilities.
358436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
359436c6c9cSStella Laurenzo namespace {
360436c6c9cSStella Laurenzo 
361436c6c9cSStella Laurenzo class PyIntegerSetConstraint {
362436c6c9cSStella Laurenzo public:
363436c6c9cSStella Laurenzo   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
364436c6c9cSStella Laurenzo 
365436c6c9cSStella Laurenzo   PyAffineExpr getExpr() {
366436c6c9cSStella Laurenzo     return PyAffineExpr(set.getContext(),
367436c6c9cSStella Laurenzo                         mlirIntegerSetGetConstraint(set, pos));
368436c6c9cSStella Laurenzo   }
369436c6c9cSStella Laurenzo 
370436c6c9cSStella Laurenzo   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
371436c6c9cSStella Laurenzo 
372436c6c9cSStella Laurenzo   static void bind(py::module &m) {
373f05ff4f7SStella Laurenzo     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
374f05ff4f7SStella Laurenzo                                        py::module_local())
375436c6c9cSStella Laurenzo         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
376436c6c9cSStella Laurenzo         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
377436c6c9cSStella Laurenzo   }
378436c6c9cSStella Laurenzo 
379436c6c9cSStella Laurenzo private:
380436c6c9cSStella Laurenzo   PyIntegerSet set;
381436c6c9cSStella Laurenzo   intptr_t pos;
382436c6c9cSStella Laurenzo };
383436c6c9cSStella Laurenzo 
384436c6c9cSStella Laurenzo class PyIntegerSetConstraintList
385436c6c9cSStella Laurenzo     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
386436c6c9cSStella Laurenzo public:
387436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerSetConstraintList";
388436c6c9cSStella Laurenzo 
389436c6c9cSStella Laurenzo   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
390436c6c9cSStella Laurenzo                              intptr_t length = -1, intptr_t step = 1)
391436c6c9cSStella Laurenzo       : Sliceable(startIndex,
392436c6c9cSStella Laurenzo                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
393436c6c9cSStella Laurenzo                   step),
394436c6c9cSStella Laurenzo         set(set) {}
395436c6c9cSStella Laurenzo 
396436c6c9cSStella Laurenzo   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
397436c6c9cSStella Laurenzo 
398436c6c9cSStella Laurenzo   PyIntegerSetConstraint getElement(intptr_t pos) {
399436c6c9cSStella Laurenzo     return PyIntegerSetConstraint(set, pos);
400436c6c9cSStella Laurenzo   }
401436c6c9cSStella Laurenzo 
402436c6c9cSStella Laurenzo   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
403436c6c9cSStella Laurenzo                                    intptr_t step) {
404436c6c9cSStella Laurenzo     return PyIntegerSetConstraintList(set, startIndex, length, step);
405436c6c9cSStella Laurenzo   }
406436c6c9cSStella Laurenzo 
407436c6c9cSStella Laurenzo private:
408436c6c9cSStella Laurenzo   PyIntegerSet set;
409436c6c9cSStella Laurenzo };
410436c6c9cSStella Laurenzo } // namespace
411436c6c9cSStella Laurenzo 
412436c6c9cSStella Laurenzo bool PyIntegerSet::operator==(const PyIntegerSet &other) {
413436c6c9cSStella Laurenzo   return mlirIntegerSetEqual(integerSet, other.integerSet);
414436c6c9cSStella Laurenzo }
415436c6c9cSStella Laurenzo 
416436c6c9cSStella Laurenzo py::object PyIntegerSet::getCapsule() {
417436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
418436c6c9cSStella Laurenzo       mlirPythonIntegerSetToCapsule(*this));
419436c6c9cSStella Laurenzo }
420436c6c9cSStella Laurenzo 
421436c6c9cSStella Laurenzo PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
422436c6c9cSStella Laurenzo   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
423436c6c9cSStella Laurenzo   if (mlirIntegerSetIsNull(rawIntegerSet))
424436c6c9cSStella Laurenzo     throw py::error_already_set();
425436c6c9cSStella Laurenzo   return PyIntegerSet(
426436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
427436c6c9cSStella Laurenzo       rawIntegerSet);
428436c6c9cSStella Laurenzo }
429436c6c9cSStella Laurenzo 
430436c6c9cSStella Laurenzo void mlir::python::populateIRAffine(py::module &m) {
431436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
432436c6c9cSStella Laurenzo   // Mapping of PyAffineExpr and derived classes.
433436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
434f05ff4f7SStella Laurenzo   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
435436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
436436c6c9cSStella Laurenzo                              &PyAffineExpr::getCapsule)
437436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
438436c6c9cSStella Laurenzo       .def("__add__",
439436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
440436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self, other);
441436c6c9cSStella Laurenzo            })
442436c6c9cSStella Laurenzo       .def("__mul__",
443436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
444436c6c9cSStella Laurenzo              return PyAffineMulExpr::get(self, other);
445436c6c9cSStella Laurenzo            })
446436c6c9cSStella Laurenzo       .def("__mod__",
447436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
448436c6c9cSStella Laurenzo              return PyAffineModExpr::get(self, other);
449436c6c9cSStella Laurenzo            })
450436c6c9cSStella Laurenzo       .def("__sub__",
451436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
452436c6c9cSStella Laurenzo              auto negOne =
453436c6c9cSStella Laurenzo                  PyAffineConstantExpr::get(-1, *self.getContext().get());
454436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self,
455436c6c9cSStella Laurenzo                                          PyAffineMulExpr::get(negOne, other));
456436c6c9cSStella Laurenzo            })
457436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineExpr &self,
458436c6c9cSStella Laurenzo                         PyAffineExpr &other) { return self == other; })
459436c6c9cSStella Laurenzo       .def("__eq__",
460436c6c9cSStella Laurenzo            [](PyAffineExpr &self, py::object &other) { return false; })
461436c6c9cSStella Laurenzo       .def("__str__",
462436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
463436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
464436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
465436c6c9cSStella Laurenzo                                  printAccum.getUserData());
466436c6c9cSStella Laurenzo              return printAccum.join();
467436c6c9cSStella Laurenzo            })
468436c6c9cSStella Laurenzo       .def("__repr__",
469436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
470436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
471436c6c9cSStella Laurenzo              printAccum.parts.append("AffineExpr(");
472436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
473436c6c9cSStella Laurenzo                                  printAccum.getUserData());
474436c6c9cSStella Laurenzo              printAccum.parts.append(")");
475436c6c9cSStella Laurenzo              return printAccum.join();
476436c6c9cSStella Laurenzo            })
477436c6c9cSStella Laurenzo       .def_property_readonly(
478436c6c9cSStella Laurenzo           "context",
479436c6c9cSStella Laurenzo           [](PyAffineExpr &self) { return self.getContext().getObject(); })
480436c6c9cSStella Laurenzo       .def_static(
481436c6c9cSStella Laurenzo           "get_add", &PyAffineAddExpr::get,
482436c6c9cSStella Laurenzo           "Gets an affine expression containing a sum of two expressions.")
483436c6c9cSStella Laurenzo       .def_static(
484436c6c9cSStella Laurenzo           "get_mul", &PyAffineMulExpr::get,
485436c6c9cSStella Laurenzo           "Gets an affine expression containing a product of two expressions.")
486436c6c9cSStella Laurenzo       .def_static("get_mod", &PyAffineModExpr::get,
487436c6c9cSStella Laurenzo                   "Gets an affine expression containing the modulo of dividing "
488436c6c9cSStella Laurenzo                   "one expression by another.")
489436c6c9cSStella Laurenzo       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
490436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-down "
491436c6c9cSStella Laurenzo                   "result of dividing one expression by another.")
492436c6c9cSStella Laurenzo       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
493436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-up result "
494436c6c9cSStella Laurenzo                   "of dividing one expression by another.")
495436c6c9cSStella Laurenzo       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
496436c6c9cSStella Laurenzo                   py::arg("context") = py::none(),
497436c6c9cSStella Laurenzo                   "Gets a constant affine expression with the given value.")
498436c6c9cSStella Laurenzo       .def_static(
499436c6c9cSStella Laurenzo           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
500436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
501436c6c9cSStella Laurenzo           "Gets an affine expression of a dimension at the given position.")
502436c6c9cSStella Laurenzo       .def_static(
503436c6c9cSStella Laurenzo           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
504436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
505436c6c9cSStella Laurenzo           "Gets an affine expression of a symbol at the given position.")
506436c6c9cSStella Laurenzo       .def(
507436c6c9cSStella Laurenzo           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
508436c6c9cSStella Laurenzo           kDumpDocstring);
509436c6c9cSStella Laurenzo   PyAffineConstantExpr::bind(m);
510436c6c9cSStella Laurenzo   PyAffineDimExpr::bind(m);
511436c6c9cSStella Laurenzo   PyAffineSymbolExpr::bind(m);
512436c6c9cSStella Laurenzo   PyAffineBinaryExpr::bind(m);
513436c6c9cSStella Laurenzo   PyAffineAddExpr::bind(m);
514436c6c9cSStella Laurenzo   PyAffineMulExpr::bind(m);
515436c6c9cSStella Laurenzo   PyAffineModExpr::bind(m);
516436c6c9cSStella Laurenzo   PyAffineFloorDivExpr::bind(m);
517436c6c9cSStella Laurenzo   PyAffineCeilDivExpr::bind(m);
518436c6c9cSStella Laurenzo 
519436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
520436c6c9cSStella Laurenzo   // Mapping of PyAffineMap.
521436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
522f05ff4f7SStella Laurenzo   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
523436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524436c6c9cSStella Laurenzo                              &PyAffineMap::getCapsule)
525436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
526436c6c9cSStella Laurenzo       .def("__eq__",
527436c6c9cSStella Laurenzo            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
528436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
529436c6c9cSStella Laurenzo       .def("__str__",
530436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
531436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
532436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
533436c6c9cSStella Laurenzo                                 printAccum.getUserData());
534436c6c9cSStella Laurenzo              return printAccum.join();
535436c6c9cSStella Laurenzo            })
536436c6c9cSStella Laurenzo       .def("__repr__",
537436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
538436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
539436c6c9cSStella Laurenzo              printAccum.parts.append("AffineMap(");
540436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
541436c6c9cSStella Laurenzo                                 printAccum.getUserData());
542436c6c9cSStella Laurenzo              printAccum.parts.append(")");
543436c6c9cSStella Laurenzo              return printAccum.join();
544436c6c9cSStella Laurenzo            })
545335d2df5SNicolas Vasilache       .def_static("compress_unused_symbols",
546335d2df5SNicolas Vasilache                   [](py::list affineMaps, DefaultingPyMlirContext context) {
547335d2df5SNicolas Vasilache                     SmallVector<MlirAffineMap> maps;
548335d2df5SNicolas Vasilache                     pyListToVector<PyAffineMap, MlirAffineMap>(
549335d2df5SNicolas Vasilache                         affineMaps, maps, "attempting to create an AffineMap");
550335d2df5SNicolas Vasilache                     std::vector<MlirAffineMap> compressed(affineMaps.size());
551335d2df5SNicolas Vasilache                     auto populate = [](void *result, intptr_t idx,
552335d2df5SNicolas Vasilache                                        MlirAffineMap m) {
553335d2df5SNicolas Vasilache                       static_cast<MlirAffineMap *>(result)[idx] = (m);
554335d2df5SNicolas Vasilache                     };
555335d2df5SNicolas Vasilache                     mlirAffineMapCompressUnusedSymbols(
556335d2df5SNicolas Vasilache                         maps.data(), maps.size(), compressed.data(), populate);
557335d2df5SNicolas Vasilache                     std::vector<PyAffineMap> res;
558335d2df5SNicolas Vasilache                     for (auto m : compressed)
559335d2df5SNicolas Vasilache                       res.push_back(PyAffineMap(context->getRef(), m));
560335d2df5SNicolas Vasilache                     return res;
561335d2df5SNicolas Vasilache                   })
562436c6c9cSStella Laurenzo       .def_property_readonly(
563436c6c9cSStella Laurenzo           "context",
564436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return self.getContext().getObject(); },
565436c6c9cSStella Laurenzo           "Context that owns the Affine Map")
566436c6c9cSStella Laurenzo       .def(
567436c6c9cSStella Laurenzo           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
568436c6c9cSStella Laurenzo           kDumpDocstring)
569436c6c9cSStella Laurenzo       .def_static(
570436c6c9cSStella Laurenzo           "get",
571436c6c9cSStella Laurenzo           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
572436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
573436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
574436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr, MlirAffineExpr>(
575436c6c9cSStella Laurenzo                 exprs, affineExprs, "attempting to create an AffineMap");
576436c6c9cSStella Laurenzo             MlirAffineMap map =
577436c6c9cSStella Laurenzo                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
578436c6c9cSStella Laurenzo                                  affineExprs.size(), affineExprs.data());
579436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), map);
580436c6c9cSStella Laurenzo           },
581436c6c9cSStella Laurenzo           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
582436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
583436c6c9cSStella Laurenzo           "Gets a map with the given expressions as results.")
584436c6c9cSStella Laurenzo       .def_static(
585436c6c9cSStella Laurenzo           "get_constant",
586436c6c9cSStella Laurenzo           [](intptr_t value, DefaultingPyMlirContext context) {
587436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
588436c6c9cSStella Laurenzo                 mlirAffineMapConstantGet(context->get(), value);
589436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
590436c6c9cSStella Laurenzo           },
591436c6c9cSStella Laurenzo           py::arg("value"), py::arg("context") = py::none(),
592436c6c9cSStella Laurenzo           "Gets an affine map with a single constant result")
593436c6c9cSStella Laurenzo       .def_static(
594436c6c9cSStella Laurenzo           "get_empty",
595436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
596436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
597436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
598436c6c9cSStella Laurenzo           },
599436c6c9cSStella Laurenzo           py::arg("context") = py::none(), "Gets an empty affine map.")
600436c6c9cSStella Laurenzo       .def_static(
601436c6c9cSStella Laurenzo           "get_identity",
602436c6c9cSStella Laurenzo           [](intptr_t nDims, DefaultingPyMlirContext context) {
603436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
604436c6c9cSStella Laurenzo                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
605436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
606436c6c9cSStella Laurenzo           },
607436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("context") = py::none(),
608436c6c9cSStella Laurenzo           "Gets an identity map with the given number of dimensions.")
609436c6c9cSStella Laurenzo       .def_static(
610436c6c9cSStella Laurenzo           "get_minor_identity",
611436c6c9cSStella Laurenzo           [](intptr_t nDims, intptr_t nResults,
612436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
613436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
614436c6c9cSStella Laurenzo                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
615436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
616436c6c9cSStella Laurenzo           },
617436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("n_results"),
618436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
619436c6c9cSStella Laurenzo           "Gets a minor identity map with the given number of dimensions and "
620436c6c9cSStella Laurenzo           "results.")
621436c6c9cSStella Laurenzo       .def_static(
622436c6c9cSStella Laurenzo           "get_permutation",
623436c6c9cSStella Laurenzo           [](std::vector<unsigned> permutation,
624436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
625436c6c9cSStella Laurenzo             if (!isPermutation(permutation))
626436c6c9cSStella Laurenzo               throw py::cast_error("Invalid permutation when attempting to "
627436c6c9cSStella Laurenzo                                    "create an AffineMap");
628436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
629436c6c9cSStella Laurenzo                 context->get(), permutation.size(), permutation.data());
630436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
631436c6c9cSStella Laurenzo           },
632436c6c9cSStella Laurenzo           py::arg("permutation"), py::arg("context") = py::none(),
633436c6c9cSStella Laurenzo           "Gets an affine map that permutes its inputs.")
634436c6c9cSStella Laurenzo       .def("get_submap",
635436c6c9cSStella Laurenzo            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
636436c6c9cSStella Laurenzo              intptr_t numResults = mlirAffineMapGetNumResults(self);
637436c6c9cSStella Laurenzo              for (intptr_t pos : resultPos) {
638436c6c9cSStella Laurenzo                if (pos < 0 || pos >= numResults)
639436c6c9cSStella Laurenzo                  throw py::value_error("result position out of bounds");
640436c6c9cSStella Laurenzo              }
641436c6c9cSStella Laurenzo              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
642436c6c9cSStella Laurenzo                  self, resultPos.size(), resultPos.data());
643436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
644436c6c9cSStella Laurenzo            })
645436c6c9cSStella Laurenzo       .def("get_major_submap",
646436c6c9cSStella Laurenzo            [](PyAffineMap &self, intptr_t nResults) {
647436c6c9cSStella Laurenzo              if (nResults >= mlirAffineMapGetNumResults(self))
648436c6c9cSStella Laurenzo                throw py::value_error("number of results out of bounds");
649436c6c9cSStella Laurenzo              MlirAffineMap affineMap =
650436c6c9cSStella Laurenzo                  mlirAffineMapGetMajorSubMap(self, nResults);
651436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
652436c6c9cSStella Laurenzo            })
653436c6c9cSStella Laurenzo       .def("get_minor_submap",
654436c6c9cSStella Laurenzo            [](PyAffineMap &self, intptr_t nResults) {
655436c6c9cSStella Laurenzo              if (nResults >= mlirAffineMapGetNumResults(self))
656436c6c9cSStella Laurenzo                throw py::value_error("number of results out of bounds");
657436c6c9cSStella Laurenzo              MlirAffineMap affineMap =
658436c6c9cSStella Laurenzo                  mlirAffineMapGetMinorSubMap(self, nResults);
659436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
660436c6c9cSStella Laurenzo            })
66131f888eaSTobias Gysi       .def("replace",
66231f888eaSTobias Gysi            [](PyAffineMap &self, PyAffineExpr &expression,
66331f888eaSTobias Gysi               PyAffineExpr &replacement, intptr_t numResultDims,
66431f888eaSTobias Gysi               intptr_t numResultSyms) {
66531f888eaSTobias Gysi              MlirAffineMap affineMap = mlirAffineMapReplace(
66631f888eaSTobias Gysi                  self, expression, replacement, numResultDims, numResultSyms);
66731f888eaSTobias Gysi              return PyAffineMap(self.getContext(), affineMap);
66831f888eaSTobias Gysi            })
669436c6c9cSStella Laurenzo       .def_property_readonly(
670436c6c9cSStella Laurenzo           "is_permutation",
671436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
672436c6c9cSStella Laurenzo       .def_property_readonly("is_projected_permutation",
673436c6c9cSStella Laurenzo                              [](PyAffineMap &self) {
674436c6c9cSStella Laurenzo                                return mlirAffineMapIsProjectedPermutation(self);
675436c6c9cSStella Laurenzo                              })
676436c6c9cSStella Laurenzo       .def_property_readonly(
677436c6c9cSStella Laurenzo           "n_dims",
678436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
679436c6c9cSStella Laurenzo       .def_property_readonly(
680436c6c9cSStella Laurenzo           "n_inputs",
681436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
682436c6c9cSStella Laurenzo       .def_property_readonly(
683436c6c9cSStella Laurenzo           "n_symbols",
684436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
685436c6c9cSStella Laurenzo       .def_property_readonly("results", [](PyAffineMap &self) {
686436c6c9cSStella Laurenzo         return PyAffineMapExprList(self);
687436c6c9cSStella Laurenzo       });
688436c6c9cSStella Laurenzo   PyAffineMapExprList::bind(m);
689436c6c9cSStella Laurenzo 
690436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
691436c6c9cSStella Laurenzo   // Mapping of PyIntegerSet.
692436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
693f05ff4f7SStella Laurenzo   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
694436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
695436c6c9cSStella Laurenzo                              &PyIntegerSet::getCapsule)
696436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
697436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self,
698436c6c9cSStella Laurenzo                         PyIntegerSet &other) { return self == other; })
699436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
700436c6c9cSStella Laurenzo       .def("__str__",
701436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
702436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
703436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
704436c6c9cSStella Laurenzo                                  printAccum.getUserData());
705436c6c9cSStella Laurenzo              return printAccum.join();
706436c6c9cSStella Laurenzo            })
707436c6c9cSStella Laurenzo       .def("__repr__",
708436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
709436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
710436c6c9cSStella Laurenzo              printAccum.parts.append("IntegerSet(");
711436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
712436c6c9cSStella Laurenzo                                  printAccum.getUserData());
713436c6c9cSStella Laurenzo              printAccum.parts.append(")");
714436c6c9cSStella Laurenzo              return printAccum.join();
715436c6c9cSStella Laurenzo            })
716436c6c9cSStella Laurenzo       .def_property_readonly(
717436c6c9cSStella Laurenzo           "context",
718436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return self.getContext().getObject(); })
719436c6c9cSStella Laurenzo       .def(
720436c6c9cSStella Laurenzo           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
721436c6c9cSStella Laurenzo           kDumpDocstring)
722436c6c9cSStella Laurenzo       .def_static(
723436c6c9cSStella Laurenzo           "get",
724436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
725436c6c9cSStella Laurenzo              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
726436c6c9cSStella Laurenzo             if (exprs.size() != eqFlags.size())
727436c6c9cSStella Laurenzo               throw py::value_error(
728436c6c9cSStella Laurenzo                   "Expected the number of constraints to match "
729436c6c9cSStella Laurenzo                   "that of equality flags");
730436c6c9cSStella Laurenzo             if (exprs.empty())
731436c6c9cSStella Laurenzo               throw py::value_error("Expected non-empty list of constraints");
732436c6c9cSStella Laurenzo 
733436c6c9cSStella Laurenzo             // Copy over to a SmallVector because std::vector has a
734436c6c9cSStella Laurenzo             // specialization for booleans that packs data and does not
735436c6c9cSStella Laurenzo             // expose a `bool *`.
736436c6c9cSStella Laurenzo             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
737436c6c9cSStella Laurenzo 
738436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
739436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(exprs, affineExprs,
740436c6c9cSStella Laurenzo                                          "attempting to create an IntegerSet");
741436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetGet(
742436c6c9cSStella Laurenzo                 context->get(), numDims, numSymbols, exprs.size(),
743436c6c9cSStella Laurenzo                 affineExprs.data(), flags.data());
744436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
745436c6c9cSStella Laurenzo           },
746436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
747436c6c9cSStella Laurenzo           py::arg("eq_flags"), py::arg("context") = py::none())
748436c6c9cSStella Laurenzo       .def_static(
749436c6c9cSStella Laurenzo           "get_empty",
750436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols,
751436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
752436c6c9cSStella Laurenzo             MlirIntegerSet set =
753436c6c9cSStella Laurenzo                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
754436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
755436c6c9cSStella Laurenzo           },
756436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"),
757436c6c9cSStella Laurenzo           py::arg("context") = py::none())
758436c6c9cSStella Laurenzo       .def("get_replaced",
759436c6c9cSStella Laurenzo            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
760436c6c9cSStella Laurenzo               intptr_t numResultDims, intptr_t numResultSymbols) {
761436c6c9cSStella Laurenzo              if (static_cast<intptr_t>(dimExprs.size()) !=
762436c6c9cSStella Laurenzo                  mlirIntegerSetGetNumDims(self))
763436c6c9cSStella Laurenzo                throw py::value_error(
764436c6c9cSStella Laurenzo                    "Expected the number of dimension replacement expressions "
765436c6c9cSStella Laurenzo                    "to match that of dimensions");
766436c6c9cSStella Laurenzo              if (static_cast<intptr_t>(symbolExprs.size()) !=
767436c6c9cSStella Laurenzo                  mlirIntegerSetGetNumSymbols(self))
768436c6c9cSStella Laurenzo                throw py::value_error(
769436c6c9cSStella Laurenzo                    "Expected the number of symbol replacement expressions "
770436c6c9cSStella Laurenzo                    "to match that of symbols");
771436c6c9cSStella Laurenzo 
772436c6c9cSStella Laurenzo              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
773436c6c9cSStella Laurenzo              pyListToVector<PyAffineExpr>(
774436c6c9cSStella Laurenzo                  dimExprs, dimAffineExprs,
775436c6c9cSStella Laurenzo                  "attempting to create an IntegerSet by replacing dimensions");
776436c6c9cSStella Laurenzo              pyListToVector<PyAffineExpr>(
777436c6c9cSStella Laurenzo                  symbolExprs, symbolAffineExprs,
778436c6c9cSStella Laurenzo                  "attempting to create an IntegerSet by replacing symbols");
779436c6c9cSStella Laurenzo              MlirIntegerSet set = mlirIntegerSetReplaceGet(
780436c6c9cSStella Laurenzo                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
781436c6c9cSStella Laurenzo                  numResultDims, numResultSymbols);
782436c6c9cSStella Laurenzo              return PyIntegerSet(self.getContext(), set);
783436c6c9cSStella Laurenzo            })
784436c6c9cSStella Laurenzo       .def_property_readonly("is_canonical_empty",
785436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
786436c6c9cSStella Laurenzo                                return mlirIntegerSetIsCanonicalEmpty(self);
787436c6c9cSStella Laurenzo                              })
788436c6c9cSStella Laurenzo       .def_property_readonly(
789436c6c9cSStella Laurenzo           "n_dims",
790436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
791436c6c9cSStella Laurenzo       .def_property_readonly(
792436c6c9cSStella Laurenzo           "n_symbols",
793436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
794436c6c9cSStella Laurenzo       .def_property_readonly(
795436c6c9cSStella Laurenzo           "n_inputs",
796436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
797436c6c9cSStella Laurenzo       .def_property_readonly("n_equalities",
798436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
799436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumEqualities(self);
800436c6c9cSStella Laurenzo                              })
801436c6c9cSStella Laurenzo       .def_property_readonly("n_inequalities",
802436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
803436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumInequalities(self);
804436c6c9cSStella Laurenzo                              })
805436c6c9cSStella Laurenzo       .def_property_readonly("constraints", [](PyIntegerSet &self) {
806436c6c9cSStella Laurenzo         return PyIntegerSetConstraintList(self);
807436c6c9cSStella Laurenzo       });
808436c6c9cSStella Laurenzo   PyIntegerSetConstraint::bind(m);
809436c6c9cSStella Laurenzo   PyIntegerSetConstraintList::bind(m);
810436c6c9cSStella Laurenzo }
811