1b1fa5ac3SMircea Trofin //===- TensorSpecTest.cpp - test for TensorSpec ---------------------------===//
2b1fa5ac3SMircea Trofin //
3b1fa5ac3SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1fa5ac3SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5b1fa5ac3SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1fa5ac3SMircea Trofin //
7b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
8b1fa5ac3SMircea Trofin
9b1fa5ac3SMircea Trofin #include "llvm/Analysis/TensorSpec.h"
10*8bb3b144SNikita Popov #include "llvm/Support/JSON.h"
11b1fa5ac3SMircea Trofin #include "llvm/Support/Path.h"
12b1fa5ac3SMircea Trofin #include "llvm/Support/SourceMgr.h"
13b1fa5ac3SMircea Trofin #include "llvm/Testing/Support/SupportHelpers.h"
14b1fa5ac3SMircea Trofin #include "gtest/gtest.h"
15b1fa5ac3SMircea Trofin
16b1fa5ac3SMircea Trofin using namespace llvm;
17b1fa5ac3SMircea Trofin
18b1fa5ac3SMircea Trofin extern const char *TestMainArgv0;
19b1fa5ac3SMircea Trofin
TEST(TensorSpecTest,JSONParsing)20b1fa5ac3SMircea Trofin TEST(TensorSpecTest, JSONParsing) {
21b1fa5ac3SMircea Trofin auto Value = json::parse(
22b1fa5ac3SMircea Trofin R"({"name": "tensor_name",
23b1fa5ac3SMircea Trofin "port": 2,
24b1fa5ac3SMircea Trofin "type": "int32_t",
25b1fa5ac3SMircea Trofin "shape":[1,4]
26b1fa5ac3SMircea Trofin })");
27b1fa5ac3SMircea Trofin EXPECT_TRUE(!!Value);
28b1fa5ac3SMircea Trofin LLVMContext Ctx;
29d4b6fcb3SFangrui Song std::optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
30d152e50cSKazu Hirata EXPECT_TRUE(Spec);
31b1fa5ac3SMircea Trofin EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
32b1fa5ac3SMircea Trofin }
33b1fa5ac3SMircea Trofin
TEST(TensorSpecTest,JSONParsingInvalidTensorType)34b1fa5ac3SMircea Trofin TEST(TensorSpecTest, JSONParsingInvalidTensorType) {
35b1fa5ac3SMircea Trofin auto Value = json::parse(
36b1fa5ac3SMircea Trofin R"(
37b1fa5ac3SMircea Trofin {"name": "tensor_name",
38b1fa5ac3SMircea Trofin "port": 2,
39b1fa5ac3SMircea Trofin "type": "no such type",
40b1fa5ac3SMircea Trofin "shape":[1,4]
41b1fa5ac3SMircea Trofin }
42b1fa5ac3SMircea Trofin )");
43b1fa5ac3SMircea Trofin EXPECT_TRUE(!!Value);
44b1fa5ac3SMircea Trofin LLVMContext Ctx;
45b1fa5ac3SMircea Trofin auto Spec = getTensorSpecFromJSON(Ctx, *Value);
46d152e50cSKazu Hirata EXPECT_FALSE(Spec);
47b1fa5ac3SMircea Trofin }
48b1fa5ac3SMircea Trofin
TEST(TensorSpecTest,TensorSpecSizesAndTypes)49b1fa5ac3SMircea Trofin TEST(TensorSpecTest, TensorSpecSizesAndTypes) {
50b1fa5ac3SMircea Trofin auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
51b1fa5ac3SMircea Trofin auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
52b1fa5ac3SMircea Trofin auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
53b1fa5ac3SMircea Trofin auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
54b1fa5ac3SMircea Trofin EXPECT_TRUE(Spec1D.isElementType<int16_t>());
55b1fa5ac3SMircea Trofin EXPECT_FALSE(Spec3DLarge.isElementType<double>());
56b1fa5ac3SMircea Trofin EXPECT_EQ(Spec1D.getElementCount(), 1U);
57b1fa5ac3SMircea Trofin EXPECT_EQ(Spec2D.getElementCount(), 1U);
58b1fa5ac3SMircea Trofin EXPECT_EQ(Spec1DLarge.getElementCount(), 10U);
59b1fa5ac3SMircea Trofin EXPECT_EQ(Spec3DLarge.getElementCount(), 80U);
60b1fa5ac3SMircea Trofin EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
61b1fa5ac3SMircea Trofin EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
62b1fa5ac3SMircea Trofin }
635b8dc7c8SMircea Trofin
TEST(TensorSpecTest,PrintValueForDebug)645b8dc7c8SMircea Trofin TEST(TensorSpecTest, PrintValueForDebug) {
655b8dc7c8SMircea Trofin std::vector<int32_t> Values{1, 3};
665b8dc7c8SMircea Trofin EXPECT_EQ(tensorValueToString(reinterpret_cast<const char *>(Values.data()),
675b8dc7c8SMircea Trofin TensorSpec::createSpec<int32_t>("name", {2})),
685b8dc7c8SMircea Trofin "1,3");
695b8dc7c8SMircea Trofin }
70