1 //===- TensorSpecTest.cpp - test for TensorSpec ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/TensorSpec.h" 10 #include "llvm/Support/Path.h" 11 #include "llvm/Support/SourceMgr.h" 12 #include "llvm/Testing/Support/SupportHelpers.h" 13 #include "gtest/gtest.h" 14 15 using namespace llvm; 16 17 extern const char *TestMainArgv0; 18 19 TEST(TensorSpecTest, JSONParsing) { 20 auto Value = json::parse( 21 R"({"name": "tensor_name", 22 "port": 2, 23 "type": "int32_t", 24 "shape":[1,4] 25 })"); 26 EXPECT_TRUE(!!Value); 27 LLVMContext Ctx; 28 std::optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value); 29 EXPECT_TRUE(Spec); 30 EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2)); 31 } 32 33 TEST(TensorSpecTest, JSONParsingInvalidTensorType) { 34 auto Value = json::parse( 35 R"( 36 {"name": "tensor_name", 37 "port": 2, 38 "type": "no such type", 39 "shape":[1,4] 40 } 41 )"); 42 EXPECT_TRUE(!!Value); 43 LLVMContext Ctx; 44 auto Spec = getTensorSpecFromJSON(Ctx, *Value); 45 EXPECT_FALSE(Spec); 46 } 47 48 TEST(TensorSpecTest, TensorSpecSizesAndTypes) { 49 auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1}); 50 auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1}); 51 auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10}); 52 auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10}); 53 EXPECT_TRUE(Spec1D.isElementType<int16_t>()); 54 EXPECT_FALSE(Spec3DLarge.isElementType<double>()); 55 EXPECT_EQ(Spec1D.getElementCount(), 1U); 56 EXPECT_EQ(Spec2D.getElementCount(), 1U); 57 EXPECT_EQ(Spec1DLarge.getElementCount(), 10U); 58 EXPECT_EQ(Spec3DLarge.getElementCount(), 80U); 59 EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); 60 EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); 61 } 62