1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation file for the abstraction of a tensor type, and JSON loading 10 // utils. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/Config/config.h" 14 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Analysis/TensorSpec.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/JSON.h" 20 #include "llvm/Support/ManagedStatic.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <array> 23 #include <cassert> 24 #include <numeric> 25 26 using namespace llvm; 27 28 namespace llvm { 29 30 #define TFUTILS_GETDATATYPE_IMPL(T, E) \ 31 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } 32 33 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) 34 35 #undef TFUTILS_GETDATATYPE_IMPL 36 37 static std::array<std::string, static_cast<size_t>(TensorType::Total)> 38 TensorTypeNames{"INVALID", 39 #define TFUTILS_GETNAME_IMPL(T, _) #T, 40 SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) 41 #undef TFUTILS_GETNAME_IMPL 42 }; 43 44 StringRef toString(TensorType TT) { 45 return TensorTypeNames[static_cast<size_t>(TT)]; 46 } 47 48 void TensorSpec::toJSON(json::OStream &OS) const { 49 OS.object([&]() { 50 OS.attribute("name", name()); 51 OS.attribute("type", toString(type())); 52 OS.attribute("port", port()); 53 OS.attributeArray("shape", [&]() { 54 for (size_t D : shape()) 55 OS.value(D); 56 }); 57 }); 58 } 59 60 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, 61 size_t ElementSize, const std::vector<int64_t> &Shape) 62 : Name(Name), Port(Port), Type(Type), Shape(Shape), 63 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, 64 std::multiplies<int64_t>())), 65 ElementSize(ElementSize) {} 66 67 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, 68 const json::Value &Value) { 69 auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> { 70 std::string S; 71 llvm::raw_string_ostream OS(S); 72 OS << Value; 73 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); 74 return std::nullopt; 75 }; 76 // FIXME: accept a Path as a parameter, and use it for error reporting. 77 json::Path::Root Root("tensor_spec"); 78 json::ObjectMapper Mapper(Value, Root); 79 if (!Mapper) 80 return EmitError("Value is not a dict"); 81 82 std::string TensorName; 83 int TensorPort = -1; 84 std::string TensorType; 85 std::vector<int64_t> TensorShape; 86 87 if (!Mapper.map<std::string>("name", TensorName)) 88 return EmitError("'name' property not present or not a string"); 89 if (!Mapper.map<std::string>("type", TensorType)) 90 return EmitError("'type' property not present or not a string"); 91 if (!Mapper.map<int>("port", TensorPort)) 92 return EmitError("'port' property not present or not an int"); 93 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) 94 return EmitError("'shape' property not present or not an int array"); 95 96 #define PARSE_TYPE(T, E) \ 97 if (TensorType == #T) \ 98 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); 99 SUPPORTED_TENSOR_TYPES(PARSE_TYPE) 100 #undef PARSE_TYPE 101 return std::nullopt; 102 } 103 104 } // namespace llvm 105