xref: /llvm-project/offload/tools/offload-tblgen/PrintGen.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1*fd3907ccSCallum Fare //===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===//
2*fd3907ccSCallum Fare //
3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information.
5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fd3907ccSCallum Fare //
7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
8*fd3907ccSCallum Fare //
9*fd3907ccSCallum Fare // This is a Tablegen backend that produces print functions for the Offload API
10*fd3907ccSCallum Fare // entry point functions.
11*fd3907ccSCallum Fare //
12*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
13*fd3907ccSCallum Fare 
14*fd3907ccSCallum Fare #include "llvm/Support/FormatVariadic.h"
15*fd3907ccSCallum Fare #include "llvm/TableGen/Record.h"
16*fd3907ccSCallum Fare 
17*fd3907ccSCallum Fare #include "GenCommon.hpp"
18*fd3907ccSCallum Fare #include "RecordTypes.hpp"
19*fd3907ccSCallum Fare 
20*fd3907ccSCallum Fare using namespace llvm;
21*fd3907ccSCallum Fare using namespace offload::tblgen;
22*fd3907ccSCallum Fare 
23*fd3907ccSCallum Fare constexpr auto PrintEnumHeader =
24*fd3907ccSCallum Fare     R"(///////////////////////////////////////////////////////////////////////////////
25*fd3907ccSCallum Fare /// @brief Print operator for the {0} type
26*fd3907ccSCallum Fare /// @returns std::ostream &
27*fd3907ccSCallum Fare )";
28*fd3907ccSCallum Fare 
29*fd3907ccSCallum Fare constexpr auto PrintTaggedEnumHeader =
30*fd3907ccSCallum Fare     R"(///////////////////////////////////////////////////////////////////////////////
31*fd3907ccSCallum Fare /// @brief Print type-tagged {0} enum value
32*fd3907ccSCallum Fare /// @returns std::ostream &
33*fd3907ccSCallum Fare )";
34*fd3907ccSCallum Fare 
35*fd3907ccSCallum Fare static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) {
36*fd3907ccSCallum Fare   OS << formatv(PrintEnumHeader, Enum.getName());
37*fd3907ccSCallum Fare   OS << formatv(
38*fd3907ccSCallum Fare       "inline std::ostream &operator<<(std::ostream &os, enum {0} value) "
39*fd3907ccSCallum Fare       "{{\n" TAB_1 "switch (value) {{\n",
40*fd3907ccSCallum Fare       Enum.getName());
41*fd3907ccSCallum Fare 
42*fd3907ccSCallum Fare   for (const auto &Val : Enum.getValues()) {
43*fd3907ccSCallum Fare     auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName();
44*fd3907ccSCallum Fare     OS << formatv(TAB_1 "case {0}:\n", Name);
45*fd3907ccSCallum Fare     OS << formatv(TAB_2 "os << \"{0}\";\n", Name);
46*fd3907ccSCallum Fare     OS << formatv(TAB_2 "break;\n");
47*fd3907ccSCallum Fare   }
48*fd3907ccSCallum Fare 
49*fd3907ccSCallum Fare   OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2
50*fd3907ccSCallum Fare               "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n";
51*fd3907ccSCallum Fare 
52*fd3907ccSCallum Fare   if (!Enum.isTyped()) {
53*fd3907ccSCallum Fare     return;
54*fd3907ccSCallum Fare   }
55*fd3907ccSCallum Fare 
56*fd3907ccSCallum Fare   OS << formatv(PrintTaggedEnumHeader, Enum.getName());
57*fd3907ccSCallum Fare 
58*fd3907ccSCallum Fare   OS << formatv(R"""(template <>
59*fd3907ccSCallum Fare inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{
60*fd3907ccSCallum Fare   if (ptr == NULL) {{
61*fd3907ccSCallum Fare     printPtr(os, ptr);
62*fd3907ccSCallum Fare     return;
63*fd3907ccSCallum Fare   }
64*fd3907ccSCallum Fare 
65*fd3907ccSCallum Fare   switch (value) {{
66*fd3907ccSCallum Fare )""",
67*fd3907ccSCallum Fare                 Enum.getName());
68*fd3907ccSCallum Fare 
69*fd3907ccSCallum Fare   for (const auto &Val : Enum.getValues()) {
70*fd3907ccSCallum Fare     auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName();
71*fd3907ccSCallum Fare     auto Type = Val.getTaggedType();
72*fd3907ccSCallum Fare     OS << formatv(TAB_1 "case {0}: {{\n", Name);
73*fd3907ccSCallum Fare     // Special case for strings
74*fd3907ccSCallum Fare     if (Type == "char[]") {
75*fd3907ccSCallum Fare       OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
76*fd3907ccSCallum Fare     } else {
77*fd3907ccSCallum Fare       OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
78*fd3907ccSCallum Fare                     Type);
79*fd3907ccSCallum Fare       // TODO: Handle other cases here
80*fd3907ccSCallum Fare       OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
81*fd3907ccSCallum Fare       if (Type.ends_with("*")) {
82*fd3907ccSCallum Fare         OS << TAB_2 "os << printPtr(os, tptr);\n";
83*fd3907ccSCallum Fare       } else {
84*fd3907ccSCallum Fare         OS << TAB_2 "os << *tptr;\n";
85*fd3907ccSCallum Fare       }
86*fd3907ccSCallum Fare       OS << TAB_2 "os << \")\";\n";
87*fd3907ccSCallum Fare     }
88*fd3907ccSCallum Fare     OS << formatv(TAB_2 "break;\n" TAB_1 "}\n");
89*fd3907ccSCallum Fare   }
90*fd3907ccSCallum Fare 
91*fd3907ccSCallum Fare   OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2
92*fd3907ccSCallum Fare               "break;\n" TAB_1 "}\n";
93*fd3907ccSCallum Fare 
94*fd3907ccSCallum Fare   OS << "}\n";
95*fd3907ccSCallum Fare }
96*fd3907ccSCallum Fare 
97*fd3907ccSCallum Fare static void EmitResultPrint(raw_ostream &OS) {
98*fd3907ccSCallum Fare   OS << R""(
99*fd3907ccSCallum Fare inline std::ostream &operator<<(std::ostream &os,
100*fd3907ccSCallum Fare                                 const ol_error_struct_t *Err) {
101*fd3907ccSCallum Fare   if (Err == nullptr) {
102*fd3907ccSCallum Fare     os << "OL_SUCCESS";
103*fd3907ccSCallum Fare   } else {
104*fd3907ccSCallum Fare     os << Err->Code;
105*fd3907ccSCallum Fare   }
106*fd3907ccSCallum Fare   return os;
107*fd3907ccSCallum Fare }
108*fd3907ccSCallum Fare )"";
109*fd3907ccSCallum Fare }
110*fd3907ccSCallum Fare 
111*fd3907ccSCallum Fare static void EmitFunctionParamStructPrint(const FunctionRec &Func,
112*fd3907ccSCallum Fare                                          raw_ostream &OS) {
113*fd3907ccSCallum Fare   if (Func.getParams().size() == 0) {
114*fd3907ccSCallum Fare     return;
115*fd3907ccSCallum Fare   }
116*fd3907ccSCallum Fare 
117*fd3907ccSCallum Fare   OS << formatv(R"(
118*fd3907ccSCallum Fare inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{
119*fd3907ccSCallum Fare )",
120*fd3907ccSCallum Fare                 Func.getParamStructName());
121*fd3907ccSCallum Fare 
122*fd3907ccSCallum Fare   for (const auto &Param : Func.getParams()) {
123*fd3907ccSCallum Fare     OS << formatv(TAB_1 "os << \".{0} = \";\n", Param.getName());
124*fd3907ccSCallum Fare     if (auto Range = Param.getRange()) {
125*fd3907ccSCallum Fare       OS << formatv(TAB_1 "os << \"{{\";\n");
126*fd3907ccSCallum Fare       OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n",
127*fd3907ccSCallum Fare                     Range->first, Range->second);
128*fd3907ccSCallum Fare       OS << TAB_2 "if (i > 0) {\n";
129*fd3907ccSCallum Fare       OS << TAB_3 " os << \", \";\n";
130*fd3907ccSCallum Fare       OS << TAB_2 "}\n";
131*fd3907ccSCallum Fare       OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n",
132*fd3907ccSCallum Fare                     Param.getName());
133*fd3907ccSCallum Fare       OS << formatv(TAB_1 "}\n");
134*fd3907ccSCallum Fare       OS << formatv(TAB_1 "os << \"}\";\n");
135*fd3907ccSCallum Fare     } else if (auto TypeInfo = Param.getTypeInfo()) {
136*fd3907ccSCallum Fare       OS << formatv(
137*fd3907ccSCallum Fare           TAB_1
138*fd3907ccSCallum Fare           "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n",
139*fd3907ccSCallum Fare           Param.getName(), TypeInfo->first, TypeInfo->second);
140*fd3907ccSCallum Fare     } else if (Param.isPointerType() || Param.isHandleType()) {
141*fd3907ccSCallum Fare       OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName());
142*fd3907ccSCallum Fare     } else {
143*fd3907ccSCallum Fare       OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName());
144*fd3907ccSCallum Fare     }
145*fd3907ccSCallum Fare     if (Param != Func.getParams().back()) {
146*fd3907ccSCallum Fare       OS << TAB_1 "os << \", \";\n";
147*fd3907ccSCallum Fare     }
148*fd3907ccSCallum Fare   }
149*fd3907ccSCallum Fare 
150*fd3907ccSCallum Fare   OS << TAB_1 "return os;\n}\n";
151*fd3907ccSCallum Fare }
152*fd3907ccSCallum Fare 
153*fd3907ccSCallum Fare void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) {
154*fd3907ccSCallum Fare   OS << GenericHeader;
155*fd3907ccSCallum Fare   OS << R"""(
156*fd3907ccSCallum Fare // Auto-generated file, do not manually edit.
157*fd3907ccSCallum Fare 
158*fd3907ccSCallum Fare #pragma once
159*fd3907ccSCallum Fare 
160*fd3907ccSCallum Fare #include <OffloadAPI.h>
161*fd3907ccSCallum Fare #include <ostream>
162*fd3907ccSCallum Fare 
163*fd3907ccSCallum Fare 
164*fd3907ccSCallum Fare template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr);
165*fd3907ccSCallum Fare template <typename T> inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size);
166*fd3907ccSCallum Fare )""";
167*fd3907ccSCallum Fare 
168*fd3907ccSCallum Fare   // ==========
169*fd3907ccSCallum Fare   OS << "template <typename T> struct is_handle : std::false_type {};\n";
170*fd3907ccSCallum Fare   for (auto *R : Records.getAllDerivedDefinitions("Handle")) {
171*fd3907ccSCallum Fare     HandleRec H{R};
172*fd3907ccSCallum Fare     OS << formatv("template <> struct is_handle<{0}> : std::true_type {{};\n",
173*fd3907ccSCallum Fare                   H.getName());
174*fd3907ccSCallum Fare   }
175*fd3907ccSCallum Fare   OS << "template <typename T> inline constexpr bool is_handle_v = "
176*fd3907ccSCallum Fare         "is_handle<T>::value;\n";
177*fd3907ccSCallum Fare   // =========
178*fd3907ccSCallum Fare 
179*fd3907ccSCallum Fare   // Forward declare the operator<< overloads so their implementations can
180*fd3907ccSCallum Fare   // use each other.
181*fd3907ccSCallum Fare   OS << "\n";
182*fd3907ccSCallum Fare   for (auto *R : Records.getAllDerivedDefinitions("Enum")) {
183*fd3907ccSCallum Fare     OS << formatv(
184*fd3907ccSCallum Fare         "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n",
185*fd3907ccSCallum Fare         EnumRec{R}.getName());
186*fd3907ccSCallum Fare   }
187*fd3907ccSCallum Fare   OS << "\n";
188*fd3907ccSCallum Fare 
189*fd3907ccSCallum Fare   // Create definitions
190*fd3907ccSCallum Fare   for (auto *R : Records.getAllDerivedDefinitions("Enum")) {
191*fd3907ccSCallum Fare     EnumRec E{R};
192*fd3907ccSCallum Fare     ProcessEnum(E, OS);
193*fd3907ccSCallum Fare   }
194*fd3907ccSCallum Fare   EmitResultPrint(OS);
195*fd3907ccSCallum Fare 
196*fd3907ccSCallum Fare   // Emit print functions for the function param structs
197*fd3907ccSCallum Fare   for (auto *R : Records.getAllDerivedDefinitions("Function")) {
198*fd3907ccSCallum Fare     EmitFunctionParamStructPrint(FunctionRec{R}, OS);
199*fd3907ccSCallum Fare   }
200*fd3907ccSCallum Fare 
201*fd3907ccSCallum Fare   OS << R"""(
202*fd3907ccSCallum Fare ///////////////////////////////////////////////////////////////////////////////
203*fd3907ccSCallum Fare // @brief Print pointer value
204*fd3907ccSCallum Fare template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr) {
205*fd3907ccSCallum Fare     if (ptr == nullptr) {
206*fd3907ccSCallum Fare         os << "nullptr";
207*fd3907ccSCallum Fare     } else if constexpr (std::is_pointer_v<T>) {
208*fd3907ccSCallum Fare         os << (const void *)(ptr) << " (";
209*fd3907ccSCallum Fare         printPtr(os, *ptr);
210*fd3907ccSCallum Fare         os << ")";
211*fd3907ccSCallum Fare     } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) {
212*fd3907ccSCallum Fare         os << (const void *)ptr;
213*fd3907ccSCallum Fare     } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) {
214*fd3907ccSCallum Fare         os << (const void *)(ptr) << " (";
215*fd3907ccSCallum Fare         os << ptr;
216*fd3907ccSCallum Fare         os << ")";
217*fd3907ccSCallum Fare     } else {
218*fd3907ccSCallum Fare         os << (const void *)(ptr) << " (";
219*fd3907ccSCallum Fare         os << *ptr;
220*fd3907ccSCallum Fare         os << ")";
221*fd3907ccSCallum Fare     }
222*fd3907ccSCallum Fare 
223*fd3907ccSCallum Fare     return OL_SUCCESS;
224*fd3907ccSCallum Fare }
225*fd3907ccSCallum Fare   )""";
226*fd3907ccSCallum Fare }
227