1 //===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===-------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Pass/PassManager.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 namespace mlir { 19 namespace { 20 21 struct TestSPIRVFuncSignatureConversion final 22 : PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> { 23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion) 24 25 StringRef getArgument() const final { 26 return "test-spirv-func-signature-conversion"; 27 } 28 29 StringRef getDescription() const final { 30 return "Test patterns that convert vector inputs and results in function " 31 "signatures"; 32 } 33 34 void getDependentDialects(DialectRegistry ®istry) const override { 35 registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect, 36 vector::VectorDialect>(); 37 } 38 39 void runOnOperation() override { 40 Operation *op = getOperation(); 41 (void)spirv::unrollVectorsInSignatures(op); 42 } 43 }; 44 45 } // namespace 46 47 namespace test { 48 void registerTestSPIRVFuncSignatureConversion() { 49 PassRegistration<TestSPIRVFuncSignatureConversion>(); 50 } 51 } // namespace test 52 } // namespace mlir 53