xref: /llvm-project/offload/liboffload/include/OffloadImpl.hpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1*fd3907ccSCallum Fare //===- offload_impl.hpp- Implementation helpers for the Offload library ---===//
2*fd3907ccSCallum Fare //
3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information.
5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fd3907ccSCallum Fare //
7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
8*fd3907ccSCallum Fare #pragma once
9*fd3907ccSCallum Fare 
10*fd3907ccSCallum Fare #include <OffloadAPI.h>
11*fd3907ccSCallum Fare #include <iostream>
12*fd3907ccSCallum Fare #include <memory>
13*fd3907ccSCallum Fare #include <optional>
14*fd3907ccSCallum Fare #include <set>
15*fd3907ccSCallum Fare #include <string>
16*fd3907ccSCallum Fare #include <unordered_set>
17*fd3907ccSCallum Fare #include <vector>
18*fd3907ccSCallum Fare 
19*fd3907ccSCallum Fare #include "llvm/ADT/DenseSet.h"
20*fd3907ccSCallum Fare #include "llvm/ADT/StringRef.h"
21*fd3907ccSCallum Fare #include "llvm/ADT/StringSet.h"
22*fd3907ccSCallum Fare 
23*fd3907ccSCallum Fare struct OffloadConfig {
24*fd3907ccSCallum Fare   bool TracingEnabled = false;
25*fd3907ccSCallum Fare };
26*fd3907ccSCallum Fare 
27*fd3907ccSCallum Fare OffloadConfig &offloadConfig();
28*fd3907ccSCallum Fare 
29*fd3907ccSCallum Fare // Use the StringSet container to efficiently deduplicate repeated error
30*fd3907ccSCallum Fare // strings (e.g. if the same error is hit constantly in a long running program)
31*fd3907ccSCallum Fare llvm::StringSet<> &errorStrs();
32*fd3907ccSCallum Fare 
33*fd3907ccSCallum Fare // Use an unordered_set to avoid duplicates of error structs themselves.
34*fd3907ccSCallum Fare // We cannot store the structs directly as returned pointers to them must always
35*fd3907ccSCallum Fare // be valid, and a rehash of the set may invalidate them. This requires
36*fd3907ccSCallum Fare // custom hash and equal_to function objects.
37*fd3907ccSCallum Fare using ErrPtrT = std::unique_ptr<ol_error_struct_t>;
38*fd3907ccSCallum Fare struct ErrPtrEqual {
39*fd3907ccSCallum Fare   bool operator()(const ErrPtrT &lhs, const ErrPtrT &rhs) const {
40*fd3907ccSCallum Fare     if (!lhs && !rhs) {
41*fd3907ccSCallum Fare       return true;
42*fd3907ccSCallum Fare     }
43*fd3907ccSCallum Fare     if (!lhs || !rhs) {
44*fd3907ccSCallum Fare       return false;
45*fd3907ccSCallum Fare     }
46*fd3907ccSCallum Fare 
47*fd3907ccSCallum Fare     bool StrsEqual = false;
48*fd3907ccSCallum Fare     if (lhs->Details == NULL && rhs->Details == NULL) {
49*fd3907ccSCallum Fare       StrsEqual = true;
50*fd3907ccSCallum Fare     } else if (lhs->Details != NULL && rhs->Details != NULL) {
51*fd3907ccSCallum Fare       StrsEqual = (std::strcmp(lhs->Details, rhs->Details) == 0);
52*fd3907ccSCallum Fare     }
53*fd3907ccSCallum Fare     return (lhs->Code == rhs->Code) && StrsEqual;
54*fd3907ccSCallum Fare   }
55*fd3907ccSCallum Fare };
56*fd3907ccSCallum Fare struct ErrPtrHash {
57*fd3907ccSCallum Fare   size_t operator()(const ErrPtrT &e) const {
58*fd3907ccSCallum Fare     if (!e) {
59*fd3907ccSCallum Fare       // We shouldn't store empty errors (i.e. success), but just in case
60*fd3907ccSCallum Fare       return 0lu;
61*fd3907ccSCallum Fare     } else {
62*fd3907ccSCallum Fare       return std::hash<int>{}(e->Code);
63*fd3907ccSCallum Fare     }
64*fd3907ccSCallum Fare   }
65*fd3907ccSCallum Fare };
66*fd3907ccSCallum Fare using ErrSetT = std::unordered_set<ErrPtrT, ErrPtrHash, ErrPtrEqual>;
67*fd3907ccSCallum Fare ErrSetT &errors();
68*fd3907ccSCallum Fare 
69*fd3907ccSCallum Fare struct ol_impl_result_t {
70*fd3907ccSCallum Fare   ol_impl_result_t(std::nullptr_t) : Result(OL_SUCCESS) {}
71*fd3907ccSCallum Fare   ol_impl_result_t(ol_errc_t Code) {
72*fd3907ccSCallum Fare     if (Code == OL_ERRC_SUCCESS) {
73*fd3907ccSCallum Fare       Result = nullptr;
74*fd3907ccSCallum Fare     } else {
75*fd3907ccSCallum Fare       auto Err = std::unique_ptr<ol_error_struct_t>(
76*fd3907ccSCallum Fare           new ol_error_struct_t{Code, nullptr});
77*fd3907ccSCallum Fare       Result = errors().emplace(std::move(Err)).first->get();
78*fd3907ccSCallum Fare     }
79*fd3907ccSCallum Fare   }
80*fd3907ccSCallum Fare 
81*fd3907ccSCallum Fare   ol_impl_result_t(ol_errc_t Code, llvm::StringRef Details) {
82*fd3907ccSCallum Fare     assert(Code != OL_ERRC_SUCCESS);
83*fd3907ccSCallum Fare     Result = nullptr;
84*fd3907ccSCallum Fare     auto DetailsStr = errorStrs().insert(Details).first->getKeyData();
85*fd3907ccSCallum Fare     auto Err = std::unique_ptr<ol_error_struct_t>(
86*fd3907ccSCallum Fare         new ol_error_struct_t{Code, DetailsStr});
87*fd3907ccSCallum Fare     Result = errors().emplace(std::move(Err)).first->get();
88*fd3907ccSCallum Fare   }
89*fd3907ccSCallum Fare 
90*fd3907ccSCallum Fare   operator ol_result_t() { return Result; }
91*fd3907ccSCallum Fare 
92*fd3907ccSCallum Fare private:
93*fd3907ccSCallum Fare   ol_result_t Result;
94*fd3907ccSCallum Fare };
95