xref: /llvm-project/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp (revision 25ae1a266d50f24a8fffc57152d7f3c3fcb65517)
197358730SAndrea Faulds //===------------------ TestVulkanRunnerPipeline.cpp --------------------===//
297358730SAndrea Faulds //
397358730SAndrea Faulds // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
497358730SAndrea Faulds // See https://llvm.org/LICENSE.txt for license information.
597358730SAndrea Faulds // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
697358730SAndrea Faulds //
797358730SAndrea Faulds //===----------------------------------------------------------------------===//
897358730SAndrea Faulds //
9e7e3c45bSAndrea Faulds // Implements a pipeline for use by Vulkan runner tests.
1097358730SAndrea Faulds //
1197358730SAndrea Faulds //===----------------------------------------------------------------------===//
1297358730SAndrea Faulds 
1399a562b3SAndrea Faulds #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1497358730SAndrea Faulds #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
1599a562b3SAndrea Faulds #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1699a562b3SAndrea Faulds #include "mlir/Dialect/Func/IR/FuncOps.h"
177724be97SAndrea Faulds #include "mlir/Dialect/GPU/IR/GPUDialect.h"
1897358730SAndrea Faulds #include "mlir/Dialect/GPU/Transforms/Passes.h"
1999a562b3SAndrea Faulds #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
2097358730SAndrea Faulds #include "mlir/Dialect/MemRef/Transforms/Passes.h"
2197358730SAndrea Faulds #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
2297358730SAndrea Faulds #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
2397358730SAndrea Faulds #include "mlir/Pass/PassManager.h"
247724be97SAndrea Faulds #include "mlir/Pass/PassOptions.h"
2597358730SAndrea Faulds 
2697358730SAndrea Faulds using namespace mlir;
2797358730SAndrea Faulds 
28*25ae1a26SAndrea Faulds // Defined in the test directory, no public header.
29*25ae1a26SAndrea Faulds namespace mlir::test {
30*25ae1a26SAndrea Faulds std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules,
31*25ae1a26SAndrea Faulds                                                    bool nestInGPUModule);
32*25ae1a26SAndrea Faulds }
33*25ae1a26SAndrea Faulds 
3497358730SAndrea Faulds namespace {
3597358730SAndrea Faulds 
367724be97SAndrea Faulds struct VulkanRunnerPipelineOptions
377724be97SAndrea Faulds     : PassPipelineOptions<VulkanRunnerPipelineOptions> {
387724be97SAndrea Faulds   Option<bool> spirvWebGPUPrepare{
397724be97SAndrea Faulds       *this, "spirv-webgpu-prepare",
407724be97SAndrea Faulds       llvm::cl::desc("Run MLIR transforms used when targetting WebGPU")};
417724be97SAndrea Faulds };
427724be97SAndrea Faulds 
437724be97SAndrea Faulds void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
447724be97SAndrea Faulds                                    const VulkanRunnerPipelineOptions &options) {
4597358730SAndrea Faulds   passManager.addPass(createGpuKernelOutliningPass());
4697358730SAndrea Faulds   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
4797358730SAndrea Faulds 
487724be97SAndrea Faulds   GpuSPIRVAttachTargetOptions attachTargetOptions{};
497724be97SAndrea Faulds   attachTargetOptions.spirvVersion = "v1.0";
507724be97SAndrea Faulds   attachTargetOptions.spirvCapabilities.push_back("Shader");
517724be97SAndrea Faulds   attachTargetOptions.spirvExtensions.push_back(
527724be97SAndrea Faulds       "SPV_KHR_storage_buffer_storage_class");
537724be97SAndrea Faulds   passManager.addPass(createGpuSPIRVAttachTarget(attachTargetOptions));
547724be97SAndrea Faulds 
55*25ae1a26SAndrea Faulds   passManager.addPass(test::createTestConvertToSPIRVPass(
56*25ae1a26SAndrea Faulds       /*convertGPUModules=*/true, /*nestInGPUModule=*/true));
577724be97SAndrea Faulds 
587724be97SAndrea Faulds   OpPassManager &spirvModulePM =
597724be97SAndrea Faulds       passManager.nest<gpu::GPUModuleOp>().nest<spirv::ModuleOp>();
607724be97SAndrea Faulds   spirvModulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
617724be97SAndrea Faulds   spirvModulePM.addPass(spirv::createSPIRVUpdateVCEPass());
627724be97SAndrea Faulds   if (options.spirvWebGPUPrepare)
637724be97SAndrea Faulds     spirvModulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
647724be97SAndrea Faulds 
657724be97SAndrea Faulds   passManager.addPass(createGpuModuleToBinaryPass());
6699a562b3SAndrea Faulds 
6799a562b3SAndrea Faulds   passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
68e7e3c45bSAndrea Faulds   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
69e7e3c45bSAndrea Faulds   // VulkanRuntimeWrappers.cpp requires these calling convention options.
7099a562b3SAndrea Faulds   GpuToLLVMConversionPassOptions opt;
7199a562b3SAndrea Faulds   opt.hostBarePtrCallConv = false;
72733be4edSAndrea Faulds   opt.kernelBarePtrCallConv = true;
73733be4edSAndrea Faulds   opt.kernelIntersperseSizeCallConv = true;
7499a562b3SAndrea Faulds   passManager.addPass(createGpuToLLVMConversionPass(opt));
7599a562b3SAndrea Faulds }
7697358730SAndrea Faulds 
7797358730SAndrea Faulds } // namespace
7897358730SAndrea Faulds 
7997358730SAndrea Faulds namespace mlir::test {
8097358730SAndrea Faulds void registerTestVulkanRunnerPipeline() {
817724be97SAndrea Faulds   PassPipelineRegistration<VulkanRunnerPipelineOptions>(
8297358730SAndrea Faulds       "test-vulkan-runner-pipeline",
83e7e3c45bSAndrea Faulds       "Runs a series of passes intended for Vulkan runner tests. Lowers GPU "
84e7e3c45bSAndrea Faulds       "dialect to LLVM dialect for the host and to serialized Vulkan SPIR-V "
85e7e3c45bSAndrea Faulds       "for the device.",
8697358730SAndrea Faulds       buildTestVulkanRunnerPipeline);
8797358730SAndrea Faulds }
8897358730SAndrea Faulds } // namespace mlir::test
89