xref: /llvm-project/mlir/lib/Analysis/Presburger/PWMAFunction.cpp (revision 832ccfe55275b1561b2548bfac075447037d6663)
1d5a29442SArjun P //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2d5a29442SArjun P //
3d5a29442SArjun P // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d5a29442SArjun P // See https://llvm.org/LICENSE.txt for license information.
5d5a29442SArjun P // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d5a29442SArjun P //
7d5a29442SArjun P //===----------------------------------------------------------------------===//
8d5a29442SArjun P 
9d5a29442SArjun P #include "mlir/Analysis/Presburger/PWMAFunction.h"
102ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/IntegerRelation.h"
112ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/PresburgerRelation.h"
122ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/PresburgerSpace.h"
132ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/Utils.h"
142ee87cd6SMehdi Amini #include "llvm/ADT/STLExtras.h"
152ee87cd6SMehdi Amini #include "llvm/ADT/STLFunctionalExtras.h"
162ee87cd6SMehdi Amini #include "llvm/ADT/SmallVector.h"
172ee87cd6SMehdi Amini #include "llvm/Support/raw_ostream.h"
182ee87cd6SMehdi Amini #include <algorithm>
192ee87cd6SMehdi Amini #include <cassert>
20a1fe1f5fSKazu Hirata #include <optional>
21d5a29442SArjun P 
22d5a29442SArjun P using namespace mlir;
230c1f6865SGroverkss using namespace presburger;
24d5a29442SArjun P 
25bb2226acSGroverkss void MultiAffineFunction::assertIsConsistent() const {
26bb2226acSGroverkss   assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
27bb2226acSGroverkss              output.getNumColumns() &&
28bb2226acSGroverkss          "Inconsistent number of output columns");
29bb2226acSGroverkss   assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
30bb2226acSGroverkss              divs.getNumNonDivs() &&
31bb2226acSGroverkss          "Inconsistent number of non-division variables in divs");
32bb2226acSGroverkss   assert(space.getNumRangeVars() == output.getNumRows() &&
33bb2226acSGroverkss          "Inconsistent number of output rows");
34bb2226acSGroverkss   assert(space.getNumLocalVars() == divs.getNumDivs() &&
35bb2226acSGroverkss          "Inconsistent number of divisions.");
36bb2226acSGroverkss   assert(divs.hasAllReprs() && "All divisions should have a representation");
37bb2226acSGroverkss }
38bb2226acSGroverkss 
39d5a29442SArjun P // Return the result of subtracting the two given vectors pointwise.
40d5a29442SArjun P // The vectors must be of the same size.
41d5a29442SArjun P // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
421a0e67d7SRamkumar Ramachandra static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA,
431a0e67d7SRamkumar Ramachandra                                                   ArrayRef<DynamicAPInt> vecB) {
44d5a29442SArjun P   assert(vecA.size() == vecB.size() &&
45d5a29442SArjun P          "Cannot subtract vectors of differing lengths!");
461a0e67d7SRamkumar Ramachandra   SmallVector<DynamicAPInt, 8> result;
47d5a29442SArjun P   result.reserve(vecA.size());
48d5a29442SArjun P   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
49266a5a9cSRamkumar Ramachandra     result.emplace_back(vecA[i] - vecB[i]);
50d5a29442SArjun P   return result;
51d5a29442SArjun P }
52d5a29442SArjun P 
53d5a29442SArjun P PresburgerSet PWMAFunction::getDomain() const {
54bb2226acSGroverkss   PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
55bb2226acSGroverkss   for (const Piece &piece : pieces)
56bb2226acSGroverkss     domain.unionInPlace(piece.domain);
57d5a29442SArjun P   return domain;
58d5a29442SArjun P }
59d5a29442SArjun P 
60bb2226acSGroverkss void MultiAffineFunction::print(raw_ostream &os) const {
61bb2226acSGroverkss   space.print(os);
62bb2226acSGroverkss   os << "Division Representation:\n";
63bb2226acSGroverkss   divs.print(os);
64bb2226acSGroverkss   os << "Output:\n";
65bb2226acSGroverkss   output.print(os);
66bb2226acSGroverkss }
67bb2226acSGroverkss 
68*832ccfe5SChristopher Bate void MultiAffineFunction::dump() const { print(llvm::errs()); }
69*832ccfe5SChristopher Bate 
701a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8>
711a0e67d7SRamkumar Ramachandra MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
72bb2226acSGroverkss   assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
734418669fSArjun P          "Point has incorrect dimensionality!");
74d5a29442SArjun P 
751a0e67d7SRamkumar Ramachandra   SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)};
76bb2226acSGroverkss   // Get the division values at this point.
771a0e67d7SRamkumar Ramachandra   SmallVector<std::optional<DynamicAPInt>, 8> divValues =
781a0e67d7SRamkumar Ramachandra       divs.divValuesAt(point);
79bb2226acSGroverkss   // The given point didn't include the values of the divs which the output is a
80bb2226acSGroverkss   // function of; we have computed one possible set of values and use them here.
81bb2226acSGroverkss   pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
821a0e67d7SRamkumar Ramachandra   for (const std::optional<DynamicAPInt> &divVal : divValues)
83266a5a9cSRamkumar Ramachandra     pointHomogenous.emplace_back(*divVal);
84d5a29442SArjun P   // The matrix `output` has an affine expression in the ith row, corresponding
85d5a29442SArjun P   // to the expression for the ith value in the output vector. The last column
86d5a29442SArjun P   // of the matrix contains the constant term. Let v be the input point with
87d5a29442SArjun P   // a 1 appended at the end. We can see that output * v gives the desired
88d5a29442SArjun P   // output vector.
896d6f6c4dSArjun P   pointHomogenous.emplace_back(1);
901a0e67d7SRamkumar Ramachandra   SmallVector<DynamicAPInt, 8> result =
911a0e67d7SRamkumar Ramachandra       output.postMultiplyWithColumn(pointHomogenous);
92d5a29442SArjun P   assert(result.size() == getNumOutputs());
93d5a29442SArjun P   return result;
94d5a29442SArjun P }
95d5a29442SArjun P 
96d5a29442SArjun P bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
97bb2226acSGroverkss   assert(space.isCompatible(other.space) &&
98bb2226acSGroverkss          "Spaces should be compatible for equality check.");
99bb2226acSGroverkss   return getAsRelation().isEqual(other.getAsRelation());
100d5a29442SArjun P }
101d5a29442SArjun P 
102bb2226acSGroverkss bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
103bb2226acSGroverkss                                   const IntegerPolyhedron &domain) const {
104bb2226acSGroverkss   assert(space.isCompatible(other.space) &&
105bb2226acSGroverkss          "Spaces should be compatible for equality check.");
106bb2226acSGroverkss   IntegerRelation restrictedThis = getAsRelation();
107bb2226acSGroverkss   restrictedThis.intersectDomain(domain);
108bb2226acSGroverkss 
109bb2226acSGroverkss   IntegerRelation restrictedOther = other.getAsRelation();
110bb2226acSGroverkss   restrictedOther.intersectDomain(domain);
111bb2226acSGroverkss 
112bb2226acSGroverkss   return restrictedThis.isEqual(restrictedOther);
113d5a29442SArjun P }
114d5a29442SArjun P 
115bb2226acSGroverkss bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
116bb2226acSGroverkss                                   const PresburgerSet &domain) const {
117bb2226acSGroverkss   assert(space.isCompatible(other.space) &&
118bb2226acSGroverkss          "Spaces should be compatible for equality check.");
119bb2226acSGroverkss   return llvm::all_of(domain.getAllDisjuncts(),
120bb2226acSGroverkss                       [&](const IntegerRelation &disjunct) {
121bb2226acSGroverkss                         return isEqual(other, IntegerPolyhedron(disjunct));
122bb2226acSGroverkss                       });
123d5a29442SArjun P }
124d5a29442SArjun P 
125bb2226acSGroverkss void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
126bb2226acSGroverkss   assert(end <= getNumOutputs() && "Invalid range");
127bb2226acSGroverkss 
128bb2226acSGroverkss   if (start >= end)
129bb2226acSGroverkss     return;
130bb2226acSGroverkss 
131bb2226acSGroverkss   space.removeVarRange(VarKind::Range, start, end);
132bb2226acSGroverkss   output.removeRows(start, end - start);
13379ad5fb2SArjun P }
13479ad5fb2SArjun P 
135bb2226acSGroverkss void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
136bb2226acSGroverkss   assert(space.isCompatible(other.space) && "Functions should be compatible");
137bb2226acSGroverkss 
138bb2226acSGroverkss   unsigned nDivs = getNumDivs();
139bb2226acSGroverkss   unsigned divOffset = divs.getDivOffset();
140bb2226acSGroverkss 
141bb2226acSGroverkss   other.divs.insertDiv(0, nDivs);
142bb2226acSGroverkss 
1431a0e67d7SRamkumar Ramachandra   SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1);
144bb2226acSGroverkss   for (unsigned i = 0; i < nDivs; ++i) {
145bb2226acSGroverkss     // Zero fill.
146bb2226acSGroverkss     std::fill(div.begin(), div.end(), 0);
147bb2226acSGroverkss     // Fill div with dividend from `divs`. Do not fill the constant.
148bb2226acSGroverkss     std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
149bb2226acSGroverkss               div.begin());
150bb2226acSGroverkss     // Fill constant.
151bb2226acSGroverkss     div.back() = divs.getDividend(i).back();
152bb2226acSGroverkss     other.divs.setDiv(i, div, divs.getDenom(i));
15379ad5fb2SArjun P   }
15479ad5fb2SArjun P 
155bb2226acSGroverkss   other.space.insertVar(VarKind::Local, 0, nDivs);
156bb2226acSGroverkss   other.output.insertColumns(divOffset, nDivs);
15715650b32SGroverkss 
158bb2226acSGroverkss   auto merge = [&](unsigned i, unsigned j) {
159bb2226acSGroverkss     // We only merge from local at pos j to local at pos i, where j > i.
160bb2226acSGroverkss     if (i >= j)
161bb2226acSGroverkss       return false;
16215650b32SGroverkss 
163bb2226acSGroverkss     // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
164bb2226acSGroverkss     // do not want to merge duplicates in `this`, we ignore this call.
165bb2226acSGroverkss     if (j < nDivs)
166bb2226acSGroverkss       return false;
16715650b32SGroverkss 
168bb2226acSGroverkss     // Merge things in space and output.
169bb2226acSGroverkss     other.space.removeVarRange(VarKind::Local, j, j + 1);
170bb2226acSGroverkss     other.output.addToColumn(divOffset + i, divOffset + j, 1);
171bb2226acSGroverkss     other.output.removeColumn(divOffset + j);
17215650b32SGroverkss     return true;
17315650b32SGroverkss   };
17415650b32SGroverkss 
175bb2226acSGroverkss   other.divs.removeDuplicateDivs(merge);
17615650b32SGroverkss 
177bb2226acSGroverkss   unsigned newDivs = other.divs.getNumDivs() - nDivs;
178d5a29442SArjun P 
179bb2226acSGroverkss   space.insertVar(VarKind::Local, nDivs, newDivs);
180bb2226acSGroverkss   output.insertColumns(divOffset + nDivs, newDivs);
181bb2226acSGroverkss   divs = other.divs;
182d5a29442SArjun P 
183bb2226acSGroverkss   // Check consistency.
184bb2226acSGroverkss   assertIsConsistent();
185bb2226acSGroverkss   other.assertIsConsistent();
186d5a29442SArjun P }
187d5a29442SArjun P 
18894750af8SGroverkss PresburgerSet
18994750af8SGroverkss MultiAffineFunction::getLexSet(OrderingKind comp,
19094750af8SGroverkss                                const MultiAffineFunction &other) const {
19194750af8SGroverkss   assert(getSpace().isCompatible(other.getSpace()) &&
19294750af8SGroverkss          "Output space of funcs should be compatible");
19394750af8SGroverkss 
19494750af8SGroverkss   // Create copies of functions and merge their local space.
19594750af8SGroverkss   MultiAffineFunction funcA = *this;
19694750af8SGroverkss   MultiAffineFunction funcB = other;
19794750af8SGroverkss   funcA.mergeDivs(funcB);
19894750af8SGroverkss 
19994750af8SGroverkss   // We first create the set `result`, corresponding to the set where output
20094750af8SGroverkss   // of funcA is lexicographically larger/smaller than funcB. This is done by
20194750af8SGroverkss   // creating a PresburgerSet with the following constraints:
20294750af8SGroverkss   //
20394750af8SGroverkss   //    (outA[0] > outB[0]) U
20494750af8SGroverkss   //    (outA[0] = outB[0], outA[1] > outA[1]) U
20594750af8SGroverkss   //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
20694750af8SGroverkss   //    ...
20794750af8SGroverkss   //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
20894750af8SGroverkss   //
20994750af8SGroverkss   // where `n` is the number of outputs.
21094750af8SGroverkss   // If `lexMin` is set, the complement inequality is used:
21194750af8SGroverkss   //
21294750af8SGroverkss   //    (outA[0] < outB[0]) U
21394750af8SGroverkss   //    (outA[0] = outB[0], outA[1] < outA[1]) U
21494750af8SGroverkss   //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
21594750af8SGroverkss   //    ...
21694750af8SGroverkss   //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
21794750af8SGroverkss   PresburgerSpace resultSpace = funcA.getDomainSpace();
21894750af8SGroverkss   PresburgerSet result =
21994750af8SGroverkss       PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
22094750af8SGroverkss   IntegerPolyhedron levelSet(
22194750af8SGroverkss       /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
22294750af8SGroverkss       /*numReservedEqualities=*/funcA.getNumOutputs(),
22394750af8SGroverkss       /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
22494750af8SGroverkss 
22594750af8SGroverkss   // Add division inequalities to `levelSet`.
22694750af8SGroverkss   for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
22794750af8SGroverkss     levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
22894750af8SGroverkss                                             funcA.divs.getDenom(i),
22994750af8SGroverkss                                             funcA.divs.getDivOffset() + i));
23094750af8SGroverkss     levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
23194750af8SGroverkss                                             funcA.divs.getDenom(i),
23294750af8SGroverkss                                             funcA.divs.getDivOffset() + i));
23394750af8SGroverkss   }
23494750af8SGroverkss 
23594750af8SGroverkss   for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
23694750af8SGroverkss     // Create the expression `outA - outB` for this level.
2371a0e67d7SRamkumar Ramachandra     SmallVector<DynamicAPInt, 8> subExpr =
23894750af8SGroverkss         subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
23994750af8SGroverkss 
24094750af8SGroverkss     // TODO: Implement all comparison cases.
24194750af8SGroverkss     switch (comp) {
24294750af8SGroverkss     case OrderingKind::LT:
24394750af8SGroverkss       // For less than, we add an upper bound of -1:
24494750af8SGroverkss       //        outA - outB <= -1
24594750af8SGroverkss       //        outA <= outB - 1
24694750af8SGroverkss       //        outA < outB
2471a0e67d7SRamkumar Ramachandra       levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1));
24894750af8SGroverkss       break;
24994750af8SGroverkss     case OrderingKind::GT:
25094750af8SGroverkss       // For greater than, we add a lower bound of 1:
25194750af8SGroverkss       //        outA - outB >= 1
25294750af8SGroverkss       //        outA > outB + 1
25394750af8SGroverkss       //        outA > outB
2541a0e67d7SRamkumar Ramachandra       levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1));
25594750af8SGroverkss       break;
25694750af8SGroverkss     case OrderingKind::GE:
25794750af8SGroverkss     case OrderingKind::LE:
25894750af8SGroverkss     case OrderingKind::EQ:
25994750af8SGroverkss     case OrderingKind::NE:
26094750af8SGroverkss       assert(false && "Not implemented case");
26194750af8SGroverkss     }
26294750af8SGroverkss 
26394750af8SGroverkss     // Union the set with the result.
26494750af8SGroverkss     result.unionInPlace(levelSet);
26594750af8SGroverkss     // The last inequality in `levelSet` is the bound we inserted. We remove
26694750af8SGroverkss     // that for next iteration.
26794750af8SGroverkss     levelSet.removeInequality(levelSet.getNumInequalities() - 1);
26894750af8SGroverkss     // Add equality `outA - outB == 0` for this level for next iteration.
26994750af8SGroverkss     levelSet.addEquality(subExpr);
27094750af8SGroverkss   }
27194750af8SGroverkss 
27294750af8SGroverkss   return result;
27394750af8SGroverkss }
27494750af8SGroverkss 
275d5a29442SArjun P /// Two PWMAFunctions are equal if they have the same dimensionalities,
276d5a29442SArjun P /// the same domain, and take the same value at every point in the domain.
277d5a29442SArjun P bool PWMAFunction::isEqual(const PWMAFunction &other) const {
27820aedb14SGroverkss   if (!space.isCompatible(other.space))
279d5a29442SArjun P     return false;
280d5a29442SArjun P 
281d5a29442SArjun P   if (!this->getDomain().isEqual(other.getDomain()))
282d5a29442SArjun P     return false;
283d5a29442SArjun P 
284d5a29442SArjun P   // Check if, whenever the domains of a piece of `this` and a piece of `other`
285d5a29442SArjun P   // overlap, they take the same output value. If `this` and `other` have the
286d5a29442SArjun P   // same domain (checked above), then this check passes iff the two functions
287d5a29442SArjun P   // have the same output at every point in the domain.
288bb2226acSGroverkss   return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
289bb2226acSGroverkss     return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
290bb2226acSGroverkss       PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
291bb2226acSGroverkss       return pieceA.output.isEqual(pieceB.output, commonDomain);
292bb2226acSGroverkss     });
293bb2226acSGroverkss   });
294d5a29442SArjun P }
295d5a29442SArjun P 
296bb2226acSGroverkss void PWMAFunction::addPiece(const Piece &piece) {
297bb2226acSGroverkss   assert(piece.isConsistent() && "Piece should be consistent");
29894750af8SGroverkss   assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
29994750af8SGroverkss          "Piece should be disjoint from the function");
300266a5a9cSRamkumar Ramachandra   pieces.emplace_back(piece);
301d5a29442SArjun P }
302d5a29442SArjun P 
303d5a29442SArjun P void PWMAFunction::print(raw_ostream &os) const {
304bb2226acSGroverkss   space.print(os);
305bb2226acSGroverkss   os << getNumPieces() << " pieces:\n";
306bb2226acSGroverkss   for (const Piece &piece : pieces) {
307bb2226acSGroverkss     os << "Domain of piece:\n";
308bb2226acSGroverkss     piece.domain.print(os);
309bb2226acSGroverkss     os << "Output of piece\n";
310bb2226acSGroverkss     piece.output.print(os);
311bb2226acSGroverkss   }
312d5a29442SArjun P }
313888894b6SArjun P 
314888894b6SArjun P void PWMAFunction::dump() const { print(llvm::errs()); }
315a18f843fSGroverkss 
316a18f843fSGroverkss PWMAFunction PWMAFunction::unionFunction(
317a18f843fSGroverkss     const PWMAFunction &func,
318bb2226acSGroverkss     llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
319a18f843fSGroverkss   assert(getNumOutputs() == func.getNumOutputs() &&
320bb2226acSGroverkss          "Ranges of functions should be same.");
321a18f843fSGroverkss   assert(getSpace().isCompatible(func.getSpace()) &&
322a18f843fSGroverkss          "Space is not compatible.");
323a18f843fSGroverkss 
324a18f843fSGroverkss   // The algorithm used here is as follows:
325bb2226acSGroverkss   // - Add the output of pieceB for the part of the domain where both pieceA and
326bb2226acSGroverkss   //   pieceB are defined, and `tiebreak` chooses the output of pieceB.
327bb2226acSGroverkss   // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
328bb2226acSGroverkss   // chooses
329bb2226acSGroverkss   //   pieceA over pieceB.
330bb2226acSGroverkss   // - Add the output of pieceB, where pieceA is not defined.
331a18f843fSGroverkss 
332bb2226acSGroverkss   // Add parts of the common domain where pieceB's output is used. Also
333bb2226acSGroverkss   // add all the parts where pieceA's output is used, both common and
334bb2226acSGroverkss   // non-common.
335bb2226acSGroverkss   PWMAFunction result(getSpace());
336bb2226acSGroverkss   for (const Piece &pieceA : pieces) {
337bb2226acSGroverkss     PresburgerSet dom(pieceA.domain);
338bb2226acSGroverkss     for (const Piece &pieceB : func.pieces) {
339bb2226acSGroverkss       PresburgerSet better = tiebreak(pieceB, pieceA);
340bb2226acSGroverkss       // Add the output of pieceB, where it is better than output of pieceA.
341a18f843fSGroverkss       // The disjuncts in "better" will be disjoint as tiebreak should gurantee
342a18f843fSGroverkss       // that.
343bb2226acSGroverkss       result.addPiece({better, pieceB.output});
344a18f843fSGroverkss       dom = dom.subtract(better);
345a18f843fSGroverkss     }
346bb2226acSGroverkss     // Add output of pieceA, where it is better than pieceB, or pieceB is not
347a18f843fSGroverkss     // defined.
348a18f843fSGroverkss     //
349a18f843fSGroverkss     // `dom` here is guranteed to be disjoint from already added pieces
3507557530fSFangrui Song     // because the pieces added before are either:
351a18f843fSGroverkss     // - Subsets of the domain of other MAFs in `this`, which are guranteed
352a18f843fSGroverkss     //   to be disjoint from `dom`, or
353bb2226acSGroverkss     // - They are one of the pieces added for `pieceB`, and we have been
354a18f843fSGroverkss     //   subtracting all such pieces from `dom`, so `dom` is disjoint from those
355a18f843fSGroverkss     //   pieces as well.
356bb2226acSGroverkss     result.addPiece({dom, pieceA.output});
357a18f843fSGroverkss   }
358a18f843fSGroverkss 
359bb2226acSGroverkss   // Add parts of pieceB which are not shared with pieceA.
360a18f843fSGroverkss   PresburgerSet dom = getDomain();
361bb2226acSGroverkss   for (const Piece &pieceB : func.pieces)
362bb2226acSGroverkss     result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
363a18f843fSGroverkss 
364a18f843fSGroverkss   return result;
365a18f843fSGroverkss }
366a18f843fSGroverkss 
367a18f843fSGroverkss /// A tiebreak function which breaks ties by comparing the outputs
36894750af8SGroverkss /// lexicographically based on the given comparison operator.
36994750af8SGroverkss /// This is templated since it is passed as a lambda.
37094750af8SGroverkss template <OrderingKind comp>
371bb2226acSGroverkss static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
372bb2226acSGroverkss                                  const PWMAFunction::Piece &pieceB) {
37394750af8SGroverkss   PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
374bb2226acSGroverkss   result = result.intersect(pieceA.domain).intersect(pieceB.domain);
375a18f843fSGroverkss 
376a18f843fSGroverkss   return result;
377a18f843fSGroverkss }
378a18f843fSGroverkss 
379a18f843fSGroverkss PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
38094750af8SGroverkss   return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
381a18f843fSGroverkss }
382a18f843fSGroverkss 
383a18f843fSGroverkss PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
38494750af8SGroverkss   return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
385a18f843fSGroverkss }
386bb2226acSGroverkss 
387bb2226acSGroverkss void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
388bb2226acSGroverkss   assert(space.isCompatible(other.space) &&
389bb2226acSGroverkss          "Spaces should be compatible for subtraction.");
390bb2226acSGroverkss 
391bb2226acSGroverkss   MultiAffineFunction copyOther = other;
392bb2226acSGroverkss   mergeDivs(copyOther);
393bb2226acSGroverkss   for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
3941a0e67d7SRamkumar Ramachandra     output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1));
395bb2226acSGroverkss 
396bb2226acSGroverkss   // Check consistency.
397bb2226acSGroverkss   assertIsConsistent();
398bb2226acSGroverkss }
399bb2226acSGroverkss 
400bb2226acSGroverkss /// Adds division constraints corresponding to local variables, given a
401bb2226acSGroverkss /// relation and division representations of the local variables in the
402bb2226acSGroverkss /// relation.
403bb2226acSGroverkss static void addDivisionConstraints(IntegerRelation &rel,
404bb2226acSGroverkss                                    const DivisionRepr &divs) {
405bb2226acSGroverkss   assert(divs.hasAllReprs() &&
406bb2226acSGroverkss          "All divisions in divs should have a representation");
407bb2226acSGroverkss   assert(rel.getNumVars() == divs.getNumVars() &&
408bb2226acSGroverkss          "Relation and divs should have the same number of vars");
409bb2226acSGroverkss   assert(rel.getNumLocalVars() == divs.getNumDivs() &&
410bb2226acSGroverkss          "Relation and divs should have the same number of local vars");
411bb2226acSGroverkss 
412bb2226acSGroverkss   for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
413bb2226acSGroverkss     rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
414bb2226acSGroverkss                                        divs.getDivOffset() + i));
415bb2226acSGroverkss     rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
416bb2226acSGroverkss                                        divs.getDivOffset() + i));
417bb2226acSGroverkss   }
418bb2226acSGroverkss }
419bb2226acSGroverkss 
420bb2226acSGroverkss IntegerRelation MultiAffineFunction::getAsRelation() const {
421bb2226acSGroverkss   // Create a relation corressponding to the input space plus the divisions
422bb2226acSGroverkss   // used in outputs.
423bb2226acSGroverkss   IntegerRelation result(PresburgerSpace::getRelationSpace(
424bb2226acSGroverkss       space.getNumDomainVars(), 0, space.getNumSymbolVars(),
425bb2226acSGroverkss       space.getNumLocalVars()));
426bb2226acSGroverkss   // Add division constraints corresponding to divisions used in outputs.
427bb2226acSGroverkss   addDivisionConstraints(result, divs);
428bb2226acSGroverkss   // The outputs are represented as range variables in the relation. We add
429bb2226acSGroverkss   // range variables for the outputs.
430bb2226acSGroverkss   result.insertVar(VarKind::Range, 0, getNumOutputs());
431bb2226acSGroverkss 
432bb2226acSGroverkss   // Add equalities such that the i^th range variable is equal to the i^th
433bb2226acSGroverkss   // output expression.
4341a0e67d7SRamkumar Ramachandra   SmallVector<DynamicAPInt, 8> eq(result.getNumCols());
435bb2226acSGroverkss   for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
436bb2226acSGroverkss     // TODO: Add functions to get VarKind offsets in output in MAF and use them
437bb2226acSGroverkss     // here.
438bb2226acSGroverkss     // The output expression does not contain range variables, while the
439bb2226acSGroverkss     // equality does. So, we need to copy all variables and mark all range
440bb2226acSGroverkss     // variables as 0 in the equality.
4411a0e67d7SRamkumar Ramachandra     ArrayRef<DynamicAPInt> expr = getOutputExpr(i);
442bb2226acSGroverkss     // Copy domain variables in `expr` to domain variables in `eq`.
443bb2226acSGroverkss     std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
444bb2226acSGroverkss     // Fill the range variables in `eq` as zero.
445bb2226acSGroverkss     std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
446bb2226acSGroverkss               eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
447bb2226acSGroverkss     // Copy remaining variables in `expr` to the remaining variables in `eq`.
448bb2226acSGroverkss     std::copy(expr.begin() + getNumDomainVars(), expr.end(),
449bb2226acSGroverkss               eq.begin() + result.getVarKindEnd(VarKind::Range));
450bb2226acSGroverkss 
451bb2226acSGroverkss     // Set the i^th range var to -1 in `eq` to equate the output expression to
452bb2226acSGroverkss     // this range var.
453bb2226acSGroverkss     eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
454bb2226acSGroverkss     // Add the equality `rangeVar_i = output[i]`.
455bb2226acSGroverkss     result.addEquality(eq);
456bb2226acSGroverkss   }
457bb2226acSGroverkss 
458bb2226acSGroverkss   return result;
459bb2226acSGroverkss }
460bb2226acSGroverkss 
461bb2226acSGroverkss void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
462bb2226acSGroverkss   space.removeVarRange(VarKind::Range, start, end);
463bb2226acSGroverkss   for (Piece &piece : pieces)
464bb2226acSGroverkss     piece.output.removeOutputs(start, end);
465bb2226acSGroverkss }
466bb2226acSGroverkss 
4671a0e67d7SRamkumar Ramachandra std::optional<SmallVector<DynamicAPInt, 8>>
4681a0e67d7SRamkumar Ramachandra PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
469bb2226acSGroverkss   assert(point.size() == getNumDomainVars() + getNumSymbolVars());
470bb2226acSGroverkss 
471bb2226acSGroverkss   for (const Piece &piece : pieces)
472bb2226acSGroverkss     if (piece.domain.containsPoint(point))
473bb2226acSGroverkss       return piece.output.valueAt(point);
4741a36588eSKazu Hirata   return std::nullopt;
475bb2226acSGroverkss }
476