xref: /freebsd-src/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1*81ad6265SDimitry Andric //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2*81ad6265SDimitry Andric //
3*81ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*81ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*81ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*81ad6265SDimitry Andric //
7*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
8*81ad6265SDimitry Andric //
9*81ad6265SDimitry Andric // Implementation file for the abstraction of a tensor type, and JSON loading
10*81ad6265SDimitry Andric // utils.
11*81ad6265SDimitry Andric //
12*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
13*81ad6265SDimitry Andric #include "llvm/Config/config.h"
14*81ad6265SDimitry Andric 
15*81ad6265SDimitry Andric #include "llvm/ADT/Twine.h"
16*81ad6265SDimitry Andric #include "llvm/Analysis/TensorSpec.h"
17*81ad6265SDimitry Andric #include "llvm/Support/CommandLine.h"
18*81ad6265SDimitry Andric #include "llvm/Support/Debug.h"
19*81ad6265SDimitry Andric #include "llvm/Support/JSON.h"
20*81ad6265SDimitry Andric #include "llvm/Support/ManagedStatic.h"
21*81ad6265SDimitry Andric #include "llvm/Support/MemoryBuffer.h"
22*81ad6265SDimitry Andric #include "llvm/Support/Path.h"
23*81ad6265SDimitry Andric #include "llvm/Support/raw_ostream.h"
24*81ad6265SDimitry Andric #include <cassert>
25*81ad6265SDimitry Andric #include <numeric>
26*81ad6265SDimitry Andric 
27*81ad6265SDimitry Andric using namespace llvm;
28*81ad6265SDimitry Andric 
29*81ad6265SDimitry Andric namespace llvm {
30*81ad6265SDimitry Andric 
31*81ad6265SDimitry Andric #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
32*81ad6265SDimitry Andric   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
33*81ad6265SDimitry Andric 
34*81ad6265SDimitry Andric SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
35*81ad6265SDimitry Andric 
36*81ad6265SDimitry Andric #undef TFUTILS_GETDATATYPE_IMPL
37*81ad6265SDimitry Andric 
38*81ad6265SDimitry Andric TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
39*81ad6265SDimitry Andric                        size_t ElementSize, const std::vector<int64_t> &Shape)
40*81ad6265SDimitry Andric     : Name(Name), Port(Port), Type(Type), Shape(Shape),
41*81ad6265SDimitry Andric       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
42*81ad6265SDimitry Andric                                    std::multiplies<int64_t>())),
43*81ad6265SDimitry Andric       ElementSize(ElementSize) {}
44*81ad6265SDimitry Andric 
45*81ad6265SDimitry Andric Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
46*81ad6265SDimitry Andric                                            const json::Value &Value) {
47*81ad6265SDimitry Andric   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
48*81ad6265SDimitry Andric     std::string S;
49*81ad6265SDimitry Andric     llvm::raw_string_ostream OS(S);
50*81ad6265SDimitry Andric     OS << Value;
51*81ad6265SDimitry Andric     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
52*81ad6265SDimitry Andric     return None;
53*81ad6265SDimitry Andric   };
54*81ad6265SDimitry Andric   // FIXME: accept a Path as a parameter, and use it for error reporting.
55*81ad6265SDimitry Andric   json::Path::Root Root("tensor_spec");
56*81ad6265SDimitry Andric   json::ObjectMapper Mapper(Value, Root);
57*81ad6265SDimitry Andric   if (!Mapper)
58*81ad6265SDimitry Andric     return EmitError("Value is not a dict");
59*81ad6265SDimitry Andric 
60*81ad6265SDimitry Andric   std::string TensorName;
61*81ad6265SDimitry Andric   int TensorPort = -1;
62*81ad6265SDimitry Andric   std::string TensorType;
63*81ad6265SDimitry Andric   std::vector<int64_t> TensorShape;
64*81ad6265SDimitry Andric 
65*81ad6265SDimitry Andric   if (!Mapper.map<std::string>("name", TensorName))
66*81ad6265SDimitry Andric     return EmitError("'name' property not present or not a string");
67*81ad6265SDimitry Andric   if (!Mapper.map<std::string>("type", TensorType))
68*81ad6265SDimitry Andric     return EmitError("'type' property not present or not a string");
69*81ad6265SDimitry Andric   if (!Mapper.map<int>("port", TensorPort))
70*81ad6265SDimitry Andric     return EmitError("'port' property not present or not an int");
71*81ad6265SDimitry Andric   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
72*81ad6265SDimitry Andric     return EmitError("'shape' property not present or not an int array");
73*81ad6265SDimitry Andric 
74*81ad6265SDimitry Andric #define PARSE_TYPE(T, E)                                                       \
75*81ad6265SDimitry Andric   if (TensorType == #T)                                                        \
76*81ad6265SDimitry Andric     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
77*81ad6265SDimitry Andric   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
78*81ad6265SDimitry Andric #undef PARSE_TYPE
79*81ad6265SDimitry Andric   return None;
80*81ad6265SDimitry Andric }
81*81ad6265SDimitry Andric 
82*81ad6265SDimitry Andric Optional<std::vector<LoggedFeatureSpec>>
83*81ad6265SDimitry Andric loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
84*81ad6265SDimitry Andric                 StringRef ModelPath, StringRef SpecFileOverride) {
85*81ad6265SDimitry Andric   SmallVector<char, 128> OutputSpecsPath;
86*81ad6265SDimitry Andric   StringRef FileName = SpecFileOverride;
87*81ad6265SDimitry Andric   if (FileName.empty()) {
88*81ad6265SDimitry Andric     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
89*81ad6265SDimitry Andric     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
90*81ad6265SDimitry Andric   }
91*81ad6265SDimitry Andric 
92*81ad6265SDimitry Andric   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
93*81ad6265SDimitry Andric   if (!BufferOrError) {
94*81ad6265SDimitry Andric     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
95*81ad6265SDimitry Andric                   BufferOrError.getError().message());
96*81ad6265SDimitry Andric     return None;
97*81ad6265SDimitry Andric   }
98*81ad6265SDimitry Andric   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
99*81ad6265SDimitry Andric   if (!ParsedJSONValues) {
100*81ad6265SDimitry Andric     Ctx.emitError("Could not parse specs file: " + FileName);
101*81ad6265SDimitry Andric     return None;
102*81ad6265SDimitry Andric   }
103*81ad6265SDimitry Andric   auto ValuesArray = ParsedJSONValues->getAsArray();
104*81ad6265SDimitry Andric   if (!ValuesArray) {
105*81ad6265SDimitry Andric     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
106*81ad6265SDimitry Andric                   "logging_name:<name>} dictionaries");
107*81ad6265SDimitry Andric     return None;
108*81ad6265SDimitry Andric   }
109*81ad6265SDimitry Andric   std::vector<LoggedFeatureSpec> Ret;
110*81ad6265SDimitry Andric   for (const auto &Value : *ValuesArray)
111*81ad6265SDimitry Andric     if (const auto *Obj = Value.getAsObject())
112*81ad6265SDimitry Andric       if (const auto *SpecPart = Obj->get("tensor_spec"))
113*81ad6265SDimitry Andric         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
114*81ad6265SDimitry Andric           if (auto LoggingName = Obj->getString("logging_name")) {
115*81ad6265SDimitry Andric             if (!TensorSpec->isElementType<int64_t>() &&
116*81ad6265SDimitry Andric                 !TensorSpec->isElementType<int32_t>() &&
117*81ad6265SDimitry Andric                 !TensorSpec->isElementType<float>()) {
118*81ad6265SDimitry Andric               Ctx.emitError(
119*81ad6265SDimitry Andric                   "Only int64, int32, and float tensors are supported. "
120*81ad6265SDimitry Andric                   "Found unsupported type for tensor named " +
121*81ad6265SDimitry Andric                   TensorSpec->name());
122*81ad6265SDimitry Andric               return None;
123*81ad6265SDimitry Andric             }
124*81ad6265SDimitry Andric             Ret.push_back({*TensorSpec, LoggingName->str()});
125*81ad6265SDimitry Andric           }
126*81ad6265SDimitry Andric 
127*81ad6265SDimitry Andric   if (ValuesArray->size() != Ret.size()) {
128*81ad6265SDimitry Andric     Ctx.emitError(
129*81ad6265SDimitry Andric         "Unable to parse output spec. It should be a json file containing an "
130*81ad6265SDimitry Andric         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
131*81ad6265SDimitry Andric         "with a json object describing a TensorSpec; and a 'logging_name' key, "
132*81ad6265SDimitry Andric         "which is a string to use as name when logging this tensor in the "
133*81ad6265SDimitry Andric         "training log.");
134*81ad6265SDimitry Andric     return None;
135*81ad6265SDimitry Andric   }
136*81ad6265SDimitry Andric   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
137*81ad6265SDimitry Andric     Ctx.emitError("The first output spec must describe the decision tensor, "
138*81ad6265SDimitry Andric                   "and must have the logging_name " +
139*81ad6265SDimitry Andric                   StringRef(ExpectedDecisionName));
140*81ad6265SDimitry Andric     return None;
141*81ad6265SDimitry Andric   }
142*81ad6265SDimitry Andric   return Ret;
143*81ad6265SDimitry Andric }
144*81ad6265SDimitry Andric } // namespace llvm
145