1 //===- unittests/Driver/DXCModeTest.cpp --- DXC Mode tests ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Unit tests for driver DXCMode. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "clang/Basic/DiagnosticIDs.h" 14 #include "clang/Basic/DiagnosticOptions.h" 15 #include "clang/Basic/LLVM.h" 16 #include "clang/Basic/TargetOptions.h" 17 #include "clang/Driver/Compilation.h" 18 #include "clang/Driver/Driver.h" 19 #include "clang/Driver/ToolChain.h" 20 #include "clang/Frontend/CompilerInstance.h" 21 #include "llvm/Support/VirtualFileSystem.h" 22 #include "llvm/Support/raw_ostream.h" 23 #include "gtest/gtest.h" 24 #include <memory> 25 26 #include "SimpleDiagnosticConsumer.h" 27 28 using namespace clang; 29 using namespace clang::driver; 30 31 static void validateTargetProfile( 32 StringRef TargetProfile, StringRef ExpectTriple, 33 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem, 34 DiagnosticsEngine &Diags) { 35 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 36 std::unique_ptr<Compilation> C{TheDriver.BuildCompilation( 37 {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})}; 38 EXPECT_TRUE(C); 39 EXPECT_STREQ(TheDriver.getTargetTriple().c_str(), ExpectTriple.data()); 40 EXPECT_EQ(Diags.getNumErrors(), 0u); 41 } 42 43 static void validateTargetProfile( 44 StringRef TargetProfile, StringRef ExpectError, 45 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem, 46 DiagnosticsEngine &Diags, SimpleDiagnosticConsumer *DiagConsumer, 47 unsigned NumOfErrors) { 48 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 49 std::unique_ptr<Compilation> C{TheDriver.BuildCompilation( 50 {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})}; 51 EXPECT_TRUE(C); 52 EXPECT_EQ(Diags.getNumErrors(), NumOfErrors); 53 EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), ExpectError.data()); 54 Diags.Clear(); 55 DiagConsumer->clear(); 56 } 57 58 TEST(DxcModeTest, TargetProfileValidation) { 59 IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs()); 60 61 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 62 new llvm::vfs::InMemoryFileSystem); 63 64 InMemoryFileSystem->addFile("foo.hlsl", 0, 65 llvm::MemoryBuffer::getMemBuffer("\n")); 66 67 auto *DiagConsumer = new SimpleDiagnosticConsumer; 68 IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions(); 69 DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); 70 71 validateTargetProfile("-Tvs_6_0", "dxil--shadermodel6.0-vertex", 72 InMemoryFileSystem, Diags); 73 validateTargetProfile("-Ths_6_1", "dxil--shadermodel6.1-hull", 74 InMemoryFileSystem, Diags); 75 validateTargetProfile("-Tds_6_2", "dxil--shadermodel6.2-domain", 76 InMemoryFileSystem, Diags); 77 validateTargetProfile("-Tds_6_2", "dxil--shadermodel6.2-domain", 78 InMemoryFileSystem, Diags); 79 validateTargetProfile("-Tgs_6_3", "dxil--shadermodel6.3-geometry", 80 InMemoryFileSystem, Diags); 81 validateTargetProfile("-Tps_6_4", "dxil--shadermodel6.4-pixel", 82 InMemoryFileSystem, Diags); 83 validateTargetProfile("-Tcs_6_5", "dxil--shadermodel6.5-compute", 84 InMemoryFileSystem, Diags); 85 validateTargetProfile("-Tms_6_6", "dxil--shadermodel6.6-mesh", 86 InMemoryFileSystem, Diags); 87 validateTargetProfile("-Tas_6_7", "dxil--shadermodel6.7-amplification", 88 InMemoryFileSystem, Diags); 89 validateTargetProfile("-Tlib_6_x", "dxil--shadermodel6.15-library", 90 InMemoryFileSystem, Diags); 91 92 // Invalid tests. 93 validateTargetProfile("-Tpss_6_1", "invalid profile : pss_6_1", 94 InMemoryFileSystem, Diags, DiagConsumer, 1); 95 96 validateTargetProfile("-Tps_6_x", "invalid profile : ps_6_x", 97 InMemoryFileSystem, Diags, DiagConsumer, 2); 98 validateTargetProfile("-Tlib_6_1", "invalid profile : lib_6_1", 99 InMemoryFileSystem, Diags, DiagConsumer, 3); 100 validateTargetProfile("-Tfoo", "invalid profile : foo", InMemoryFileSystem, 101 Diags, DiagConsumer, 4); 102 validateTargetProfile("", "target profile option (-T) is missing", 103 InMemoryFileSystem, Diags, DiagConsumer, 5); 104 } 105 106 TEST(DxcModeTest, ValidatorVersionValidation) { 107 IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs()); 108 109 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 110 new llvm::vfs::InMemoryFileSystem); 111 112 InMemoryFileSystem->addFile("foo.hlsl", 0, 113 llvm::MemoryBuffer::getMemBuffer("\n")); 114 115 auto *DiagConsumer = new SimpleDiagnosticConsumer; 116 IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions(); 117 DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); 118 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 119 std::unique_ptr<Compilation> C(TheDriver.BuildCompilation( 120 {"clang", "--driver-mode=dxc", "-Tlib_6_7", "foo.hlsl"})); 121 EXPECT_TRUE(C); 122 EXPECT_TRUE(!C->containsError()); 123 124 auto &TC = C->getDefaultToolChain(); 125 bool ContainsError = false; 126 auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false, 127 ContainsError); 128 EXPECT_FALSE(ContainsError); 129 auto DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 130 for (auto *A : Args) 131 DAL->append(A); 132 133 std::unique_ptr<llvm::opt::DerivedArgList> TranslatedArgs{ 134 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)}; 135 EXPECT_NE(TranslatedArgs, nullptr); 136 if (TranslatedArgs) { 137 auto *A = TranslatedArgs->getLastArg( 138 clang::driver::options::OPT_dxil_validator_version); 139 EXPECT_NE(A, nullptr); 140 if (A) 141 EXPECT_STREQ(A->getValue(), "1.1"); 142 } 143 EXPECT_EQ(Diags.getNumErrors(), 0u); 144 145 // Invalid tests. 146 Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false, 147 ContainsError); 148 EXPECT_FALSE(ContainsError); 149 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 150 for (auto *A : Args) 151 DAL->append(A); 152 153 TranslatedArgs.reset( 154 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 155 EXPECT_EQ(Diags.getNumErrors(), 1u); 156 EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), 157 "invalid validator version : 0.1\nIf validator major version is " 158 "0, minor version must also be 0."); 159 Diags.Clear(); 160 DiagConsumer->clear(); 161 162 Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false, 163 ContainsError); 164 EXPECT_FALSE(ContainsError); 165 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 166 for (auto *A : Args) 167 DAL->append(A); 168 169 TranslatedArgs.reset( 170 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 171 EXPECT_EQ(Diags.getNumErrors(), 2u); 172 EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), 173 "invalid validator version : 1\nFormat of validator version is " 174 "\"<major>.<minor>\" (ex:\"1.4\")."); 175 Diags.Clear(); 176 DiagConsumer->clear(); 177 178 Args = TheDriver.ParseArgStrings({"-validator-version", "-Tlib_6_7"}, false, 179 ContainsError); 180 EXPECT_FALSE(ContainsError); 181 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 182 for (auto *A : Args) 183 DAL->append(A); 184 185 TranslatedArgs.reset( 186 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 187 EXPECT_EQ(Diags.getNumErrors(), 3u); 188 EXPECT_STREQ( 189 DiagConsumer->Errors.back().c_str(), 190 "invalid validator version : -Tlib_6_7\nFormat of validator version is " 191 "\"<major>.<minor>\" (ex:\"1.4\")."); 192 Diags.Clear(); 193 DiagConsumer->clear(); 194 195 Args = TheDriver.ParseArgStrings({"-validator-version", "foo"}, false, 196 ContainsError); 197 EXPECT_FALSE(ContainsError); 198 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 199 for (auto *A : Args) 200 DAL->append(A); 201 202 TranslatedArgs.reset( 203 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 204 EXPECT_EQ(Diags.getNumErrors(), 4u); 205 EXPECT_STREQ( 206 DiagConsumer->Errors.back().c_str(), 207 "invalid validator version : foo\nFormat of validator version is " 208 "\"<major>.<minor>\" (ex:\"1.4\")."); 209 Diags.Clear(); 210 DiagConsumer->clear(); 211 } 212 213 TEST(DxcModeTest, DefaultEntry) { 214 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 215 new llvm::vfs::InMemoryFileSystem); 216 217 InMemoryFileSystem->addFile("foo.hlsl", 0, 218 llvm::MemoryBuffer::getMemBuffer("\n")); 219 220 const char *Args[] = {"clang", "--driver-mode=dxc", "-Tcs_6_7", "foo.hlsl"}; 221 222 IntrusiveRefCntPtr<DiagnosticsEngine> Diags = 223 CompilerInstance::createDiagnostics(new DiagnosticOptions()); 224 225 CreateInvocationOptions CIOpts; 226 CIOpts.Diags = Diags; 227 std::unique_ptr<CompilerInvocation> CInvok = 228 createInvocation(Args, std::move(CIOpts)); 229 EXPECT_TRUE(CInvok); 230 // Make sure default entry is "main". 231 EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "main"); 232 233 const char *EntryArgs[] = {"clang", "--driver-mode=dxc", "-Ebar", "-Tcs_6_7", 234 "foo.hlsl"}; 235 CInvok = createInvocation(EntryArgs, std::move(CIOpts)); 236 EXPECT_TRUE(CInvok); 237 // Make sure "-E" will set entry. 238 EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "bar"); 239 } 240