xref: /llvm-project/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp (revision f83950ab8dfda1da882a6ef7b508639df251621a)
1 //===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===-------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 namespace mlir {
19 namespace {
20 
21 struct TestSPIRVFuncSignatureConversion final
22     : PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
23   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
24 
25   StringRef getArgument() const final {
26     return "test-spirv-func-signature-conversion";
27   }
28 
29   StringRef getDescription() const final {
30     return "Test patterns that convert vector inputs and results in function "
31            "signatures";
32   }
33 
34   void getDependentDialects(DialectRegistry &registry) const override {
35     registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
36                     vector::VectorDialect>();
37   }
38 
39   void runOnOperation() override {
40     Operation *op = getOperation();
41     (void)spirv::unrollVectorsInSignatures(op);
42   }
43 };
44 
45 } // namespace
46 
47 namespace test {
48 void registerTestSPIRVFuncSignatureConversion() {
49   PassRegistration<TestSPIRVFuncSignatureConversion>();
50 }
51 } // namespace test
52 } // namespace mlir
53