xref: /llvm-project/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp (revision 70e227a404e51f9248c7ad5d79953805b2afacb4)
1 //===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "llvm/Support/Compiler.h"
11 #include "gmock/gmock.h"
12 #include "gtest/gtest.h"
13 
14 #include <memory>
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 namespace {
20 
21 ///
22 /// Defines macros to iterate binary and the combination of binary operations.
23 ///
24 
25 #define FOREVERY_BINOP(DO)                                                     \
26   DO(mulf, TensorExp::Kind::kMulF)                                             \
27   DO(mulc, TensorExp::Kind::kMulC)                                             \
28   DO(muli, TensorExp::Kind::kMulI)                                             \
29   DO(addf, TensorExp::Kind::kAddF)                                             \
30   DO(addc, TensorExp::Kind::kAddC)                                             \
31   DO(addi, TensorExp::Kind::kAddI)                                             \
32   DO(subf, TensorExp::Kind::kSubF)                                             \
33   DO(subc, TensorExp::Kind::kSubC)                                             \
34   DO(subi, TensorExp::Kind::kSubI)                                             \
35   DO(andi, TensorExp::Kind::kAndI)                                             \
36   DO(xori, TensorExp::Kind::kXorI)                                             \
37   DO(ori, TensorExp::Kind::kOrI)                                               \
38   DO(cmpf, TensorExp::Kind::kCmpF)                                             \
39   DO(cmpi, TensorExp::Kind::kCmpI)
40 
41 #define FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, EXTRA)                          \
42   TEST(addf, EXTRA)                                                            \
43   TEST(addc, EXTRA)                                                            \
44   TEST(addi, EXTRA)                                                            \
45   TEST(xori, EXTRA)                                                            \
46   TEST(ori, EXTRA)
47 
48 #define FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, EXTRA)                          \
49   TEST(mulf, EXTRA)                                                            \
50   TEST(mulc, EXTRA)                                                            \
51   TEST(muli, EXTRA)                                                            \
52   TEST(andi, EXTRA)
53 
54 #define FOREVERY_COMMON_DISJ_BINOP(TEST)                                       \
55   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "")
56 
57 #define FOREVERY_COMMON_CONJ_BINOP(TEST)                                       \
58   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "")
59 
60 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
61   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf)                                 \
62   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc)                                 \
63   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi)                                 \
64   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori)                                 \
65   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori)
66 
67 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
68   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf)                                 \
69   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc)                                 \
70   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli)                                 \
71   FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi)
72 
73 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
74   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf)                                 \
75   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc)                                 \
76   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi)                                 \
77   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori)                                  \
78   FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori)
79 
80 ///
81 /// Helper classes/functions for testing Merger.
82 ///
83 
84 /// Simple recursive data structure used to match expressions in `Merger`,
85 /// which uses const references into the short-lived data strucutures.
86 struct Match {
87   struct Children {
Children__anon93a878f70111::Match::Children88     Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
89     const Match &e0;
90     const Match &e1;
91   };
92 
Match__anon93a878f70111::Match93   Match() : kind(TensorExp::Kind::kSynZero) {}
Match__anon93a878f70111::Match94   Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
Match__anon93a878f70111::Match95   Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
96       : kind(kind), children(e0, e1) {
97     assert(kind >= TensorExp::Kind::kMulF);
98   }
99 
100   TensorExp::Kind kind;
101   union {
102     TensorId tid;
103     Children children;
104   };
105 };
106 
107 ///
108 /// Readable Match builder functions.
109 /// These should be preferred over the actual constructors.
110 ///
111 
tensorMatch(TensorId tid)112 static Match tensorMatch(TensorId tid) { return Match(tid); }
synZeroMatch()113 static Match synZeroMatch() { return Match(); }
114 
115 #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
116   LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0,                \
117                                                const Match &e1) {              \
118     return Match(KIND, e0, e1);                                                \
119   }
120 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
121 #undef IMPL_BINOP_PATTERN
122 
123 // Parameterize LevelFormat to test both Dense and Batch LevelFormat.
124 class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
125 protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)126   MergerTestBase(unsigned numTensors, unsigned numLoops)
127       : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
128     tensors.reserve(numTensors);
129     for (unsigned t = 0; t < numTensors; t++)
130       tensors.push_back(merger.addTensorExp(tid(t)));
131   }
132 
133   ///
134   /// Expression construction helpers.
135   ///
136 
tid(unsigned t) const137   TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
lid(unsigned i) const138   LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
tensor(unsigned t) const139   ExprId tensor(unsigned t) const {
140     assert(t < tensors.size());
141     return tensors[t];
142   }
143 
144 #define IMPL_BINOP_EXPR(OP, KIND)                                              \
145   LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) {                \
146     return merger.addExp(KIND, e0, e1);                                        \
147   }
FOREVERY_BINOP(IMPL_BINOP_EXPR)148   FOREVERY_BINOP(IMPL_BINOP_EXPR)
149 #undef IMPL_BINOP_EXPR
150 
151   ///
152   /// Comparison helpers.
153   ///
154 
155   /// Returns true if any lattice point with an expression matching
156   /// the given `pattern` and bits matching the given `bits` is present
157   /// in the `[lo, lo+n)` slice of the lattice set `s`.  This is useful
158   /// for testing partial ordering constraints between lattice points.
159   /// We generally know how contiguous groups of lattice points should
160   /// be ordered with respect to other groups, but there is no required
161   /// ordering within groups.  If `simple` is true, then compare the
162   /// `lat.simple` field instead to test the result after optimization.
163   bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
164                            const Match &pattern, const BitVector &bits,
165                            bool simple) {
166     for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
167       if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
168           compareBits(s, k, bits, simple))
169         return true;
170     }
171     return false;
172   }
173 
174   /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(LatSetId s,unsigned lo,unsigned n,const Match & pattern,const BitVector & bits,bool simple=false)175   void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
176                                  const Match &pattern, const BitVector &bits,
177                                  bool simple = false) {
178     EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
179   }
180 
181   /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(LatSetId s,unsigned lo,const Match & pattern,const BitVector & bits,bool simple=false)182   void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
183                       const BitVector &bits, bool simple = false) {
184     EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
185   }
186 
187   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
188   /// corresponding bits set.
loopsToBits(const std::vector<std::pair<LoopId,TensorId>> & loops)189   BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
190     BitVector testBits = BitVector(merger.getNumTensors(), false);
191     for (auto [loop, tensor] : loops)
192       testBits.set(merger.makeTensorLoopId(tensor, loop));
193     return testBits;
194   }
195 
196   /// Returns true if the bits of the `k`th point in set `s` matches
197   /// the given `bits`.  If `simple` is true, then compares the `lat.simple`
198   /// field instead, to test the result after optimization
compareBits(LatSetId s,unsigned k,const BitVector & bits,bool simple)199   bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
200     const auto &point = merger.lat(merger.set(s)[k]);
201     return (simple ? point.simple : point.bits) == bits;
202   }
203 
204   /// Check that there are n lattice points in set s.
expectNumLatPoints(LatSetId s,unsigned n)205   void expectNumLatPoints(LatSetId s, unsigned n) {
206     EXPECT_THAT(merger.set(s).size(), n);
207   }
208 
209   /// Compares expressions for equality. Equality is defined recursively as:
210   /// - Operations are equal if they have the same kind and children.
211   /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(ExprId e,const Match & pattern)212   bool compareExpression(ExprId e, const Match &pattern) {
213     const auto &tensorExp = merger.exp(e);
214     if (tensorExp.kind != pattern.kind)
215       return false;
216     switch (tensorExp.kind) {
217     // Leaf.
218     case TensorExp::Kind::kTensor:
219       return tensorExp.tensor == pattern.tid;
220     case TensorExp::Kind::kSynZero:
221       // Already checked kind equivalence @L233
222       return true;
223     case TensorExp::Kind::kInvariant:
224       llvm_unreachable("invariant not handled yet");
225     case TensorExp::Kind::kLoopVar:
226       llvm_unreachable("loop-variables not handled yet");
227     // Unary operations.
228     case TensorExp::Kind::kAbsF:
229     case TensorExp::Kind::kAbsC:
230     case TensorExp::Kind::kAbsI:
231     case TensorExp::Kind::kCeilF:
232     case TensorExp::Kind::kFloorF:
233     case TensorExp::Kind::kSqrtF:
234     case TensorExp::Kind::kSqrtC:
235     case TensorExp::Kind::kExpm1F:
236     case TensorExp::Kind::kExpm1C:
237     case TensorExp::Kind::kLog1pF:
238     case TensorExp::Kind::kLog1pC:
239     case TensorExp::Kind::kRelu:
240     case TensorExp::Kind::kSinF:
241     case TensorExp::Kind::kSinC:
242     case TensorExp::Kind::kTanhF:
243     case TensorExp::Kind::kTanhC:
244     case TensorExp::Kind::kNegF:
245     case TensorExp::Kind::kNegC:
246     case TensorExp::Kind::kNegI:
247     case TensorExp::Kind::kTruncF:
248     case TensorExp::Kind::kExtF:
249     case TensorExp::Kind::kCastFS:
250     case TensorExp::Kind::kCastFU:
251     case TensorExp::Kind::kCastSF:
252     case TensorExp::Kind::kCastUF:
253     case TensorExp::Kind::kCastS:
254     case TensorExp::Kind::kCastU:
255     case TensorExp::Kind::kCastIdx:
256     case TensorExp::Kind::kTruncI:
257     case TensorExp::Kind::kCIm:
258     case TensorExp::Kind::kCRe:
259     case TensorExp::Kind::kBitCast:
260     case TensorExp::Kind::kSelect:
261     case TensorExp::Kind::kBinaryBranch:
262     case TensorExp::Kind::kUnary:
263       return compareExpression(tensorExp.children.e0, pattern.children.e0);
264     // Binary operations.
265     case TensorExp::Kind::kMulF:
266     case TensorExp::Kind::kMulC:
267     case TensorExp::Kind::kMulI:
268     case TensorExp::Kind::kDivF:
269     case TensorExp::Kind::kDivC:
270     case TensorExp::Kind::kDivS:
271     case TensorExp::Kind::kDivU:
272     case TensorExp::Kind::kAddF:
273     case TensorExp::Kind::kAddC:
274     case TensorExp::Kind::kAddI:
275     case TensorExp::Kind::kSubF:
276     case TensorExp::Kind::kSubC:
277     case TensorExp::Kind::kSubI:
278     case TensorExp::Kind::kAndI:
279     case TensorExp::Kind::kOrI:
280     case TensorExp::Kind::kXorI:
281     case TensorExp::Kind::kCmpF:
282     case TensorExp::Kind::kCmpI:
283     case TensorExp::Kind::kShrS:
284     case TensorExp::Kind::kShrU:
285     case TensorExp::Kind::kShlI:
286     case TensorExp::Kind::kBinary:
287     case TensorExp::Kind::kReduce:
288       return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
289              compareExpression(tensorExp.children.e1, pattern.children.e1);
290     case TensorExp::Kind::kDenseOp: {
291       bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
292       if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
293         return compareExpression(tensorExp.children.e1, pattern.children.e1);
294       return eq;
295     }
296     }
297     llvm_unreachable("unexpected kind");
298   }
299 
300   // This field is public for convenience.
301   Merger merger;
302 
303 private:
304   // This field is private to prevent mutation after the ctor.
305   SmallVector<ExprId> tensors;
306 };
307 
308 ///
309 /// Tests with all sparse inputs.
310 ///
311 
312 /// Three tensors (two inputs, one output); and a single loop.
313 class MergerTest3T1L : public MergerTestBase {
314 protected:
MergerTest3T1L()315   MergerTest3T1L() : MergerTestBase(3, 1) {
316     EXPECT_TRUE(merger.getOutTensorID() == tid(2));
317     // Tensor 0: sparse input vector.
318     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
319     // Tensor 1: sparse input vector.
320     merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
321     // Tensor 2: dense output vector.
322     merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
323   }
324 };
325 
326 INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
327                          ::testing::Values(LevelFormat::Dense,
328                                            LevelFormat::Batch));
329 
330 /// Four tensors (three inputs, one output); and a single loop.
331 class MergerTest4T1L : public MergerTestBase {
332 protected:
MergerTest4T1L()333   MergerTest4T1L() : MergerTestBase(4, 1) {
334     EXPECT_TRUE(merger.getOutTensorID() == tid(3));
335     // Tensor 0: sparse input vector.
336     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
337     // Tensor 1: sparse input vector.
338     merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
339     // Tensor 2: sparse input vector
340     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
341     // Tensor 3: dense output vector
342     merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
343   }
344 };
345 
346 INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
347                          ::testing::Values(LevelFormat::Dense,
348                                            LevelFormat::Batch));
349 
350 ///
351 /// Tests with both sparse and dense input.
352 ///
353 
354 /// Three tensors (two inputs, one output); and a single loop.
355 class MergerTest3T1LD : public MergerTestBase {
356 protected:
MergerTest3T1LD()357   MergerTest3T1LD() : MergerTestBase(3, 1) {
358     EXPECT_TRUE(merger.getOutTensorID() == tid(2));
359     // Tensor 0: sparse input vector.
360     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
361     // Tensor 1: dense input vector.
362     merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
363     // Tensor 2: dense output vector.
364     merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
365   }
366 };
367 
368 INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
369                          ::testing::Values(LevelFormat::Dense,
370                                            LevelFormat::Batch));
371 
372 ///
373 /// Tests with both undef and dense input.
374 ///
375 
376 /// Three tensors (three inputs, one output); and a single loop.
377 class MergerTest4T1LU : public MergerTestBase {
378 protected:
MergerTest4T1LU()379   MergerTest4T1LU() : MergerTestBase(4, 1) {
380     EXPECT_TRUE(merger.getOutTensorID() == tid(3));
381     // Tensor 0: undef input vector.
382     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
383     // Tensor 1: dense input vector.
384     merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
385     // Tensor 2: undef input vector.
386     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
387     // Tensor 3: dense output vector.
388     merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
389   }
390 };
391 
392 INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
393                          ::testing::Values(LevelFormat::Dense,
394                                            LevelFormat::Batch));
395 
396 ///
397 /// Tests with operation on sparse output.
398 ///
399 
400 /// Three tensors (two inputs, one output, one synthetic); and a single loop.
401 class MergerTest3T1LSo : public MergerTestBase {
402 protected:
MergerTest3T1LSo()403   MergerTest3T1LSo() : MergerTestBase(3, 1) {
404     EXPECT_TRUE(merger.getOutTensorID() == tid(2));
405     EXPECT_TRUE(merger.getSynTensorID() == tid(3));
406     merger.setHasSparseOut(true);
407     // Tensor 0: undef input vector.
408     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
409     // Tensor 1: undef input vector.
410     merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
411     // Tensor 2: sparse output vector.
412     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
413   }
414 };
415 
416 // This testsuite does not use any dense-like format, just one of {Dense, Batch}
417 // is enough.
418 INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
419                          ::testing::Values(LevelFormat::Dense));
420 
421 } // namespace
422 
423 /// Vector multiplication (conjunction) of 3 vectors, i.e.;
424 ///   a(i) = b(i) * c(i) * d(i)
425 /// which should form the single lattice point
426 /// {
427 ///   lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
428 /// }
429 /// after optimization, the dense dimesion should be kept, despite it appears
430 /// in the middle
431 /// {
432 ///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
433 /// }
434 #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
435   TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
436     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
437     const auto e = CONJ2##Expr(em, tensor(2));                                 \
438     const auto l0 = lid(0);                                                    \
439     const auto t0 = tid(0);                                                    \
440     const auto t1 = tid(1);                                                    \
441     const auto t2 = tid(2);                                                    \
442     const Match &p0 = tensorMatch(t0);                                         \
443     const Match &p1 = tensorMatch(t1);                                         \
444     const Match &p2 = tensorMatch(t2);                                         \
445     auto s = merger.buildLattices(e, l0);                                      \
446     expectNumLatPoints(s, 1);                                                  \
447     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
448                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
449     s = merger.optimizeSet(s);                                                 \
450     expectNumLatPoints(s, 1);                                                  \
451     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
452                    loopsToBits({{l0, t1}}), true);                             \
453   }
454 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
455 #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
456 
457 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
458 ///   o(i) = b(i) * c(i) * o(i)
459 /// which should form the single lattice point (note how a synthetic tensor
460 /// i_03_U is created for the sparse output)
461 /// {
462 ///   lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
463 /// }
464 /// after optimization, the synthetic tensor should be preserved.
465 /// {
466 ///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
467 /// }
468 #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
469   TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
470     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
471     const auto e = CONJ2##Expr(em, tensor(2));                                 \
472     const auto l0 = lid(0);                                                    \
473     const auto t0 = tid(0);                                                    \
474     const auto t1 = tid(1);                                                    \
475     const auto t2 = tid(2);                                                    \
476     const auto t3 = tid(3);                                                    \
477     const Match &p0 = tensorMatch(t0);                                         \
478     const Match &p1 = tensorMatch(t1);                                         \
479     const Match &p2 = tensorMatch(t2);                                         \
480     auto s = merger.buildLattices(e, l0);                                      \
481     expectNumLatPoints(s, 1);                                                  \
482     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
483                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}}));               \
484     s = merger.optimizeSet(s);                                                 \
485     expectNumLatPoints(s, 1);                                                  \
486     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
487                    loopsToBits({{l0, t3}}), true);                             \
488   }
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)489 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
490 #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
491 
492 /// Vector addition (disjunction) of 2 vectors. i.e.;
493 ///   a(i) = b(i) + c(i)
494 /// which should form the 3 lattice points
495 /// {
496 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
497 ///   lat( i_00 / tensor_0 )
498 ///   lat( i_01 / tensor_1 )
499 /// }
500 /// and after optimization, the lattice points do not change (as there is no
501 /// duplicated point and all input vectors are sparse vector).
502 /// {
503 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
504 ///   lat( i_00 / tensor_0 )
505 ///   lat( i_01 / tensor_1 )
506 /// }
507 #define IMPL_MERGER_TEST_DISJ(OP, UNUSED)                                      \
508   TEST_P(MergerTest3T1L, vector_##OP) {                                        \
509     const auto e = OP##Expr(tensor(0), tensor(1));                             \
510     const auto l0 = lid(0);                                                    \
511     const auto t0 = tid(0);                                                    \
512     const auto t1 = tid(1);                                                    \
513     const Match &p0 = tensorMatch(t0);                                         \
514     const Match &p1 = tensorMatch(t1);                                         \
515     auto s = merger.buildLattices(e, l0);                                      \
516                                                                                \
517     expectNumLatPoints(s, 3);                                                  \
518     expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
519                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
520     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
521     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
522                                                                                \
523     s = merger.optimizeSet(s);                                                 \
524     expectNumLatPoints(s, 3);                                                  \
525     expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
526                    true);                                                      \
527     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true);     \
528     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true);     \
529   }
530 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
531 #undef IMPL_MERGER_TEST_DISJ
532 
533 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
534 ///   a(i) = b(i) * c(i)
535 /// which should form the single lattice point
536 /// {
537 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
538 /// }
539 #define IMPL_MERGER_TEST_CONJ(OP, UNUSED)                                      \
540   TEST_P(MergerTest3T1L, vector_##OP) {                                        \
541     const auto e = OP##Expr(tensor(0), tensor(1));                             \
542     const auto l0 = lid(0);                                                    \
543     const auto t0 = tid(0);                                                    \
544     const auto t1 = tid(1);                                                    \
545     const Match &p0 = tensorMatch(t0);                                         \
546     const Match &p1 = tensorMatch(t1);                                         \
547     auto s = merger.buildLattices(e, l0);                                      \
548                                                                                \
549     expectNumLatPoints(s, 1);                                                  \
550     expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
551                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
552                                                                                \
553     s = merger.optimizeSet(s);                                                 \
554     expectNumLatPoints(s, 1);                                                  \
555     expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
556                    true);                                                      \
557   }
558 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
559 #undef IMPL_MERGER_TEST_CONJ
560 
561 /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
562 ///   a(i) = b(i) * c(i) + d(i);
563 /// which should form
564 /// {
565 ///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
566 ///    lat( i_00 i_01 / tensor_0 * tensor_1
567 ///    lat( i_02 / tensor_2 )
568 /// }
569 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
570   TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
571     const auto em = CONJ##Expr(tensor(0), tensor(1));                          \
572     const auto e = DISJ##Expr(em, tensor(2));                                  \
573     const auto l0 = lid(0);                                                    \
574     const auto t0 = tid(0);                                                    \
575     const auto t1 = tid(1);                                                    \
576     const auto t2 = tid(2);                                                    \
577     const Match &p0 = tensorMatch(t0);                                         \
578     const Match &p1 = tensorMatch(t1);                                         \
579     const Match &p2 = tensorMatch(t2);                                         \
580     auto s = merger.buildLattices(e, l0);                                      \
581                                                                                \
582     expectNumLatPoints(s, 3);                                                  \
583     expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
584                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
585     expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
586                               loopsToBits({{l0, t0}, {l0, t1}}));              \
587     expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
588                                                                                \
589     s = merger.optimizeSet(s);                                                 \
590     expectNumLatPoints(s, 3);                                                  \
591     expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
592                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
593     expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
594                               loopsToBits({{l0, t0}, {l0, t1}}));              \
595     expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
596   }
597 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
598 #undef IMPL_MERGER_TEST_CONJ_DISJ
599 
600 /// Vector addition (disjunction) then addition (disjunction), i.e.;
601 ///   a(i) = b(i) + c(i) + d(i)
602 /// which should form
603 /// {
604 ///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
605 ///   lat( i_02 i_01 / tensor_2 + tensor_1 )
606 ///   lat( i_02 i_00 / tensor_2 + tensor_0 )
607 ///   lat( i_01 i_00 / tensor_1 + tensor_0 )
608 ///   lat( i_02 / tensor_2 )
609 ///   lat( i_01 / tensor_1 )
610 ///   lat( i_00 / tensor_0 )
611 /// }
612 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
613   TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
614     const auto em = DISJ1##Expr(tensor(0), tensor(1));                         \
615     const auto e = DISJ2##Expr(em, tensor(2));                                 \
616     const auto l0 = lid(0);                                                    \
617     const auto t0 = tid(0);                                                    \
618     const auto t1 = tid(1);                                                    \
619     const auto t2 = tid(2);                                                    \
620     const Match &p0 = tensorMatch(t0);                                         \
621     const Match &p1 = tensorMatch(t1);                                         \
622     const Match &p2 = tensorMatch(t2);                                         \
623     auto s = merger.buildLattices(e, l0);                                      \
624                                                                                \
625     expectNumLatPoints(s, 7);                                                  \
626     expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
627                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
628     expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
629                               loopsToBits({{l0, t1}, {l0, t2}}));              \
630     expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
631                               loopsToBits({{l0, t0}, {l0, t2}}));              \
632     expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
633                               loopsToBits({{l0, t0}, {l0, t1}}));              \
634     expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
635     expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
636     expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
637                                                                                \
638     s = merger.optimizeSet(s);                                                 \
639     expectNumLatPoints(s, 7);                                                  \
640     expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
641                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
642     expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
643                               loopsToBits({{l0, t1}, {l0, t2}}));              \
644     expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
645                               loopsToBits({{l0, t0}, {l0, t2}}));              \
646     expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
647                               loopsToBits({{l0, t0}, {l0, t1}}));              \
648     expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
649     expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
650     expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
651   }
652 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
653 #undef IMPL_MERGER_TEST_DISJ_DISJ
654 
655 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
656 ///   a(i) = b(i) * c(i) * d(i);
657 /// which should form
658 /// {
659 ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
660 /// }
661 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
662   TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
663     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
664     const auto e = CONJ2##Expr(em, tensor(2));                                 \
665     const auto l0 = lid(0);                                                    \
666     const auto t0 = tid(0);                                                    \
667     const auto t1 = tid(1);                                                    \
668     const auto t2 = tid(2);                                                    \
669     const Match &p0 = tensorMatch(t0);                                         \
670     const Match &p1 = tensorMatch(t1);                                         \
671     const Match &p2 = tensorMatch(t2);                                         \
672     auto s = merger.buildLattices(e, l0);                                      \
673     expectNumLatPoints(s, 1);                                                  \
674     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
675                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
676     s = merger.optimizeSet(s);                                                 \
677     expectNumLatPoints(s, 1);                                                  \
678     expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
679                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
680   }
681 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
682 #undef IMPL_MERGER_TEST_CONJ_CONJ
683 
684 /// Vector addition (disjunction) of 2 vectors, i.e.;
685 ///   a(i) = b(i) + c(i)
686 /// which should form the 3 lattice points
687 /// {
688 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
689 ///   lat( i_00 / sparse_tensor_0 )
690 ///   lat( i_01 / dense_tensor_1 )
691 /// }
692 /// which should be optimized to
693 /// {
694 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
695 ///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
696 /// }
697 ///
698 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
699 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
700 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED)                            \
701   TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
702     const auto e = OP##Expr(tensor(0), tensor(1));                             \
703     const auto l0 = lid(0);                                                    \
704     const auto t0 = tid(0);                                                    \
705     const auto t1 = tid(1);                                                    \
706     const Match &p0 = tensorMatch(t0);                                         \
707     const Match &p1 = tensorMatch(t1);                                         \
708     auto s = merger.buildLattices(e, l0);                                      \
709                                                                                \
710     expectNumLatPoints(s, 3);                                                  \
711     expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
712                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
713     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
714     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
715                                                                                \
716     s = merger.optimizeSet(s);                                                 \
717     expectNumLatPoints(s, 2);                                                  \
718     expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
719                    true);                                                      \
720     expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true);                   \
721   }
722 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
723 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
724 
725 /// Vector multiplication (conjunction) of 2 vectors, i.e.:
726 ///   a(i) = b(i) * c(i)
727 /// which should form the single lattice point
728 /// {
729 ///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
730 /// }
731 /// it should be optimized to
732 /// {
733 ///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
734 /// }
735 /// since i_01 is a dense dimension.
736 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED)                            \
737   TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
738     const auto e = OP##Expr(tensor(0), tensor(1));                             \
739     const auto l0 = lid(0);                                                    \
740     const auto t0 = tid(0);                                                    \
741     const auto t1 = tid(1);                                                    \
742     const Match &p0 = tensorMatch(t0);                                         \
743     const Match &p1 = tensorMatch(t1);                                         \
744     auto s = merger.buildLattices(e, l0);                                      \
745                                                                                \
746     expectNumLatPoints(s, 1);                                                  \
747     expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
748                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
749                                                                                \
750     s = merger.optimizeSet(s);                                                 \
751     expectNumLatPoints(s, 1);                                                  \
752     expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true);    \
753   }
754 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
755 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
756 
757 /// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
758 ///   a(i) = b(i) + c(i)
759 /// which should form the 3 lattice points
760 /// {
761 ///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
762 ///   lat( i_00 / tensor_0 cmp 0 )
763 ///   lat( i_01 / 0 cmp tensor_1 )
764 /// }
765 /// and after optimization, the lattice points do not change (as there is no
766 /// duplicated point and all input vectors are sparse vector).
767 /// {
768 ///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
769 ///   lat( i_00 / tensor_0 cmp 0 )
770 ///   lat( i_01 / 0 cmp tensor_1 )
771 /// }
772 TEST_P(MergerTest3T1L, vector_cmp) {
773   const auto e = cmpiExpr(tensor(0), tensor(1));
774   const auto l0 = lid(0);
775   const auto t0 = tid(0);
776   const auto t1 = tid(1);
777   const Match &zero = synZeroMatch();
778   const Match &p0 = tensorMatch(t0);
779   const Match &p1 = tensorMatch(t1);
780   auto s = merger.buildLattices(e, l0);
781   expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
782   expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
783                             loopsToBits({{l0, t0}}));
784   expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
785                             loopsToBits({{l0, t1}}));
786   s = merger.optimizeSet(s);
787   expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
788   expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
789                             loopsToBits({{l0, t0}}));
790   expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
791                             loopsToBits({{l0, t1}}));
792 }
793 
794 /// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.;
795 ///   a(i) = b(i) cmp c(i)
796 /// which should form the 3 lattice points
797 /// {
798 ///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) )
799 ///   lat( i_00 / sparse_tensor_0 cmp 0)
800 ///   lat( i_01 / 0 cmp dense_tensor_1 )
801 /// }
802 /// which should be optimized to
803 /// {
804 ///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton)
805 ///   lat( i_01 / 0 cmp dense_tensor_0 ) ()
806 /// }
807 ///
808 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
809 /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_P(MergerTest3T1LD,vector_cmp)810 TEST_P(MergerTest3T1LD, vector_cmp) {
811   const auto e = cmpiExpr(tensor(0), tensor(1));
812   const auto l0 = lid(0);
813   const auto t0 = tid(0);
814   const auto t1 = tid(1);
815   const Match &zero = synZeroMatch();
816   const Match &p0 = tensorMatch(t0);
817   const Match &p1 = tensorMatch(t1);
818   auto s = merger.buildLattices(e, l0);
819   expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
820   expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
821                             loopsToBits({{l0, t0}}));
822   expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
823                             loopsToBits({{l0, t1}}));
824   s = merger.optimizeSet(s);
825   expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
826   expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
827                             loopsToBits({{l0, t1}}));
828 }
829