xref: /llvm-project/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp (revision 6c7a3f80e75de36f2642110a077664e948d9e7e3)
1 //===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PresburgerSpace.h"
10 #include "llvm/Support/ErrorHandling.h"
11 #include "llvm/Support/raw_ostream.h"
12 #include <algorithm>
13 #include <cassert>
14 
15 using namespace mlir;
16 using namespace presburger;
17 
18 bool Identifier::isEqual(const Identifier &other) const {
19   if (value == nullptr || other.value == nullptr)
20     return false;
21 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
22   assert(value != other.value ||
23          (value == other.value && idType == other.idType &&
24           "Values of Identifiers are equal but their types do not match."));
25 #endif
26   return value == other.value;
27 }
28 
29 void Identifier::print(llvm::raw_ostream &os) const {
30   os << "Id<" << value << ">";
31 }
32 
33 void Identifier::dump() const {
34   print(llvm::errs());
35   llvm::errs() << "\n";
36 }
37 
38 PresburgerSpace PresburgerSpace::getDomainSpace() const {
39   PresburgerSpace newSpace = *this;
40   newSpace.removeVarRange(VarKind::Range, 0, getNumRangeVars());
41   newSpace.convertVarKind(VarKind::Domain, 0, getNumDomainVars(),
42                           VarKind::SetDim, 0);
43   return newSpace;
44 }
45 
46 PresburgerSpace PresburgerSpace::getRangeSpace() const {
47   PresburgerSpace newSpace = *this;
48   newSpace.removeVarRange(VarKind::Domain, 0, getNumDomainVars());
49   return newSpace;
50 }
51 
52 PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
53   PresburgerSpace space = *this;
54   space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
55   return space;
56 }
57 
58 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
59   if (kind == VarKind::Domain)
60     return getNumDomainVars();
61   if (kind == VarKind::Range)
62     return getNumRangeVars();
63   if (kind == VarKind::Symbol)
64     return getNumSymbolVars();
65   if (kind == VarKind::Local)
66     return getNumLocalVars();
67   llvm_unreachable("VarKind does not exist!");
68 }
69 
70 unsigned PresburgerSpace::getVarKindOffset(VarKind kind) const {
71   if (kind == VarKind::Domain)
72     return 0;
73   if (kind == VarKind::Range)
74     return getNumDomainVars();
75   if (kind == VarKind::Symbol)
76     return getNumDimVars();
77   if (kind == VarKind::Local)
78     return getNumDimAndSymbolVars();
79   llvm_unreachable("VarKind does not exist!");
80 }
81 
82 unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const {
83   return getVarKindOffset(kind) + getNumVarKind(kind);
84 }
85 
86 unsigned PresburgerSpace::getVarKindOverlap(VarKind kind, unsigned varStart,
87                                             unsigned varLimit) const {
88   unsigned varRangeStart = getVarKindOffset(kind);
89   unsigned varRangeEnd = getVarKindEnd(kind);
90 
91   // Compute number of elements in intersection of the ranges [varStart,
92   // varLimit) and [varRangeStart, varRangeEnd).
93   unsigned overlapStart = std::max(varStart, varRangeStart);
94   unsigned overlapEnd = std::min(varLimit, varRangeEnd);
95 
96   if (overlapStart > overlapEnd)
97     return 0;
98   return overlapEnd - overlapStart;
99 }
100 
101 VarKind PresburgerSpace::getVarKindAt(unsigned pos) const {
102   assert(pos < getNumVars() && "`pos` should represent a valid var position");
103   if (pos < getVarKindEnd(VarKind::Domain))
104     return VarKind::Domain;
105   if (pos < getVarKindEnd(VarKind::Range))
106     return VarKind::Range;
107   if (pos < getVarKindEnd(VarKind::Symbol))
108     return VarKind::Symbol;
109   if (pos < getVarKindEnd(VarKind::Local))
110     return VarKind::Local;
111   llvm_unreachable("`pos` should represent a valid var position");
112 }
113 
114 unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) {
115   assert(pos <= getNumVarKind(kind));
116 
117   unsigned absolutePos = getVarKindOffset(kind) + pos;
118 
119   if (kind == VarKind::Domain)
120     numDomain += num;
121   else if (kind == VarKind::Range)
122     numRange += num;
123   else if (kind == VarKind::Symbol)
124     numSymbols += num;
125   else
126     numLocals += num;
127 
128   // Insert NULL identifiers if `usingIds` and variables inserted are
129   // not locals.
130   if (usingIds && kind != VarKind::Local)
131     identifiers.insert(identifiers.begin() + absolutePos, num, Identifier());
132 
133   return absolutePos;
134 }
135 
136 void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart,
137                                      unsigned varLimit) {
138   assert(varLimit <= getNumVarKind(kind) && "invalid var limit");
139 
140   if (varStart >= varLimit)
141     return;
142 
143   unsigned numVarsEliminated = varLimit - varStart;
144   if (kind == VarKind::Domain)
145     numDomain -= numVarsEliminated;
146   else if (kind == VarKind::Range)
147     numRange -= numVarsEliminated;
148   else if (kind == VarKind::Symbol)
149     numSymbols -= numVarsEliminated;
150   else
151     numLocals -= numVarsEliminated;
152 
153   // Remove identifiers if `usingIds` and variables removed are not
154   // locals.
155   if (usingIds && kind != VarKind::Local)
156     identifiers.erase(identifiers.begin() + getVarKindOffset(kind) + varStart,
157                       identifiers.begin() + getVarKindOffset(kind) + varLimit);
158 }
159 
160 void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
161                                      unsigned num, VarKind dstKind,
162                                      unsigned dstPos) {
163   assert(srcKind != dstKind && "cannot convert variables to the same kind");
164   assert(srcPos + num <= getNumVarKind(srcKind) &&
165          "invalid range for source variables");
166   assert(dstPos <= getNumVarKind(dstKind) &&
167          "invalid position for destination variables");
168 
169   // Move identifiers if `usingIds` and variables moved are not locals.
170   unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
171   unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
172   if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
173     identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
174     // Update srcOffset if insertion of new elements invalidates it.
175     if (dstOffset < srcOffset)
176       srcOffset += num;
177     std::move(identifiers.begin() + srcOffset,
178               identifiers.begin() + srcOffset + num,
179               identifiers.begin() + dstOffset);
180     identifiers.erase(identifiers.begin() + srcOffset,
181                       identifiers.begin() + srcOffset + num);
182   } else if (isUsingIds() && srcKind != VarKind::Local) {
183     identifiers.erase(identifiers.begin() + srcOffset,
184                       identifiers.begin() + srcOffset + num);
185   } else if (isUsingIds() && dstKind != VarKind::Local) {
186     identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
187   }
188 
189   auto addVars = [&](VarKind kind, int num) {
190     switch (kind) {
191     case VarKind::Domain:
192       numDomain += num;
193       break;
194     case VarKind::Range:
195       numRange += num;
196       break;
197     case VarKind::Symbol:
198       numSymbols += num;
199       break;
200     case VarKind::Local:
201       numLocals += num;
202       break;
203     }
204   };
205 
206   addVars(srcKind, -(signed)num);
207   addVars(dstKind, num);
208 }
209 
210 void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
211                               unsigned posB) {
212   if (!isUsingIds())
213     return;
214 
215   if (kindA == VarKind::Local && kindB == VarKind::Local)
216     return;
217 
218   if (kindA == VarKind::Local) {
219     setId(kindB, posB, Identifier());
220     return;
221   }
222 
223   if (kindB == VarKind::Local) {
224     setId(kindA, posA, Identifier());
225     return;
226   }
227 
228   std::swap(identifiers[getVarKindOffset(kindA) + posA],
229             identifiers[getVarKindOffset(kindB) + posB]);
230 }
231 
232 bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const {
233   return getNumDomainVars() == other.getNumDomainVars() &&
234          getNumRangeVars() == other.getNumRangeVars() &&
235          getNumSymbolVars() == other.getNumSymbolVars();
236 }
237 
238 bool PresburgerSpace::isEqual(const PresburgerSpace &other) const {
239   return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars();
240 }
241 
242 /// Checks if the number of ids of the given kind in the two spaces are
243 /// equal and if the ids are equal. Assumes that both spaces are using
244 /// ids.
245 static bool areIdsEqual(const PresburgerSpace &spaceA,
246                         const PresburgerSpace &spaceB, VarKind kind) {
247   assert(spaceA.isUsingIds() && spaceB.isUsingIds() &&
248          "Both spaces should be using ids");
249   if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind))
250     return false;
251   if (kind == VarKind::Local)
252     return true; // No ids.
253   return spaceA.getIds(kind) == spaceB.getIds(kind);
254 }
255 
256 bool PresburgerSpace::isAligned(const PresburgerSpace &other) const {
257   // If only one of the spaces is using identifiers, then they are
258   // not aligned.
259   if (isUsingIds() != other.isUsingIds())
260     return false;
261   // If both spaces are using identifiers, then they are aligned if
262   // their identifiers are equal. Identifiers being equal implies
263   // that the number of variables of each kind is same, which implies
264   // compatiblity, so we do not check for that.
265   if (isUsingIds())
266     return areIdsEqual(*this, other, VarKind::Domain) &&
267            areIdsEqual(*this, other, VarKind::Range) &&
268            areIdsEqual(*this, other, VarKind::Symbol);
269   // If neither space is using identifiers, then they are aligned if
270   // they are compatible.
271   return isCompatible(other);
272 }
273 
274 bool PresburgerSpace::isAligned(const PresburgerSpace &other,
275                                 VarKind kind) const {
276   // If only one of the spaces is using identifiers, then they are
277   // not aligned.
278   if (isUsingIds() != other.isUsingIds())
279     return false;
280   // If both spaces are using identifiers, then they are aligned if
281   // their identifiers are equal. Identifiers being equal implies
282   // that the number of variables of each kind is same, which implies
283   // compatiblity, so we do not check for that
284   if (isUsingIds())
285     return areIdsEqual(*this, other, kind);
286   // If neither space is using identifiers, then they are aligned if
287   // the number of variable kind is equal.
288   return getNumVarKind(kind) == other.getNumVarKind(kind);
289 }
290 
291 void PresburgerSpace::setVarSymbolSeparation(unsigned newSymbolCount) {
292   assert(newSymbolCount <= getNumDimAndSymbolVars() &&
293          "invalid separation position");
294   numRange = numRange + numSymbols - newSymbolCount;
295   numSymbols = newSymbolCount;
296   // We do not need to change `identifiers` since the ordering of
297   // `identifiers` remains same.
298 }
299 
300 void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
301   assert(usingIds && other.usingIds &&
302          "Both spaces need to have identifers to merge & align");
303 
304   // First merge & align identifiers into `other` from `this`.
305   unsigned i = 0;
306   for (const Identifier identifier : getIds(VarKind::Symbol)) {
307     // If the identifier exists in `other`, then align it; otherwise insert it
308     // assuming it is a new identifier. Search in `other` starting at position
309     // `i` since the left of `i` is aligned.
310     auto *findBegin = other.getIds(VarKind::Symbol).begin() + i;
311     auto *findEnd = other.getIds(VarKind::Symbol).end();
312     auto *itr = std::find(findBegin, findEnd, identifier);
313     if (itr != findEnd) {
314       std::swap(findBegin, itr);
315     } else {
316       other.insertVar(VarKind::Symbol, i);
317       other.setId(VarKind::Symbol, i, identifier);
318     }
319     ++i;
320   }
321 
322   // Finally add identifiers that are in `other`, but not in `this` to `this`.
323   for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
324     insertVar(VarKind::Symbol, i);
325     setId(VarKind::Symbol, i, other.getId(VarKind::Symbol, i));
326   }
327 }
328 
329 void PresburgerSpace::print(llvm::raw_ostream &os) const {
330   os << "Domain: " << getNumDomainVars() << ", "
331      << "Range: " << getNumRangeVars() << ", "
332      << "Symbols: " << getNumSymbolVars() << ", "
333      << "Locals: " << getNumLocalVars() << "\n";
334 
335   if (isUsingIds()) {
336     auto printIds = [&](VarKind kind) {
337       os << " ";
338       for (Identifier id : getIds(kind)) {
339         if (id.hasValue())
340           id.print(os);
341         else
342           os << "None";
343         os << " ";
344       }
345     };
346 
347     os << "(";
348     printIds(VarKind::Domain);
349     os << ") -> (";
350     printIds(VarKind::Range);
351     os << ") : [";
352     printIds(VarKind::Symbol);
353     os << "]";
354   }
355 }
356 
357 void PresburgerSpace::dump() const {
358   print(llvm::errs());
359   llvm::errs() << "\n";
360 }
361