1fd9b3e47SAart Bik //===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
2fd9b3e47SAart Bik //
3fd9b3e47SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd9b3e47SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5fd9b3e47SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd9b3e47SAart Bik //
7fd9b3e47SAart Bik //===----------------------------------------------------------------------===//
8fd9b3e47SAart Bik
940843347SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
1066ae1d60SPeiming Liu #include "llvm/Support/Compiler.h"
1140843347SGus Smith #include "gmock/gmock.h"
1240843347SGus Smith #include "gtest/gtest.h"
131e15d791SAart Bik
1440843347SGus Smith #include <memory>
1540843347SGus Smith
166842ec42SRiver Riddle using namespace mlir;
1740843347SGus Smith using namespace mlir::sparse_tensor;
1840843347SGus Smith
1940843347SGus Smith namespace {
2040843347SGus Smith
2166ae1d60SPeiming Liu ///
2266ae1d60SPeiming Liu /// Defines macros to iterate binary and the combination of binary operations.
2366ae1d60SPeiming Liu ///
2466ae1d60SPeiming Liu
2566ae1d60SPeiming Liu #define FOREVERY_BINOP(DO) \
261f58ae80Swren romano DO(mulf, TensorExp::Kind::kMulF) \
271f58ae80Swren romano DO(mulc, TensorExp::Kind::kMulC) \
281f58ae80Swren romano DO(muli, TensorExp::Kind::kMulI) \
291f58ae80Swren romano DO(addf, TensorExp::Kind::kAddF) \
301f58ae80Swren romano DO(addc, TensorExp::Kind::kAddC) \
311f58ae80Swren romano DO(addi, TensorExp::Kind::kAddI) \
321f58ae80Swren romano DO(subf, TensorExp::Kind::kSubF) \
331f58ae80Swren romano DO(subc, TensorExp::Kind::kSubC) \
341f58ae80Swren romano DO(subi, TensorExp::Kind::kSubI) \
351f58ae80Swren romano DO(andi, TensorExp::Kind::kAndI) \
361f58ae80Swren romano DO(xori, TensorExp::Kind::kXorI) \
37faf7cd97SPeiming Liu DO(ori, TensorExp::Kind::kOrI) \
38faf7cd97SPeiming Liu DO(cmpf, TensorExp::Kind::kCmpF) \
39faf7cd97SPeiming Liu DO(cmpi, TensorExp::Kind::kCmpI)
4066ae1d60SPeiming Liu
411e15d791SAart Bik #define FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, EXTRA) \
421e15d791SAart Bik TEST(addf, EXTRA) \
431e15d791SAart Bik TEST(addc, EXTRA) \
441e15d791SAart Bik TEST(addi, EXTRA) \
451e15d791SAart Bik TEST(xori, EXTRA) \
461e15d791SAart Bik TEST(ori, EXTRA)
4766ae1d60SPeiming Liu
481e15d791SAart Bik #define FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, EXTRA) \
491e15d791SAart Bik TEST(mulf, EXTRA) \
501e15d791SAart Bik TEST(mulc, EXTRA) \
511e15d791SAart Bik TEST(muli, EXTRA) \
521e15d791SAart Bik TEST(andi, EXTRA)
531e15d791SAart Bik
541e15d791SAart Bik #define FOREVERY_COMMON_DISJ_BINOP(TEST) \
551e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "")
561e15d791SAart Bik
571e15d791SAart Bik #define FOREVERY_COMMON_CONJ_BINOP(TEST) \
581e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "")
5966ae1d60SPeiming Liu
6066ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
611e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf) \
621e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc) \
631e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi) \
641e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori) \
651e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori)
6666ae1d60SPeiming Liu
6766ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
681e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf) \
691e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc) \
701e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli) \
711e15d791SAart Bik FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi)
7266ae1d60SPeiming Liu
7366ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
741e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf) \
751e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc) \
761e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi) \
771e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori) \
781e15d791SAart Bik FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori)
7966ae1d60SPeiming Liu
8066ae1d60SPeiming Liu ///
8166ae1d60SPeiming Liu /// Helper classes/functions for testing Merger.
8266ae1d60SPeiming Liu ///
8366ae1d60SPeiming Liu
84fd9b3e47SAart Bik /// Simple recursive data structure used to match expressions in `Merger`,
85fd9b3e47SAart Bik /// which uses const references into the short-lived data strucutures.
86fd9b3e47SAart Bik struct Match {
87164b918dSwren romano struct Children {
Children__anon93a878f70111::Match::Children88fd9b3e47SAart Bik Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
89fd9b3e47SAart Bik const Match &e0;
90fd9b3e47SAart Bik const Match &e1;
91164b918dSwren romano };
92164b918dSwren romano
Match__anon93a878f70111::Match93fd9b3e47SAart Bik Match() : kind(TensorExp::Kind::kSynZero) {}
Match__anon93a878f70111::Match94fd9b3e47SAart Bik Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
Match__anon93a878f70111::Match95fd9b3e47SAart Bik Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
96164b918dSwren romano : kind(kind), children(e0, e1) {
971f58ae80Swren romano assert(kind >= TensorExp::Kind::kMulF);
9840843347SGus Smith }
99fd9b3e47SAart Bik
100fd9b3e47SAart Bik TensorExp::Kind kind;
101fd9b3e47SAart Bik union {
102fd9b3e47SAart Bik TensorId tid;
103fd9b3e47SAart Bik Children children;
104fd9b3e47SAart Bik };
10540843347SGus Smith };
10640843347SGus Smith
10740843347SGus Smith ///
108fd9b3e47SAart Bik /// Readable Match builder functions.
10940843347SGus Smith /// These should be preferred over the actual constructors.
11040843347SGus Smith ///
11140843347SGus Smith
tensorMatch(TensorId tid)112fd9b3e47SAart Bik static Match tensorMatch(TensorId tid) { return Match(tid); }
synZeroMatch()113fd9b3e47SAart Bik static Match synZeroMatch() { return Match(); }
11440843347SGus Smith
11566ae1d60SPeiming Liu #define IMPL_BINOP_PATTERN(OP, KIND) \
116fd9b3e47SAart Bik LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
117fd9b3e47SAart Bik const Match &e1) { \
118fd9b3e47SAart Bik return Match(KIND, e0, e1); \
11940843347SGus Smith }
12066ae1d60SPeiming Liu FOREVERY_BINOP(IMPL_BINOP_PATTERN)
12166ae1d60SPeiming Liu #undef IMPL_BINOP_PATTERN
12240843347SGus Smith
123d82e93e7SPeiming Liu // Parameterize LevelFormat to test both Dense and Batch LevelFormat.
124d82e93e7SPeiming Liu class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
12540843347SGus Smith protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)12640843347SGus Smith MergerTestBase(unsigned numTensors, unsigned numLoops)
127ccd923e3SPeiming Liu : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
128164b918dSwren romano tensors.reserve(numTensors);
129164b918dSwren romano for (unsigned t = 0; t < numTensors; t++)
13046a384dfSwren romano tensors.push_back(merger.addTensorExp(tid(t)));
131164b918dSwren romano }
13240843347SGus Smith
13340843347SGus Smith ///
13440843347SGus Smith /// Expression construction helpers.
13540843347SGus Smith ///
13640843347SGus Smith
tid(unsigned t) const13746a384dfSwren romano TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
lid(unsigned i) const13846a384dfSwren romano LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
tensor(unsigned t) const139164b918dSwren romano ExprId tensor(unsigned t) const {
140164b918dSwren romano assert(t < tensors.size());
141164b918dSwren romano return tensors[t];
14240843347SGus Smith }
14340843347SGus Smith
14466ae1d60SPeiming Liu #define IMPL_BINOP_EXPR(OP, KIND) \
145164b918dSwren romano LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
14666ae1d60SPeiming Liu return merger.addExp(KIND, e0, e1); \
14740843347SGus Smith }
FOREVERY_BINOP(IMPL_BINOP_EXPR)14866ae1d60SPeiming Liu FOREVERY_BINOP(IMPL_BINOP_EXPR)
14966ae1d60SPeiming Liu #undef IMPL_BINOP_EXPR
15040843347SGus Smith
15140843347SGus Smith ///
15240843347SGus Smith /// Comparison helpers.
15340843347SGus Smith ///
15440843347SGus Smith
155164b918dSwren romano /// Returns true if any lattice point with an expression matching
156164b918dSwren romano /// the given `pattern` and bits matching the given `bits` is present
157164b918dSwren romano /// in the `[lo, lo+n)` slice of the lattice set `s`. This is useful
158164b918dSwren romano /// for testing partial ordering constraints between lattice points.
159164b918dSwren romano /// We generally know how contiguous groups of lattice points should
160164b918dSwren romano /// be ordered with respect to other groups, but there is no required
161164b918dSwren romano /// ordering within groups. If `simple` is true, then compare the
162164b918dSwren romano /// `lat.simple` field instead to test the result after optimization.
163164b918dSwren romano bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
164fd9b3e47SAart Bik const Match &pattern, const BitVector &bits,
165164b918dSwren romano bool simple) {
166164b918dSwren romano for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
167164b918dSwren romano if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
168164b918dSwren romano compareBits(s, k, bits, simple))
16940843347SGus Smith return true;
17040843347SGus Smith }
17140843347SGus Smith return false;
17240843347SGus Smith }
17340843347SGus Smith
17440843347SGus Smith /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(LatSetId s,unsigned lo,unsigned n,const Match & pattern,const BitVector & bits,bool simple=false)175164b918dSwren romano void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
176fd9b3e47SAart Bik const Match &pattern, const BitVector &bits,
177164b918dSwren romano bool simple = false) {
178164b918dSwren romano EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
17940843347SGus Smith }
18040843347SGus Smith
18140843347SGus Smith /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(LatSetId s,unsigned lo,const Match & pattern,const BitVector & bits,bool simple=false)182fd9b3e47SAart Bik void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
18366ae1d60SPeiming Liu const BitVector &bits, bool simple = false) {
184164b918dSwren romano EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
18540843347SGus Smith }
18640843347SGus Smith
18740843347SGus Smith /// Converts a vector of (loop, tensor) pairs to a bitvector with the
18840843347SGus Smith /// corresponding bits set.
loopsToBits(const std::vector<std::pair<LoopId,TensorId>> & loops)189164b918dSwren romano BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
19046a384dfSwren romano BitVector testBits = BitVector(merger.getNumTensors(), false);
191164b918dSwren romano for (auto [loop, tensor] : loops)
19246a384dfSwren romano testBits.set(merger.makeTensorLoopId(tensor, loop));
19340843347SGus Smith return testBits;
19440843347SGus Smith }
19540843347SGus Smith
196164b918dSwren romano /// Returns true if the bits of the `k`th point in set `s` matches
197164b918dSwren romano /// the given `bits`. If `simple` is true, then compares the `lat.simple`
198164b918dSwren romano /// field instead, to test the result after optimization
compareBits(LatSetId s,unsigned k,const BitVector & bits,bool simple)199164b918dSwren romano bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
200164b918dSwren romano const auto &point = merger.lat(merger.set(s)[k]);
201164b918dSwren romano return (simple ? point.simple : point.bits) == bits;
20240843347SGus Smith }
20340843347SGus Smith
20440843347SGus Smith /// Check that there are n lattice points in set s.
expectNumLatPoints(LatSetId s,unsigned n)205164b918dSwren romano void expectNumLatPoints(LatSetId s, unsigned n) {
20640843347SGus Smith EXPECT_THAT(merger.set(s).size(), n);
20740843347SGus Smith }
20840843347SGus Smith
20940843347SGus Smith /// Compares expressions for equality. Equality is defined recursively as:
21006aa6ec8SAart Bik /// - Operations are equal if they have the same kind and children.
21106aa6ec8SAart Bik /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(ExprId e,const Match & pattern)212fd9b3e47SAart Bik bool compareExpression(ExprId e, const Match &pattern) {
213164b918dSwren romano const auto &tensorExp = merger.exp(e);
214164b918dSwren romano if (tensorExp.kind != pattern.kind)
21540843347SGus Smith return false;
21640843347SGus Smith switch (tensorExp.kind) {
21706aa6ec8SAart Bik // Leaf.
2181f58ae80Swren romano case TensorExp::Kind::kTensor:
219164b918dSwren romano return tensorExp.tensor == pattern.tid;
220faf7cd97SPeiming Liu case TensorExp::Kind::kSynZero:
221faf7cd97SPeiming Liu // Already checked kind equivalence @L233
222faf7cd97SPeiming Liu return true;
2231f58ae80Swren romano case TensorExp::Kind::kInvariant:
22406aa6ec8SAart Bik llvm_unreachable("invariant not handled yet");
225164b918dSwren romano case TensorExp::Kind::kLoopVar:
226164b918dSwren romano llvm_unreachable("loop-variables not handled yet");
22706aa6ec8SAart Bik // Unary operations.
2281f58ae80Swren romano case TensorExp::Kind::kAbsF:
2291f58ae80Swren romano case TensorExp::Kind::kAbsC:
2301f58ae80Swren romano case TensorExp::Kind::kAbsI:
2311f58ae80Swren romano case TensorExp::Kind::kCeilF:
2321f58ae80Swren romano case TensorExp::Kind::kFloorF:
2331f58ae80Swren romano case TensorExp::Kind::kSqrtF:
2341f58ae80Swren romano case TensorExp::Kind::kSqrtC:
2351f58ae80Swren romano case TensorExp::Kind::kExpm1F:
2361f58ae80Swren romano case TensorExp::Kind::kExpm1C:
2371f58ae80Swren romano case TensorExp::Kind::kLog1pF:
2381f58ae80Swren romano case TensorExp::Kind::kLog1pC:
239*70e227a4SAart Bik case TensorExp::Kind::kRelu:
2401f58ae80Swren romano case TensorExp::Kind::kSinF:
2411f58ae80Swren romano case TensorExp::Kind::kSinC:
2421f58ae80Swren romano case TensorExp::Kind::kTanhF:
2431f58ae80Swren romano case TensorExp::Kind::kTanhC:
2441f58ae80Swren romano case TensorExp::Kind::kNegF:
2451f58ae80Swren romano case TensorExp::Kind::kNegC:
2461f58ae80Swren romano case TensorExp::Kind::kNegI:
2471f58ae80Swren romano case TensorExp::Kind::kTruncF:
2481f58ae80Swren romano case TensorExp::Kind::kExtF:
2491f58ae80Swren romano case TensorExp::Kind::kCastFS:
2501f58ae80Swren romano case TensorExp::Kind::kCastFU:
2511f58ae80Swren romano case TensorExp::Kind::kCastSF:
2521f58ae80Swren romano case TensorExp::Kind::kCastUF:
2531f58ae80Swren romano case TensorExp::Kind::kCastS:
2541f58ae80Swren romano case TensorExp::Kind::kCastU:
2551f58ae80Swren romano case TensorExp::Kind::kCastIdx:
2561f58ae80Swren romano case TensorExp::Kind::kTruncI:
2571f58ae80Swren romano case TensorExp::Kind::kCIm:
2581f58ae80Swren romano case TensorExp::Kind::kCRe:
2591f58ae80Swren romano case TensorExp::Kind::kBitCast:
2601f58ae80Swren romano case TensorExp::Kind::kSelect:
2611f58ae80Swren romano case TensorExp::Kind::kBinaryBranch:
2621f58ae80Swren romano case TensorExp::Kind::kUnary:
263164b918dSwren romano return compareExpression(tensorExp.children.e0, pattern.children.e0);
26406aa6ec8SAart Bik // Binary operations.
2651f58ae80Swren romano case TensorExp::Kind::kMulF:
2661f58ae80Swren romano case TensorExp::Kind::kMulC:
2671f58ae80Swren romano case TensorExp::Kind::kMulI:
2681f58ae80Swren romano case TensorExp::Kind::kDivF:
2691f58ae80Swren romano case TensorExp::Kind::kDivC:
2701f58ae80Swren romano case TensorExp::Kind::kDivS:
2711f58ae80Swren romano case TensorExp::Kind::kDivU:
2721f58ae80Swren romano case TensorExp::Kind::kAddF:
2731f58ae80Swren romano case TensorExp::Kind::kAddC:
2741f58ae80Swren romano case TensorExp::Kind::kAddI:
2751f58ae80Swren romano case TensorExp::Kind::kSubF:
2761f58ae80Swren romano case TensorExp::Kind::kSubC:
2771f58ae80Swren romano case TensorExp::Kind::kSubI:
2781f58ae80Swren romano case TensorExp::Kind::kAndI:
2791f58ae80Swren romano case TensorExp::Kind::kOrI:
2801f58ae80Swren romano case TensorExp::Kind::kXorI:
281faf7cd97SPeiming Liu case TensorExp::Kind::kCmpF:
282faf7cd97SPeiming Liu case TensorExp::Kind::kCmpI:
2831f58ae80Swren romano case TensorExp::Kind::kShrS:
2841f58ae80Swren romano case TensorExp::Kind::kShrU:
2851f58ae80Swren romano case TensorExp::Kind::kShlI:
2861f58ae80Swren romano case TensorExp::Kind::kBinary:
2871f58ae80Swren romano case TensorExp::Kind::kReduce:
288164b918dSwren romano return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
289164b918dSwren romano compareExpression(tensorExp.children.e1, pattern.children.e1);
290df11a2b4SPeiming Liu case TensorExp::Kind::kDenseOp: {
291df11a2b4SPeiming Liu bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
292df11a2b4SPeiming Liu if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
293df11a2b4SPeiming Liu return compareExpression(tensorExp.children.e1, pattern.children.e1);
294df11a2b4SPeiming Liu return eq;
295df11a2b4SPeiming Liu }
29640843347SGus Smith }
29706aa6ec8SAart Bik llvm_unreachable("unexpected kind");
29840843347SGus Smith }
29940843347SGus Smith
300164b918dSwren romano // This field is public for convenience.
30140843347SGus Smith Merger merger;
302164b918dSwren romano
303164b918dSwren romano private:
304164b918dSwren romano // This field is private to prevent mutation after the ctor.
305164b918dSwren romano SmallVector<ExprId> tensors;
30640843347SGus Smith };
30740843347SGus Smith
30866ae1d60SPeiming Liu ///
30966ae1d60SPeiming Liu /// Tests with all sparse inputs.
31066ae1d60SPeiming Liu ///
31166ae1d60SPeiming Liu
312164b918dSwren romano /// Three tensors (two inputs, one output); and a single loop.
31340843347SGus Smith class MergerTest3T1L : public MergerTestBase {
31440843347SGus Smith protected:
MergerTest3T1L()31540843347SGus Smith MergerTest3T1L() : MergerTestBase(3, 1) {
316164b918dSwren romano EXPECT_TRUE(merger.getOutTensorID() == tid(2));
31740843347SGus Smith // Tensor 0: sparse input vector.
318aaf91645SPeiming Liu merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
31940843347SGus Smith // Tensor 1: sparse input vector.
320aaf91645SPeiming Liu merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
32140843347SGus Smith // Tensor 2: dense output vector.
322d82e93e7SPeiming Liu merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
32340843347SGus Smith }
32440843347SGus Smith };
32540843347SGus Smith
326d82e93e7SPeiming Liu INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
327d82e93e7SPeiming Liu ::testing::Values(LevelFormat::Dense,
328d82e93e7SPeiming Liu LevelFormat::Batch));
329d82e93e7SPeiming Liu
330164b918dSwren romano /// Four tensors (three inputs, one output); and a single loop.
33166ae1d60SPeiming Liu class MergerTest4T1L : public MergerTestBase {
33266ae1d60SPeiming Liu protected:
MergerTest4T1L()33366ae1d60SPeiming Liu MergerTest4T1L() : MergerTestBase(4, 1) {
334164b918dSwren romano EXPECT_TRUE(merger.getOutTensorID() == tid(3));
33566ae1d60SPeiming Liu // Tensor 0: sparse input vector.
336aaf91645SPeiming Liu merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
33766ae1d60SPeiming Liu // Tensor 1: sparse input vector.
338aaf91645SPeiming Liu merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
33966ae1d60SPeiming Liu // Tensor 2: sparse input vector
340aaf91645SPeiming Liu merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
34166ae1d60SPeiming Liu // Tensor 3: dense output vector
342d82e93e7SPeiming Liu merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
34366ae1d60SPeiming Liu }
34466ae1d60SPeiming Liu };
34566ae1d60SPeiming Liu
346d82e93e7SPeiming Liu INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
347d82e93e7SPeiming Liu ::testing::Values(LevelFormat::Dense,
348d82e93e7SPeiming Liu LevelFormat::Batch));
349d82e93e7SPeiming Liu
35066ae1d60SPeiming Liu ///
35166ae1d60SPeiming Liu /// Tests with both sparse and dense input.
35266ae1d60SPeiming Liu ///
35366ae1d60SPeiming Liu
354164b918dSwren romano /// Three tensors (two inputs, one output); and a single loop.
35566ae1d60SPeiming Liu class MergerTest3T1LD : public MergerTestBase {
35666ae1d60SPeiming Liu protected:
MergerTest3T1LD()35766ae1d60SPeiming Liu MergerTest3T1LD() : MergerTestBase(3, 1) {
358164b918dSwren romano EXPECT_TRUE(merger.getOutTensorID() == tid(2));
35966ae1d60SPeiming Liu // Tensor 0: sparse input vector.
360aaf91645SPeiming Liu merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
36166ae1d60SPeiming Liu // Tensor 1: dense input vector.
362d82e93e7SPeiming Liu merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
36366ae1d60SPeiming Liu // Tensor 2: dense output vector.
364d82e93e7SPeiming Liu merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
36566ae1d60SPeiming Liu }
36666ae1d60SPeiming Liu };
36766ae1d60SPeiming Liu
368d82e93e7SPeiming Liu INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
369d82e93e7SPeiming Liu ::testing::Values(LevelFormat::Dense,
370d82e93e7SPeiming Liu LevelFormat::Batch));
371d82e93e7SPeiming Liu
37201dffc5aSPeiming Liu ///
37301dffc5aSPeiming Liu /// Tests with both undef and dense input.
37401dffc5aSPeiming Liu ///
375d30dccd2SPeiming Liu
376164b918dSwren romano /// Three tensors (three inputs, one output); and a single loop.
377d30dccd2SPeiming Liu class MergerTest4T1LU : public MergerTestBase {
37801dffc5aSPeiming Liu protected:
MergerTest4T1LU()379d30dccd2SPeiming Liu MergerTest4T1LU() : MergerTestBase(4, 1) {
380164b918dSwren romano EXPECT_TRUE(merger.getOutTensorID() == tid(3));
38101dffc5aSPeiming Liu // Tensor 0: undef input vector.
382aaf91645SPeiming Liu merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
38301dffc5aSPeiming Liu // Tensor 1: dense input vector.
384d82e93e7SPeiming Liu merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
385d30dccd2SPeiming Liu // Tensor 2: undef input vector.
386aaf91645SPeiming Liu merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
387d30dccd2SPeiming Liu // Tensor 3: dense output vector.
388d82e93e7SPeiming Liu merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
38901dffc5aSPeiming Liu }
39001dffc5aSPeiming Liu };
391d30dccd2SPeiming Liu
392d82e93e7SPeiming Liu INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
393d82e93e7SPeiming Liu ::testing::Values(LevelFormat::Dense,
394d82e93e7SPeiming Liu LevelFormat::Batch));
395d82e93e7SPeiming Liu
396d30dccd2SPeiming Liu ///
397d30dccd2SPeiming Liu /// Tests with operation on sparse output.
398d30dccd2SPeiming Liu ///
399d30dccd2SPeiming Liu
400164b918dSwren romano /// Three tensors (two inputs, one output, one synthetic); and a single loop.
401f7f917a7SMehdi Amini class MergerTest3T1LSo : public MergerTestBase {
402d30dccd2SPeiming Liu protected:
MergerTest3T1LSo()403f7f917a7SMehdi Amini MergerTest3T1LSo() : MergerTestBase(3, 1) {
404164b918dSwren romano EXPECT_TRUE(merger.getOutTensorID() == tid(2));
405164b918dSwren romano EXPECT_TRUE(merger.getSynTensorID() == tid(3));
406d30dccd2SPeiming Liu merger.setHasSparseOut(true);
407d30dccd2SPeiming Liu // Tensor 0: undef input vector.
408aaf91645SPeiming Liu merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
409d30dccd2SPeiming Liu // Tensor 1: undef input vector.
410aaf91645SPeiming Liu merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
411d30dccd2SPeiming Liu // Tensor 2: sparse output vector.
412aaf91645SPeiming Liu merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
413d30dccd2SPeiming Liu }
414d30dccd2SPeiming Liu };
415d30dccd2SPeiming Liu
416d82e93e7SPeiming Liu // This testsuite does not use any dense-like format, just one of {Dense, Batch}
417d82e93e7SPeiming Liu // is enough.
418d82e93e7SPeiming Liu INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
419d82e93e7SPeiming Liu ::testing::Values(LevelFormat::Dense));
420d82e93e7SPeiming Liu
421be0a7e9fSMehdi Amini } // namespace
42240843347SGus Smith
423d30dccd2SPeiming Liu /// Vector multiplication (conjunction) of 3 vectors, i.e.;
424d30dccd2SPeiming Liu /// a(i) = b(i) * c(i) * d(i)
42501dffc5aSPeiming Liu /// which should form the single lattice point
42601dffc5aSPeiming Liu /// {
427d30dccd2SPeiming Liu /// lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
42801dffc5aSPeiming Liu /// }
42901dffc5aSPeiming Liu /// after optimization, the dense dimesion should be kept, despite it appears
430d30dccd2SPeiming Liu /// in the middle
43101dffc5aSPeiming Liu /// {
432d30dccd2SPeiming Liu /// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
43301dffc5aSPeiming Liu /// }
434d30dccd2SPeiming Liu #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
435d82e93e7SPeiming Liu TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
436164b918dSwren romano const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
437164b918dSwren romano const auto e = CONJ2##Expr(em, tensor(2)); \
438164b918dSwren romano const auto l0 = lid(0); \
439164b918dSwren romano const auto t0 = tid(0); \
440164b918dSwren romano const auto t1 = tid(1); \
441164b918dSwren romano const auto t2 = tid(2); \
442fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
443fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
444fd9b3e47SAart Bik const Match &p2 = tensorMatch(t2); \
44501dffc5aSPeiming Liu auto s = merger.buildLattices(e, l0); \
44601dffc5aSPeiming Liu expectNumLatPoints(s, 1); \
447fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
448d30dccd2SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
44901dffc5aSPeiming Liu s = merger.optimizeSet(s); \
45001dffc5aSPeiming Liu expectNumLatPoints(s, 1); \
451fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
452d30dccd2SPeiming Liu loopsToBits({{l0, t1}}), true); \
45301dffc5aSPeiming Liu }
454d30dccd2SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
455d30dccd2SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
456d30dccd2SPeiming Liu
457d30dccd2SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.;
458d30dccd2SPeiming Liu /// o(i) = b(i) * c(i) * o(i)
459d30dccd2SPeiming Liu /// which should form the single lattice point (note how a synthetic tensor
460d30dccd2SPeiming Liu /// i_03_U is created for the sparse output)
461d30dccd2SPeiming Liu /// {
462d30dccd2SPeiming Liu /// lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
463d30dccd2SPeiming Liu /// }
464d30dccd2SPeiming Liu /// after optimization, the synthetic tensor should be preserved.
465d30dccd2SPeiming Liu /// {
466d30dccd2SPeiming Liu /// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
467d30dccd2SPeiming Liu /// }
468d30dccd2SPeiming Liu #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
469d82e93e7SPeiming Liu TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
470164b918dSwren romano const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
471164b918dSwren romano const auto e = CONJ2##Expr(em, tensor(2)); \
472164b918dSwren romano const auto l0 = lid(0); \
473164b918dSwren romano const auto t0 = tid(0); \
474164b918dSwren romano const auto t1 = tid(1); \
475164b918dSwren romano const auto t2 = tid(2); \
476164b918dSwren romano const auto t3 = tid(3); \
477fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
478fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
479fd9b3e47SAart Bik const Match &p2 = tensorMatch(t2); \
480d30dccd2SPeiming Liu auto s = merger.buildLattices(e, l0); \
481d30dccd2SPeiming Liu expectNumLatPoints(s, 1); \
482fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
483d30dccd2SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
484d30dccd2SPeiming Liu s = merger.optimizeSet(s); \
485d30dccd2SPeiming Liu expectNumLatPoints(s, 1); \
486fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
487d30dccd2SPeiming Liu loopsToBits({{l0, t3}}), true); \
488d30dccd2SPeiming Liu }
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)489d30dccd2SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
490d30dccd2SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
49101dffc5aSPeiming Liu
49266ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors. i.e.;
49340843347SGus Smith /// a(i) = b(i) + c(i)
49440843347SGus Smith /// which should form the 3 lattice points
49540843347SGus Smith /// {
49640843347SGus Smith /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
49740843347SGus Smith /// lat( i_00 / tensor_0 )
49840843347SGus Smith /// lat( i_01 / tensor_1 )
49940843347SGus Smith /// }
50066ae1d60SPeiming Liu /// and after optimization, the lattice points do not change (as there is no
50166ae1d60SPeiming Liu /// duplicated point and all input vectors are sparse vector).
50240843347SGus Smith /// {
50340843347SGus Smith /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
50440843347SGus Smith /// lat( i_00 / tensor_0 )
50566ae1d60SPeiming Liu /// lat( i_01 / tensor_1 )
50640843347SGus Smith /// }
5071e15d791SAart Bik #define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
508d82e93e7SPeiming Liu TEST_P(MergerTest3T1L, vector_##OP) { \
509164b918dSwren romano const auto e = OP##Expr(tensor(0), tensor(1)); \
510164b918dSwren romano const auto l0 = lid(0); \
511164b918dSwren romano const auto t0 = tid(0); \
512164b918dSwren romano const auto t1 = tid(1); \
513fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
514fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
51566ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
51666ae1d60SPeiming Liu \
51766ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
518fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), \
51966ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
520164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
521164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
52266ae1d60SPeiming Liu \
52366ae1d60SPeiming Liu s = merger.optimizeSet(s); \
52466ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
525fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
526fd9b3e47SAart Bik true); \
527164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
528164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
52940843347SGus Smith }
53066ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
53166ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ
53266ae1d60SPeiming Liu
53366ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.;
53440843347SGus Smith /// a(i) = b(i) * c(i)
53540843347SGus Smith /// which should form the single lattice point
53640843347SGus Smith /// {
53740843347SGus Smith /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
53840843347SGus Smith /// }
5391e15d791SAart Bik #define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
540d82e93e7SPeiming Liu TEST_P(MergerTest3T1L, vector_##OP) { \
541164b918dSwren romano const auto e = OP##Expr(tensor(0), tensor(1)); \
542164b918dSwren romano const auto l0 = lid(0); \
543164b918dSwren romano const auto t0 = tid(0); \
544164b918dSwren romano const auto t1 = tid(1); \
545fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
546fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
54766ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
54866ae1d60SPeiming Liu \
54966ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
550fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), \
55166ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
55266ae1d60SPeiming Liu \
55366ae1d60SPeiming Liu s = merger.optimizeSet(s); \
55466ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
555fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
556fd9b3e47SAart Bik true); \
55740843347SGus Smith }
55866ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
55966ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ
56066ae1d60SPeiming Liu
56166ae1d60SPeiming Liu /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
56266ae1d60SPeiming Liu /// a(i) = b(i) * c(i) + d(i);
56366ae1d60SPeiming Liu /// which should form
56466ae1d60SPeiming Liu /// {
56566ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
56666ae1d60SPeiming Liu /// lat( i_00 i_01 / tensor_0 * tensor_1
56766ae1d60SPeiming Liu /// lat( i_02 / tensor_2 )
56866ae1d60SPeiming Liu /// }
56966ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
570d82e93e7SPeiming Liu TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
571164b918dSwren romano const auto em = CONJ##Expr(tensor(0), tensor(1)); \
572164b918dSwren romano const auto e = DISJ##Expr(em, tensor(2)); \
573164b918dSwren romano const auto l0 = lid(0); \
574164b918dSwren romano const auto t0 = tid(0); \
575164b918dSwren romano const auto t1 = tid(1); \
576164b918dSwren romano const auto t2 = tid(2); \
577fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
578fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
579fd9b3e47SAart Bik const Match &p2 = tensorMatch(t2); \
58066ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
58166ae1d60SPeiming Liu \
58266ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
583fd9b3e47SAart Bik expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
58466ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
585fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
58666ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
587164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
58866ae1d60SPeiming Liu \
58966ae1d60SPeiming Liu s = merger.optimizeSet(s); \
59066ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
591fd9b3e47SAart Bik expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
59266ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
593fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
59466ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
595164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
59666ae1d60SPeiming Liu }
59766ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
59866ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_DISJ
59966ae1d60SPeiming Liu
60066ae1d60SPeiming Liu /// Vector addition (disjunction) then addition (disjunction), i.e.;
60166ae1d60SPeiming Liu /// a(i) = b(i) + c(i) + d(i)
60266ae1d60SPeiming Liu /// which should form
60366ae1d60SPeiming Liu /// {
60466ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
60566ae1d60SPeiming Liu /// lat( i_02 i_01 / tensor_2 + tensor_1 )
60666ae1d60SPeiming Liu /// lat( i_02 i_00 / tensor_2 + tensor_0 )
60766ae1d60SPeiming Liu /// lat( i_01 i_00 / tensor_1 + tensor_0 )
60866ae1d60SPeiming Liu /// lat( i_02 / tensor_2 )
60966ae1d60SPeiming Liu /// lat( i_01 / tensor_1 )
61066ae1d60SPeiming Liu /// lat( i_00 / tensor_0 )
61166ae1d60SPeiming Liu /// }
61266ae1d60SPeiming Liu #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
613d82e93e7SPeiming Liu TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
614164b918dSwren romano const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
615164b918dSwren romano const auto e = DISJ2##Expr(em, tensor(2)); \
616164b918dSwren romano const auto l0 = lid(0); \
617164b918dSwren romano const auto t0 = tid(0); \
618164b918dSwren romano const auto t1 = tid(1); \
619164b918dSwren romano const auto t2 = tid(2); \
620fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
621fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
622fd9b3e47SAart Bik const Match &p2 = tensorMatch(t2); \
62366ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
62466ae1d60SPeiming Liu \
62566ae1d60SPeiming Liu expectNumLatPoints(s, 7); \
626fd9b3e47SAart Bik expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
62766ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
628fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
62966ae1d60SPeiming Liu loopsToBits({{l0, t1}, {l0, t2}})); \
630fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
63166ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t2}})); \
632fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
63366ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
634164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
635164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
636164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
63766ae1d60SPeiming Liu \
63866ae1d60SPeiming Liu s = merger.optimizeSet(s); \
63966ae1d60SPeiming Liu expectNumLatPoints(s, 7); \
640fd9b3e47SAart Bik expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
64166ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
642fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
64366ae1d60SPeiming Liu loopsToBits({{l0, t1}, {l0, t2}})); \
644fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
64566ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t2}})); \
646fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
64766ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
648164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
649164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
650164b918dSwren romano expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
65166ae1d60SPeiming Liu }
65266ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
65366ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ_DISJ
65466ae1d60SPeiming Liu
65566ae1d60SPeiming Liu /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
65666ae1d60SPeiming Liu /// a(i) = b(i) * c(i) * d(i);
65766ae1d60SPeiming Liu /// which should form
65866ae1d60SPeiming Liu /// {
65966ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
66066ae1d60SPeiming Liu /// }
66166ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
662d82e93e7SPeiming Liu TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
663164b918dSwren romano const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
664164b918dSwren romano const auto e = CONJ2##Expr(em, tensor(2)); \
665164b918dSwren romano const auto l0 = lid(0); \
666164b918dSwren romano const auto t0 = tid(0); \
667164b918dSwren romano const auto t1 = tid(1); \
668164b918dSwren romano const auto t2 = tid(2); \
669fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
670fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
671fd9b3e47SAart Bik const Match &p2 = tensorMatch(t2); \
67266ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
67366ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
674fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
67566ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
67666ae1d60SPeiming Liu s = merger.optimizeSet(s); \
67766ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
678fd9b3e47SAart Bik expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
67966ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
68066ae1d60SPeiming Liu }
68166ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
68266ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_CONJ
68366ae1d60SPeiming Liu
68466ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors, i.e.;
68566ae1d60SPeiming Liu /// a(i) = b(i) + c(i)
68666ae1d60SPeiming Liu /// which should form the 3 lattice points
68766ae1d60SPeiming Liu /// {
68866ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
68966ae1d60SPeiming Liu /// lat( i_00 / sparse_tensor_0 )
69066ae1d60SPeiming Liu /// lat( i_01 / dense_tensor_1 )
69166ae1d60SPeiming Liu /// }
69266ae1d60SPeiming Liu /// which should be optimized to
69366ae1d60SPeiming Liu /// {
69466ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
69566ae1d60SPeiming Liu /// lat( i_01 / dense_tensor_0 ) (no sparse dimension)
69666ae1d60SPeiming Liu /// }
69766ae1d60SPeiming Liu ///
69866ae1d60SPeiming Liu /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
69966ae1d60SPeiming Liu /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
7001e15d791SAart Bik #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
701d82e93e7SPeiming Liu TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
702164b918dSwren romano const auto e = OP##Expr(tensor(0), tensor(1)); \
703164b918dSwren romano const auto l0 = lid(0); \
704164b918dSwren romano const auto t0 = tid(0); \
705164b918dSwren romano const auto t1 = tid(1); \
706fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
707fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
70866ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
70966ae1d60SPeiming Liu \
71066ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
711fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), \
71266ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
713164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
714164b918dSwren romano expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
71566ae1d60SPeiming Liu \
71666ae1d60SPeiming Liu s = merger.optimizeSet(s); \
71766ae1d60SPeiming Liu expectNumLatPoints(s, 2); \
718fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
719fd9b3e47SAart Bik true); \
720164b918dSwren romano expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
72166ae1d60SPeiming Liu }
72266ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
72366ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
72466ae1d60SPeiming Liu
72566ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.:
72666ae1d60SPeiming Liu /// a(i) = b(i) * c(i)
72766ae1d60SPeiming Liu /// which should form the single lattice point
72866ae1d60SPeiming Liu /// {
72966ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
73066ae1d60SPeiming Liu /// }
73166ae1d60SPeiming Liu /// it should be optimized to
73266ae1d60SPeiming Liu /// {
73366ae1d60SPeiming Liu /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
73466ae1d60SPeiming Liu /// }
73566ae1d60SPeiming Liu /// since i_01 is a dense dimension.
7361e15d791SAart Bik #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
737d82e93e7SPeiming Liu TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
738164b918dSwren romano const auto e = OP##Expr(tensor(0), tensor(1)); \
739164b918dSwren romano const auto l0 = lid(0); \
740164b918dSwren romano const auto t0 = tid(0); \
741164b918dSwren romano const auto t1 = tid(1); \
742fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0); \
743fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1); \
74466ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
74566ae1d60SPeiming Liu \
74666ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
747fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), \
74866ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
74966ae1d60SPeiming Liu \
75066ae1d60SPeiming Liu s = merger.optimizeSet(s); \
75166ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
752fd9b3e47SAart Bik expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \
75366ae1d60SPeiming Liu }
75466ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
755fd9b3e47SAart Bik #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
75666ae1d60SPeiming Liu
757faf7cd97SPeiming Liu /// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
758faf7cd97SPeiming Liu /// a(i) = b(i) + c(i)
759faf7cd97SPeiming Liu /// which should form the 3 lattice points
760faf7cd97SPeiming Liu /// {
761faf7cd97SPeiming Liu /// lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
762faf7cd97SPeiming Liu /// lat( i_00 / tensor_0 cmp 0 )
763faf7cd97SPeiming Liu /// lat( i_01 / 0 cmp tensor_1 )
764faf7cd97SPeiming Liu /// }
765faf7cd97SPeiming Liu /// and after optimization, the lattice points do not change (as there is no
766faf7cd97SPeiming Liu /// duplicated point and all input vectors are sparse vector).
767faf7cd97SPeiming Liu /// {
768faf7cd97SPeiming Liu /// lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
769faf7cd97SPeiming Liu /// lat( i_00 / tensor_0 cmp 0 )
770faf7cd97SPeiming Liu /// lat( i_01 / 0 cmp tensor_1 )
771faf7cd97SPeiming Liu /// }
772d82e93e7SPeiming Liu TEST_P(MergerTest3T1L, vector_cmp) {
773faf7cd97SPeiming Liu const auto e = cmpiExpr(tensor(0), tensor(1));
774faf7cd97SPeiming Liu const auto l0 = lid(0);
775faf7cd97SPeiming Liu const auto t0 = tid(0);
776faf7cd97SPeiming Liu const auto t1 = tid(1);
777fd9b3e47SAart Bik const Match &zero = synZeroMatch();
778fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0);
779fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1);
780faf7cd97SPeiming Liu auto s = merger.buildLattices(e, l0);
781fd9b3e47SAart Bik expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
782fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
783faf7cd97SPeiming Liu loopsToBits({{l0, t0}}));
784fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
785faf7cd97SPeiming Liu loopsToBits({{l0, t1}}));
786faf7cd97SPeiming Liu s = merger.optimizeSet(s);
787fd9b3e47SAart Bik expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
788fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
789faf7cd97SPeiming Liu loopsToBits({{l0, t0}}));
790fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
791faf7cd97SPeiming Liu loopsToBits({{l0, t1}}));
792faf7cd97SPeiming Liu }
793faf7cd97SPeiming Liu
794faf7cd97SPeiming Liu /// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.;
795faf7cd97SPeiming Liu /// a(i) = b(i) cmp c(i)
796faf7cd97SPeiming Liu /// which should form the 3 lattice points
797faf7cd97SPeiming Liu /// {
798faf7cd97SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) )
799faf7cd97SPeiming Liu /// lat( i_00 / sparse_tensor_0 cmp 0)
800faf7cd97SPeiming Liu /// lat( i_01 / 0 cmp dense_tensor_1 )
801faf7cd97SPeiming Liu /// }
802faf7cd97SPeiming Liu /// which should be optimized to
803faf7cd97SPeiming Liu /// {
804faf7cd97SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton)
805faf7cd97SPeiming Liu /// lat( i_01 / 0 cmp dense_tensor_0 ) ()
806faf7cd97SPeiming Liu /// }
807faf7cd97SPeiming Liu ///
808faf7cd97SPeiming Liu /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
809faf7cd97SPeiming Liu /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_P(MergerTest3T1LD,vector_cmp)810d82e93e7SPeiming Liu TEST_P(MergerTest3T1LD, vector_cmp) {
811faf7cd97SPeiming Liu const auto e = cmpiExpr(tensor(0), tensor(1));
812faf7cd97SPeiming Liu const auto l0 = lid(0);
813faf7cd97SPeiming Liu const auto t0 = tid(0);
814faf7cd97SPeiming Liu const auto t1 = tid(1);
815fd9b3e47SAart Bik const Match &zero = synZeroMatch();
816fd9b3e47SAart Bik const Match &p0 = tensorMatch(t0);
817fd9b3e47SAart Bik const Match &p1 = tensorMatch(t1);
818faf7cd97SPeiming Liu auto s = merger.buildLattices(e, l0);
819fd9b3e47SAart Bik expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
820fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
821faf7cd97SPeiming Liu loopsToBits({{l0, t0}}));
822fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
823faf7cd97SPeiming Liu loopsToBits({{l0, t1}}));
824faf7cd97SPeiming Liu s = merger.optimizeSet(s);
825fd9b3e47SAart Bik expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
826fd9b3e47SAart Bik expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
827faf7cd97SPeiming Liu loopsToBits({{l0, t1}}));
828faf7cd97SPeiming Liu }
829