xref: /llvm-project/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp (revision f83950ab8dfda1da882a6ef7b508639df251621a)
1 //===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===-------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 namespace mlir {
20 namespace {
21 
22 struct TestSPIRVVectorUnrolling final
23     : PassWrapper<TestSPIRVVectorUnrolling, OperationPass<ModuleOp>> {
24   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling)
25 
26   StringRef getArgument() const final { return "test-spirv-vector-unrolling"; }
27 
28   StringRef getDescription() const final {
29     return "Test patterns that unroll vectors to types supported by SPIR-V";
30   }
31 
32   void getDependentDialects(DialectRegistry &registry) const override {
33     registry.insert<spirv::SPIRVDialect, vector::VectorDialect>();
34   }
35 
36   void runOnOperation() override {
37     Operation *op = getOperation();
38     (void)spirv::unrollVectorsInSignatures(op);
39     (void)spirv::unrollVectorsInFuncBodies(op);
40   }
41 };
42 
43 } // namespace
44 
45 namespace test {
46 void registerTestSPIRVVectorUnrolling() {
47   PassRegistration<TestSPIRVVectorUnrolling>();
48 }
49 } // namespace test
50 } // namespace mlir
51