xref: /llvm-project/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp (revision 25ae1a266d50f24a8fffc57152d7f3c3fcb65517)
1 //===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
10 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
11 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
12 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
13 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
14 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
15 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
16 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
17 #include "mlir/Dialect/Arith/Transforms/Passes.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
20 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Pass/PassOptions.h"
27 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include <memory>
30 
31 #define DEBUG_TYPE "test-convert-to-spirv"
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 /// Map memRef memory space to SPIR-V storage class.
38 void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
39   spirv::TargetEnv targetEnv(targetAttr);
40   bool targetEnvSupportsKernelCapability =
41       targetEnv.allows(spirv::Capability::Kernel);
42   spirv::MemorySpaceToStorageClassMap memorySpaceMap =
43       targetEnvSupportsKernelCapability
44           ? spirv::mapMemorySpaceToOpenCLStorageClass
45           : spirv::mapMemorySpaceToVulkanStorageClass;
46   spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
47   spirv::convertMemRefTypesAndAttrs(op, converter);
48 }
49 
50 /// Populate patterns for each dialect.
51 void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
52                                     ScfToSPIRVContext &scfToSPIRVContext,
53                                     RewritePatternSet &patterns) {
54   arith::populateCeilFloorDivExpandOpsPatterns(patterns);
55   arith::populateArithToSPIRVPatterns(typeConverter, patterns);
56   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
57   populateFuncToSPIRVPatterns(typeConverter, patterns);
58   populateGPUToSPIRVPatterns(typeConverter, patterns);
59   index::populateIndexToSPIRVPatterns(typeConverter, patterns);
60   populateMemRefToSPIRVPatterns(typeConverter, patterns);
61   populateVectorToSPIRVPatterns(typeConverter, patterns);
62   populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
63   ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
64 }
65 
66 /// A pass to perform the SPIR-V conversion.
67 struct TestConvertToSPIRVPass final
68     : PassWrapper<TestConvertToSPIRVPass, OperationPass<>> {
69   Option<bool> runSignatureConversion{
70       *this, "run-signature-conversion",
71       llvm::cl::desc(
72           "Run function signature conversion to convert vector types"),
73       llvm::cl::init(true)};
74   Option<bool> runVectorUnrolling{
75       *this, "run-vector-unrolling",
76       llvm::cl::desc(
77           "Run vector unrolling to convert vector types in function bodies"),
78       llvm::cl::init(true)};
79   Option<bool> convertGPUModules{
80       *this, "convert-gpu-modules",
81       llvm::cl::desc("Clone and convert GPU modules"), llvm::cl::init(false)};
82   Option<bool> nestInGPUModule{
83       *this, "nest-in-gpu-module",
84       llvm::cl::desc("Put converted SPIR-V module inside the gpu.module "
85                      "instead of alongside it."),
86       llvm::cl::init(false)};
87 
88   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToSPIRVPass)
89 
90   StringRef getArgument() const final { return "test-convert-to-spirv"; }
91   StringRef getDescription() const final {
92     return "Conversion to SPIR-V pass only used for internal tests.";
93   }
94   void getDependentDialects(DialectRegistry &registry) const override {
95     registry.insert<spirv::SPIRVDialect>();
96     registry.insert<vector::VectorDialect>();
97   }
98 
99   TestConvertToSPIRVPass() = default;
100   TestConvertToSPIRVPass(bool convertGPUModules, bool nestInGPUModule) {
101     this->convertGPUModules = convertGPUModules;
102     this->nestInGPUModule = nestInGPUModule;
103   };
104   TestConvertToSPIRVPass(const TestConvertToSPIRVPass &) {}
105 
106   void runOnOperation() override {
107     Operation *op = getOperation();
108     MLIRContext *context = &getContext();
109 
110     // Unroll vectors in function signatures to native size.
111     if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
112       return signalPassFailure();
113 
114     // Unroll vectors in function bodies to native size.
115     if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
116       return signalPassFailure();
117 
118     // Generic conversion.
119     if (!convertGPUModules) {
120       spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
121       std::unique_ptr<ConversionTarget> target =
122           SPIRVConversionTarget::get(targetAttr);
123       SPIRVTypeConverter typeConverter(targetAttr);
124       RewritePatternSet patterns(context);
125       ScfToSPIRVContext scfToSPIRVContext;
126       mapToMemRef(op, targetAttr);
127       populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
128                                      patterns);
129       if (failed(applyPartialConversion(op, *target, std::move(patterns))))
130         return signalPassFailure();
131       return;
132     }
133 
134     // Clone each GPU kernel module for conversion, given that the GPU
135     // launch op still needs the original GPU kernel module.
136     SmallVector<Operation *, 1> gpuModules;
137     OpBuilder builder(context);
138     op->walk([&](gpu::GPUModuleOp gpuModule) {
139       if (nestInGPUModule)
140         builder.setInsertionPointToStart(gpuModule.getBody());
141       else
142         builder.setInsertionPoint(gpuModule);
143       gpuModules.push_back(builder.clone(*gpuModule));
144     });
145     // Run conversion for each module independently as they can have
146     // different TargetEnv attributes.
147     for (Operation *gpuModule : gpuModules) {
148       spirv::TargetEnvAttr targetAttr =
149           spirv::lookupTargetEnvOrDefault(gpuModule);
150       std::unique_ptr<ConversionTarget> target =
151           SPIRVConversionTarget::get(targetAttr);
152       SPIRVTypeConverter typeConverter(targetAttr);
153       RewritePatternSet patterns(context);
154       ScfToSPIRVContext scfToSPIRVContext;
155       mapToMemRef(gpuModule, targetAttr);
156       populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
157                                      patterns);
158       if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
159         return signalPassFailure();
160     }
161   }
162 };
163 
164 } // namespace
165 
166 namespace mlir::test {
167 void registerTestConvertToSPIRVPass() {
168   PassRegistration<TestConvertToSPIRVPass>();
169 }
170 std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules,
171                                                    bool nestInGPUModule) {
172   return std::make_unique<TestConvertToSPIRVPass>(convertGPUModules,
173                                                   nestInGPUModule);
174 }
175 } // namespace mlir::test
176