xref: /llvm-project/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp (revision 7557530f428a2f226d8d925c33d527dfcfdcb0c5)
1 //===- SetTest.cpp - Tests for PresburgerSet ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains tests for PresburgerSet. The tests for union,
10 // intersection, subtract, and complement work by computing the operation on
11 // two sets and checking, for a set of points, that the resulting set contains
12 // the point iff the result is supposed to contain it. The test for isEqual just
13 // checks if the result for two sets matches the expected result.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "Parser.h"
18 #include "Utils.h"
19 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
20 #include "mlir/IR/MLIRContext.h"
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace presburger;
28 
29 /// Compute the union of s and t, and check that each of the given points
30 /// belongs to the union iff it belongs to at least one of s and t.
testUnionAtPoints(const PresburgerSet & s,const PresburgerSet & t,ArrayRef<SmallVector<int64_t,4>> points)31 static void testUnionAtPoints(const PresburgerSet &s, const PresburgerSet &t,
32                               ArrayRef<SmallVector<int64_t, 4>> points) {
33   PresburgerSet unionSet = s.unionSet(t);
34   for (const SmallVector<int64_t, 4> &point : points) {
35     bool inS = s.containsPoint(point);
36     bool inT = t.containsPoint(point);
37     bool inUnion = unionSet.containsPoint(point);
38     EXPECT_EQ(inUnion, inS || inT);
39   }
40 }
41 
42 /// Compute the intersection of s and t, and check that each of the given points
43 /// belongs to the intersection iff it belongs to both s and t.
testIntersectAtPoints(const PresburgerSet & s,const PresburgerSet & t,ArrayRef<SmallVector<int64_t,4>> points)44 static void testIntersectAtPoints(const PresburgerSet &s,
45                                   const PresburgerSet &t,
46                                   ArrayRef<SmallVector<int64_t, 4>> points) {
47   PresburgerSet intersection = s.intersect(t);
48   for (const SmallVector<int64_t, 4> &point : points) {
49     bool inS = s.containsPoint(point);
50     bool inT = t.containsPoint(point);
51     bool inIntersection = intersection.containsPoint(point);
52     EXPECT_EQ(inIntersection, inS && inT);
53   }
54 }
55 
56 /// Compute the set difference s \ t, and check that each of the given points
57 /// belongs to the difference iff it belongs to s and does not belong to t.
testSubtractAtPoints(const PresburgerSet & s,const PresburgerSet & t,ArrayRef<SmallVector<int64_t,4>> points)58 static void testSubtractAtPoints(const PresburgerSet &s, const PresburgerSet &t,
59                                  ArrayRef<SmallVector<int64_t, 4>> points) {
60   PresburgerSet diff = s.subtract(t);
61   for (const SmallVector<int64_t, 4> &point : points) {
62     bool inS = s.containsPoint(point);
63     bool inT = t.containsPoint(point);
64     bool inDiff = diff.containsPoint(point);
65     if (inT)
66       EXPECT_FALSE(inDiff);
67     else
68       EXPECT_EQ(inDiff, inS);
69   }
70 }
71 
72 /// Compute the complement of s, and check that each of the given points
73 /// belongs to the complement iff it does not belong to s.
testComplementAtPoints(const PresburgerSet & s,ArrayRef<SmallVector<int64_t,4>> points)74 static void testComplementAtPoints(const PresburgerSet &s,
75                                    ArrayRef<SmallVector<int64_t, 4>> points) {
76   PresburgerSet complement = s.complement();
77   complement.complement();
78   for (const SmallVector<int64_t, 4> &point : points) {
79     bool inS = s.containsPoint(point);
80     bool inComplement = complement.containsPoint(point);
81     if (inS)
82       EXPECT_FALSE(inComplement);
83     else
84       EXPECT_TRUE(inComplement);
85   }
86 }
87 
88 /// Construct a PresburgerSet having `numDims` dimensions and no symbols from
89 /// the given list of IntegerPolyhedron. Each Poly in `polys` should also have
90 /// `numDims` dimensions and no symbols, although it can have any number of
91 /// local ids.
makeSetFromPoly(unsigned numDims,ArrayRef<IntegerPolyhedron> polys)92 static PresburgerSet makeSetFromPoly(unsigned numDims,
93                                      ArrayRef<IntegerPolyhedron> polys) {
94   PresburgerSet set =
95       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims));
96   for (const IntegerPolyhedron &poly : polys)
97     set.unionInPlace(poly);
98   return set;
99 }
100 
TEST(SetTest,containsPoint)101 TEST(SetTest, containsPoint) {
102   PresburgerSet setA = parsePresburgerSet(
103       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
104   for (unsigned x = 0; x <= 21; ++x) {
105     if ((2 <= x && x <= 8) || (10 <= x && x <= 20))
106       EXPECT_TRUE(setA.containsPoint({x}));
107     else
108       EXPECT_FALSE(setA.containsPoint({x}));
109   }
110 
111   // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union
112   // a square with opposite corners (2, 2) and (10, 10).
113   PresburgerSet setB = parsePresburgerSet(
114       {"(x,y) : (x + y - 4 >= 0, -x - y + 32 >= 0, "
115        "x - y - 2 >= 0, -x + y + 16 >= 0)",
116        "(x,y) : (x - 2 >= 0, y - 2 >= 0, -x + 10 >= 0, -y + 10 >= 0)"});
117 
118   for (unsigned x = 1; x <= 25; ++x) {
119     for (unsigned y = -6; y <= 16; ++y) {
120       if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16)
121         EXPECT_TRUE(setB.containsPoint({x, y}));
122       else if (2 <= x && x <= 10 && 2 <= y && y <= 10)
123         EXPECT_TRUE(setB.containsPoint({x, y}));
124       else
125         EXPECT_FALSE(setB.containsPoint({x, y}));
126     }
127   }
128 
129   // The PresburgerSet has only one id, x, so we supply one value.
130   EXPECT_TRUE(
131       PresburgerSet(parseIntegerPolyhedron("(x) : (x - 2*(x floordiv 2) == 0)"))
132           .containsPoint({0}));
133 }
134 
TEST(SetTest,Union)135 TEST(SetTest, Union) {
136   PresburgerSet set = parsePresburgerSet(
137       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
138 
139   // Universe union set.
140   testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
141                     set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
142 
143   // empty set union set.
144   testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
145                     set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
146 
147   // empty set union Universe.
148   testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
149                     PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
150                     {{1}, {2}, {0}, {-1}});
151 
152   // Universe union empty set.
153   testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
154                     PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
155                     {{1}, {2}, {0}, {-1}});
156 
157   // empty set union empty set.
158   testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
159                     PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
160                     {{1}, {2}, {0}, {-1}});
161 }
162 
TEST(SetTest,Intersect)163 TEST(SetTest, Intersect) {
164   PresburgerSet set = parsePresburgerSet(
165       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
166 
167   // Universe intersection set.
168   testIntersectAtPoints(
169       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), set,
170       {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
171 
172   // empty set intersection set.
173   testIntersectAtPoints(
174       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), set,
175       {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
176 
177   // empty set intersection Universe.
178   testIntersectAtPoints(
179       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
180       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
181       {{1}, {2}, {0}, {-1}});
182 
183   // Universe intersection empty set.
184   testIntersectAtPoints(
185       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
186       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
187       {{1}, {2}, {0}, {-1}});
188 
189   // Universe intersection Universe.
190   testIntersectAtPoints(
191       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
192       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
193       {{1}, {2}, {0}, {-1}});
194 }
195 
TEST(SetTest,Subtract)196 TEST(SetTest, Subtract) {
197   // The interval [2, 8] minus the interval [10, 20].
198   testSubtractAtPoints(
199       parsePresburgerSet({"(x) : (x - 2 >= 0, -x + 8 >= 0)"}),
200       parsePresburgerSet({"(x) : (x - 10 >= 0, -x + 20 >= 0)"}),
201       {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
202 
203   // Universe minus [2, 8] U [10, 20]
204   testSubtractAtPoints(
205       parsePresburgerSet({"(x) : ()"}),
206       parsePresburgerSet({"(x) : (x - 2 >= 0, -x + 8 >= 0)",
207                           "(x) : (x - 10 >= 0, -x + 20 >= 0)"}),
208       {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
209 
210   // ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6])
211   testSubtractAtPoints(
212       parsePresburgerSet({"(x) : (-x >= 0)", "(x) : (x - 3 >= 0, -x + 4 >= 0)",
213                           "(x) : (x - 6 >= 0, -x + 7 >= 0)"}),
214       parsePresburgerSet({"(x) : (x - 2 >= 0, -x + 3 >= 0)",
215                           "(x) : (x - 5 >= 0, -x + 6 >= 0)"}),
216       {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
217 
218   // Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}.
219   testSubtractAtPoints(parsePresburgerSet({"(x, y) : (x - y >= 0)"}),
220                        parsePresburgerSet({"(x, y) : (x + y >= 0)"}),
221                        {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}});
222 
223   // A rectangle with corners at (2, 2) and (10, 10), minus
224   // a rectangle with corners at (5, -10) and (7, 100).
225   // This splits the former rectangle into two halves, (2, 2) to (5, 10) and
226   // (7, 2) to (10, 10).
227   testSubtractAtPoints(
228       parsePresburgerSet({
229           "(x, y) : (x - 2 >= 0, y - 2 >= 0, -x + 10 >= 0, -y + 10 >= 0)",
230       }),
231       parsePresburgerSet({
232           "(x, y) : (x - 5 >= 0, y + 10 >= 0, -x + 7 >= 0, -y + 100 >= 0)",
233       }),
234       {{1, 2},  {2, 2},  {4, 2},  {5, 2},  {7, 2},  {8, 2},  {11, 2},
235        {1, 1},  {2, 1},  {4, 1},  {5, 1},  {7, 1},  {8, 1},  {11, 1},
236        {1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10},
237        {1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}});
238 
239   // A rectangle with corners at (2, 2) and (10, 10), minus
240   // a rectangle with corners at (5, 4) and (7, 8).
241   // This creates a hole in the middle of the former rectangle, and the
242   // resulting set can be represented as a union of four rectangles.
243   testSubtractAtPoints(
244       parsePresburgerSet(
245           {"(x, y) : (x - 2 >= 0, y -2 >= 0, -x + 10 >= 0, -y + 10 >= 0)"}),
246       parsePresburgerSet({
247           "(x, y) : (x - 5 >= 0, y - 4 >= 0, -x + 7 >= 0, -y + 8 >= 0)",
248       }),
249       {{1, 1},
250        {2, 2},
251        {10, 10},
252        {11, 11},
253        {5, 4},
254        {7, 4},
255        {5, 8},
256        {7, 8},
257        {4, 4},
258        {8, 4},
259        {4, 8},
260        {8, 8}});
261 
262   // The second set is a superset of the first one, since on the line x + y = 0,
263   // y <= 1 is equivalent to x >= -1. So the result is empty.
264   testSubtractAtPoints(
265       parsePresburgerSet({"(x, y) : (x >= 0, x + y == 0)"}),
266       parsePresburgerSet({"(x, y) : (-y + 1 >= 0, x + y == 0)"}),
267       {{0, 0},
268        {1, -1},
269        {2, -2},
270        {-1, 1},
271        {-2, 2},
272        {1, 1},
273        {-1, -1},
274        {-1, 1},
275        {1, -1}});
276 
277   // The result should be {0} U {2}.
278   testSubtractAtPoints(parsePresburgerSet({"(x) : (x >= 0, -x + 2 >= 0)"}),
279                        parsePresburgerSet({"(x) : (x - 1 == 0)"}),
280                        {{-1}, {0}, {1}, {2}, {3}});
281 
282   // Sets with lots of redundant inequalities to test the redundancy heuristic.
283   // (the heuristic is for the subtrahend, the second set which is the one being
284   // subtracted)
285 
286   // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus
287   // a triangle with vertices {(2, 2), (10, 2), (10, 10)}.
288   testSubtractAtPoints(
289       parsePresburgerSet({
290           "(x, y) : (x + y - 4 >= 0, -x - y + 32 >= 0, x - y - 2 >= 0, "
291           "-x + y + 16 >= 0)",
292       }),
293       parsePresburgerSet(
294           {"(x, y) : (x - 2 >= 0, y - 2 >= 0, -x + 10 >= 0, "
295            "-y + 10 >= 0, x + y - 2 >= 0, -x - y + 30 >= 0, x - y >= 0, "
296            "-x + y + 10 >= 0)"}),
297       {{1, 2},  {2, 2},   {3, 2},   {4, 2},  {1, 1},   {2, 1},   {3, 1},
298        {4, 1},  {2, 0},   {3, 0},   {4, 0},  {5, 0},   {10, 2},  {11, 2},
299        {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
300        {24, 8}, {24, 7},  {17, 15}, {16, 15}});
301 
302   // ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6,
303   // 7])
304   testSubtractAtPoints(
305       parsePresburgerSet({"(x) : (-x - 5 >= 0)", "(x) : (x - 3 == 0)",
306                           "(x) : (x - 4 == 0)", "(x) : (x - 5 == 0)"}),
307       parsePresburgerSet(
308           {"(x) : (-x - 2 >= 0, x - 10 >= 0, -x >= 0, -x + 10 >= 0, "
309            "x - 100 >= 0, x - 50 >= 0)",
310            "(x) : (x - 3 >= 0, -x + 4 >= 0, x + 1 >= 0, "
311            "x + 7 >= 0, -x + 10 >= 0)",
312            "(x) : (x - 6 >= 0, -x + 7 >= 0, x + 1 >= 0, x - 3 >= 0, "
313            "-x + 5 >= 0)"}),
314       {{-6},
315        {-5},
316        {-4},
317        {-9},
318        {-10},
319        {-11},
320        {0},
321        {1},
322        {2},
323        {3},
324        {4},
325        {5},
326        {6},
327        {7},
328        {8}});
329 }
330 
TEST(SetTest,Complement)331 TEST(SetTest, Complement) {
332   // Complement of universe.
333   testComplementAtPoints(
334       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
335       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
336 
337   // Complement of empty set.
338   testComplementAtPoints(
339       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
340       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
341 
342   testComplementAtPoints(parsePresburgerSet({"(x,y) : (x - 2 >= 0, y - 2 >= 0, "
343                                              "-x + 10 >= 0, -y + 10 >= 0)"}),
344                          {{1, 1},
345                           {2, 1},
346                           {1, 2},
347                           {2, 2},
348                           {2, 3},
349                           {3, 2},
350                           {10, 10},
351                           {10, 11},
352                           {11, 10},
353                           {2, 10},
354                           {2, 11},
355                           {1, 10}});
356 }
357 
TEST(SetTest,isEqual)358 TEST(SetTest, isEqual) {
359   // set = [2, 8] U [10, 20].
360   PresburgerSet universe =
361       PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1)));
362   PresburgerSet emptySet =
363       PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1)));
364   PresburgerSet set = parsePresburgerSet(
365       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
366 
367   // universe != emptySet.
368   EXPECT_FALSE(universe.isEqual(emptySet));
369   // emptySet != universe.
370   EXPECT_FALSE(emptySet.isEqual(universe));
371   // emptySet == emptySet.
372   EXPECT_TRUE(emptySet.isEqual(emptySet));
373   // universe == universe.
374   EXPECT_TRUE(universe.isEqual(universe));
375 
376   // universe U emptySet == universe.
377   EXPECT_TRUE(universe.unionSet(emptySet).isEqual(universe));
378   // universe U universe == universe.
379   EXPECT_TRUE(universe.unionSet(universe).isEqual(universe));
380   // emptySet U emptySet == emptySet.
381   EXPECT_TRUE(emptySet.unionSet(emptySet).isEqual(emptySet));
382   // universe U emptySet != emptySet.
383   EXPECT_FALSE(universe.unionSet(emptySet).isEqual(emptySet));
384   // universe U universe != emptySet.
385   EXPECT_FALSE(universe.unionSet(universe).isEqual(emptySet));
386   // emptySet U emptySet != universe.
387   EXPECT_FALSE(emptySet.unionSet(emptySet).isEqual(universe));
388 
389   // set \ set == emptySet.
390   EXPECT_TRUE(set.subtract(set).isEqual(emptySet));
391   // set == set.
392   EXPECT_TRUE(set.isEqual(set));
393   // set U (universe \ set) == universe.
394   EXPECT_TRUE(set.unionSet(set.complement()).isEqual(universe));
395   // set U (universe \ set) != set.
396   EXPECT_FALSE(set.unionSet(set.complement()).isEqual(set));
397   // set != set U (universe \ set).
398   EXPECT_FALSE(set.isEqual(set.unionSet(set.complement())));
399 
400   // square is one unit taller than rect.
401   PresburgerSet square = parsePresburgerSet(
402       {"(x, y) : (x - 2 >= 0, y - 2 >= 0, -x + 9 >= 0, -y + 9 >= 0)"});
403   PresburgerSet rect = parsePresburgerSet(
404       {"(x, y) : (x - 2 >= 0, y - 2 >= 0, -x + 9 >= 0, -y + 8 >= 0)"});
405   EXPECT_FALSE(square.isEqual(rect));
406   PresburgerSet universeRect = square.unionSet(square.complement());
407   PresburgerSet universeSquare = rect.unionSet(rect.complement());
408   EXPECT_TRUE(universeRect.isEqual(universeSquare));
409   EXPECT_FALSE(universeRect.isEqual(rect));
410   EXPECT_FALSE(universeSquare.isEqual(square));
411   EXPECT_FALSE(rect.complement().isEqual(square.complement()));
412 }
413 
expectEqual(const PresburgerSet & s,const PresburgerSet & t)414 void expectEqual(const PresburgerSet &s, const PresburgerSet &t) {
415   EXPECT_TRUE(s.isEqual(t));
416 }
417 
expectEqual(const IntegerPolyhedron & s,const IntegerPolyhedron & t)418 void expectEqual(const IntegerPolyhedron &s, const IntegerPolyhedron &t) {
419   EXPECT_TRUE(s.isEqual(t));
420 }
421 
expectEmpty(const PresburgerSet & s)422 void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); }
423 
TEST(SetTest,divisions)424 TEST(SetTest, divisions) {
425   // evens = {x : exists q, x = 2q}.
426   PresburgerSet evens{
427       parseIntegerPolyhedron("(x) : (x - 2 * (x floordiv 2) == 0)")};
428 
429   //  odds = {x : exists q, x = 2q + 1}.
430   PresburgerSet odds{
431       parseIntegerPolyhedron("(x) : (x - 2 * (x floordiv 2) - 1 == 0)")};
432 
433   // multiples3 = {x : exists q, x = 3q}.
434   PresburgerSet multiples3{
435       parseIntegerPolyhedron("(x) : (x - 3 * (x floordiv 3) == 0)")};
436 
437   // multiples6 = {x : exists q, x = 6q}.
438   PresburgerSet multiples6{
439       parseIntegerPolyhedron("(x) : (x - 6 * (x floordiv 6) == 0)")};
440 
441   // evens /\ odds = empty.
442   expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
443   // evens U odds = universe.
444   expectEqual(evens.unionSet(odds),
445               PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))));
446   expectEqual(evens.complement(), odds);
447   expectEqual(odds.complement(), evens);
448   // even multiples of 3 = multiples of 6.
449   expectEqual(multiples3.intersect(evens), multiples6);
450 
451   PresburgerSet setA{parseIntegerPolyhedron("(x) : (-x >= 0)")};
452   PresburgerSet setB{parseIntegerPolyhedron("(x) : (x floordiv 2 - 4 >= 0)")};
453   EXPECT_TRUE(setA.subtract(setB).isEqual(setA));
454 }
455 
convertSuffixDimsToLocals(IntegerPolyhedron & poly,unsigned numLocals)456 void convertSuffixDimsToLocals(IntegerPolyhedron &poly, unsigned numLocals) {
457   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numLocals,
458                       poly.getNumDimVars(), VarKind::Local);
459 }
460 
461 inline IntegerPolyhedron
parseIntegerPolyhedronAndMakeLocals(StringRef str,unsigned numLocals)462 parseIntegerPolyhedronAndMakeLocals(StringRef str, unsigned numLocals) {
463   IntegerPolyhedron poly = parseIntegerPolyhedron(str);
464   convertSuffixDimsToLocals(poly, numLocals);
465   return poly;
466 }
467 
TEST(SetTest,divisionsDefByEq)468 TEST(SetTest, divisionsDefByEq) {
469   // evens = {x : exists q, x = 2q}.
470   PresburgerSet evens{parseIntegerPolyhedronAndMakeLocals(
471       "(x, y) : (x - 2 * y == 0)", /*numLocals=*/1)};
472 
473   //  odds = {x : exists q, x = 2q + 1}.
474   PresburgerSet odds{parseIntegerPolyhedronAndMakeLocals(
475       "(x, y) : (x - 2 * y - 1 == 0)", /*numLocals=*/1)};
476 
477   // multiples3 = {x : exists q, x = 3q}.
478   PresburgerSet multiples3{parseIntegerPolyhedronAndMakeLocals(
479       "(x, y) : (x - 3 * y == 0)", /*numLocals=*/1)};
480 
481   // multiples6 = {x : exists q, x = 6q}.
482   PresburgerSet multiples6{parseIntegerPolyhedronAndMakeLocals(
483       "(x, y) : (x - 6 * y == 0)", /*numLocals=*/1)};
484 
485   // evens /\ odds = empty.
486   expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
487   // evens U odds = universe.
488   expectEqual(evens.unionSet(odds),
489               PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))));
490   expectEqual(evens.complement(), odds);
491   expectEqual(odds.complement(), evens);
492   // even multiples of 3 = multiples of 6.
493   expectEqual(multiples3.intersect(evens), multiples6);
494 
495   PresburgerSet evensDefByIneq{
496       parseIntegerPolyhedron("(x) : (x - 2 * (x floordiv 2) == 0)")};
497   expectEqual(evens, PresburgerSet(evensDefByIneq));
498 }
499 
TEST(SetTest,divisionNonDivLocals)500 TEST(SetTest, divisionNonDivLocals) {
501   // This is a tetrahedron with vertices at
502   // (1/3, 0, 0), (2/3, 0, 0), (2/3, 0, 1000), and (1000, 1000, 1000).
503   //
504   // The only integer point in this is at (1000, 1000, 1000).
505   // We project this to the xy plane.
506   IntegerPolyhedron tetrahedron = parseIntegerPolyhedronAndMakeLocals(
507       "(x, y, z) : (y >= 0, z - y >= 0, 3000*x - 2998*y "
508       "- 1000 - z >= 0, -1500*x + 1499*y + 1000 >= 0)",
509       /*numLocals=*/1);
510 
511   // This is a triangle with vertices at (1/3, 0), (2/3, 0) and (1000, 1000).
512   // The only integer point in this is at (1000, 1000).
513   //
514   // It also happens to be the projection of the above onto the xy plane.
515   IntegerPolyhedron triangle =
516       parseIntegerPolyhedron("(x,y) : (y >= 0, 3000 * x - 2999 * y - 1000 >= "
517                              "0, -3000 * x + 2998 * y + 2000 >= 0)");
518 
519   EXPECT_TRUE(triangle.containsPoint({1000, 1000}));
520   EXPECT_FALSE(triangle.containsPoint({1001, 1001}));
521   expectEqual(triangle, tetrahedron);
522 
523   convertSuffixDimsToLocals(triangle, 1);
524   IntegerPolyhedron line = parseIntegerPolyhedron("(x) : (x - 1000 == 0)");
525   expectEqual(line, triangle);
526 
527   // Triangle with vertices (0, 0), (5, 0), (15, 5).
528   // Projected on x, it becomes [0, 13] U {15} as it becomes too narrow towards
529   // the apex and so does not have any integer point at x = 14.
530   // At x = 15, the apex is an integer point.
531   PresburgerSet triangle2{
532       parseIntegerPolyhedronAndMakeLocals("(x,y) : (y >= 0, "
533                                           "x - 3*y >= 0, "
534                                           "2*y - x + 5 >= 0)",
535                                           /*numLocals=*/1)};
536   PresburgerSet zeroToThirteen{
537       parseIntegerPolyhedron("(x) : (13 - x >= 0, x >= 0)")};
538   PresburgerSet fifteen{parseIntegerPolyhedron("(x) : (x - 15 == 0)")};
539   expectEqual(triangle2.subtract(zeroToThirteen), fifteen);
540 }
541 
TEST(SetTest,subtractDuplicateDivsRegression)542 TEST(SetTest, subtractDuplicateDivsRegression) {
543   // Previously, subtracting sets with duplicate divs might result in crashes
544   // due to existing divs being removed when merging local ids, due to being
545   // identified as being duplicates for the first time.
546   IntegerPolyhedron setA(PresburgerSpace::getSetSpace(1));
547   setA.addLocalFloorDiv({1, 0}, 2);
548   setA.addLocalFloorDiv({1, 0, 0}, 2);
549   EXPECT_TRUE(setA.isEqual(setA));
550 }
551 
552 /// Coalesce `set` and check that the `newSet` is equal to `set` and that
553 /// `expectedNumPoly` matches the number of Poly in the coalesced set.
expectCoalesce(size_t expectedNumPoly,const PresburgerSet & set)554 void expectCoalesce(size_t expectedNumPoly, const PresburgerSet &set) {
555   PresburgerSet newSet = set.coalesce();
556   EXPECT_TRUE(set.isEqual(newSet));
557   EXPECT_TRUE(expectedNumPoly == newSet.getNumDisjuncts());
558 }
559 
TEST(SetTest,coalesceNoPoly)560 TEST(SetTest, coalesceNoPoly) {
561   PresburgerSet set = makeSetFromPoly(0, {});
562   expectCoalesce(0, set);
563 }
564 
TEST(SetTest,coalesceContainedOneDim)565 TEST(SetTest, coalesceContainedOneDim) {
566   PresburgerSet set = parsePresburgerSet(
567       {"(x) : (x >= 0, -x + 4 >= 0)", "(x) : (x - 1 >= 0, -x + 2 >= 0)"});
568   expectCoalesce(1, set);
569 }
570 
TEST(SetTest,coalesceFirstEmpty)571 TEST(SetTest, coalesceFirstEmpty) {
572   PresburgerSet set = parsePresburgerSet(
573       {"(x) : ( x >= 0, -x - 1 >= 0)", "(x) : ( x - 1 >= 0, -x + 2 >= 0)"});
574   expectCoalesce(1, set);
575 }
576 
TEST(SetTest,coalesceSecondEmpty)577 TEST(SetTest, coalesceSecondEmpty) {
578   PresburgerSet set = parsePresburgerSet(
579       {"(x) : (x - 1 >= 0, -x + 2 >= 0)", "(x) : (x >= 0, -x - 1 >= 0)"});
580   expectCoalesce(1, set);
581 }
582 
TEST(SetTest,coalesceBothEmpty)583 TEST(SetTest, coalesceBothEmpty) {
584   PresburgerSet set = parsePresburgerSet(
585       {"(x) : (x - 3 >= 0, -x - 1 >= 0)", "(x) : (x >= 0, -x - 1 >= 0)"});
586   expectCoalesce(0, set);
587 }
588 
TEST(SetTest,coalesceFirstUniv)589 TEST(SetTest, coalesceFirstUniv) {
590   PresburgerSet set =
591       parsePresburgerSet({"(x) : ()", "(x) : ( x >= 0, -x + 1 >= 0)"});
592   expectCoalesce(1, set);
593 }
594 
TEST(SetTest,coalesceSecondUniv)595 TEST(SetTest, coalesceSecondUniv) {
596   PresburgerSet set =
597       parsePresburgerSet({"(x) : ( x >= 0, -x + 1 >= 0)", "(x) : ()"});
598   expectCoalesce(1, set);
599 }
600 
TEST(SetTest,coalesceBothUniv)601 TEST(SetTest, coalesceBothUniv) {
602   PresburgerSet set = parsePresburgerSet({"(x) : ()", "(x) : ()"});
603   expectCoalesce(1, set);
604 }
605 
TEST(SetTest,coalesceFirstUnivSecondEmpty)606 TEST(SetTest, coalesceFirstUnivSecondEmpty) {
607   PresburgerSet set =
608       parsePresburgerSet({"(x) : ()", "(x) : ( x >= 0, -x - 1 >= 0)"});
609   expectCoalesce(1, set);
610 }
611 
TEST(SetTest,coalesceFirstEmptySecondUniv)612 TEST(SetTest, coalesceFirstEmptySecondUniv) {
613   PresburgerSet set =
614       parsePresburgerSet({"(x) : ( x >= 0, -x - 1 >= 0)", "(x) : ()"});
615   expectCoalesce(1, set);
616 }
617 
TEST(SetTest,coalesceCutOneDim)618 TEST(SetTest, coalesceCutOneDim) {
619   PresburgerSet set = parsePresburgerSet({
620       "(x) : ( x >= 0, -x + 3 >= 0)",
621       "(x) : ( x - 2 >= 0, -x + 4 >= 0)",
622   });
623   expectCoalesce(1, set);
624 }
625 
TEST(SetTest,coalesceSeparateOneDim)626 TEST(SetTest, coalesceSeparateOneDim) {
627   PresburgerSet set = parsePresburgerSet(
628       {"(x) : ( x >= 0, -x + 2 >= 0)", "(x) : ( x - 3 >= 0, -x + 4 >= 0)"});
629   expectCoalesce(2, set);
630 }
631 
TEST(SetTest,coalesceAdjEq)632 TEST(SetTest, coalesceAdjEq) {
633   PresburgerSet set =
634       parsePresburgerSet({"(x) : ( x == 0)", "(x) : ( x - 1 == 0)"});
635   expectCoalesce(2, set);
636 }
637 
TEST(SetTest,coalesceContainedTwoDim)638 TEST(SetTest, coalesceContainedTwoDim) {
639   PresburgerSet set = parsePresburgerSet({
640       "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 3 >= 0)",
641       "(x,y) : (x >= 0, -x + 3 >= 0, y - 2 >= 0, -y + 3 >= 0)",
642   });
643   expectCoalesce(1, set);
644 }
645 
TEST(SetTest,coalesceCutTwoDim)646 TEST(SetTest, coalesceCutTwoDim) {
647   PresburgerSet set = parsePresburgerSet({
648       "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 2 >= 0)",
649       "(x,y) : (x >= 0, -x + 3 >= 0, y - 1 >= 0, -y + 3 >= 0)",
650   });
651   expectCoalesce(1, set);
652 }
653 
TEST(SetTest,coalesceEqStickingOut)654 TEST(SetTest, coalesceEqStickingOut) {
655   PresburgerSet set = parsePresburgerSet({
656       "(x,y) : (x >= 0, -x + 2 >= 0, y >= 0, -y + 2 >= 0)",
657       "(x,y) : (y - 1 == 0, x >= 0, -x + 3 >= 0)",
658   });
659   expectCoalesce(2, set);
660 }
661 
TEST(SetTest,coalesceSeparateTwoDim)662 TEST(SetTest, coalesceSeparateTwoDim) {
663   PresburgerSet set = parsePresburgerSet({
664       "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 1 >= 0)",
665       "(x,y) : (x >= 0, -x + 3 >= 0, y - 2 >= 0, -y + 3 >= 0)",
666   });
667   expectCoalesce(2, set);
668 }
669 
TEST(SetTest,coalesceContainedEq)670 TEST(SetTest, coalesceContainedEq) {
671   PresburgerSet set = parsePresburgerSet({
672       "(x,y) : (x >= 0, -x + 3 >= 0, x - y == 0)",
673       "(x,y) : (x - 1 >= 0, -x + 2 >= 0, x - y == 0)",
674   });
675   expectCoalesce(1, set);
676 }
677 
TEST(SetTest,coalesceCuttingEq)678 TEST(SetTest, coalesceCuttingEq) {
679   PresburgerSet set = parsePresburgerSet({
680       "(x,y) : (x + 1 >= 0, -x + 1 >= 0, x - y == 0)",
681       "(x,y) : (x >= 0, -x + 2 >= 0, x - y == 0)",
682   });
683   expectCoalesce(1, set);
684 }
685 
TEST(SetTest,coalesceSeparateEq)686 TEST(SetTest, coalesceSeparateEq) {
687   PresburgerSet set = parsePresburgerSet({
688       "(x,y) : (x - 3 >= 0, -x + 4 >= 0, x - y == 0)",
689       "(x,y) : (x >= 0, -x + 1 >= 0, x - y == 0)",
690   });
691   expectCoalesce(2, set);
692 }
693 
TEST(SetTest,coalesceContainedEqAsIneq)694 TEST(SetTest, coalesceContainedEqAsIneq) {
695   PresburgerSet set = parsePresburgerSet({
696       "(x,y) : (x >= 0, -x + 3 >= 0, x - y >= 0, -x + y >= 0)",
697       "(x,y) : (x - 1 >= 0, -x + 2 >= 0, x - y == 0)",
698   });
699   expectCoalesce(1, set);
700 }
701 
TEST(SetTest,coalesceContainedEqComplex)702 TEST(SetTest, coalesceContainedEqComplex) {
703   PresburgerSet set = parsePresburgerSet({
704       "(x,y) : (x - 2 == 0, x - y == 0)",
705       "(x,y) : (x - 1 >= 0, -x + 2 >= 0, x - y == 0)",
706   });
707   expectCoalesce(1, set);
708 }
709 
TEST(SetTest,coalesceThreeContained)710 TEST(SetTest, coalesceThreeContained) {
711   PresburgerSet set = parsePresburgerSet({
712       "(x) : (x >= 0, -x + 1 >= 0)",
713       "(x) : (x >= 0, -x + 2 >= 0)",
714       "(x) : (x >= 0, -x + 3 >= 0)",
715   });
716   expectCoalesce(1, set);
717 }
718 
TEST(SetTest,coalesceDoubleIncrement)719 TEST(SetTest, coalesceDoubleIncrement) {
720   PresburgerSet set = parsePresburgerSet({
721       "(x) : (x == 0)",
722       "(x) : (x - 2 == 0)",
723       "(x) : (x + 2 == 0)",
724       "(x) : (x - 2 >= 0, -x + 3 >= 0)",
725   });
726   expectCoalesce(3, set);
727 }
728 
TEST(SetTest,coalesceLastCoalesced)729 TEST(SetTest, coalesceLastCoalesced) {
730   PresburgerSet set = parsePresburgerSet({
731       "(x) : (x == 0)",
732       "(x) : (x - 1 >= 0, -x + 3 >= 0)",
733       "(x) : (x + 2 == 0)",
734       "(x) : (x - 2 >= 0, -x + 4 >= 0)",
735   });
736   expectCoalesce(3, set);
737 }
738 
TEST(SetTest,coalesceDiv)739 TEST(SetTest, coalesceDiv) {
740   PresburgerSet set = parsePresburgerSet({
741       "(x) : (x floordiv 2 == 0)",
742       "(x) : (x floordiv 2 - 1 == 0)",
743   });
744   expectCoalesce(2, set);
745 }
746 
TEST(SetTest,coalesceDivOtherContained)747 TEST(SetTest, coalesceDivOtherContained) {
748   PresburgerSet set = parsePresburgerSet({
749       "(x) : (x floordiv 2 == 0)",
750       "(x) : (x == 0)",
751       "(x) : (x >= 0, -x + 1 >= 0)",
752   });
753   expectCoalesce(2, set);
754 }
755 
756 static void
expectComputedVolumeIsValidOverapprox(const PresburgerSet & set,std::optional<int64_t> trueVolume,std::optional<int64_t> resultBound)757 expectComputedVolumeIsValidOverapprox(const PresburgerSet &set,
758                                       std::optional<int64_t> trueVolume,
759                                       std::optional<int64_t> resultBound) {
760   expectComputedVolumeIsValidOverapprox(set.computeVolume(), trueVolume,
761                                         resultBound);
762 }
763 
TEST(SetTest,computeVolume)764 TEST(SetTest, computeVolume) {
765   // Diamond with vertices at (0, 0), (5, 5), (5, 5), (10, 0).
766   PresburgerSet diamond(parseIntegerPolyhedron(
767       "(x, y) : (x + y >= 0, -x - y + 10 >= 0, x - y >= 0, -x + y + "
768       "10 >= 0)"));
769   expectComputedVolumeIsValidOverapprox(diamond,
770                                         /*trueVolume=*/61ull,
771                                         /*resultBound=*/121ull);
772 
773   // Diamond with vertices at (-5, 0), (0, -5), (0, 5), (5, 0).
774   PresburgerSet shiftedDiamond(parseIntegerPolyhedron(
775       "(x, y) : (x + y + 5 >= 0, -x - y + 5 >= 0, x - y + 5 >= 0, -x + y + "
776       "5 >= 0)"));
777   expectComputedVolumeIsValidOverapprox(shiftedDiamond,
778                                         /*trueVolume=*/61ull,
779                                         /*resultBound=*/121ull);
780 
781   // Diamond with vertices at (-5, 0), (5, -10), (5, 10), (15, 0).
782   PresburgerSet biggerDiamond(parseIntegerPolyhedron(
783       "(x, y) : (x + y + 5 >= 0, -x - y + 15 >= 0, x - y + 5 >= 0, -x + y + "
784       "15 >= 0)"));
785   expectComputedVolumeIsValidOverapprox(biggerDiamond,
786                                         /*trueVolume=*/221ull,
787                                         /*resultBound=*/441ull);
788 
789   // There is some overlap between diamond and shiftedDiamond.
790   expectComputedVolumeIsValidOverapprox(diamond.unionSet(shiftedDiamond),
791                                         /*trueVolume=*/104ull,
792                                         /*resultBound=*/242ull);
793 
794   // biggerDiamond subsumes both the small ones.
795   expectComputedVolumeIsValidOverapprox(
796       diamond.unionSet(shiftedDiamond).unionSet(biggerDiamond),
797       /*trueVolume=*/221ull,
798       /*resultBound=*/683ull);
799 
800   // Unbounded polytope.
801   PresburgerSet unbounded(
802       parseIntegerPolyhedron("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"));
803   expectComputedVolumeIsValidOverapprox(unbounded, /*trueVolume=*/{},
804                                         /*resultBound=*/{});
805 
806   // Union of unbounded with bounded is unbounded.
807   expectComputedVolumeIsValidOverapprox(unbounded.unionSet(diamond),
808                                         /*trueVolume=*/{},
809                                         /*resultBound=*/{});
810 }
811 
812 // The last `numToProject` dims will be projected out, i.e., converted to
813 // locals.
testComputeReprAtPoints(IntegerPolyhedron poly,ArrayRef<SmallVector<int64_t,4>> points,unsigned numToProject)814 void testComputeReprAtPoints(IntegerPolyhedron poly,
815                              ArrayRef<SmallVector<int64_t, 4>> points,
816                              unsigned numToProject) {
817   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
818                       poly.getNumDimVars(), VarKind::Local);
819   PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
820   EXPECT_TRUE(repr.hasOnlyDivLocals());
821   EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
822   for (const SmallVector<int64_t, 4> &point : points) {
823     EXPECT_EQ(poly.containsPointNoLocal(point).has_value(),
824               repr.containsPoint(point));
825   }
826 }
827 
testComputeRepr(IntegerPolyhedron poly,const PresburgerSet & expected,unsigned numToProject)828 void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
829                      unsigned numToProject) {
830   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
831                       poly.getNumDimVars(), VarKind::Local);
832   PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
833   EXPECT_TRUE(repr.hasOnlyDivLocals());
834   EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
835   EXPECT_TRUE(repr.isEqual(expected));
836 }
837 
TEST(SetTest,computeReprWithOnlyDivLocals)838 TEST(SetTest, computeReprWithOnlyDivLocals) {
839   testComputeReprAtPoints(parseIntegerPolyhedron("(x, y) : (x - 2*y == 0)"),
840                           {{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}},
841                           /*numToProject=*/0);
842   testComputeReprAtPoints(parseIntegerPolyhedron("(x, e) : (x - 2*e == 0)"),
843                           {{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1);
844 
845   // Tests to check that the space is preserved.
846   testComputeReprAtPoints(parseIntegerPolyhedron("(x, y)[z, w] : ()"), {},
847                           /*numToProject=*/1);
848   testComputeReprAtPoints(
849       parseIntegerPolyhedron("(x, y)[z, w] : (z - (w floordiv 2) == 0)"), {},
850       /*numToProject=*/1);
851 
852   // Bezout's lemma: if a, b are constants,
853   // the set of values that ax + by can take is all multiples of gcd(a, b).
854   testComputeRepr(parseIntegerPolyhedron("(x, e, f) : (x - 15*e - 21*f == 0)"),
855                   PresburgerSet(parseIntegerPolyhedron(
856                       {"(x) : (x - 3*(x floordiv 3) == 0)"})),
857                   /*numToProject=*/2);
858 }
859 
TEST(SetTest,subtractOutputSizeRegression)860 TEST(SetTest, subtractOutputSizeRegression) {
861   PresburgerSet set1 = parsePresburgerSet({"(i) : (i >= 0, 10 - i >= 0)"});
862   PresburgerSet set2 = parsePresburgerSet({"(i) : (i - 5 >= 0)"});
863 
864   PresburgerSet set3 = parsePresburgerSet({"(i) : (i >= 0, 4 - i >= 0)"});
865 
866   PresburgerSet result = set1.subtract(set2);
867 
868   EXPECT_TRUE(result.isEqual(set3));
869 
870   // Previously, the subtraction result was producing an extra empty set, which
871   // is correct, but bad for output size.
872   EXPECT_EQ(result.getNumDisjuncts(), 1u);
873 
874   PresburgerSet subtractSelf = set1.subtract(set1);
875   EXPECT_TRUE(subtractSelf.isIntegerEmpty());
876   // Previously, the subtraction result was producing several unnecessary empty
877   // sets, which is correct, but bad for output size.
878   EXPECT_EQ(subtractSelf.getNumDisjuncts(), 0u);
879 }
880