1b1fa5ac3SMircea Trofin //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2b1fa5ac3SMircea Trofin //
3b1fa5ac3SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1fa5ac3SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5b1fa5ac3SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1fa5ac3SMircea Trofin //
7b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
8b1fa5ac3SMircea Trofin //
9b1fa5ac3SMircea Trofin // Implementation file for the abstraction of a tensor type, and JSON loading
10b1fa5ac3SMircea Trofin // utils.
11b1fa5ac3SMircea Trofin //
12b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
135b8dc7c8SMircea Trofin #include "llvm/ADT/STLExtras.h"
14b1fa5ac3SMircea Trofin #include "llvm/Config/config.h"
15b1fa5ac3SMircea Trofin
165b8dc7c8SMircea Trofin #include "llvm/ADT/StringExtras.h"
17b1fa5ac3SMircea Trofin #include "llvm/ADT/Twine.h"
18b1fa5ac3SMircea Trofin #include "llvm/Analysis/TensorSpec.h"
19b1fa5ac3SMircea Trofin #include "llvm/Support/CommandLine.h"
20b1fa5ac3SMircea Trofin #include "llvm/Support/Debug.h"
21b1fa5ac3SMircea Trofin #include "llvm/Support/JSON.h"
22b1fa5ac3SMircea Trofin #include "llvm/Support/ManagedStatic.h"
23b1fa5ac3SMircea Trofin #include "llvm/Support/raw_ostream.h"
244c97745bSMircea Trofin #include <array>
25b1fa5ac3SMircea Trofin #include <cassert>
26b1fa5ac3SMircea Trofin #include <numeric>
27b1fa5ac3SMircea Trofin
28b1fa5ac3SMircea Trofin using namespace llvm;
29b1fa5ac3SMircea Trofin
30b1fa5ac3SMircea Trofin namespace llvm {
31b1fa5ac3SMircea Trofin
32b1fa5ac3SMircea Trofin #define TFUTILS_GETDATATYPE_IMPL(T, E) \
33b1fa5ac3SMircea Trofin template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34b1fa5ac3SMircea Trofin
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)35b1fa5ac3SMircea Trofin SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
36b1fa5ac3SMircea Trofin
37b1fa5ac3SMircea Trofin #undef TFUTILS_GETDATATYPE_IMPL
38b1fa5ac3SMircea Trofin
394c97745bSMircea Trofin static std::array<std::string, static_cast<size_t>(TensorType::Total)>
404c97745bSMircea Trofin TensorTypeNames{"INVALID",
414c97745bSMircea Trofin #define TFUTILS_GETNAME_IMPL(T, _) #T,
424c97745bSMircea Trofin SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
434c97745bSMircea Trofin #undef TFUTILS_GETNAME_IMPL
444c97745bSMircea Trofin };
454c97745bSMircea Trofin
toString(TensorType TT)464c97745bSMircea Trofin StringRef toString(TensorType TT) {
474c97745bSMircea Trofin return TensorTypeNames[static_cast<size_t>(TT)];
484c97745bSMircea Trofin }
494c97745bSMircea Trofin
toJSON(json::OStream & OS) const504c97745bSMircea Trofin void TensorSpec::toJSON(json::OStream &OS) const {
514c97745bSMircea Trofin OS.object([&]() {
524c97745bSMircea Trofin OS.attribute("name", name());
534c97745bSMircea Trofin OS.attribute("type", toString(type()));
544c97745bSMircea Trofin OS.attribute("port", port());
554c97745bSMircea Trofin OS.attributeArray("shape", [&]() {
564c97745bSMircea Trofin for (size_t D : shape())
574c97745bSMircea Trofin OS.value(static_cast<int64_t>(D));
584c97745bSMircea Trofin });
594c97745bSMircea Trofin });
604c97745bSMircea Trofin }
614c97745bSMircea Trofin
TensorSpec(const std::string & Name,int Port,TensorType Type,size_t ElementSize,const std::vector<int64_t> & Shape)62b1fa5ac3SMircea Trofin TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63b1fa5ac3SMircea Trofin size_t ElementSize, const std::vector<int64_t> &Shape)
64b1fa5ac3SMircea Trofin : Name(Name), Port(Port), Type(Type), Shape(Shape),
65b1fa5ac3SMircea Trofin ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66b1fa5ac3SMircea Trofin std::multiplies<int64_t>())),
67b1fa5ac3SMircea Trofin ElementSize(ElementSize) {}
68b1fa5ac3SMircea Trofin
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)69d4b6fcb3SFangrui Song std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70b1fa5ac3SMircea Trofin const json::Value &Value) {
71d4b6fcb3SFangrui Song auto EmitError =
72d4b6fcb3SFangrui Song [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73b1fa5ac3SMircea Trofin std::string S;
74b1fa5ac3SMircea Trofin llvm::raw_string_ostream OS(S);
75b1fa5ac3SMircea Trofin OS << Value;
76b1fa5ac3SMircea Trofin Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
7719aff0f3SKazu Hirata return std::nullopt;
78b1fa5ac3SMircea Trofin };
79b1fa5ac3SMircea Trofin // FIXME: accept a Path as a parameter, and use it for error reporting.
80b1fa5ac3SMircea Trofin json::Path::Root Root("tensor_spec");
81b1fa5ac3SMircea Trofin json::ObjectMapper Mapper(Value, Root);
82b1fa5ac3SMircea Trofin if (!Mapper)
83b1fa5ac3SMircea Trofin return EmitError("Value is not a dict");
84b1fa5ac3SMircea Trofin
85b1fa5ac3SMircea Trofin std::string TensorName;
86b1fa5ac3SMircea Trofin int TensorPort = -1;
87b1fa5ac3SMircea Trofin std::string TensorType;
88b1fa5ac3SMircea Trofin std::vector<int64_t> TensorShape;
89b1fa5ac3SMircea Trofin
90b1fa5ac3SMircea Trofin if (!Mapper.map<std::string>("name", TensorName))
91b1fa5ac3SMircea Trofin return EmitError("'name' property not present or not a string");
92b1fa5ac3SMircea Trofin if (!Mapper.map<std::string>("type", TensorType))
93b1fa5ac3SMircea Trofin return EmitError("'type' property not present or not a string");
94b1fa5ac3SMircea Trofin if (!Mapper.map<int>("port", TensorPort))
95b1fa5ac3SMircea Trofin return EmitError("'port' property not present or not an int");
96b1fa5ac3SMircea Trofin if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97b1fa5ac3SMircea Trofin return EmitError("'shape' property not present or not an int array");
98b1fa5ac3SMircea Trofin
99b1fa5ac3SMircea Trofin #define PARSE_TYPE(T, E) \
100b1fa5ac3SMircea Trofin if (TensorType == #T) \
101b1fa5ac3SMircea Trofin return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102b1fa5ac3SMircea Trofin SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
103b1fa5ac3SMircea Trofin #undef PARSE_TYPE
10419aff0f3SKazu Hirata return std::nullopt;
105b1fa5ac3SMircea Trofin }
106b1fa5ac3SMircea Trofin
tensorValueToString(const char * Buffer,const TensorSpec & Spec)1075b8dc7c8SMircea Trofin std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
1085b8dc7c8SMircea Trofin switch (Spec.type()) {
1095b8dc7c8SMircea Trofin #define _IMR_DBG_PRINTER(T, N) \
1105b8dc7c8SMircea Trofin case TensorType::N: { \
1115b8dc7c8SMircea Trofin const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
1125b8dc7c8SMircea Trofin auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
1135b8dc7c8SMircea Trofin return llvm::join( \
1145b8dc7c8SMircea Trofin llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
1155b8dc7c8SMircea Trofin }
1165b8dc7c8SMircea Trofin SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
1175b8dc7c8SMircea Trofin #undef _IMR_DBG_PRINTER
1185b8dc7c8SMircea Trofin case TensorType::Total:
1195b8dc7c8SMircea Trofin case TensorType::Invalid:
1205b8dc7c8SMircea Trofin llvm_unreachable("invalid tensor type");
1215b8dc7c8SMircea Trofin }
122*7d31d3b0SMircea Trofin // To appease warnings about not all control paths returning a value.
123*7d31d3b0SMircea Trofin return "";
1245b8dc7c8SMircea Trofin }
1255b8dc7c8SMircea Trofin
126b1fa5ac3SMircea Trofin } // namespace llvm
127