xref: /llvm-project/mlir/unittests/Analysis/Presburger/MatrixTest.cpp (revision cfd51fbadd19a7b83805a84cb5745f1350c798c4)
1 //===- MatrixTest.cpp - Tests for Matrix ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/Matrix.h"
10 #include "./Utils.h"
11 #include "mlir/Analysis/Presburger/Fraction.h"
12 #include <gmock/gmock.h>
13 #include <gtest/gtest.h>
14 
15 using namespace mlir;
16 using namespace presburger;
17 
18 TEST(MatrixTest, ReadWrite) {
19   IntMatrix mat(5, 5);
20   for (unsigned row = 0; row < 5; ++row)
21     for (unsigned col = 0; col < 5; ++col)
22       mat(row, col) = 10 * row + col;
23   for (unsigned row = 0; row < 5; ++row)
24     for (unsigned col = 0; col < 5; ++col)
25       EXPECT_EQ(mat(row, col), int(10 * row + col));
26 }
27 
28 TEST(MatrixTest, SwapColumns) {
29   IntMatrix mat(5, 5);
30   for (unsigned row = 0; row < 5; ++row)
31     for (unsigned col = 0; col < 5; ++col)
32       mat(row, col) = col == 3 ? 1 : 0;
33   mat.swapColumns(3, 1);
34   for (unsigned row = 0; row < 5; ++row)
35     for (unsigned col = 0; col < 5; ++col)
36       EXPECT_EQ(mat(row, col), col == 1 ? 1 : 0);
37 
38   // swap around all the other columns, swap (1, 3) twice for no effect.
39   mat.swapColumns(3, 1);
40   mat.swapColumns(2, 4);
41   mat.swapColumns(1, 3);
42   mat.swapColumns(0, 4);
43   mat.swapColumns(2, 2);
44 
45   for (unsigned row = 0; row < 5; ++row)
46     for (unsigned col = 0; col < 5; ++col)
47       EXPECT_EQ(mat(row, col), col == 1 ? 1 : 0);
48 }
49 
50 TEST(MatrixTest, SwapRows) {
51   IntMatrix mat(5, 5);
52   for (unsigned row = 0; row < 5; ++row)
53     for (unsigned col = 0; col < 5; ++col)
54       mat(row, col) = row == 2 ? 1 : 0;
55   mat.swapRows(2, 0);
56   for (unsigned row = 0; row < 5; ++row)
57     for (unsigned col = 0; col < 5; ++col)
58       EXPECT_EQ(mat(row, col), row == 0 ? 1 : 0);
59 
60   // swap around all the other rows, swap (2, 0) twice for no effect.
61   mat.swapRows(3, 4);
62   mat.swapRows(1, 4);
63   mat.swapRows(2, 0);
64   mat.swapRows(1, 1);
65   mat.swapRows(0, 2);
66 
67   for (unsigned row = 0; row < 5; ++row)
68     for (unsigned col = 0; col < 5; ++col)
69       EXPECT_EQ(mat(row, col), row == 0 ? 1 : 0);
70 }
71 
72 TEST(MatrixTest, resizeVertically) {
73   IntMatrix mat(5, 5);
74   EXPECT_EQ(mat.getNumRows(), 5u);
75   EXPECT_EQ(mat.getNumColumns(), 5u);
76   for (unsigned row = 0; row < 5; ++row)
77     for (unsigned col = 0; col < 5; ++col)
78       mat(row, col) = 10 * row + col;
79 
80   mat.resizeVertically(3);
81   ASSERT_TRUE(mat.hasConsistentState());
82   EXPECT_EQ(mat.getNumRows(), 3u);
83   EXPECT_EQ(mat.getNumColumns(), 5u);
84   for (unsigned row = 0; row < 3; ++row)
85     for (unsigned col = 0; col < 5; ++col)
86       EXPECT_EQ(mat(row, col), int(10 * row + col));
87 
88   mat.resizeVertically(5);
89   ASSERT_TRUE(mat.hasConsistentState());
90   EXPECT_EQ(mat.getNumRows(), 5u);
91   EXPECT_EQ(mat.getNumColumns(), 5u);
92   for (unsigned row = 0; row < 5; ++row)
93     for (unsigned col = 0; col < 5; ++col)
94       EXPECT_EQ(mat(row, col), row >= 3 ? 0 : int(10 * row + col));
95 }
96 
97 TEST(MatrixTest, insertColumns) {
98   IntMatrix mat(5, 5, 5, 10);
99   EXPECT_EQ(mat.getNumRows(), 5u);
100   EXPECT_EQ(mat.getNumColumns(), 5u);
101   for (unsigned row = 0; row < 5; ++row)
102     for (unsigned col = 0; col < 5; ++col)
103       mat(row, col) = 10 * row + col;
104 
105   mat.insertColumns(3, 100);
106   ASSERT_TRUE(mat.hasConsistentState());
107   EXPECT_EQ(mat.getNumRows(), 5u);
108   EXPECT_EQ(mat.getNumColumns(), 105u);
109   for (unsigned row = 0; row < 5; ++row) {
110     for (unsigned col = 0; col < 105; ++col) {
111       if (col < 3)
112         EXPECT_EQ(mat(row, col), int(10 * row + col));
113       else if (3 <= col && col <= 102)
114         EXPECT_EQ(mat(row, col), 0);
115       else
116         EXPECT_EQ(mat(row, col), int(10 * row + col - 100));
117     }
118   }
119 
120   mat.removeColumns(3, 100);
121   ASSERT_TRUE(mat.hasConsistentState());
122   mat.insertColumns(0, 0);
123   ASSERT_TRUE(mat.hasConsistentState());
124   mat.insertColumn(5);
125   ASSERT_TRUE(mat.hasConsistentState());
126 
127   EXPECT_EQ(mat.getNumRows(), 5u);
128   EXPECT_EQ(mat.getNumColumns(), 6u);
129   for (unsigned row = 0; row < 5; ++row)
130     for (unsigned col = 0; col < 6; ++col)
131       EXPECT_EQ(mat(row, col), col == 5 ? 0 : 10 * row + col);
132 }
133 
134 TEST(MatrixTest, insertRows) {
135   IntMatrix mat(5, 5, 5, 10);
136   ASSERT_TRUE(mat.hasConsistentState());
137   EXPECT_EQ(mat.getNumRows(), 5u);
138   EXPECT_EQ(mat.getNumColumns(), 5u);
139   for (unsigned row = 0; row < 5; ++row)
140     for (unsigned col = 0; col < 5; ++col)
141       mat(row, col) = 10 * row + col;
142 
143   mat.insertRows(3, 100);
144   ASSERT_TRUE(mat.hasConsistentState());
145   EXPECT_EQ(mat.getNumRows(), 105u);
146   EXPECT_EQ(mat.getNumColumns(), 5u);
147   for (unsigned row = 0; row < 105; ++row) {
148     for (unsigned col = 0; col < 5; ++col) {
149       if (row < 3)
150         EXPECT_EQ(mat(row, col), int(10 * row + col));
151       else if (3 <= row && row <= 102)
152         EXPECT_EQ(mat(row, col), 0);
153       else
154         EXPECT_EQ(mat(row, col), int(10 * (row - 100) + col));
155     }
156   }
157 
158   mat.removeRows(3, 100);
159   ASSERT_TRUE(mat.hasConsistentState());
160   mat.insertRows(0, 0);
161   ASSERT_TRUE(mat.hasConsistentState());
162   mat.insertRow(5);
163   ASSERT_TRUE(mat.hasConsistentState());
164 
165   EXPECT_EQ(mat.getNumRows(), 6u);
166   EXPECT_EQ(mat.getNumColumns(), 5u);
167   for (unsigned row = 0; row < 6; ++row)
168     for (unsigned col = 0; col < 5; ++col)
169       EXPECT_EQ(mat(row, col), row == 5 ? 0 : 10 * row + col);
170 }
171 
172 TEST(MatrixTest, resize) {
173   IntMatrix mat(5, 5);
174   EXPECT_EQ(mat.getNumRows(), 5u);
175   EXPECT_EQ(mat.getNumColumns(), 5u);
176   for (unsigned row = 0; row < 5; ++row)
177     for (unsigned col = 0; col < 5; ++col)
178       mat(row, col) = 10 * row + col;
179 
180   mat.resize(3, 3);
181   ASSERT_TRUE(mat.hasConsistentState());
182   EXPECT_EQ(mat.getNumRows(), 3u);
183   EXPECT_EQ(mat.getNumColumns(), 3u);
184   for (unsigned row = 0; row < 3; ++row)
185     for (unsigned col = 0; col < 3; ++col)
186       EXPECT_EQ(mat(row, col), int(10 * row + col));
187 
188   mat.resize(7, 7);
189   ASSERT_TRUE(mat.hasConsistentState());
190   EXPECT_EQ(mat.getNumRows(), 7u);
191   EXPECT_EQ(mat.getNumColumns(), 7u);
192   for (unsigned row = 0; row < 7; ++row)
193     for (unsigned col = 0; col < 7; ++col)
194       EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
195 }
196 
197 static void checkHermiteNormalForm(const IntMatrix &mat,
198                                    const IntMatrix &hermiteForm) {
199   auto [h, u] = mat.computeHermiteNormalForm();
200 
201   for (unsigned row = 0; row < mat.getNumRows(); row++)
202     for (unsigned col = 0; col < mat.getNumColumns(); col++)
203       EXPECT_EQ(h(row, col), hermiteForm(row, col));
204 }
205 
206 TEST(MatrixTest, computeHermiteNormalForm) {
207   // TODO: Add a check to test the original statement of hermite normal form
208   // instead of using a precomputed result.
209 
210   {
211     // Hermite form of a unimodular matrix is the identity matrix.
212     IntMatrix mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
213     IntMatrix hermiteForm =
214         makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
215     checkHermiteNormalForm(mat, hermiteForm);
216   }
217 
218   {
219     // Hermite form of a unimodular is the identity matrix.
220     IntMatrix mat = makeIntMatrix(
221         4, 4,
222         {{-6, -1, -19, -20}, {0, 1, 0, 0}, {-5, 0, -15, -16}, {6, 0, 18, 19}});
223     IntMatrix hermiteForm = makeIntMatrix(
224         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}});
225     checkHermiteNormalForm(mat, hermiteForm);
226   }
227 
228   {
229     IntMatrix mat = makeIntMatrix(
230         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
231     IntMatrix hermiteForm = makeIntMatrix(
232         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
233     checkHermiteNormalForm(mat, hermiteForm);
234   }
235 
236   {
237     IntMatrix mat = makeIntMatrix(
238         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
239     IntMatrix hermiteForm = makeIntMatrix(
240         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
241     checkHermiteNormalForm(mat, hermiteForm);
242   }
243 
244   {
245     IntMatrix mat = makeIntMatrix(
246         3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
247     IntMatrix hermiteForm = makeIntMatrix(
248         3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
249     checkHermiteNormalForm(mat, hermiteForm);
250   }
251 }
252 
253 TEST(MatrixTest, inverse) {
254   FracMatrix mat = makeFracMatrix(
255       2, 2, {{Fraction(2), Fraction(1)}, {Fraction(7), Fraction(0)}});
256   FracMatrix inverse = makeFracMatrix(
257       2, 2, {{Fraction(0), Fraction(1, 7)}, {Fraction(1), Fraction(-2, 7)}});
258 
259   FracMatrix inv(2, 2);
260   mat.determinant(&inv);
261 
262   EXPECT_EQ_FRAC_MATRIX(inv, inverse);
263 
264   mat = makeFracMatrix(
265       2, 2, {{Fraction(0), Fraction(1)}, {Fraction(0), Fraction(2)}});
266   Fraction det = mat.determinant(nullptr);
267 
268   EXPECT_EQ(det, Fraction(0));
269 
270   mat = makeFracMatrix(3, 3,
271                        {{Fraction(1), Fraction(2), Fraction(3)},
272                         {Fraction(4), Fraction(8), Fraction(6)},
273                         {Fraction(7), Fraction(8), Fraction(6)}});
274   inverse = makeFracMatrix(3, 3,
275                            {{Fraction(0), Fraction(-1, 3), Fraction(1, 3)},
276                             {Fraction(-1, 2), Fraction(5, 12), Fraction(-1, 6)},
277                             {Fraction(2, 3), Fraction(-1, 6), Fraction(0)}});
278 
279   mat.determinant(&inv);
280   EXPECT_EQ_FRAC_MATRIX(inv, inverse);
281 
282   mat = makeFracMatrix(0, 0, {});
283   mat.determinant(&inv);
284 }
285 
286 TEST(MatrixTest, intInverse) {
287   IntMatrix mat = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
288   IntMatrix inverse = makeIntMatrix(2, 2, {{0, -1}, {-7, 2}});
289 
290   IntMatrix inv(2, 2);
291   mat.determinant(&inv);
292 
293   EXPECT_EQ_INT_MATRIX(inv, inverse);
294 
295   mat = makeIntMatrix(
296       4, 4, {{4, 14, 11, 3}, {13, 5, 14, 12}, {13, 9, 7, 14}, {2, 3, 12, 7}});
297   inverse = makeIntMatrix(4, 4,
298                           {{155, 1636, -579, -1713},
299                            {725, -743, 537, -111},
300                            {210, 735, -855, 360},
301                            {-715, -1409, 1401, 1482}});
302 
303   mat.determinant(&inv);
304 
305   EXPECT_EQ_INT_MATRIX(inv, inverse);
306 
307   mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}});
308 
309   MPInt det = mat.determinant(&inv);
310 
311   EXPECT_EQ(det, 0);
312 }
313 
314 TEST(MatrixTest, gramSchmidt) {
315   FracMatrix mat =
316       makeFracMatrix(3, 5,
317                      {{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1),
318                        Fraction(12, 1), Fraction(19, 1)},
319                       {Fraction(4, 1), Fraction(5, 1), Fraction(6, 1),
320                        Fraction(13, 1), Fraction(20, 1)},
321                       {Fraction(7, 1), Fraction(8, 1), Fraction(9, 1),
322                        Fraction(16, 1), Fraction(24, 1)}});
323 
324   FracMatrix gramSchmidt = makeFracMatrix(
325       3, 5,
326       {{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1), Fraction(12, 1),
327         Fraction(19, 1)},
328        {Fraction(142, 185), Fraction(383, 555), Fraction(68, 111),
329         Fraction(13, 185), Fraction(-262, 555)},
330        {Fraction(53, 463), Fraction(27, 463), Fraction(1, 463),
331         Fraction(-181, 463), Fraction(100, 463)}});
332 
333   FracMatrix gs = mat.gramSchmidt();
334 
335   EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
336   for (unsigned i = 0; i < 3u; i++)
337     for (unsigned j = i + 1; j < 3u; j++)
338       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
339 
340   mat = makeFracMatrix(3, 3,
341                        {{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
342                         {Fraction(20, 1), Fraction(18, 1), Fraction(6, 1)},
343                         {Fraction(15, 1), Fraction(14, 1), Fraction(10, 1)}});
344 
345   gramSchmidt = makeFracMatrix(
346       3, 3,
347       {{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
348        {Fraction(460, 789), Fraction(1180, 789), Fraction(-2926, 789)},
349        {Fraction(-2925, 3221), Fraction(3000, 3221), Fraction(750, 3221)}});
350 
351   gs = mat.gramSchmidt();
352 
353   EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
354   for (unsigned i = 0; i < 3u; i++)
355     for (unsigned j = i + 1; j < 3u; j++)
356       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
357 
358   mat = makeFracMatrix(
359       4, 4,
360       {{Fraction(1, 26), Fraction(13, 12), Fraction(34, 13), Fraction(7, 10)},
361        {Fraction(40, 23), Fraction(34, 1), Fraction(11, 19), Fraction(15, 1)},
362        {Fraction(21, 22), Fraction(10, 9), Fraction(4, 11), Fraction(14, 11)},
363        {Fraction(35, 22), Fraction(1, 15), Fraction(5, 8), Fraction(30, 1)}});
364 
365   gs = mat.gramSchmidt();
366 
367   // The integers involved are too big to construct the actual matrix.
368   // but we can check that the result is linearly independent.
369   ASSERT_FALSE(mat.determinant(nullptr) == 0);
370 
371   for (unsigned i = 0; i < 4u; i++)
372     for (unsigned j = i + 1; j < 4u; j++)
373       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
374 
375   mat = FracMatrix::identity(/*dim=*/10);
376 
377   gs = mat.gramSchmidt();
378 
379   EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
380 }
381 
382 void checkReducedBasis(FracMatrix mat, Fraction delta) {
383   FracMatrix gsOrth = mat.gramSchmidt();
384 
385   // Size-reduced check.
386   for (unsigned i = 0, e = mat.getNumRows(); i < e; i++) {
387     for (unsigned j = 0; j < i; j++) {
388       Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
389                     dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
390       EXPECT_TRUE(abs(mu) <= Fraction(1, 2));
391     }
392   }
393 
394   // Lovasz condition check.
395   for (unsigned i = 1, e = mat.getNumRows(); i < e; i++) {
396     Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
397                   dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
398     EXPECT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
399                 (delta - mu * mu) *
400                     dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
401   }
402 }
403 
404 TEST(MatrixTest, LLL) {
405   FracMatrix mat =
406       makeFracMatrix(3, 3,
407                      {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
408                       {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
409                       {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
410   mat.LLL(Fraction(3, 4));
411 
412   checkReducedBasis(mat, Fraction(3, 4));
413 
414   mat = makeFracMatrix(
415       2, 2,
416       {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
417   mat.LLL(Fraction(3, 4));
418 
419   checkReducedBasis(mat, Fraction(3, 4));
420 
421   mat = makeFracMatrix(3, 3,
422                        {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
423                         {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
424                         {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
425   mat.LLL(Fraction(3, 4));
426 
427   checkReducedBasis(mat, Fraction(3, 4));
428 }