1b194ef69SJakub Kuderski //===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===// 2b194ef69SJakub Kuderski // 3b194ef69SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b194ef69SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information. 5b194ef69SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b194ef69SJakub Kuderski // 7b194ef69SJakub Kuderski //===----------------------------------------------------------------------===// 8b194ef69SJakub Kuderski 9b194ef69SJakub Kuderski #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 10b194ef69SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 11b194ef69SJakub Kuderski #include "mlir/Dialect/Func/IR/FuncOps.h" 12b194ef69SJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 13b194ef69SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h" 14b194ef69SJakub Kuderski #include "mlir/Pass/Pass.h" 15b194ef69SJakub Kuderski #include "mlir/Pass/PassManager.h" 16b194ef69SJakub Kuderski #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17b194ef69SJakub Kuderski 18b194ef69SJakub Kuderski namespace mlir { 19b194ef69SJakub Kuderski namespace { 20b194ef69SJakub Kuderski 21b194ef69SJakub Kuderski struct TestVectorReductionToSPIRVDotProd 22b194ef69SJakub Kuderski : PassWrapper<TestVectorReductionToSPIRVDotProd, 23b194ef69SJakub Kuderski OperationPass<func::FuncOp>> { 24b194ef69SJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 25b194ef69SJakub Kuderski TestVectorReductionToSPIRVDotProd) 26b194ef69SJakub Kuderski 27b194ef69SJakub Kuderski StringRef getArgument() const final { 28b194ef69SJakub Kuderski return "test-vector-reduction-to-spirv-dot-prod"; 29b194ef69SJakub Kuderski } 30b194ef69SJakub Kuderski 31b194ef69SJakub Kuderski StringRef getDescription() const final { 32b194ef69SJakub Kuderski return "Test lowering patterns that converts vector.reduction to SPIR-V " 33b194ef69SJakub Kuderski "integer dot product ops"; 34b194ef69SJakub Kuderski } 35b194ef69SJakub Kuderski 36b194ef69SJakub Kuderski void getDependentDialects(DialectRegistry ®istry) const override { 37b194ef69SJakub Kuderski registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect, 38b194ef69SJakub Kuderski vector::VectorDialect>(); 39b194ef69SJakub Kuderski } 40b194ef69SJakub Kuderski 41b194ef69SJakub Kuderski void runOnOperation() override { 42b194ef69SJakub Kuderski RewritePatternSet patterns(&getContext()); 43b194ef69SJakub Kuderski populateVectorReductionToSPIRVDotProductPatterns(patterns); 44*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 45b194ef69SJakub Kuderski } 46b194ef69SJakub Kuderski }; 47b194ef69SJakub Kuderski 48b194ef69SJakub Kuderski } // namespace 49b194ef69SJakub Kuderski 50b194ef69SJakub Kuderski namespace test { 51b194ef69SJakub Kuderski void registerTestVectorReductionToSPIRVDotProd() { 52b194ef69SJakub Kuderski PassRegistration<TestVectorReductionToSPIRVDotProd>(); 53b194ef69SJakub Kuderski } 54b194ef69SJakub Kuderski } // namespace test 55b194ef69SJakub Kuderski } // namespace mlir 56