xref: /llvm-project/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1b194ef69SJakub Kuderski //===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===//
2b194ef69SJakub Kuderski //
3b194ef69SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b194ef69SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information.
5b194ef69SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b194ef69SJakub Kuderski //
7b194ef69SJakub Kuderski //===----------------------------------------------------------------------===//
8b194ef69SJakub Kuderski 
9b194ef69SJakub Kuderski #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
10b194ef69SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
11b194ef69SJakub Kuderski #include "mlir/Dialect/Func/IR/FuncOps.h"
12b194ef69SJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
13b194ef69SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h"
14b194ef69SJakub Kuderski #include "mlir/Pass/Pass.h"
15b194ef69SJakub Kuderski #include "mlir/Pass/PassManager.h"
16b194ef69SJakub Kuderski #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17b194ef69SJakub Kuderski 
18b194ef69SJakub Kuderski namespace mlir {
19b194ef69SJakub Kuderski namespace {
20b194ef69SJakub Kuderski 
21b194ef69SJakub Kuderski struct TestVectorReductionToSPIRVDotProd
22b194ef69SJakub Kuderski     : PassWrapper<TestVectorReductionToSPIRVDotProd,
23b194ef69SJakub Kuderski                   OperationPass<func::FuncOp>> {
24b194ef69SJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
25b194ef69SJakub Kuderski       TestVectorReductionToSPIRVDotProd)
26b194ef69SJakub Kuderski 
27b194ef69SJakub Kuderski   StringRef getArgument() const final {
28b194ef69SJakub Kuderski     return "test-vector-reduction-to-spirv-dot-prod";
29b194ef69SJakub Kuderski   }
30b194ef69SJakub Kuderski 
31b194ef69SJakub Kuderski   StringRef getDescription() const final {
32b194ef69SJakub Kuderski     return "Test lowering patterns that converts vector.reduction to SPIR-V "
33b194ef69SJakub Kuderski            "integer dot product ops";
34b194ef69SJakub Kuderski   }
35b194ef69SJakub Kuderski 
36b194ef69SJakub Kuderski   void getDependentDialects(DialectRegistry &registry) const override {
37b194ef69SJakub Kuderski     registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
38b194ef69SJakub Kuderski                     vector::VectorDialect>();
39b194ef69SJakub Kuderski   }
40b194ef69SJakub Kuderski 
41b194ef69SJakub Kuderski   void runOnOperation() override {
42b194ef69SJakub Kuderski     RewritePatternSet patterns(&getContext());
43b194ef69SJakub Kuderski     populateVectorReductionToSPIRVDotProductPatterns(patterns);
44*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
45b194ef69SJakub Kuderski   }
46b194ef69SJakub Kuderski };
47b194ef69SJakub Kuderski 
48b194ef69SJakub Kuderski } // namespace
49b194ef69SJakub Kuderski 
50b194ef69SJakub Kuderski namespace test {
51b194ef69SJakub Kuderski void registerTestVectorReductionToSPIRVDotProd() {
52b194ef69SJakub Kuderski   PassRegistration<TestVectorReductionToSPIRVDotProd>();
53b194ef69SJakub Kuderski }
54b194ef69SJakub Kuderski } // namespace test
55b194ef69SJakub Kuderski } // namespace mlir
56