xref: /llvm-project/offload/unittests/OffloadAPI/common/Environment.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1 //===------- Offload API tests - gtest environment ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Environment.hpp"
10 #include "Fixtures.hpp"
11 #include "llvm/Support/CommandLine.h"
12 #include <OffloadAPI.h>
13 
14 using namespace llvm;
15 
16 // Wrapper so we don't have to constantly init and shutdown Offload in every
17 // test, while having sensible lifetime for the platform environment
18 struct OffloadInitWrapper {
19   OffloadInitWrapper() { olInit(); }
20   ~OffloadInitWrapper() { olShutDown(); }
21 };
22 static OffloadInitWrapper Wrapper{};
23 
24 static cl::opt<std::string>
25     SelectedPlatform("platform", cl::desc("Only test the specified platform"),
26                      cl::value_desc("platform"));
27 
28 std::ostream &operator<<(std::ostream &Out,
29                          const ol_platform_handle_t &Platform) {
30   size_t Size;
31   olGetPlatformInfoSize(Platform, OL_PLATFORM_INFO_NAME, &Size);
32   std::vector<char> Name(Size);
33   olGetPlatformInfo(Platform, OL_PLATFORM_INFO_NAME, Size, Name.data());
34   Out << Name.data();
35   return Out;
36 }
37 
38 std::ostream &operator<<(std::ostream &Out,
39                          const std::vector<ol_platform_handle_t> &Platforms) {
40   for (auto Platform : Platforms) {
41     Out << "\n  * \"" << Platform << "\"";
42   }
43   return Out;
44 }
45 
46 const std::vector<ol_platform_handle_t> &TestEnvironment::getPlatforms() {
47   static std::vector<ol_platform_handle_t> Platforms{};
48 
49   if (Platforms.empty()) {
50     uint32_t PlatformCount = 0;
51     olGetPlatformCount(&PlatformCount);
52     if (PlatformCount > 0) {
53       Platforms.resize(PlatformCount);
54       olGetPlatform(PlatformCount, Platforms.data());
55     }
56   }
57 
58   return Platforms;
59 }
60 
61 // Get a single platform, which may be selected by the user.
62 ol_platform_handle_t TestEnvironment::getPlatform() {
63   static ol_platform_handle_t Platform = nullptr;
64   const auto &Platforms = getPlatforms();
65 
66   if (!Platform) {
67     if (SelectedPlatform != "") {
68       for (const auto CandidatePlatform : Platforms) {
69         std::stringstream PlatformName;
70         PlatformName << CandidatePlatform;
71         if (SelectedPlatform == PlatformName.str()) {
72           Platform = CandidatePlatform;
73           return Platform;
74         }
75       }
76       std::cout << "No platform found with the name \"" << SelectedPlatform
77                 << "\". Choose from:" << Platforms << "\n";
78       std::exit(1);
79     } else {
80       // Pick a single platform. We prefer one that has available devices, but
81       // just pick the first initially in case none have any devices.
82       Platform = Platforms[0];
83       for (auto CandidatePlatform : Platforms) {
84         uint32_t NumDevices = 0;
85         if (olGetDeviceCount(CandidatePlatform, &NumDevices) == OL_SUCCESS) {
86           if (NumDevices > 0) {
87             Platform = CandidatePlatform;
88             break;
89           }
90         }
91       }
92     }
93   }
94 
95   return Platform;
96 }
97