1 //===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the optimization passes for the TOSA Dialect in MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H 14 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H 15 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" 18 #include "mlir/Pass/Pass.h" 19 20 namespace mlir { 21 class TypeConverter; 22 namespace tosa { 23 24 #define GEN_PASS_DECL 25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 26 27 // Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops. 28 // The rewrites can be selectively added to a conversion pass. 29 void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns); 30 void populateTosaDecomposeTransposeConv(MLIRContext *ctx, 31 RewritePatternSet &patterns); 32 void populateTosaDecomposeDepthwise(MLIRContext *ctx, 33 RewritePatternSet &patterns); 34 void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, 35 RewritePatternSet &patterns); 36 void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, 37 RewritePatternSet &patterns); 38 void populateTosaConstantReduction(MLIRContext *ctx, 39 RewritePatternSet &patterns, 40 bool aggressiveReduceConstant); 41 42 void populateTosaTypeConversion(TypeConverter &converter); 43 44 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(); 45 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass( 46 const TosaLayerwiseConstantFoldPassOptions &options); 47 std::unique_ptr<Pass> createTosaInferShapesPass(); 48 std::unique_ptr<Pass> createTosaMakeBroadcastablePass(); 49 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass(); 50 std::unique_ptr<Pass> createTosaOptionalDecompositions(); 51 52 struct ValidationOptions { 53 /// Validate if operations match for the given profile. 54 TosaProfileEnum profile = TosaProfileEnum::Undefined; setProfileValidationOptions55 ValidationOptions &setProfile(TosaProfileEnum profile) { 56 this->profile = profile; 57 return *this; 58 } 59 /// Verify if the properties of certain operations align the spec requirement. 60 bool strictOperationSpecAlignment = false; 61 ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) { 62 strictOperationSpecAlignment = enable; 63 return *this; 64 } 65 /// Validate if operator parameters are within specfication for the given 66 /// level. 67 TosaLevelEnum level = TosaLevelEnum::EightK; setLevelValidationOptions68 ValidationOptions &setLevel(TosaLevelEnum level) { 69 this->level = level; 70 return *this; 71 } 72 }; 73 74 #define GEN_PASS_REGISTRATION 75 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 76 77 } // namespace tosa 78 } // namespace mlir 79 80 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H 81