1d5a29442SArjun P //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// 2d5a29442SArjun P // 3d5a29442SArjun P // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d5a29442SArjun P // See https://llvm.org/LICENSE.txt for license information. 5d5a29442SArjun P // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d5a29442SArjun P // 7d5a29442SArjun P //===----------------------------------------------------------------------===// 8d5a29442SArjun P 9d5a29442SArjun P #include "mlir/Analysis/Presburger/PWMAFunction.h" 102ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/IntegerRelation.h" 112ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/PresburgerRelation.h" 122ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/PresburgerSpace.h" 132ee87cd6SMehdi Amini #include "mlir/Analysis/Presburger/Utils.h" 142ee87cd6SMehdi Amini #include "llvm/ADT/STLExtras.h" 152ee87cd6SMehdi Amini #include "llvm/ADT/STLFunctionalExtras.h" 162ee87cd6SMehdi Amini #include "llvm/ADT/SmallVector.h" 172ee87cd6SMehdi Amini #include "llvm/Support/raw_ostream.h" 182ee87cd6SMehdi Amini #include <algorithm> 192ee87cd6SMehdi Amini #include <cassert> 20a1fe1f5fSKazu Hirata #include <optional> 21d5a29442SArjun P 22d5a29442SArjun P using namespace mlir; 230c1f6865SGroverkss using namespace presburger; 24d5a29442SArjun P 25bb2226acSGroverkss void MultiAffineFunction::assertIsConsistent() const { 26bb2226acSGroverkss assert(space.getNumVars() - space.getNumRangeVars() + 1 == 27bb2226acSGroverkss output.getNumColumns() && 28bb2226acSGroverkss "Inconsistent number of output columns"); 29bb2226acSGroverkss assert(space.getNumDomainVars() + space.getNumSymbolVars() == 30bb2226acSGroverkss divs.getNumNonDivs() && 31bb2226acSGroverkss "Inconsistent number of non-division variables in divs"); 32bb2226acSGroverkss assert(space.getNumRangeVars() == output.getNumRows() && 33bb2226acSGroverkss "Inconsistent number of output rows"); 34bb2226acSGroverkss assert(space.getNumLocalVars() == divs.getNumDivs() && 35bb2226acSGroverkss "Inconsistent number of divisions."); 36bb2226acSGroverkss assert(divs.hasAllReprs() && "All divisions should have a representation"); 37bb2226acSGroverkss } 38bb2226acSGroverkss 39d5a29442SArjun P // Return the result of subtracting the two given vectors pointwise. 40d5a29442SArjun P // The vectors must be of the same size. 41d5a29442SArjun P // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. 421a0e67d7SRamkumar Ramachandra static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA, 431a0e67d7SRamkumar Ramachandra ArrayRef<DynamicAPInt> vecB) { 44d5a29442SArjun P assert(vecA.size() == vecB.size() && 45d5a29442SArjun P "Cannot subtract vectors of differing lengths!"); 461a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> result; 47d5a29442SArjun P result.reserve(vecA.size()); 48d5a29442SArjun P for (unsigned i = 0, e = vecA.size(); i < e; ++i) 49266a5a9cSRamkumar Ramachandra result.emplace_back(vecA[i] - vecB[i]); 50d5a29442SArjun P return result; 51d5a29442SArjun P } 52d5a29442SArjun P 53d5a29442SArjun P PresburgerSet PWMAFunction::getDomain() const { 54bb2226acSGroverkss PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace()); 55bb2226acSGroverkss for (const Piece &piece : pieces) 56bb2226acSGroverkss domain.unionInPlace(piece.domain); 57d5a29442SArjun P return domain; 58d5a29442SArjun P } 59d5a29442SArjun P 60bb2226acSGroverkss void MultiAffineFunction::print(raw_ostream &os) const { 61bb2226acSGroverkss space.print(os); 62bb2226acSGroverkss os << "Division Representation:\n"; 63bb2226acSGroverkss divs.print(os); 64bb2226acSGroverkss os << "Output:\n"; 65bb2226acSGroverkss output.print(os); 66bb2226acSGroverkss } 67bb2226acSGroverkss 68*832ccfe5SChristopher Bate void MultiAffineFunction::dump() const { print(llvm::errs()); } 69*832ccfe5SChristopher Bate 701a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> 711a0e67d7SRamkumar Ramachandra MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const { 72bb2226acSGroverkss assert(point.size() == getNumDomainVars() + getNumSymbolVars() && 734418669fSArjun P "Point has incorrect dimensionality!"); 74d5a29442SArjun P 751a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)}; 76bb2226acSGroverkss // Get the division values at this point. 771a0e67d7SRamkumar Ramachandra SmallVector<std::optional<DynamicAPInt>, 8> divValues = 781a0e67d7SRamkumar Ramachandra divs.divValuesAt(point); 79bb2226acSGroverkss // The given point didn't include the values of the divs which the output is a 80bb2226acSGroverkss // function of; we have computed one possible set of values and use them here. 81bb2226acSGroverkss pointHomogenous.reserve(pointHomogenous.size() + divValues.size()); 821a0e67d7SRamkumar Ramachandra for (const std::optional<DynamicAPInt> &divVal : divValues) 83266a5a9cSRamkumar Ramachandra pointHomogenous.emplace_back(*divVal); 84d5a29442SArjun P // The matrix `output` has an affine expression in the ith row, corresponding 85d5a29442SArjun P // to the expression for the ith value in the output vector. The last column 86d5a29442SArjun P // of the matrix contains the constant term. Let v be the input point with 87d5a29442SArjun P // a 1 appended at the end. We can see that output * v gives the desired 88d5a29442SArjun P // output vector. 896d6f6c4dSArjun P pointHomogenous.emplace_back(1); 901a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> result = 911a0e67d7SRamkumar Ramachandra output.postMultiplyWithColumn(pointHomogenous); 92d5a29442SArjun P assert(result.size() == getNumOutputs()); 93d5a29442SArjun P return result; 94d5a29442SArjun P } 95d5a29442SArjun P 96d5a29442SArjun P bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { 97bb2226acSGroverkss assert(space.isCompatible(other.space) && 98bb2226acSGroverkss "Spaces should be compatible for equality check."); 99bb2226acSGroverkss return getAsRelation().isEqual(other.getAsRelation()); 100d5a29442SArjun P } 101d5a29442SArjun P 102bb2226acSGroverkss bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, 103bb2226acSGroverkss const IntegerPolyhedron &domain) const { 104bb2226acSGroverkss assert(space.isCompatible(other.space) && 105bb2226acSGroverkss "Spaces should be compatible for equality check."); 106bb2226acSGroverkss IntegerRelation restrictedThis = getAsRelation(); 107bb2226acSGroverkss restrictedThis.intersectDomain(domain); 108bb2226acSGroverkss 109bb2226acSGroverkss IntegerRelation restrictedOther = other.getAsRelation(); 110bb2226acSGroverkss restrictedOther.intersectDomain(domain); 111bb2226acSGroverkss 112bb2226acSGroverkss return restrictedThis.isEqual(restrictedOther); 113d5a29442SArjun P } 114d5a29442SArjun P 115bb2226acSGroverkss bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, 116bb2226acSGroverkss const PresburgerSet &domain) const { 117bb2226acSGroverkss assert(space.isCompatible(other.space) && 118bb2226acSGroverkss "Spaces should be compatible for equality check."); 119bb2226acSGroverkss return llvm::all_of(domain.getAllDisjuncts(), 120bb2226acSGroverkss [&](const IntegerRelation &disjunct) { 121bb2226acSGroverkss return isEqual(other, IntegerPolyhedron(disjunct)); 122bb2226acSGroverkss }); 123d5a29442SArjun P } 124d5a29442SArjun P 125bb2226acSGroverkss void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { 126bb2226acSGroverkss assert(end <= getNumOutputs() && "Invalid range"); 127bb2226acSGroverkss 128bb2226acSGroverkss if (start >= end) 129bb2226acSGroverkss return; 130bb2226acSGroverkss 131bb2226acSGroverkss space.removeVarRange(VarKind::Range, start, end); 132bb2226acSGroverkss output.removeRows(start, end - start); 13379ad5fb2SArjun P } 13479ad5fb2SArjun P 135bb2226acSGroverkss void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { 136bb2226acSGroverkss assert(space.isCompatible(other.space) && "Functions should be compatible"); 137bb2226acSGroverkss 138bb2226acSGroverkss unsigned nDivs = getNumDivs(); 139bb2226acSGroverkss unsigned divOffset = divs.getDivOffset(); 140bb2226acSGroverkss 141bb2226acSGroverkss other.divs.insertDiv(0, nDivs); 142bb2226acSGroverkss 1431a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1); 144bb2226acSGroverkss for (unsigned i = 0; i < nDivs; ++i) { 145bb2226acSGroverkss // Zero fill. 146bb2226acSGroverkss std::fill(div.begin(), div.end(), 0); 147bb2226acSGroverkss // Fill div with dividend from `divs`. Do not fill the constant. 148bb2226acSGroverkss std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1, 149bb2226acSGroverkss div.begin()); 150bb2226acSGroverkss // Fill constant. 151bb2226acSGroverkss div.back() = divs.getDividend(i).back(); 152bb2226acSGroverkss other.divs.setDiv(i, div, divs.getDenom(i)); 15379ad5fb2SArjun P } 15479ad5fb2SArjun P 155bb2226acSGroverkss other.space.insertVar(VarKind::Local, 0, nDivs); 156bb2226acSGroverkss other.output.insertColumns(divOffset, nDivs); 15715650b32SGroverkss 158bb2226acSGroverkss auto merge = [&](unsigned i, unsigned j) { 159bb2226acSGroverkss // We only merge from local at pos j to local at pos i, where j > i. 160bb2226acSGroverkss if (i >= j) 161bb2226acSGroverkss return false; 16215650b32SGroverkss 163bb2226acSGroverkss // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we 164bb2226acSGroverkss // do not want to merge duplicates in `this`, we ignore this call. 165bb2226acSGroverkss if (j < nDivs) 166bb2226acSGroverkss return false; 16715650b32SGroverkss 168bb2226acSGroverkss // Merge things in space and output. 169bb2226acSGroverkss other.space.removeVarRange(VarKind::Local, j, j + 1); 170bb2226acSGroverkss other.output.addToColumn(divOffset + i, divOffset + j, 1); 171bb2226acSGroverkss other.output.removeColumn(divOffset + j); 17215650b32SGroverkss return true; 17315650b32SGroverkss }; 17415650b32SGroverkss 175bb2226acSGroverkss other.divs.removeDuplicateDivs(merge); 17615650b32SGroverkss 177bb2226acSGroverkss unsigned newDivs = other.divs.getNumDivs() - nDivs; 178d5a29442SArjun P 179bb2226acSGroverkss space.insertVar(VarKind::Local, nDivs, newDivs); 180bb2226acSGroverkss output.insertColumns(divOffset + nDivs, newDivs); 181bb2226acSGroverkss divs = other.divs; 182d5a29442SArjun P 183bb2226acSGroverkss // Check consistency. 184bb2226acSGroverkss assertIsConsistent(); 185bb2226acSGroverkss other.assertIsConsistent(); 186d5a29442SArjun P } 187d5a29442SArjun P 18894750af8SGroverkss PresburgerSet 18994750af8SGroverkss MultiAffineFunction::getLexSet(OrderingKind comp, 19094750af8SGroverkss const MultiAffineFunction &other) const { 19194750af8SGroverkss assert(getSpace().isCompatible(other.getSpace()) && 19294750af8SGroverkss "Output space of funcs should be compatible"); 19394750af8SGroverkss 19494750af8SGroverkss // Create copies of functions and merge their local space. 19594750af8SGroverkss MultiAffineFunction funcA = *this; 19694750af8SGroverkss MultiAffineFunction funcB = other; 19794750af8SGroverkss funcA.mergeDivs(funcB); 19894750af8SGroverkss 19994750af8SGroverkss // We first create the set `result`, corresponding to the set where output 20094750af8SGroverkss // of funcA is lexicographically larger/smaller than funcB. This is done by 20194750af8SGroverkss // creating a PresburgerSet with the following constraints: 20294750af8SGroverkss // 20394750af8SGroverkss // (outA[0] > outB[0]) U 20494750af8SGroverkss // (outA[0] = outB[0], outA[1] > outA[1]) U 20594750af8SGroverkss // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U 20694750af8SGroverkss // ... 20794750af8SGroverkss // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) 20894750af8SGroverkss // 20994750af8SGroverkss // where `n` is the number of outputs. 21094750af8SGroverkss // If `lexMin` is set, the complement inequality is used: 21194750af8SGroverkss // 21294750af8SGroverkss // (outA[0] < outB[0]) U 21394750af8SGroverkss // (outA[0] = outB[0], outA[1] < outA[1]) U 21494750af8SGroverkss // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U 21594750af8SGroverkss // ... 21694750af8SGroverkss // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) 21794750af8SGroverkss PresburgerSpace resultSpace = funcA.getDomainSpace(); 21894750af8SGroverkss PresburgerSet result = 21994750af8SGroverkss PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals()); 22094750af8SGroverkss IntegerPolyhedron levelSet( 22194750af8SGroverkss /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(), 22294750af8SGroverkss /*numReservedEqualities=*/funcA.getNumOutputs(), 22394750af8SGroverkss /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace); 22494750af8SGroverkss 22594750af8SGroverkss // Add division inequalities to `levelSet`. 22694750af8SGroverkss for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) { 22794750af8SGroverkss levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i), 22894750af8SGroverkss funcA.divs.getDenom(i), 22994750af8SGroverkss funcA.divs.getDivOffset() + i)); 23094750af8SGroverkss levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i), 23194750af8SGroverkss funcA.divs.getDenom(i), 23294750af8SGroverkss funcA.divs.getDivOffset() + i)); 23394750af8SGroverkss } 23494750af8SGroverkss 23594750af8SGroverkss for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { 23694750af8SGroverkss // Create the expression `outA - outB` for this level. 2371a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> subExpr = 23894750af8SGroverkss subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level)); 23994750af8SGroverkss 24094750af8SGroverkss // TODO: Implement all comparison cases. 24194750af8SGroverkss switch (comp) { 24294750af8SGroverkss case OrderingKind::LT: 24394750af8SGroverkss // For less than, we add an upper bound of -1: 24494750af8SGroverkss // outA - outB <= -1 24594750af8SGroverkss // outA <= outB - 1 24694750af8SGroverkss // outA < outB 2471a0e67d7SRamkumar Ramachandra levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1)); 24894750af8SGroverkss break; 24994750af8SGroverkss case OrderingKind::GT: 25094750af8SGroverkss // For greater than, we add a lower bound of 1: 25194750af8SGroverkss // outA - outB >= 1 25294750af8SGroverkss // outA > outB + 1 25394750af8SGroverkss // outA > outB 2541a0e67d7SRamkumar Ramachandra levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1)); 25594750af8SGroverkss break; 25694750af8SGroverkss case OrderingKind::GE: 25794750af8SGroverkss case OrderingKind::LE: 25894750af8SGroverkss case OrderingKind::EQ: 25994750af8SGroverkss case OrderingKind::NE: 26094750af8SGroverkss assert(false && "Not implemented case"); 26194750af8SGroverkss } 26294750af8SGroverkss 26394750af8SGroverkss // Union the set with the result. 26494750af8SGroverkss result.unionInPlace(levelSet); 26594750af8SGroverkss // The last inequality in `levelSet` is the bound we inserted. We remove 26694750af8SGroverkss // that for next iteration. 26794750af8SGroverkss levelSet.removeInequality(levelSet.getNumInequalities() - 1); 26894750af8SGroverkss // Add equality `outA - outB == 0` for this level for next iteration. 26994750af8SGroverkss levelSet.addEquality(subExpr); 27094750af8SGroverkss } 27194750af8SGroverkss 27294750af8SGroverkss return result; 27394750af8SGroverkss } 27494750af8SGroverkss 275d5a29442SArjun P /// Two PWMAFunctions are equal if they have the same dimensionalities, 276d5a29442SArjun P /// the same domain, and take the same value at every point in the domain. 277d5a29442SArjun P bool PWMAFunction::isEqual(const PWMAFunction &other) const { 27820aedb14SGroverkss if (!space.isCompatible(other.space)) 279d5a29442SArjun P return false; 280d5a29442SArjun P 281d5a29442SArjun P if (!this->getDomain().isEqual(other.getDomain())) 282d5a29442SArjun P return false; 283d5a29442SArjun P 284d5a29442SArjun P // Check if, whenever the domains of a piece of `this` and a piece of `other` 285d5a29442SArjun P // overlap, they take the same output value. If `this` and `other` have the 286d5a29442SArjun P // same domain (checked above), then this check passes iff the two functions 287d5a29442SArjun P // have the same output at every point in the domain. 288bb2226acSGroverkss return llvm::all_of(this->pieces, [&other](const Piece &pieceA) { 289bb2226acSGroverkss return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) { 290bb2226acSGroverkss PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain); 291bb2226acSGroverkss return pieceA.output.isEqual(pieceB.output, commonDomain); 292bb2226acSGroverkss }); 293bb2226acSGroverkss }); 294d5a29442SArjun P } 295d5a29442SArjun P 296bb2226acSGroverkss void PWMAFunction::addPiece(const Piece &piece) { 297bb2226acSGroverkss assert(piece.isConsistent() && "Piece should be consistent"); 29894750af8SGroverkss assert(piece.domain.intersect(getDomain()).isIntegerEmpty() && 29994750af8SGroverkss "Piece should be disjoint from the function"); 300266a5a9cSRamkumar Ramachandra pieces.emplace_back(piece); 301d5a29442SArjun P } 302d5a29442SArjun P 303d5a29442SArjun P void PWMAFunction::print(raw_ostream &os) const { 304bb2226acSGroverkss space.print(os); 305bb2226acSGroverkss os << getNumPieces() << " pieces:\n"; 306bb2226acSGroverkss for (const Piece &piece : pieces) { 307bb2226acSGroverkss os << "Domain of piece:\n"; 308bb2226acSGroverkss piece.domain.print(os); 309bb2226acSGroverkss os << "Output of piece\n"; 310bb2226acSGroverkss piece.output.print(os); 311bb2226acSGroverkss } 312d5a29442SArjun P } 313888894b6SArjun P 314888894b6SArjun P void PWMAFunction::dump() const { print(llvm::errs()); } 315a18f843fSGroverkss 316a18f843fSGroverkss PWMAFunction PWMAFunction::unionFunction( 317a18f843fSGroverkss const PWMAFunction &func, 318bb2226acSGroverkss llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const { 319a18f843fSGroverkss assert(getNumOutputs() == func.getNumOutputs() && 320bb2226acSGroverkss "Ranges of functions should be same."); 321a18f843fSGroverkss assert(getSpace().isCompatible(func.getSpace()) && 322a18f843fSGroverkss "Space is not compatible."); 323a18f843fSGroverkss 324a18f843fSGroverkss // The algorithm used here is as follows: 325bb2226acSGroverkss // - Add the output of pieceB for the part of the domain where both pieceA and 326bb2226acSGroverkss // pieceB are defined, and `tiebreak` chooses the output of pieceB. 327bb2226acSGroverkss // - Add the output of pieceA, where pieceB is not defined or `tiebreak` 328bb2226acSGroverkss // chooses 329bb2226acSGroverkss // pieceA over pieceB. 330bb2226acSGroverkss // - Add the output of pieceB, where pieceA is not defined. 331a18f843fSGroverkss 332bb2226acSGroverkss // Add parts of the common domain where pieceB's output is used. Also 333bb2226acSGroverkss // add all the parts where pieceA's output is used, both common and 334bb2226acSGroverkss // non-common. 335bb2226acSGroverkss PWMAFunction result(getSpace()); 336bb2226acSGroverkss for (const Piece &pieceA : pieces) { 337bb2226acSGroverkss PresburgerSet dom(pieceA.domain); 338bb2226acSGroverkss for (const Piece &pieceB : func.pieces) { 339bb2226acSGroverkss PresburgerSet better = tiebreak(pieceB, pieceA); 340bb2226acSGroverkss // Add the output of pieceB, where it is better than output of pieceA. 341a18f843fSGroverkss // The disjuncts in "better" will be disjoint as tiebreak should gurantee 342a18f843fSGroverkss // that. 343bb2226acSGroverkss result.addPiece({better, pieceB.output}); 344a18f843fSGroverkss dom = dom.subtract(better); 345a18f843fSGroverkss } 346bb2226acSGroverkss // Add output of pieceA, where it is better than pieceB, or pieceB is not 347a18f843fSGroverkss // defined. 348a18f843fSGroverkss // 349a18f843fSGroverkss // `dom` here is guranteed to be disjoint from already added pieces 3507557530fSFangrui Song // because the pieces added before are either: 351a18f843fSGroverkss // - Subsets of the domain of other MAFs in `this`, which are guranteed 352a18f843fSGroverkss // to be disjoint from `dom`, or 353bb2226acSGroverkss // - They are one of the pieces added for `pieceB`, and we have been 354a18f843fSGroverkss // subtracting all such pieces from `dom`, so `dom` is disjoint from those 355a18f843fSGroverkss // pieces as well. 356bb2226acSGroverkss result.addPiece({dom, pieceA.output}); 357a18f843fSGroverkss } 358a18f843fSGroverkss 359bb2226acSGroverkss // Add parts of pieceB which are not shared with pieceA. 360a18f843fSGroverkss PresburgerSet dom = getDomain(); 361bb2226acSGroverkss for (const Piece &pieceB : func.pieces) 362bb2226acSGroverkss result.addPiece({pieceB.domain.subtract(dom), pieceB.output}); 363a18f843fSGroverkss 364a18f843fSGroverkss return result; 365a18f843fSGroverkss } 366a18f843fSGroverkss 367a18f843fSGroverkss /// A tiebreak function which breaks ties by comparing the outputs 36894750af8SGroverkss /// lexicographically based on the given comparison operator. 36994750af8SGroverkss /// This is templated since it is passed as a lambda. 37094750af8SGroverkss template <OrderingKind comp> 371bb2226acSGroverkss static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, 372bb2226acSGroverkss const PWMAFunction::Piece &pieceB) { 37394750af8SGroverkss PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output); 374bb2226acSGroverkss result = result.intersect(pieceA.domain).intersect(pieceB.domain); 375a18f843fSGroverkss 376a18f843fSGroverkss return result; 377a18f843fSGroverkss } 378a18f843fSGroverkss 379a18f843fSGroverkss PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { 38094750af8SGroverkss return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>); 381a18f843fSGroverkss } 382a18f843fSGroverkss 383a18f843fSGroverkss PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { 38494750af8SGroverkss return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>); 385a18f843fSGroverkss } 386bb2226acSGroverkss 387bb2226acSGroverkss void MultiAffineFunction::subtract(const MultiAffineFunction &other) { 388bb2226acSGroverkss assert(space.isCompatible(other.space) && 389bb2226acSGroverkss "Spaces should be compatible for subtraction."); 390bb2226acSGroverkss 391bb2226acSGroverkss MultiAffineFunction copyOther = other; 392bb2226acSGroverkss mergeDivs(copyOther); 393bb2226acSGroverkss for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) 3941a0e67d7SRamkumar Ramachandra output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1)); 395bb2226acSGroverkss 396bb2226acSGroverkss // Check consistency. 397bb2226acSGroverkss assertIsConsistent(); 398bb2226acSGroverkss } 399bb2226acSGroverkss 400bb2226acSGroverkss /// Adds division constraints corresponding to local variables, given a 401bb2226acSGroverkss /// relation and division representations of the local variables in the 402bb2226acSGroverkss /// relation. 403bb2226acSGroverkss static void addDivisionConstraints(IntegerRelation &rel, 404bb2226acSGroverkss const DivisionRepr &divs) { 405bb2226acSGroverkss assert(divs.hasAllReprs() && 406bb2226acSGroverkss "All divisions in divs should have a representation"); 407bb2226acSGroverkss assert(rel.getNumVars() == divs.getNumVars() && 408bb2226acSGroverkss "Relation and divs should have the same number of vars"); 409bb2226acSGroverkss assert(rel.getNumLocalVars() == divs.getNumDivs() && 410bb2226acSGroverkss "Relation and divs should have the same number of local vars"); 411bb2226acSGroverkss 412bb2226acSGroverkss for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { 413bb2226acSGroverkss rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i), 414bb2226acSGroverkss divs.getDivOffset() + i)); 415bb2226acSGroverkss rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i), 416bb2226acSGroverkss divs.getDivOffset() + i)); 417bb2226acSGroverkss } 418bb2226acSGroverkss } 419bb2226acSGroverkss 420bb2226acSGroverkss IntegerRelation MultiAffineFunction::getAsRelation() const { 421bb2226acSGroverkss // Create a relation corressponding to the input space plus the divisions 422bb2226acSGroverkss // used in outputs. 423bb2226acSGroverkss IntegerRelation result(PresburgerSpace::getRelationSpace( 424bb2226acSGroverkss space.getNumDomainVars(), 0, space.getNumSymbolVars(), 425bb2226acSGroverkss space.getNumLocalVars())); 426bb2226acSGroverkss // Add division constraints corresponding to divisions used in outputs. 427bb2226acSGroverkss addDivisionConstraints(result, divs); 428bb2226acSGroverkss // The outputs are represented as range variables in the relation. We add 429bb2226acSGroverkss // range variables for the outputs. 430bb2226acSGroverkss result.insertVar(VarKind::Range, 0, getNumOutputs()); 431bb2226acSGroverkss 432bb2226acSGroverkss // Add equalities such that the i^th range variable is equal to the i^th 433bb2226acSGroverkss // output expression. 4341a0e67d7SRamkumar Ramachandra SmallVector<DynamicAPInt, 8> eq(result.getNumCols()); 435bb2226acSGroverkss for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { 436bb2226acSGroverkss // TODO: Add functions to get VarKind offsets in output in MAF and use them 437bb2226acSGroverkss // here. 438bb2226acSGroverkss // The output expression does not contain range variables, while the 439bb2226acSGroverkss // equality does. So, we need to copy all variables and mark all range 440bb2226acSGroverkss // variables as 0 in the equality. 4411a0e67d7SRamkumar Ramachandra ArrayRef<DynamicAPInt> expr = getOutputExpr(i); 442bb2226acSGroverkss // Copy domain variables in `expr` to domain variables in `eq`. 443bb2226acSGroverkss std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin()); 444bb2226acSGroverkss // Fill the range variables in `eq` as zero. 445bb2226acSGroverkss std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range), 446bb2226acSGroverkss eq.begin() + result.getVarKindEnd(VarKind::Range), 0); 447bb2226acSGroverkss // Copy remaining variables in `expr` to the remaining variables in `eq`. 448bb2226acSGroverkss std::copy(expr.begin() + getNumDomainVars(), expr.end(), 449bb2226acSGroverkss eq.begin() + result.getVarKindEnd(VarKind::Range)); 450bb2226acSGroverkss 451bb2226acSGroverkss // Set the i^th range var to -1 in `eq` to equate the output expression to 452bb2226acSGroverkss // this range var. 453bb2226acSGroverkss eq[result.getVarKindOffset(VarKind::Range) + i] = -1; 454bb2226acSGroverkss // Add the equality `rangeVar_i = output[i]`. 455bb2226acSGroverkss result.addEquality(eq); 456bb2226acSGroverkss } 457bb2226acSGroverkss 458bb2226acSGroverkss return result; 459bb2226acSGroverkss } 460bb2226acSGroverkss 461bb2226acSGroverkss void PWMAFunction::removeOutputs(unsigned start, unsigned end) { 462bb2226acSGroverkss space.removeVarRange(VarKind::Range, start, end); 463bb2226acSGroverkss for (Piece &piece : pieces) 464bb2226acSGroverkss piece.output.removeOutputs(start, end); 465bb2226acSGroverkss } 466bb2226acSGroverkss 4671a0e67d7SRamkumar Ramachandra std::optional<SmallVector<DynamicAPInt, 8>> 4681a0e67d7SRamkumar Ramachandra PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const { 469bb2226acSGroverkss assert(point.size() == getNumDomainVars() + getNumSymbolVars()); 470bb2226acSGroverkss 471bb2226acSGroverkss for (const Piece &piece : pieces) 472bb2226acSGroverkss if (piece.domain.containsPoint(point)) 473bb2226acSGroverkss return piece.output.valueAt(point); 4741a36588eSKazu Hirata return std::nullopt; 475bb2226acSGroverkss } 476