1436c6c9cSStella Laurenzo //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #include "IRModule.h" 10436c6c9cSStella Laurenzo 11436c6c9cSStella Laurenzo #include "PybindUtils.h" 12436c6c9cSStella Laurenzo 13436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h" 14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 15436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h" 16436c6c9cSStella Laurenzo 17436c6c9cSStella Laurenzo namespace py = pybind11; 18436c6c9cSStella Laurenzo using namespace mlir; 19436c6c9cSStella Laurenzo using namespace mlir::python; 20436c6c9cSStella Laurenzo 21436c6c9cSStella Laurenzo using llvm::SmallVector; 22436c6c9cSStella Laurenzo using llvm::StringRef; 23436c6c9cSStella Laurenzo using llvm::Twine; 24436c6c9cSStella Laurenzo 25436c6c9cSStella Laurenzo static const char kDumpDocstring[] = 26436c6c9cSStella Laurenzo R"(Dumps a debug representation of the object to stderr.)"; 27436c6c9cSStella Laurenzo 28436c6c9cSStella Laurenzo /// Attempts to populate `result` with the content of `list` casted to the 29436c6c9cSStella Laurenzo /// appropriate type (Python and C types are provided as template arguments). 30436c6c9cSStella Laurenzo /// Throws errors in case of failure, using "action" to describe what the caller 31436c6c9cSStella Laurenzo /// was attempting to do. 32436c6c9cSStella Laurenzo template <typename PyType, typename CType> 33436c6c9cSStella Laurenzo static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34436c6c9cSStella Laurenzo StringRef action) { 35436c6c9cSStella Laurenzo result.reserve(py::len(list)); 36436c6c9cSStella Laurenzo for (py::handle item : list) { 37436c6c9cSStella Laurenzo try { 38436c6c9cSStella Laurenzo result.push_back(item.cast<PyType>()); 39436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 40436c6c9cSStella Laurenzo std::string msg = (llvm::Twine("Invalid expression when ") + action + 41436c6c9cSStella Laurenzo " (" + err.what() + ")") 42436c6c9cSStella Laurenzo .str(); 43436c6c9cSStella Laurenzo throw py::cast_error(msg); 44436c6c9cSStella Laurenzo } catch (py::reference_cast_error &err) { 45436c6c9cSStella Laurenzo std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46436c6c9cSStella Laurenzo action + " (" + err.what() + ")") 47436c6c9cSStella Laurenzo .str(); 48436c6c9cSStella Laurenzo throw py::cast_error(msg); 49436c6c9cSStella Laurenzo } 50436c6c9cSStella Laurenzo } 51436c6c9cSStella Laurenzo } 52436c6c9cSStella Laurenzo 53436c6c9cSStella Laurenzo template <typename PermutationTy> 54436c6c9cSStella Laurenzo static bool isPermutation(std::vector<PermutationTy> permutation) { 55436c6c9cSStella Laurenzo llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56436c6c9cSStella Laurenzo for (auto val : permutation) { 57436c6c9cSStella Laurenzo if (val < permutation.size()) { 58436c6c9cSStella Laurenzo if (seen[val]) 59436c6c9cSStella Laurenzo return false; 60436c6c9cSStella Laurenzo seen[val] = true; 61436c6c9cSStella Laurenzo continue; 62436c6c9cSStella Laurenzo } 63436c6c9cSStella Laurenzo return false; 64436c6c9cSStella Laurenzo } 65436c6c9cSStella Laurenzo return true; 66436c6c9cSStella Laurenzo } 67436c6c9cSStella Laurenzo 68436c6c9cSStella Laurenzo namespace { 69436c6c9cSStella Laurenzo 70436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71436c6c9cSStella Laurenzo /// and should be castable from it. Intermediate hierarchy classes can be 72436c6c9cSStella Laurenzo /// modeled by specifying BaseTy. 73436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74436c6c9cSStella Laurenzo class PyConcreteAffineExpr : public BaseTy { 75436c6c9cSStella Laurenzo public: 76436c6c9cSStella Laurenzo // Derived classes must define statics for: 77436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 78436c6c9cSStella Laurenzo // const char *pyClassName 79436c6c9cSStella Laurenzo // and redefine bindDerived. 80436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, BaseTy>; 81436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirAffineExpr); 82436c6c9cSStella Laurenzo 83436c6c9cSStella Laurenzo PyConcreteAffineExpr() = default; 84436c6c9cSStella Laurenzo PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85436c6c9cSStella Laurenzo : BaseTy(std::move(contextRef), affineExpr) {} 86436c6c9cSStella Laurenzo PyConcreteAffineExpr(PyAffineExpr &orig) 87436c6c9cSStella Laurenzo : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88436c6c9cSStella Laurenzo 89436c6c9cSStella Laurenzo static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig)) { 91436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 93436c6c9cSStella Laurenzo Twine("Cannot cast affine expression to ") + 94436c6c9cSStella Laurenzo DerivedTy::pyClassName + " (from " + origRepr + ")"); 95436c6c9cSStella Laurenzo } 96436c6c9cSStella Laurenzo return orig; 97436c6c9cSStella Laurenzo } 98436c6c9cSStella Laurenzo 99436c6c9cSStella Laurenzo static void bind(py::module &m) { 100f05ff4f7SStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 101436c6c9cSStella Laurenzo cls.def(py::init<PyAffineExpr &>()); 102*78f2dae0SAlex Zinenko cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { 103*78f2dae0SAlex Zinenko return DerivedTy::isaFunction(otherAffineExpr); 104*78f2dae0SAlex Zinenko }); 105436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 106436c6c9cSStella Laurenzo } 107436c6c9cSStella Laurenzo 108436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 109436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 110436c6c9cSStella Laurenzo }; 111436c6c9cSStella Laurenzo 112436c6c9cSStella Laurenzo class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 113436c6c9cSStella Laurenzo public: 114436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 115436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineConstantExpr"; 116436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 117436c6c9cSStella Laurenzo 118436c6c9cSStella Laurenzo static PyAffineConstantExpr get(intptr_t value, 119436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 120436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = 121436c6c9cSStella Laurenzo mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 122436c6c9cSStella Laurenzo return PyAffineConstantExpr(context->getRef(), affineExpr); 123436c6c9cSStella Laurenzo } 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 126436c6c9cSStella Laurenzo c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 127436c6c9cSStella Laurenzo py::arg("context") = py::none()); 128436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 129436c6c9cSStella Laurenzo return mlirAffineConstantExprGetValue(self); 130436c6c9cSStella Laurenzo }); 131436c6c9cSStella Laurenzo } 132436c6c9cSStella Laurenzo }; 133436c6c9cSStella Laurenzo 134436c6c9cSStella Laurenzo class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 135436c6c9cSStella Laurenzo public: 136436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 137436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineDimExpr"; 138436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 139436c6c9cSStella Laurenzo 140436c6c9cSStella Laurenzo static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 141436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 142436c6c9cSStella Laurenzo return PyAffineDimExpr(context->getRef(), affineExpr); 143436c6c9cSStella Laurenzo } 144436c6c9cSStella Laurenzo 145436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 146436c6c9cSStella Laurenzo c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 147436c6c9cSStella Laurenzo py::arg("context") = py::none()); 148436c6c9cSStella Laurenzo c.def_property_readonly("position", [](PyAffineDimExpr &self) { 149436c6c9cSStella Laurenzo return mlirAffineDimExprGetPosition(self); 150436c6c9cSStella Laurenzo }); 151436c6c9cSStella Laurenzo } 152436c6c9cSStella Laurenzo }; 153436c6c9cSStella Laurenzo 154436c6c9cSStella Laurenzo class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 155436c6c9cSStella Laurenzo public: 156436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 157436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineSymbolExpr"; 158436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 159436c6c9cSStella Laurenzo 160436c6c9cSStella Laurenzo static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 161436c6c9cSStella Laurenzo MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 162436c6c9cSStella Laurenzo return PyAffineSymbolExpr(context->getRef(), affineExpr); 163436c6c9cSStella Laurenzo } 164436c6c9cSStella Laurenzo 165436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 166436c6c9cSStella Laurenzo c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 167436c6c9cSStella Laurenzo py::arg("context") = py::none()); 168436c6c9cSStella Laurenzo c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 169436c6c9cSStella Laurenzo return mlirAffineSymbolExprGetPosition(self); 170436c6c9cSStella Laurenzo }); 171436c6c9cSStella Laurenzo } 172436c6c9cSStella Laurenzo }; 173436c6c9cSStella Laurenzo 174436c6c9cSStella Laurenzo class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 175436c6c9cSStella Laurenzo public: 176436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 177436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineBinaryExpr"; 178436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 179436c6c9cSStella Laurenzo 180436c6c9cSStella Laurenzo PyAffineExpr lhs() { 181436c6c9cSStella Laurenzo MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 182436c6c9cSStella Laurenzo return PyAffineExpr(getContext(), lhsExpr); 183436c6c9cSStella Laurenzo } 184436c6c9cSStella Laurenzo 185436c6c9cSStella Laurenzo PyAffineExpr rhs() { 186436c6c9cSStella Laurenzo MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 187436c6c9cSStella Laurenzo return PyAffineExpr(getContext(), rhsExpr); 188436c6c9cSStella Laurenzo } 189436c6c9cSStella Laurenzo 190436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 191436c6c9cSStella Laurenzo c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 192436c6c9cSStella Laurenzo c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 193436c6c9cSStella Laurenzo } 194436c6c9cSStella Laurenzo }; 195436c6c9cSStella Laurenzo 196436c6c9cSStella Laurenzo class PyAffineAddExpr 197436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 198436c6c9cSStella Laurenzo public: 199436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 200436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineAddExpr"; 201436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 202436c6c9cSStella Laurenzo 203436c6c9cSStella Laurenzo static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 204436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 205436c6c9cSStella Laurenzo return PyAffineAddExpr(lhs.getContext(), expr); 206436c6c9cSStella Laurenzo } 207436c6c9cSStella Laurenzo 208436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 209436c6c9cSStella Laurenzo c.def_static("get", &PyAffineAddExpr::get); 210436c6c9cSStella Laurenzo } 211436c6c9cSStella Laurenzo }; 212436c6c9cSStella Laurenzo 213436c6c9cSStella Laurenzo class PyAffineMulExpr 214436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 215436c6c9cSStella Laurenzo public: 216436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 217436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMulExpr"; 218436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 219436c6c9cSStella Laurenzo 220436c6c9cSStella Laurenzo static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 221436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 222436c6c9cSStella Laurenzo return PyAffineMulExpr(lhs.getContext(), expr); 223436c6c9cSStella Laurenzo } 224436c6c9cSStella Laurenzo 225436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 226436c6c9cSStella Laurenzo c.def_static("get", &PyAffineMulExpr::get); 227436c6c9cSStella Laurenzo } 228436c6c9cSStella Laurenzo }; 229436c6c9cSStella Laurenzo 230436c6c9cSStella Laurenzo class PyAffineModExpr 231436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 232436c6c9cSStella Laurenzo public: 233436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 234436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineModExpr"; 235436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 236436c6c9cSStella Laurenzo 237436c6c9cSStella Laurenzo static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 238436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 239436c6c9cSStella Laurenzo return PyAffineModExpr(lhs.getContext(), expr); 240436c6c9cSStella Laurenzo } 241436c6c9cSStella Laurenzo 242436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 243436c6c9cSStella Laurenzo c.def_static("get", &PyAffineModExpr::get); 244436c6c9cSStella Laurenzo } 245436c6c9cSStella Laurenzo }; 246436c6c9cSStella Laurenzo 247436c6c9cSStella Laurenzo class PyAffineFloorDivExpr 248436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 249436c6c9cSStella Laurenzo public: 250436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 251436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineFloorDivExpr"; 252436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 253436c6c9cSStella Laurenzo 254436c6c9cSStella Laurenzo static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 255436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 256436c6c9cSStella Laurenzo return PyAffineFloorDivExpr(lhs.getContext(), expr); 257436c6c9cSStella Laurenzo } 258436c6c9cSStella Laurenzo 259436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 260436c6c9cSStella Laurenzo c.def_static("get", &PyAffineFloorDivExpr::get); 261436c6c9cSStella Laurenzo } 262436c6c9cSStella Laurenzo }; 263436c6c9cSStella Laurenzo 264436c6c9cSStella Laurenzo class PyAffineCeilDivExpr 265436c6c9cSStella Laurenzo : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 266436c6c9cSStella Laurenzo public: 267436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 268436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineCeilDivExpr"; 269436c6c9cSStella Laurenzo using PyConcreteAffineExpr::PyConcreteAffineExpr; 270436c6c9cSStella Laurenzo 271436c6c9cSStella Laurenzo static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 272436c6c9cSStella Laurenzo MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 273436c6c9cSStella Laurenzo return PyAffineCeilDivExpr(lhs.getContext(), expr); 274436c6c9cSStella Laurenzo } 275436c6c9cSStella Laurenzo 276436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 277436c6c9cSStella Laurenzo c.def_static("get", &PyAffineCeilDivExpr::get); 278436c6c9cSStella Laurenzo } 279436c6c9cSStella Laurenzo }; 280436c6c9cSStella Laurenzo 281436c6c9cSStella Laurenzo } // namespace 282436c6c9cSStella Laurenzo 283436c6c9cSStella Laurenzo bool PyAffineExpr::operator==(const PyAffineExpr &other) { 284436c6c9cSStella Laurenzo return mlirAffineExprEqual(affineExpr, other.affineExpr); 285436c6c9cSStella Laurenzo } 286436c6c9cSStella Laurenzo 287436c6c9cSStella Laurenzo py::object PyAffineExpr::getCapsule() { 288436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>( 289436c6c9cSStella Laurenzo mlirPythonAffineExprToCapsule(*this)); 290436c6c9cSStella Laurenzo } 291436c6c9cSStella Laurenzo 292436c6c9cSStella Laurenzo PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 293436c6c9cSStella Laurenzo MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 294436c6c9cSStella Laurenzo if (mlirAffineExprIsNull(rawAffineExpr)) 295436c6c9cSStella Laurenzo throw py::error_already_set(); 296436c6c9cSStella Laurenzo return PyAffineExpr( 297436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 298436c6c9cSStella Laurenzo rawAffineExpr); 299436c6c9cSStella Laurenzo } 300436c6c9cSStella Laurenzo 301436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 302436c6c9cSStella Laurenzo // PyAffineMap and utilities. 303436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 304436c6c9cSStella Laurenzo namespace { 305436c6c9cSStella Laurenzo 306436c6c9cSStella Laurenzo /// A list of expressions contained in an affine map. Internally these are 307436c6c9cSStella Laurenzo /// stored as a consecutive array leading to inexpensive random access. Both 308436c6c9cSStella Laurenzo /// the map and the expression are owned by the context so we need not bother 309436c6c9cSStella Laurenzo /// with lifetime extension. 310436c6c9cSStella Laurenzo class PyAffineMapExprList 311436c6c9cSStella Laurenzo : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 312436c6c9cSStella Laurenzo public: 313436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineExprList"; 314436c6c9cSStella Laurenzo 315436c6c9cSStella Laurenzo PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 316436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 317436c6c9cSStella Laurenzo : Sliceable(startIndex, 318436c6c9cSStella Laurenzo length == -1 ? mlirAffineMapGetNumResults(map) : length, 319436c6c9cSStella Laurenzo step), 320436c6c9cSStella Laurenzo affineMap(map) {} 321436c6c9cSStella Laurenzo 322436c6c9cSStella Laurenzo intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 323436c6c9cSStella Laurenzo 324436c6c9cSStella Laurenzo PyAffineExpr getElement(intptr_t pos) { 325436c6c9cSStella Laurenzo return PyAffineExpr(affineMap.getContext(), 326436c6c9cSStella Laurenzo mlirAffineMapGetResult(affineMap, pos)); 327436c6c9cSStella Laurenzo } 328436c6c9cSStella Laurenzo 329436c6c9cSStella Laurenzo PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 330436c6c9cSStella Laurenzo intptr_t step) { 331436c6c9cSStella Laurenzo return PyAffineMapExprList(affineMap, startIndex, length, step); 332436c6c9cSStella Laurenzo } 333436c6c9cSStella Laurenzo 334436c6c9cSStella Laurenzo private: 335436c6c9cSStella Laurenzo PyAffineMap affineMap; 336436c6c9cSStella Laurenzo }; 337436c6c9cSStella Laurenzo } // end namespace 338436c6c9cSStella Laurenzo 339436c6c9cSStella Laurenzo bool PyAffineMap::operator==(const PyAffineMap &other) { 340436c6c9cSStella Laurenzo return mlirAffineMapEqual(affineMap, other.affineMap); 341436c6c9cSStella Laurenzo } 342436c6c9cSStella Laurenzo 343436c6c9cSStella Laurenzo py::object PyAffineMap::getCapsule() { 344436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 345436c6c9cSStella Laurenzo } 346436c6c9cSStella Laurenzo 347436c6c9cSStella Laurenzo PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 348436c6c9cSStella Laurenzo MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 349436c6c9cSStella Laurenzo if (mlirAffineMapIsNull(rawAffineMap)) 350436c6c9cSStella Laurenzo throw py::error_already_set(); 351436c6c9cSStella Laurenzo return PyAffineMap( 352436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 353436c6c9cSStella Laurenzo rawAffineMap); 354436c6c9cSStella Laurenzo } 355436c6c9cSStella Laurenzo 356436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 357436c6c9cSStella Laurenzo // PyIntegerSet and utilities. 358436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 359436c6c9cSStella Laurenzo namespace { 360436c6c9cSStella Laurenzo 361436c6c9cSStella Laurenzo class PyIntegerSetConstraint { 362436c6c9cSStella Laurenzo public: 363436c6c9cSStella Laurenzo PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 364436c6c9cSStella Laurenzo 365436c6c9cSStella Laurenzo PyAffineExpr getExpr() { 366436c6c9cSStella Laurenzo return PyAffineExpr(set.getContext(), 367436c6c9cSStella Laurenzo mlirIntegerSetGetConstraint(set, pos)); 368436c6c9cSStella Laurenzo } 369436c6c9cSStella Laurenzo 370436c6c9cSStella Laurenzo bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 371436c6c9cSStella Laurenzo 372436c6c9cSStella Laurenzo static void bind(py::module &m) { 373f05ff4f7SStella Laurenzo py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 374f05ff4f7SStella Laurenzo py::module_local()) 375436c6c9cSStella Laurenzo .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 376436c6c9cSStella Laurenzo .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 377436c6c9cSStella Laurenzo } 378436c6c9cSStella Laurenzo 379436c6c9cSStella Laurenzo private: 380436c6c9cSStella Laurenzo PyIntegerSet set; 381436c6c9cSStella Laurenzo intptr_t pos; 382436c6c9cSStella Laurenzo }; 383436c6c9cSStella Laurenzo 384436c6c9cSStella Laurenzo class PyIntegerSetConstraintList 385436c6c9cSStella Laurenzo : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 386436c6c9cSStella Laurenzo public: 387436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerSetConstraintList"; 388436c6c9cSStella Laurenzo 389436c6c9cSStella Laurenzo PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 390436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 391436c6c9cSStella Laurenzo : Sliceable(startIndex, 392436c6c9cSStella Laurenzo length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 393436c6c9cSStella Laurenzo step), 394436c6c9cSStella Laurenzo set(set) {} 395436c6c9cSStella Laurenzo 396436c6c9cSStella Laurenzo intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 397436c6c9cSStella Laurenzo 398436c6c9cSStella Laurenzo PyIntegerSetConstraint getElement(intptr_t pos) { 399436c6c9cSStella Laurenzo return PyIntegerSetConstraint(set, pos); 400436c6c9cSStella Laurenzo } 401436c6c9cSStella Laurenzo 402436c6c9cSStella Laurenzo PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 403436c6c9cSStella Laurenzo intptr_t step) { 404436c6c9cSStella Laurenzo return PyIntegerSetConstraintList(set, startIndex, length, step); 405436c6c9cSStella Laurenzo } 406436c6c9cSStella Laurenzo 407436c6c9cSStella Laurenzo private: 408436c6c9cSStella Laurenzo PyIntegerSet set; 409436c6c9cSStella Laurenzo }; 410436c6c9cSStella Laurenzo } // namespace 411436c6c9cSStella Laurenzo 412436c6c9cSStella Laurenzo bool PyIntegerSet::operator==(const PyIntegerSet &other) { 413436c6c9cSStella Laurenzo return mlirIntegerSetEqual(integerSet, other.integerSet); 414436c6c9cSStella Laurenzo } 415436c6c9cSStella Laurenzo 416436c6c9cSStella Laurenzo py::object PyIntegerSet::getCapsule() { 417436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>( 418436c6c9cSStella Laurenzo mlirPythonIntegerSetToCapsule(*this)); 419436c6c9cSStella Laurenzo } 420436c6c9cSStella Laurenzo 421436c6c9cSStella Laurenzo PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 422436c6c9cSStella Laurenzo MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 423436c6c9cSStella Laurenzo if (mlirIntegerSetIsNull(rawIntegerSet)) 424436c6c9cSStella Laurenzo throw py::error_already_set(); 425436c6c9cSStella Laurenzo return PyIntegerSet( 426436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 427436c6c9cSStella Laurenzo rawIntegerSet); 428436c6c9cSStella Laurenzo } 429436c6c9cSStella Laurenzo 430436c6c9cSStella Laurenzo void mlir::python::populateIRAffine(py::module &m) { 431436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 432436c6c9cSStella Laurenzo // Mapping of PyAffineExpr and derived classes. 433436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 434f05ff4f7SStella Laurenzo py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 435436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 436436c6c9cSStella Laurenzo &PyAffineExpr::getCapsule) 437436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 438436c6c9cSStella Laurenzo .def("__add__", 439436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 440436c6c9cSStella Laurenzo return PyAffineAddExpr::get(self, other); 441436c6c9cSStella Laurenzo }) 442436c6c9cSStella Laurenzo .def("__mul__", 443436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 444436c6c9cSStella Laurenzo return PyAffineMulExpr::get(self, other); 445436c6c9cSStella Laurenzo }) 446436c6c9cSStella Laurenzo .def("__mod__", 447436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 448436c6c9cSStella Laurenzo return PyAffineModExpr::get(self, other); 449436c6c9cSStella Laurenzo }) 450436c6c9cSStella Laurenzo .def("__sub__", 451436c6c9cSStella Laurenzo [](PyAffineExpr &self, PyAffineExpr &other) { 452436c6c9cSStella Laurenzo auto negOne = 453436c6c9cSStella Laurenzo PyAffineConstantExpr::get(-1, *self.getContext().get()); 454436c6c9cSStella Laurenzo return PyAffineAddExpr::get(self, 455436c6c9cSStella Laurenzo PyAffineMulExpr::get(negOne, other)); 456436c6c9cSStella Laurenzo }) 457436c6c9cSStella Laurenzo .def("__eq__", [](PyAffineExpr &self, 458436c6c9cSStella Laurenzo PyAffineExpr &other) { return self == other; }) 459436c6c9cSStella Laurenzo .def("__eq__", 460436c6c9cSStella Laurenzo [](PyAffineExpr &self, py::object &other) { return false; }) 461436c6c9cSStella Laurenzo .def("__str__", 462436c6c9cSStella Laurenzo [](PyAffineExpr &self) { 463436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 464436c6c9cSStella Laurenzo mlirAffineExprPrint(self, printAccum.getCallback(), 465436c6c9cSStella Laurenzo printAccum.getUserData()); 466436c6c9cSStella Laurenzo return printAccum.join(); 467436c6c9cSStella Laurenzo }) 468436c6c9cSStella Laurenzo .def("__repr__", 469436c6c9cSStella Laurenzo [](PyAffineExpr &self) { 470436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 471436c6c9cSStella Laurenzo printAccum.parts.append("AffineExpr("); 472436c6c9cSStella Laurenzo mlirAffineExprPrint(self, printAccum.getCallback(), 473436c6c9cSStella Laurenzo printAccum.getUserData()); 474436c6c9cSStella Laurenzo printAccum.parts.append(")"); 475436c6c9cSStella Laurenzo return printAccum.join(); 476436c6c9cSStella Laurenzo }) 477436c6c9cSStella Laurenzo .def_property_readonly( 478436c6c9cSStella Laurenzo "context", 479436c6c9cSStella Laurenzo [](PyAffineExpr &self) { return self.getContext().getObject(); }) 480436c6c9cSStella Laurenzo .def_static( 481436c6c9cSStella Laurenzo "get_add", &PyAffineAddExpr::get, 482436c6c9cSStella Laurenzo "Gets an affine expression containing a sum of two expressions.") 483436c6c9cSStella Laurenzo .def_static( 484436c6c9cSStella Laurenzo "get_mul", &PyAffineMulExpr::get, 485436c6c9cSStella Laurenzo "Gets an affine expression containing a product of two expressions.") 486436c6c9cSStella Laurenzo .def_static("get_mod", &PyAffineModExpr::get, 487436c6c9cSStella Laurenzo "Gets an affine expression containing the modulo of dividing " 488436c6c9cSStella Laurenzo "one expression by another.") 489436c6c9cSStella Laurenzo .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 490436c6c9cSStella Laurenzo "Gets an affine expression containing the rounded-down " 491436c6c9cSStella Laurenzo "result of dividing one expression by another.") 492436c6c9cSStella Laurenzo .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 493436c6c9cSStella Laurenzo "Gets an affine expression containing the rounded-up result " 494436c6c9cSStella Laurenzo "of dividing one expression by another.") 495436c6c9cSStella Laurenzo .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 496436c6c9cSStella Laurenzo py::arg("context") = py::none(), 497436c6c9cSStella Laurenzo "Gets a constant affine expression with the given value.") 498436c6c9cSStella Laurenzo .def_static( 499436c6c9cSStella Laurenzo "get_dim", &PyAffineDimExpr::get, py::arg("position"), 500436c6c9cSStella Laurenzo py::arg("context") = py::none(), 501436c6c9cSStella Laurenzo "Gets an affine expression of a dimension at the given position.") 502436c6c9cSStella Laurenzo .def_static( 503436c6c9cSStella Laurenzo "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 504436c6c9cSStella Laurenzo py::arg("context") = py::none(), 505436c6c9cSStella Laurenzo "Gets an affine expression of a symbol at the given position.") 506436c6c9cSStella Laurenzo .def( 507436c6c9cSStella Laurenzo "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 508436c6c9cSStella Laurenzo kDumpDocstring); 509436c6c9cSStella Laurenzo PyAffineConstantExpr::bind(m); 510436c6c9cSStella Laurenzo PyAffineDimExpr::bind(m); 511436c6c9cSStella Laurenzo PyAffineSymbolExpr::bind(m); 512436c6c9cSStella Laurenzo PyAffineBinaryExpr::bind(m); 513436c6c9cSStella Laurenzo PyAffineAddExpr::bind(m); 514436c6c9cSStella Laurenzo PyAffineMulExpr::bind(m); 515436c6c9cSStella Laurenzo PyAffineModExpr::bind(m); 516436c6c9cSStella Laurenzo PyAffineFloorDivExpr::bind(m); 517436c6c9cSStella Laurenzo PyAffineCeilDivExpr::bind(m); 518436c6c9cSStella Laurenzo 519436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 520436c6c9cSStella Laurenzo // Mapping of PyAffineMap. 521436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 522f05ff4f7SStella Laurenzo py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 523436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 524436c6c9cSStella Laurenzo &PyAffineMap::getCapsule) 525436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 526436c6c9cSStella Laurenzo .def("__eq__", 527436c6c9cSStella Laurenzo [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 528436c6c9cSStella Laurenzo .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 529436c6c9cSStella Laurenzo .def("__str__", 530436c6c9cSStella Laurenzo [](PyAffineMap &self) { 531436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 532436c6c9cSStella Laurenzo mlirAffineMapPrint(self, printAccum.getCallback(), 533436c6c9cSStella Laurenzo printAccum.getUserData()); 534436c6c9cSStella Laurenzo return printAccum.join(); 535436c6c9cSStella Laurenzo }) 536436c6c9cSStella Laurenzo .def("__repr__", 537436c6c9cSStella Laurenzo [](PyAffineMap &self) { 538436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 539436c6c9cSStella Laurenzo printAccum.parts.append("AffineMap("); 540436c6c9cSStella Laurenzo mlirAffineMapPrint(self, printAccum.getCallback(), 541436c6c9cSStella Laurenzo printAccum.getUserData()); 542436c6c9cSStella Laurenzo printAccum.parts.append(")"); 543436c6c9cSStella Laurenzo return printAccum.join(); 544436c6c9cSStella Laurenzo }) 545335d2df5SNicolas Vasilache .def_static("compress_unused_symbols", 546335d2df5SNicolas Vasilache [](py::list affineMaps, DefaultingPyMlirContext context) { 547335d2df5SNicolas Vasilache SmallVector<MlirAffineMap> maps; 548335d2df5SNicolas Vasilache pyListToVector<PyAffineMap, MlirAffineMap>( 549335d2df5SNicolas Vasilache affineMaps, maps, "attempting to create an AffineMap"); 550335d2df5SNicolas Vasilache std::vector<MlirAffineMap> compressed(affineMaps.size()); 551335d2df5SNicolas Vasilache auto populate = [](void *result, intptr_t idx, 552335d2df5SNicolas Vasilache MlirAffineMap m) { 553335d2df5SNicolas Vasilache static_cast<MlirAffineMap *>(result)[idx] = (m); 554335d2df5SNicolas Vasilache }; 555335d2df5SNicolas Vasilache mlirAffineMapCompressUnusedSymbols( 556335d2df5SNicolas Vasilache maps.data(), maps.size(), compressed.data(), populate); 557335d2df5SNicolas Vasilache std::vector<PyAffineMap> res; 558335d2df5SNicolas Vasilache for (auto m : compressed) 559335d2df5SNicolas Vasilache res.push_back(PyAffineMap(context->getRef(), m)); 560335d2df5SNicolas Vasilache return res; 561335d2df5SNicolas Vasilache }) 562436c6c9cSStella Laurenzo .def_property_readonly( 563436c6c9cSStella Laurenzo "context", 564436c6c9cSStella Laurenzo [](PyAffineMap &self) { return self.getContext().getObject(); }, 565436c6c9cSStella Laurenzo "Context that owns the Affine Map") 566436c6c9cSStella Laurenzo .def( 567436c6c9cSStella Laurenzo "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 568436c6c9cSStella Laurenzo kDumpDocstring) 569436c6c9cSStella Laurenzo .def_static( 570436c6c9cSStella Laurenzo "get", 571436c6c9cSStella Laurenzo [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 572436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 573436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> affineExprs; 574436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr, MlirAffineExpr>( 575436c6c9cSStella Laurenzo exprs, affineExprs, "attempting to create an AffineMap"); 576436c6c9cSStella Laurenzo MlirAffineMap map = 577436c6c9cSStella Laurenzo mlirAffineMapGet(context->get(), dimCount, symbolCount, 578436c6c9cSStella Laurenzo affineExprs.size(), affineExprs.data()); 579436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), map); 580436c6c9cSStella Laurenzo }, 581436c6c9cSStella Laurenzo py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 582436c6c9cSStella Laurenzo py::arg("context") = py::none(), 583436c6c9cSStella Laurenzo "Gets a map with the given expressions as results.") 584436c6c9cSStella Laurenzo .def_static( 585436c6c9cSStella Laurenzo "get_constant", 586436c6c9cSStella Laurenzo [](intptr_t value, DefaultingPyMlirContext context) { 587436c6c9cSStella Laurenzo MlirAffineMap affineMap = 588436c6c9cSStella Laurenzo mlirAffineMapConstantGet(context->get(), value); 589436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 590436c6c9cSStella Laurenzo }, 591436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 592436c6c9cSStella Laurenzo "Gets an affine map with a single constant result") 593436c6c9cSStella Laurenzo .def_static( 594436c6c9cSStella Laurenzo "get_empty", 595436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 596436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 597436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 598436c6c9cSStella Laurenzo }, 599436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Gets an empty affine map.") 600436c6c9cSStella Laurenzo .def_static( 601436c6c9cSStella Laurenzo "get_identity", 602436c6c9cSStella Laurenzo [](intptr_t nDims, DefaultingPyMlirContext context) { 603436c6c9cSStella Laurenzo MlirAffineMap affineMap = 604436c6c9cSStella Laurenzo mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 605436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 606436c6c9cSStella Laurenzo }, 607436c6c9cSStella Laurenzo py::arg("n_dims"), py::arg("context") = py::none(), 608436c6c9cSStella Laurenzo "Gets an identity map with the given number of dimensions.") 609436c6c9cSStella Laurenzo .def_static( 610436c6c9cSStella Laurenzo "get_minor_identity", 611436c6c9cSStella Laurenzo [](intptr_t nDims, intptr_t nResults, 612436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 613436c6c9cSStella Laurenzo MlirAffineMap affineMap = 614436c6c9cSStella Laurenzo mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 615436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 616436c6c9cSStella Laurenzo }, 617436c6c9cSStella Laurenzo py::arg("n_dims"), py::arg("n_results"), 618436c6c9cSStella Laurenzo py::arg("context") = py::none(), 619436c6c9cSStella Laurenzo "Gets a minor identity map with the given number of dimensions and " 620436c6c9cSStella Laurenzo "results.") 621436c6c9cSStella Laurenzo .def_static( 622436c6c9cSStella Laurenzo "get_permutation", 623436c6c9cSStella Laurenzo [](std::vector<unsigned> permutation, 624436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 625436c6c9cSStella Laurenzo if (!isPermutation(permutation)) 626436c6c9cSStella Laurenzo throw py::cast_error("Invalid permutation when attempting to " 627436c6c9cSStella Laurenzo "create an AffineMap"); 628436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapPermutationGet( 629436c6c9cSStella Laurenzo context->get(), permutation.size(), permutation.data()); 630436c6c9cSStella Laurenzo return PyAffineMap(context->getRef(), affineMap); 631436c6c9cSStella Laurenzo }, 632436c6c9cSStella Laurenzo py::arg("permutation"), py::arg("context") = py::none(), 633436c6c9cSStella Laurenzo "Gets an affine map that permutes its inputs.") 634436c6c9cSStella Laurenzo .def("get_submap", 635436c6c9cSStella Laurenzo [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 636436c6c9cSStella Laurenzo intptr_t numResults = mlirAffineMapGetNumResults(self); 637436c6c9cSStella Laurenzo for (intptr_t pos : resultPos) { 638436c6c9cSStella Laurenzo if (pos < 0 || pos >= numResults) 639436c6c9cSStella Laurenzo throw py::value_error("result position out of bounds"); 640436c6c9cSStella Laurenzo } 641436c6c9cSStella Laurenzo MlirAffineMap affineMap = mlirAffineMapGetSubMap( 642436c6c9cSStella Laurenzo self, resultPos.size(), resultPos.data()); 643436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 644436c6c9cSStella Laurenzo }) 645436c6c9cSStella Laurenzo .def("get_major_submap", 646436c6c9cSStella Laurenzo [](PyAffineMap &self, intptr_t nResults) { 647436c6c9cSStella Laurenzo if (nResults >= mlirAffineMapGetNumResults(self)) 648436c6c9cSStella Laurenzo throw py::value_error("number of results out of bounds"); 649436c6c9cSStella Laurenzo MlirAffineMap affineMap = 650436c6c9cSStella Laurenzo mlirAffineMapGetMajorSubMap(self, nResults); 651436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 652436c6c9cSStella Laurenzo }) 653436c6c9cSStella Laurenzo .def("get_minor_submap", 654436c6c9cSStella Laurenzo [](PyAffineMap &self, intptr_t nResults) { 655436c6c9cSStella Laurenzo if (nResults >= mlirAffineMapGetNumResults(self)) 656436c6c9cSStella Laurenzo throw py::value_error("number of results out of bounds"); 657436c6c9cSStella Laurenzo MlirAffineMap affineMap = 658436c6c9cSStella Laurenzo mlirAffineMapGetMinorSubMap(self, nResults); 659436c6c9cSStella Laurenzo return PyAffineMap(self.getContext(), affineMap); 660436c6c9cSStella Laurenzo }) 66131f888eaSTobias Gysi .def("replace", 66231f888eaSTobias Gysi [](PyAffineMap &self, PyAffineExpr &expression, 66331f888eaSTobias Gysi PyAffineExpr &replacement, intptr_t numResultDims, 66431f888eaSTobias Gysi intptr_t numResultSyms) { 66531f888eaSTobias Gysi MlirAffineMap affineMap = mlirAffineMapReplace( 66631f888eaSTobias Gysi self, expression, replacement, numResultDims, numResultSyms); 66731f888eaSTobias Gysi return PyAffineMap(self.getContext(), affineMap); 66831f888eaSTobias Gysi }) 669436c6c9cSStella Laurenzo .def_property_readonly( 670436c6c9cSStella Laurenzo "is_permutation", 671436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 672436c6c9cSStella Laurenzo .def_property_readonly("is_projected_permutation", 673436c6c9cSStella Laurenzo [](PyAffineMap &self) { 674436c6c9cSStella Laurenzo return mlirAffineMapIsProjectedPermutation(self); 675436c6c9cSStella Laurenzo }) 676436c6c9cSStella Laurenzo .def_property_readonly( 677436c6c9cSStella Laurenzo "n_dims", 678436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 679436c6c9cSStella Laurenzo .def_property_readonly( 680436c6c9cSStella Laurenzo "n_inputs", 681436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 682436c6c9cSStella Laurenzo .def_property_readonly( 683436c6c9cSStella Laurenzo "n_symbols", 684436c6c9cSStella Laurenzo [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 685436c6c9cSStella Laurenzo .def_property_readonly("results", [](PyAffineMap &self) { 686436c6c9cSStella Laurenzo return PyAffineMapExprList(self); 687436c6c9cSStella Laurenzo }); 688436c6c9cSStella Laurenzo PyAffineMapExprList::bind(m); 689436c6c9cSStella Laurenzo 690436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 691436c6c9cSStella Laurenzo // Mapping of PyIntegerSet. 692436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 693f05ff4f7SStella Laurenzo py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 694436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 695436c6c9cSStella Laurenzo &PyIntegerSet::getCapsule) 696436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 697436c6c9cSStella Laurenzo .def("__eq__", [](PyIntegerSet &self, 698436c6c9cSStella Laurenzo PyIntegerSet &other) { return self == other; }) 699436c6c9cSStella Laurenzo .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 700436c6c9cSStella Laurenzo .def("__str__", 701436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 702436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 703436c6c9cSStella Laurenzo mlirIntegerSetPrint(self, printAccum.getCallback(), 704436c6c9cSStella Laurenzo printAccum.getUserData()); 705436c6c9cSStella Laurenzo return printAccum.join(); 706436c6c9cSStella Laurenzo }) 707436c6c9cSStella Laurenzo .def("__repr__", 708436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 709436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 710436c6c9cSStella Laurenzo printAccum.parts.append("IntegerSet("); 711436c6c9cSStella Laurenzo mlirIntegerSetPrint(self, printAccum.getCallback(), 712436c6c9cSStella Laurenzo printAccum.getUserData()); 713436c6c9cSStella Laurenzo printAccum.parts.append(")"); 714436c6c9cSStella Laurenzo return printAccum.join(); 715436c6c9cSStella Laurenzo }) 716436c6c9cSStella Laurenzo .def_property_readonly( 717436c6c9cSStella Laurenzo "context", 718436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return self.getContext().getObject(); }) 719436c6c9cSStella Laurenzo .def( 720436c6c9cSStella Laurenzo "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 721436c6c9cSStella Laurenzo kDumpDocstring) 722436c6c9cSStella Laurenzo .def_static( 723436c6c9cSStella Laurenzo "get", 724436c6c9cSStella Laurenzo [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 725436c6c9cSStella Laurenzo std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 726436c6c9cSStella Laurenzo if (exprs.size() != eqFlags.size()) 727436c6c9cSStella Laurenzo throw py::value_error( 728436c6c9cSStella Laurenzo "Expected the number of constraints to match " 729436c6c9cSStella Laurenzo "that of equality flags"); 730436c6c9cSStella Laurenzo if (exprs.empty()) 731436c6c9cSStella Laurenzo throw py::value_error("Expected non-empty list of constraints"); 732436c6c9cSStella Laurenzo 733436c6c9cSStella Laurenzo // Copy over to a SmallVector because std::vector has a 734436c6c9cSStella Laurenzo // specialization for booleans that packs data and does not 735436c6c9cSStella Laurenzo // expose a `bool *`. 736436c6c9cSStella Laurenzo SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 737436c6c9cSStella Laurenzo 738436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> affineExprs; 739436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>(exprs, affineExprs, 740436c6c9cSStella Laurenzo "attempting to create an IntegerSet"); 741436c6c9cSStella Laurenzo MlirIntegerSet set = mlirIntegerSetGet( 742436c6c9cSStella Laurenzo context->get(), numDims, numSymbols, exprs.size(), 743436c6c9cSStella Laurenzo affineExprs.data(), flags.data()); 744436c6c9cSStella Laurenzo return PyIntegerSet(context->getRef(), set); 745436c6c9cSStella Laurenzo }, 746436c6c9cSStella Laurenzo py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 747436c6c9cSStella Laurenzo py::arg("eq_flags"), py::arg("context") = py::none()) 748436c6c9cSStella Laurenzo .def_static( 749436c6c9cSStella Laurenzo "get_empty", 750436c6c9cSStella Laurenzo [](intptr_t numDims, intptr_t numSymbols, 751436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 752436c6c9cSStella Laurenzo MlirIntegerSet set = 753436c6c9cSStella Laurenzo mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 754436c6c9cSStella Laurenzo return PyIntegerSet(context->getRef(), set); 755436c6c9cSStella Laurenzo }, 756436c6c9cSStella Laurenzo py::arg("num_dims"), py::arg("num_symbols"), 757436c6c9cSStella Laurenzo py::arg("context") = py::none()) 758436c6c9cSStella Laurenzo .def("get_replaced", 759436c6c9cSStella Laurenzo [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 760436c6c9cSStella Laurenzo intptr_t numResultDims, intptr_t numResultSymbols) { 761436c6c9cSStella Laurenzo if (static_cast<intptr_t>(dimExprs.size()) != 762436c6c9cSStella Laurenzo mlirIntegerSetGetNumDims(self)) 763436c6c9cSStella Laurenzo throw py::value_error( 764436c6c9cSStella Laurenzo "Expected the number of dimension replacement expressions " 765436c6c9cSStella Laurenzo "to match that of dimensions"); 766436c6c9cSStella Laurenzo if (static_cast<intptr_t>(symbolExprs.size()) != 767436c6c9cSStella Laurenzo mlirIntegerSetGetNumSymbols(self)) 768436c6c9cSStella Laurenzo throw py::value_error( 769436c6c9cSStella Laurenzo "Expected the number of symbol replacement expressions " 770436c6c9cSStella Laurenzo "to match that of symbols"); 771436c6c9cSStella Laurenzo 772436c6c9cSStella Laurenzo SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 773436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>( 774436c6c9cSStella Laurenzo dimExprs, dimAffineExprs, 775436c6c9cSStella Laurenzo "attempting to create an IntegerSet by replacing dimensions"); 776436c6c9cSStella Laurenzo pyListToVector<PyAffineExpr>( 777436c6c9cSStella Laurenzo symbolExprs, symbolAffineExprs, 778436c6c9cSStella Laurenzo "attempting to create an IntegerSet by replacing symbols"); 779436c6c9cSStella Laurenzo MlirIntegerSet set = mlirIntegerSetReplaceGet( 780436c6c9cSStella Laurenzo self, dimAffineExprs.data(), symbolAffineExprs.data(), 781436c6c9cSStella Laurenzo numResultDims, numResultSymbols); 782436c6c9cSStella Laurenzo return PyIntegerSet(self.getContext(), set); 783436c6c9cSStella Laurenzo }) 784436c6c9cSStella Laurenzo .def_property_readonly("is_canonical_empty", 785436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 786436c6c9cSStella Laurenzo return mlirIntegerSetIsCanonicalEmpty(self); 787436c6c9cSStella Laurenzo }) 788436c6c9cSStella Laurenzo .def_property_readonly( 789436c6c9cSStella Laurenzo "n_dims", 790436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 791436c6c9cSStella Laurenzo .def_property_readonly( 792436c6c9cSStella Laurenzo "n_symbols", 793436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 794436c6c9cSStella Laurenzo .def_property_readonly( 795436c6c9cSStella Laurenzo "n_inputs", 796436c6c9cSStella Laurenzo [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 797436c6c9cSStella Laurenzo .def_property_readonly("n_equalities", 798436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 799436c6c9cSStella Laurenzo return mlirIntegerSetGetNumEqualities(self); 800436c6c9cSStella Laurenzo }) 801436c6c9cSStella Laurenzo .def_property_readonly("n_inequalities", 802436c6c9cSStella Laurenzo [](PyIntegerSet &self) { 803436c6c9cSStella Laurenzo return mlirIntegerSetGetNumInequalities(self); 804436c6c9cSStella Laurenzo }) 805436c6c9cSStella Laurenzo .def_property_readonly("constraints", [](PyIntegerSet &self) { 806436c6c9cSStella Laurenzo return PyIntegerSetConstraintList(self); 807436c6c9cSStella Laurenzo }); 808436c6c9cSStella Laurenzo PyIntegerSetConstraint::bind(m); 809436c6c9cSStella Laurenzo PyIntegerSetConstraintList::bind(m); 810436c6c9cSStella Laurenzo } 811