xref: /llvm-project/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp (revision f83950ab8dfda1da882a6ef7b508639df251621a)
16867e49fSAngel Zhang //===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
26867e49fSAngel Zhang //
36867e49fSAngel Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46867e49fSAngel Zhang // See https://llvm.org/LICENSE.txt for license information.
56867e49fSAngel Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66867e49fSAngel Zhang //
76867e49fSAngel Zhang //===-------------------------------------------------------------------===//
86867e49fSAngel Zhang 
96867e49fSAngel Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
106867e49fSAngel Zhang #include "mlir/Dialect/Func/IR/FuncOps.h"
116867e49fSAngel Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
126867e49fSAngel Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
136867e49fSAngel Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h"
146867e49fSAngel Zhang #include "mlir/Pass/Pass.h"
156867e49fSAngel Zhang #include "mlir/Pass/PassManager.h"
166867e49fSAngel Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
176867e49fSAngel Zhang 
186867e49fSAngel Zhang namespace mlir {
196867e49fSAngel Zhang namespace {
206867e49fSAngel Zhang 
216867e49fSAngel Zhang struct TestSPIRVFuncSignatureConversion final
226867e49fSAngel Zhang     : PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
236867e49fSAngel Zhang   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
246867e49fSAngel Zhang 
256867e49fSAngel Zhang   StringRef getArgument() const final {
266867e49fSAngel Zhang     return "test-spirv-func-signature-conversion";
276867e49fSAngel Zhang   }
286867e49fSAngel Zhang 
296867e49fSAngel Zhang   StringRef getDescription() const final {
306867e49fSAngel Zhang     return "Test patterns that convert vector inputs and results in function "
316867e49fSAngel Zhang            "signatures";
326867e49fSAngel Zhang   }
336867e49fSAngel Zhang 
346867e49fSAngel Zhang   void getDependentDialects(DialectRegistry &registry) const override {
356867e49fSAngel Zhang     registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
366867e49fSAngel Zhang                     vector::VectorDialect>();
376867e49fSAngel Zhang   }
386867e49fSAngel Zhang 
396867e49fSAngel Zhang   void runOnOperation() override {
40*f83950abSAngel Zhang     Operation *op = getOperation();
41*f83950abSAngel Zhang     (void)spirv::unrollVectorsInSignatures(op);
426867e49fSAngel Zhang   }
436867e49fSAngel Zhang };
446867e49fSAngel Zhang 
456867e49fSAngel Zhang } // namespace
466867e49fSAngel Zhang 
476867e49fSAngel Zhang namespace test {
486867e49fSAngel Zhang void registerTestSPIRVFuncSignatureConversion() {
496867e49fSAngel Zhang   PassRegistration<TestSPIRVFuncSignatureConversion>();
506867e49fSAngel Zhang }
516867e49fSAngel Zhang } // namespace test
526867e49fSAngel Zhang } // namespace mlir
53