1*436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2*436c6c9cSStella Laurenzo // 3*436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5*436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*436c6c9cSStella Laurenzo // 7*436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8*436c6c9cSStella Laurenzo 9*436c6c9cSStella Laurenzo #include "IRModule.h" 10*436c6c9cSStella Laurenzo 11*436c6c9cSStella Laurenzo #include "PybindUtils.h" 12*436c6c9cSStella Laurenzo 13*436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h" 14*436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 15*436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h" 16*436c6c9cSStella Laurenzo 17*436c6c9cSStella Laurenzo namespace py = pybind11; 18*436c6c9cSStella Laurenzo using namespace mlir; 19*436c6c9cSStella Laurenzo using namespace mlir::python; 20*436c6c9cSStella Laurenzo 21*436c6c9cSStella Laurenzo using llvm::SmallVector; 22*436c6c9cSStella Laurenzo using llvm::StringRef; 23*436c6c9cSStella Laurenzo using llvm::Twine; 24*436c6c9cSStella Laurenzo 25*436c6c9cSStella Laurenzo static const char kDumpDocstring[] = 26*436c6c9cSStella Laurenzo R"(Dumps a debug representation of the object to stderr.)"; 27*436c6c9cSStella Laurenzo 28*436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the 29*436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments). 30*436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller 31*436c6c9cSStella Laurenzo /// was attempting to do. 32*436c6c9cSStella Laurenzo template <typename PyType, typename CType> 33*436c6c9cSStella Laurenzo static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34*436c6c9cSStella Laurenzo StringRef action) { 35*436c6c9cSStella Laurenzo result.reserve(py::len(list)); 36*436c6c9cSStella Laurenzo for (py::handle item : list) { 37*436c6c9cSStella Laurenzo try { 38*436c6c9cSStella Laurenzo result.push_back(item.cast<PyType>()); 39*436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 40*436c6c9cSStella Laurenzo std::string msg = (llvm::Twine("Invalid expression when ") + action + 41*436c6c9cSStella Laurenzo " (" + err.what() + ")") 42*436c6c9cSStella Laurenzo .str(); 43*436c6c9cSStella Laurenzo throw py::cast_error(msg); 44*436c6c9cSStella Laurenzo } catch (py::reference_cast_error &err) { 45*436c6c9cSStella Laurenzo std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46*436c6c9cSStella Laurenzo action + " (" + err.what() + ")") 47*436c6c9cSStella Laurenzo .str(); 48*436c6c9cSStella Laurenzo throw py::cast_error(msg); 49*436c6c9cSStella Laurenzo } 50*436c6c9cSStella Laurenzo } 51*436c6c9cSStella Laurenzo } 52*436c6c9cSStella Laurenzo 53*436c6c9cSStella Laurenzo template <typename PermutationTy> 54*436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) { 55*436c6c9cSStella Laurenzo llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56*436c6c9cSStella Laurenzo for (auto val : permutation) { 57*436c6c9cSStella Laurenzo if (val < permutation.size()) { 58*436c6c9cSStella Laurenzo if (seen[val]) 59*436c6c9cSStella Laurenzo return false; 60*436c6c9cSStella Laurenzo seen[val] = true; 61*436c6c9cSStella Laurenzo continue; 62*436c6c9cSStella Laurenzo } 63*436c6c9cSStella Laurenzo return false; 64*436c6c9cSStella Laurenzo } 65*436c6c9cSStella Laurenzo return true; 66*436c6c9cSStella Laurenzo } 67*436c6c9cSStella Laurenzo 68*436c6c9cSStella Laurenzo namespace { 69*436c6c9cSStella Laurenzo 70*436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71*436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be 72*436c6c9cSStella Laurenzo /// modeled by specifying BaseTy. 73*436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74*436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy { 75*436c6c9cSStella Laurenzo public: 76*436c6c9cSStella Laurenzo // Derived classes must define statics for: 77*436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 78*436c6c9cSStella Laurenzo // const char *pyClassName 79*436c6c9cSStella Laurenzo // and redefine bindDerived. 80*436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, BaseTy>; 81*436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirAffineExpr); 82*436c6c9cSStella Laurenzo 83*436c6c9cSStella Laurenzo PyConcreteAffineExpr() = default; 84*436c6c9cSStella Laurenzo PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85*436c6c9cSStella Laurenzo : BaseTy(std::move(contextRef), affineExpr) {} 86*436c6c9cSStella Laurenzo PyConcreteAffineExpr(PyAffineExpr &orig) 87*436c6c9cSStella Laurenzo : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88*436c6c9cSStella Laurenzo 89*436c6c9cSStella Laurenzo static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90*436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig)) { 91*436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 93*436c6c9cSStella Laurenzo Twine("Cannot cast affine expression to ") + 94*436c6c9cSStella Laurenzo DerivedTy::pyClassName + " (from " + origRepr + ")"); 95*436c6c9cSStella Laurenzo } 96*436c6c9cSStella Laurenzo return orig; 97*436c6c9cSStella Laurenzo } 98*436c6c9cSStella Laurenzo 99*436c6c9cSStella Laurenzo static void bind(py::module &m) { 100*436c6c9cSStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName); 101*436c6c9cSStella Laurenzo cls.def(py::init<PyAffineExpr &>()); 102*436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 103*436c6c9cSStella Laurenzo } 104*436c6c9cSStella Laurenzo 105*436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 106*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 107*436c6c9cSStella Laurenzo }; 108*436c6c9cSStella Laurenzo 109*436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 110*436c6c9cSStella Laurenzo public: 111*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 112*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineConstantExpr"; 113*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 114*436c6c9cSStella Laurenzo 115*436c6c9cSStella Laurenzo static PyAffineConstantExpr get(intptr_t value, 116*436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 117*436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = 118*436c6c9cSStella Laurenzo mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 119*436c6c9cSStella Laurenzo return PyAffineConstantExpr(context->getRef(), affineExpr); 120*436c6c9cSStella Laurenzo } 121*436c6c9cSStella Laurenzo 122*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 123*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 124*436c6c9cSStella Laurenzo py::arg("context") = py::none()); 125*436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 126*436c6c9cSStella Laurenzo return mlirAffineConstantExprGetValue(self); 127*436c6c9cSStella Laurenzo }); 128*436c6c9cSStella Laurenzo } 129*436c6c9cSStella Laurenzo }; 130*436c6c9cSStella Laurenzo 131*436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 132*436c6c9cSStella Laurenzo public: 133*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 134*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineDimExpr"; 135*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 136*436c6c9cSStella Laurenzo 137*436c6c9cSStella Laurenzo static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 138*436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 139*436c6c9cSStella Laurenzo return PyAffineDimExpr(context->getRef(), affineExpr); 140*436c6c9cSStella Laurenzo } 141*436c6c9cSStella Laurenzo 142*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 143*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 144*436c6c9cSStella Laurenzo py::arg("context") = py::none()); 145*436c6c9cSStella Laurenzo c.def_property_readonly("position", [](PyAffineDimExpr &self) { 146*436c6c9cSStella Laurenzo return mlirAffineDimExprGetPosition(self); 147*436c6c9cSStella Laurenzo }); 148*436c6c9cSStella Laurenzo } 149*436c6c9cSStella Laurenzo }; 150*436c6c9cSStella Laurenzo 151*436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 152*436c6c9cSStella Laurenzo public: 153*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 154*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineSymbolExpr"; 155*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 156*436c6c9cSStella Laurenzo 157*436c6c9cSStella Laurenzo static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 158*436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 159*436c6c9cSStella Laurenzo return PyAffineSymbolExpr(context->getRef(), affineExpr); 160*436c6c9cSStella Laurenzo } 161*436c6c9cSStella Laurenzo 162*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 163*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 164*436c6c9cSStella Laurenzo py::arg("context") = py::none()); 165*436c6c9cSStella Laurenzo c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 166*436c6c9cSStella Laurenzo return mlirAffineSymbolExprGetPosition(self); 167*436c6c9cSStella Laurenzo }); 168*436c6c9cSStella Laurenzo } 169*436c6c9cSStella Laurenzo }; 170*436c6c9cSStella Laurenzo 171*436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 172*436c6c9cSStella Laurenzo public: 173*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 174*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineBinaryExpr"; 175*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 176*436c6c9cSStella Laurenzo 177*436c6c9cSStella Laurenzo PyAffineExpr lhs() { 178*436c6c9cSStella Laurenzo MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 179*436c6c9cSStella Laurenzo return PyAffineExpr(getContext(), lhsExpr); 180*436c6c9cSStella Laurenzo } 181*436c6c9cSStella Laurenzo 182*436c6c9cSStella Laurenzo PyAffineExpr rhs() { 183*436c6c9cSStella Laurenzo MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 184*436c6c9cSStella Laurenzo return PyAffineExpr(getContext(), rhsExpr); 185*436c6c9cSStella Laurenzo } 186*436c6c9cSStella Laurenzo 187*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 188*436c6c9cSStella Laurenzo c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 189*436c6c9cSStella Laurenzo c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 190*436c6c9cSStella Laurenzo } 191*436c6c9cSStella Laurenzo }; 192*436c6c9cSStella Laurenzo 193*436c6c9cSStella Laurenzo class PyAffineAddExpr 194*436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 195*436c6c9cSStella Laurenzo public: 196*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 197*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineAddExpr"; 198*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 199*436c6c9cSStella Laurenzo 200*436c6c9cSStella Laurenzo static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 201*436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 202*436c6c9cSStella Laurenzo return PyAffineAddExpr(lhs.getContext(), expr); 203*436c6c9cSStella Laurenzo } 204*436c6c9cSStella Laurenzo 205*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 206*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineAddExpr::get); 207*436c6c9cSStella Laurenzo } 208*436c6c9cSStella Laurenzo }; 209*436c6c9cSStella Laurenzo 210*436c6c9cSStella Laurenzo class PyAffineMulExpr 211*436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 212*436c6c9cSStella Laurenzo public: 213*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 214*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMulExpr"; 215*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 216*436c6c9cSStella Laurenzo 217*436c6c9cSStella Laurenzo static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 218*436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 219*436c6c9cSStella Laurenzo return PyAffineMulExpr(lhs.getContext(), expr); 220*436c6c9cSStella Laurenzo } 221*436c6c9cSStella Laurenzo 222*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 223*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineMulExpr::get); 224*436c6c9cSStella Laurenzo } 225*436c6c9cSStella Laurenzo }; 226*436c6c9cSStella Laurenzo 227*436c6c9cSStella Laurenzo class PyAffineModExpr 228*436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 229*436c6c9cSStella Laurenzo public: 230*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 231*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineModExpr"; 232*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 233*436c6c9cSStella Laurenzo 234*436c6c9cSStella Laurenzo static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 235*436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 236*436c6c9cSStella Laurenzo return PyAffineModExpr(lhs.getContext(), expr); 237*436c6c9cSStella Laurenzo } 238*436c6c9cSStella Laurenzo 239*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 240*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineModExpr::get); 241*436c6c9cSStella Laurenzo } 242*436c6c9cSStella Laurenzo }; 243*436c6c9cSStella Laurenzo 244*436c6c9cSStella Laurenzo class PyAffineFloorDivExpr 245*436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 246*436c6c9cSStella Laurenzo public: 247*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 248*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineFloorDivExpr"; 249*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 250*436c6c9cSStella Laurenzo 251*436c6c9cSStella Laurenzo static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 252*436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 253*436c6c9cSStella Laurenzo return PyAffineFloorDivExpr(lhs.getContext(), expr); 254*436c6c9cSStella Laurenzo } 255*436c6c9cSStella Laurenzo 256*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 257*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineFloorDivExpr::get); 258*436c6c9cSStella Laurenzo } 259*436c6c9cSStella Laurenzo }; 260*436c6c9cSStella Laurenzo 261*436c6c9cSStella Laurenzo class PyAffineCeilDivExpr 262*436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 263*436c6c9cSStella Laurenzo public: 264*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 265*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineCeilDivExpr"; 266*436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 267*436c6c9cSStella Laurenzo 268*436c6c9cSStella Laurenzo static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 269*436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 270*436c6c9cSStella Laurenzo return PyAffineCeilDivExpr(lhs.getContext(), expr); 271*436c6c9cSStella Laurenzo } 272*436c6c9cSStella Laurenzo 273*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 274*436c6c9cSStella Laurenzo c.def_static("get", &PyAffineCeilDivExpr::get); 275*436c6c9cSStella Laurenzo } 276*436c6c9cSStella Laurenzo }; 277*436c6c9cSStella Laurenzo 278*436c6c9cSStella Laurenzo } // namespace 279*436c6c9cSStella Laurenzo 280*436c6c9cSStella Laurenzo bool PyAffineExpr::operator==(const PyAffineExpr &other) { 281*436c6c9cSStella Laurenzo return mlirAffineExprEqual(affineExpr, other.affineExpr); 282*436c6c9cSStella Laurenzo } 283*436c6c9cSStella Laurenzo 284*436c6c9cSStella Laurenzo py::object PyAffineExpr::getCapsule() { 285*436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>( 286*436c6c9cSStella Laurenzo mlirPythonAffineExprToCapsule(*this)); 287*436c6c9cSStella Laurenzo } 288*436c6c9cSStella Laurenzo 289*436c6c9cSStella Laurenzo PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 290*436c6c9cSStella Laurenzo MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 291*436c6c9cSStella Laurenzo if (mlirAffineExprIsNull(rawAffineExpr)) 292*436c6c9cSStella Laurenzo throw py::error_already_set(); 293*436c6c9cSStella Laurenzo return PyAffineExpr( 294*436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 295*436c6c9cSStella Laurenzo rawAffineExpr); 296*436c6c9cSStella Laurenzo } 297*436c6c9cSStella Laurenzo 298*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 299*436c6c9cSStella Laurenzo // PyAffineMap and utilities. 300*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 301*436c6c9cSStella Laurenzo namespace { 302*436c6c9cSStella Laurenzo 303*436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are 304*436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both 305*436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother 306*436c6c9cSStella Laurenzo /// with lifetime extension. 307*436c6c9cSStella Laurenzo class PyAffineMapExprList 308*436c6c9cSStella Laurenzo : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 309*436c6c9cSStella Laurenzo public: 310*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineExprList"; 311*436c6c9cSStella Laurenzo 312*436c6c9cSStella Laurenzo PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 313*436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 314*436c6c9cSStella Laurenzo : Sliceable(startIndex, 315*436c6c9cSStella Laurenzo length == -1 ? mlirAffineMapGetNumResults(map) : length, 316*436c6c9cSStella Laurenzo step), 317*436c6c9cSStella Laurenzo affineMap(map) {} 318*436c6c9cSStella Laurenzo 319*436c6c9cSStella Laurenzo intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 320*436c6c9cSStella Laurenzo 321*436c6c9cSStella Laurenzo PyAffineExpr getElement(intptr_t pos) { 322*436c6c9cSStella Laurenzo return PyAffineExpr(affineMap.getContext(), 323*436c6c9cSStella Laurenzo mlirAffineMapGetResult(affineMap, pos)); 324*436c6c9cSStella Laurenzo } 325*436c6c9cSStella Laurenzo 326*436c6c9cSStella Laurenzo PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 327*436c6c9cSStella Laurenzo intptr_t step) { 328*436c6c9cSStella Laurenzo return PyAffineMapExprList(affineMap, startIndex, length, step); 329*436c6c9cSStella Laurenzo } 330*436c6c9cSStella Laurenzo 331*436c6c9cSStella Laurenzo private: 332*436c6c9cSStella Laurenzo PyAffineMap affineMap; 333*436c6c9cSStella Laurenzo }; 334*436c6c9cSStella Laurenzo } // end namespace 335*436c6c9cSStella Laurenzo 336*436c6c9cSStella Laurenzo bool PyAffineMap::operator==(const PyAffineMap &other) { 337*436c6c9cSStella Laurenzo return mlirAffineMapEqual(affineMap, other.affineMap); 338*436c6c9cSStella Laurenzo } 339*436c6c9cSStella Laurenzo 340*436c6c9cSStella Laurenzo py::object PyAffineMap::getCapsule() { 341*436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 342*436c6c9cSStella Laurenzo } 343*436c6c9cSStella Laurenzo 344*436c6c9cSStella Laurenzo PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 345*436c6c9cSStella Laurenzo MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 346*436c6c9cSStella Laurenzo if (mlirAffineMapIsNull(rawAffineMap)) 347*436c6c9cSStella Laurenzo throw py::error_already_set(); 348*436c6c9cSStella Laurenzo return PyAffineMap( 349*436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 350*436c6c9cSStella Laurenzo rawAffineMap); 351*436c6c9cSStella Laurenzo } 352*436c6c9cSStella Laurenzo 353*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 354*436c6c9cSStella Laurenzo // PyIntegerSet and utilities. 355*436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 356*436c6c9cSStella Laurenzo namespace { 357*436c6c9cSStella Laurenzo 358*436c6c9cSStella Laurenzo class PyIntegerSetConstraint { 359*436c6c9cSStella Laurenzo public: 360*436c6c9cSStella Laurenzo PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 361*436c6c9cSStella Laurenzo 362*436c6c9cSStella Laurenzo PyAffineExpr getExpr() { 363*436c6c9cSStella Laurenzo return PyAffineExpr(set.getContext(), 364*436c6c9cSStella Laurenzo mlirIntegerSetGetConstraint(set, pos)); 365*436c6c9cSStella Laurenzo } 366*436c6c9cSStella Laurenzo 367*436c6c9cSStella Laurenzo bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 368*436c6c9cSStella Laurenzo 369*436c6c9cSStella Laurenzo static void bind(py::module &m) { 370*436c6c9cSStella Laurenzo py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint") 371*436c6c9cSStella Laurenzo .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 372*436c6c9cSStella Laurenzo .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 373*436c6c9cSStella Laurenzo } 374*436c6c9cSStella Laurenzo 375*436c6c9cSStella Laurenzo private: 376*436c6c9cSStella Laurenzo PyIntegerSet set; 377*436c6c9cSStella Laurenzo intptr_t pos; 378*436c6c9cSStella Laurenzo }; 379*436c6c9cSStella Laurenzo 380*436c6c9cSStella Laurenzo class PyIntegerSetConstraintList 381*436c6c9cSStella Laurenzo : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 382*436c6c9cSStella Laurenzo public: 383*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerSetConstraintList"; 384*436c6c9cSStella Laurenzo 385*436c6c9cSStella Laurenzo PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 386*436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 387*436c6c9cSStella Laurenzo : Sliceable(startIndex, 388*436c6c9cSStella Laurenzo length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 389*436c6c9cSStella Laurenzo step), 390*436c6c9cSStella Laurenzo set(set) {} 391*436c6c9cSStella Laurenzo 392*436c6c9cSStella Laurenzo intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 393*436c6c9cSStella Laurenzo 394*436c6c9cSStella Laurenzo PyIntegerSetConstraint getElement(intptr_t pos) { 395*436c6c9cSStella Laurenzo return PyIntegerSetConstraint(set, pos); 396*436c6c9cSStella Laurenzo } 397*436c6c9cSStella Laurenzo 398*436c6c9cSStella Laurenzo PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 399*436c6c9cSStella Laurenzo intptr_t step) { 400*436c6c9cSStella Laurenzo return PyIntegerSetConstraintList(set, startIndex, length, step); 401*436c6c9cSStella Laurenzo } 402*436c6c9cSStella Laurenzo 403*436c6c9cSStella Laurenzo private: 404*436c6c9cSStella Laurenzo PyIntegerSet set; 405*436c6c9cSStella Laurenzo }; 406*436c6c9cSStella Laurenzo } // namespace 407*436c6c9cSStella Laurenzo 408*436c6c9cSStella Laurenzo bool PyIntegerSet::operator==(const PyIntegerSet &other) { 409*436c6c9cSStella Laurenzo return mlirIntegerSetEqual(integerSet, other.integerSet); 410*436c6c9cSStella Laurenzo } 411*436c6c9cSStella Laurenzo 412*436c6c9cSStella Laurenzo py::object PyIntegerSet::getCapsule() { 413*436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>( 414*436c6c9cSStella Laurenzo mlirPythonIntegerSetToCapsule(*this)); 415*436c6c9cSStella Laurenzo } 416*436c6c9cSStella Laurenzo 417*436c6c9cSStella Laurenzo PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 418*436c6c9cSStella Laurenzo MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 419*436c6c9cSStella Laurenzo if (mlirIntegerSetIsNull(rawIntegerSet)) 420*436c6c9cSStella Laurenzo throw py::error_already_set(); 421*436c6c9cSStella Laurenzo return PyIntegerSet( 422*436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 423*436c6c9cSStella Laurenzo rawIntegerSet); 424*436c6c9cSStella Laurenzo } 425*436c6c9cSStella Laurenzo 426*436c6c9cSStella Laurenzo void mlir::python::populateIRAffine(py::module &m) { 427*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 428*436c6c9cSStella Laurenzo // Mapping of PyAffineExpr and derived classes. 429*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 430*436c6c9cSStella Laurenzo py::class_<PyAffineExpr>(m, "AffineExpr") 431*436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 432*436c6c9cSStella Laurenzo &PyAffineExpr::getCapsule) 433*436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 434*436c6c9cSStella Laurenzo .def("__add__", 435*436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 436*436c6c9cSStella Laurenzo return PyAffineAddExpr::get(self, other); 437*436c6c9cSStella Laurenzo }) 438*436c6c9cSStella Laurenzo .def("__mul__", 439*436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 440*436c6c9cSStella Laurenzo return PyAffineMulExpr::get(self, other); 441*436c6c9cSStella Laurenzo }) 442*436c6c9cSStella Laurenzo .def("__mod__", 443*436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 444*436c6c9cSStella Laurenzo return PyAffineModExpr::get(self, other); 445*436c6c9cSStella Laurenzo }) 446*436c6c9cSStella Laurenzo .def("__sub__", 447*436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 448*436c6c9cSStella Laurenzo auto negOne = 449*436c6c9cSStella Laurenzo PyAffineConstantExpr::get(-1, *self.getContext().get()); 450*436c6c9cSStella Laurenzo return PyAffineAddExpr::get(self, 451*436c6c9cSStella Laurenzo PyAffineMulExpr::get(negOne, other)); 452*436c6c9cSStella Laurenzo }) 453*436c6c9cSStella Laurenzo .def("__eq__", [](PyAffineExpr &self, 454*436c6c9cSStella Laurenzo PyAffineExpr &other) { return self == other; }) 455*436c6c9cSStella Laurenzo .def("__eq__", 456*436c6c9cSStella Laurenzo [](PyAffineExpr &self, py::object &other) { return false; }) 457*436c6c9cSStella Laurenzo .def("__str__", 458*436c6c9cSStella Laurenzo [](PyAffineExpr &self) { 459*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 460*436c6c9cSStella Laurenzo mlirAffineExprPrint(self, printAccum.getCallback(), 461*436c6c9cSStella Laurenzo printAccum.getUserData()); 462*436c6c9cSStella Laurenzo return printAccum.join(); 463*436c6c9cSStella Laurenzo }) 464*436c6c9cSStella Laurenzo .def("__repr__", 465*436c6c9cSStella Laurenzo [](PyAffineExpr &self) { 466*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 467*436c6c9cSStella Laurenzo printAccum.parts.append("AffineExpr("); 468*436c6c9cSStella Laurenzo mlirAffineExprPrint(self, printAccum.getCallback(), 469*436c6c9cSStella Laurenzo printAccum.getUserData()); 470*436c6c9cSStella Laurenzo printAccum.parts.append(")"); 471*436c6c9cSStella Laurenzo return printAccum.join(); 472*436c6c9cSStella Laurenzo }) 473*436c6c9cSStella Laurenzo .def_property_readonly( 474*436c6c9cSStella Laurenzo "context", 475*436c6c9cSStella Laurenzo [](PyAffineExpr &self) { return self.getContext().getObject(); }) 476*436c6c9cSStella Laurenzo .def_static( 477*436c6c9cSStella Laurenzo "get_add", &PyAffineAddExpr::get, 478*436c6c9cSStella Laurenzo "Gets an affine expression containing a sum of two expressions.") 479*436c6c9cSStella Laurenzo .def_static( 480*436c6c9cSStella Laurenzo "get_mul", &PyAffineMulExpr::get, 481*436c6c9cSStella Laurenzo "Gets an affine expression containing a product of two expressions.") 482*436c6c9cSStella Laurenzo .def_static("get_mod", &PyAffineModExpr::get, 483*436c6c9cSStella Laurenzo "Gets an affine expression containing the modulo of dividing " 484*436c6c9cSStella Laurenzo "one expression by another.") 485*436c6c9cSStella Laurenzo .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 486*436c6c9cSStella Laurenzo "Gets an affine expression containing the rounded-down " 487*436c6c9cSStella Laurenzo "result of dividing one expression by another.") 488*436c6c9cSStella Laurenzo .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 489*436c6c9cSStella Laurenzo "Gets an affine expression containing the rounded-up result " 490*436c6c9cSStella Laurenzo "of dividing one expression by another.") 491*436c6c9cSStella Laurenzo .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 492*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 493*436c6c9cSStella Laurenzo "Gets a constant affine expression with the given value.") 494*436c6c9cSStella Laurenzo .def_static( 495*436c6c9cSStella Laurenzo "get_dim", &PyAffineDimExpr::get, py::arg("position"), 496*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 497*436c6c9cSStella Laurenzo "Gets an affine expression of a dimension at the given position.") 498*436c6c9cSStella Laurenzo .def_static( 499*436c6c9cSStella Laurenzo "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 500*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 501*436c6c9cSStella Laurenzo "Gets an affine expression of a symbol at the given position.") 502*436c6c9cSStella Laurenzo .def( 503*436c6c9cSStella Laurenzo "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 504*436c6c9cSStella Laurenzo kDumpDocstring); 505*436c6c9cSStella Laurenzo PyAffineConstantExpr::bind(m); 506*436c6c9cSStella Laurenzo PyAffineDimExpr::bind(m); 507*436c6c9cSStella Laurenzo PyAffineSymbolExpr::bind(m); 508*436c6c9cSStella Laurenzo PyAffineBinaryExpr::bind(m); 509*436c6c9cSStella Laurenzo PyAffineAddExpr::bind(m); 510*436c6c9cSStella Laurenzo PyAffineMulExpr::bind(m); 511*436c6c9cSStella Laurenzo PyAffineModExpr::bind(m); 512*436c6c9cSStella Laurenzo PyAffineFloorDivExpr::bind(m); 513*436c6c9cSStella Laurenzo PyAffineCeilDivExpr::bind(m); 514*436c6c9cSStella Laurenzo 515*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 516*436c6c9cSStella Laurenzo // Mapping of PyAffineMap. 517*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 518*436c6c9cSStella Laurenzo py::class_<PyAffineMap>(m, "AffineMap") 519*436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 520*436c6c9cSStella Laurenzo &PyAffineMap::getCapsule) 521*436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 522*436c6c9cSStella Laurenzo .def("__eq__", 523*436c6c9cSStella Laurenzo [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 524*436c6c9cSStella Laurenzo .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 525*436c6c9cSStella Laurenzo .def("__str__", 526*436c6c9cSStella Laurenzo [](PyAffineMap &self) { 527*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 528*436c6c9cSStella Laurenzo mlirAffineMapPrint(self, printAccum.getCallback(), 529*436c6c9cSStella Laurenzo printAccum.getUserData()); 530*436c6c9cSStella Laurenzo return printAccum.join(); 531*436c6c9cSStella Laurenzo }) 532*436c6c9cSStella Laurenzo .def("__repr__", 533*436c6c9cSStella Laurenzo [](PyAffineMap &self) { 534*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 535*436c6c9cSStella Laurenzo printAccum.parts.append("AffineMap("); 536*436c6c9cSStella Laurenzo mlirAffineMapPrint(self, printAccum.getCallback(), 537*436c6c9cSStella Laurenzo printAccum.getUserData()); 538*436c6c9cSStella Laurenzo printAccum.parts.append(")"); 539*436c6c9cSStella Laurenzo return printAccum.join(); 540*436c6c9cSStella Laurenzo }) 541*436c6c9cSStella Laurenzo .def_property_readonly( 542*436c6c9cSStella Laurenzo "context", 543*436c6c9cSStella Laurenzo [](PyAffineMap &self) { return self.getContext().getObject(); }, 544*436c6c9cSStella Laurenzo "Context that owns the Affine Map") 545*436c6c9cSStella Laurenzo .def( 546*436c6c9cSStella Laurenzo "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 547*436c6c9cSStella Laurenzo kDumpDocstring) 548*436c6c9cSStella Laurenzo .def_static( 549*436c6c9cSStella Laurenzo "get", 550*436c6c9cSStella Laurenzo [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 551*436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 552*436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> affineExprs; 553*436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr, MlirAffineExpr>( 554*436c6c9cSStella Laurenzo exprs, affineExprs, "attempting to create an AffineMap"); 555*436c6c9cSStella Laurenzo MlirAffineMap map = 556*436c6c9cSStella Laurenzo mlirAffineMapGet(context->get(), dimCount, symbolCount, 557*436c6c9cSStella Laurenzo affineExprs.size(), affineExprs.data()); 558*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), map); 559*436c6c9cSStella Laurenzo }, 560*436c6c9cSStella Laurenzo py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 561*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 562*436c6c9cSStella Laurenzo "Gets a map with the given expressions as results.") 563*436c6c9cSStella Laurenzo .def_static( 564*436c6c9cSStella Laurenzo "get_constant", 565*436c6c9cSStella Laurenzo [](intptr_t value, DefaultingPyMlirContext context) { 566*436c6c9cSStella Laurenzo MlirAffineMap affineMap = 567*436c6c9cSStella Laurenzo mlirAffineMapConstantGet(context->get(), value); 568*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 569*436c6c9cSStella Laurenzo }, 570*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 571*436c6c9cSStella Laurenzo "Gets an affine map with a single constant result") 572*436c6c9cSStella Laurenzo .def_static( 573*436c6c9cSStella Laurenzo "get_empty", 574*436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 575*436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 576*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 577*436c6c9cSStella Laurenzo }, 578*436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Gets an empty affine map.") 579*436c6c9cSStella Laurenzo .def_static( 580*436c6c9cSStella Laurenzo "get_identity", 581*436c6c9cSStella Laurenzo [](intptr_t nDims, DefaultingPyMlirContext context) { 582*436c6c9cSStella Laurenzo MlirAffineMap affineMap = 583*436c6c9cSStella Laurenzo mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 584*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 585*436c6c9cSStella Laurenzo }, 586*436c6c9cSStella Laurenzo py::arg("n_dims"), py::arg("context") = py::none(), 587*436c6c9cSStella Laurenzo "Gets an identity map with the given number of dimensions.") 588*436c6c9cSStella Laurenzo .def_static( 589*436c6c9cSStella Laurenzo "get_minor_identity", 590*436c6c9cSStella Laurenzo [](intptr_t nDims, intptr_t nResults, 591*436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 592*436c6c9cSStella Laurenzo MlirAffineMap affineMap = 593*436c6c9cSStella Laurenzo mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 594*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 595*436c6c9cSStella Laurenzo }, 596*436c6c9cSStella Laurenzo py::arg("n_dims"), py::arg("n_results"), 597*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 598*436c6c9cSStella Laurenzo "Gets a minor identity map with the given number of dimensions and " 599*436c6c9cSStella Laurenzo "results.") 600*436c6c9cSStella Laurenzo .def_static( 601*436c6c9cSStella Laurenzo "get_permutation", 602*436c6c9cSStella Laurenzo [](std::vector<unsigned> permutation, 603*436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 604*436c6c9cSStella Laurenzo if (!isPermutation(permutation)) 605*436c6c9cSStella Laurenzo throw py::cast_error("Invalid permutation when attempting to " 606*436c6c9cSStella Laurenzo "create an AffineMap"); 607*436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapPermutationGet( 608*436c6c9cSStella Laurenzo context->get(), permutation.size(), permutation.data()); 609*436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 610*436c6c9cSStella Laurenzo }, 611*436c6c9cSStella Laurenzo py::arg("permutation"), py::arg("context") = py::none(), 612*436c6c9cSStella Laurenzo "Gets an affine map that permutes its inputs.") 613*436c6c9cSStella Laurenzo .def("get_submap", 614*436c6c9cSStella Laurenzo [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 615*436c6c9cSStella Laurenzo intptr_t numResults = mlirAffineMapGetNumResults(self); 616*436c6c9cSStella Laurenzo for (intptr_t pos : resultPos) { 617*436c6c9cSStella Laurenzo if (pos < 0 || pos >= numResults) 618*436c6c9cSStella Laurenzo throw py::value_error("result position out of bounds"); 619*436c6c9cSStella Laurenzo } 620*436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapGetSubMap( 621*436c6c9cSStella Laurenzo self, resultPos.size(), resultPos.data()); 622*436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 623*436c6c9cSStella Laurenzo }) 624*436c6c9cSStella Laurenzo .def("get_major_submap", 625*436c6c9cSStella Laurenzo [](PyAffineMap &self, intptr_t nResults) { 626*436c6c9cSStella Laurenzo if (nResults >= mlirAffineMapGetNumResults(self)) 627*436c6c9cSStella Laurenzo throw py::value_error("number of results out of bounds"); 628*436c6c9cSStella Laurenzo MlirAffineMap affineMap = 629*436c6c9cSStella Laurenzo mlirAffineMapGetMajorSubMap(self, nResults); 630*436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 631*436c6c9cSStella Laurenzo }) 632*436c6c9cSStella Laurenzo .def("get_minor_submap", 633*436c6c9cSStella Laurenzo [](PyAffineMap &self, intptr_t nResults) { 634*436c6c9cSStella Laurenzo if (nResults >= mlirAffineMapGetNumResults(self)) 635*436c6c9cSStella Laurenzo throw py::value_error("number of results out of bounds"); 636*436c6c9cSStella Laurenzo MlirAffineMap affineMap = 637*436c6c9cSStella Laurenzo mlirAffineMapGetMinorSubMap(self, nResults); 638*436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 639*436c6c9cSStella Laurenzo }) 640*436c6c9cSStella Laurenzo .def_property_readonly( 641*436c6c9cSStella Laurenzo "is_permutation", 642*436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 643*436c6c9cSStella Laurenzo .def_property_readonly("is_projected_permutation", 644*436c6c9cSStella Laurenzo [](PyAffineMap &self) { 645*436c6c9cSStella Laurenzo return mlirAffineMapIsProjectedPermutation(self); 646*436c6c9cSStella Laurenzo }) 647*436c6c9cSStella Laurenzo .def_property_readonly( 648*436c6c9cSStella Laurenzo "n_dims", 649*436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 650*436c6c9cSStella Laurenzo .def_property_readonly( 651*436c6c9cSStella Laurenzo "n_inputs", 652*436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 653*436c6c9cSStella Laurenzo .def_property_readonly( 654*436c6c9cSStella Laurenzo "n_symbols", 655*436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 656*436c6c9cSStella Laurenzo .def_property_readonly("results", [](PyAffineMap &self) { 657*436c6c9cSStella Laurenzo return PyAffineMapExprList(self); 658*436c6c9cSStella Laurenzo }); 659*436c6c9cSStella Laurenzo PyAffineMapExprList::bind(m); 660*436c6c9cSStella Laurenzo 661*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 662*436c6c9cSStella Laurenzo // Mapping of PyIntegerSet. 663*436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 664*436c6c9cSStella Laurenzo py::class_<PyIntegerSet>(m, "IntegerSet") 665*436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 666*436c6c9cSStella Laurenzo &PyIntegerSet::getCapsule) 667*436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 668*436c6c9cSStella Laurenzo .def("__eq__", [](PyIntegerSet &self, 669*436c6c9cSStella Laurenzo PyIntegerSet &other) { return self == other; }) 670*436c6c9cSStella Laurenzo .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 671*436c6c9cSStella Laurenzo .def("__str__", 672*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 673*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 674*436c6c9cSStella Laurenzo mlirIntegerSetPrint(self, printAccum.getCallback(), 675*436c6c9cSStella Laurenzo printAccum.getUserData()); 676*436c6c9cSStella Laurenzo return printAccum.join(); 677*436c6c9cSStella Laurenzo }) 678*436c6c9cSStella Laurenzo .def("__repr__", 679*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 680*436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 681*436c6c9cSStella Laurenzo printAccum.parts.append("IntegerSet("); 682*436c6c9cSStella Laurenzo mlirIntegerSetPrint(self, printAccum.getCallback(), 683*436c6c9cSStella Laurenzo printAccum.getUserData()); 684*436c6c9cSStella Laurenzo printAccum.parts.append(")"); 685*436c6c9cSStella Laurenzo return printAccum.join(); 686*436c6c9cSStella Laurenzo }) 687*436c6c9cSStella Laurenzo .def_property_readonly( 688*436c6c9cSStella Laurenzo "context", 689*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return self.getContext().getObject(); }) 690*436c6c9cSStella Laurenzo .def( 691*436c6c9cSStella Laurenzo "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 692*436c6c9cSStella Laurenzo kDumpDocstring) 693*436c6c9cSStella Laurenzo .def_static( 694*436c6c9cSStella Laurenzo "get", 695*436c6c9cSStella Laurenzo [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 696*436c6c9cSStella Laurenzo std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 697*436c6c9cSStella Laurenzo if (exprs.size() != eqFlags.size()) 698*436c6c9cSStella Laurenzo throw py::value_error( 699*436c6c9cSStella Laurenzo "Expected the number of constraints to match " 700*436c6c9cSStella Laurenzo "that of equality flags"); 701*436c6c9cSStella Laurenzo if (exprs.empty()) 702*436c6c9cSStella Laurenzo throw py::value_error("Expected non-empty list of constraints"); 703*436c6c9cSStella Laurenzo 704*436c6c9cSStella Laurenzo // Copy over to a SmallVector because std::vector has a 705*436c6c9cSStella Laurenzo // specialization for booleans that packs data and does not 706*436c6c9cSStella Laurenzo // expose a `bool *`. 707*436c6c9cSStella Laurenzo SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 708*436c6c9cSStella Laurenzo 709*436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> affineExprs; 710*436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>(exprs, affineExprs, 711*436c6c9cSStella Laurenzo "attempting to create an IntegerSet"); 712*436c6c9cSStella Laurenzo MlirIntegerSet set = mlirIntegerSetGet( 713*436c6c9cSStella Laurenzo context->get(), numDims, numSymbols, exprs.size(), 714*436c6c9cSStella Laurenzo affineExprs.data(), flags.data()); 715*436c6c9cSStella Laurenzo return PyIntegerSet(context->getRef(), set); 716*436c6c9cSStella Laurenzo }, 717*436c6c9cSStella Laurenzo py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 718*436c6c9cSStella Laurenzo py::arg("eq_flags"), py::arg("context") = py::none()) 719*436c6c9cSStella Laurenzo .def_static( 720*436c6c9cSStella Laurenzo "get_empty", 721*436c6c9cSStella Laurenzo [](intptr_t numDims, intptr_t numSymbols, 722*436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 723*436c6c9cSStella Laurenzo MlirIntegerSet set = 724*436c6c9cSStella Laurenzo mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 725*436c6c9cSStella Laurenzo return PyIntegerSet(context->getRef(), set); 726*436c6c9cSStella Laurenzo }, 727*436c6c9cSStella Laurenzo py::arg("num_dims"), py::arg("num_symbols"), 728*436c6c9cSStella Laurenzo py::arg("context") = py::none()) 729*436c6c9cSStella Laurenzo .def("get_replaced", 730*436c6c9cSStella Laurenzo [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 731*436c6c9cSStella Laurenzo intptr_t numResultDims, intptr_t numResultSymbols) { 732*436c6c9cSStella Laurenzo if (static_cast<intptr_t>(dimExprs.size()) != 733*436c6c9cSStella Laurenzo mlirIntegerSetGetNumDims(self)) 734*436c6c9cSStella Laurenzo throw py::value_error( 735*436c6c9cSStella Laurenzo "Expected the number of dimension replacement expressions " 736*436c6c9cSStella Laurenzo "to match that of dimensions"); 737*436c6c9cSStella Laurenzo if (static_cast<intptr_t>(symbolExprs.size()) != 738*436c6c9cSStella Laurenzo mlirIntegerSetGetNumSymbols(self)) 739*436c6c9cSStella Laurenzo throw py::value_error( 740*436c6c9cSStella Laurenzo "Expected the number of symbol replacement expressions " 741*436c6c9cSStella Laurenzo "to match that of symbols"); 742*436c6c9cSStella Laurenzo 743*436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 744*436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>( 745*436c6c9cSStella Laurenzo dimExprs, dimAffineExprs, 746*436c6c9cSStella Laurenzo "attempting to create an IntegerSet by replacing dimensions"); 747*436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>( 748*436c6c9cSStella Laurenzo symbolExprs, symbolAffineExprs, 749*436c6c9cSStella Laurenzo "attempting to create an IntegerSet by replacing symbols"); 750*436c6c9cSStella Laurenzo MlirIntegerSet set = mlirIntegerSetReplaceGet( 751*436c6c9cSStella Laurenzo self, dimAffineExprs.data(), symbolAffineExprs.data(), 752*436c6c9cSStella Laurenzo numResultDims, numResultSymbols); 753*436c6c9cSStella Laurenzo return PyIntegerSet(self.getContext(), set); 754*436c6c9cSStella Laurenzo }) 755*436c6c9cSStella Laurenzo .def_property_readonly("is_canonical_empty", 756*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 757*436c6c9cSStella Laurenzo return mlirIntegerSetIsCanonicalEmpty(self); 758*436c6c9cSStella Laurenzo }) 759*436c6c9cSStella Laurenzo .def_property_readonly( 760*436c6c9cSStella Laurenzo "n_dims", 761*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 762*436c6c9cSStella Laurenzo .def_property_readonly( 763*436c6c9cSStella Laurenzo "n_symbols", 764*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 765*436c6c9cSStella Laurenzo .def_property_readonly( 766*436c6c9cSStella Laurenzo "n_inputs", 767*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 768*436c6c9cSStella Laurenzo .def_property_readonly("n_equalities", 769*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 770*436c6c9cSStella Laurenzo return mlirIntegerSetGetNumEqualities(self); 771*436c6c9cSStella Laurenzo }) 772*436c6c9cSStella Laurenzo .def_property_readonly("n_inequalities", 773*436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 774*436c6c9cSStella Laurenzo return mlirIntegerSetGetNumInequalities(self); 775*436c6c9cSStella Laurenzo }) 776*436c6c9cSStella Laurenzo .def_property_readonly("constraints", [](PyIntegerSet &self) { 777*436c6c9cSStella Laurenzo return PyIntegerSetConstraintList(self); 778*436c6c9cSStella Laurenzo }); 779*436c6c9cSStella Laurenzo PyIntegerSetConstraint::bind(m); 780*436c6c9cSStella Laurenzo PyIntegerSetConstraintList::bind(m); 781*436c6c9cSStella Laurenzo } 782