xref: /llvm-project/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h (revision 1a0e67d73023e7ad9e7e79f66afb43a6f2561d04)
1 //===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Support for piece-wise multi-affine functions. These are functions that are
10 // defined on a domain that is a union of IntegerPolyhedrons, and on each domain
11 // the value of the function is a tuple of integers, with each value in the
12 // tuple being an affine expression in the vars of the IntegerPolyhedron.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
17 #define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
18 
19 #include "mlir/Analysis/Presburger/IntegerRelation.h"
20 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
21 #include <optional>
22 
23 namespace mlir {
24 namespace presburger {
25 
26 /// Enum representing a binary comparison operator: equal, not equal, less than,
27 /// less than or equal, greater than, greater than or equal.
28 enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
29 
30 /// This class represents a multi-affine function with the domain as Z^d, where
31 /// `d` is the number of domain variables of the function. For example:
32 ///
33 /// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
34 ///
35 /// The output expressions are represented as a matrix with one row for every
36 /// output, one column for each var including division variables, and an extra
37 /// column at the end for the constant term.
38 ///
39 /// Checking equality of two such functions is supported, as well as finding the
40 /// value of the function at a specified point.
41 class MultiAffineFunction {
42 public:
MultiAffineFunction(const PresburgerSpace & space,const IntMatrix & output)43   MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output)
44       : space(space), output(output),
45         divs(space.getNumVars() - space.getNumRangeVars()) {
46     assertIsConsistent();
47   }
48 
MultiAffineFunction(const PresburgerSpace & space,const IntMatrix & output,const DivisionRepr & divs)49   MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output,
50                       const DivisionRepr &divs)
51       : space(space), output(output), divs(divs) {
52     assertIsConsistent();
53   }
54 
getNumDomainVars()55   unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
getNumSymbolVars()56   unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
getNumOutputs()57   unsigned getNumOutputs() const { return space.getNumRangeVars(); }
getNumDivs()58   unsigned getNumDivs() const { return space.getNumLocalVars(); }
59 
60   /// Get the space of this function.
getSpace()61   const PresburgerSpace &getSpace() const { return space; }
62   /// Get the domain/output space of the function. The returned space is a set
63   /// space.
getDomainSpace()64   PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
getOutputSpace()65   PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
66 
67   /// Get a matrix with each row representing row^th output expression.
getOutputMatrix()68   const IntMatrix &getOutputMatrix() const { return output; }
69   /// Get the `i^th` output expression.
getOutputExpr(unsigned i)70   ArrayRef<DynamicAPInt> getOutputExpr(unsigned i) const {
71     return output.getRow(i);
72   }
73 
74   /// Get the divisions used in this function.
getDivs()75   const DivisionRepr &getDivs() const { return divs; }
76 
77   /// Remove the specified range of outputs.
78   void removeOutputs(unsigned start, unsigned end);
79 
80   /// Given a MAF `other`, merges division variables such that both functions
81   /// have the union of the division vars that exist in the functions.
82   void mergeDivs(MultiAffineFunction &other);
83 
84   //// Return the output of the function at the given point.
85   SmallVector<DynamicAPInt, 8> valueAt(ArrayRef<DynamicAPInt> point) const;
valueAt(ArrayRef<int64_t> point)86   SmallVector<DynamicAPInt, 8> valueAt(ArrayRef<int64_t> point) const {
87     return valueAt(getDynamicAPIntVec(point));
88   }
89 
90   /// Return whether the `this` and `other` are equal when the domain is
91   /// restricted to `domain`. This is the case if they lie in the same space,
92   /// and their outputs are equal for every point in `domain`.
93   bool isEqual(const MultiAffineFunction &other) const;
94   bool isEqual(const MultiAffineFunction &other,
95                const IntegerPolyhedron &domain) const;
96   bool isEqual(const MultiAffineFunction &other,
97                const PresburgerSet &domain) const;
98 
99   void subtract(const MultiAffineFunction &other);
100 
101   /// Return the set of domain points where the output of `this` and `other`
102   /// are ordered lexicographically according to the given ordering.
103   /// For example, if the given comparison is `LT`, then the returned set
104   /// contains all points where the first output of `this` is lexicographically
105   /// less than `other`.
106   PresburgerSet getLexSet(OrderingKind comp,
107                           const MultiAffineFunction &other) const;
108 
109   /// Get this function as a relation.
110   IntegerRelation getAsRelation() const;
111 
112   void print(raw_ostream &os) const;
113   void dump() const;
114 
115 private:
116   /// Assert that the MAF is consistent.
117   void assertIsConsistent() const;
118 
119   /// The space of this function. The domain variables are considered as the
120   /// input variables of the function. The range variables are considered as
121   /// the outputs. The symbols parametrize the function and locals are used to
122   /// represent divisions. Each local variable has a corressponding division
123   /// representation stored in `divs`.
124   PresburgerSpace space;
125 
126   /// The function's output is a tuple of integers, with the ith element of the
127   /// tuple defined by the affine expression given by the ith row of this output
128   /// matrix.
129   IntMatrix output;
130 
131   /// Storage for division representation for each local variable in space.
132   DivisionRepr divs;
133 };
134 
135 /// This class represents a piece-wise MultiAffineFunction. This can be thought
136 /// of as a list of MultiAffineFunction with disjoint domains, with each having
137 /// their own affine expressions for their output tuples. For example, we could
138 /// have a function with two input variables (x, y), defined as
139 ///
140 /// f(x, y) = (2*x + y, y - 4)  if x >= 0, y >= 0
141 ///         = (-2*x + y, y + 4) if x < 0,  y < 0
142 ///         = (4, 1)            if x < 0,  y >= 0
143 ///
144 /// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
145 /// this class is undefined. The domains need not cover all possible points;
146 /// this represents a partial function and so could be undefined at some points.
147 ///
148 /// As in PresburgerSets, the input vars are partitioned into dimension vars and
149 /// symbolic vars.
150 ///
151 /// Support is provided to compare equality of two such functions as well as
152 /// finding the value of the function at a point.
153 class PWMAFunction {
154 public:
155   struct Piece {
156     PresburgerSet domain;
157     MultiAffineFunction output;
158 
isConsistentPiece159     bool isConsistent() const {
160       return domain.getSpace().isCompatible(output.getDomainSpace());
161     }
162   };
163 
PWMAFunction(const PresburgerSpace & space)164   PWMAFunction(const PresburgerSpace &space) : space(space) {
165     assert(space.getNumLocalVars() == 0 &&
166            "PWMAFunction cannot have local vars.");
167   }
168 
169   // Get the space of this function.
getSpace()170   const PresburgerSpace &getSpace() const { return space; }
171 
172   // Add a piece ([domain, output] pair) to this function.
173   void addPiece(const Piece &piece);
174 
getNumPieces()175   unsigned getNumPieces() const { return pieces.size(); }
getNumVarKind(VarKind kind)176   unsigned getNumVarKind(VarKind kind) const {
177     return space.getNumVarKind(kind);
178   }
getNumDomainVars()179   unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
getNumOutputs()180   unsigned getNumOutputs() const { return space.getNumRangeVars(); }
getNumSymbolVars()181   unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
182 
183   /// Remove the specified range of outputs.
184   void removeOutputs(unsigned start, unsigned end);
185 
186   /// Get the domain/output space of the function. The returned space is a set
187   /// space.
getDomainSpace()188   PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
getOutputSpace()189   PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); }
190 
191   /// Return the domain of this piece-wise MultiAffineFunction. This is the
192   /// union of the domains of all the pieces.
193   PresburgerSet getDomain() const;
194 
195   /// Return the output of the function at the given point.
196   std::optional<SmallVector<DynamicAPInt, 8>>
197   valueAt(ArrayRef<DynamicAPInt> point) const;
198   std::optional<SmallVector<DynamicAPInt, 8>>
valueAt(ArrayRef<int64_t> point)199   valueAt(ArrayRef<int64_t> point) const {
200     return valueAt(getDynamicAPIntVec(point));
201   }
202 
203   /// Return all the pieces of this piece-wise function.
getAllPieces()204   ArrayRef<Piece> getAllPieces() const { return pieces; }
205 
206   /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
207   /// they have the same dimensions, the same domain and they take the same
208   /// value at every point in the domain.
209   bool isEqual(const PWMAFunction &other) const;
210 
211   /// Return a function defined on the union of the domains of this and func,
212   /// such that when only one of the functions is defined, it outputs the same
213   /// as that function, and if both are defined, it outputs the lexmax/lexmin of
214   /// the two outputs. On points where neither function is defined, the returned
215   /// function is not defined either.
216   ///
217   /// Currently this does not support PWMAFunctions which have pieces containing
218   /// divisions.
219   /// TODO: Support division in pieces.
220   PWMAFunction unionLexMin(const PWMAFunction &func);
221   PWMAFunction unionLexMax(const PWMAFunction &func);
222 
223   void print(raw_ostream &os) const;
224   void dump() const;
225 
226 private:
227   /// Return a function defined on the union of the domains of `this` and
228   /// `func`, such that when only one of the functions is defined, it outputs
229   /// the same as that function, and if neither is defined, the returned
230   /// function is not defined either.
231   ///
232   /// The provided `tiebreak` function determines which of the two functions'
233   /// output should be used on inputs where both the functions are defined. More
234   /// precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak`
235   /// returns the subset of the intersection of the two functions' domains where
236   /// the output of `mafA` should be used.
237   ///
238   /// The PresburgerSet returned by `tiebreak` should be disjoint.
239   /// TODO: Remove this constraint of returning disjoint set.
240   PWMAFunction unionFunction(
241       const PWMAFunction &func,
242       llvm::function_ref<PresburgerSet(Piece mafA, Piece mafB)> tiebreak) const;
243 
244   /// The space of this function. The domain variables are considered as the
245   /// input variables of the function. The range variables are considered as
246   /// the outputs. The symbols paramterize the function.
247   PresburgerSpace space;
248 
249   // The pieces of the PWMAFunction.
250   SmallVector<Piece, 4> pieces;
251 };
252 
253 } // namespace presburger
254 } // namespace mlir
255 
256 #endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
257