xref: /llvm-project/mlir/include/mlir/Analysis/Presburger/GeneratingFunction.h (revision 562790f371f230d8f67a1a8fb4b54e02e8d1e31f)
1 //===- GeneratingFunction.h - Generating Functions over Q^d -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definition of the GeneratingFunction class for Barvinok's algorithm,
10 // which represents a function over Q^n, parameterized by d parameters.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
15 #define MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
16 
17 #include "mlir/Analysis/Presburger/Fraction.h"
18 #include "mlir/Analysis/Presburger/Matrix.h"
19 
20 namespace mlir {
21 namespace presburger {
22 namespace detail {
23 
24 // A parametric point is a vector, each of whose elements
25 // is an affine function of n parameters. Each column
26 // in the matrix represents the affine function and
27 // has n+1 elements.
28 using ParamPoint = FracMatrix;
29 
30 // A point is simply a vector.
31 using Point = SmallVector<Fraction>;
32 
33 // A class to describe the type of generating function
34 // used to enumerate the integer points in a polytope.
35 // Consists of a set of terms, where the ith term has
36 // * a sign, ±1, stored in `signs[i]`
37 // * a numerator, of the form x^{n},
38 //      where n, stored in `numerators[i]`,
39 //      is a parametric point.
40 // * a denominator, of the form (1 - x^{d1})...(1 - x^{dn}),
41 //      where each dj, stored in `denominators[i][j]`,
42 //      is a vector.
43 //
44 // Represents functions f_p : Q^n -> Q of the form
45 //
46 // f_p(x) = \sum_i s_i * (x^n_i(p)) / (\prod_j (1 - x^d_{ij})
47 //
48 // where s_i is ±1,
49 // n_i \in Q^d -> Q^n is an n-vector of affine functions on d parameters, and
50 // g_{ij} \in Q^n are vectors.
51 class GeneratingFunction {
52 public:
GeneratingFunction(unsigned numParam,SmallVector<int> signs,std::vector<ParamPoint> nums,std::vector<std::vector<Point>> dens)53   GeneratingFunction(unsigned numParam, SmallVector<int> signs,
54                      std::vector<ParamPoint> nums,
55                      std::vector<std::vector<Point>> dens)
56       : numParam(numParam), signs(signs), numerators(nums), denominators(dens) {
57 #ifndef NDEBUG
58     for (const ParamPoint &term : numerators)
59       assert(term.getNumRows() == numParam + 1 &&
60              "dimensionality of numerator exponents does not match number of "
61              "parameters!");
62 #endif // NDEBUG
63   }
64 
getNumParams()65   unsigned getNumParams() const { return numParam; }
66 
getSigns()67   SmallVector<int> getSigns() const { return signs; }
68 
getNumerators()69   std::vector<ParamPoint> getNumerators() const { return numerators; }
70 
getDenominators()71   std::vector<std::vector<Point>> getDenominators() const {
72     return denominators;
73   }
74 
75   GeneratingFunction operator+(const GeneratingFunction &gf) const {
76     assert(numParam == gf.getNumParams() &&
77            "two generating functions with different numbers of parameters "
78            "cannot be added!");
79     SmallVector<int> sumSigns = signs;
80     sumSigns.append(gf.signs);
81 
82     std::vector<ParamPoint> sumNumerators = numerators;
83     sumNumerators.insert(sumNumerators.end(), gf.numerators.begin(),
84                          gf.numerators.end());
85 
86     std::vector<std::vector<Point>> sumDenominators = denominators;
87     sumDenominators.insert(sumDenominators.end(), gf.denominators.begin(),
88                            gf.denominators.end());
89     return GeneratingFunction(numParam, sumSigns, sumNumerators,
90                               sumDenominators);
91   }
92 
print(llvm::raw_ostream & os)93   llvm::raw_ostream &print(llvm::raw_ostream &os) const {
94     for (unsigned i = 0, e = signs.size(); i < e; i++) {
95       if (i == 0) {
96         if (signs[i] == -1)
97           os << "- ";
98       } else {
99         if (signs[i] == 1)
100           os << " + ";
101         else
102           os << " - ";
103       }
104 
105       os << "x^[";
106       unsigned r = numerators[i].getNumRows();
107       for (unsigned j = 0; j < r - 1; j++) {
108         os << "[";
109         for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++)
110           os << numerators[i].at(j, k) << ",";
111         os << numerators[i].getRow(j).back() << "],";
112       }
113       os << "[";
114       for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++)
115         os << numerators[i].at(r - 1, k) << ",";
116       os << numerators[i].getRow(r - 1).back() << "]]/";
117 
118       for (const Point &den : denominators[i]) {
119         os << "(x^[";
120         for (unsigned j = 0, e = den.size(); j < e - 1; j++)
121           os << den[j] << ",";
122         os << den.back() << "])";
123       }
124     }
125     return os;
126   }
127 
128 private:
129   unsigned numParam;
130   SmallVector<int> signs;
131   std::vector<ParamPoint> numerators;
132   std::vector<std::vector<Point>> denominators;
133 };
134 
135 } // namespace detail
136 } // namespace presburger
137 } // namespace mlir
138 
139 #endif // MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
140