xref: /llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp (revision 436c6c9c20cc522c92a923440a5fc509c342a7db)
1*436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2*436c6c9cSStella Laurenzo //
3*436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5*436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*436c6c9cSStella Laurenzo //
7*436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8*436c6c9cSStella Laurenzo 
9*436c6c9cSStella Laurenzo #include "IRModule.h"
10*436c6c9cSStella Laurenzo 
11*436c6c9cSStella Laurenzo #include "PybindUtils.h"
12*436c6c9cSStella Laurenzo 
13*436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h"
14*436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15*436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h"
16*436c6c9cSStella Laurenzo 
17*436c6c9cSStella Laurenzo namespace py = pybind11;
18*436c6c9cSStella Laurenzo using namespace mlir;
19*436c6c9cSStella Laurenzo using namespace mlir::python;
20*436c6c9cSStella Laurenzo 
21*436c6c9cSStella Laurenzo using llvm::SmallVector;
22*436c6c9cSStella Laurenzo using llvm::StringRef;
23*436c6c9cSStella Laurenzo using llvm::Twine;
24*436c6c9cSStella Laurenzo 
25*436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
26*436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
27*436c6c9cSStella Laurenzo 
28*436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the
29*436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments).
30*436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller
31*436c6c9cSStella Laurenzo /// was attempting to do.
32*436c6c9cSStella Laurenzo template <typename PyType, typename CType>
33*436c6c9cSStella Laurenzo static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34*436c6c9cSStella Laurenzo                            StringRef action) {
35*436c6c9cSStella Laurenzo   result.reserve(py::len(list));
36*436c6c9cSStella Laurenzo   for (py::handle item : list) {
37*436c6c9cSStella Laurenzo     try {
38*436c6c9cSStella Laurenzo       result.push_back(item.cast<PyType>());
39*436c6c9cSStella Laurenzo     } catch (py::cast_error &err) {
40*436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41*436c6c9cSStella Laurenzo                          " (" + err.what() + ")")
42*436c6c9cSStella Laurenzo                             .str();
43*436c6c9cSStella Laurenzo       throw py::cast_error(msg);
44*436c6c9cSStella Laurenzo     } catch (py::reference_cast_error &err) {
45*436c6c9cSStella Laurenzo       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46*436c6c9cSStella Laurenzo                          action + " (" + err.what() + ")")
47*436c6c9cSStella Laurenzo                             .str();
48*436c6c9cSStella Laurenzo       throw py::cast_error(msg);
49*436c6c9cSStella Laurenzo     }
50*436c6c9cSStella Laurenzo   }
51*436c6c9cSStella Laurenzo }
52*436c6c9cSStella Laurenzo 
53*436c6c9cSStella Laurenzo template <typename PermutationTy>
54*436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) {
55*436c6c9cSStella Laurenzo   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56*436c6c9cSStella Laurenzo   for (auto val : permutation) {
57*436c6c9cSStella Laurenzo     if (val < permutation.size()) {
58*436c6c9cSStella Laurenzo       if (seen[val])
59*436c6c9cSStella Laurenzo         return false;
60*436c6c9cSStella Laurenzo       seen[val] = true;
61*436c6c9cSStella Laurenzo       continue;
62*436c6c9cSStella Laurenzo     }
63*436c6c9cSStella Laurenzo     return false;
64*436c6c9cSStella Laurenzo   }
65*436c6c9cSStella Laurenzo   return true;
66*436c6c9cSStella Laurenzo }
67*436c6c9cSStella Laurenzo 
68*436c6c9cSStella Laurenzo namespace {
69*436c6c9cSStella Laurenzo 
70*436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71*436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be
72*436c6c9cSStella Laurenzo /// modeled by specifying BaseTy.
73*436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74*436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy {
75*436c6c9cSStella Laurenzo public:
76*436c6c9cSStella Laurenzo   // Derived classes must define statics for:
77*436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
78*436c6c9cSStella Laurenzo   //   const char *pyClassName
79*436c6c9cSStella Laurenzo   // and redefine bindDerived.
80*436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, BaseTy>;
81*436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82*436c6c9cSStella Laurenzo 
83*436c6c9cSStella Laurenzo   PyConcreteAffineExpr() = default;
84*436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85*436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), affineExpr) {}
86*436c6c9cSStella Laurenzo   PyConcreteAffineExpr(PyAffineExpr &orig)
87*436c6c9cSStella Laurenzo       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88*436c6c9cSStella Laurenzo 
89*436c6c9cSStella Laurenzo   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90*436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
91*436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError,
93*436c6c9cSStella Laurenzo                        Twine("Cannot cast affine expression to ") +
94*436c6c9cSStella Laurenzo                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95*436c6c9cSStella Laurenzo     }
96*436c6c9cSStella Laurenzo     return orig;
97*436c6c9cSStella Laurenzo   }
98*436c6c9cSStella Laurenzo 
99*436c6c9cSStella Laurenzo   static void bind(py::module &m) {
100*436c6c9cSStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName);
101*436c6c9cSStella Laurenzo     cls.def(py::init<PyAffineExpr &>());
102*436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
103*436c6c9cSStella Laurenzo   }
104*436c6c9cSStella Laurenzo 
105*436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
106*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
107*436c6c9cSStella Laurenzo };
108*436c6c9cSStella Laurenzo 
109*436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
110*436c6c9cSStella Laurenzo public:
111*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
112*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineConstantExpr";
113*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
114*436c6c9cSStella Laurenzo 
115*436c6c9cSStella Laurenzo   static PyAffineConstantExpr get(intptr_t value,
116*436c6c9cSStella Laurenzo                                   DefaultingPyMlirContext context) {
117*436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr =
118*436c6c9cSStella Laurenzo         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
119*436c6c9cSStella Laurenzo     return PyAffineConstantExpr(context->getRef(), affineExpr);
120*436c6c9cSStella Laurenzo   }
121*436c6c9cSStella Laurenzo 
122*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
123*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
124*436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
125*436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
126*436c6c9cSStella Laurenzo       return mlirAffineConstantExprGetValue(self);
127*436c6c9cSStella Laurenzo     });
128*436c6c9cSStella Laurenzo   }
129*436c6c9cSStella Laurenzo };
130*436c6c9cSStella Laurenzo 
131*436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
132*436c6c9cSStella Laurenzo public:
133*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
134*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineDimExpr";
135*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
136*436c6c9cSStella Laurenzo 
137*436c6c9cSStella Laurenzo   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
138*436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
139*436c6c9cSStella Laurenzo     return PyAffineDimExpr(context->getRef(), affineExpr);
140*436c6c9cSStella Laurenzo   }
141*436c6c9cSStella Laurenzo 
142*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
143*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
144*436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
145*436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
146*436c6c9cSStella Laurenzo       return mlirAffineDimExprGetPosition(self);
147*436c6c9cSStella Laurenzo     });
148*436c6c9cSStella Laurenzo   }
149*436c6c9cSStella Laurenzo };
150*436c6c9cSStella Laurenzo 
151*436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
152*436c6c9cSStella Laurenzo public:
153*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
154*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineSymbolExpr";
155*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
156*436c6c9cSStella Laurenzo 
157*436c6c9cSStella Laurenzo   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
158*436c6c9cSStella Laurenzo     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
159*436c6c9cSStella Laurenzo     return PyAffineSymbolExpr(context->getRef(), affineExpr);
160*436c6c9cSStella Laurenzo   }
161*436c6c9cSStella Laurenzo 
162*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
163*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
164*436c6c9cSStella Laurenzo                  py::arg("context") = py::none());
165*436c6c9cSStella Laurenzo     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
166*436c6c9cSStella Laurenzo       return mlirAffineSymbolExprGetPosition(self);
167*436c6c9cSStella Laurenzo     });
168*436c6c9cSStella Laurenzo   }
169*436c6c9cSStella Laurenzo };
170*436c6c9cSStella Laurenzo 
171*436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
172*436c6c9cSStella Laurenzo public:
173*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
174*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineBinaryExpr";
175*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
176*436c6c9cSStella Laurenzo 
177*436c6c9cSStella Laurenzo   PyAffineExpr lhs() {
178*436c6c9cSStella Laurenzo     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
179*436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), lhsExpr);
180*436c6c9cSStella Laurenzo   }
181*436c6c9cSStella Laurenzo 
182*436c6c9cSStella Laurenzo   PyAffineExpr rhs() {
183*436c6c9cSStella Laurenzo     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
184*436c6c9cSStella Laurenzo     return PyAffineExpr(getContext(), rhsExpr);
185*436c6c9cSStella Laurenzo   }
186*436c6c9cSStella Laurenzo 
187*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
188*436c6c9cSStella Laurenzo     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
189*436c6c9cSStella Laurenzo     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
190*436c6c9cSStella Laurenzo   }
191*436c6c9cSStella Laurenzo };
192*436c6c9cSStella Laurenzo 
193*436c6c9cSStella Laurenzo class PyAffineAddExpr
194*436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
195*436c6c9cSStella Laurenzo public:
196*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
197*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineAddExpr";
198*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
199*436c6c9cSStella Laurenzo 
200*436c6c9cSStella Laurenzo   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
201*436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
202*436c6c9cSStella Laurenzo     return PyAffineAddExpr(lhs.getContext(), expr);
203*436c6c9cSStella Laurenzo   }
204*436c6c9cSStella Laurenzo 
205*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
206*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineAddExpr::get);
207*436c6c9cSStella Laurenzo   }
208*436c6c9cSStella Laurenzo };
209*436c6c9cSStella Laurenzo 
210*436c6c9cSStella Laurenzo class PyAffineMulExpr
211*436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
212*436c6c9cSStella Laurenzo public:
213*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
214*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMulExpr";
215*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
216*436c6c9cSStella Laurenzo 
217*436c6c9cSStella Laurenzo   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
218*436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
219*436c6c9cSStella Laurenzo     return PyAffineMulExpr(lhs.getContext(), expr);
220*436c6c9cSStella Laurenzo   }
221*436c6c9cSStella Laurenzo 
222*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
223*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineMulExpr::get);
224*436c6c9cSStella Laurenzo   }
225*436c6c9cSStella Laurenzo };
226*436c6c9cSStella Laurenzo 
227*436c6c9cSStella Laurenzo class PyAffineModExpr
228*436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
229*436c6c9cSStella Laurenzo public:
230*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
231*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineModExpr";
232*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
233*436c6c9cSStella Laurenzo 
234*436c6c9cSStella Laurenzo   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
235*436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
236*436c6c9cSStella Laurenzo     return PyAffineModExpr(lhs.getContext(), expr);
237*436c6c9cSStella Laurenzo   }
238*436c6c9cSStella Laurenzo 
239*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
240*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineModExpr::get);
241*436c6c9cSStella Laurenzo   }
242*436c6c9cSStella Laurenzo };
243*436c6c9cSStella Laurenzo 
244*436c6c9cSStella Laurenzo class PyAffineFloorDivExpr
245*436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
246*436c6c9cSStella Laurenzo public:
247*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
248*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineFloorDivExpr";
249*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
250*436c6c9cSStella Laurenzo 
251*436c6c9cSStella Laurenzo   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
252*436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
253*436c6c9cSStella Laurenzo     return PyAffineFloorDivExpr(lhs.getContext(), expr);
254*436c6c9cSStella Laurenzo   }
255*436c6c9cSStella Laurenzo 
256*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
257*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineFloorDivExpr::get);
258*436c6c9cSStella Laurenzo   }
259*436c6c9cSStella Laurenzo };
260*436c6c9cSStella Laurenzo 
261*436c6c9cSStella Laurenzo class PyAffineCeilDivExpr
262*436c6c9cSStella Laurenzo     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
263*436c6c9cSStella Laurenzo public:
264*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
265*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineCeilDivExpr";
266*436c6c9cSStella Laurenzo   using PyConcreteAffineExpr::PyConcreteAffineExpr;
267*436c6c9cSStella Laurenzo 
268*436c6c9cSStella Laurenzo   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
269*436c6c9cSStella Laurenzo     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
270*436c6c9cSStella Laurenzo     return PyAffineCeilDivExpr(lhs.getContext(), expr);
271*436c6c9cSStella Laurenzo   }
272*436c6c9cSStella Laurenzo 
273*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
274*436c6c9cSStella Laurenzo     c.def_static("get", &PyAffineCeilDivExpr::get);
275*436c6c9cSStella Laurenzo   }
276*436c6c9cSStella Laurenzo };
277*436c6c9cSStella Laurenzo 
278*436c6c9cSStella Laurenzo } // namespace
279*436c6c9cSStella Laurenzo 
280*436c6c9cSStella Laurenzo bool PyAffineExpr::operator==(const PyAffineExpr &other) {
281*436c6c9cSStella Laurenzo   return mlirAffineExprEqual(affineExpr, other.affineExpr);
282*436c6c9cSStella Laurenzo }
283*436c6c9cSStella Laurenzo 
284*436c6c9cSStella Laurenzo py::object PyAffineExpr::getCapsule() {
285*436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
286*436c6c9cSStella Laurenzo       mlirPythonAffineExprToCapsule(*this));
287*436c6c9cSStella Laurenzo }
288*436c6c9cSStella Laurenzo 
289*436c6c9cSStella Laurenzo PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
290*436c6c9cSStella Laurenzo   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
291*436c6c9cSStella Laurenzo   if (mlirAffineExprIsNull(rawAffineExpr))
292*436c6c9cSStella Laurenzo     throw py::error_already_set();
293*436c6c9cSStella Laurenzo   return PyAffineExpr(
294*436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
295*436c6c9cSStella Laurenzo       rawAffineExpr);
296*436c6c9cSStella Laurenzo }
297*436c6c9cSStella Laurenzo 
298*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
299*436c6c9cSStella Laurenzo // PyAffineMap and utilities.
300*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
301*436c6c9cSStella Laurenzo namespace {
302*436c6c9cSStella Laurenzo 
303*436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are
304*436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both
305*436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother
306*436c6c9cSStella Laurenzo /// with lifetime extension.
307*436c6c9cSStella Laurenzo class PyAffineMapExprList
308*436c6c9cSStella Laurenzo     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
309*436c6c9cSStella Laurenzo public:
310*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineExprList";
311*436c6c9cSStella Laurenzo 
312*436c6c9cSStella Laurenzo   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
313*436c6c9cSStella Laurenzo                       intptr_t length = -1, intptr_t step = 1)
314*436c6c9cSStella Laurenzo       : Sliceable(startIndex,
315*436c6c9cSStella Laurenzo                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
316*436c6c9cSStella Laurenzo                   step),
317*436c6c9cSStella Laurenzo         affineMap(map) {}
318*436c6c9cSStella Laurenzo 
319*436c6c9cSStella Laurenzo   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
320*436c6c9cSStella Laurenzo 
321*436c6c9cSStella Laurenzo   PyAffineExpr getElement(intptr_t pos) {
322*436c6c9cSStella Laurenzo     return PyAffineExpr(affineMap.getContext(),
323*436c6c9cSStella Laurenzo                         mlirAffineMapGetResult(affineMap, pos));
324*436c6c9cSStella Laurenzo   }
325*436c6c9cSStella Laurenzo 
326*436c6c9cSStella Laurenzo   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
327*436c6c9cSStella Laurenzo                             intptr_t step) {
328*436c6c9cSStella Laurenzo     return PyAffineMapExprList(affineMap, startIndex, length, step);
329*436c6c9cSStella Laurenzo   }
330*436c6c9cSStella Laurenzo 
331*436c6c9cSStella Laurenzo private:
332*436c6c9cSStella Laurenzo   PyAffineMap affineMap;
333*436c6c9cSStella Laurenzo };
334*436c6c9cSStella Laurenzo } // end namespace
335*436c6c9cSStella Laurenzo 
336*436c6c9cSStella Laurenzo bool PyAffineMap::operator==(const PyAffineMap &other) {
337*436c6c9cSStella Laurenzo   return mlirAffineMapEqual(affineMap, other.affineMap);
338*436c6c9cSStella Laurenzo }
339*436c6c9cSStella Laurenzo 
340*436c6c9cSStella Laurenzo py::object PyAffineMap::getCapsule() {
341*436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
342*436c6c9cSStella Laurenzo }
343*436c6c9cSStella Laurenzo 
344*436c6c9cSStella Laurenzo PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
345*436c6c9cSStella Laurenzo   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
346*436c6c9cSStella Laurenzo   if (mlirAffineMapIsNull(rawAffineMap))
347*436c6c9cSStella Laurenzo     throw py::error_already_set();
348*436c6c9cSStella Laurenzo   return PyAffineMap(
349*436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
350*436c6c9cSStella Laurenzo       rawAffineMap);
351*436c6c9cSStella Laurenzo }
352*436c6c9cSStella Laurenzo 
353*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
354*436c6c9cSStella Laurenzo // PyIntegerSet and utilities.
355*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
356*436c6c9cSStella Laurenzo namespace {
357*436c6c9cSStella Laurenzo 
358*436c6c9cSStella Laurenzo class PyIntegerSetConstraint {
359*436c6c9cSStella Laurenzo public:
360*436c6c9cSStella Laurenzo   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
361*436c6c9cSStella Laurenzo 
362*436c6c9cSStella Laurenzo   PyAffineExpr getExpr() {
363*436c6c9cSStella Laurenzo     return PyAffineExpr(set.getContext(),
364*436c6c9cSStella Laurenzo                         mlirIntegerSetGetConstraint(set, pos));
365*436c6c9cSStella Laurenzo   }
366*436c6c9cSStella Laurenzo 
367*436c6c9cSStella Laurenzo   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
368*436c6c9cSStella Laurenzo 
369*436c6c9cSStella Laurenzo   static void bind(py::module &m) {
370*436c6c9cSStella Laurenzo     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
371*436c6c9cSStella Laurenzo         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
372*436c6c9cSStella Laurenzo         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
373*436c6c9cSStella Laurenzo   }
374*436c6c9cSStella Laurenzo 
375*436c6c9cSStella Laurenzo private:
376*436c6c9cSStella Laurenzo   PyIntegerSet set;
377*436c6c9cSStella Laurenzo   intptr_t pos;
378*436c6c9cSStella Laurenzo };
379*436c6c9cSStella Laurenzo 
380*436c6c9cSStella Laurenzo class PyIntegerSetConstraintList
381*436c6c9cSStella Laurenzo     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
382*436c6c9cSStella Laurenzo public:
383*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerSetConstraintList";
384*436c6c9cSStella Laurenzo 
385*436c6c9cSStella Laurenzo   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
386*436c6c9cSStella Laurenzo                              intptr_t length = -1, intptr_t step = 1)
387*436c6c9cSStella Laurenzo       : Sliceable(startIndex,
388*436c6c9cSStella Laurenzo                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
389*436c6c9cSStella Laurenzo                   step),
390*436c6c9cSStella Laurenzo         set(set) {}
391*436c6c9cSStella Laurenzo 
392*436c6c9cSStella Laurenzo   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
393*436c6c9cSStella Laurenzo 
394*436c6c9cSStella Laurenzo   PyIntegerSetConstraint getElement(intptr_t pos) {
395*436c6c9cSStella Laurenzo     return PyIntegerSetConstraint(set, pos);
396*436c6c9cSStella Laurenzo   }
397*436c6c9cSStella Laurenzo 
398*436c6c9cSStella Laurenzo   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
399*436c6c9cSStella Laurenzo                                    intptr_t step) {
400*436c6c9cSStella Laurenzo     return PyIntegerSetConstraintList(set, startIndex, length, step);
401*436c6c9cSStella Laurenzo   }
402*436c6c9cSStella Laurenzo 
403*436c6c9cSStella Laurenzo private:
404*436c6c9cSStella Laurenzo   PyIntegerSet set;
405*436c6c9cSStella Laurenzo };
406*436c6c9cSStella Laurenzo } // namespace
407*436c6c9cSStella Laurenzo 
408*436c6c9cSStella Laurenzo bool PyIntegerSet::operator==(const PyIntegerSet &other) {
409*436c6c9cSStella Laurenzo   return mlirIntegerSetEqual(integerSet, other.integerSet);
410*436c6c9cSStella Laurenzo }
411*436c6c9cSStella Laurenzo 
412*436c6c9cSStella Laurenzo py::object PyIntegerSet::getCapsule() {
413*436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(
414*436c6c9cSStella Laurenzo       mlirPythonIntegerSetToCapsule(*this));
415*436c6c9cSStella Laurenzo }
416*436c6c9cSStella Laurenzo 
417*436c6c9cSStella Laurenzo PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
418*436c6c9cSStella Laurenzo   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
419*436c6c9cSStella Laurenzo   if (mlirIntegerSetIsNull(rawIntegerSet))
420*436c6c9cSStella Laurenzo     throw py::error_already_set();
421*436c6c9cSStella Laurenzo   return PyIntegerSet(
422*436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
423*436c6c9cSStella Laurenzo       rawIntegerSet);
424*436c6c9cSStella Laurenzo }
425*436c6c9cSStella Laurenzo 
426*436c6c9cSStella Laurenzo void mlir::python::populateIRAffine(py::module &m) {
427*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
428*436c6c9cSStella Laurenzo   // Mapping of PyAffineExpr and derived classes.
429*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
430*436c6c9cSStella Laurenzo   py::class_<PyAffineExpr>(m, "AffineExpr")
431*436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
432*436c6c9cSStella Laurenzo                              &PyAffineExpr::getCapsule)
433*436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
434*436c6c9cSStella Laurenzo       .def("__add__",
435*436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
436*436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self, other);
437*436c6c9cSStella Laurenzo            })
438*436c6c9cSStella Laurenzo       .def("__mul__",
439*436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
440*436c6c9cSStella Laurenzo              return PyAffineMulExpr::get(self, other);
441*436c6c9cSStella Laurenzo            })
442*436c6c9cSStella Laurenzo       .def("__mod__",
443*436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
444*436c6c9cSStella Laurenzo              return PyAffineModExpr::get(self, other);
445*436c6c9cSStella Laurenzo            })
446*436c6c9cSStella Laurenzo       .def("__sub__",
447*436c6c9cSStella Laurenzo            [](PyAffineExpr &self, PyAffineExpr &other) {
448*436c6c9cSStella Laurenzo              auto negOne =
449*436c6c9cSStella Laurenzo                  PyAffineConstantExpr::get(-1, *self.getContext().get());
450*436c6c9cSStella Laurenzo              return PyAffineAddExpr::get(self,
451*436c6c9cSStella Laurenzo                                          PyAffineMulExpr::get(negOne, other));
452*436c6c9cSStella Laurenzo            })
453*436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineExpr &self,
454*436c6c9cSStella Laurenzo                         PyAffineExpr &other) { return self == other; })
455*436c6c9cSStella Laurenzo       .def("__eq__",
456*436c6c9cSStella Laurenzo            [](PyAffineExpr &self, py::object &other) { return false; })
457*436c6c9cSStella Laurenzo       .def("__str__",
458*436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
459*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
460*436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
461*436c6c9cSStella Laurenzo                                  printAccum.getUserData());
462*436c6c9cSStella Laurenzo              return printAccum.join();
463*436c6c9cSStella Laurenzo            })
464*436c6c9cSStella Laurenzo       .def("__repr__",
465*436c6c9cSStella Laurenzo            [](PyAffineExpr &self) {
466*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
467*436c6c9cSStella Laurenzo              printAccum.parts.append("AffineExpr(");
468*436c6c9cSStella Laurenzo              mlirAffineExprPrint(self, printAccum.getCallback(),
469*436c6c9cSStella Laurenzo                                  printAccum.getUserData());
470*436c6c9cSStella Laurenzo              printAccum.parts.append(")");
471*436c6c9cSStella Laurenzo              return printAccum.join();
472*436c6c9cSStella Laurenzo            })
473*436c6c9cSStella Laurenzo       .def_property_readonly(
474*436c6c9cSStella Laurenzo           "context",
475*436c6c9cSStella Laurenzo           [](PyAffineExpr &self) { return self.getContext().getObject(); })
476*436c6c9cSStella Laurenzo       .def_static(
477*436c6c9cSStella Laurenzo           "get_add", &PyAffineAddExpr::get,
478*436c6c9cSStella Laurenzo           "Gets an affine expression containing a sum of two expressions.")
479*436c6c9cSStella Laurenzo       .def_static(
480*436c6c9cSStella Laurenzo           "get_mul", &PyAffineMulExpr::get,
481*436c6c9cSStella Laurenzo           "Gets an affine expression containing a product of two expressions.")
482*436c6c9cSStella Laurenzo       .def_static("get_mod", &PyAffineModExpr::get,
483*436c6c9cSStella Laurenzo                   "Gets an affine expression containing the modulo of dividing "
484*436c6c9cSStella Laurenzo                   "one expression by another.")
485*436c6c9cSStella Laurenzo       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
486*436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-down "
487*436c6c9cSStella Laurenzo                   "result of dividing one expression by another.")
488*436c6c9cSStella Laurenzo       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
489*436c6c9cSStella Laurenzo                   "Gets an affine expression containing the rounded-up result "
490*436c6c9cSStella Laurenzo                   "of dividing one expression by another.")
491*436c6c9cSStella Laurenzo       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
492*436c6c9cSStella Laurenzo                   py::arg("context") = py::none(),
493*436c6c9cSStella Laurenzo                   "Gets a constant affine expression with the given value.")
494*436c6c9cSStella Laurenzo       .def_static(
495*436c6c9cSStella Laurenzo           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
496*436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
497*436c6c9cSStella Laurenzo           "Gets an affine expression of a dimension at the given position.")
498*436c6c9cSStella Laurenzo       .def_static(
499*436c6c9cSStella Laurenzo           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
500*436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
501*436c6c9cSStella Laurenzo           "Gets an affine expression of a symbol at the given position.")
502*436c6c9cSStella Laurenzo       .def(
503*436c6c9cSStella Laurenzo           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
504*436c6c9cSStella Laurenzo           kDumpDocstring);
505*436c6c9cSStella Laurenzo   PyAffineConstantExpr::bind(m);
506*436c6c9cSStella Laurenzo   PyAffineDimExpr::bind(m);
507*436c6c9cSStella Laurenzo   PyAffineSymbolExpr::bind(m);
508*436c6c9cSStella Laurenzo   PyAffineBinaryExpr::bind(m);
509*436c6c9cSStella Laurenzo   PyAffineAddExpr::bind(m);
510*436c6c9cSStella Laurenzo   PyAffineMulExpr::bind(m);
511*436c6c9cSStella Laurenzo   PyAffineModExpr::bind(m);
512*436c6c9cSStella Laurenzo   PyAffineFloorDivExpr::bind(m);
513*436c6c9cSStella Laurenzo   PyAffineCeilDivExpr::bind(m);
514*436c6c9cSStella Laurenzo 
515*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
516*436c6c9cSStella Laurenzo   // Mapping of PyAffineMap.
517*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
518*436c6c9cSStella Laurenzo   py::class_<PyAffineMap>(m, "AffineMap")
519*436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
520*436c6c9cSStella Laurenzo                              &PyAffineMap::getCapsule)
521*436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
522*436c6c9cSStella Laurenzo       .def("__eq__",
523*436c6c9cSStella Laurenzo            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
524*436c6c9cSStella Laurenzo       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
525*436c6c9cSStella Laurenzo       .def("__str__",
526*436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
527*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
528*436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
529*436c6c9cSStella Laurenzo                                 printAccum.getUserData());
530*436c6c9cSStella Laurenzo              return printAccum.join();
531*436c6c9cSStella Laurenzo            })
532*436c6c9cSStella Laurenzo       .def("__repr__",
533*436c6c9cSStella Laurenzo            [](PyAffineMap &self) {
534*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
535*436c6c9cSStella Laurenzo              printAccum.parts.append("AffineMap(");
536*436c6c9cSStella Laurenzo              mlirAffineMapPrint(self, printAccum.getCallback(),
537*436c6c9cSStella Laurenzo                                 printAccum.getUserData());
538*436c6c9cSStella Laurenzo              printAccum.parts.append(")");
539*436c6c9cSStella Laurenzo              return printAccum.join();
540*436c6c9cSStella Laurenzo            })
541*436c6c9cSStella Laurenzo       .def_property_readonly(
542*436c6c9cSStella Laurenzo           "context",
543*436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return self.getContext().getObject(); },
544*436c6c9cSStella Laurenzo           "Context that owns the Affine Map")
545*436c6c9cSStella Laurenzo       .def(
546*436c6c9cSStella Laurenzo           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
547*436c6c9cSStella Laurenzo           kDumpDocstring)
548*436c6c9cSStella Laurenzo       .def_static(
549*436c6c9cSStella Laurenzo           "get",
550*436c6c9cSStella Laurenzo           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
551*436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
552*436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
553*436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr, MlirAffineExpr>(
554*436c6c9cSStella Laurenzo                 exprs, affineExprs, "attempting to create an AffineMap");
555*436c6c9cSStella Laurenzo             MlirAffineMap map =
556*436c6c9cSStella Laurenzo                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
557*436c6c9cSStella Laurenzo                                  affineExprs.size(), affineExprs.data());
558*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), map);
559*436c6c9cSStella Laurenzo           },
560*436c6c9cSStella Laurenzo           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
561*436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
562*436c6c9cSStella Laurenzo           "Gets a map with the given expressions as results.")
563*436c6c9cSStella Laurenzo       .def_static(
564*436c6c9cSStella Laurenzo           "get_constant",
565*436c6c9cSStella Laurenzo           [](intptr_t value, DefaultingPyMlirContext context) {
566*436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
567*436c6c9cSStella Laurenzo                 mlirAffineMapConstantGet(context->get(), value);
568*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
569*436c6c9cSStella Laurenzo           },
570*436c6c9cSStella Laurenzo           py::arg("value"), py::arg("context") = py::none(),
571*436c6c9cSStella Laurenzo           "Gets an affine map with a single constant result")
572*436c6c9cSStella Laurenzo       .def_static(
573*436c6c9cSStella Laurenzo           "get_empty",
574*436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
575*436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
576*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
577*436c6c9cSStella Laurenzo           },
578*436c6c9cSStella Laurenzo           py::arg("context") = py::none(), "Gets an empty affine map.")
579*436c6c9cSStella Laurenzo       .def_static(
580*436c6c9cSStella Laurenzo           "get_identity",
581*436c6c9cSStella Laurenzo           [](intptr_t nDims, DefaultingPyMlirContext context) {
582*436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
583*436c6c9cSStella Laurenzo                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
584*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
585*436c6c9cSStella Laurenzo           },
586*436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("context") = py::none(),
587*436c6c9cSStella Laurenzo           "Gets an identity map with the given number of dimensions.")
588*436c6c9cSStella Laurenzo       .def_static(
589*436c6c9cSStella Laurenzo           "get_minor_identity",
590*436c6c9cSStella Laurenzo           [](intptr_t nDims, intptr_t nResults,
591*436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
592*436c6c9cSStella Laurenzo             MlirAffineMap affineMap =
593*436c6c9cSStella Laurenzo                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
594*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
595*436c6c9cSStella Laurenzo           },
596*436c6c9cSStella Laurenzo           py::arg("n_dims"), py::arg("n_results"),
597*436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
598*436c6c9cSStella Laurenzo           "Gets a minor identity map with the given number of dimensions and "
599*436c6c9cSStella Laurenzo           "results.")
600*436c6c9cSStella Laurenzo       .def_static(
601*436c6c9cSStella Laurenzo           "get_permutation",
602*436c6c9cSStella Laurenzo           [](std::vector<unsigned> permutation,
603*436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
604*436c6c9cSStella Laurenzo             if (!isPermutation(permutation))
605*436c6c9cSStella Laurenzo               throw py::cast_error("Invalid permutation when attempting to "
606*436c6c9cSStella Laurenzo                                    "create an AffineMap");
607*436c6c9cSStella Laurenzo             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
608*436c6c9cSStella Laurenzo                 context->get(), permutation.size(), permutation.data());
609*436c6c9cSStella Laurenzo             return PyAffineMap(context->getRef(), affineMap);
610*436c6c9cSStella Laurenzo           },
611*436c6c9cSStella Laurenzo           py::arg("permutation"), py::arg("context") = py::none(),
612*436c6c9cSStella Laurenzo           "Gets an affine map that permutes its inputs.")
613*436c6c9cSStella Laurenzo       .def("get_submap",
614*436c6c9cSStella Laurenzo            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
615*436c6c9cSStella Laurenzo              intptr_t numResults = mlirAffineMapGetNumResults(self);
616*436c6c9cSStella Laurenzo              for (intptr_t pos : resultPos) {
617*436c6c9cSStella Laurenzo                if (pos < 0 || pos >= numResults)
618*436c6c9cSStella Laurenzo                  throw py::value_error("result position out of bounds");
619*436c6c9cSStella Laurenzo              }
620*436c6c9cSStella Laurenzo              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
621*436c6c9cSStella Laurenzo                  self, resultPos.size(), resultPos.data());
622*436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
623*436c6c9cSStella Laurenzo            })
624*436c6c9cSStella Laurenzo       .def("get_major_submap",
625*436c6c9cSStella Laurenzo            [](PyAffineMap &self, intptr_t nResults) {
626*436c6c9cSStella Laurenzo              if (nResults >= mlirAffineMapGetNumResults(self))
627*436c6c9cSStella Laurenzo                throw py::value_error("number of results out of bounds");
628*436c6c9cSStella Laurenzo              MlirAffineMap affineMap =
629*436c6c9cSStella Laurenzo                  mlirAffineMapGetMajorSubMap(self, nResults);
630*436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
631*436c6c9cSStella Laurenzo            })
632*436c6c9cSStella Laurenzo       .def("get_minor_submap",
633*436c6c9cSStella Laurenzo            [](PyAffineMap &self, intptr_t nResults) {
634*436c6c9cSStella Laurenzo              if (nResults >= mlirAffineMapGetNumResults(self))
635*436c6c9cSStella Laurenzo                throw py::value_error("number of results out of bounds");
636*436c6c9cSStella Laurenzo              MlirAffineMap affineMap =
637*436c6c9cSStella Laurenzo                  mlirAffineMapGetMinorSubMap(self, nResults);
638*436c6c9cSStella Laurenzo              return PyAffineMap(self.getContext(), affineMap);
639*436c6c9cSStella Laurenzo            })
640*436c6c9cSStella Laurenzo       .def_property_readonly(
641*436c6c9cSStella Laurenzo           "is_permutation",
642*436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
643*436c6c9cSStella Laurenzo       .def_property_readonly("is_projected_permutation",
644*436c6c9cSStella Laurenzo                              [](PyAffineMap &self) {
645*436c6c9cSStella Laurenzo                                return mlirAffineMapIsProjectedPermutation(self);
646*436c6c9cSStella Laurenzo                              })
647*436c6c9cSStella Laurenzo       .def_property_readonly(
648*436c6c9cSStella Laurenzo           "n_dims",
649*436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
650*436c6c9cSStella Laurenzo       .def_property_readonly(
651*436c6c9cSStella Laurenzo           "n_inputs",
652*436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
653*436c6c9cSStella Laurenzo       .def_property_readonly(
654*436c6c9cSStella Laurenzo           "n_symbols",
655*436c6c9cSStella Laurenzo           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
656*436c6c9cSStella Laurenzo       .def_property_readonly("results", [](PyAffineMap &self) {
657*436c6c9cSStella Laurenzo         return PyAffineMapExprList(self);
658*436c6c9cSStella Laurenzo       });
659*436c6c9cSStella Laurenzo   PyAffineMapExprList::bind(m);
660*436c6c9cSStella Laurenzo 
661*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
662*436c6c9cSStella Laurenzo   // Mapping of PyIntegerSet.
663*436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
664*436c6c9cSStella Laurenzo   py::class_<PyIntegerSet>(m, "IntegerSet")
665*436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
666*436c6c9cSStella Laurenzo                              &PyIntegerSet::getCapsule)
667*436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
668*436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self,
669*436c6c9cSStella Laurenzo                         PyIntegerSet &other) { return self == other; })
670*436c6c9cSStella Laurenzo       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
671*436c6c9cSStella Laurenzo       .def("__str__",
672*436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
673*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
674*436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
675*436c6c9cSStella Laurenzo                                  printAccum.getUserData());
676*436c6c9cSStella Laurenzo              return printAccum.join();
677*436c6c9cSStella Laurenzo            })
678*436c6c9cSStella Laurenzo       .def("__repr__",
679*436c6c9cSStella Laurenzo            [](PyIntegerSet &self) {
680*436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
681*436c6c9cSStella Laurenzo              printAccum.parts.append("IntegerSet(");
682*436c6c9cSStella Laurenzo              mlirIntegerSetPrint(self, printAccum.getCallback(),
683*436c6c9cSStella Laurenzo                                  printAccum.getUserData());
684*436c6c9cSStella Laurenzo              printAccum.parts.append(")");
685*436c6c9cSStella Laurenzo              return printAccum.join();
686*436c6c9cSStella Laurenzo            })
687*436c6c9cSStella Laurenzo       .def_property_readonly(
688*436c6c9cSStella Laurenzo           "context",
689*436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return self.getContext().getObject(); })
690*436c6c9cSStella Laurenzo       .def(
691*436c6c9cSStella Laurenzo           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
692*436c6c9cSStella Laurenzo           kDumpDocstring)
693*436c6c9cSStella Laurenzo       .def_static(
694*436c6c9cSStella Laurenzo           "get",
695*436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
696*436c6c9cSStella Laurenzo              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
697*436c6c9cSStella Laurenzo             if (exprs.size() != eqFlags.size())
698*436c6c9cSStella Laurenzo               throw py::value_error(
699*436c6c9cSStella Laurenzo                   "Expected the number of constraints to match "
700*436c6c9cSStella Laurenzo                   "that of equality flags");
701*436c6c9cSStella Laurenzo             if (exprs.empty())
702*436c6c9cSStella Laurenzo               throw py::value_error("Expected non-empty list of constraints");
703*436c6c9cSStella Laurenzo 
704*436c6c9cSStella Laurenzo             // Copy over to a SmallVector because std::vector has a
705*436c6c9cSStella Laurenzo             // specialization for booleans that packs data and does not
706*436c6c9cSStella Laurenzo             // expose a `bool *`.
707*436c6c9cSStella Laurenzo             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
708*436c6c9cSStella Laurenzo 
709*436c6c9cSStella Laurenzo             SmallVector<MlirAffineExpr> affineExprs;
710*436c6c9cSStella Laurenzo             pyListToVector<PyAffineExpr>(exprs, affineExprs,
711*436c6c9cSStella Laurenzo                                          "attempting to create an IntegerSet");
712*436c6c9cSStella Laurenzo             MlirIntegerSet set = mlirIntegerSetGet(
713*436c6c9cSStella Laurenzo                 context->get(), numDims, numSymbols, exprs.size(),
714*436c6c9cSStella Laurenzo                 affineExprs.data(), flags.data());
715*436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
716*436c6c9cSStella Laurenzo           },
717*436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
718*436c6c9cSStella Laurenzo           py::arg("eq_flags"), py::arg("context") = py::none())
719*436c6c9cSStella Laurenzo       .def_static(
720*436c6c9cSStella Laurenzo           "get_empty",
721*436c6c9cSStella Laurenzo           [](intptr_t numDims, intptr_t numSymbols,
722*436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
723*436c6c9cSStella Laurenzo             MlirIntegerSet set =
724*436c6c9cSStella Laurenzo                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
725*436c6c9cSStella Laurenzo             return PyIntegerSet(context->getRef(), set);
726*436c6c9cSStella Laurenzo           },
727*436c6c9cSStella Laurenzo           py::arg("num_dims"), py::arg("num_symbols"),
728*436c6c9cSStella Laurenzo           py::arg("context") = py::none())
729*436c6c9cSStella Laurenzo       .def("get_replaced",
730*436c6c9cSStella Laurenzo            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
731*436c6c9cSStella Laurenzo               intptr_t numResultDims, intptr_t numResultSymbols) {
732*436c6c9cSStella Laurenzo              if (static_cast<intptr_t>(dimExprs.size()) !=
733*436c6c9cSStella Laurenzo                  mlirIntegerSetGetNumDims(self))
734*436c6c9cSStella Laurenzo                throw py::value_error(
735*436c6c9cSStella Laurenzo                    "Expected the number of dimension replacement expressions "
736*436c6c9cSStella Laurenzo                    "to match that of dimensions");
737*436c6c9cSStella Laurenzo              if (static_cast<intptr_t>(symbolExprs.size()) !=
738*436c6c9cSStella Laurenzo                  mlirIntegerSetGetNumSymbols(self))
739*436c6c9cSStella Laurenzo                throw py::value_error(
740*436c6c9cSStella Laurenzo                    "Expected the number of symbol replacement expressions "
741*436c6c9cSStella Laurenzo                    "to match that of symbols");
742*436c6c9cSStella Laurenzo 
743*436c6c9cSStella Laurenzo              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
744*436c6c9cSStella Laurenzo              pyListToVector<PyAffineExpr>(
745*436c6c9cSStella Laurenzo                  dimExprs, dimAffineExprs,
746*436c6c9cSStella Laurenzo                  "attempting to create an IntegerSet by replacing dimensions");
747*436c6c9cSStella Laurenzo              pyListToVector<PyAffineExpr>(
748*436c6c9cSStella Laurenzo                  symbolExprs, symbolAffineExprs,
749*436c6c9cSStella Laurenzo                  "attempting to create an IntegerSet by replacing symbols");
750*436c6c9cSStella Laurenzo              MlirIntegerSet set = mlirIntegerSetReplaceGet(
751*436c6c9cSStella Laurenzo                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
752*436c6c9cSStella Laurenzo                  numResultDims, numResultSymbols);
753*436c6c9cSStella Laurenzo              return PyIntegerSet(self.getContext(), set);
754*436c6c9cSStella Laurenzo            })
755*436c6c9cSStella Laurenzo       .def_property_readonly("is_canonical_empty",
756*436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
757*436c6c9cSStella Laurenzo                                return mlirIntegerSetIsCanonicalEmpty(self);
758*436c6c9cSStella Laurenzo                              })
759*436c6c9cSStella Laurenzo       .def_property_readonly(
760*436c6c9cSStella Laurenzo           "n_dims",
761*436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
762*436c6c9cSStella Laurenzo       .def_property_readonly(
763*436c6c9cSStella Laurenzo           "n_symbols",
764*436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
765*436c6c9cSStella Laurenzo       .def_property_readonly(
766*436c6c9cSStella Laurenzo           "n_inputs",
767*436c6c9cSStella Laurenzo           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
768*436c6c9cSStella Laurenzo       .def_property_readonly("n_equalities",
769*436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
770*436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumEqualities(self);
771*436c6c9cSStella Laurenzo                              })
772*436c6c9cSStella Laurenzo       .def_property_readonly("n_inequalities",
773*436c6c9cSStella Laurenzo                              [](PyIntegerSet &self) {
774*436c6c9cSStella Laurenzo                                return mlirIntegerSetGetNumInequalities(self);
775*436c6c9cSStella Laurenzo                              })
776*436c6c9cSStella Laurenzo       .def_property_readonly("constraints", [](PyIntegerSet &self) {
777*436c6c9cSStella Laurenzo         return PyIntegerSetConstraintList(self);
778*436c6c9cSStella Laurenzo       });
779*436c6c9cSStella Laurenzo   PyIntegerSetConstraint::bind(m);
780*436c6c9cSStella Laurenzo   PyIntegerSetConstraintList::bind(m);
781*436c6c9cSStella Laurenzo }
782