xref: /llvm-project/mlir/unittests/Analysis/Presburger/MatrixTest.cpp (revision 1a0e67d73023e7ad9e7e79f66afb43a6f2561d04)
1 //===- MatrixTest.cpp - Tests for Matrix ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/Matrix.h"
10 #include "./Utils.h"
11 #include "mlir/Analysis/Presburger/Fraction.h"
12 #include <gmock/gmock.h>
13 #include <gtest/gtest.h>
14 
15 using namespace mlir;
16 using namespace presburger;
17 
TEST(MatrixTest,ReadWrite)18 TEST(MatrixTest, ReadWrite) {
19   IntMatrix mat(5, 5);
20   for (unsigned row = 0; row < 5; ++row)
21     for (unsigned col = 0; col < 5; ++col)
22       mat(row, col) = 10 * row + col;
23   for (unsigned row = 0; row < 5; ++row)
24     for (unsigned col = 0; col < 5; ++col)
25       EXPECT_EQ(mat(row, col), int(10 * row + col));
26 }
27 
TEST(MatrixTest,SwapColumns)28 TEST(MatrixTest, SwapColumns) {
29   IntMatrix mat(5, 5);
30   for (unsigned row = 0; row < 5; ++row)
31     for (unsigned col = 0; col < 5; ++col)
32       mat(row, col) = col == 3 ? 1 : 0;
33   mat.swapColumns(3, 1);
34   for (unsigned row = 0; row < 5; ++row)
35     for (unsigned col = 0; col < 5; ++col)
36       EXPECT_EQ(mat(row, col), col == 1 ? 1 : 0);
37 
38   // swap around all the other columns, swap (1, 3) twice for no effect.
39   mat.swapColumns(3, 1);
40   mat.swapColumns(2, 4);
41   mat.swapColumns(1, 3);
42   mat.swapColumns(0, 4);
43   mat.swapColumns(2, 2);
44 
45   for (unsigned row = 0; row < 5; ++row)
46     for (unsigned col = 0; col < 5; ++col)
47       EXPECT_EQ(mat(row, col), col == 1 ? 1 : 0);
48 }
49 
TEST(MatrixTest,SwapRows)50 TEST(MatrixTest, SwapRows) {
51   IntMatrix mat(5, 5);
52   for (unsigned row = 0; row < 5; ++row)
53     for (unsigned col = 0; col < 5; ++col)
54       mat(row, col) = row == 2 ? 1 : 0;
55   mat.swapRows(2, 0);
56   for (unsigned row = 0; row < 5; ++row)
57     for (unsigned col = 0; col < 5; ++col)
58       EXPECT_EQ(mat(row, col), row == 0 ? 1 : 0);
59 
60   // swap around all the other rows, swap (2, 0) twice for no effect.
61   mat.swapRows(3, 4);
62   mat.swapRows(1, 4);
63   mat.swapRows(2, 0);
64   mat.swapRows(1, 1);
65   mat.swapRows(0, 2);
66 
67   for (unsigned row = 0; row < 5; ++row)
68     for (unsigned col = 0; col < 5; ++col)
69       EXPECT_EQ(mat(row, col), row == 0 ? 1 : 0);
70 }
71 
TEST(MatrixTest,resizeVertically)72 TEST(MatrixTest, resizeVertically) {
73   IntMatrix mat(5, 5);
74   EXPECT_EQ(mat.getNumRows(), 5u);
75   EXPECT_EQ(mat.getNumColumns(), 5u);
76   for (unsigned row = 0; row < 5; ++row)
77     for (unsigned col = 0; col < 5; ++col)
78       mat(row, col) = 10 * row + col;
79 
80   mat.resizeVertically(3);
81   ASSERT_TRUE(mat.hasConsistentState());
82   EXPECT_EQ(mat.getNumRows(), 3u);
83   EXPECT_EQ(mat.getNumColumns(), 5u);
84   for (unsigned row = 0; row < 3; ++row)
85     for (unsigned col = 0; col < 5; ++col)
86       EXPECT_EQ(mat(row, col), int(10 * row + col));
87 
88   mat.resizeVertically(5);
89   ASSERT_TRUE(mat.hasConsistentState());
90   EXPECT_EQ(mat.getNumRows(), 5u);
91   EXPECT_EQ(mat.getNumColumns(), 5u);
92   for (unsigned row = 0; row < 5; ++row)
93     for (unsigned col = 0; col < 5; ++col)
94       EXPECT_EQ(mat(row, col), row >= 3 ? 0 : int(10 * row + col));
95 }
96 
TEST(MatrixTest,insertColumns)97 TEST(MatrixTest, insertColumns) {
98   IntMatrix mat(5, 5, 5, 10);
99   EXPECT_EQ(mat.getNumRows(), 5u);
100   EXPECT_EQ(mat.getNumColumns(), 5u);
101   for (unsigned row = 0; row < 5; ++row)
102     for (unsigned col = 0; col < 5; ++col)
103       mat(row, col) = 10 * row + col;
104 
105   mat.insertColumns(3, 100);
106   ASSERT_TRUE(mat.hasConsistentState());
107   EXPECT_EQ(mat.getNumRows(), 5u);
108   EXPECT_EQ(mat.getNumColumns(), 105u);
109   for (unsigned row = 0; row < 5; ++row) {
110     for (unsigned col = 0; col < 105; ++col) {
111       if (col < 3)
112         EXPECT_EQ(mat(row, col), int(10 * row + col));
113       else if (3 <= col && col <= 102)
114         EXPECT_EQ(mat(row, col), 0);
115       else
116         EXPECT_EQ(mat(row, col), int(10 * row + col - 100));
117     }
118   }
119 
120   mat.removeColumns(3, 100);
121   ASSERT_TRUE(mat.hasConsistentState());
122   mat.insertColumns(0, 0);
123   ASSERT_TRUE(mat.hasConsistentState());
124   mat.insertColumn(5);
125   ASSERT_TRUE(mat.hasConsistentState());
126 
127   EXPECT_EQ(mat.getNumRows(), 5u);
128   EXPECT_EQ(mat.getNumColumns(), 6u);
129   for (unsigned row = 0; row < 5; ++row)
130     for (unsigned col = 0; col < 6; ++col)
131       EXPECT_EQ(mat(row, col), col == 5 ? 0 : 10 * row + col);
132 }
133 
TEST(MatrixTest,insertRows)134 TEST(MatrixTest, insertRows) {
135   IntMatrix mat(5, 5, 5, 10);
136   ASSERT_TRUE(mat.hasConsistentState());
137   EXPECT_EQ(mat.getNumRows(), 5u);
138   EXPECT_EQ(mat.getNumColumns(), 5u);
139   for (unsigned row = 0; row < 5; ++row)
140     for (unsigned col = 0; col < 5; ++col)
141       mat(row, col) = 10 * row + col;
142 
143   mat.insertRows(3, 100);
144   ASSERT_TRUE(mat.hasConsistentState());
145   EXPECT_EQ(mat.getNumRows(), 105u);
146   EXPECT_EQ(mat.getNumColumns(), 5u);
147   for (unsigned row = 0; row < 105; ++row) {
148     for (unsigned col = 0; col < 5; ++col) {
149       if (row < 3)
150         EXPECT_EQ(mat(row, col), int(10 * row + col));
151       else if (3 <= row && row <= 102)
152         EXPECT_EQ(mat(row, col), 0);
153       else
154         EXPECT_EQ(mat(row, col), int(10 * (row - 100) + col));
155     }
156   }
157 
158   mat.removeRows(3, 100);
159   ASSERT_TRUE(mat.hasConsistentState());
160   mat.insertRows(0, 0);
161   ASSERT_TRUE(mat.hasConsistentState());
162   mat.insertRow(5);
163   ASSERT_TRUE(mat.hasConsistentState());
164 
165   EXPECT_EQ(mat.getNumRows(), 6u);
166   EXPECT_EQ(mat.getNumColumns(), 5u);
167   for (unsigned row = 0; row < 6; ++row)
168     for (unsigned col = 0; col < 5; ++col)
169       EXPECT_EQ(mat(row, col), row == 5 ? 0 : 10 * row + col);
170 }
171 
TEST(MatrixTest,resize)172 TEST(MatrixTest, resize) {
173   IntMatrix mat(5, 5);
174   EXPECT_EQ(mat.getNumRows(), 5u);
175   EXPECT_EQ(mat.getNumColumns(), 5u);
176   for (unsigned row = 0; row < 5; ++row)
177     for (unsigned col = 0; col < 5; ++col)
178       mat(row, col) = 10 * row + col;
179 
180   mat.resize(3, 3);
181   ASSERT_TRUE(mat.hasConsistentState());
182   EXPECT_EQ(mat.getNumRows(), 3u);
183   EXPECT_EQ(mat.getNumColumns(), 3u);
184   for (unsigned row = 0; row < 3; ++row)
185     for (unsigned col = 0; col < 3; ++col)
186       EXPECT_EQ(mat(row, col), int(10 * row + col));
187 
188   mat.resize(7, 7);
189   ASSERT_TRUE(mat.hasConsistentState());
190   EXPECT_EQ(mat.getNumRows(), 7u);
191   EXPECT_EQ(mat.getNumColumns(), 7u);
192   for (unsigned row = 0; row < 7; ++row)
193     for (unsigned col = 0; col < 7; ++col)
194       EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
195 }
196 
197 template <typename T>
checkMatEqual(const Matrix<T> m1,const Matrix<T> m2)198 static void checkMatEqual(const Matrix<T> m1, const Matrix<T> m2) {
199   EXPECT_EQ(m1.getNumRows(), m2.getNumRows());
200   EXPECT_EQ(m1.getNumColumns(), m2.getNumColumns());
201 
202   for (unsigned row = 0, rows = m1.getNumRows(); row < rows; ++row)
203     for (unsigned col = 0, cols = m1.getNumColumns(); col < cols; ++col)
204       EXPECT_EQ(m1(row, col), m2(row, col));
205 }
206 
checkHermiteNormalForm(const IntMatrix & mat,const IntMatrix & hermiteForm)207 static void checkHermiteNormalForm(const IntMatrix &mat,
208                                    const IntMatrix &hermiteForm) {
209   auto [h, u] = mat.computeHermiteNormalForm();
210 
211   checkMatEqual(h, hermiteForm);
212 }
213 
TEST(MatrixTest,computeHermiteNormalForm)214 TEST(MatrixTest, computeHermiteNormalForm) {
215   // TODO: Add a check to test the original statement of hermite normal form
216   // instead of using a precomputed result.
217 
218   {
219     // Hermite form of a unimodular matrix is the identity matrix.
220     IntMatrix mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
221     IntMatrix hermiteForm =
222         makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
223     checkHermiteNormalForm(mat, hermiteForm);
224   }
225 
226   {
227     // Hermite form of a unimodular is the identity matrix.
228     IntMatrix mat = makeIntMatrix(
229         4, 4,
230         {{-6, -1, -19, -20}, {0, 1, 0, 0}, {-5, 0, -15, -16}, {6, 0, 18, 19}});
231     IntMatrix hermiteForm = makeIntMatrix(
232         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}});
233     checkHermiteNormalForm(mat, hermiteForm);
234   }
235 
236   {
237     IntMatrix mat = makeIntMatrix(
238         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
239     IntMatrix hermiteForm = makeIntMatrix(
240         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
241     checkHermiteNormalForm(mat, hermiteForm);
242   }
243 
244   {
245     IntMatrix mat = makeIntMatrix(
246         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
247     IntMatrix hermiteForm = makeIntMatrix(
248         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
249     checkHermiteNormalForm(mat, hermiteForm);
250   }
251 
252   {
253     IntMatrix mat = makeIntMatrix(
254         3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
255     IntMatrix hermiteForm = makeIntMatrix(
256         3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
257     checkHermiteNormalForm(mat, hermiteForm);
258   }
259 }
260 
TEST(MatrixTest,inverse)261 TEST(MatrixTest, inverse) {
262   IntMatrix mat1 = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
263   EXPECT_EQ(mat1.determinant(), -7);
264 
265   FracMatrix mat = makeFracMatrix(
266       2, 2, {{Fraction(2), Fraction(1)}, {Fraction(7), Fraction(0)}});
267   FracMatrix inverse = makeFracMatrix(
268       2, 2, {{Fraction(0), Fraction(1, 7)}, {Fraction(1), Fraction(-2, 7)}});
269 
270   FracMatrix inv(2, 2);
271   mat.determinant(&inv);
272 
273   EXPECT_EQ_FRAC_MATRIX(inv, inverse);
274 
275   mat = makeFracMatrix(
276       2, 2, {{Fraction(0), Fraction(1)}, {Fraction(0), Fraction(2)}});
277   Fraction det = mat.determinant(nullptr);
278 
279   EXPECT_EQ(det, Fraction(0));
280 
281   mat = makeFracMatrix(3, 3,
282                        {{Fraction(1), Fraction(2), Fraction(3)},
283                         {Fraction(4), Fraction(8), Fraction(6)},
284                         {Fraction(7), Fraction(8), Fraction(6)}});
285   inverse = makeFracMatrix(3, 3,
286                            {{Fraction(0), Fraction(-1, 3), Fraction(1, 3)},
287                             {Fraction(-1, 2), Fraction(5, 12), Fraction(-1, 6)},
288                             {Fraction(2, 3), Fraction(-1, 6), Fraction(0)}});
289 
290   mat.determinant(&inv);
291   EXPECT_EQ_FRAC_MATRIX(inv, inverse);
292 
293   mat = makeFracMatrix(0, 0, {});
294   mat.determinant(&inv);
295 }
296 
TEST(MatrixTest,intInverse)297 TEST(MatrixTest, intInverse) {
298   IntMatrix mat = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
299   IntMatrix inverse = makeIntMatrix(2, 2, {{0, -1}, {-7, 2}});
300 
301   IntMatrix inv(2, 2);
302   mat.determinant(&inv);
303 
304   EXPECT_EQ_INT_MATRIX(inv, inverse);
305 
306   mat = makeIntMatrix(
307       4, 4, {{4, 14, 11, 3}, {13, 5, 14, 12}, {13, 9, 7, 14}, {2, 3, 12, 7}});
308   inverse = makeIntMatrix(4, 4,
309                           {{155, 1636, -579, -1713},
310                            {725, -743, 537, -111},
311                            {210, 735, -855, 360},
312                            {-715, -1409, 1401, 1482}});
313 
314   mat.determinant(&inv);
315 
316   EXPECT_EQ_INT_MATRIX(inv, inverse);
317 
318   mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}});
319 
320   DynamicAPInt det = mat.determinant(&inv);
321 
322   EXPECT_EQ(det, 0);
323 }
324 
TEST(MatrixTest,gramSchmidt)325 TEST(MatrixTest, gramSchmidt) {
326   FracMatrix mat =
327       makeFracMatrix(3, 5,
328                      {{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1),
329                        Fraction(12, 1), Fraction(19, 1)},
330                       {Fraction(4, 1), Fraction(5, 1), Fraction(6, 1),
331                        Fraction(13, 1), Fraction(20, 1)},
332                       {Fraction(7, 1), Fraction(8, 1), Fraction(9, 1),
333                        Fraction(16, 1), Fraction(24, 1)}});
334 
335   FracMatrix gramSchmidt = makeFracMatrix(
336       3, 5,
337       {{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1), Fraction(12, 1),
338         Fraction(19, 1)},
339        {Fraction(142, 185), Fraction(383, 555), Fraction(68, 111),
340         Fraction(13, 185), Fraction(-262, 555)},
341        {Fraction(53, 463), Fraction(27, 463), Fraction(1, 463),
342         Fraction(-181, 463), Fraction(100, 463)}});
343 
344   FracMatrix gs = mat.gramSchmidt();
345 
346   EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
347   for (unsigned i = 0; i < 3u; i++)
348     for (unsigned j = i + 1; j < 3u; j++)
349       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
350 
351   mat = makeFracMatrix(3, 3,
352                        {{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
353                         {Fraction(20, 1), Fraction(18, 1), Fraction(6, 1)},
354                         {Fraction(15, 1), Fraction(14, 1), Fraction(10, 1)}});
355 
356   gramSchmidt = makeFracMatrix(
357       3, 3,
358       {{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
359        {Fraction(460, 789), Fraction(1180, 789), Fraction(-2926, 789)},
360        {Fraction(-2925, 3221), Fraction(3000, 3221), Fraction(750, 3221)}});
361 
362   gs = mat.gramSchmidt();
363 
364   EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
365   for (unsigned i = 0; i < 3u; i++)
366     for (unsigned j = i + 1; j < 3u; j++)
367       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
368 
369   mat = makeFracMatrix(
370       4, 4,
371       {{Fraction(1, 26), Fraction(13, 12), Fraction(34, 13), Fraction(7, 10)},
372        {Fraction(40, 23), Fraction(34, 1), Fraction(11, 19), Fraction(15, 1)},
373        {Fraction(21, 22), Fraction(10, 9), Fraction(4, 11), Fraction(14, 11)},
374        {Fraction(35, 22), Fraction(1, 15), Fraction(5, 8), Fraction(30, 1)}});
375 
376   gs = mat.gramSchmidt();
377 
378   // The integers involved are too big to construct the actual matrix.
379   // but we can check that the result is linearly independent.
380   ASSERT_FALSE(mat.determinant(nullptr) == 0);
381 
382   for (unsigned i = 0; i < 4u; i++)
383     for (unsigned j = i + 1; j < 4u; j++)
384       EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
385 
386   mat = FracMatrix::identity(/*dim=*/10);
387 
388   gs = mat.gramSchmidt();
389 
390   EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
391 }
392 
checkReducedBasis(FracMatrix mat,Fraction delta)393 void checkReducedBasis(FracMatrix mat, Fraction delta) {
394   FracMatrix gsOrth = mat.gramSchmidt();
395 
396   // Size-reduced check.
397   for (unsigned i = 0, e = mat.getNumRows(); i < e; i++) {
398     for (unsigned j = 0; j < i; j++) {
399       Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
400                     dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
401       EXPECT_TRUE(abs(mu) <= Fraction(1, 2));
402     }
403   }
404 
405   // Lovasz condition check.
406   for (unsigned i = 1, e = mat.getNumRows(); i < e; i++) {
407     Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
408                   dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
409     EXPECT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
410                 (delta - mu * mu) *
411                     dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
412   }
413 }
414 
TEST(MatrixTest,LLL)415 TEST(MatrixTest, LLL) {
416   FracMatrix mat =
417       makeFracMatrix(3, 3,
418                      {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
419                       {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
420                       {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
421   mat.LLL(Fraction(3, 4));
422 
423   checkReducedBasis(mat, Fraction(3, 4));
424 
425   mat = makeFracMatrix(
426       2, 2,
427       {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
428   mat.LLL(Fraction(3, 4));
429 
430   checkReducedBasis(mat, Fraction(3, 4));
431 
432   mat = makeFracMatrix(3, 3,
433                        {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
434                         {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
435                         {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
436   mat.LLL(Fraction(3, 4));
437 
438   checkReducedBasis(mat, Fraction(3, 4));
439 }
440 
TEST(MatrixTest,moveColumns)441 TEST(MatrixTest, moveColumns) {
442   IntMatrix mat =
443       makeIntMatrix(3, 4, {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 4, 2}});
444 
445   {
446     IntMatrix movedMat =
447         makeIntMatrix(3, 4, {{0, 3, 1, 2}, {4, 7, 5, 6}, {8, 2, 9, 4}});
448 
449     movedMat.moveColumns(2, 2, 1);
450     checkMatEqual(mat, movedMat);
451   }
452 
453   {
454     IntMatrix movedMat =
455         makeIntMatrix(3, 4, {{0, 3, 1, 2}, {4, 7, 5, 6}, {8, 2, 9, 4}});
456 
457     movedMat.moveColumns(1, 1, 3);
458     checkMatEqual(mat, movedMat);
459   }
460 
461   {
462     IntMatrix movedMat =
463         makeIntMatrix(3, 4, {{1, 2, 0, 3}, {5, 6, 4, 7}, {9, 4, 8, 2}});
464 
465     movedMat.moveColumns(0, 2, 1);
466     checkMatEqual(mat, movedMat);
467   }
468 
469   {
470     IntMatrix movedMat =
471         makeIntMatrix(3, 4, {{1, 0, 2, 3}, {5, 4, 6, 7}, {9, 8, 4, 2}});
472 
473     movedMat.moveColumns(0, 1, 1);
474     checkMatEqual(mat, movedMat);
475   }
476 }
477