xref: /llvm-project/llvm/lib/Analysis/TensorSpec.cpp (revision 7d31d3b09844897821db029f96682853160863d0)
1b1fa5ac3SMircea Trofin //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2b1fa5ac3SMircea Trofin //
3b1fa5ac3SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1fa5ac3SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5b1fa5ac3SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1fa5ac3SMircea Trofin //
7b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
8b1fa5ac3SMircea Trofin //
9b1fa5ac3SMircea Trofin // Implementation file for the abstraction of a tensor type, and JSON loading
10b1fa5ac3SMircea Trofin // utils.
11b1fa5ac3SMircea Trofin //
12b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
135b8dc7c8SMircea Trofin #include "llvm/ADT/STLExtras.h"
14b1fa5ac3SMircea Trofin #include "llvm/Config/config.h"
15b1fa5ac3SMircea Trofin 
165b8dc7c8SMircea Trofin #include "llvm/ADT/StringExtras.h"
17b1fa5ac3SMircea Trofin #include "llvm/ADT/Twine.h"
18b1fa5ac3SMircea Trofin #include "llvm/Analysis/TensorSpec.h"
19b1fa5ac3SMircea Trofin #include "llvm/Support/CommandLine.h"
20b1fa5ac3SMircea Trofin #include "llvm/Support/Debug.h"
21b1fa5ac3SMircea Trofin #include "llvm/Support/JSON.h"
22b1fa5ac3SMircea Trofin #include "llvm/Support/ManagedStatic.h"
23b1fa5ac3SMircea Trofin #include "llvm/Support/raw_ostream.h"
244c97745bSMircea Trofin #include <array>
25b1fa5ac3SMircea Trofin #include <cassert>
26b1fa5ac3SMircea Trofin #include <numeric>
27b1fa5ac3SMircea Trofin 
28b1fa5ac3SMircea Trofin using namespace llvm;
29b1fa5ac3SMircea Trofin 
30b1fa5ac3SMircea Trofin namespace llvm {
31b1fa5ac3SMircea Trofin 
32b1fa5ac3SMircea Trofin #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
33b1fa5ac3SMircea Trofin   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34b1fa5ac3SMircea Trofin 
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)35b1fa5ac3SMircea Trofin SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
36b1fa5ac3SMircea Trofin 
37b1fa5ac3SMircea Trofin #undef TFUTILS_GETDATATYPE_IMPL
38b1fa5ac3SMircea Trofin 
394c97745bSMircea Trofin static std::array<std::string, static_cast<size_t>(TensorType::Total)>
404c97745bSMircea Trofin     TensorTypeNames{"INVALID",
414c97745bSMircea Trofin #define TFUTILS_GETNAME_IMPL(T, _) #T,
424c97745bSMircea Trofin                     SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
434c97745bSMircea Trofin #undef TFUTILS_GETNAME_IMPL
444c97745bSMircea Trofin     };
454c97745bSMircea Trofin 
toString(TensorType TT)464c97745bSMircea Trofin StringRef toString(TensorType TT) {
474c97745bSMircea Trofin   return TensorTypeNames[static_cast<size_t>(TT)];
484c97745bSMircea Trofin }
494c97745bSMircea Trofin 
toJSON(json::OStream & OS) const504c97745bSMircea Trofin void TensorSpec::toJSON(json::OStream &OS) const {
514c97745bSMircea Trofin   OS.object([&]() {
524c97745bSMircea Trofin     OS.attribute("name", name());
534c97745bSMircea Trofin     OS.attribute("type", toString(type()));
544c97745bSMircea Trofin     OS.attribute("port", port());
554c97745bSMircea Trofin     OS.attributeArray("shape", [&]() {
564c97745bSMircea Trofin       for (size_t D : shape())
574c97745bSMircea Trofin         OS.value(static_cast<int64_t>(D));
584c97745bSMircea Trofin     });
594c97745bSMircea Trofin   });
604c97745bSMircea Trofin }
614c97745bSMircea Trofin 
TensorSpec(const std::string & Name,int Port,TensorType Type,size_t ElementSize,const std::vector<int64_t> & Shape)62b1fa5ac3SMircea Trofin TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63b1fa5ac3SMircea Trofin                        size_t ElementSize, const std::vector<int64_t> &Shape)
64b1fa5ac3SMircea Trofin     : Name(Name), Port(Port), Type(Type), Shape(Shape),
65b1fa5ac3SMircea Trofin       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66b1fa5ac3SMircea Trofin                                    std::multiplies<int64_t>())),
67b1fa5ac3SMircea Trofin       ElementSize(ElementSize) {}
68b1fa5ac3SMircea Trofin 
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)69d4b6fcb3SFangrui Song std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70b1fa5ac3SMircea Trofin                                                 const json::Value &Value) {
71d4b6fcb3SFangrui Song   auto EmitError =
72d4b6fcb3SFangrui Song       [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73b1fa5ac3SMircea Trofin     std::string S;
74b1fa5ac3SMircea Trofin     llvm::raw_string_ostream OS(S);
75b1fa5ac3SMircea Trofin     OS << Value;
76b1fa5ac3SMircea Trofin     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
7719aff0f3SKazu Hirata     return std::nullopt;
78b1fa5ac3SMircea Trofin   };
79b1fa5ac3SMircea Trofin   // FIXME: accept a Path as a parameter, and use it for error reporting.
80b1fa5ac3SMircea Trofin   json::Path::Root Root("tensor_spec");
81b1fa5ac3SMircea Trofin   json::ObjectMapper Mapper(Value, Root);
82b1fa5ac3SMircea Trofin   if (!Mapper)
83b1fa5ac3SMircea Trofin     return EmitError("Value is not a dict");
84b1fa5ac3SMircea Trofin 
85b1fa5ac3SMircea Trofin   std::string TensorName;
86b1fa5ac3SMircea Trofin   int TensorPort = -1;
87b1fa5ac3SMircea Trofin   std::string TensorType;
88b1fa5ac3SMircea Trofin   std::vector<int64_t> TensorShape;
89b1fa5ac3SMircea Trofin 
90b1fa5ac3SMircea Trofin   if (!Mapper.map<std::string>("name", TensorName))
91b1fa5ac3SMircea Trofin     return EmitError("'name' property not present or not a string");
92b1fa5ac3SMircea Trofin   if (!Mapper.map<std::string>("type", TensorType))
93b1fa5ac3SMircea Trofin     return EmitError("'type' property not present or not a string");
94b1fa5ac3SMircea Trofin   if (!Mapper.map<int>("port", TensorPort))
95b1fa5ac3SMircea Trofin     return EmitError("'port' property not present or not an int");
96b1fa5ac3SMircea Trofin   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97b1fa5ac3SMircea Trofin     return EmitError("'shape' property not present or not an int array");
98b1fa5ac3SMircea Trofin 
99b1fa5ac3SMircea Trofin #define PARSE_TYPE(T, E)                                                       \
100b1fa5ac3SMircea Trofin   if (TensorType == #T)                                                        \
101b1fa5ac3SMircea Trofin     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102b1fa5ac3SMircea Trofin   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
103b1fa5ac3SMircea Trofin #undef PARSE_TYPE
10419aff0f3SKazu Hirata   return std::nullopt;
105b1fa5ac3SMircea Trofin }
106b1fa5ac3SMircea Trofin 
tensorValueToString(const char * Buffer,const TensorSpec & Spec)1075b8dc7c8SMircea Trofin std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
1085b8dc7c8SMircea Trofin   switch (Spec.type()) {
1095b8dc7c8SMircea Trofin #define _IMR_DBG_PRINTER(T, N)                                                 \
1105b8dc7c8SMircea Trofin   case TensorType::N: {                                                        \
1115b8dc7c8SMircea Trofin     const T *TypedBuff = reinterpret_cast<const T *>(Buffer);                  \
1125b8dc7c8SMircea Trofin     auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount());  \
1135b8dc7c8SMircea Trofin     return llvm::join(                                                         \
1145b8dc7c8SMircea Trofin         llvm::map_range(R, [](T V) { return std::to_string(V); }), ",");       \
1155b8dc7c8SMircea Trofin   }
1165b8dc7c8SMircea Trofin     SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
1175b8dc7c8SMircea Trofin #undef _IMR_DBG_PRINTER
1185b8dc7c8SMircea Trofin   case TensorType::Total:
1195b8dc7c8SMircea Trofin   case TensorType::Invalid:
1205b8dc7c8SMircea Trofin     llvm_unreachable("invalid tensor type");
1215b8dc7c8SMircea Trofin   }
122*7d31d3b0SMircea Trofin   // To appease warnings about not all control paths returning a value.
123*7d31d3b0SMircea Trofin   return "";
1245b8dc7c8SMircea Trofin }
1255b8dc7c8SMircea Trofin 
126b1fa5ac3SMircea Trofin } // namespace llvm
127