xref: /llvm-project/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp (revision 25ae1a266d50f24a8fffc57152d7f3c3fcb65517)
1 //===------------------ TestVulkanRunnerPipeline.cpp --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements a pipeline for use by Vulkan runner tests.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
14 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
15 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/GPU/Transforms/Passes.h"
19 #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
22 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Pass/PassOptions.h"
25 
26 using namespace mlir;
27 
28 // Defined in the test directory, no public header.
29 namespace mlir::test {
30 std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules,
31                                                    bool nestInGPUModule);
32 }
33 
34 namespace {
35 
36 struct VulkanRunnerPipelineOptions
37     : PassPipelineOptions<VulkanRunnerPipelineOptions> {
38   Option<bool> spirvWebGPUPrepare{
39       *this, "spirv-webgpu-prepare",
40       llvm::cl::desc("Run MLIR transforms used when targetting WebGPU")};
41 };
42 
43 void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
44                                    const VulkanRunnerPipelineOptions &options) {
45   passManager.addPass(createGpuKernelOutliningPass());
46   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
47 
48   GpuSPIRVAttachTargetOptions attachTargetOptions{};
49   attachTargetOptions.spirvVersion = "v1.0";
50   attachTargetOptions.spirvCapabilities.push_back("Shader");
51   attachTargetOptions.spirvExtensions.push_back(
52       "SPV_KHR_storage_buffer_storage_class");
53   passManager.addPass(createGpuSPIRVAttachTarget(attachTargetOptions));
54 
55   passManager.addPass(test::createTestConvertToSPIRVPass(
56       /*convertGPUModules=*/true, /*nestInGPUModule=*/true));
57 
58   OpPassManager &spirvModulePM =
59       passManager.nest<gpu::GPUModuleOp>().nest<spirv::ModuleOp>();
60   spirvModulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
61   spirvModulePM.addPass(spirv::createSPIRVUpdateVCEPass());
62   if (options.spirvWebGPUPrepare)
63     spirvModulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
64 
65   passManager.addPass(createGpuModuleToBinaryPass());
66 
67   passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
68   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
69   // VulkanRuntimeWrappers.cpp requires these calling convention options.
70   GpuToLLVMConversionPassOptions opt;
71   opt.hostBarePtrCallConv = false;
72   opt.kernelBarePtrCallConv = true;
73   opt.kernelIntersperseSizeCallConv = true;
74   passManager.addPass(createGpuToLLVMConversionPass(opt));
75 }
76 
77 } // namespace
78 
79 namespace mlir::test {
80 void registerTestVulkanRunnerPipeline() {
81   PassPipelineRegistration<VulkanRunnerPipelineOptions>(
82       "test-vulkan-runner-pipeline",
83       "Runs a series of passes intended for Vulkan runner tests. Lowers GPU "
84       "dialect to LLVM dialect for the host and to serialized Vulkan SPIR-V "
85       "for the device.",
86       buildTestVulkanRunnerPipeline);
87 }
88 } // namespace mlir::test
89