1*81ad6265SDimitry Andric //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// 2*81ad6265SDimitry Andric // 3*81ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*81ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*81ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*81ad6265SDimitry Andric // 7*81ad6265SDimitry Andric //===----------------------------------------------------------------------===// 8*81ad6265SDimitry Andric // 9*81ad6265SDimitry Andric // Implementation file for the abstraction of a tensor type, and JSON loading 10*81ad6265SDimitry Andric // utils. 11*81ad6265SDimitry Andric // 12*81ad6265SDimitry Andric //===----------------------------------------------------------------------===// 13*81ad6265SDimitry Andric #include "llvm/Config/config.h" 14*81ad6265SDimitry Andric 15*81ad6265SDimitry Andric #include "llvm/ADT/Twine.h" 16*81ad6265SDimitry Andric #include "llvm/Analysis/TensorSpec.h" 17*81ad6265SDimitry Andric #include "llvm/Support/CommandLine.h" 18*81ad6265SDimitry Andric #include "llvm/Support/Debug.h" 19*81ad6265SDimitry Andric #include "llvm/Support/JSON.h" 20*81ad6265SDimitry Andric #include "llvm/Support/ManagedStatic.h" 21*81ad6265SDimitry Andric #include "llvm/Support/MemoryBuffer.h" 22*81ad6265SDimitry Andric #include "llvm/Support/Path.h" 23*81ad6265SDimitry Andric #include "llvm/Support/raw_ostream.h" 24*81ad6265SDimitry Andric #include <cassert> 25*81ad6265SDimitry Andric #include <numeric> 26*81ad6265SDimitry Andric 27*81ad6265SDimitry Andric using namespace llvm; 28*81ad6265SDimitry Andric 29*81ad6265SDimitry Andric namespace llvm { 30*81ad6265SDimitry Andric 31*81ad6265SDimitry Andric #define TFUTILS_GETDATATYPE_IMPL(T, E) \ 32*81ad6265SDimitry Andric template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } 33*81ad6265SDimitry Andric 34*81ad6265SDimitry Andric SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) 35*81ad6265SDimitry Andric 36*81ad6265SDimitry Andric #undef TFUTILS_GETDATATYPE_IMPL 37*81ad6265SDimitry Andric 38*81ad6265SDimitry Andric TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, 39*81ad6265SDimitry Andric size_t ElementSize, const std::vector<int64_t> &Shape) 40*81ad6265SDimitry Andric : Name(Name), Port(Port), Type(Type), Shape(Shape), 41*81ad6265SDimitry Andric ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, 42*81ad6265SDimitry Andric std::multiplies<int64_t>())), 43*81ad6265SDimitry Andric ElementSize(ElementSize) {} 44*81ad6265SDimitry Andric 45*81ad6265SDimitry Andric Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, 46*81ad6265SDimitry Andric const json::Value &Value) { 47*81ad6265SDimitry Andric auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> { 48*81ad6265SDimitry Andric std::string S; 49*81ad6265SDimitry Andric llvm::raw_string_ostream OS(S); 50*81ad6265SDimitry Andric OS << Value; 51*81ad6265SDimitry Andric Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); 52*81ad6265SDimitry Andric return None; 53*81ad6265SDimitry Andric }; 54*81ad6265SDimitry Andric // FIXME: accept a Path as a parameter, and use it for error reporting. 55*81ad6265SDimitry Andric json::Path::Root Root("tensor_spec"); 56*81ad6265SDimitry Andric json::ObjectMapper Mapper(Value, Root); 57*81ad6265SDimitry Andric if (!Mapper) 58*81ad6265SDimitry Andric return EmitError("Value is not a dict"); 59*81ad6265SDimitry Andric 60*81ad6265SDimitry Andric std::string TensorName; 61*81ad6265SDimitry Andric int TensorPort = -1; 62*81ad6265SDimitry Andric std::string TensorType; 63*81ad6265SDimitry Andric std::vector<int64_t> TensorShape; 64*81ad6265SDimitry Andric 65*81ad6265SDimitry Andric if (!Mapper.map<std::string>("name", TensorName)) 66*81ad6265SDimitry Andric return EmitError("'name' property not present or not a string"); 67*81ad6265SDimitry Andric if (!Mapper.map<std::string>("type", TensorType)) 68*81ad6265SDimitry Andric return EmitError("'type' property not present or not a string"); 69*81ad6265SDimitry Andric if (!Mapper.map<int>("port", TensorPort)) 70*81ad6265SDimitry Andric return EmitError("'port' property not present or not an int"); 71*81ad6265SDimitry Andric if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) 72*81ad6265SDimitry Andric return EmitError("'shape' property not present or not an int array"); 73*81ad6265SDimitry Andric 74*81ad6265SDimitry Andric #define PARSE_TYPE(T, E) \ 75*81ad6265SDimitry Andric if (TensorType == #T) \ 76*81ad6265SDimitry Andric return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); 77*81ad6265SDimitry Andric SUPPORTED_TENSOR_TYPES(PARSE_TYPE) 78*81ad6265SDimitry Andric #undef PARSE_TYPE 79*81ad6265SDimitry Andric return None; 80*81ad6265SDimitry Andric } 81*81ad6265SDimitry Andric 82*81ad6265SDimitry Andric Optional<std::vector<LoggedFeatureSpec>> 83*81ad6265SDimitry Andric loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, 84*81ad6265SDimitry Andric StringRef ModelPath, StringRef SpecFileOverride) { 85*81ad6265SDimitry Andric SmallVector<char, 128> OutputSpecsPath; 86*81ad6265SDimitry Andric StringRef FileName = SpecFileOverride; 87*81ad6265SDimitry Andric if (FileName.empty()) { 88*81ad6265SDimitry Andric llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); 89*81ad6265SDimitry Andric FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; 90*81ad6265SDimitry Andric } 91*81ad6265SDimitry Andric 92*81ad6265SDimitry Andric auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); 93*81ad6265SDimitry Andric if (!BufferOrError) { 94*81ad6265SDimitry Andric Ctx.emitError("Error opening output specs file: " + FileName + " : " + 95*81ad6265SDimitry Andric BufferOrError.getError().message()); 96*81ad6265SDimitry Andric return None; 97*81ad6265SDimitry Andric } 98*81ad6265SDimitry Andric auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); 99*81ad6265SDimitry Andric if (!ParsedJSONValues) { 100*81ad6265SDimitry Andric Ctx.emitError("Could not parse specs file: " + FileName); 101*81ad6265SDimitry Andric return None; 102*81ad6265SDimitry Andric } 103*81ad6265SDimitry Andric auto ValuesArray = ParsedJSONValues->getAsArray(); 104*81ad6265SDimitry Andric if (!ValuesArray) { 105*81ad6265SDimitry Andric Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, " 106*81ad6265SDimitry Andric "logging_name:<name>} dictionaries"); 107*81ad6265SDimitry Andric return None; 108*81ad6265SDimitry Andric } 109*81ad6265SDimitry Andric std::vector<LoggedFeatureSpec> Ret; 110*81ad6265SDimitry Andric for (const auto &Value : *ValuesArray) 111*81ad6265SDimitry Andric if (const auto *Obj = Value.getAsObject()) 112*81ad6265SDimitry Andric if (const auto *SpecPart = Obj->get("tensor_spec")) 113*81ad6265SDimitry Andric if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) 114*81ad6265SDimitry Andric if (auto LoggingName = Obj->getString("logging_name")) { 115*81ad6265SDimitry Andric if (!TensorSpec->isElementType<int64_t>() && 116*81ad6265SDimitry Andric !TensorSpec->isElementType<int32_t>() && 117*81ad6265SDimitry Andric !TensorSpec->isElementType<float>()) { 118*81ad6265SDimitry Andric Ctx.emitError( 119*81ad6265SDimitry Andric "Only int64, int32, and float tensors are supported. " 120*81ad6265SDimitry Andric "Found unsupported type for tensor named " + 121*81ad6265SDimitry Andric TensorSpec->name()); 122*81ad6265SDimitry Andric return None; 123*81ad6265SDimitry Andric } 124*81ad6265SDimitry Andric Ret.push_back({*TensorSpec, LoggingName->str()}); 125*81ad6265SDimitry Andric } 126*81ad6265SDimitry Andric 127*81ad6265SDimitry Andric if (ValuesArray->size() != Ret.size()) { 128*81ad6265SDimitry Andric Ctx.emitError( 129*81ad6265SDimitry Andric "Unable to parse output spec. It should be a json file containing an " 130*81ad6265SDimitry Andric "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " 131*81ad6265SDimitry Andric "with a json object describing a TensorSpec; and a 'logging_name' key, " 132*81ad6265SDimitry Andric "which is a string to use as name when logging this tensor in the " 133*81ad6265SDimitry Andric "training log."); 134*81ad6265SDimitry Andric return None; 135*81ad6265SDimitry Andric } 136*81ad6265SDimitry Andric if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { 137*81ad6265SDimitry Andric Ctx.emitError("The first output spec must describe the decision tensor, " 138*81ad6265SDimitry Andric "and must have the logging_name " + 139*81ad6265SDimitry Andric StringRef(ExpectedDecisionName)); 140*81ad6265SDimitry Andric return None; 141*81ad6265SDimitry Andric } 142*81ad6265SDimitry Andric return Ret; 143*81ad6265SDimitry Andric } 144*81ad6265SDimitry Andric } // namespace llvm 145