1 //===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a Tablegen backend that produces the actual entry points for the 10 // Offload API. It serves as a place to integrate functionality like tracing 11 // and validation before dispatching to the actual implementations. 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/FormatVariadic.h" 15 #include "llvm/TableGen/Record.h" 16 17 #include "GenCommon.hpp" 18 #include "RecordTypes.hpp" 19 20 using namespace llvm; 21 using namespace offload::tblgen; 22 23 static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { 24 OS << CommentsHeader; 25 // Emit preamble 26 OS << formatv("{0}_impl_result_t {1}_val(\n ", PrefixLower, F.getName()); 27 // Emit arguments 28 std::string ParamNameList = ""; 29 for (auto &Param : F.getParams()) { 30 OS << Param.getType() << " " << Param.getName(); 31 if (Param != F.getParams().back()) { 32 OS << ", "; 33 } 34 ParamNameList += Param.getName().str() + ", "; 35 } 36 OS << ") {\n"; 37 38 OS << TAB_1 "if (true /*enableParameterValidation*/) {\n"; 39 // Emit validation checks 40 for (const auto &Return : F.getReturns()) { 41 for (auto &Condition : Return.getConditions()) { 42 if (Condition.starts_with("`") && Condition.ends_with("`")) { 43 auto ConditionString = Condition.substr(1, Condition.size() - 2); 44 OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); 45 OS << formatv(TAB_3 "return {0};\n", Return.getValue()); 46 OS << TAB_2 "}\n\n"; 47 } 48 } 49 } 50 OS << TAB_1 "}\n\n"; 51 52 // Perform actual function call to the implementation 53 ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); 54 OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList); 55 OS << "}\n"; 56 } 57 58 static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { 59 // Emit preamble 60 OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n ", PrefixLower, 61 PrefixUpper, F.getName()); 62 // Emit arguments 63 std::string ParamNameList = ""; 64 for (auto &Param : F.getParams()) { 65 OS << Param.getType() << " " << Param.getName(); 66 if (Param != F.getParams().back()) { 67 OS << ", "; 68 } 69 ParamNameList += Param.getName().str() + ", "; 70 } 71 OS << ") {\n"; 72 73 // Emit pre-call prints 74 OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; 75 OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName()); 76 OS << TAB_1 "}\n\n"; 77 78 // Perform actual function call to the validation wrapper 79 ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); 80 OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower, 81 F.getName(), ParamNameList); 82 83 // Emit post-call prints 84 OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; 85 if (F.getParams().size() > 0) { 86 OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName()); 87 for (const auto &Param : F.getParams()) { 88 OS << "&" << Param.getName(); 89 if (Param != F.getParams().back()) { 90 OS << ", "; 91 } 92 } 93 OS << formatv("};\n"); 94 OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n"; 95 } else { 96 OS << TAB_2 "std::cout << \"()\";\n"; 97 } 98 OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n"; 99 OS << TAB_2 "if (Result && Result->Details) {\n"; 100 OS << TAB_3 "std::cout << \" *Error Details* \" << Result->Details " 101 "<< \" \\n\";\n"; 102 OS << TAB_2 "}\n"; 103 OS << TAB_1 "}\n"; 104 105 OS << TAB_1 "return Result;\n"; 106 OS << "}\n"; 107 } 108 109 static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { 110 // Emit preamble 111 OS << formatv("{0}_result_t {1}WithCodeLoc(\n ", PrefixLower, F.getName()); 112 // Emit arguments 113 std::string ParamNameList = ""; 114 for (auto &Param : F.getParams()) { 115 OS << Param.getType() << " " << Param.getName() << ", "; 116 ParamNameList += Param.getName().str(); 117 if (Param != F.getParams().back()) { 118 ParamNameList += ", "; 119 } 120 } 121 OS << "ol_code_location_t *CodeLocation"; 122 OS << ") {\n"; 123 OS << TAB_1 "currentCodeLocation() = CodeLocation;\n"; 124 OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower, 125 F.getName(), ParamNameList); 126 OS << TAB_1 "currentCodeLocation() = nullptr;\n"; 127 OS << TAB_1 "return Result;\n"; 128 OS << "}\n"; 129 } 130 131 void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) { 132 OS << GenericHeader; 133 for (auto *R : Records.getAllDerivedDefinitions("Function")) { 134 EmitValidationFunc(FunctionRec{R}, OS); 135 EmitEntryPointFunc(FunctionRec{R}, OS); 136 EmitCodeLocWrapper(FunctionRec{R}, OS); 137 } 138 } 139