1 //===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements constant folding transformations on TOSA operations 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 namespace mlir { 21 namespace tosa { 22 #define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS 23 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 24 } // namespace tosa 25 } // namespace mlir 26 27 using namespace mlir; 28 using namespace mlir::tosa; 29 30 namespace { 31 32 template <typename... Args> 33 void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { 34 (Args::getCanonicalizationPatterns(patterns, ctx), ...); 35 } 36 37 void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, 38 RewritePatternSet &patterns) { 39 addOpsCanonicalizations< 40 #define GET_OP_LIST 41 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" 42 >(ctx, patterns); 43 } 44 45 struct TosaLayerwiseConstantFoldPass 46 : public tosa::impl::TosaLayerwiseConstantFoldPassBase< 47 TosaLayerwiseConstantFoldPass> { 48 TosaLayerwiseConstantFoldPass( 49 const TosaLayerwiseConstantFoldPassOptions &options) 50 : TosaLayerwiseConstantFoldPassBase(options) {} 51 52 void runOnOperation() override { 53 auto *ctx = &getContext(); 54 RewritePatternSet patterns(ctx); 55 auto func = getOperation(); 56 57 mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); 58 mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); 59 mlir::tosa::populateTosaConstantReduction(ctx, patterns, 60 aggressiveReduceConstant); 61 populateTosaOpsCanonicalizationPatterns(ctx, patterns); 62 63 if (applyPatternsGreedily(func, std::move(patterns)).failed()) 64 signalPassFailure(); 65 } 66 }; 67 68 } // namespace 69 70 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() { 71 return std::make_unique<TosaLayerwiseConstantFoldPass>( 72 TosaLayerwiseConstantFoldPassOptions{false}); 73 } 74 75 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass( 76 const TosaLayerwiseConstantFoldPassOptions &options) { 77 return std::make_unique<TosaLayerwiseConstantFoldPass>(options); 78 } 79