1 //===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a Tablegen backend that produces print functions for the Offload API 10 // entry point functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Record.h" 16 17 #include "GenCommon.hpp" 18 #include "RecordTypes.hpp" 19 20 using namespace llvm; 21 using namespace offload::tblgen; 22 23 constexpr auto PrintEnumHeader = 24 R"(/////////////////////////////////////////////////////////////////////////////// 25 /// @brief Print operator for the {0} type 26 /// @returns std::ostream & 27 )"; 28 29 constexpr auto PrintTaggedEnumHeader = 30 R"(/////////////////////////////////////////////////////////////////////////////// 31 /// @brief Print type-tagged {0} enum value 32 /// @returns std::ostream & 33 )"; 34 35 static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { 36 OS << formatv(PrintEnumHeader, Enum.getName()); 37 OS << formatv( 38 "inline std::ostream &operator<<(std::ostream &os, enum {0} value) " 39 "{{\n" TAB_1 "switch (value) {{\n", 40 Enum.getName()); 41 42 for (const auto &Val : Enum.getValues()) { 43 auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); 44 OS << formatv(TAB_1 "case {0}:\n", Name); 45 OS << formatv(TAB_2 "os << \"{0}\";\n", Name); 46 OS << formatv(TAB_2 "break;\n"); 47 } 48 49 OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 50 "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n"; 51 52 if (!Enum.isTyped()) { 53 return; 54 } 55 56 OS << formatv(PrintTaggedEnumHeader, Enum.getName()); 57 58 OS << formatv(R"""(template <> 59 inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{ 60 if (ptr == NULL) {{ 61 printPtr(os, ptr); 62 return; 63 } 64 65 switch (value) {{ 66 )""", 67 Enum.getName()); 68 69 for (const auto &Val : Enum.getValues()) { 70 auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); 71 auto Type = Val.getTaggedType(); 72 OS << formatv(TAB_1 "case {0}: {{\n", Name); 73 // Special case for strings 74 if (Type == "char[]") { 75 OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n"); 76 } else { 77 OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", 78 Type); 79 // TODO: Handle other cases here 80 OS << TAB_2 "os << (const void *)tptr << \" (\";\n"; 81 if (Type.ends_with("*")) { 82 OS << TAB_2 "os << printPtr(os, tptr);\n"; 83 } else { 84 OS << TAB_2 "os << *tptr;\n"; 85 } 86 OS << TAB_2 "os << \")\";\n"; 87 } 88 OS << formatv(TAB_2 "break;\n" TAB_1 "}\n"); 89 } 90 91 OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 92 "break;\n" TAB_1 "}\n"; 93 94 OS << "}\n"; 95 } 96 97 static void EmitResultPrint(raw_ostream &OS) { 98 OS << R""( 99 inline std::ostream &operator<<(std::ostream &os, 100 const ol_error_struct_t *Err) { 101 if (Err == nullptr) { 102 os << "OL_SUCCESS"; 103 } else { 104 os << Err->Code; 105 } 106 return os; 107 } 108 )""; 109 } 110 111 static void EmitFunctionParamStructPrint(const FunctionRec &Func, 112 raw_ostream &OS) { 113 if (Func.getParams().size() == 0) { 114 return; 115 } 116 117 OS << formatv(R"( 118 inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{ 119 )", 120 Func.getParamStructName()); 121 122 for (const auto &Param : Func.getParams()) { 123 OS << formatv(TAB_1 "os << \".{0} = \";\n", Param.getName()); 124 if (auto Range = Param.getRange()) { 125 OS << formatv(TAB_1 "os << \"{{\";\n"); 126 OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n", 127 Range->first, Range->second); 128 OS << TAB_2 "if (i > 0) {\n"; 129 OS << TAB_3 " os << \", \";\n"; 130 OS << TAB_2 "}\n"; 131 OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n", 132 Param.getName()); 133 OS << formatv(TAB_1 "}\n"); 134 OS << formatv(TAB_1 "os << \"}\";\n"); 135 } else if (auto TypeInfo = Param.getTypeInfo()) { 136 OS << formatv( 137 TAB_1 138 "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n", 139 Param.getName(), TypeInfo->first, TypeInfo->second); 140 } else if (Param.isPointerType() || Param.isHandleType()) { 141 OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName()); 142 } else { 143 OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName()); 144 } 145 if (Param != Func.getParams().back()) { 146 OS << TAB_1 "os << \", \";\n"; 147 } 148 } 149 150 OS << TAB_1 "return os;\n}\n"; 151 } 152 153 void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) { 154 OS << GenericHeader; 155 OS << R"""( 156 // Auto-generated file, do not manually edit. 157 158 #pragma once 159 160 #include <OffloadAPI.h> 161 #include <ostream> 162 163 164 template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr); 165 template <typename T> inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size); 166 )"""; 167 168 // ========== 169 OS << "template <typename T> struct is_handle : std::false_type {};\n"; 170 for (auto *R : Records.getAllDerivedDefinitions("Handle")) { 171 HandleRec H{R}; 172 OS << formatv("template <> struct is_handle<{0}> : std::true_type {{};\n", 173 H.getName()); 174 } 175 OS << "template <typename T> inline constexpr bool is_handle_v = " 176 "is_handle<T>::value;\n"; 177 // ========= 178 179 // Forward declare the operator<< overloads so their implementations can 180 // use each other. 181 OS << "\n"; 182 for (auto *R : Records.getAllDerivedDefinitions("Enum")) { 183 OS << formatv( 184 "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n", 185 EnumRec{R}.getName()); 186 } 187 OS << "\n"; 188 189 // Create definitions 190 for (auto *R : Records.getAllDerivedDefinitions("Enum")) { 191 EnumRec E{R}; 192 ProcessEnum(E, OS); 193 } 194 EmitResultPrint(OS); 195 196 // Emit print functions for the function param structs 197 for (auto *R : Records.getAllDerivedDefinitions("Function")) { 198 EmitFunctionParamStructPrint(FunctionRec{R}, OS); 199 } 200 201 OS << R"""( 202 /////////////////////////////////////////////////////////////////////////////// 203 // @brief Print pointer value 204 template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr) { 205 if (ptr == nullptr) { 206 os << "nullptr"; 207 } else if constexpr (std::is_pointer_v<T>) { 208 os << (const void *)(ptr) << " ("; 209 printPtr(os, *ptr); 210 os << ")"; 211 } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) { 212 os << (const void *)ptr; 213 } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) { 214 os << (const void *)(ptr) << " ("; 215 os << ptr; 216 os << ")"; 217 } else { 218 os << (const void *)(ptr) << " ("; 219 os << *ptr; 220 os << ")"; 221 } 222 223 return OL_SUCCESS; 224 } 225 )"""; 226 } 227