1 //===- SerializeNVVMTarget.cpp ----------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Config/mlir-config.h" 10 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 11 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 12 #include "mlir/IR/MLIRContext.h" 13 #include "mlir/InitAllDialects.h" 14 #include "mlir/Parser/Parser.h" 15 #include "mlir/Target/LLVM/NVVM/Target.h" 16 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" 17 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" 18 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 19 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" 20 21 #include "llvm/Bitcode/BitcodeWriter.h" 22 #include "llvm/Config/llvm-config.h" // for LLVM_HAS_NVPTX_TARGET 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Support/MemoryBufferRef.h" 25 #include "llvm/Support/Process.h" 26 #include "llvm/Support/SourceMgr.h" 27 #include "llvm/Support/TargetSelect.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "llvm/TargetParser/Host.h" 30 31 #include "gmock/gmock.h" 32 #include <cstdint> 33 34 using namespace mlir; 35 36 // Skip the test if the NVPTX target was not built. 37 #if LLVM_HAS_NVPTX_TARGET 38 #define SKIP_WITHOUT_NVPTX(x) x 39 #else 40 #define SKIP_WITHOUT_NVPTX(x) DISABLED_##x 41 #endif 42 43 class MLIRTargetLLVMNVVM : public ::testing::Test { 44 protected: 45 void SetUp() override { 46 registerBuiltinDialectTranslation(registry); 47 registerLLVMDialectTranslation(registry); 48 registerGPUDialectTranslation(registry); 49 registerNVVMDialectTranslation(registry); 50 NVVM::registerNVVMTargetInterfaceExternalModels(registry); 51 } 52 53 // Checks if PTXAS is in PATH. 54 bool hasPtxas() { 55 // Find the `ptxas` compiler. 56 std::optional<std::string> ptxasCompiler = 57 llvm::sys::Process::FindInEnvPath("PATH", "ptxas"); 58 return ptxasCompiler.has_value(); 59 } 60 61 // Dialect registry. 62 DialectRegistry registry; 63 64 // MLIR module used for the tests. 65 const std::string moduleStr = R"mlir( 66 gpu.module @nvvm_test { 67 llvm.func @nvvm_kernel(%arg0: f32) attributes {gpu.kernel, nvvm.kernel} { 68 llvm.return 69 } 70 })mlir"; 71 }; 72 73 // Test NVVM serialization to LLVM. 74 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMMToLLVM)) { 75 MLIRContext context(registry); 76 77 OwningOpRef<ModuleOp> module = 78 parseSourceString<ModuleOp>(moduleStr, &context); 79 ASSERT_TRUE(!!module); 80 81 // Create an NVVM target. 82 NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); 83 84 // Serialize the module. 85 auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); 86 ASSERT_TRUE(!!serializer); 87 gpu::TargetOptions options("", {}, "", "", gpu::CompilationTarget::Offload); 88 for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { 89 std::optional<SmallVector<char, 0>> object = 90 serializer.serializeToObject(gpuModule, options); 91 // Check that the serializer was successful. 92 ASSERT_TRUE(object != std::nullopt); 93 ASSERT_TRUE(!object->empty()); 94 95 // Read the serialized module. 96 llvm::MemoryBufferRef buffer(StringRef(object->data(), object->size()), 97 "module"); 98 llvm::LLVMContext llvmContext; 99 llvm::Expected<std::unique_ptr<llvm::Module>> llvmModule = 100 llvm::getLazyBitcodeModule(buffer, llvmContext); 101 ASSERT_TRUE(!!llvmModule); 102 ASSERT_TRUE(!!*llvmModule); 103 104 // Check that it has a function named `foo`. 105 ASSERT_TRUE((*llvmModule)->getFunction("nvvm_kernel") != nullptr); 106 } 107 } 108 109 // Test NVVM serialization to PTX. 110 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) { 111 MLIRContext context(registry); 112 113 OwningOpRef<ModuleOp> module = 114 parseSourceString<ModuleOp>(moduleStr, &context); 115 ASSERT_TRUE(!!module); 116 117 // Create an NVVM target. 118 NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); 119 120 // Serialize the module. 121 auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); 122 ASSERT_TRUE(!!serializer); 123 gpu::TargetOptions options("", {}, "", "", gpu::CompilationTarget::Assembly); 124 for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { 125 std::optional<SmallVector<char, 0>> object = 126 serializer.serializeToObject(gpuModule, options); 127 // Check that the serializer was successful. 128 ASSERT_TRUE(object != std::nullopt); 129 ASSERT_TRUE(!object->empty()); 130 131 ASSERT_TRUE( 132 StringRef(object->data(), object->size()).contains("nvvm_kernel")); 133 } 134 } 135 136 // Test NVVM serialization to Binary. 137 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) { 138 if (!hasPtxas()) 139 GTEST_SKIP() << "PTXAS compiler not found, skipping test."; 140 141 MLIRContext context(registry); 142 143 OwningOpRef<ModuleOp> module = 144 parseSourceString<ModuleOp>(moduleStr, &context); 145 ASSERT_TRUE(!!module); 146 147 // Create an NVVM target. 148 NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); 149 150 // Serialize the module. 151 auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); 152 ASSERT_TRUE(!!serializer); 153 gpu::TargetOptions options("", {}, "", "", gpu::CompilationTarget::Binary); 154 for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { 155 std::optional<SmallVector<char, 0>> object = 156 serializer.serializeToObject(gpuModule, options); 157 // Check that the serializer was successful. 158 ASSERT_TRUE(object != std::nullopt); 159 ASSERT_TRUE(!object->empty()); 160 } 161 } 162 163 // Test callback functions invoked with LLVM IR and ISA. 164 TEST_F(MLIRTargetLLVMNVVM, 165 SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) { 166 MLIRContext context(registry); 167 168 OwningOpRef<ModuleOp> module = 169 parseSourceString<ModuleOp>(moduleStr, &context); 170 ASSERT_TRUE(!!module); 171 172 NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); 173 174 auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); 175 ASSERT_TRUE(!!serializer); 176 177 std::string initialLLVMIR; 178 auto initialCallback = [&initialLLVMIR](llvm::Module &module) { 179 llvm::raw_string_ostream ros(initialLLVMIR); 180 module.print(ros, nullptr); 181 }; 182 183 std::string linkedLLVMIR; 184 auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { 185 llvm::raw_string_ostream ros(linkedLLVMIR); 186 module.print(ros, nullptr); 187 }; 188 189 std::string optimizedLLVMIR; 190 auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { 191 llvm::raw_string_ostream ros(optimizedLLVMIR); 192 module.print(ros, nullptr); 193 }; 194 195 std::string isaResult; 196 auto isaCallback = [&isaResult](llvm::StringRef isa) { 197 isaResult = isa.str(); 198 }; 199 200 gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, 201 {}, initialCallback, linkedCallback, 202 optimizedCallback, isaCallback); 203 204 for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { 205 std::optional<SmallVector<char, 0>> object = 206 serializer.serializeToObject(gpuModule, options); 207 208 ASSERT_TRUE(object != std::nullopt); 209 ASSERT_TRUE(!object->empty()); 210 ASSERT_TRUE(!initialLLVMIR.empty()); 211 ASSERT_TRUE(!linkedLLVMIR.empty()); 212 ASSERT_TRUE(!optimizedLLVMIR.empty()); 213 ASSERT_TRUE(!isaResult.empty()); 214 215 initialLLVMIR.clear(); 216 linkedLLVMIR.clear(); 217 optimizedLLVMIR.clear(); 218 isaResult.clear(); 219 } 220 } 221 222 // Test linking LLVM IR from a resource attribute. 223 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { 224 MLIRContext context(registry); 225 std::string moduleStr = R"mlir( 226 gpu.module @nvvm_test { 227 llvm.func @bar() 228 llvm.func @nvvm_kernel(%arg0: f32) attributes {gpu.kernel, nvvm.kernel} { 229 llvm.call @bar() : () -> () 230 llvm.return 231 } 232 } 233 )mlir"; 234 // Provide the library to link as a serialized bitcode blob. 235 SmallVector<char> bitcodeToLink; 236 { 237 std::string linkedLib = R"llvm( 238 define void @bar() { 239 ret void 240 } 241 )llvm"; 242 llvm::SMDiagnostic err; 243 llvm::MemoryBufferRef buffer(linkedLib, "linkedLib"); 244 llvm::LLVMContext llvmCtx; 245 std::unique_ptr<llvm::Module> module = llvm::parseIR(buffer, err, llvmCtx); 246 ASSERT_TRUE(module) << " Can't parse IR: " << err.getMessage(); 247 { 248 llvm::raw_svector_ostream os(bitcodeToLink); 249 WriteBitcodeToFile(*module, os); 250 } 251 } 252 253 OwningOpRef<ModuleOp> module = 254 parseSourceString<ModuleOp>(moduleStr, &context); 255 ASSERT_TRUE(!!module); 256 Builder builder(&context); 257 258 NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); 259 auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); 260 261 // Hook to intercept the LLVM IR after linking external libs. 262 std::string linkedLLVMIR; 263 auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { 264 llvm::raw_string_ostream ros(linkedLLVMIR); 265 module.print(ros, nullptr); 266 }; 267 268 // Store the bitcode as a DenseI8ArrayAttr. 269 SmallVector<Attribute> librariesToLink; 270 librariesToLink.push_back(DenseI8ArrayAttr::get( 271 &context, 272 ArrayRef<int8_t>((int8_t *)bitcodeToLink.data(), bitcodeToLink.size()))); 273 gpu::TargetOptions options({}, librariesToLink, {}, {}, 274 gpu::CompilationTarget::Assembly, {}, {}, 275 linkedCallback); 276 for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { 277 std::optional<SmallVector<char, 0>> object = 278 serializer.serializeToObject(gpuModule, options); 279 280 // Verify that we correctly linked in the library: the external call is 281 // replaced by the definition. 282 ASSERT_TRUE(!linkedLLVMIR.empty()); 283 { 284 llvm::SMDiagnostic err; 285 llvm::MemoryBufferRef buffer(linkedLLVMIR, "linkedLLVMIR"); 286 llvm::LLVMContext llvmCtx; 287 std::unique_ptr<llvm::Module> module = 288 llvm::parseIR(buffer, err, llvmCtx); 289 ASSERT_TRUE(module) << " Can't parse linkedLLVMIR: " << err.getMessage() 290 << " IR: \n\b" << linkedLLVMIR; 291 llvm::Function *bar = module->getFunction("bar"); 292 ASSERT_TRUE(bar); 293 ASSERT_FALSE(bar->empty()); 294 } 295 ASSERT_TRUE(object != std::nullopt); 296 ASSERT_TRUE(!object->empty()); 297 } 298 } 299