xref: /llvm-project/offload/tools/offload-tblgen/PrintGen.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1 //===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a Tablegen backend that produces print functions for the Offload API
10 // entry point functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Record.h"
16 
17 #include "GenCommon.hpp"
18 #include "RecordTypes.hpp"
19 
20 using namespace llvm;
21 using namespace offload::tblgen;
22 
23 constexpr auto PrintEnumHeader =
24     R"(///////////////////////////////////////////////////////////////////////////////
25 /// @brief Print operator for the {0} type
26 /// @returns std::ostream &
27 )";
28 
29 constexpr auto PrintTaggedEnumHeader =
30     R"(///////////////////////////////////////////////////////////////////////////////
31 /// @brief Print type-tagged {0} enum value
32 /// @returns std::ostream &
33 )";
34 
35 static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) {
36   OS << formatv(PrintEnumHeader, Enum.getName());
37   OS << formatv(
38       "inline std::ostream &operator<<(std::ostream &os, enum {0} value) "
39       "{{\n" TAB_1 "switch (value) {{\n",
40       Enum.getName());
41 
42   for (const auto &Val : Enum.getValues()) {
43     auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName();
44     OS << formatv(TAB_1 "case {0}:\n", Name);
45     OS << formatv(TAB_2 "os << \"{0}\";\n", Name);
46     OS << formatv(TAB_2 "break;\n");
47   }
48 
49   OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2
50               "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n";
51 
52   if (!Enum.isTyped()) {
53     return;
54   }
55 
56   OS << formatv(PrintTaggedEnumHeader, Enum.getName());
57 
58   OS << formatv(R"""(template <>
59 inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{
60   if (ptr == NULL) {{
61     printPtr(os, ptr);
62     return;
63   }
64 
65   switch (value) {{
66 )""",
67                 Enum.getName());
68 
69   for (const auto &Val : Enum.getValues()) {
70     auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName();
71     auto Type = Val.getTaggedType();
72     OS << formatv(TAB_1 "case {0}: {{\n", Name);
73     // Special case for strings
74     if (Type == "char[]") {
75       OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
76     } else {
77       OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
78                     Type);
79       // TODO: Handle other cases here
80       OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
81       if (Type.ends_with("*")) {
82         OS << TAB_2 "os << printPtr(os, tptr);\n";
83       } else {
84         OS << TAB_2 "os << *tptr;\n";
85       }
86       OS << TAB_2 "os << \")\";\n";
87     }
88     OS << formatv(TAB_2 "break;\n" TAB_1 "}\n");
89   }
90 
91   OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2
92               "break;\n" TAB_1 "}\n";
93 
94   OS << "}\n";
95 }
96 
97 static void EmitResultPrint(raw_ostream &OS) {
98   OS << R""(
99 inline std::ostream &operator<<(std::ostream &os,
100                                 const ol_error_struct_t *Err) {
101   if (Err == nullptr) {
102     os << "OL_SUCCESS";
103   } else {
104     os << Err->Code;
105   }
106   return os;
107 }
108 )"";
109 }
110 
111 static void EmitFunctionParamStructPrint(const FunctionRec &Func,
112                                          raw_ostream &OS) {
113   if (Func.getParams().size() == 0) {
114     return;
115   }
116 
117   OS << formatv(R"(
118 inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{
119 )",
120                 Func.getParamStructName());
121 
122   for (const auto &Param : Func.getParams()) {
123     OS << formatv(TAB_1 "os << \".{0} = \";\n", Param.getName());
124     if (auto Range = Param.getRange()) {
125       OS << formatv(TAB_1 "os << \"{{\";\n");
126       OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n",
127                     Range->first, Range->second);
128       OS << TAB_2 "if (i > 0) {\n";
129       OS << TAB_3 " os << \", \";\n";
130       OS << TAB_2 "}\n";
131       OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n",
132                     Param.getName());
133       OS << formatv(TAB_1 "}\n");
134       OS << formatv(TAB_1 "os << \"}\";\n");
135     } else if (auto TypeInfo = Param.getTypeInfo()) {
136       OS << formatv(
137           TAB_1
138           "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n",
139           Param.getName(), TypeInfo->first, TypeInfo->second);
140     } else if (Param.isPointerType() || Param.isHandleType()) {
141       OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName());
142     } else {
143       OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName());
144     }
145     if (Param != Func.getParams().back()) {
146       OS << TAB_1 "os << \", \";\n";
147     }
148   }
149 
150   OS << TAB_1 "return os;\n}\n";
151 }
152 
153 void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) {
154   OS << GenericHeader;
155   OS << R"""(
156 // Auto-generated file, do not manually edit.
157 
158 #pragma once
159 
160 #include <OffloadAPI.h>
161 #include <ostream>
162 
163 
164 template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr);
165 template <typename T> inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size);
166 )""";
167 
168   // ==========
169   OS << "template <typename T> struct is_handle : std::false_type {};\n";
170   for (auto *R : Records.getAllDerivedDefinitions("Handle")) {
171     HandleRec H{R};
172     OS << formatv("template <> struct is_handle<{0}> : std::true_type {{};\n",
173                   H.getName());
174   }
175   OS << "template <typename T> inline constexpr bool is_handle_v = "
176         "is_handle<T>::value;\n";
177   // =========
178 
179   // Forward declare the operator<< overloads so their implementations can
180   // use each other.
181   OS << "\n";
182   for (auto *R : Records.getAllDerivedDefinitions("Enum")) {
183     OS << formatv(
184         "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n",
185         EnumRec{R}.getName());
186   }
187   OS << "\n";
188 
189   // Create definitions
190   for (auto *R : Records.getAllDerivedDefinitions("Enum")) {
191     EnumRec E{R};
192     ProcessEnum(E, OS);
193   }
194   EmitResultPrint(OS);
195 
196   // Emit print functions for the function param structs
197   for (auto *R : Records.getAllDerivedDefinitions("Function")) {
198     EmitFunctionParamStructPrint(FunctionRec{R}, OS);
199   }
200 
201   OS << R"""(
202 ///////////////////////////////////////////////////////////////////////////////
203 // @brief Print pointer value
204 template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr) {
205     if (ptr == nullptr) {
206         os << "nullptr";
207     } else if constexpr (std::is_pointer_v<T>) {
208         os << (const void *)(ptr) << " (";
209         printPtr(os, *ptr);
210         os << ")";
211     } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) {
212         os << (const void *)ptr;
213     } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) {
214         os << (const void *)(ptr) << " (";
215         os << ptr;
216         os << ")";
217     } else {
218         os << (const void *)(ptr) << " (";
219         os << *ptr;
220         os << ")";
221     }
222 
223     return OL_SUCCESS;
224 }
225   )""";
226 }
227