xref: /llvm-project/offload/tools/offload-tblgen/EntryPointGen.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1 //===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a Tablegen backend that produces the actual entry points for the
10 // Offload API. It serves as a place to integrate functionality like tracing
11 // and validation before dispatching to the actual implementations.
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Record.h"
16 
17 #include "GenCommon.hpp"
18 #include "RecordTypes.hpp"
19 
20 using namespace llvm;
21 using namespace offload::tblgen;
22 
23 static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
24   OS << CommentsHeader;
25   // Emit preamble
26   OS << formatv("{0}_impl_result_t {1}_val(\n  ", PrefixLower, F.getName());
27   // Emit arguments
28   std::string ParamNameList = "";
29   for (auto &Param : F.getParams()) {
30     OS << Param.getType() << " " << Param.getName();
31     if (Param != F.getParams().back()) {
32       OS << ", ";
33     }
34     ParamNameList += Param.getName().str() + ", ";
35   }
36   OS << ") {\n";
37 
38   OS << TAB_1 "if (true /*enableParameterValidation*/) {\n";
39   // Emit validation checks
40   for (const auto &Return : F.getReturns()) {
41     for (auto &Condition : Return.getConditions()) {
42       if (Condition.starts_with("`") && Condition.ends_with("`")) {
43         auto ConditionString = Condition.substr(1, Condition.size() - 2);
44         OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
45         OS << formatv(TAB_3 "return {0};\n", Return.getValue());
46         OS << TAB_2 "}\n\n";
47       }
48     }
49   }
50   OS << TAB_1 "}\n\n";
51 
52   // Perform actual function call to the implementation
53   ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
54   OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList);
55   OS << "}\n";
56 }
57 
58 static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
59   // Emit preamble
60   OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n  ", PrefixLower,
61                 PrefixUpper, F.getName());
62   // Emit arguments
63   std::string ParamNameList = "";
64   for (auto &Param : F.getParams()) {
65     OS << Param.getType() << " " << Param.getName();
66     if (Param != F.getParams().back()) {
67       OS << ", ";
68     }
69     ParamNameList += Param.getName().str() + ", ";
70   }
71   OS << ") {\n";
72 
73   // Emit pre-call prints
74   OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
75   OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName());
76   OS << TAB_1 "}\n\n";
77 
78   // Perform actual function call to the validation wrapper
79   ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
80   OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower,
81                 F.getName(), ParamNameList);
82 
83   // Emit post-call prints
84   OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
85   if (F.getParams().size() > 0) {
86     OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
87     for (const auto &Param : F.getParams()) {
88       OS << "&" << Param.getName();
89       if (Param != F.getParams().back()) {
90         OS << ", ";
91       }
92     }
93     OS << formatv("};\n");
94     OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n";
95   } else {
96     OS << TAB_2 "std::cout << \"()\";\n";
97   }
98   OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n";
99   OS << TAB_2 "if (Result && Result->Details) {\n";
100   OS << TAB_3 "std::cout << \"     *Error Details* \" << Result->Details "
101               "<< \" \\n\";\n";
102   OS << TAB_2 "}\n";
103   OS << TAB_1 "}\n";
104 
105   OS << TAB_1 "return Result;\n";
106   OS << "}\n";
107 }
108 
109 static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
110   // Emit preamble
111   OS << formatv("{0}_result_t {1}WithCodeLoc(\n  ", PrefixLower, F.getName());
112   // Emit arguments
113   std::string ParamNameList = "";
114   for (auto &Param : F.getParams()) {
115     OS << Param.getType() << " " << Param.getName() << ", ";
116     ParamNameList += Param.getName().str();
117     if (Param != F.getParams().back()) {
118       ParamNameList += ", ";
119     }
120   }
121   OS << "ol_code_location_t *CodeLocation";
122   OS << ") {\n";
123   OS << TAB_1 "currentCodeLocation() = CodeLocation;\n";
124   OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower,
125                 F.getName(), ParamNameList);
126   OS << TAB_1 "currentCodeLocation() = nullptr;\n";
127   OS << TAB_1 "return Result;\n";
128   OS << "}\n";
129 }
130 
131 void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
132   OS << GenericHeader;
133   for (auto *R : Records.getAllDerivedDefinitions("Function")) {
134     EmitValidationFunc(FunctionRec{R}, OS);
135     EmitEntryPointFunc(FunctionRec{R}, OS);
136     EmitCodeLocWrapper(FunctionRec{R}, OS);
137   }
138 }
139