17972dcefSLei Zhang //===- BroadcastShapeTest.cpp - broadcasting shape unit tests -------------===//
27972dcefSLei Zhang //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67972dcefSLei Zhang //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
87972dcefSLei Zhang
97972dcefSLei Zhang #include "mlir/Dialect/Traits.h"
10eeadfbc1SLei Zhang #include "llvm/ADT/SmallVector.h"
117972dcefSLei Zhang #include "gmock/gmock.h"
127972dcefSLei Zhang
137972dcefSLei Zhang using namespace mlir::OpTrait::util;
147972dcefSLei Zhang
15eeadfbc1SLei Zhang using llvm::SmallVector;
167972dcefSLei Zhang using ::testing::ElementsAre;
177972dcefSLei Zhang
TEST(BroadcastShapeTest,CompatibleScalarAndScalar)187972dcefSLei Zhang TEST(BroadcastShapeTest, CompatibleScalarAndScalar) {
19eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
20eeadfbc1SLei Zhang ASSERT_TRUE(getBroadcastedShape({}, {}, result));
21eeadfbc1SLei Zhang EXPECT_TRUE(result.empty());
227972dcefSLei Zhang }
237972dcefSLei Zhang
TEST(BroadcastShapeTest,Compatible0DAnd1DTensor)247972dcefSLei Zhang TEST(BroadcastShapeTest, Compatible0DAnd1DTensor) {
25eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
26eeadfbc1SLei Zhang ASSERT_TRUE(getBroadcastedShape({}, {4}, result));
27eeadfbc1SLei Zhang EXPECT_THAT(result, ElementsAre(4));
287972dcefSLei Zhang }
297972dcefSLei Zhang
TEST(BroadcastShapeTest,Compatible0DAnd3DTensor)307972dcefSLei Zhang TEST(BroadcastShapeTest, Compatible0DAnd3DTensor) {
31eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
32eeadfbc1SLei Zhang ASSERT_TRUE(getBroadcastedShape({}, {3, 5, 4}, result));
33eeadfbc1SLei Zhang EXPECT_THAT(result, ElementsAre(3, 5, 4));
347972dcefSLei Zhang }
357972dcefSLei Zhang
TEST(BroadcastShapeTest,CompatibleTensorAndTensor)367972dcefSLei Zhang TEST(BroadcastShapeTest, CompatibleTensorAndTensor) {
37eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
38eeadfbc1SLei Zhang ASSERT_TRUE(getBroadcastedShape({1, 7, 8, 9}, {8, 9}, result));
39eeadfbc1SLei Zhang EXPECT_THAT(result, ElementsAre(1, 7, 8, 9));
407972dcefSLei Zhang }
417972dcefSLei Zhang
TEST(BroadcastShapeTest,InterleavingOnes)427972dcefSLei Zhang TEST(BroadcastShapeTest, InterleavingOnes) {
43eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
44eeadfbc1SLei Zhang ASSERT_TRUE(getBroadcastedShape({8, 1, 2, 1, 4}, {5, 1, 7, 1}, result));
45eeadfbc1SLei Zhang EXPECT_THAT(result, ElementsAre(8, 5, 2, 7, 4));
467972dcefSLei Zhang }
477972dcefSLei Zhang
TEST(BroadcastShapeTest,InterleavingUnknowns)487972dcefSLei Zhang TEST(BroadcastShapeTest, InterleavingUnknowns) {
49eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
50*399638f9SAliia Khasanova int64_t dyn = mlir::ShapedType::kDynamic;
51fb4cedccSAliia Khasanova ASSERT_TRUE(getBroadcastedShape({1, 2, dyn, dyn, dyn}, {dyn, dyn, dyn, 4, 1},
52fb4cedccSAliia Khasanova result));
53fb4cedccSAliia Khasanova EXPECT_THAT(result, ElementsAre(dyn, 2, dyn, 4, dyn));
547972dcefSLei Zhang }
557972dcefSLei Zhang
TEST(BroadcastShapeTest,IncompatibleLowDim)567972dcefSLei Zhang TEST(BroadcastShapeTest, IncompatibleLowDim) {
57eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
58eeadfbc1SLei Zhang ASSERT_FALSE(getBroadcastedShape({4, 3, 5, 5}, {3, 5, 4}, result));
59eeadfbc1SLei Zhang EXPECT_TRUE(result.empty());
607972dcefSLei Zhang }
617972dcefSLei Zhang
TEST(BroadcastShapeTest,IncompatibleMiddleDim)627972dcefSLei Zhang TEST(BroadcastShapeTest, IncompatibleMiddleDim) {
63eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
64eeadfbc1SLei Zhang ASSERT_FALSE(getBroadcastedShape({4, 3, 5, 5}, {3, 7, 5}, result));
65eeadfbc1SLei Zhang EXPECT_TRUE(result.empty());
667972dcefSLei Zhang }
677972dcefSLei Zhang
TEST(BroadcastShapeTest,IncompatibleHighDim)687972dcefSLei Zhang TEST(BroadcastShapeTest, IncompatibleHighDim) {
69eeadfbc1SLei Zhang SmallVector<int64_t, 4> result;
70eeadfbc1SLei Zhang ASSERT_FALSE(getBroadcastedShape({3, 5, 5}, {4, 5, 5}, result));
71eeadfbc1SLei Zhang EXPECT_TRUE(result.empty());
727972dcefSLei Zhang }
73