1 //===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/PresburgerSpace.h" 10 #include "llvm/Support/ErrorHandling.h" 11 #include "llvm/Support/raw_ostream.h" 12 #include <algorithm> 13 #include <cassert> 14 15 using namespace mlir; 16 using namespace presburger; 17 18 bool Identifier::isEqual(const Identifier &other) const { 19 if (value == nullptr || other.value == nullptr) 20 return false; 21 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 22 assert(value != other.value || 23 (value == other.value && idType == other.idType && 24 "Values of Identifiers are equal but their types do not match.")); 25 #endif 26 return value == other.value; 27 } 28 29 void Identifier::print(llvm::raw_ostream &os) const { 30 os << "Id<" << value << ">"; 31 } 32 33 void Identifier::dump() const { 34 print(llvm::errs()); 35 llvm::errs() << "\n"; 36 } 37 38 PresburgerSpace PresburgerSpace::getDomainSpace() const { 39 PresburgerSpace newSpace = *this; 40 newSpace.removeVarRange(VarKind::Range, 0, getNumRangeVars()); 41 newSpace.convertVarKind(VarKind::Domain, 0, getNumDomainVars(), 42 VarKind::SetDim, 0); 43 return newSpace; 44 } 45 46 PresburgerSpace PresburgerSpace::getRangeSpace() const { 47 PresburgerSpace newSpace = *this; 48 newSpace.removeVarRange(VarKind::Domain, 0, getNumDomainVars()); 49 return newSpace; 50 } 51 52 PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const { 53 PresburgerSpace space = *this; 54 space.removeVarRange(VarKind::Local, 0, getNumLocalVars()); 55 return space; 56 } 57 58 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const { 59 if (kind == VarKind::Domain) 60 return getNumDomainVars(); 61 if (kind == VarKind::Range) 62 return getNumRangeVars(); 63 if (kind == VarKind::Symbol) 64 return getNumSymbolVars(); 65 if (kind == VarKind::Local) 66 return getNumLocalVars(); 67 llvm_unreachable("VarKind does not exist!"); 68 } 69 70 unsigned PresburgerSpace::getVarKindOffset(VarKind kind) const { 71 if (kind == VarKind::Domain) 72 return 0; 73 if (kind == VarKind::Range) 74 return getNumDomainVars(); 75 if (kind == VarKind::Symbol) 76 return getNumDimVars(); 77 if (kind == VarKind::Local) 78 return getNumDimAndSymbolVars(); 79 llvm_unreachable("VarKind does not exist!"); 80 } 81 82 unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const { 83 return getVarKindOffset(kind) + getNumVarKind(kind); 84 } 85 86 unsigned PresburgerSpace::getVarKindOverlap(VarKind kind, unsigned varStart, 87 unsigned varLimit) const { 88 unsigned varRangeStart = getVarKindOffset(kind); 89 unsigned varRangeEnd = getVarKindEnd(kind); 90 91 // Compute number of elements in intersection of the ranges [varStart, 92 // varLimit) and [varRangeStart, varRangeEnd). 93 unsigned overlapStart = std::max(varStart, varRangeStart); 94 unsigned overlapEnd = std::min(varLimit, varRangeEnd); 95 96 if (overlapStart > overlapEnd) 97 return 0; 98 return overlapEnd - overlapStart; 99 } 100 101 VarKind PresburgerSpace::getVarKindAt(unsigned pos) const { 102 assert(pos < getNumVars() && "`pos` should represent a valid var position"); 103 if (pos < getVarKindEnd(VarKind::Domain)) 104 return VarKind::Domain; 105 if (pos < getVarKindEnd(VarKind::Range)) 106 return VarKind::Range; 107 if (pos < getVarKindEnd(VarKind::Symbol)) 108 return VarKind::Symbol; 109 if (pos < getVarKindEnd(VarKind::Local)) 110 return VarKind::Local; 111 llvm_unreachable("`pos` should represent a valid var position"); 112 } 113 114 unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) { 115 assert(pos <= getNumVarKind(kind)); 116 117 unsigned absolutePos = getVarKindOffset(kind) + pos; 118 119 if (kind == VarKind::Domain) 120 numDomain += num; 121 else if (kind == VarKind::Range) 122 numRange += num; 123 else if (kind == VarKind::Symbol) 124 numSymbols += num; 125 else 126 numLocals += num; 127 128 // Insert NULL identifiers if `usingIds` and variables inserted are 129 // not locals. 130 if (usingIds && kind != VarKind::Local) 131 identifiers.insert(identifiers.begin() + absolutePos, num, Identifier()); 132 133 return absolutePos; 134 } 135 136 void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart, 137 unsigned varLimit) { 138 assert(varLimit <= getNumVarKind(kind) && "invalid var limit"); 139 140 if (varStart >= varLimit) 141 return; 142 143 unsigned numVarsEliminated = varLimit - varStart; 144 if (kind == VarKind::Domain) 145 numDomain -= numVarsEliminated; 146 else if (kind == VarKind::Range) 147 numRange -= numVarsEliminated; 148 else if (kind == VarKind::Symbol) 149 numSymbols -= numVarsEliminated; 150 else 151 numLocals -= numVarsEliminated; 152 153 // Remove identifiers if `usingIds` and variables removed are not 154 // locals. 155 if (usingIds && kind != VarKind::Local) 156 identifiers.erase(identifiers.begin() + getVarKindOffset(kind) + varStart, 157 identifiers.begin() + getVarKindOffset(kind) + varLimit); 158 } 159 160 void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos, 161 unsigned num, VarKind dstKind, 162 unsigned dstPos) { 163 assert(srcKind != dstKind && "cannot convert variables to the same kind"); 164 assert(srcPos + num <= getNumVarKind(srcKind) && 165 "invalid range for source variables"); 166 assert(dstPos <= getNumVarKind(dstKind) && 167 "invalid position for destination variables"); 168 169 // Move identifiers if `usingIds` and variables moved are not locals. 170 unsigned srcOffset = getVarKindOffset(srcKind) + srcPos; 171 unsigned dstOffset = getVarKindOffset(dstKind) + dstPos; 172 if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) { 173 identifiers.insert(identifiers.begin() + dstOffset, num, Identifier()); 174 // Update srcOffset if insertion of new elements invalidates it. 175 if (dstOffset < srcOffset) 176 srcOffset += num; 177 std::move(identifiers.begin() + srcOffset, 178 identifiers.begin() + srcOffset + num, 179 identifiers.begin() + dstOffset); 180 identifiers.erase(identifiers.begin() + srcOffset, 181 identifiers.begin() + srcOffset + num); 182 } else if (isUsingIds() && srcKind != VarKind::Local) { 183 identifiers.erase(identifiers.begin() + srcOffset, 184 identifiers.begin() + srcOffset + num); 185 } else if (isUsingIds() && dstKind != VarKind::Local) { 186 identifiers.insert(identifiers.begin() + dstOffset, num, Identifier()); 187 } 188 189 auto addVars = [&](VarKind kind, int num) { 190 switch (kind) { 191 case VarKind::Domain: 192 numDomain += num; 193 break; 194 case VarKind::Range: 195 numRange += num; 196 break; 197 case VarKind::Symbol: 198 numSymbols += num; 199 break; 200 case VarKind::Local: 201 numLocals += num; 202 break; 203 } 204 }; 205 206 addVars(srcKind, -(signed)num); 207 addVars(dstKind, num); 208 } 209 210 void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA, 211 unsigned posB) { 212 if (!isUsingIds()) 213 return; 214 215 if (kindA == VarKind::Local && kindB == VarKind::Local) 216 return; 217 218 if (kindA == VarKind::Local) { 219 setId(kindB, posB, Identifier()); 220 return; 221 } 222 223 if (kindB == VarKind::Local) { 224 setId(kindA, posA, Identifier()); 225 return; 226 } 227 228 std::swap(identifiers[getVarKindOffset(kindA) + posA], 229 identifiers[getVarKindOffset(kindB) + posB]); 230 } 231 232 bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const { 233 return getNumDomainVars() == other.getNumDomainVars() && 234 getNumRangeVars() == other.getNumRangeVars() && 235 getNumSymbolVars() == other.getNumSymbolVars(); 236 } 237 238 bool PresburgerSpace::isEqual(const PresburgerSpace &other) const { 239 return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars(); 240 } 241 242 /// Checks if the number of ids of the given kind in the two spaces are 243 /// equal and if the ids are equal. Assumes that both spaces are using 244 /// ids. 245 static bool areIdsEqual(const PresburgerSpace &spaceA, 246 const PresburgerSpace &spaceB, VarKind kind) { 247 assert(spaceA.isUsingIds() && spaceB.isUsingIds() && 248 "Both spaces should be using ids"); 249 if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind)) 250 return false; 251 if (kind == VarKind::Local) 252 return true; // No ids. 253 return spaceA.getIds(kind) == spaceB.getIds(kind); 254 } 255 256 bool PresburgerSpace::isAligned(const PresburgerSpace &other) const { 257 // If only one of the spaces is using identifiers, then they are 258 // not aligned. 259 if (isUsingIds() != other.isUsingIds()) 260 return false; 261 // If both spaces are using identifiers, then they are aligned if 262 // their identifiers are equal. Identifiers being equal implies 263 // that the number of variables of each kind is same, which implies 264 // compatiblity, so we do not check for that. 265 if (isUsingIds()) 266 return areIdsEqual(*this, other, VarKind::Domain) && 267 areIdsEqual(*this, other, VarKind::Range) && 268 areIdsEqual(*this, other, VarKind::Symbol); 269 // If neither space is using identifiers, then they are aligned if 270 // they are compatible. 271 return isCompatible(other); 272 } 273 274 bool PresburgerSpace::isAligned(const PresburgerSpace &other, 275 VarKind kind) const { 276 // If only one of the spaces is using identifiers, then they are 277 // not aligned. 278 if (isUsingIds() != other.isUsingIds()) 279 return false; 280 // If both spaces are using identifiers, then they are aligned if 281 // their identifiers are equal. Identifiers being equal implies 282 // that the number of variables of each kind is same, which implies 283 // compatiblity, so we do not check for that 284 if (isUsingIds()) 285 return areIdsEqual(*this, other, kind); 286 // If neither space is using identifiers, then they are aligned if 287 // the number of variable kind is equal. 288 return getNumVarKind(kind) == other.getNumVarKind(kind); 289 } 290 291 void PresburgerSpace::setVarSymbolSeparation(unsigned newSymbolCount) { 292 assert(newSymbolCount <= getNumDimAndSymbolVars() && 293 "invalid separation position"); 294 numRange = numRange + numSymbols - newSymbolCount; 295 numSymbols = newSymbolCount; 296 // We do not need to change `identifiers` since the ordering of 297 // `identifiers` remains same. 298 } 299 300 void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) { 301 assert(usingIds && other.usingIds && 302 "Both spaces need to have identifers to merge & align"); 303 304 // First merge & align identifiers into `other` from `this`. 305 unsigned i = 0; 306 for (const Identifier identifier : getIds(VarKind::Symbol)) { 307 // If the identifier exists in `other`, then align it; otherwise insert it 308 // assuming it is a new identifier. Search in `other` starting at position 309 // `i` since the left of `i` is aligned. 310 auto *findBegin = other.getIds(VarKind::Symbol).begin() + i; 311 auto *findEnd = other.getIds(VarKind::Symbol).end(); 312 auto *itr = std::find(findBegin, findEnd, identifier); 313 if (itr != findEnd) { 314 std::swap(findBegin, itr); 315 } else { 316 other.insertVar(VarKind::Symbol, i); 317 other.setId(VarKind::Symbol, i, identifier); 318 } 319 ++i; 320 } 321 322 // Finally add identifiers that are in `other`, but not in `this` to `this`. 323 for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) { 324 insertVar(VarKind::Symbol, i); 325 setId(VarKind::Symbol, i, other.getId(VarKind::Symbol, i)); 326 } 327 } 328 329 void PresburgerSpace::print(llvm::raw_ostream &os) const { 330 os << "Domain: " << getNumDomainVars() << ", " 331 << "Range: " << getNumRangeVars() << ", " 332 << "Symbols: " << getNumSymbolVars() << ", " 333 << "Locals: " << getNumLocalVars() << "\n"; 334 335 if (isUsingIds()) { 336 auto printIds = [&](VarKind kind) { 337 os << " "; 338 for (Identifier id : getIds(kind)) { 339 if (id.hasValue()) 340 id.print(os); 341 else 342 os << "None"; 343 os << " "; 344 } 345 }; 346 347 os << "("; 348 printIds(VarKind::Domain); 349 os << ") -> ("; 350 printIds(VarKind::Range); 351 os << ") : ["; 352 printIds(VarKind::Symbol); 353 os << "]"; 354 } 355 } 356 357 void PresburgerSpace::dump() const { 358 print(llvm::errs()); 359 llvm::errs() << "\n"; 360 } 361