1*fd3907ccSCallum Fare //===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===// 2*fd3907ccSCallum Fare // 3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information. 5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*fd3907ccSCallum Fare // 7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===// 8*fd3907ccSCallum Fare // 9*fd3907ccSCallum Fare // This is a Tablegen backend that produces the actual entry points for the 10*fd3907ccSCallum Fare // Offload API. It serves as a place to integrate functionality like tracing 11*fd3907ccSCallum Fare // and validation before dispatching to the actual implementations. 12*fd3907ccSCallum Fare //===----------------------------------------------------------------------===// 13*fd3907ccSCallum Fare 14*fd3907ccSCallum Fare #include "llvm/Support/FormatVariadic.h" 15*fd3907ccSCallum Fare #include "llvm/TableGen/Record.h" 16*fd3907ccSCallum Fare 17*fd3907ccSCallum Fare #include "GenCommon.hpp" 18*fd3907ccSCallum Fare #include "RecordTypes.hpp" 19*fd3907ccSCallum Fare 20*fd3907ccSCallum Fare using namespace llvm; 21*fd3907ccSCallum Fare using namespace offload::tblgen; 22*fd3907ccSCallum Fare 23*fd3907ccSCallum Fare static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { 24*fd3907ccSCallum Fare OS << CommentsHeader; 25*fd3907ccSCallum Fare // Emit preamble 26*fd3907ccSCallum Fare OS << formatv("{0}_impl_result_t {1}_val(\n ", PrefixLower, F.getName()); 27*fd3907ccSCallum Fare // Emit arguments 28*fd3907ccSCallum Fare std::string ParamNameList = ""; 29*fd3907ccSCallum Fare for (auto &Param : F.getParams()) { 30*fd3907ccSCallum Fare OS << Param.getType() << " " << Param.getName(); 31*fd3907ccSCallum Fare if (Param != F.getParams().back()) { 32*fd3907ccSCallum Fare OS << ", "; 33*fd3907ccSCallum Fare } 34*fd3907ccSCallum Fare ParamNameList += Param.getName().str() + ", "; 35*fd3907ccSCallum Fare } 36*fd3907ccSCallum Fare OS << ") {\n"; 37*fd3907ccSCallum Fare 38*fd3907ccSCallum Fare OS << TAB_1 "if (true /*enableParameterValidation*/) {\n"; 39*fd3907ccSCallum Fare // Emit validation checks 40*fd3907ccSCallum Fare for (const auto &Return : F.getReturns()) { 41*fd3907ccSCallum Fare for (auto &Condition : Return.getConditions()) { 42*fd3907ccSCallum Fare if (Condition.starts_with("`") && Condition.ends_with("`")) { 43*fd3907ccSCallum Fare auto ConditionString = Condition.substr(1, Condition.size() - 2); 44*fd3907ccSCallum Fare OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); 45*fd3907ccSCallum Fare OS << formatv(TAB_3 "return {0};\n", Return.getValue()); 46*fd3907ccSCallum Fare OS << TAB_2 "}\n\n"; 47*fd3907ccSCallum Fare } 48*fd3907ccSCallum Fare } 49*fd3907ccSCallum Fare } 50*fd3907ccSCallum Fare OS << TAB_1 "}\n\n"; 51*fd3907ccSCallum Fare 52*fd3907ccSCallum Fare // Perform actual function call to the implementation 53*fd3907ccSCallum Fare ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); 54*fd3907ccSCallum Fare OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList); 55*fd3907ccSCallum Fare OS << "}\n"; 56*fd3907ccSCallum Fare } 57*fd3907ccSCallum Fare 58*fd3907ccSCallum Fare static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { 59*fd3907ccSCallum Fare // Emit preamble 60*fd3907ccSCallum Fare OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n ", PrefixLower, 61*fd3907ccSCallum Fare PrefixUpper, F.getName()); 62*fd3907ccSCallum Fare // Emit arguments 63*fd3907ccSCallum Fare std::string ParamNameList = ""; 64*fd3907ccSCallum Fare for (auto &Param : F.getParams()) { 65*fd3907ccSCallum Fare OS << Param.getType() << " " << Param.getName(); 66*fd3907ccSCallum Fare if (Param != F.getParams().back()) { 67*fd3907ccSCallum Fare OS << ", "; 68*fd3907ccSCallum Fare } 69*fd3907ccSCallum Fare ParamNameList += Param.getName().str() + ", "; 70*fd3907ccSCallum Fare } 71*fd3907ccSCallum Fare OS << ") {\n"; 72*fd3907ccSCallum Fare 73*fd3907ccSCallum Fare // Emit pre-call prints 74*fd3907ccSCallum Fare OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; 75*fd3907ccSCallum Fare OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName()); 76*fd3907ccSCallum Fare OS << TAB_1 "}\n\n"; 77*fd3907ccSCallum Fare 78*fd3907ccSCallum Fare // Perform actual function call to the validation wrapper 79*fd3907ccSCallum Fare ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); 80*fd3907ccSCallum Fare OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower, 81*fd3907ccSCallum Fare F.getName(), ParamNameList); 82*fd3907ccSCallum Fare 83*fd3907ccSCallum Fare // Emit post-call prints 84*fd3907ccSCallum Fare OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; 85*fd3907ccSCallum Fare if (F.getParams().size() > 0) { 86*fd3907ccSCallum Fare OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName()); 87*fd3907ccSCallum Fare for (const auto &Param : F.getParams()) { 88*fd3907ccSCallum Fare OS << "&" << Param.getName(); 89*fd3907ccSCallum Fare if (Param != F.getParams().back()) { 90*fd3907ccSCallum Fare OS << ", "; 91*fd3907ccSCallum Fare } 92*fd3907ccSCallum Fare } 93*fd3907ccSCallum Fare OS << formatv("};\n"); 94*fd3907ccSCallum Fare OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n"; 95*fd3907ccSCallum Fare } else { 96*fd3907ccSCallum Fare OS << TAB_2 "std::cout << \"()\";\n"; 97*fd3907ccSCallum Fare } 98*fd3907ccSCallum Fare OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n"; 99*fd3907ccSCallum Fare OS << TAB_2 "if (Result && Result->Details) {\n"; 100*fd3907ccSCallum Fare OS << TAB_3 "std::cout << \" *Error Details* \" << Result->Details " 101*fd3907ccSCallum Fare "<< \" \\n\";\n"; 102*fd3907ccSCallum Fare OS << TAB_2 "}\n"; 103*fd3907ccSCallum Fare OS << TAB_1 "}\n"; 104*fd3907ccSCallum Fare 105*fd3907ccSCallum Fare OS << TAB_1 "return Result;\n"; 106*fd3907ccSCallum Fare OS << "}\n"; 107*fd3907ccSCallum Fare } 108*fd3907ccSCallum Fare 109*fd3907ccSCallum Fare static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { 110*fd3907ccSCallum Fare // Emit preamble 111*fd3907ccSCallum Fare OS << formatv("{0}_result_t {1}WithCodeLoc(\n ", PrefixLower, F.getName()); 112*fd3907ccSCallum Fare // Emit arguments 113*fd3907ccSCallum Fare std::string ParamNameList = ""; 114*fd3907ccSCallum Fare for (auto &Param : F.getParams()) { 115*fd3907ccSCallum Fare OS << Param.getType() << " " << Param.getName() << ", "; 116*fd3907ccSCallum Fare ParamNameList += Param.getName().str(); 117*fd3907ccSCallum Fare if (Param != F.getParams().back()) { 118*fd3907ccSCallum Fare ParamNameList += ", "; 119*fd3907ccSCallum Fare } 120*fd3907ccSCallum Fare } 121*fd3907ccSCallum Fare OS << "ol_code_location_t *CodeLocation"; 122*fd3907ccSCallum Fare OS << ") {\n"; 123*fd3907ccSCallum Fare OS << TAB_1 "currentCodeLocation() = CodeLocation;\n"; 124*fd3907ccSCallum Fare OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower, 125*fd3907ccSCallum Fare F.getName(), ParamNameList); 126*fd3907ccSCallum Fare OS << TAB_1 "currentCodeLocation() = nullptr;\n"; 127*fd3907ccSCallum Fare OS << TAB_1 "return Result;\n"; 128*fd3907ccSCallum Fare OS << "}\n"; 129*fd3907ccSCallum Fare } 130*fd3907ccSCallum Fare 131*fd3907ccSCallum Fare void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) { 132*fd3907ccSCallum Fare OS << GenericHeader; 133*fd3907ccSCallum Fare for (auto *R : Records.getAllDerivedDefinitions("Function")) { 134*fd3907ccSCallum Fare EmitValidationFunc(FunctionRec{R}, OS); 135*fd3907ccSCallum Fare EmitEntryPointFunc(FunctionRec{R}, OS); 136*fd3907ccSCallum Fare EmitCodeLocWrapper(FunctionRec{R}, OS); 137*fd3907ccSCallum Fare } 138*fd3907ccSCallum Fare } 139