1 //===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===-------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 namespace mlir { 20 namespace { 21 22 struct TestSPIRVVectorUnrolling final 23 : PassWrapper<TestSPIRVVectorUnrolling, OperationPass<ModuleOp>> { 24 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling) 25 26 StringRef getArgument() const final { return "test-spirv-vector-unrolling"; } 27 28 StringRef getDescription() const final { 29 return "Test patterns that unroll vectors to types supported by SPIR-V"; 30 } 31 32 void getDependentDialects(DialectRegistry ®istry) const override { 33 registry.insert<spirv::SPIRVDialect, vector::VectorDialect>(); 34 } 35 36 void runOnOperation() override { 37 Operation *op = getOperation(); 38 (void)spirv::unrollVectorsInSignatures(op); 39 (void)spirv::unrollVectorsInFuncBodies(op); 40 } 41 }; 42 43 } // namespace 44 45 namespace test { 46 void registerTestSPIRVVectorUnrolling() { 47 PassRegistration<TestSPIRVVectorUnrolling>(); 48 } 49 } // namespace test 50 } // namespace mlir 51