1 //===- TosaOptionalDecompositions.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pass to apply the Tosa operations decompositions 10 // exposed as populate functions in 11 // include/mlir/Dialect/Tosa/Transforms/Passes.h 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 namespace mlir { 23 namespace tosa { 24 #define GEN_PASS_DEF_TOSAOPTIONALDECOMPOSITIONS 25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 26 } // namespace tosa 27 } // namespace mlir 28 29 using namespace mlir; 30 31 namespace { 32 33 struct TosaOptionalDecompositions 34 : public tosa::impl::TosaOptionalDecompositionsBase< 35 TosaOptionalDecompositions> { 36 void runOnOperation() override { 37 auto *ctx = &getContext(); 38 RewritePatternSet patterns(ctx); 39 auto func = getOperation(); 40 41 mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns); 42 mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); 43 mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); 44 45 if (applyPatternsGreedily(func, std::move(patterns)).failed()) 46 signalPassFailure(); 47 } 48 }; 49 50 } // namespace 51 52 std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() { 53 return std::make_unique<TosaOptionalDecompositions>(); 54 } 55