1 //===- GeneratingFunction.h - Generating Functions over Q^d -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Definition of the GeneratingFunction class for Barvinok's algorithm, 10 // which represents a function over Q^n, parameterized by d parameters. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H 15 #define MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H 16 17 #include "mlir/Analysis/Presburger/Fraction.h" 18 #include "mlir/Analysis/Presburger/Matrix.h" 19 20 namespace mlir { 21 namespace presburger { 22 namespace detail { 23 24 // A parametric point is a vector, each of whose elements 25 // is an affine function of n parameters. Each column 26 // in the matrix represents the affine function and 27 // has n+1 elements. 28 using ParamPoint = FracMatrix; 29 30 // A point is simply a vector. 31 using Point = SmallVector<Fraction>; 32 33 // A class to describe the type of generating function 34 // used to enumerate the integer points in a polytope. 35 // Consists of a set of terms, where the ith term has 36 // * a sign, ±1, stored in `signs[i]` 37 // * a numerator, of the form x^{n}, 38 // where n, stored in `numerators[i]`, 39 // is a parametric point. 40 // * a denominator, of the form (1 - x^{d1})...(1 - x^{dn}), 41 // where each dj, stored in `denominators[i][j]`, 42 // is a vector. 43 // 44 // Represents functions f_p : Q^n -> Q of the form 45 // 46 // f_p(x) = \sum_i s_i * (x^n_i(p)) / (\prod_j (1 - x^d_{ij}) 47 // 48 // where s_i is ±1, 49 // n_i \in Q^d -> Q^n is an n-vector of affine functions on d parameters, and 50 // g_{ij} \in Q^n are vectors. 51 class GeneratingFunction { 52 public: GeneratingFunction(unsigned numParam,SmallVector<int> signs,std::vector<ParamPoint> nums,std::vector<std::vector<Point>> dens)53 GeneratingFunction(unsigned numParam, SmallVector<int> signs, 54 std::vector<ParamPoint> nums, 55 std::vector<std::vector<Point>> dens) 56 : numParam(numParam), signs(signs), numerators(nums), denominators(dens) { 57 #ifndef NDEBUG 58 for (const ParamPoint &term : numerators) 59 assert(term.getNumRows() == numParam + 1 && 60 "dimensionality of numerator exponents does not match number of " 61 "parameters!"); 62 #endif // NDEBUG 63 } 64 getNumParams()65 unsigned getNumParams() const { return numParam; } 66 getSigns()67 SmallVector<int> getSigns() const { return signs; } 68 getNumerators()69 std::vector<ParamPoint> getNumerators() const { return numerators; } 70 getDenominators()71 std::vector<std::vector<Point>> getDenominators() const { 72 return denominators; 73 } 74 75 GeneratingFunction operator+(const GeneratingFunction &gf) const { 76 assert(numParam == gf.getNumParams() && 77 "two generating functions with different numbers of parameters " 78 "cannot be added!"); 79 SmallVector<int> sumSigns = signs; 80 sumSigns.append(gf.signs); 81 82 std::vector<ParamPoint> sumNumerators = numerators; 83 sumNumerators.insert(sumNumerators.end(), gf.numerators.begin(), 84 gf.numerators.end()); 85 86 std::vector<std::vector<Point>> sumDenominators = denominators; 87 sumDenominators.insert(sumDenominators.end(), gf.denominators.begin(), 88 gf.denominators.end()); 89 return GeneratingFunction(numParam, sumSigns, sumNumerators, 90 sumDenominators); 91 } 92 print(llvm::raw_ostream & os)93 llvm::raw_ostream &print(llvm::raw_ostream &os) const { 94 for (unsigned i = 0, e = signs.size(); i < e; i++) { 95 if (i == 0) { 96 if (signs[i] == -1) 97 os << "- "; 98 } else { 99 if (signs[i] == 1) 100 os << " + "; 101 else 102 os << " - "; 103 } 104 105 os << "x^["; 106 unsigned r = numerators[i].getNumRows(); 107 for (unsigned j = 0; j < r - 1; j++) { 108 os << "["; 109 for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++) 110 os << numerators[i].at(j, k) << ","; 111 os << numerators[i].getRow(j).back() << "],"; 112 } 113 os << "["; 114 for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++) 115 os << numerators[i].at(r - 1, k) << ","; 116 os << numerators[i].getRow(r - 1).back() << "]]/"; 117 118 for (const Point &den : denominators[i]) { 119 os << "(x^["; 120 for (unsigned j = 0, e = den.size(); j < e - 1; j++) 121 os << den[j] << ","; 122 os << den.back() << "])"; 123 } 124 } 125 return os; 126 } 127 128 private: 129 unsigned numParam; 130 SmallVector<int> signs; 131 std::vector<ParamPoint> numerators; 132 std::vector<std::vector<Point>> denominators; 133 }; 134 135 } // namespace detail 136 } // namespace presburger 137 } // namespace mlir 138 139 #endif // MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H 140