xref: /llvm-project/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (revision af22e274e9c5643780f25066442e05b5bd453328)
1 //===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the optimization passes for the TOSA Dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
14 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
15 
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 class TypeConverter;
22 namespace tosa {
23 
24 #define GEN_PASS_DECL
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26 
27 // Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
28 // The rewrites can be selectively added to a conversion pass.
29 void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
30 void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
31                                         RewritePatternSet &patterns);
32 void populateTosaDecomposeDepthwise(MLIRContext *ctx,
33                                     RewritePatternSet &patterns);
34 void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
35                                                 RewritePatternSet &patterns);
36 void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
37                                                RewritePatternSet &patterns);
38 void populateTosaConstantReduction(MLIRContext *ctx,
39                                    RewritePatternSet &patterns,
40                                    bool aggressiveReduceConstant);
41 
42 void populateTosaTypeConversion(TypeConverter &converter);
43 
44 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
45 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(
46     const TosaLayerwiseConstantFoldPassOptions &options);
47 std::unique_ptr<Pass> createTosaInferShapesPass();
48 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
49 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
50 std::unique_ptr<Pass> createTosaOptionalDecompositions();
51 
52 struct ValidationOptions {
53   /// Validate if operations match for the given profile.
54   TosaProfileEnum profile = TosaProfileEnum::Undefined;
setProfileValidationOptions55   ValidationOptions &setProfile(TosaProfileEnum profile) {
56     this->profile = profile;
57     return *this;
58   }
59   /// Verify if the properties of certain operations align the spec requirement.
60   bool strictOperationSpecAlignment = false;
61   ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) {
62     strictOperationSpecAlignment = enable;
63     return *this;
64   }
65   /// Validate if operator parameters are within specfication for the given
66   /// level.
67   TosaLevelEnum level = TosaLevelEnum::EightK;
setLevelValidationOptions68   ValidationOptions &setLevel(TosaLevelEnum level) {
69     this->level = level;
70     return *this;
71   }
72 };
73 
74 #define GEN_PASS_REGISTRATION
75 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
76 
77 } // namespace tosa
78 } // namespace mlir
79 
80 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
81