xref: /llvm-project/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp (revision f83950ab8dfda1da882a6ef7b508639df251621a)
1*f83950abSAngel Zhang //===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===//
2*f83950abSAngel Zhang //
3*f83950abSAngel Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*f83950abSAngel Zhang // See https://llvm.org/LICENSE.txt for license information.
5*f83950abSAngel Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*f83950abSAngel Zhang //
7*f83950abSAngel Zhang //===-------------------------------------------------------------------===//
8*f83950abSAngel Zhang 
9*f83950abSAngel Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
10*f83950abSAngel Zhang #include "mlir/Dialect/Func/IR/FuncOps.h"
11*f83950abSAngel Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12*f83950abSAngel Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13*f83950abSAngel Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h"
14*f83950abSAngel Zhang #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
15*f83950abSAngel Zhang #include "mlir/Pass/Pass.h"
16*f83950abSAngel Zhang #include "mlir/Pass/PassManager.h"
17*f83950abSAngel Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18*f83950abSAngel Zhang 
19*f83950abSAngel Zhang namespace mlir {
20*f83950abSAngel Zhang namespace {
21*f83950abSAngel Zhang 
22*f83950abSAngel Zhang struct TestSPIRVVectorUnrolling final
23*f83950abSAngel Zhang     : PassWrapper<TestSPIRVVectorUnrolling, OperationPass<ModuleOp>> {
24*f83950abSAngel Zhang   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling)
25*f83950abSAngel Zhang 
26*f83950abSAngel Zhang   StringRef getArgument() const final { return "test-spirv-vector-unrolling"; }
27*f83950abSAngel Zhang 
28*f83950abSAngel Zhang   StringRef getDescription() const final {
29*f83950abSAngel Zhang     return "Test patterns that unroll vectors to types supported by SPIR-V";
30*f83950abSAngel Zhang   }
31*f83950abSAngel Zhang 
32*f83950abSAngel Zhang   void getDependentDialects(DialectRegistry &registry) const override {
33*f83950abSAngel Zhang     registry.insert<spirv::SPIRVDialect, vector::VectorDialect>();
34*f83950abSAngel Zhang   }
35*f83950abSAngel Zhang 
36*f83950abSAngel Zhang   void runOnOperation() override {
37*f83950abSAngel Zhang     Operation *op = getOperation();
38*f83950abSAngel Zhang     (void)spirv::unrollVectorsInSignatures(op);
39*f83950abSAngel Zhang     (void)spirv::unrollVectorsInFuncBodies(op);
40*f83950abSAngel Zhang   }
41*f83950abSAngel Zhang };
42*f83950abSAngel Zhang 
43*f83950abSAngel Zhang } // namespace
44*f83950abSAngel Zhang 
45*f83950abSAngel Zhang namespace test {
46*f83950abSAngel Zhang void registerTestSPIRVVectorUnrolling() {
47*f83950abSAngel Zhang   PassRegistration<TestSPIRVVectorUnrolling>();
48*f83950abSAngel Zhang }
49*f83950abSAngel Zhang } // namespace test
50*f83950abSAngel Zhang } // namespace mlir
51