xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TosaOptionalDecompositions.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pass to apply the Tosa operations decompositions
10 // exposed as populate functions in
11 // include/mlir/Dialect/Tosa/Transforms/Passes.h
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 namespace mlir {
23 namespace tosa {
24 #define GEN_PASS_DEF_TOSAOPTIONALDECOMPOSITIONS
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26 } // namespace tosa
27 } // namespace mlir
28 
29 using namespace mlir;
30 
31 namespace {
32 
33 struct TosaOptionalDecompositions
34     : public tosa::impl::TosaOptionalDecompositionsBase<
35           TosaOptionalDecompositions> {
36   void runOnOperation() override {
37     auto *ctx = &getContext();
38     RewritePatternSet patterns(ctx);
39     auto func = getOperation();
40 
41     mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
42     mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
43     mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
44 
45     if (applyPatternsGreedily(func, std::move(patterns)).failed())
46       signalPassFailure();
47   }
48 };
49 
50 } // namespace
51 
52 std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
53   return std::make_unique<TosaOptionalDecompositions>();
54 }
55