xref: /llvm-project/mlir/lib/Analysis/Presburger/Matrix.cpp (revision 2740273505ab27c0d8531d35948f0647309842cd)
110a898b3SArjun P //===- Matrix.cpp - MLIR Matrix Class -------------------------------------===//
210a898b3SArjun P //
310a898b3SArjun P // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
410a898b3SArjun P // See https://llvm.org/LICENSE.txt for license information.
510a898b3SArjun P // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
610a898b3SArjun P //
710a898b3SArjun P //===----------------------------------------------------------------------===//
810a898b3SArjun P 
910a898b3SArjun P #include "mlir/Analysis/Presburger/Matrix.h"
10c1b99746SAbhinav271828 #include "mlir/Analysis/Presburger/Fraction.h"
11aafb4282SArjun P #include "mlir/Analysis/Presburger/Utils.h"
12c605dfcfSArjun P #include "llvm/Support/MathExtras.h"
132ee87cd6SMehdi Amini #include "llvm/Support/raw_ostream.h"
142ee87cd6SMehdi Amini #include <algorithm>
152ee87cd6SMehdi Amini #include <cassert>
162ee87cd6SMehdi Amini #include <utility>
1710a898b3SArjun P 
180c1f6865SGroverkss using namespace mlir;
190c1f6865SGroverkss using namespace presburger;
2010a898b3SArjun P 
217025ff6fSArjun P template <typename T>
227025ff6fSArjun P Matrix<T>::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
23c605dfcfSArjun P                   unsigned reservedColumns)
24c605dfcfSArjun P     : nRows(rows), nColumns(columns),
25c605dfcfSArjun P       nReservedColumns(std::max(nColumns, reservedColumns)),
26c605dfcfSArjun P       data(nRows * nReservedColumns) {
27c605dfcfSArjun P   data.reserve(std::max(nRows, reservedRows) * nReservedColumns);
28c605dfcfSArjun P }
2910a898b3SArjun P 
30562790f3SAbhinav271828 /// We cannot use the default implementation of operator== as it compares
31562790f3SAbhinav271828 /// fields like `reservedColumns` etc., which are not part of the data.
32562790f3SAbhinav271828 template <typename T>
33562790f3SAbhinav271828 bool Matrix<T>::operator==(const Matrix<T> &m) const {
34562790f3SAbhinav271828   if (nRows != m.getNumRows())
35562790f3SAbhinav271828     return false;
36562790f3SAbhinav271828   if (nColumns != m.getNumColumns())
37562790f3SAbhinav271828     return false;
38562790f3SAbhinav271828 
39562790f3SAbhinav271828   for (unsigned i = 0; i < nRows; i++)
40562790f3SAbhinav271828     if (getRow(i) != m.getRow(i))
41562790f3SAbhinav271828       return false;
42562790f3SAbhinav271828 
43562790f3SAbhinav271828   return true;
44562790f3SAbhinav271828 }
45562790f3SAbhinav271828 
467025ff6fSArjun P template <typename T>
477025ff6fSArjun P Matrix<T> Matrix<T>::identity(unsigned dimension) {
4810a898b3SArjun P   Matrix matrix(dimension, dimension);
4910a898b3SArjun P   for (unsigned i = 0; i < dimension; ++i)
5010a898b3SArjun P     matrix(i, i) = 1;
5110a898b3SArjun P   return matrix;
5210a898b3SArjun P }
5310a898b3SArjun P 
547025ff6fSArjun P template <typename T>
557025ff6fSArjun P unsigned Matrix<T>::getNumReservedRows() const {
56c605dfcfSArjun P   return data.capacity() / nReservedColumns;
57c605dfcfSArjun P }
58c605dfcfSArjun P 
597025ff6fSArjun P template <typename T>
607025ff6fSArjun P void Matrix<T>::reserveRows(unsigned rows) {
61c605dfcfSArjun P   data.reserve(rows * nReservedColumns);
62c605dfcfSArjun P }
63c605dfcfSArjun P 
647025ff6fSArjun P template <typename T>
657025ff6fSArjun P unsigned Matrix<T>::appendExtraRow() {
66c605dfcfSArjun P   resizeVertically(nRows + 1);
67c605dfcfSArjun P   return nRows - 1;
68c605dfcfSArjun P }
69c605dfcfSArjun P 
707025ff6fSArjun P template <typename T>
717025ff6fSArjun P unsigned Matrix<T>::appendExtraRow(ArrayRef<T> elems) {
7279ad5fb2SArjun P   assert(elems.size() == nColumns && "elems must match row length!");
7379ad5fb2SArjun P   unsigned row = appendExtraRow();
7479ad5fb2SArjun P   for (unsigned col = 0; col < nColumns; ++col)
7579ad5fb2SArjun P     at(row, col) = elems[col];
7679ad5fb2SArjun P   return row;
7779ad5fb2SArjun P }
7879ad5fb2SArjun P 
797025ff6fSArjun P template <typename T>
802dde029dSAbhinav271828 Matrix<T> Matrix<T>::transpose() const {
812dde029dSAbhinav271828   Matrix<T> transp(nColumns, nRows);
822dde029dSAbhinav271828   for (unsigned row = 0; row < nRows; ++row)
832dde029dSAbhinav271828     for (unsigned col = 0; col < nColumns; ++col)
842dde029dSAbhinav271828       transp(col, row) = at(row, col);
852dde029dSAbhinav271828 
862dde029dSAbhinav271828   return transp;
872dde029dSAbhinav271828 }
882dde029dSAbhinav271828 
892dde029dSAbhinav271828 template <typename T>
907025ff6fSArjun P void Matrix<T>::resizeHorizontally(unsigned newNColumns) {
91f263ea15SArjun P   if (newNColumns < nColumns)
92f263ea15SArjun P     removeColumns(newNColumns, nColumns - newNColumns);
93f263ea15SArjun P   if (newNColumns > nColumns)
94f263ea15SArjun P     insertColumns(nColumns, newNColumns - nColumns);
95f263ea15SArjun P }
96f263ea15SArjun P 
977025ff6fSArjun P template <typename T>
987025ff6fSArjun P void Matrix<T>::resize(unsigned newNRows, unsigned newNColumns) {
99f263ea15SArjun P   resizeHorizontally(newNColumns);
100f263ea15SArjun P   resizeVertically(newNRows);
101f263ea15SArjun P }
102f263ea15SArjun P 
1037025ff6fSArjun P template <typename T>
1047025ff6fSArjun P void Matrix<T>::resizeVertically(unsigned newNRows) {
10510a898b3SArjun P   nRows = newNRows;
106c605dfcfSArjun P   data.resize(nRows * nReservedColumns);
10710a898b3SArjun P }
10810a898b3SArjun P 
1097025ff6fSArjun P template <typename T>
1107025ff6fSArjun P void Matrix<T>::swapRows(unsigned row, unsigned otherRow) {
11110a898b3SArjun P   assert((row < getNumRows() && otherRow < getNumRows()) &&
11210a898b3SArjun P          "Given row out of bounds");
11310a898b3SArjun P   if (row == otherRow)
11410a898b3SArjun P     return;
11510a898b3SArjun P   for (unsigned col = 0; col < nColumns; col++)
11610a898b3SArjun P     std::swap(at(row, col), at(otherRow, col));
11710a898b3SArjun P }
11810a898b3SArjun P 
1197025ff6fSArjun P template <typename T>
1207025ff6fSArjun P void Matrix<T>::swapColumns(unsigned column, unsigned otherColumn) {
12110a898b3SArjun P   assert((column < getNumColumns() && otherColumn < getNumColumns()) &&
12210a898b3SArjun P          "Given column out of bounds");
12310a898b3SArjun P   if (column == otherColumn)
12410a898b3SArjun P     return;
12510a898b3SArjun P   for (unsigned row = 0; row < nRows; row++)
12610a898b3SArjun P     std::swap(at(row, column), at(row, otherColumn));
12710a898b3SArjun P }
12810a898b3SArjun P 
1297025ff6fSArjun P template <typename T>
1307025ff6fSArjun P MutableArrayRef<T> Matrix<T>::getRow(unsigned row) {
131d98dfdeaSArjun P   return {&data[row * nReservedColumns], nColumns};
132d98dfdeaSArjun P }
133d98dfdeaSArjun P 
1347025ff6fSArjun P template <typename T>
1357025ff6fSArjun P ArrayRef<T> Matrix<T>::getRow(unsigned row) const {
136c605dfcfSArjun P   return {&data[row * nReservedColumns], nColumns};
137c605dfcfSArjun P }
138c605dfcfSArjun P 
1397025ff6fSArjun P template <typename T>
1407025ff6fSArjun P void Matrix<T>::setRow(unsigned row, ArrayRef<T> elems) {
141479c4f64SGroverkss   assert(elems.size() == getNumColumns() &&
142479c4f64SGroverkss          "elems size must match row length!");
143479c4f64SGroverkss   for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
144479c4f64SGroverkss     at(row, i) = elems[i];
145479c4f64SGroverkss }
146479c4f64SGroverkss 
1477025ff6fSArjun P template <typename T>
1487025ff6fSArjun P void Matrix<T>::insertColumn(unsigned pos) {
1497025ff6fSArjun P   insertColumns(pos, 1);
1507025ff6fSArjun P }
1517025ff6fSArjun P template <typename T>
1527025ff6fSArjun P void Matrix<T>::insertColumns(unsigned pos, unsigned count) {
153c605dfcfSArjun P   if (count == 0)
154c605dfcfSArjun P     return;
155c605dfcfSArjun P   assert(pos <= nColumns);
156c605dfcfSArjun P   unsigned oldNReservedColumns = nReservedColumns;
157c605dfcfSArjun P   if (nColumns + count > nReservedColumns) {
158c605dfcfSArjun P     nReservedColumns = llvm::NextPowerOf2(nColumns + count);
159c605dfcfSArjun P     data.resize(nRows * nReservedColumns);
160c605dfcfSArjun P   }
161c605dfcfSArjun P   nColumns += count;
162c605dfcfSArjun P 
163c605dfcfSArjun P   for (int ri = nRows - 1; ri >= 0; --ri) {
164c605dfcfSArjun P     for (int ci = nReservedColumns - 1; ci >= 0; --ci) {
165c605dfcfSArjun P       unsigned r = ri;
166c605dfcfSArjun P       unsigned c = ci;
167c1b99746SAbhinav271828       T &dest = data[r * nReservedColumns + c];
16830c0a148SArjun P       if (c >= nColumns) { // NOLINT
16930c0a148SArjun P         // Out of bounds columns are zero-initialized. NOLINT because clang-tidy
17030c0a148SArjun P         // complains about this branch being the same as the c >= pos one.
17130c0a148SArjun P         //
17230c0a148SArjun P         // TODO: this case can be skipped if the number of reserved columns
17330c0a148SArjun P         // didn't change.
174c605dfcfSArjun P         dest = 0;
17530c0a148SArjun P       } else if (c >= pos + count) {
17630c0a148SArjun P         // Shift the data occuring after the inserted columns.
177c605dfcfSArjun P         dest = data[r * oldNReservedColumns + c - count];
17830c0a148SArjun P       } else if (c >= pos) {
17930c0a148SArjun P         // The inserted columns are also zero-initialized.
180c605dfcfSArjun P         dest = 0;
18130c0a148SArjun P       } else {
18230c0a148SArjun P         // The columns before the inserted columns stay at the same (row, col)
18330c0a148SArjun P         // but this corresponds to a different location in the linearized array
18430c0a148SArjun P         // if the number of reserved columns changed.
18530c0a148SArjun P         if (nReservedColumns == oldNReservedColumns)
18630c0a148SArjun P           break;
187c605dfcfSArjun P         dest = data[r * oldNReservedColumns + c];
188c605dfcfSArjun P       }
189c605dfcfSArjun P     }
190c605dfcfSArjun P   }
19130c0a148SArjun P }
192c605dfcfSArjun P 
1937025ff6fSArjun P template <typename T>
1947025ff6fSArjun P void Matrix<T>::removeColumn(unsigned pos) {
1957025ff6fSArjun P   removeColumns(pos, 1);
1967025ff6fSArjun P }
1977025ff6fSArjun P template <typename T>
1987025ff6fSArjun P void Matrix<T>::removeColumns(unsigned pos, unsigned count) {
199c605dfcfSArjun P   if (count == 0)
200c605dfcfSArjun P     return;
201c605dfcfSArjun P   assert(pos + count - 1 < nColumns);
202c605dfcfSArjun P   for (unsigned r = 0; r < nRows; ++r) {
203c605dfcfSArjun P     for (unsigned c = pos; c < nColumns - count; ++c)
204c605dfcfSArjun P       at(r, c) = at(r, c + count);
205c605dfcfSArjun P     for (unsigned c = nColumns - count; c < nColumns; ++c)
206c605dfcfSArjun P       at(r, c) = 0;
207c605dfcfSArjun P   }
208c605dfcfSArjun P   nColumns -= count;
209c605dfcfSArjun P }
210c605dfcfSArjun P 
2117025ff6fSArjun P template <typename T>
2127025ff6fSArjun P void Matrix<T>::insertRow(unsigned pos) {
2137025ff6fSArjun P   insertRows(pos, 1);
2147025ff6fSArjun P }
2157025ff6fSArjun P template <typename T>
2167025ff6fSArjun P void Matrix<T>::insertRows(unsigned pos, unsigned count) {
217c605dfcfSArjun P   if (count == 0)
218c605dfcfSArjun P     return;
219c605dfcfSArjun P 
220c605dfcfSArjun P   assert(pos <= nRows);
221c605dfcfSArjun P   resizeVertically(nRows + count);
222c605dfcfSArjun P   for (int r = nRows - 1; r >= int(pos + count); --r)
223c605dfcfSArjun P     copyRow(r - count, r);
224c605dfcfSArjun P   for (int r = pos + count - 1; r >= int(pos); --r)
225c605dfcfSArjun P     for (unsigned c = 0; c < nColumns; ++c)
226c605dfcfSArjun P       at(r, c) = 0;
227c605dfcfSArjun P }
228c605dfcfSArjun P 
2297025ff6fSArjun P template <typename T>
2307025ff6fSArjun P void Matrix<T>::removeRow(unsigned pos) {
2317025ff6fSArjun P   removeRows(pos, 1);
2327025ff6fSArjun P }
2337025ff6fSArjun P template <typename T>
2347025ff6fSArjun P void Matrix<T>::removeRows(unsigned pos, unsigned count) {
235c605dfcfSArjun P   if (count == 0)
236c605dfcfSArjun P     return;
237c605dfcfSArjun P   assert(pos + count - 1 <= nRows);
238c605dfcfSArjun P   for (unsigned r = pos; r + count < nRows; ++r)
239c605dfcfSArjun P     copyRow(r + count, r);
240c605dfcfSArjun P   resizeVertically(nRows - count);
241c605dfcfSArjun P }
242c605dfcfSArjun P 
2437025ff6fSArjun P template <typename T>
2447025ff6fSArjun P void Matrix<T>::copyRow(unsigned sourceRow, unsigned targetRow) {
245c605dfcfSArjun P   if (sourceRow == targetRow)
246c605dfcfSArjun P     return;
247c605dfcfSArjun P   for (unsigned c = 0; c < nColumns; ++c)
248c605dfcfSArjun P     at(targetRow, c) = at(sourceRow, c);
24910a898b3SArjun P }
25010a898b3SArjun P 
2517025ff6fSArjun P template <typename T>
2527025ff6fSArjun P void Matrix<T>::fillRow(unsigned row, const T &value) {
2536db01958SArjun P   for (unsigned col = 0; col < nColumns; ++col)
2546db01958SArjun P     at(row, col) = value;
2556db01958SArjun P }
2566db01958SArjun P 
25766786a79SBharathi Ramana Joshi // moveColumns is implemented by moving the columns adjacent to the source range
25866786a79SBharathi Ramana Joshi // to their final position. When moving right (i.e. dstPos > srcPos), the range
25966786a79SBharathi Ramana Joshi // of the adjacent columns is [srcPos + num, dstPos + num). When moving left
26066786a79SBharathi Ramana Joshi // (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos).
26166786a79SBharathi Ramana Joshi // First, zeroed out columns are inserted in the final positions of the adjacent
26266786a79SBharathi Ramana Joshi // columns. Then, the adjacent columns are moved to their final positions by
26366786a79SBharathi Ramana Joshi // swapping them with the zeroed columns. Finally, the now zeroed adjacent
26466786a79SBharathi Ramana Joshi // columns are deleted.
26566786a79SBharathi Ramana Joshi template <typename T>
26666786a79SBharathi Ramana Joshi void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) {
26766786a79SBharathi Ramana Joshi   if (num == 0)
26866786a79SBharathi Ramana Joshi     return;
26966786a79SBharathi Ramana Joshi 
27066786a79SBharathi Ramana Joshi   int offset = dstPos - srcPos;
27166786a79SBharathi Ramana Joshi   if (offset == 0)
27266786a79SBharathi Ramana Joshi     return;
27366786a79SBharathi Ramana Joshi 
27466786a79SBharathi Ramana Joshi   assert(srcPos + num <= getNumColumns() &&
27566786a79SBharathi Ramana Joshi          "move source range exceeds matrix columns");
27666786a79SBharathi Ramana Joshi   assert(dstPos + num <= getNumColumns() &&
27766786a79SBharathi Ramana Joshi          "move destination range exceeds matrix columns");
27866786a79SBharathi Ramana Joshi 
27966786a79SBharathi Ramana Joshi   unsigned insertCount = offset > 0 ? offset : -offset;
28066786a79SBharathi Ramana Joshi   unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num;
28166786a79SBharathi Ramana Joshi   unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos;
28266786a79SBharathi Ramana Joshi   // TODO: This can be done using std::rotate.
28366786a79SBharathi Ramana Joshi   // Insert new zero columns in the positions where the adjacent columns are to
28466786a79SBharathi Ramana Joshi   // be moved.
28566786a79SBharathi Ramana Joshi   insertColumns(finalAdjStart, insertCount);
28666786a79SBharathi Ramana Joshi   // Update curAdjStart if insertion of new columns invalidates it.
28766786a79SBharathi Ramana Joshi   if (finalAdjStart < curAdjStart)
28866786a79SBharathi Ramana Joshi     curAdjStart += insertCount;
28966786a79SBharathi Ramana Joshi 
29066786a79SBharathi Ramana Joshi   // Swap the adjacent columns with inserted zero columns.
29166786a79SBharathi Ramana Joshi   for (unsigned i = 0; i < insertCount; ++i)
29266786a79SBharathi Ramana Joshi     swapColumns(finalAdjStart + i, curAdjStart + i);
29366786a79SBharathi Ramana Joshi 
29466786a79SBharathi Ramana Joshi   // Delete the now redundant zero columns.
29566786a79SBharathi Ramana Joshi   removeColumns(curAdjStart, insertCount);
29666786a79SBharathi Ramana Joshi }
29766786a79SBharathi Ramana Joshi 
2987025ff6fSArjun P template <typename T>
2997025ff6fSArjun P void Matrix<T>::addToRow(unsigned sourceRow, unsigned targetRow,
300c1b99746SAbhinav271828                          const T &scale) {
301bb2226acSGroverkss   addToRow(targetRow, getRow(sourceRow), scale);
302bb2226acSGroverkss }
303bb2226acSGroverkss 
3047025ff6fSArjun P template <typename T>
3057025ff6fSArjun P void Matrix<T>::addToRow(unsigned row, ArrayRef<T> rowVec, const T &scale) {
30610a898b3SArjun P   if (scale == 0)
30710a898b3SArjun P     return;
30810a898b3SArjun P   for (unsigned col = 0; col < nColumns; ++col)
309bb2226acSGroverkss     at(row, col) += scale * rowVec[col];
31010a898b3SArjun P }
31110a898b3SArjun P 
3127025ff6fSArjun P template <typename T>
313562790f3SAbhinav271828 void Matrix<T>::scaleRow(unsigned row, const T &scale) {
314562790f3SAbhinav271828   for (unsigned col = 0; col < nColumns; ++col)
315562790f3SAbhinav271828     at(row, col) *= scale;
316562790f3SAbhinav271828 }
317562790f3SAbhinav271828 
318562790f3SAbhinav271828 template <typename T>
3197025ff6fSArjun P void Matrix<T>::addToColumn(unsigned sourceColumn, unsigned targetColumn,
320c1b99746SAbhinav271828                             const T &scale) {
3216ebeba88SArjun P   if (scale == 0)
3226ebeba88SArjun P     return;
3236ebeba88SArjun P   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
3246ebeba88SArjun P     at(row, targetColumn) += scale * at(row, sourceColumn);
3256ebeba88SArjun P }
3266ebeba88SArjun P 
3277025ff6fSArjun P template <typename T>
3287025ff6fSArjun P void Matrix<T>::negateColumn(unsigned column) {
3296ebeba88SArjun P   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
3306ebeba88SArjun P     at(row, column) = -at(row, column);
3316ebeba88SArjun P }
3326ebeba88SArjun P 
3337025ff6fSArjun P template <typename T>
3347025ff6fSArjun P void Matrix<T>::negateRow(unsigned row) {
335eff51cf9SGroverkss   for (unsigned column = 0, e = getNumColumns(); column < e; ++column)
336eff51cf9SGroverkss     at(row, column) = -at(row, column);
337eff51cf9SGroverkss }
338eff51cf9SGroverkss 
3397025ff6fSArjun P template <typename T>
340562790f3SAbhinav271828 void Matrix<T>::negateMatrix() {
341562790f3SAbhinav271828   for (unsigned row = 0; row < nRows; ++row)
342562790f3SAbhinav271828     negateRow(row);
343562790f3SAbhinav271828 }
344562790f3SAbhinav271828 
345562790f3SAbhinav271828 template <typename T>
3467025ff6fSArjun P SmallVector<T, 8> Matrix<T>::preMultiplyWithRow(ArrayRef<T> rowVec) const {
34725549414SArjun P   assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
34825549414SArjun P 
349c1b99746SAbhinav271828   SmallVector<T, 8> result(getNumColumns(), T(0));
35025549414SArjun P   for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
35125549414SArjun P     for (unsigned i = 0, e = getNumRows(); i < e; ++i)
35225549414SArjun P       result[col] += rowVec[i] * at(i, col);
35325549414SArjun P   return result;
35425549414SArjun P }
35525549414SArjun P 
3567025ff6fSArjun P template <typename T>
3577025ff6fSArjun P SmallVector<T, 8> Matrix<T>::postMultiplyWithColumn(ArrayRef<T> colVec) const {
35825549414SArjun P   assert(getNumColumns() == colVec.size() &&
35925549414SArjun P          "Invalid column vector dimension!");
36025549414SArjun P 
361c1b99746SAbhinav271828   SmallVector<T, 8> result(getNumRows(), T(0));
36225549414SArjun P   for (unsigned row = 0, e = getNumRows(); row < e; row++)
36325549414SArjun P     for (unsigned i = 0, e = getNumColumns(); i < e; i++)
36425549414SArjun P       result[row] += at(row, i) * colVec[i];
36525549414SArjun P   return result;
36625549414SArjun P }
36725549414SArjun P 
368b696e25aSGroverkss /// Set M(row, targetCol) to its remainder on division by M(row, sourceCol)
369b696e25aSGroverkss /// by subtracting from column targetCol an appropriate integer multiple of
370b696e25aSGroverkss /// sourceCol. This brings M(row, targetCol) to the range [0, M(row,
371b696e25aSGroverkss /// sourceCol)). Apply the same column operation to otherMatrix, with the same
372b696e25aSGroverkss /// integer multiple.
3731a0e67d7SRamkumar Ramachandra static void modEntryColumnOperation(Matrix<DynamicAPInt> &m, unsigned row,
3747025ff6fSArjun P                                     unsigned sourceCol, unsigned targetCol,
3751a0e67d7SRamkumar Ramachandra                                     Matrix<DynamicAPInt> &otherMatrix) {
376b696e25aSGroverkss   assert(m(row, sourceCol) != 0 && "Cannot divide by zero!");
377b696e25aSGroverkss   assert(m(row, sourceCol) > 0 && "Source must be positive!");
3781a0e67d7SRamkumar Ramachandra   DynamicAPInt ratio = -floorDiv(m(row, targetCol), m(row, sourceCol));
379b696e25aSGroverkss   m.addToColumn(sourceCol, targetCol, ratio);
380b696e25aSGroverkss   otherMatrix.addToColumn(sourceCol, targetCol, ratio);
381b696e25aSGroverkss }
382b696e25aSGroverkss 
3837025ff6fSArjun P template <typename T>
384562790f3SAbhinav271828 Matrix<T> Matrix<T>::getSubMatrix(unsigned fromRow, unsigned toRow,
385562790f3SAbhinav271828                                   unsigned fromColumn,
386562790f3SAbhinav271828                                   unsigned toColumn) const {
387562790f3SAbhinav271828   assert(fromRow <= toRow && "end of row range must be after beginning!");
388562790f3SAbhinav271828   assert(toRow < nRows && "end of row range out of bounds!");
389562790f3SAbhinav271828   assert(fromColumn <= toColumn &&
390562790f3SAbhinav271828          "end of column range must be after beginning!");
391562790f3SAbhinav271828   assert(toColumn < nColumns && "end of column range out of bounds!");
392562790f3SAbhinav271828   Matrix<T> subMatrix(toRow - fromRow + 1, toColumn - fromColumn + 1);
393562790f3SAbhinav271828   for (unsigned i = fromRow; i <= toRow; ++i)
394562790f3SAbhinav271828     for (unsigned j = fromColumn; j <= toColumn; ++j)
395562790f3SAbhinav271828       subMatrix(i - fromRow, j - fromColumn) = at(i, j);
396562790f3SAbhinav271828   return subMatrix;
397562790f3SAbhinav271828 }
398562790f3SAbhinav271828 
399562790f3SAbhinav271828 template <typename T>
4007025ff6fSArjun P void Matrix<T>::print(raw_ostream &os) const {
401*27402735SAmy Wang   PrintTableMetrics ptm = {0, 0, "-"};
402*27402735SAmy Wang   for (unsigned row = 0; row < nRows; ++row)
403c1b99746SAbhinav271828     for (unsigned column = 0; column < nColumns; ++column)
404*27402735SAmy Wang       updatePrintMetrics<T>(at(row, column), ptm);
405*27402735SAmy Wang   unsigned MIN_SPACING = 1;
406*27402735SAmy Wang   for (unsigned row = 0; row < nRows; ++row) {
407*27402735SAmy Wang     for (unsigned column = 0; column < nColumns; ++column) {
408*27402735SAmy Wang       printWithPrintMetrics<T>(os, at(row, column), MIN_SPACING, ptm);
409*27402735SAmy Wang     }
410*27402735SAmy Wang     os << "\n";
411c1b99746SAbhinav271828   }
412c1b99746SAbhinav271828 }
413c1b99746SAbhinav271828 
414562790f3SAbhinav271828 /// We iterate over the `indicator` bitset, checking each bit. If a bit is 1,
415562790f3SAbhinav271828 /// we append it to one matrix, and if it is zero, we append it to the other.
416562790f3SAbhinav271828 template <typename T>
417562790f3SAbhinav271828 std::pair<Matrix<T>, Matrix<T>>
418562790f3SAbhinav271828 Matrix<T>::splitByBitset(ArrayRef<int> indicator) {
419562790f3SAbhinav271828   Matrix<T> rowsForOne(0, nColumns), rowsForZero(0, nColumns);
420562790f3SAbhinav271828   for (unsigned i = 0; i < nRows; i++) {
421562790f3SAbhinav271828     if (indicator[i] == 1)
422562790f3SAbhinav271828       rowsForOne.appendExtraRow(getRow(i));
423562790f3SAbhinav271828     else
424562790f3SAbhinav271828       rowsForZero.appendExtraRow(getRow(i));
425562790f3SAbhinav271828   }
426562790f3SAbhinav271828   return {rowsForOne, rowsForZero};
427562790f3SAbhinav271828 }
428562790f3SAbhinav271828 
4297025ff6fSArjun P template <typename T>
4307025ff6fSArjun P void Matrix<T>::dump() const {
4317025ff6fSArjun P   print(llvm::errs());
4327025ff6fSArjun P }
433c1b99746SAbhinav271828 
4347025ff6fSArjun P template <typename T>
4357025ff6fSArjun P bool Matrix<T>::hasConsistentState() const {
436c1b99746SAbhinav271828   if (data.size() != nRows * nReservedColumns)
437c1b99746SAbhinav271828     return false;
438c1b99746SAbhinav271828   if (nColumns > nReservedColumns)
439c1b99746SAbhinav271828     return false;
440c1b99746SAbhinav271828 #ifdef EXPENSIVE_CHECKS
441c1b99746SAbhinav271828   for (unsigned r = 0; r < nRows; ++r)
442c1b99746SAbhinav271828     for (unsigned c = nColumns; c < nReservedColumns; ++c)
443c1b99746SAbhinav271828       if (data[r * nReservedColumns + c] != 0)
444c1b99746SAbhinav271828         return false;
445c1b99746SAbhinav271828 #endif
446c1b99746SAbhinav271828   return true;
447c1b99746SAbhinav271828 }
448c1b99746SAbhinav271828 
449c1b99746SAbhinav271828 namespace mlir {
450c1b99746SAbhinav271828 namespace presburger {
4511a0e67d7SRamkumar Ramachandra template class Matrix<DynamicAPInt>;
452c1b99746SAbhinav271828 template class Matrix<Fraction>;
4537025ff6fSArjun P } // namespace presburger
4547025ff6fSArjun P } // namespace mlir
455c1b99746SAbhinav271828 
456c1b99746SAbhinav271828 IntMatrix IntMatrix::identity(unsigned dimension) {
457c1b99746SAbhinav271828   IntMatrix matrix(dimension, dimension);
458c1b99746SAbhinav271828   for (unsigned i = 0; i < dimension; ++i)
459c1b99746SAbhinav271828     matrix(i, i) = 1;
460c1b99746SAbhinav271828   return matrix;
461c1b99746SAbhinav271828 }
462c1b99746SAbhinav271828 
463c1b99746SAbhinav271828 std::pair<IntMatrix, IntMatrix> IntMatrix::computeHermiteNormalForm() const {
464b696e25aSGroverkss   // We start with u as an identity matrix and perform operations on h until h
465b696e25aSGroverkss   // is in hermite normal form. We apply the same sequence of operations on u to
466b696e25aSGroverkss   // obtain a transform that takes h to hermite normal form.
467c1b99746SAbhinav271828   IntMatrix h = *this;
468c1b99746SAbhinav271828   IntMatrix u = IntMatrix::identity(h.getNumColumns());
469b696e25aSGroverkss 
470b696e25aSGroverkss   unsigned echelonCol = 0;
471b696e25aSGroverkss   // Invariant: in all rows above row, all columns from echelonCol onwards
472b696e25aSGroverkss   // are all zero elements. In an iteration, if the curent row has any non-zero
473b696e25aSGroverkss   // elements echelonCol onwards, we bring one to echelonCol and use it to
474b696e25aSGroverkss   // make all elements echelonCol + 1 onwards zero.
475b696e25aSGroverkss   for (unsigned row = 0; row < h.getNumRows(); ++row) {
476b696e25aSGroverkss     // Search row for a non-empty entry, starting at echelonCol.
477b696e25aSGroverkss     unsigned nonZeroCol = echelonCol;
478b696e25aSGroverkss     for (unsigned e = h.getNumColumns(); nonZeroCol < e; ++nonZeroCol) {
479b696e25aSGroverkss       if (h(row, nonZeroCol) == 0)
480b696e25aSGroverkss         continue;
481b696e25aSGroverkss       break;
482b696e25aSGroverkss     }
483b696e25aSGroverkss 
484b696e25aSGroverkss     // Continue to the next row with the same echelonCol if this row is all
485b696e25aSGroverkss     // zeros from echelonCol onwards.
486b696e25aSGroverkss     if (nonZeroCol == h.getNumColumns())
487b696e25aSGroverkss       continue;
488b696e25aSGroverkss 
489b696e25aSGroverkss     // Bring the non-zero column to echelonCol. This doesn't affect rows
490b696e25aSGroverkss     // above since they are all zero at these columns.
491b696e25aSGroverkss     if (nonZeroCol != echelonCol) {
492b696e25aSGroverkss       h.swapColumns(nonZeroCol, echelonCol);
493b696e25aSGroverkss       u.swapColumns(nonZeroCol, echelonCol);
494b696e25aSGroverkss     }
495b696e25aSGroverkss 
496b696e25aSGroverkss     // Make h(row, echelonCol) non-negative.
497b696e25aSGroverkss     if (h(row, echelonCol) < 0) {
498b696e25aSGroverkss       h.negateColumn(echelonCol);
499b696e25aSGroverkss       u.negateColumn(echelonCol);
500b696e25aSGroverkss     }
501b696e25aSGroverkss 
502b696e25aSGroverkss     // Make all the entries in row after echelonCol zero.
503b696e25aSGroverkss     for (unsigned i = echelonCol + 1, e = h.getNumColumns(); i < e; ++i) {
504b696e25aSGroverkss       // We make h(row, i) non-negative, and then apply the Euclidean GCD
505b696e25aSGroverkss       // algorithm to (row, i) and (row, echelonCol). At the end, one of them
506b696e25aSGroverkss       // has value equal to the gcd of the two entries, and the other is zero.
507b696e25aSGroverkss 
508b696e25aSGroverkss       if (h(row, i) < 0) {
509b696e25aSGroverkss         h.negateColumn(i);
510b696e25aSGroverkss         u.negateColumn(i);
511b696e25aSGroverkss       }
512b696e25aSGroverkss 
513b696e25aSGroverkss       unsigned targetCol = i, sourceCol = echelonCol;
514b696e25aSGroverkss       // At every step, we set h(row, targetCol) %= h(row, sourceCol), and
515b696e25aSGroverkss       // swap the indices sourceCol and targetCol. (not the columns themselves)
516b696e25aSGroverkss       // This modulo is implemented as a subtraction
517b696e25aSGroverkss       // h(row, targetCol) -= quotient * h(row, sourceCol),
518b696e25aSGroverkss       // where quotient = floor(h(row, targetCol) / h(row, sourceCol)),
519b696e25aSGroverkss       // which brings h(row, targetCol) to the range [0, h(row, sourceCol)).
520b696e25aSGroverkss       //
521b696e25aSGroverkss       // We are only allowed column operations; we perform the above
522b696e25aSGroverkss       // for every row, i.e., the above subtraction is done as a column
523b696e25aSGroverkss       // operation. This does not affect any rows above us since they are
524b696e25aSGroverkss       // guaranteed to be zero at these columns.
525b696e25aSGroverkss       while (h(row, targetCol) != 0 && h(row, sourceCol) != 0) {
526b696e25aSGroverkss         modEntryColumnOperation(h, row, sourceCol, targetCol, u);
527b696e25aSGroverkss         std::swap(targetCol, sourceCol);
528b696e25aSGroverkss       }
529b696e25aSGroverkss 
530b696e25aSGroverkss       // One of (row, echelonCol) and (row, i) is zero and the other is the gcd.
531b696e25aSGroverkss       // Make it so that (row, echelonCol) holds the non-zero value.
532b696e25aSGroverkss       if (h(row, echelonCol) == 0) {
533b696e25aSGroverkss         h.swapColumns(i, echelonCol);
534b696e25aSGroverkss         u.swapColumns(i, echelonCol);
535b696e25aSGroverkss       }
536b696e25aSGroverkss     }
537b696e25aSGroverkss 
538b696e25aSGroverkss     // Make all entries before echelonCol non-negative and strictly smaller
539b696e25aSGroverkss     // than the pivot entry.
540b696e25aSGroverkss     for (unsigned i = 0; i < echelonCol; ++i)
541b696e25aSGroverkss       modEntryColumnOperation(h, row, echelonCol, i, u);
542b696e25aSGroverkss 
543b696e25aSGroverkss     ++echelonCol;
544b696e25aSGroverkss   }
545b696e25aSGroverkss 
546b696e25aSGroverkss   return {h, u};
547b696e25aSGroverkss }
548b696e25aSGroverkss 
5491a0e67d7SRamkumar Ramachandra DynamicAPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
550c1b99746SAbhinav271828   return normalizeRange(getRow(row).slice(0, cols));
55110a898b3SArjun P }
55210a898b3SArjun P 
5531a0e67d7SRamkumar Ramachandra DynamicAPInt IntMatrix::normalizeRow(unsigned row) {
554c1b99746SAbhinav271828   return normalizeRow(row, getNumColumns());
555c605dfcfSArjun P }
556f08fe1f1SAbhinav271828 
5571a0e67d7SRamkumar Ramachandra DynamicAPInt IntMatrix::determinant(IntMatrix *inverse) const {
558f08fe1f1SAbhinav271828   assert(nRows == nColumns &&
559f08fe1f1SAbhinav271828          "determinant can only be calculated for square matrices!");
560f08fe1f1SAbhinav271828 
561f08fe1f1SAbhinav271828   FracMatrix m(*this);
562f08fe1f1SAbhinav271828 
563f08fe1f1SAbhinav271828   FracMatrix fracInverse(nRows, nColumns);
5641a0e67d7SRamkumar Ramachandra   DynamicAPInt detM = m.determinant(&fracInverse).getAsInteger();
565f08fe1f1SAbhinav271828 
566f08fe1f1SAbhinav271828   if (detM == 0)
5671a0e67d7SRamkumar Ramachandra     return DynamicAPInt(0);
568f08fe1f1SAbhinav271828 
569e213af78SAbhinav271828   if (!inverse)
570e213af78SAbhinav271828     return detM;
571e213af78SAbhinav271828 
572f08fe1f1SAbhinav271828   *inverse = IntMatrix(nRows, nColumns);
573f08fe1f1SAbhinav271828   for (unsigned i = 0; i < nRows; i++)
574f08fe1f1SAbhinav271828     for (unsigned j = 0; j < nColumns; j++)
575f08fe1f1SAbhinav271828       inverse->at(i, j) = (fracInverse.at(i, j) * detM).getAsInteger();
576f08fe1f1SAbhinav271828 
577f08fe1f1SAbhinav271828   return detM;
578f08fe1f1SAbhinav271828 }
579f08fe1f1SAbhinav271828 
580f08fe1f1SAbhinav271828 FracMatrix FracMatrix::identity(unsigned dimension) {
581f08fe1f1SAbhinav271828   return Matrix::identity(dimension);
582f08fe1f1SAbhinav271828 }
583f08fe1f1SAbhinav271828 
584f08fe1f1SAbhinav271828 FracMatrix::FracMatrix(IntMatrix m)
585f08fe1f1SAbhinav271828     : FracMatrix(m.getNumRows(), m.getNumColumns()) {
58684ab06baSAbhinav271828   for (unsigned i = 0, r = m.getNumRows(); i < r; i++)
58784ab06baSAbhinav271828     for (unsigned j = 0, c = m.getNumColumns(); j < c; j++)
588f08fe1f1SAbhinav271828       this->at(i, j) = m.at(i, j);
589f08fe1f1SAbhinav271828 }
590f08fe1f1SAbhinav271828 
591f08fe1f1SAbhinav271828 Fraction FracMatrix::determinant(FracMatrix *inverse) const {
592f08fe1f1SAbhinav271828   assert(nRows == nColumns &&
593f08fe1f1SAbhinav271828          "determinant can only be calculated for square matrices!");
594f08fe1f1SAbhinav271828 
595f08fe1f1SAbhinav271828   FracMatrix m(*this);
596f08fe1f1SAbhinav271828   FracMatrix tempInv(nRows, nColumns);
597f08fe1f1SAbhinav271828   if (inverse)
598f08fe1f1SAbhinav271828     tempInv = FracMatrix::identity(nRows);
599f08fe1f1SAbhinav271828 
600f08fe1f1SAbhinav271828   Fraction a, b;
601f08fe1f1SAbhinav271828   // Make the matrix into upper triangular form using
602f08fe1f1SAbhinav271828   // gaussian elimination with row operations.
603f08fe1f1SAbhinav271828   // If inverse is required, we apply more operations
604f08fe1f1SAbhinav271828   // to turn the matrix into diagonal form. We apply
605f08fe1f1SAbhinav271828   // the same operations to the inverse matrix,
606f08fe1f1SAbhinav271828   // which is initially identity.
607f08fe1f1SAbhinav271828   // Either way, the product of the diagonal elements
608f08fe1f1SAbhinav271828   // is then the determinant.
609f08fe1f1SAbhinav271828   for (unsigned i = 0; i < nRows; i++) {
610f08fe1f1SAbhinav271828     if (m(i, i) == 0)
611f08fe1f1SAbhinav271828       // First ensure that the diagonal
612f08fe1f1SAbhinav271828       // element is nonzero, by swapping
613f08fe1f1SAbhinav271828       // it with a nonzero row.
614f08fe1f1SAbhinav271828       for (unsigned j = i + 1; j < nRows; j++) {
615f08fe1f1SAbhinav271828         if (m(j, i) != 0) {
616f08fe1f1SAbhinav271828           m.swapRows(j, i);
617f08fe1f1SAbhinav271828           if (inverse)
618f08fe1f1SAbhinav271828             tempInv.swapRows(j, i);
619f08fe1f1SAbhinav271828           break;
620f08fe1f1SAbhinav271828         }
621f08fe1f1SAbhinav271828       }
622f08fe1f1SAbhinav271828 
623f08fe1f1SAbhinav271828     b = m.at(i, i);
624f08fe1f1SAbhinav271828     if (b == 0)
625f08fe1f1SAbhinav271828       return 0;
626f08fe1f1SAbhinav271828 
627f08fe1f1SAbhinav271828     // Set all elements above the
628f08fe1f1SAbhinav271828     // diagonal to zero.
629f08fe1f1SAbhinav271828     if (inverse) {
630f08fe1f1SAbhinav271828       for (unsigned j = 0; j < i; j++) {
631f08fe1f1SAbhinav271828         if (m.at(j, i) == 0)
632f08fe1f1SAbhinav271828           continue;
633f08fe1f1SAbhinav271828         a = m.at(j, i);
634f08fe1f1SAbhinav271828         // Set element (j, i) to zero
635f08fe1f1SAbhinav271828         // by subtracting the ith row,
636f08fe1f1SAbhinav271828         // appropriately scaled.
637f08fe1f1SAbhinav271828         m.addToRow(i, j, -a / b);
638f08fe1f1SAbhinav271828         tempInv.addToRow(i, j, -a / b);
639f08fe1f1SAbhinav271828       }
640f08fe1f1SAbhinav271828     }
641f08fe1f1SAbhinav271828 
642f08fe1f1SAbhinav271828     // Set all elements below the
643f08fe1f1SAbhinav271828     // diagonal to zero.
644f08fe1f1SAbhinav271828     for (unsigned j = i + 1; j < nRows; j++) {
645f08fe1f1SAbhinav271828       if (m.at(j, i) == 0)
646f08fe1f1SAbhinav271828         continue;
647f08fe1f1SAbhinav271828       a = m.at(j, i);
648f08fe1f1SAbhinav271828       // Set element (j, i) to zero
649f08fe1f1SAbhinav271828       // by subtracting the ith row,
650f08fe1f1SAbhinav271828       // appropriately scaled.
651f08fe1f1SAbhinav271828       m.addToRow(i, j, -a / b);
652f08fe1f1SAbhinav271828       if (inverse)
653f08fe1f1SAbhinav271828         tempInv.addToRow(i, j, -a / b);
654f08fe1f1SAbhinav271828     }
655f08fe1f1SAbhinav271828   }
656f08fe1f1SAbhinav271828 
657f08fe1f1SAbhinav271828   // Now only diagonal elements of m are nonzero, but they are
658f08fe1f1SAbhinav271828   // not necessarily 1. To get the true inverse, we should
659f08fe1f1SAbhinav271828   // normalize them and apply the same scale to the inverse matrix.
660f08fe1f1SAbhinav271828   // For efficiency we skip scaling m and just scale tempInv appropriately.
661f08fe1f1SAbhinav271828   if (inverse) {
662f08fe1f1SAbhinav271828     for (unsigned i = 0; i < nRows; i++)
663f08fe1f1SAbhinav271828       for (unsigned j = 0; j < nRows; j++)
664f08fe1f1SAbhinav271828         tempInv.at(i, j) = tempInv.at(i, j) / m(i, i);
665f08fe1f1SAbhinav271828 
666f08fe1f1SAbhinav271828     *inverse = std::move(tempInv);
667f08fe1f1SAbhinav271828   }
668f08fe1f1SAbhinav271828 
669f08fe1f1SAbhinav271828   Fraction determinant = 1;
670f08fe1f1SAbhinav271828   for (unsigned i = 0; i < nRows; i++)
671f08fe1f1SAbhinav271828     determinant *= m.at(i, i);
672f08fe1f1SAbhinav271828 
673f08fe1f1SAbhinav271828   return determinant;
674f08fe1f1SAbhinav271828 }
67584ab06baSAbhinav271828 
67684ab06baSAbhinav271828 FracMatrix FracMatrix::gramSchmidt() const {
67784ab06baSAbhinav271828   // Create a copy of the argument to store
67884ab06baSAbhinav271828   // the orthogonalised version.
67984ab06baSAbhinav271828   FracMatrix orth(*this);
68084ab06baSAbhinav271828 
68184ab06baSAbhinav271828   // For each vector (row) in the matrix, subtract its unit
68284ab06baSAbhinav271828   // projection along each of the previous vectors.
68384ab06baSAbhinav271828   // This ensures that it has no component in the direction
68484ab06baSAbhinav271828   // of any of the previous vectors.
68584ab06baSAbhinav271828   for (unsigned i = 1, e = getNumRows(); i < e; i++) {
68684ab06baSAbhinav271828     for (unsigned j = 0; j < i; j++) {
68784ab06baSAbhinav271828       Fraction jNormSquared = dotProduct(orth.getRow(j), orth.getRow(j));
68884ab06baSAbhinav271828       assert(jNormSquared != 0 && "some row became zero! Inputs to this "
68984ab06baSAbhinav271828                                   "function must be linearly independent.");
69084ab06baSAbhinav271828       Fraction projectionScale =
69184ab06baSAbhinav271828           dotProduct(orth.getRow(i), orth.getRow(j)) / jNormSquared;
69284ab06baSAbhinav271828       orth.addToRow(j, i, -projectionScale);
69384ab06baSAbhinav271828     }
69484ab06baSAbhinav271828   }
69584ab06baSAbhinav271828   return orth;
69684ab06baSAbhinav271828 }
697cfd51fbaSAbhinav271828 
698cfd51fbaSAbhinav271828 // Convert the matrix, interpreted (row-wise) as a basis
699cfd51fbaSAbhinav271828 // to an LLL-reduced basis.
700cfd51fbaSAbhinav271828 //
701cfd51fbaSAbhinav271828 // This is an implementation of the algorithm described in
702cfd51fbaSAbhinav271828 // "Factoring polynomials with rational coefficients" by
703cfd51fbaSAbhinav271828 // A. K. Lenstra, H. W. Lenstra Jr., L. Lovasz.
704cfd51fbaSAbhinav271828 //
705cfd51fbaSAbhinav271828 // Let {b_1,  ..., b_n}  be the current basis and
706cfd51fbaSAbhinav271828 //     {b_1*, ..., b_n*} be the Gram-Schmidt orthogonalised
707cfd51fbaSAbhinav271828 //                          basis (unnormalized).
708cfd51fbaSAbhinav271828 // Define the Gram-Schmidt coefficients μ_ij as
709cfd51fbaSAbhinav271828 // (b_i • b_j*) / (b_j* • b_j*), where (•) represents the inner product.
710cfd51fbaSAbhinav271828 //
711cfd51fbaSAbhinav271828 // We iterate starting from the second row to the last row.
712cfd51fbaSAbhinav271828 //
713cfd51fbaSAbhinav271828 // For the kth row, we first check μ_kj for all rows j < k.
714cfd51fbaSAbhinav271828 // We subtract b_j (scaled by the integer nearest to μ_kj)
715cfd51fbaSAbhinav271828 // from b_k.
716cfd51fbaSAbhinav271828 //
717cfd51fbaSAbhinav271828 // Now, we update k.
718cfd51fbaSAbhinav271828 // If b_k and b_{k-1} satisfy the Lovasz condition
719cfd51fbaSAbhinav271828 //    |b_k|^2 ≥ (δ - μ_k{k-1}^2) |b_{k-1}|^2,
720cfd51fbaSAbhinav271828 // we are done and we increment k.
721cfd51fbaSAbhinav271828 // Otherwise, we swap b_k and b_{k-1} and decrement k.
722cfd51fbaSAbhinav271828 //
723cfd51fbaSAbhinav271828 // We repeat this until k = n and return.
724cfd51fbaSAbhinav271828 void FracMatrix::LLL(Fraction delta) {
7251a0e67d7SRamkumar Ramachandra   DynamicAPInt nearest;
726cfd51fbaSAbhinav271828   Fraction mu;
727cfd51fbaSAbhinav271828 
728cfd51fbaSAbhinav271828   // `gsOrth` holds the Gram-Schmidt orthogonalisation
729cfd51fbaSAbhinav271828   // of the matrix at all times. It is recomputed every
730cfd51fbaSAbhinav271828   // time the matrix is modified during the algorithm.
731cfd51fbaSAbhinav271828   // This is naive and can be optimised.
732cfd51fbaSAbhinav271828   FracMatrix gsOrth = gramSchmidt();
733cfd51fbaSAbhinav271828 
734cfd51fbaSAbhinav271828   // We start from the second row.
735cfd51fbaSAbhinav271828   unsigned k = 1;
736cfd51fbaSAbhinav271828   while (k < getNumRows()) {
737cfd51fbaSAbhinav271828     for (unsigned j = k - 1; j < k; j--) {
738cfd51fbaSAbhinav271828       // Compute the Gram-Schmidt coefficient μ_jk.
739cfd51fbaSAbhinav271828       mu = dotProduct(getRow(k), gsOrth.getRow(j)) /
740cfd51fbaSAbhinav271828            dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
741cfd51fbaSAbhinav271828       nearest = round(mu);
742cfd51fbaSAbhinav271828       // Subtract b_j scaled by the integer nearest to μ_jk from b_k.
743cfd51fbaSAbhinav271828       addToRow(k, getRow(j), -Fraction(nearest, 1));
744cfd51fbaSAbhinav271828       gsOrth = gramSchmidt(); // Update orthogonalization.
745cfd51fbaSAbhinav271828     }
746cfd51fbaSAbhinav271828     mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
747cfd51fbaSAbhinav271828          dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1));
748cfd51fbaSAbhinav271828     // Check the Lovasz condition for b_k and b_{k-1}.
749cfd51fbaSAbhinav271828     if (dotProduct(gsOrth.getRow(k), gsOrth.getRow(k)) >
750cfd51fbaSAbhinav271828         (delta - mu * mu) *
751cfd51fbaSAbhinav271828             dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1))) {
752cfd51fbaSAbhinav271828       // If it is satisfied, proceed to the next k.
753cfd51fbaSAbhinav271828       k += 1;
754cfd51fbaSAbhinav271828     } else {
755cfd51fbaSAbhinav271828       // If it is not satisfied, decrement k (without
756cfd51fbaSAbhinav271828       // going beyond the second row).
757cfd51fbaSAbhinav271828       swapRows(k, k - 1);
758cfd51fbaSAbhinav271828       gsOrth = gramSchmidt(); // Update orthogonalization.
759cfd51fbaSAbhinav271828       k = k > 1 ? k - 1 : 1;
760cfd51fbaSAbhinav271828     }
761cfd51fbaSAbhinav271828   }
762cfd51fbaSAbhinav271828 }
763562790f3SAbhinav271828 
764562790f3SAbhinav271828 IntMatrix FracMatrix::normalizeRows() const {
765562790f3SAbhinav271828   unsigned numRows = getNumRows();
766562790f3SAbhinav271828   unsigned numColumns = getNumColumns();
767562790f3SAbhinav271828   IntMatrix normalized(numRows, numColumns);
768562790f3SAbhinav271828 
7691a0e67d7SRamkumar Ramachandra   DynamicAPInt lcmDenoms = DynamicAPInt(1);
770562790f3SAbhinav271828   for (unsigned i = 0; i < numRows; i++) {
771562790f3SAbhinav271828     // For a row, first compute the LCM of the denominators.
772562790f3SAbhinav271828     for (unsigned j = 0; j < numColumns; j++)
773562790f3SAbhinav271828       lcmDenoms = lcm(lcmDenoms, at(i, j).den);
774562790f3SAbhinav271828     // Then, multiply by it throughout and convert to integers.
775562790f3SAbhinav271828     for (unsigned j = 0; j < numColumns; j++)
776562790f3SAbhinav271828       normalized(i, j) = (at(i, j) * lcmDenoms).getAsInteger();
777562790f3SAbhinav271828   }
778562790f3SAbhinav271828   return normalized;
779562790f3SAbhinav271828 }
780