xref: /llvm-project/offload/tools/offload-tblgen/EntryPointGen.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1*fd3907ccSCallum Fare //===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===//
2*fd3907ccSCallum Fare //
3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information.
5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fd3907ccSCallum Fare //
7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
8*fd3907ccSCallum Fare //
9*fd3907ccSCallum Fare // This is a Tablegen backend that produces the actual entry points for the
10*fd3907ccSCallum Fare // Offload API. It serves as a place to integrate functionality like tracing
11*fd3907ccSCallum Fare // and validation before dispatching to the actual implementations.
12*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
13*fd3907ccSCallum Fare 
14*fd3907ccSCallum Fare #include "llvm/Support/FormatVariadic.h"
15*fd3907ccSCallum Fare #include "llvm/TableGen/Record.h"
16*fd3907ccSCallum Fare 
17*fd3907ccSCallum Fare #include "GenCommon.hpp"
18*fd3907ccSCallum Fare #include "RecordTypes.hpp"
19*fd3907ccSCallum Fare 
20*fd3907ccSCallum Fare using namespace llvm;
21*fd3907ccSCallum Fare using namespace offload::tblgen;
22*fd3907ccSCallum Fare 
23*fd3907ccSCallum Fare static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
24*fd3907ccSCallum Fare   OS << CommentsHeader;
25*fd3907ccSCallum Fare   // Emit preamble
26*fd3907ccSCallum Fare   OS << formatv("{0}_impl_result_t {1}_val(\n  ", PrefixLower, F.getName());
27*fd3907ccSCallum Fare   // Emit arguments
28*fd3907ccSCallum Fare   std::string ParamNameList = "";
29*fd3907ccSCallum Fare   for (auto &Param : F.getParams()) {
30*fd3907ccSCallum Fare     OS << Param.getType() << " " << Param.getName();
31*fd3907ccSCallum Fare     if (Param != F.getParams().back()) {
32*fd3907ccSCallum Fare       OS << ", ";
33*fd3907ccSCallum Fare     }
34*fd3907ccSCallum Fare     ParamNameList += Param.getName().str() + ", ";
35*fd3907ccSCallum Fare   }
36*fd3907ccSCallum Fare   OS << ") {\n";
37*fd3907ccSCallum Fare 
38*fd3907ccSCallum Fare   OS << TAB_1 "if (true /*enableParameterValidation*/) {\n";
39*fd3907ccSCallum Fare   // Emit validation checks
40*fd3907ccSCallum Fare   for (const auto &Return : F.getReturns()) {
41*fd3907ccSCallum Fare     for (auto &Condition : Return.getConditions()) {
42*fd3907ccSCallum Fare       if (Condition.starts_with("`") && Condition.ends_with("`")) {
43*fd3907ccSCallum Fare         auto ConditionString = Condition.substr(1, Condition.size() - 2);
44*fd3907ccSCallum Fare         OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
45*fd3907ccSCallum Fare         OS << formatv(TAB_3 "return {0};\n", Return.getValue());
46*fd3907ccSCallum Fare         OS << TAB_2 "}\n\n";
47*fd3907ccSCallum Fare       }
48*fd3907ccSCallum Fare     }
49*fd3907ccSCallum Fare   }
50*fd3907ccSCallum Fare   OS << TAB_1 "}\n\n";
51*fd3907ccSCallum Fare 
52*fd3907ccSCallum Fare   // Perform actual function call to the implementation
53*fd3907ccSCallum Fare   ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
54*fd3907ccSCallum Fare   OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList);
55*fd3907ccSCallum Fare   OS << "}\n";
56*fd3907ccSCallum Fare }
57*fd3907ccSCallum Fare 
58*fd3907ccSCallum Fare static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
59*fd3907ccSCallum Fare   // Emit preamble
60*fd3907ccSCallum Fare   OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n  ", PrefixLower,
61*fd3907ccSCallum Fare                 PrefixUpper, F.getName());
62*fd3907ccSCallum Fare   // Emit arguments
63*fd3907ccSCallum Fare   std::string ParamNameList = "";
64*fd3907ccSCallum Fare   for (auto &Param : F.getParams()) {
65*fd3907ccSCallum Fare     OS << Param.getType() << " " << Param.getName();
66*fd3907ccSCallum Fare     if (Param != F.getParams().back()) {
67*fd3907ccSCallum Fare       OS << ", ";
68*fd3907ccSCallum Fare     }
69*fd3907ccSCallum Fare     ParamNameList += Param.getName().str() + ", ";
70*fd3907ccSCallum Fare   }
71*fd3907ccSCallum Fare   OS << ") {\n";
72*fd3907ccSCallum Fare 
73*fd3907ccSCallum Fare   // Emit pre-call prints
74*fd3907ccSCallum Fare   OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
75*fd3907ccSCallum Fare   OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName());
76*fd3907ccSCallum Fare   OS << TAB_1 "}\n\n";
77*fd3907ccSCallum Fare 
78*fd3907ccSCallum Fare   // Perform actual function call to the validation wrapper
79*fd3907ccSCallum Fare   ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
80*fd3907ccSCallum Fare   OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower,
81*fd3907ccSCallum Fare                 F.getName(), ParamNameList);
82*fd3907ccSCallum Fare 
83*fd3907ccSCallum Fare   // Emit post-call prints
84*fd3907ccSCallum Fare   OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
85*fd3907ccSCallum Fare   if (F.getParams().size() > 0) {
86*fd3907ccSCallum Fare     OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
87*fd3907ccSCallum Fare     for (const auto &Param : F.getParams()) {
88*fd3907ccSCallum Fare       OS << "&" << Param.getName();
89*fd3907ccSCallum Fare       if (Param != F.getParams().back()) {
90*fd3907ccSCallum Fare         OS << ", ";
91*fd3907ccSCallum Fare       }
92*fd3907ccSCallum Fare     }
93*fd3907ccSCallum Fare     OS << formatv("};\n");
94*fd3907ccSCallum Fare     OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n";
95*fd3907ccSCallum Fare   } else {
96*fd3907ccSCallum Fare     OS << TAB_2 "std::cout << \"()\";\n";
97*fd3907ccSCallum Fare   }
98*fd3907ccSCallum Fare   OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n";
99*fd3907ccSCallum Fare   OS << TAB_2 "if (Result && Result->Details) {\n";
100*fd3907ccSCallum Fare   OS << TAB_3 "std::cout << \"     *Error Details* \" << Result->Details "
101*fd3907ccSCallum Fare               "<< \" \\n\";\n";
102*fd3907ccSCallum Fare   OS << TAB_2 "}\n";
103*fd3907ccSCallum Fare   OS << TAB_1 "}\n";
104*fd3907ccSCallum Fare 
105*fd3907ccSCallum Fare   OS << TAB_1 "return Result;\n";
106*fd3907ccSCallum Fare   OS << "}\n";
107*fd3907ccSCallum Fare }
108*fd3907ccSCallum Fare 
109*fd3907ccSCallum Fare static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
110*fd3907ccSCallum Fare   // Emit preamble
111*fd3907ccSCallum Fare   OS << formatv("{0}_result_t {1}WithCodeLoc(\n  ", PrefixLower, F.getName());
112*fd3907ccSCallum Fare   // Emit arguments
113*fd3907ccSCallum Fare   std::string ParamNameList = "";
114*fd3907ccSCallum Fare   for (auto &Param : F.getParams()) {
115*fd3907ccSCallum Fare     OS << Param.getType() << " " << Param.getName() << ", ";
116*fd3907ccSCallum Fare     ParamNameList += Param.getName().str();
117*fd3907ccSCallum Fare     if (Param != F.getParams().back()) {
118*fd3907ccSCallum Fare       ParamNameList += ", ";
119*fd3907ccSCallum Fare     }
120*fd3907ccSCallum Fare   }
121*fd3907ccSCallum Fare   OS << "ol_code_location_t *CodeLocation";
122*fd3907ccSCallum Fare   OS << ") {\n";
123*fd3907ccSCallum Fare   OS << TAB_1 "currentCodeLocation() = CodeLocation;\n";
124*fd3907ccSCallum Fare   OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower,
125*fd3907ccSCallum Fare                 F.getName(), ParamNameList);
126*fd3907ccSCallum Fare   OS << TAB_1 "currentCodeLocation() = nullptr;\n";
127*fd3907ccSCallum Fare   OS << TAB_1 "return Result;\n";
128*fd3907ccSCallum Fare   OS << "}\n";
129*fd3907ccSCallum Fare }
130*fd3907ccSCallum Fare 
131*fd3907ccSCallum Fare void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
132*fd3907ccSCallum Fare   OS << GenericHeader;
133*fd3907ccSCallum Fare   for (auto *R : Records.getAllDerivedDefinitions("Function")) {
134*fd3907ccSCallum Fare     EmitValidationFunc(FunctionRec{R}, OS);
135*fd3907ccSCallum Fare     EmitEntryPointFunc(FunctionRec{R}, OS);
136*fd3907ccSCallum Fare     EmitCodeLocWrapper(FunctionRec{R}, OS);
137*fd3907ccSCallum Fare   }
138*fd3907ccSCallum Fare }
139