1 //===- unittests/Driver/DXCModeTest.cpp --- DXC Mode tests ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Unit tests for driver DXCMode. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "clang/Basic/DiagnosticIDs.h" 14 #include "clang/Basic/DiagnosticOptions.h" 15 #include "clang/Basic/LLVM.h" 16 #include "clang/Basic/TargetOptions.h" 17 #include "clang/Driver/Compilation.h" 18 #include "clang/Driver/Driver.h" 19 #include "clang/Driver/ToolChain.h" 20 #include "clang/Frontend/CompilerInstance.h" 21 #include "llvm/Support/VirtualFileSystem.h" 22 #include "llvm/Support/raw_ostream.h" 23 #include "gtest/gtest.h" 24 #include <memory> 25 26 #include "SimpleDiagnosticConsumer.h" 27 28 using namespace clang; 29 using namespace clang::driver; 30 31 static void validateTargetProfile( 32 StringRef TargetProfile, StringRef ExpectTriple, 33 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem, 34 DiagnosticsEngine &Diags) { 35 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 36 std::unique_ptr<Compilation> C{TheDriver.BuildCompilation( 37 {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})}; 38 EXPECT_TRUE(C); 39 EXPECT_STREQ(TheDriver.getTargetTriple().c_str(), ExpectTriple.data()); 40 EXPECT_EQ(Diags.getNumErrors(), 0u); 41 } 42 43 static void validateTargetProfile( 44 StringRef TargetProfile, StringRef ExpectError, 45 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> &InMemoryFileSystem, 46 DiagnosticsEngine &Diags, SimpleDiagnosticConsumer *DiagConsumer, 47 unsigned NumOfErrors) { 48 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 49 std::unique_ptr<Compilation> C{TheDriver.BuildCompilation( 50 {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl", "-Vd"})}; 51 EXPECT_TRUE(C); 52 EXPECT_EQ(Diags.getNumErrors(), NumOfErrors); 53 EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), ExpectError.data()); 54 DiagConsumer->clear(); 55 } 56 57 TEST(DxcModeTest, TargetProfileValidation) { 58 IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs()); 59 60 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 61 new llvm::vfs::InMemoryFileSystem); 62 63 InMemoryFileSystem->addFile("foo.hlsl", 0, 64 llvm::MemoryBuffer::getMemBuffer("\n")); 65 66 auto *DiagConsumer = new SimpleDiagnosticConsumer; 67 IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions(); 68 DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); 69 70 validateTargetProfile("-Tvs_6_0", "dxilv1.0--shadermodel6.0-vertex", 71 InMemoryFileSystem, Diags); 72 validateTargetProfile("-Ths_6_1", "dxilv1.1--shadermodel6.1-hull", 73 InMemoryFileSystem, Diags); 74 validateTargetProfile("-Tds_6_2", "dxilv1.2--shadermodel6.2-domain", 75 InMemoryFileSystem, Diags); 76 validateTargetProfile("-Tds_6_2", "dxilv1.2--shadermodel6.2-domain", 77 InMemoryFileSystem, Diags); 78 validateTargetProfile("-Tgs_6_3", "dxilv1.3--shadermodel6.3-geometry", 79 InMemoryFileSystem, Diags); 80 validateTargetProfile("-Tps_6_4", "dxilv1.4--shadermodel6.4-pixel", 81 InMemoryFileSystem, Diags); 82 validateTargetProfile("-Tcs_6_5", "dxilv1.5--shadermodel6.5-compute", 83 InMemoryFileSystem, Diags); 84 validateTargetProfile("-Tms_6_6", "dxilv1.6--shadermodel6.6-mesh", 85 InMemoryFileSystem, Diags); 86 validateTargetProfile("-Tas_6_7", "dxilv1.7--shadermodel6.7-amplification", 87 InMemoryFileSystem, Diags); 88 validateTargetProfile("-Tcs_6_8", "dxilv1.8--shadermodel6.8-compute", 89 InMemoryFileSystem, Diags); 90 validateTargetProfile("-Tlib_6_x", "dxilv1.8--shadermodel6.15-library", 91 InMemoryFileSystem, Diags); 92 93 // Invalid tests. 94 validateTargetProfile("-Tpss_6_1", "invalid profile : pss_6_1", 95 InMemoryFileSystem, Diags, DiagConsumer, 1); 96 97 validateTargetProfile("-Tps_6_x", "invalid profile : ps_6_x", 98 InMemoryFileSystem, Diags, DiagConsumer, 2); 99 validateTargetProfile("-Tlib_6_1", "invalid profile : lib_6_1", 100 InMemoryFileSystem, Diags, DiagConsumer, 3); 101 validateTargetProfile("-Tfoo", "invalid profile : foo", InMemoryFileSystem, 102 Diags, DiagConsumer, 4); 103 validateTargetProfile("", "target profile option (-T) is missing", 104 InMemoryFileSystem, Diags, DiagConsumer, 5); 105 } 106 107 TEST(DxcModeTest, ValidatorVersionValidation) { 108 IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs()); 109 110 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 111 new llvm::vfs::InMemoryFileSystem); 112 113 InMemoryFileSystem->addFile("foo.hlsl", 0, 114 llvm::MemoryBuffer::getMemBuffer("\n")); 115 116 auto *DiagConsumer = new SimpleDiagnosticConsumer; 117 IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions(); 118 DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); 119 Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); 120 std::unique_ptr<Compilation> C(TheDriver.BuildCompilation( 121 {"clang", "--driver-mode=dxc", "-Tlib_6_7", "foo.hlsl"})); 122 EXPECT_TRUE(C); 123 EXPECT_TRUE(!C->containsError()); 124 125 auto &TC = C->getDefaultToolChain(); 126 bool ContainsError = false; 127 auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false, 128 ContainsError); 129 EXPECT_FALSE(ContainsError); 130 auto DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 131 for (auto *A : Args) 132 DAL->append(A); 133 134 std::unique_ptr<llvm::opt::DerivedArgList> TranslatedArgs{ 135 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)}; 136 EXPECT_NE(TranslatedArgs, nullptr); 137 if (TranslatedArgs) { 138 auto *A = TranslatedArgs->getLastArg( 139 clang::driver::options::OPT_dxil_validator_version); 140 EXPECT_NE(A, nullptr); 141 if (A) { 142 EXPECT_STREQ(A->getValue(), "1.1"); 143 } 144 } 145 EXPECT_EQ(Diags.getNumErrors(), 0u); 146 147 // Invalid tests. 148 Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false, 149 ContainsError); 150 EXPECT_FALSE(ContainsError); 151 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 152 for (auto *A : Args) 153 DAL->append(A); 154 155 TranslatedArgs.reset( 156 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 157 EXPECT_EQ(Diags.getNumErrors(), 1u); 158 EXPECT_STREQ( 159 DiagConsumer->Errors.back().c_str(), 160 "invalid validator version : 0.1; if validator major version is 0, " 161 "minor version must also be 0"); 162 DiagConsumer->clear(); 163 164 Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false, 165 ContainsError); 166 EXPECT_FALSE(ContainsError); 167 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 168 for (auto *A : Args) 169 DAL->append(A); 170 171 TranslatedArgs.reset( 172 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 173 EXPECT_EQ(Diags.getNumErrors(), 2u); 174 EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), 175 "invalid validator version : 1; format of validator version is " 176 "\"<major>.<minor>\" (ex:\"1.4\")"); 177 DiagConsumer->clear(); 178 179 Args = TheDriver.ParseArgStrings({"-validator-version", "-Tlib_6_7"}, false, 180 ContainsError); 181 EXPECT_FALSE(ContainsError); 182 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 183 for (auto *A : Args) 184 DAL->append(A); 185 186 TranslatedArgs.reset( 187 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 188 EXPECT_EQ(Diags.getNumErrors(), 3u); 189 EXPECT_STREQ( 190 DiagConsumer->Errors.back().c_str(), 191 "invalid validator version : -Tlib_6_7; format of validator version is " 192 "\"<major>.<minor>\" (ex:\"1.4\")"); 193 DiagConsumer->clear(); 194 195 Args = TheDriver.ParseArgStrings({"-validator-version", "foo"}, false, 196 ContainsError); 197 EXPECT_FALSE(ContainsError); 198 DAL = std::make_unique<llvm::opt::DerivedArgList>(Args); 199 for (auto *A : Args) 200 DAL->append(A); 201 202 TranslatedArgs.reset( 203 TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None)); 204 EXPECT_EQ(Diags.getNumErrors(), 4u); 205 EXPECT_STREQ( 206 DiagConsumer->Errors.back().c_str(), 207 "invalid validator version : foo; format of validator version is " 208 "\"<major>.<minor>\" (ex:\"1.4\")"); 209 DiagConsumer->clear(); 210 } 211 212 TEST(DxcModeTest, DefaultEntry) { 213 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem( 214 new llvm::vfs::InMemoryFileSystem); 215 216 InMemoryFileSystem->addFile("foo.hlsl", 0, 217 llvm::MemoryBuffer::getMemBuffer("\n")); 218 219 const char *Args[] = {"clang", "--driver-mode=dxc", "-Tcs_6_7", "foo.hlsl"}; 220 221 IntrusiveRefCntPtr<DiagnosticsEngine> Diags = 222 CompilerInstance::createDiagnostics(*InMemoryFileSystem, 223 new DiagnosticOptions()); 224 225 CreateInvocationOptions CIOpts; 226 CIOpts.Diags = Diags; 227 std::unique_ptr<CompilerInvocation> CInvok = 228 createInvocation(Args, std::move(CIOpts)); 229 EXPECT_TRUE(CInvok); 230 // Make sure default entry is "main". 231 EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "main"); 232 233 const char *EntryArgs[] = {"clang", "--driver-mode=dxc", "-Ebar", "-Tcs_6_7", 234 "foo.hlsl"}; 235 CInvok = createInvocation(EntryArgs, std::move(CIOpts)); 236 EXPECT_TRUE(CInvok); 237 // Make sure "-E" will set entry. 238 EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "bar"); 239 } 240