1 //===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" 10 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 11 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" 12 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" 13 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" 14 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 15 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" 16 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 17 #include "mlir/Dialect/Arith/Transforms/Passes.h" 18 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 20 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Pass/PassOptions.h" 27 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include <memory> 30 31 #define DEBUG_TYPE "test-convert-to-spirv" 32 33 using namespace mlir; 34 35 namespace { 36 37 /// Map memRef memory space to SPIR-V storage class. 38 void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) { 39 spirv::TargetEnv targetEnv(targetAttr); 40 bool targetEnvSupportsKernelCapability = 41 targetEnv.allows(spirv::Capability::Kernel); 42 spirv::MemorySpaceToStorageClassMap memorySpaceMap = 43 targetEnvSupportsKernelCapability 44 ? spirv::mapMemorySpaceToOpenCLStorageClass 45 : spirv::mapMemorySpaceToVulkanStorageClass; 46 spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); 47 spirv::convertMemRefTypesAndAttrs(op, converter); 48 } 49 50 /// Populate patterns for each dialect. 51 void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 52 ScfToSPIRVContext &scfToSPIRVContext, 53 RewritePatternSet &patterns) { 54 arith::populateCeilFloorDivExpandOpsPatterns(patterns); 55 arith::populateArithToSPIRVPatterns(typeConverter, patterns); 56 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 57 populateFuncToSPIRVPatterns(typeConverter, patterns); 58 populateGPUToSPIRVPatterns(typeConverter, patterns); 59 index::populateIndexToSPIRVPatterns(typeConverter, patterns); 60 populateMemRefToSPIRVPatterns(typeConverter, patterns); 61 populateVectorToSPIRVPatterns(typeConverter, patterns); 62 populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); 63 ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); 64 } 65 66 /// A pass to perform the SPIR-V conversion. 67 struct TestConvertToSPIRVPass final 68 : PassWrapper<TestConvertToSPIRVPass, OperationPass<>> { 69 Option<bool> runSignatureConversion{ 70 *this, "run-signature-conversion", 71 llvm::cl::desc( 72 "Run function signature conversion to convert vector types"), 73 llvm::cl::init(true)}; 74 Option<bool> runVectorUnrolling{ 75 *this, "run-vector-unrolling", 76 llvm::cl::desc( 77 "Run vector unrolling to convert vector types in function bodies"), 78 llvm::cl::init(true)}; 79 Option<bool> convertGPUModules{ 80 *this, "convert-gpu-modules", 81 llvm::cl::desc("Clone and convert GPU modules"), llvm::cl::init(false)}; 82 Option<bool> nestInGPUModule{ 83 *this, "nest-in-gpu-module", 84 llvm::cl::desc("Put converted SPIR-V module inside the gpu.module " 85 "instead of alongside it."), 86 llvm::cl::init(false)}; 87 88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToSPIRVPass) 89 90 StringRef getArgument() const final { return "test-convert-to-spirv"; } 91 StringRef getDescription() const final { 92 return "Conversion to SPIR-V pass only used for internal tests."; 93 } 94 void getDependentDialects(DialectRegistry ®istry) const override { 95 registry.insert<spirv::SPIRVDialect>(); 96 registry.insert<vector::VectorDialect>(); 97 } 98 99 TestConvertToSPIRVPass() = default; 100 TestConvertToSPIRVPass(bool convertGPUModules, bool nestInGPUModule) { 101 this->convertGPUModules = convertGPUModules; 102 this->nestInGPUModule = nestInGPUModule; 103 }; 104 TestConvertToSPIRVPass(const TestConvertToSPIRVPass &) {} 105 106 void runOnOperation() override { 107 Operation *op = getOperation(); 108 MLIRContext *context = &getContext(); 109 110 // Unroll vectors in function signatures to native size. 111 if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op))) 112 return signalPassFailure(); 113 114 // Unroll vectors in function bodies to native size. 115 if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op))) 116 return signalPassFailure(); 117 118 // Generic conversion. 119 if (!convertGPUModules) { 120 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); 121 std::unique_ptr<ConversionTarget> target = 122 SPIRVConversionTarget::get(targetAttr); 123 SPIRVTypeConverter typeConverter(targetAttr); 124 RewritePatternSet patterns(context); 125 ScfToSPIRVContext scfToSPIRVContext; 126 mapToMemRef(op, targetAttr); 127 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext, 128 patterns); 129 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 130 return signalPassFailure(); 131 return; 132 } 133 134 // Clone each GPU kernel module for conversion, given that the GPU 135 // launch op still needs the original GPU kernel module. 136 SmallVector<Operation *, 1> gpuModules; 137 OpBuilder builder(context); 138 op->walk([&](gpu::GPUModuleOp gpuModule) { 139 if (nestInGPUModule) 140 builder.setInsertionPointToStart(gpuModule.getBody()); 141 else 142 builder.setInsertionPoint(gpuModule); 143 gpuModules.push_back(builder.clone(*gpuModule)); 144 }); 145 // Run conversion for each module independently as they can have 146 // different TargetEnv attributes. 147 for (Operation *gpuModule : gpuModules) { 148 spirv::TargetEnvAttr targetAttr = 149 spirv::lookupTargetEnvOrDefault(gpuModule); 150 std::unique_ptr<ConversionTarget> target = 151 SPIRVConversionTarget::get(targetAttr); 152 SPIRVTypeConverter typeConverter(targetAttr); 153 RewritePatternSet patterns(context); 154 ScfToSPIRVContext scfToSPIRVContext; 155 mapToMemRef(gpuModule, targetAttr); 156 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext, 157 patterns); 158 if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) 159 return signalPassFailure(); 160 } 161 } 162 }; 163 164 } // namespace 165 166 namespace mlir::test { 167 void registerTestConvertToSPIRVPass() { 168 PassRegistration<TestConvertToSPIRVPass>(); 169 } 170 std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules, 171 bool nestInGPUModule) { 172 return std::make_unique<TestConvertToSPIRVPass>(convertGPUModules, 173 nestInGPUModule); 174 } 175 } // namespace mlir::test 176