xref: /llvm-project/clang/unittests/Driver/DXCModeTest.cpp (revision 6d8901488f160dd92aea5b98fcc21c7fa7c1cbe6)
1 //===- unittests/Driver/DXCModeTest.cpp --- DXC Mode tests ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Unit tests for driver DXCMode.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/Basic/DiagnosticIDs.h"
14 #include "clang/Basic/DiagnosticOptions.h"
15 #include "clang/Basic/LLVM.h"
16 #include "clang/Basic/TargetOptions.h"
17 #include "clang/Driver/Compilation.h"
18 #include "clang/Driver/Driver.h"
19 #include "clang/Driver/ToolChain.h"
20 #include "clang/Frontend/CompilerInstance.h"
21 #include "llvm/Support/VirtualFileSystem.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "gtest/gtest.h"
24 #include <memory>
25 
26 #include "SimpleDiagnosticConsumer.h"
27 
28 using namespace clang;
29 using namespace clang::driver;
30 
31 static void validateTargetProfile(
32     StringRef TargetProfile, StringRef ExpectTriple,
33     IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem,
34     DiagnosticsEngine &Diags) {
35   Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem);
36   std::unique_ptr<Compilation> C{TheDriver.BuildCompilation(
37       {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})};
38   EXPECT_TRUE(C);
39   EXPECT_STREQ(TheDriver.getTargetTriple().c_str(), ExpectTriple.data());
40   EXPECT_EQ(Diags.getNumErrors(), 0u);
41 }
42 
43 static void validateTargetProfile(
44     StringRef TargetProfile, StringRef ExpectError,
45     IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem,
46     DiagnosticsEngine &Diags, SimpleDiagnosticConsumer *DiagConsumer,
47     unsigned NumOfErrors) {
48   Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem);
49   std::unique_ptr<Compilation> C{TheDriver.BuildCompilation(
50       {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})};
51   EXPECT_TRUE(C);
52   EXPECT_EQ(Diags.getNumErrors(), NumOfErrors);
53   EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), ExpectError.data());
54   Diags.Clear();
55   DiagConsumer->clear();
56 }
57 
58 TEST(DxcModeTest, TargetProfileValidation) {
59   IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs());
60 
61   IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem(
62       new llvm::vfs::InMemoryFileSystem);
63 
64   InMemoryFileSystem->addFile("foo.hlsl", 0,
65                               llvm::MemoryBuffer::getMemBuffer("\n"));
66 
67   auto *DiagConsumer = new SimpleDiagnosticConsumer;
68   IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
69   DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer);
70 
71   validateTargetProfile("-Tvs_6_0", "dxilv1.0--shadermodel6.0-vertex",
72                         InMemoryFileSystem, Diags);
73   validateTargetProfile("-Ths_6_1", "dxilv1.1--shadermodel6.1-hull",
74                         InMemoryFileSystem, Diags);
75   validateTargetProfile("-Tds_6_2", "dxilv1.2--shadermodel6.2-domain",
76                         InMemoryFileSystem, Diags);
77   validateTargetProfile("-Tds_6_2", "dxilv1.2--shadermodel6.2-domain",
78                         InMemoryFileSystem, Diags);
79   validateTargetProfile("-Tgs_6_3", "dxilv1.3--shadermodel6.3-geometry",
80                         InMemoryFileSystem, Diags);
81   validateTargetProfile("-Tps_6_4", "dxilv1.4--shadermodel6.4-pixel",
82                         InMemoryFileSystem, Diags);
83   validateTargetProfile("-Tcs_6_5", "dxilv1.5--shadermodel6.5-compute",
84                         InMemoryFileSystem, Diags);
85   validateTargetProfile("-Tms_6_6", "dxilv1.6--shadermodel6.6-mesh",
86                         InMemoryFileSystem, Diags);
87   validateTargetProfile("-Tas_6_7", "dxilv1.7--shadermodel6.7-amplification",
88                         InMemoryFileSystem, Diags);
89   validateTargetProfile("-Tcs_6_8", "dxilv1.8--shadermodel6.8-compute",
90                         InMemoryFileSystem, Diags);
91   validateTargetProfile("-Tlib_6_x", "dxilv1.8--shadermodel6.15-library",
92                         InMemoryFileSystem, Diags);
93 
94   // Invalid tests.
95   validateTargetProfile("-Tpss_6_1", "invalid profile : pss_6_1",
96                         InMemoryFileSystem, Diags, DiagConsumer, 1);
97 
98   validateTargetProfile("-Tps_6_x", "invalid profile : ps_6_x",
99                         InMemoryFileSystem, Diags, DiagConsumer, 2);
100   validateTargetProfile("-Tlib_6_1", "invalid profile : lib_6_1",
101                         InMemoryFileSystem, Diags, DiagConsumer, 3);
102   validateTargetProfile("-Tfoo", "invalid profile : foo", InMemoryFileSystem,
103                         Diags, DiagConsumer, 4);
104   validateTargetProfile("", "target profile option (-T) is missing",
105                         InMemoryFileSystem, Diags, DiagConsumer, 5);
106 }
107 
108 TEST(DxcModeTest, ValidatorVersionValidation) {
109   IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs());
110 
111   IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem(
112       new llvm::vfs::InMemoryFileSystem);
113 
114   InMemoryFileSystem->addFile("foo.hlsl", 0,
115                               llvm::MemoryBuffer::getMemBuffer("\n"));
116 
117   auto *DiagConsumer = new SimpleDiagnosticConsumer;
118   IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
119   DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer);
120   Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem);
121   std::unique_ptr<Compilation> C(TheDriver.BuildCompilation(
122       {"clang", "--driver-mode=dxc", "-Tlib_6_7", "foo.hlsl"}));
123   EXPECT_TRUE(C);
124   EXPECT_TRUE(!C->containsError());
125 
126   auto &TC = C->getDefaultToolChain();
127   bool ContainsError = false;
128   auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false,
129                                         ContainsError);
130   EXPECT_FALSE(ContainsError);
131   auto DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
132   for (auto *A : Args)
133     DAL->append(A);
134 
135   std::unique_ptr<llvm::opt::DerivedArgList> TranslatedArgs{
136       TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)};
137   EXPECT_NE(TranslatedArgs, nullptr);
138   if (TranslatedArgs) {
139     auto *A = TranslatedArgs->getLastArg(
140         clang::driver::options::OPT_dxil_validator_version);
141     EXPECT_NE(A, nullptr);
142     if (A) {
143       EXPECT_STREQ(A->getValue(), "1.1");
144     }
145   }
146   EXPECT_EQ(Diags.getNumErrors(), 0u);
147 
148   // Invalid tests.
149   Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false,
150                                    ContainsError);
151   EXPECT_FALSE(ContainsError);
152   DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
153   for (auto *A : Args)
154     DAL->append(A);
155 
156   TranslatedArgs.reset(
157       TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None));
158   EXPECT_EQ(Diags.getNumErrors(), 1u);
159   EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
160                "invalid validator version : 0.1\nIf validator major version is "
161                "0, minor version must also be 0.");
162   Diags.Clear();
163   DiagConsumer->clear();
164 
165   Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false,
166                                    ContainsError);
167   EXPECT_FALSE(ContainsError);
168   DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
169   for (auto *A : Args)
170     DAL->append(A);
171 
172   TranslatedArgs.reset(
173       TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None));
174   EXPECT_EQ(Diags.getNumErrors(), 2u);
175   EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
176                "invalid validator version : 1\nFormat of validator version is "
177                "\"<major>.<minor>\" (ex:\"1.4\").");
178   Diags.Clear();
179   DiagConsumer->clear();
180 
181   Args = TheDriver.ParseArgStrings({"-validator-version", "-Tlib_6_7"}, false,
182                                    ContainsError);
183   EXPECT_FALSE(ContainsError);
184   DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
185   for (auto *A : Args)
186     DAL->append(A);
187 
188   TranslatedArgs.reset(
189       TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None));
190   EXPECT_EQ(Diags.getNumErrors(), 3u);
191   EXPECT_STREQ(
192       DiagConsumer->Errors.back().c_str(),
193       "invalid validator version : -Tlib_6_7\nFormat of validator version is "
194       "\"<major>.<minor>\" (ex:\"1.4\").");
195   Diags.Clear();
196   DiagConsumer->clear();
197 
198   Args = TheDriver.ParseArgStrings({"-validator-version", "foo"}, false,
199                                    ContainsError);
200   EXPECT_FALSE(ContainsError);
201   DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
202   for (auto *A : Args)
203     DAL->append(A);
204 
205   TranslatedArgs.reset(
206       TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None));
207   EXPECT_EQ(Diags.getNumErrors(), 4u);
208   EXPECT_STREQ(
209       DiagConsumer->Errors.back().c_str(),
210       "invalid validator version : foo\nFormat of validator version is "
211       "\"<major>.<minor>\" (ex:\"1.4\").");
212   Diags.Clear();
213   DiagConsumer->clear();
214 }
215 
216 TEST(DxcModeTest, DefaultEntry) {
217   IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem(
218       new llvm::vfs::InMemoryFileSystem);
219 
220   InMemoryFileSystem->addFile("foo.hlsl", 0,
221                               llvm::MemoryBuffer::getMemBuffer("\n"));
222 
223   const char *Args[] = {"clang", "--driver-mode=dxc", "-Tcs_6_7", "foo.hlsl"};
224 
225   IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
226       CompilerInstance::createDiagnostics(new DiagnosticOptions());
227 
228   CreateInvocationOptions CIOpts;
229   CIOpts.Diags = Diags;
230   std::unique_ptr<CompilerInvocation> CInvok =
231       createInvocation(Args, std::move(CIOpts));
232   EXPECT_TRUE(CInvok);
233   // Make sure default entry is "main".
234   EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "main");
235 
236   const char *EntryArgs[] = {"clang", "--driver-mode=dxc", "-Ebar", "-Tcs_6_7",
237                              "foo.hlsl"};
238   CInvok = createInvocation(EntryArgs, std::move(CIOpts));
239   EXPECT_TRUE(CInvok);
240   // Make sure "-E" will set entry.
241   EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "bar");
242 }
243