1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/PWMAFunction.h" 10 #include "mlir/Analysis/Presburger/IntegerRelation.h" 11 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 12 #include "mlir/Analysis/Presburger/PresburgerSpace.h" 13 #include "mlir/Analysis/Presburger/Utils.h" 14 #include "llvm/ADT/STLExtras.h" 15 #include "llvm/ADT/STLFunctionalExtras.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include <algorithm> 19 #include <cassert> 20 #include <optional> 21 22 using namespace mlir; 23 using namespace presburger; 24 25 void MultiAffineFunction::assertIsConsistent() const { 26 assert(space.getNumVars() - space.getNumRangeVars() + 1 == 27 output.getNumColumns() && 28 "Inconsistent number of output columns"); 29 assert(space.getNumDomainVars() + space.getNumSymbolVars() == 30 divs.getNumNonDivs() && 31 "Inconsistent number of non-division variables in divs"); 32 assert(space.getNumRangeVars() == output.getNumRows() && 33 "Inconsistent number of output rows"); 34 assert(space.getNumLocalVars() == divs.getNumDivs() && 35 "Inconsistent number of divisions."); 36 assert(divs.hasAllReprs() && "All divisions should have a representation"); 37 } 38 39 // Return the result of subtracting the two given vectors pointwise. 40 // The vectors must be of the same size. 41 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. 42 static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA, 43 ArrayRef<DynamicAPInt> vecB) { 44 assert(vecA.size() == vecB.size() && 45 "Cannot subtract vectors of differing lengths!"); 46 SmallVector<DynamicAPInt, 8> result; 47 result.reserve(vecA.size()); 48 for (unsigned i = 0, e = vecA.size(); i < e; ++i) 49 result.emplace_back(vecA[i] - vecB[i]); 50 return result; 51 } 52 53 PresburgerSet PWMAFunction::getDomain() const { 54 PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace()); 55 for (const Piece &piece : pieces) 56 domain.unionInPlace(piece.domain); 57 return domain; 58 } 59 60 void MultiAffineFunction::print(raw_ostream &os) const { 61 space.print(os); 62 os << "Division Representation:\n"; 63 divs.print(os); 64 os << "Output:\n"; 65 output.print(os); 66 } 67 68 void MultiAffineFunction::dump() const { print(llvm::errs()); } 69 70 SmallVector<DynamicAPInt, 8> 71 MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const { 72 assert(point.size() == getNumDomainVars() + getNumSymbolVars() && 73 "Point has incorrect dimensionality!"); 74 75 SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)}; 76 // Get the division values at this point. 77 SmallVector<std::optional<DynamicAPInt>, 8> divValues = 78 divs.divValuesAt(point); 79 // The given point didn't include the values of the divs which the output is a 80 // function of; we have computed one possible set of values and use them here. 81 pointHomogenous.reserve(pointHomogenous.size() + divValues.size()); 82 for (const std::optional<DynamicAPInt> &divVal : divValues) 83 pointHomogenous.emplace_back(*divVal); 84 // The matrix `output` has an affine expression in the ith row, corresponding 85 // to the expression for the ith value in the output vector. The last column 86 // of the matrix contains the constant term. Let v be the input point with 87 // a 1 appended at the end. We can see that output * v gives the desired 88 // output vector. 89 pointHomogenous.emplace_back(1); 90 SmallVector<DynamicAPInt, 8> result = 91 output.postMultiplyWithColumn(pointHomogenous); 92 assert(result.size() == getNumOutputs()); 93 return result; 94 } 95 96 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { 97 assert(space.isCompatible(other.space) && 98 "Spaces should be compatible for equality check."); 99 return getAsRelation().isEqual(other.getAsRelation()); 100 } 101 102 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, 103 const IntegerPolyhedron &domain) const { 104 assert(space.isCompatible(other.space) && 105 "Spaces should be compatible for equality check."); 106 IntegerRelation restrictedThis = getAsRelation(); 107 restrictedThis.intersectDomain(domain); 108 109 IntegerRelation restrictedOther = other.getAsRelation(); 110 restrictedOther.intersectDomain(domain); 111 112 return restrictedThis.isEqual(restrictedOther); 113 } 114 115 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, 116 const PresburgerSet &domain) const { 117 assert(space.isCompatible(other.space) && 118 "Spaces should be compatible for equality check."); 119 return llvm::all_of(domain.getAllDisjuncts(), 120 [&](const IntegerRelation &disjunct) { 121 return isEqual(other, IntegerPolyhedron(disjunct)); 122 }); 123 } 124 125 void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { 126 assert(end <= getNumOutputs() && "Invalid range"); 127 128 if (start >= end) 129 return; 130 131 space.removeVarRange(VarKind::Range, start, end); 132 output.removeRows(start, end - start); 133 } 134 135 void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { 136 assert(space.isCompatible(other.space) && "Functions should be compatible"); 137 138 unsigned nDivs = getNumDivs(); 139 unsigned divOffset = divs.getDivOffset(); 140 141 other.divs.insertDiv(0, nDivs); 142 143 SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1); 144 for (unsigned i = 0; i < nDivs; ++i) { 145 // Zero fill. 146 std::fill(div.begin(), div.end(), 0); 147 // Fill div with dividend from `divs`. Do not fill the constant. 148 std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1, 149 div.begin()); 150 // Fill constant. 151 div.back() = divs.getDividend(i).back(); 152 other.divs.setDiv(i, div, divs.getDenom(i)); 153 } 154 155 other.space.insertVar(VarKind::Local, 0, nDivs); 156 other.output.insertColumns(divOffset, nDivs); 157 158 auto merge = [&](unsigned i, unsigned j) { 159 // We only merge from local at pos j to local at pos i, where j > i. 160 if (i >= j) 161 return false; 162 163 // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we 164 // do not want to merge duplicates in `this`, we ignore this call. 165 if (j < nDivs) 166 return false; 167 168 // Merge things in space and output. 169 other.space.removeVarRange(VarKind::Local, j, j + 1); 170 other.output.addToColumn(divOffset + i, divOffset + j, 1); 171 other.output.removeColumn(divOffset + j); 172 return true; 173 }; 174 175 other.divs.removeDuplicateDivs(merge); 176 177 unsigned newDivs = other.divs.getNumDivs() - nDivs; 178 179 space.insertVar(VarKind::Local, nDivs, newDivs); 180 output.insertColumns(divOffset + nDivs, newDivs); 181 divs = other.divs; 182 183 // Check consistency. 184 assertIsConsistent(); 185 other.assertIsConsistent(); 186 } 187 188 PresburgerSet 189 MultiAffineFunction::getLexSet(OrderingKind comp, 190 const MultiAffineFunction &other) const { 191 assert(getSpace().isCompatible(other.getSpace()) && 192 "Output space of funcs should be compatible"); 193 194 // Create copies of functions and merge their local space. 195 MultiAffineFunction funcA = *this; 196 MultiAffineFunction funcB = other; 197 funcA.mergeDivs(funcB); 198 199 // We first create the set `result`, corresponding to the set where output 200 // of funcA is lexicographically larger/smaller than funcB. This is done by 201 // creating a PresburgerSet with the following constraints: 202 // 203 // (outA[0] > outB[0]) U 204 // (outA[0] = outB[0], outA[1] > outA[1]) U 205 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U 206 // ... 207 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) 208 // 209 // where `n` is the number of outputs. 210 // If `lexMin` is set, the complement inequality is used: 211 // 212 // (outA[0] < outB[0]) U 213 // (outA[0] = outB[0], outA[1] < outA[1]) U 214 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U 215 // ... 216 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) 217 PresburgerSpace resultSpace = funcA.getDomainSpace(); 218 PresburgerSet result = 219 PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals()); 220 IntegerPolyhedron levelSet( 221 /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(), 222 /*numReservedEqualities=*/funcA.getNumOutputs(), 223 /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace); 224 225 // Add division inequalities to `levelSet`. 226 for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) { 227 levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i), 228 funcA.divs.getDenom(i), 229 funcA.divs.getDivOffset() + i)); 230 levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i), 231 funcA.divs.getDenom(i), 232 funcA.divs.getDivOffset() + i)); 233 } 234 235 for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { 236 // Create the expression `outA - outB` for this level. 237 SmallVector<DynamicAPInt, 8> subExpr = 238 subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level)); 239 240 // TODO: Implement all comparison cases. 241 switch (comp) { 242 case OrderingKind::LT: 243 // For less than, we add an upper bound of -1: 244 // outA - outB <= -1 245 // outA <= outB - 1 246 // outA < outB 247 levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1)); 248 break; 249 case OrderingKind::GT: 250 // For greater than, we add a lower bound of 1: 251 // outA - outB >= 1 252 // outA > outB + 1 253 // outA > outB 254 levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1)); 255 break; 256 case OrderingKind::GE: 257 case OrderingKind::LE: 258 case OrderingKind::EQ: 259 case OrderingKind::NE: 260 assert(false && "Not implemented case"); 261 } 262 263 // Union the set with the result. 264 result.unionInPlace(levelSet); 265 // The last inequality in `levelSet` is the bound we inserted. We remove 266 // that for next iteration. 267 levelSet.removeInequality(levelSet.getNumInequalities() - 1); 268 // Add equality `outA - outB == 0` for this level for next iteration. 269 levelSet.addEquality(subExpr); 270 } 271 272 return result; 273 } 274 275 /// Two PWMAFunctions are equal if they have the same dimensionalities, 276 /// the same domain, and take the same value at every point in the domain. 277 bool PWMAFunction::isEqual(const PWMAFunction &other) const { 278 if (!space.isCompatible(other.space)) 279 return false; 280 281 if (!this->getDomain().isEqual(other.getDomain())) 282 return false; 283 284 // Check if, whenever the domains of a piece of `this` and a piece of `other` 285 // overlap, they take the same output value. If `this` and `other` have the 286 // same domain (checked above), then this check passes iff the two functions 287 // have the same output at every point in the domain. 288 return llvm::all_of(this->pieces, [&other](const Piece &pieceA) { 289 return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) { 290 PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain); 291 return pieceA.output.isEqual(pieceB.output, commonDomain); 292 }); 293 }); 294 } 295 296 void PWMAFunction::addPiece(const Piece &piece) { 297 assert(piece.isConsistent() && "Piece should be consistent"); 298 assert(piece.domain.intersect(getDomain()).isIntegerEmpty() && 299 "Piece should be disjoint from the function"); 300 pieces.emplace_back(piece); 301 } 302 303 void PWMAFunction::print(raw_ostream &os) const { 304 space.print(os); 305 os << getNumPieces() << " pieces:\n"; 306 for (const Piece &piece : pieces) { 307 os << "Domain of piece:\n"; 308 piece.domain.print(os); 309 os << "Output of piece\n"; 310 piece.output.print(os); 311 } 312 } 313 314 void PWMAFunction::dump() const { print(llvm::errs()); } 315 316 PWMAFunction PWMAFunction::unionFunction( 317 const PWMAFunction &func, 318 llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const { 319 assert(getNumOutputs() == func.getNumOutputs() && 320 "Ranges of functions should be same."); 321 assert(getSpace().isCompatible(func.getSpace()) && 322 "Space is not compatible."); 323 324 // The algorithm used here is as follows: 325 // - Add the output of pieceB for the part of the domain where both pieceA and 326 // pieceB are defined, and `tiebreak` chooses the output of pieceB. 327 // - Add the output of pieceA, where pieceB is not defined or `tiebreak` 328 // chooses 329 // pieceA over pieceB. 330 // - Add the output of pieceB, where pieceA is not defined. 331 332 // Add parts of the common domain where pieceB's output is used. Also 333 // add all the parts where pieceA's output is used, both common and 334 // non-common. 335 PWMAFunction result(getSpace()); 336 for (const Piece &pieceA : pieces) { 337 PresburgerSet dom(pieceA.domain); 338 for (const Piece &pieceB : func.pieces) { 339 PresburgerSet better = tiebreak(pieceB, pieceA); 340 // Add the output of pieceB, where it is better than output of pieceA. 341 // The disjuncts in "better" will be disjoint as tiebreak should gurantee 342 // that. 343 result.addPiece({better, pieceB.output}); 344 dom = dom.subtract(better); 345 } 346 // Add output of pieceA, where it is better than pieceB, or pieceB is not 347 // defined. 348 // 349 // `dom` here is guranteed to be disjoint from already added pieces 350 // because the pieces added before are either: 351 // - Subsets of the domain of other MAFs in `this`, which are guranteed 352 // to be disjoint from `dom`, or 353 // - They are one of the pieces added for `pieceB`, and we have been 354 // subtracting all such pieces from `dom`, so `dom` is disjoint from those 355 // pieces as well. 356 result.addPiece({dom, pieceA.output}); 357 } 358 359 // Add parts of pieceB which are not shared with pieceA. 360 PresburgerSet dom = getDomain(); 361 for (const Piece &pieceB : func.pieces) 362 result.addPiece({pieceB.domain.subtract(dom), pieceB.output}); 363 364 return result; 365 } 366 367 /// A tiebreak function which breaks ties by comparing the outputs 368 /// lexicographically based on the given comparison operator. 369 /// This is templated since it is passed as a lambda. 370 template <OrderingKind comp> 371 static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, 372 const PWMAFunction::Piece &pieceB) { 373 PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output); 374 result = result.intersect(pieceA.domain).intersect(pieceB.domain); 375 376 return result; 377 } 378 379 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { 380 return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>); 381 } 382 383 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { 384 return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>); 385 } 386 387 void MultiAffineFunction::subtract(const MultiAffineFunction &other) { 388 assert(space.isCompatible(other.space) && 389 "Spaces should be compatible for subtraction."); 390 391 MultiAffineFunction copyOther = other; 392 mergeDivs(copyOther); 393 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) 394 output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1)); 395 396 // Check consistency. 397 assertIsConsistent(); 398 } 399 400 /// Adds division constraints corresponding to local variables, given a 401 /// relation and division representations of the local variables in the 402 /// relation. 403 static void addDivisionConstraints(IntegerRelation &rel, 404 const DivisionRepr &divs) { 405 assert(divs.hasAllReprs() && 406 "All divisions in divs should have a representation"); 407 assert(rel.getNumVars() == divs.getNumVars() && 408 "Relation and divs should have the same number of vars"); 409 assert(rel.getNumLocalVars() == divs.getNumDivs() && 410 "Relation and divs should have the same number of local vars"); 411 412 for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { 413 rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i), 414 divs.getDivOffset() + i)); 415 rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i), 416 divs.getDivOffset() + i)); 417 } 418 } 419 420 IntegerRelation MultiAffineFunction::getAsRelation() const { 421 // Create a relation corressponding to the input space plus the divisions 422 // used in outputs. 423 IntegerRelation result(PresburgerSpace::getRelationSpace( 424 space.getNumDomainVars(), 0, space.getNumSymbolVars(), 425 space.getNumLocalVars())); 426 // Add division constraints corresponding to divisions used in outputs. 427 addDivisionConstraints(result, divs); 428 // The outputs are represented as range variables in the relation. We add 429 // range variables for the outputs. 430 result.insertVar(VarKind::Range, 0, getNumOutputs()); 431 432 // Add equalities such that the i^th range variable is equal to the i^th 433 // output expression. 434 SmallVector<DynamicAPInt, 8> eq(result.getNumCols()); 435 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { 436 // TODO: Add functions to get VarKind offsets in output in MAF and use them 437 // here. 438 // The output expression does not contain range variables, while the 439 // equality does. So, we need to copy all variables and mark all range 440 // variables as 0 in the equality. 441 ArrayRef<DynamicAPInt> expr = getOutputExpr(i); 442 // Copy domain variables in `expr` to domain variables in `eq`. 443 std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin()); 444 // Fill the range variables in `eq` as zero. 445 std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range), 446 eq.begin() + result.getVarKindEnd(VarKind::Range), 0); 447 // Copy remaining variables in `expr` to the remaining variables in `eq`. 448 std::copy(expr.begin() + getNumDomainVars(), expr.end(), 449 eq.begin() + result.getVarKindEnd(VarKind::Range)); 450 451 // Set the i^th range var to -1 in `eq` to equate the output expression to 452 // this range var. 453 eq[result.getVarKindOffset(VarKind::Range) + i] = -1; 454 // Add the equality `rangeVar_i = output[i]`. 455 result.addEquality(eq); 456 } 457 458 return result; 459 } 460 461 void PWMAFunction::removeOutputs(unsigned start, unsigned end) { 462 space.removeVarRange(VarKind::Range, start, end); 463 for (Piece &piece : pieces) 464 piece.output.removeOutputs(start, end); 465 } 466 467 std::optional<SmallVector<DynamicAPInt, 8>> 468 PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const { 469 assert(point.size() == getNumDomainVars() + getNumSymbolVars()); 470 471 for (const Piece &piece : pieces) 472 if (piece.domain.containsPoint(point)) 473 return piece.output.valueAt(point); 474 return std::nullopt; 475 } 476