xref: /llvm-project/mlir/lib/Analysis/Presburger/PWMAFunction.cpp (revision 832ccfe55275b1561b2548bfac075447037d6663)
1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/IntegerRelation.h"
11 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
12 #include "mlir/Analysis/Presburger/PresburgerSpace.h"
13 #include "mlir/Analysis/Presburger/Utils.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/STLFunctionalExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace presburger;
24 
25 void MultiAffineFunction::assertIsConsistent() const {
26   assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
27              output.getNumColumns() &&
28          "Inconsistent number of output columns");
29   assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
30              divs.getNumNonDivs() &&
31          "Inconsistent number of non-division variables in divs");
32   assert(space.getNumRangeVars() == output.getNumRows() &&
33          "Inconsistent number of output rows");
34   assert(space.getNumLocalVars() == divs.getNumDivs() &&
35          "Inconsistent number of divisions.");
36   assert(divs.hasAllReprs() && "All divisions should have a representation");
37 }
38 
39 // Return the result of subtracting the two given vectors pointwise.
40 // The vectors must be of the same size.
41 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
42 static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA,
43                                                   ArrayRef<DynamicAPInt> vecB) {
44   assert(vecA.size() == vecB.size() &&
45          "Cannot subtract vectors of differing lengths!");
46   SmallVector<DynamicAPInt, 8> result;
47   result.reserve(vecA.size());
48   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
49     result.emplace_back(vecA[i] - vecB[i]);
50   return result;
51 }
52 
53 PresburgerSet PWMAFunction::getDomain() const {
54   PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
55   for (const Piece &piece : pieces)
56     domain.unionInPlace(piece.domain);
57   return domain;
58 }
59 
60 void MultiAffineFunction::print(raw_ostream &os) const {
61   space.print(os);
62   os << "Division Representation:\n";
63   divs.print(os);
64   os << "Output:\n";
65   output.print(os);
66 }
67 
68 void MultiAffineFunction::dump() const { print(llvm::errs()); }
69 
70 SmallVector<DynamicAPInt, 8>
71 MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
72   assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
73          "Point has incorrect dimensionality!");
74 
75   SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)};
76   // Get the division values at this point.
77   SmallVector<std::optional<DynamicAPInt>, 8> divValues =
78       divs.divValuesAt(point);
79   // The given point didn't include the values of the divs which the output is a
80   // function of; we have computed one possible set of values and use them here.
81   pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
82   for (const std::optional<DynamicAPInt> &divVal : divValues)
83     pointHomogenous.emplace_back(*divVal);
84   // The matrix `output` has an affine expression in the ith row, corresponding
85   // to the expression for the ith value in the output vector. The last column
86   // of the matrix contains the constant term. Let v be the input point with
87   // a 1 appended at the end. We can see that output * v gives the desired
88   // output vector.
89   pointHomogenous.emplace_back(1);
90   SmallVector<DynamicAPInt, 8> result =
91       output.postMultiplyWithColumn(pointHomogenous);
92   assert(result.size() == getNumOutputs());
93   return result;
94 }
95 
96 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
97   assert(space.isCompatible(other.space) &&
98          "Spaces should be compatible for equality check.");
99   return getAsRelation().isEqual(other.getAsRelation());
100 }
101 
102 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
103                                   const IntegerPolyhedron &domain) const {
104   assert(space.isCompatible(other.space) &&
105          "Spaces should be compatible for equality check.");
106   IntegerRelation restrictedThis = getAsRelation();
107   restrictedThis.intersectDomain(domain);
108 
109   IntegerRelation restrictedOther = other.getAsRelation();
110   restrictedOther.intersectDomain(domain);
111 
112   return restrictedThis.isEqual(restrictedOther);
113 }
114 
115 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
116                                   const PresburgerSet &domain) const {
117   assert(space.isCompatible(other.space) &&
118          "Spaces should be compatible for equality check.");
119   return llvm::all_of(domain.getAllDisjuncts(),
120                       [&](const IntegerRelation &disjunct) {
121                         return isEqual(other, IntegerPolyhedron(disjunct));
122                       });
123 }
124 
125 void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
126   assert(end <= getNumOutputs() && "Invalid range");
127 
128   if (start >= end)
129     return;
130 
131   space.removeVarRange(VarKind::Range, start, end);
132   output.removeRows(start, end - start);
133 }
134 
135 void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
136   assert(space.isCompatible(other.space) && "Functions should be compatible");
137 
138   unsigned nDivs = getNumDivs();
139   unsigned divOffset = divs.getDivOffset();
140 
141   other.divs.insertDiv(0, nDivs);
142 
143   SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1);
144   for (unsigned i = 0; i < nDivs; ++i) {
145     // Zero fill.
146     std::fill(div.begin(), div.end(), 0);
147     // Fill div with dividend from `divs`. Do not fill the constant.
148     std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
149               div.begin());
150     // Fill constant.
151     div.back() = divs.getDividend(i).back();
152     other.divs.setDiv(i, div, divs.getDenom(i));
153   }
154 
155   other.space.insertVar(VarKind::Local, 0, nDivs);
156   other.output.insertColumns(divOffset, nDivs);
157 
158   auto merge = [&](unsigned i, unsigned j) {
159     // We only merge from local at pos j to local at pos i, where j > i.
160     if (i >= j)
161       return false;
162 
163     // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
164     // do not want to merge duplicates in `this`, we ignore this call.
165     if (j < nDivs)
166       return false;
167 
168     // Merge things in space and output.
169     other.space.removeVarRange(VarKind::Local, j, j + 1);
170     other.output.addToColumn(divOffset + i, divOffset + j, 1);
171     other.output.removeColumn(divOffset + j);
172     return true;
173   };
174 
175   other.divs.removeDuplicateDivs(merge);
176 
177   unsigned newDivs = other.divs.getNumDivs() - nDivs;
178 
179   space.insertVar(VarKind::Local, nDivs, newDivs);
180   output.insertColumns(divOffset + nDivs, newDivs);
181   divs = other.divs;
182 
183   // Check consistency.
184   assertIsConsistent();
185   other.assertIsConsistent();
186 }
187 
188 PresburgerSet
189 MultiAffineFunction::getLexSet(OrderingKind comp,
190                                const MultiAffineFunction &other) const {
191   assert(getSpace().isCompatible(other.getSpace()) &&
192          "Output space of funcs should be compatible");
193 
194   // Create copies of functions and merge their local space.
195   MultiAffineFunction funcA = *this;
196   MultiAffineFunction funcB = other;
197   funcA.mergeDivs(funcB);
198 
199   // We first create the set `result`, corresponding to the set where output
200   // of funcA is lexicographically larger/smaller than funcB. This is done by
201   // creating a PresburgerSet with the following constraints:
202   //
203   //    (outA[0] > outB[0]) U
204   //    (outA[0] = outB[0], outA[1] > outA[1]) U
205   //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
206   //    ...
207   //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
208   //
209   // where `n` is the number of outputs.
210   // If `lexMin` is set, the complement inequality is used:
211   //
212   //    (outA[0] < outB[0]) U
213   //    (outA[0] = outB[0], outA[1] < outA[1]) U
214   //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
215   //    ...
216   //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
217   PresburgerSpace resultSpace = funcA.getDomainSpace();
218   PresburgerSet result =
219       PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
220   IntegerPolyhedron levelSet(
221       /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
222       /*numReservedEqualities=*/funcA.getNumOutputs(),
223       /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
224 
225   // Add division inequalities to `levelSet`.
226   for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
227     levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
228                                             funcA.divs.getDenom(i),
229                                             funcA.divs.getDivOffset() + i));
230     levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
231                                             funcA.divs.getDenom(i),
232                                             funcA.divs.getDivOffset() + i));
233   }
234 
235   for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
236     // Create the expression `outA - outB` for this level.
237     SmallVector<DynamicAPInt, 8> subExpr =
238         subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
239 
240     // TODO: Implement all comparison cases.
241     switch (comp) {
242     case OrderingKind::LT:
243       // For less than, we add an upper bound of -1:
244       //        outA - outB <= -1
245       //        outA <= outB - 1
246       //        outA < outB
247       levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1));
248       break;
249     case OrderingKind::GT:
250       // For greater than, we add a lower bound of 1:
251       //        outA - outB >= 1
252       //        outA > outB + 1
253       //        outA > outB
254       levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1));
255       break;
256     case OrderingKind::GE:
257     case OrderingKind::LE:
258     case OrderingKind::EQ:
259     case OrderingKind::NE:
260       assert(false && "Not implemented case");
261     }
262 
263     // Union the set with the result.
264     result.unionInPlace(levelSet);
265     // The last inequality in `levelSet` is the bound we inserted. We remove
266     // that for next iteration.
267     levelSet.removeInequality(levelSet.getNumInequalities() - 1);
268     // Add equality `outA - outB == 0` for this level for next iteration.
269     levelSet.addEquality(subExpr);
270   }
271 
272   return result;
273 }
274 
275 /// Two PWMAFunctions are equal if they have the same dimensionalities,
276 /// the same domain, and take the same value at every point in the domain.
277 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
278   if (!space.isCompatible(other.space))
279     return false;
280 
281   if (!this->getDomain().isEqual(other.getDomain()))
282     return false;
283 
284   // Check if, whenever the domains of a piece of `this` and a piece of `other`
285   // overlap, they take the same output value. If `this` and `other` have the
286   // same domain (checked above), then this check passes iff the two functions
287   // have the same output at every point in the domain.
288   return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
289     return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
290       PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
291       return pieceA.output.isEqual(pieceB.output, commonDomain);
292     });
293   });
294 }
295 
296 void PWMAFunction::addPiece(const Piece &piece) {
297   assert(piece.isConsistent() && "Piece should be consistent");
298   assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
299          "Piece should be disjoint from the function");
300   pieces.emplace_back(piece);
301 }
302 
303 void PWMAFunction::print(raw_ostream &os) const {
304   space.print(os);
305   os << getNumPieces() << " pieces:\n";
306   for (const Piece &piece : pieces) {
307     os << "Domain of piece:\n";
308     piece.domain.print(os);
309     os << "Output of piece\n";
310     piece.output.print(os);
311   }
312 }
313 
314 void PWMAFunction::dump() const { print(llvm::errs()); }
315 
316 PWMAFunction PWMAFunction::unionFunction(
317     const PWMAFunction &func,
318     llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
319   assert(getNumOutputs() == func.getNumOutputs() &&
320          "Ranges of functions should be same.");
321   assert(getSpace().isCompatible(func.getSpace()) &&
322          "Space is not compatible.");
323 
324   // The algorithm used here is as follows:
325   // - Add the output of pieceB for the part of the domain where both pieceA and
326   //   pieceB are defined, and `tiebreak` chooses the output of pieceB.
327   // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
328   // chooses
329   //   pieceA over pieceB.
330   // - Add the output of pieceB, where pieceA is not defined.
331 
332   // Add parts of the common domain where pieceB's output is used. Also
333   // add all the parts where pieceA's output is used, both common and
334   // non-common.
335   PWMAFunction result(getSpace());
336   for (const Piece &pieceA : pieces) {
337     PresburgerSet dom(pieceA.domain);
338     for (const Piece &pieceB : func.pieces) {
339       PresburgerSet better = tiebreak(pieceB, pieceA);
340       // Add the output of pieceB, where it is better than output of pieceA.
341       // The disjuncts in "better" will be disjoint as tiebreak should gurantee
342       // that.
343       result.addPiece({better, pieceB.output});
344       dom = dom.subtract(better);
345     }
346     // Add output of pieceA, where it is better than pieceB, or pieceB is not
347     // defined.
348     //
349     // `dom` here is guranteed to be disjoint from already added pieces
350     // because the pieces added before are either:
351     // - Subsets of the domain of other MAFs in `this`, which are guranteed
352     //   to be disjoint from `dom`, or
353     // - They are one of the pieces added for `pieceB`, and we have been
354     //   subtracting all such pieces from `dom`, so `dom` is disjoint from those
355     //   pieces as well.
356     result.addPiece({dom, pieceA.output});
357   }
358 
359   // Add parts of pieceB which are not shared with pieceA.
360   PresburgerSet dom = getDomain();
361   for (const Piece &pieceB : func.pieces)
362     result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
363 
364   return result;
365 }
366 
367 /// A tiebreak function which breaks ties by comparing the outputs
368 /// lexicographically based on the given comparison operator.
369 /// This is templated since it is passed as a lambda.
370 template <OrderingKind comp>
371 static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
372                                  const PWMAFunction::Piece &pieceB) {
373   PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
374   result = result.intersect(pieceA.domain).intersect(pieceB.domain);
375 
376   return result;
377 }
378 
379 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
380   return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
381 }
382 
383 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
384   return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
385 }
386 
387 void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
388   assert(space.isCompatible(other.space) &&
389          "Spaces should be compatible for subtraction.");
390 
391   MultiAffineFunction copyOther = other;
392   mergeDivs(copyOther);
393   for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
394     output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1));
395 
396   // Check consistency.
397   assertIsConsistent();
398 }
399 
400 /// Adds division constraints corresponding to local variables, given a
401 /// relation and division representations of the local variables in the
402 /// relation.
403 static void addDivisionConstraints(IntegerRelation &rel,
404                                    const DivisionRepr &divs) {
405   assert(divs.hasAllReprs() &&
406          "All divisions in divs should have a representation");
407   assert(rel.getNumVars() == divs.getNumVars() &&
408          "Relation and divs should have the same number of vars");
409   assert(rel.getNumLocalVars() == divs.getNumDivs() &&
410          "Relation and divs should have the same number of local vars");
411 
412   for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
413     rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
414                                        divs.getDivOffset() + i));
415     rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
416                                        divs.getDivOffset() + i));
417   }
418 }
419 
420 IntegerRelation MultiAffineFunction::getAsRelation() const {
421   // Create a relation corressponding to the input space plus the divisions
422   // used in outputs.
423   IntegerRelation result(PresburgerSpace::getRelationSpace(
424       space.getNumDomainVars(), 0, space.getNumSymbolVars(),
425       space.getNumLocalVars()));
426   // Add division constraints corresponding to divisions used in outputs.
427   addDivisionConstraints(result, divs);
428   // The outputs are represented as range variables in the relation. We add
429   // range variables for the outputs.
430   result.insertVar(VarKind::Range, 0, getNumOutputs());
431 
432   // Add equalities such that the i^th range variable is equal to the i^th
433   // output expression.
434   SmallVector<DynamicAPInt, 8> eq(result.getNumCols());
435   for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
436     // TODO: Add functions to get VarKind offsets in output in MAF and use them
437     // here.
438     // The output expression does not contain range variables, while the
439     // equality does. So, we need to copy all variables and mark all range
440     // variables as 0 in the equality.
441     ArrayRef<DynamicAPInt> expr = getOutputExpr(i);
442     // Copy domain variables in `expr` to domain variables in `eq`.
443     std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
444     // Fill the range variables in `eq` as zero.
445     std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
446               eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
447     // Copy remaining variables in `expr` to the remaining variables in `eq`.
448     std::copy(expr.begin() + getNumDomainVars(), expr.end(),
449               eq.begin() + result.getVarKindEnd(VarKind::Range));
450 
451     // Set the i^th range var to -1 in `eq` to equate the output expression to
452     // this range var.
453     eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
454     // Add the equality `rangeVar_i = output[i]`.
455     result.addEquality(eq);
456   }
457 
458   return result;
459 }
460 
461 void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
462   space.removeVarRange(VarKind::Range, start, end);
463   for (Piece &piece : pieces)
464     piece.output.removeOutputs(start, end);
465 }
466 
467 std::optional<SmallVector<DynamicAPInt, 8>>
468 PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
469   assert(point.size() == getNumDomainVars() + getNumSymbolVars());
470 
471   for (const Piece &piece : pieces)
472     if (piece.domain.containsPoint(point))
473       return piece.output.valueAt(point);
474   return std::nullopt;
475 }
476