xref: /llvm-project/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp (revision 94750af83640cd702d80c53ab99d5bb303e55796)
1 //===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains tests for PWMAFunction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include "mlir/Analysis/Presburger/PWMAFunction.h"
16 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
17 #include "mlir/IR/MLIRContext.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 
22 using namespace mlir;
23 using namespace presburger;
24 
25 using testing::ElementsAre;
26 
TEST(PWAFunctionTest,isEqual)27 TEST(PWAFunctionTest, isEqual) {
28   // The output expressions are different but it doesn't matter because they are
29   // equal in this domain.
30   PWMAFunction idAtZeros =
31       parsePWMAF({{"(x, y) : (y == 0)", "(x, y) -> (x, y)"},
32                   {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (x, y)"},
33                   {"(x, y) : (-y - 1 >= 0, x == 0)", "(x, y) -> (x, y)"}});
34   PWMAFunction idAtZeros2 =
35       parsePWMAF({{"(x, y) : (y == 0)", "(x, y) -> (x, 20*y)"},
36                   {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (30*x, y)"},
37                   {"(x, y) : (-y - 1 > =0, x == 0)", "(x, y) -> (30*x, y)"}});
38   EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
39 
40   PWMAFunction notIdAtZeros = parsePWMAF({
41       {"(x, y) : (y == 0)", "(x, y) -> (x, y)"},
42       {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (x, 2*y)"},
43       {"(x, y) : (-y - 1 >= 0, x == 0)", "(x, y) -> (x, 2*y)"},
44   });
45   EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
46 
47   // These match at their intersection but one has a bigger domain.
48   PWMAFunction idNoNegNegQuadrant =
49       parsePWMAF({{"(x, y) : (x >= 0)", "(x, y) -> (x, y)"},
50                   {"(x, y) : (-x - 1 >= 0, y >= 0)", "(x, y) -> (x, y)"}});
51   PWMAFunction idOnlyPosX = parsePWMAF({
52       {"(x, y) : (x >= 0)", "(x, y) -> (x, y)"},
53   });
54   EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
55 
56   // Different representations of the same domain.
57   PWMAFunction sumPlusOne = parsePWMAF({
58       {"(x, y) : (x >= 0)", "(x, y) -> (x + y + 1)"},
59       {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", "(x, y) -> (x + y + 1)"},
60       {"(x, y) : (-x - 1 >= 0, y >= 0)", "(x, y) -> (x + y + 1)"},
61   });
62   PWMAFunction sumPlusOne2 = parsePWMAF({
63       {"(x, y) : ()", "(x, y) -> (x + y + 1)"},
64   });
65   EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
66 
67   // Functions with zero input dimensions.
68   PWMAFunction noInputs1 = parsePWMAF({
69       {"() : ()", "() -> (1)"},
70   });
71   PWMAFunction noInputs2 = parsePWMAF({
72       {"() : ()", "() -> (2)"},
73   });
74   EXPECT_TRUE(noInputs1.isEqual(noInputs1));
75   EXPECT_FALSE(noInputs1.isEqual(noInputs2));
76 
77   // Mismatched dimensionalities.
78   EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
79   EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
80 
81   // Divisions.
82   // Domain is only multiples of 6; x = 6k for some k.
83   // x + 4(x/2) + 4(x/3) == 26k.
84   PWMAFunction mul2AndMul3 = parsePWMAF({
85       {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
86        "(x) -> (x + 4 * (x floordiv 2) + 4 * (x floordiv 3))"},
87   });
88   PWMAFunction mul6 = parsePWMAF({
89       {"(x) : (x - 6*(x floordiv 6) == 0)", "(x) -> (26 * (x floordiv 6))"},
90   });
91   EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
92 
93   PWMAFunction mul6diff = parsePWMAF({
94       {"(x) : (x - 5*(x floordiv 5) == 0)", "(x) -> (52 * (x floordiv 6))"},
95   });
96   EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
97 
98   PWMAFunction mul5 = parsePWMAF({
99       {"(x) : (x - 5*(x floordiv 5) == 0)", "(x) -> (26 * (x floordiv 5))"},
100   });
101   EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
102 }
103 
TEST(PWMAFunction,valueAt)104 TEST(PWMAFunction, valueAt) {
105   PWMAFunction nonNegPWMAF = parsePWMAF(
106       {{"(x, y) : (x >= 0)", "(x, y) -> (x + 2*y + 3, 3*x + 4*y + 5)"},
107        {"(x, y) : (y >= 0, -x - 1 >= 0)",
108         "(x, y) -> (-x + 2*y + 3, -3*x + 4*y + 5)"}});
109   EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
110   EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
111   EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
112   EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).has_value());
113 
114   PWMAFunction divPWMAF = parsePWMAF(
115       {{"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
116         "(x, y) -> (2*y + (x floordiv 2) + 3, 4*y + 3*(x floordiv 2) + 5)"},
117        {"(x, y) : (y >= 0, -x - 1 >= 0)",
118         "(x, y) -> (-x + 2*y + 3, -3*x + 4*y + 5)"}});
119   EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
120   EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
121   EXPECT_FALSE(divPWMAF.valueAt({3, 3}).has_value());
122   EXPECT_FALSE(divPWMAF.valueAt({3, -3}).has_value());
123 
124   EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
125   EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).has_value());
126 }
127 
TEST(PWMAFunction,removeIdRangeRegressionTest)128 TEST(PWMAFunction, removeIdRangeRegressionTest) {
129   PWMAFunction pwmafA = parsePWMAF({
130       {"(x, y) : (x == 0, y == 0, x - 2*(x floordiv 2) == 0, y - 2*(y floordiv "
131        "2) == 0)",
132        "(x, y) -> (0, 0)"},
133   });
134   PWMAFunction pwmafB = parsePWMAF({
135       {"(x, y) : (x - 11*y == 0, 11*x - y == 0, x - 2*(x floordiv 2) == 0, "
136        "y - 2*(y floordiv 2) == 0)",
137        "(x, y) -> (0, 0)"},
138   });
139   EXPECT_TRUE(pwmafA.isEqual(pwmafB));
140 }
141 
TEST(PWMAFunction,eliminateRedundantLocalIdRegressionTest)142 TEST(PWMAFunction, eliminateRedundantLocalIdRegressionTest) {
143   PWMAFunction pwmafA = parsePWMAF({
144       {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)", "(x, y) -> (y)"},
145   });
146   PWMAFunction pwmafB = parsePWMAF({
147       {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)",
148        "(x, y) -> (x - y)"},
149   });
150   EXPECT_TRUE(pwmafA.isEqual(pwmafB));
151 }
152 
TEST(PWMAFunction,unionLexMaxSimple)153 TEST(PWMAFunction, unionLexMaxSimple) {
154   // func2 is better than func1, but func2's domain is empty.
155   {
156     PWMAFunction func1 = parsePWMAF({
157         {"(x) : ()", "(x) -> (1)"},
158     });
159 
160     PWMAFunction func2 = parsePWMAF({
161         {"(x) : (1 == 0)", "(x) -> (2)"},
162     });
163 
164     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1));
165     EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1));
166   }
167 
168   // func2 is better than func1 on a subset of func1.
169   {
170     PWMAFunction func1 = parsePWMAF({
171         {"(x) : ()", "(x) -> (1)"},
172     });
173 
174     PWMAFunction func2 = parsePWMAF({
175         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (2)"},
176     });
177 
178     PWMAFunction result = parsePWMAF({
179         {"(x) : (-1 - x >= 0)", "(x) -> (1)"},
180         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (2)"},
181         {"(x) : (x - 11 >= 0)", "(x) -> (1)"},
182     });
183 
184     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
185     EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
186   }
187 
188   // func1 and func2 are defined over the whole domain with different outputs.
189   {
190     PWMAFunction func1 = parsePWMAF({
191         {"(x) : ()", "(x) -> (x)"},
192     });
193 
194     PWMAFunction func2 = parsePWMAF({
195         {"(x) : ()", "(x) -> (-x)"},
196     });
197 
198     PWMAFunction result = parsePWMAF({
199         {"(x) : (x >= 0)", "(x) -> (x)"},
200         {"(x) : (-1 - x >= 0)", "(x) -> (-x)"},
201     });
202 
203     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
204     EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
205   }
206 
207   // func1 and func2 have disjoint domains.
208   {
209     PWMAFunction func1 = parsePWMAF({
210         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (1)"},
211         {"(x) : (x - 71 >= 0, 80 - x >= 0)", "(x) -> (1)"},
212     });
213 
214     PWMAFunction func2 = parsePWMAF({
215         {"(x) : (x - 20 >= 0, 41 - x >= 0)", "(x) -> (2)"},
216         {"(x) : (x - 101 >= 0, 120 - x >= 0)", "(x) -> (2)"},
217     });
218 
219     PWMAFunction result = parsePWMAF({
220         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (1)"},
221         {"(x) : (x - 71 >= 0, 80 - x >= 0)", "(x) -> (1)"},
222         {"(x) : (x - 20 >= 0, 41 - x >= 0)", "(x) -> (2)"},
223         {"(x) : (x - 101 >= 0, 120 - x >= 0)", "(x) -> (2)"},
224     });
225 
226     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
227     EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
228   }
229 }
230 
TEST(PWMAFunction,unionLexMinSimple)231 TEST(PWMAFunction, unionLexMinSimple) {
232   // func2 is better than func1, but func2's domain is empty.
233   {
234     PWMAFunction func1 = parsePWMAF({
235         {"(x) : ()", "(x) -> (-1)"},
236     });
237 
238     PWMAFunction func2 = parsePWMAF({
239         {"(x) : (1 == 0)", "(x) -> (-2)"},
240     });
241 
242     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1));
243     EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1));
244   }
245 
246   // func2 is better than func1 on a subset of func1.
247   {
248     PWMAFunction func1 = parsePWMAF({
249         {"(x) : ()", "(x) -> (-1)"},
250     });
251 
252     PWMAFunction func2 = parsePWMAF({
253         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (-2)"},
254     });
255 
256     PWMAFunction result = parsePWMAF({
257         {"(x) : (-1 - x >= 0)", "(x) -> (-1)"},
258         {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (-2)"},
259         {"(x) : (x - 11 >= 0)", "(x) -> (-1)"},
260     });
261 
262     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
263     EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
264   }
265 
266   // func1 and func2 are defined over the whole domain with different outputs.
267   {
268     PWMAFunction func1 = parsePWMAF({
269         {"(x) : ()", "(x) -> (-x)"},
270     });
271 
272     PWMAFunction func2 = parsePWMAF({
273         {"(x) : ()", "(x) -> (x)"},
274     });
275 
276     PWMAFunction result = parsePWMAF({
277         {"(x) : (x >= 0)", "(x) -> (-x)"},
278         {"(x) : (-1 - x >= 0)", "(x) -> (x)"},
279     });
280 
281     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
282     EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
283   }
284 }
285 
TEST(PWMAFunction,unionLexMaxComplex)286 TEST(PWMAFunction, unionLexMaxComplex) {
287   // Union of function containing 4 different pieces of output.
288   //
289   // x >= 21               --> func1 (func2 not defined)
290   // x <= 0                --> func2 (func1 not defined)
291   // 10 <= x <= 20, y >  0 --> func1 (x + y  > x - y for y >  0)
292   // 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0)
293   {
294     PWMAFunction func1 = parsePWMAF({
295         {"(x, y) : (x >= 10)", "(x, y) -> (x + y)"},
296     });
297 
298     PWMAFunction func2 = parsePWMAF({
299         {"(x, y) : (x <= 20)", "(x, y) -> (x - y)"},
300     });
301 
302     PWMAFunction result = parsePWMAF({
303         {"(x, y) : (x >= 10, x <= 20, y >= 1)", "(x, y) -> (x + y)"},
304         {"(x, y) : (x >= 21)", "(x, y) -> (x + y)"},
305         {"(x, y) : (x <= 9)", "(x, y) -> (x - y)"},
306         {"(x, y) : (x >= 10, x <= 20, y <= 0)", "(x, y) -> (x - y)"},
307     });
308 
309     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
310   }
311 
312   // Functions with more than one output, with contribution from both functions.
313   //
314   // If y >= 1, func1 is better because in the first output,
315   // x + y (func1) > x (func2), when y >= 1
316   //
317   // If y == 0, the first output is same for both functions, so we look at the
318   // second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we
319   // take func1 for this domain and func2 for the remaining.
320   {
321     PWMAFunction func1 = parsePWMAF({
322         {"(x, y) : (x >= 0, y >= 0)", "(x, y) -> (x + y, -2*x + 4)"},
323     });
324 
325     PWMAFunction func2 = parsePWMAF({
326         {"(x, y) : (x >= 0, y >= 0)", "(x, y) -> (x, 2*x - 2)"},
327     });
328 
329     PWMAFunction result = parsePWMAF({
330         {"(x, y) : (x >= 0, y >= 1)", "(x, y) -> (x + y, -2*x + 4)"},
331         {"(x, y) : (x >= 0, x <= 1, y == 0)", "(x, y) -> (x + y, -2*x + 4)"},
332         {"(x, y) : (x >= 2, y == 0)", "(x, y) -> (x, 2*x - 2)"},
333     });
334 
335     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
336     EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
337   }
338 
339   // Function with three boolean variables `a, b, c` used to control which
340   // output will be taken lexicographically.
341   //
342   // a == 1                 --> Take func2
343   // a == 0, b == 1         --> Take func1
344   // a == 0, b == 0, c == 1 --> Take func2
345   {
346     PWMAFunction func1 = parsePWMAF({
347         {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c "
348          ">= 0, 1 - c >= 0)",
349          "(a, b, c) -> (0, b, 0)"},
350     });
351 
352     PWMAFunction func2 = parsePWMAF({
353         {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - "
354          "c >= 0)",
355          "(a, b, c) -> (a, 0, c)"},
356     });
357 
358     PWMAFunction result = parsePWMAF({
359         {"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)",
360          "(a, b, c) -> (a, 0, c)"},
361         {"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)",
362          "(a, b, c) -> (0, b, 0)"},
363         {"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)",
364          "(a, b, c) -> (a, 0, c)"},
365     });
366 
367     EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
368     EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
369   }
370 }
371 
TEST(PWMAFunction,unionLexMinComplex)372 TEST(PWMAFunction, unionLexMinComplex) {
373   // Regression test checking if lexicographic tiebreak produces disjoint
374   // domains.
375   //
376   // If x == 1, func1 is better since in the first output,
377   // -x (func1) is < 0 (func2) when x == 1.
378   //
379   // If x == 0, func1 and func2 both have the same first output. So we take a
380   // look at the second output. func2 is better since in the second output,
381   // y - 1 (func2) is < y (func1).
382   PWMAFunction func1 = parsePWMAF({
383       {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", "(x, y) -> (-x, y)"},
384   });
385 
386   PWMAFunction func2 = parsePWMAF({
387       {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", "(x, y) -> (0, y - 1)"},
388   });
389 
390   PWMAFunction result = parsePWMAF({
391       {"(x, y) : (x == 1, y >= 0, y <= 1)", "(x, y) -> (-x, y)"},
392       {"(x, y) : (x == 0, y >= 0, y <= 1)", "(x, y) -> (0, y - 1)"},
393   });
394 
395   EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
396   EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
397 }
398 
TEST(PWMAFunction,unionLexMinWithDivs)399 TEST(PWMAFunction, unionLexMinWithDivs) {
400   {
401     PWMAFunction func1 = parsePWMAF({
402         {"(x, y) : (x mod 5 == 0)", "(x, y) -> (x, 1)"},
403     });
404 
405     PWMAFunction func2 = parsePWMAF({
406         {"(x, y) : (x mod 7 == 0)", "(x, y) -> (x + y, 2)"},
407     });
408 
409     PWMAFunction result = parsePWMAF({
410         {"(x, y) : (x mod 5 == 0, x mod 7 >= 1)", "(x, y) -> (x, 1)"},
411         {"(x, y) : (x mod 7 == 0, x mod 5 >= 1)", "(x, y) -> (x + y, 2)"},
412         {"(x, y) : (x mod 5 == 0, x mod 7 == 0, y >= 0)", "(x, y) -> (x, 1)"},
413         {"(x, y) : (x mod 7 == 0, x mod 5 == 0, y <= -1)",
414          "(x, y) -> (x + y, 2)"},
415     });
416 
417     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
418   }
419 
420   {
421     PWMAFunction func1 = parsePWMAF({
422         {"(x) : (x >= 0, x <= 1000)", "(x) -> (x floordiv 16)"},
423     });
424 
425     PWMAFunction func2 = parsePWMAF({
426         {"(x) : (x >= 0, x <= 1000)", "(x) -> ((x + 10) floordiv 17)"},
427     });
428 
429     PWMAFunction result = parsePWMAF({
430         {"(x) : (x >= 0, x <= 1000, x floordiv 16 <= (x + 10) floordiv 17)",
431          "(x) -> (x floordiv 16)"},
432         {"(x) : (x >= 0, x <= 1000, x floordiv 16 >= (x + 10) floordiv 17 + 1)",
433          "(x) -> ((x + 10) floordiv 17)"},
434     });
435 
436     EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
437   }
438 }
439