xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13bcaf2ebSGeorgios Pinitas //===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
23bcaf2ebSGeorgios Pinitas //
33bcaf2ebSGeorgios Pinitas // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43bcaf2ebSGeorgios Pinitas // See https://llvm.org/LICENSE.txt for license information.
53bcaf2ebSGeorgios Pinitas // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63bcaf2ebSGeorgios Pinitas //
73bcaf2ebSGeorgios Pinitas //===----------------------------------------------------------------------===//
83bcaf2ebSGeorgios Pinitas //
93bcaf2ebSGeorgios Pinitas // This file implements constant folding transformations on TOSA operations
103bcaf2ebSGeorgios Pinitas //
113bcaf2ebSGeorgios Pinitas //===----------------------------------------------------------------------===//
123bcaf2ebSGeorgios Pinitas 
13039b969bSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h"
1467d0d7acSMichele Scuttari 
1567d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/IR/TosaOps.h"
173bcaf2ebSGeorgios Pinitas #include "mlir/Pass/Pass.h"
183bcaf2ebSGeorgios Pinitas #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
193bcaf2ebSGeorgios Pinitas 
2067d0d7acSMichele Scuttari namespace mlir {
2167d0d7acSMichele Scuttari namespace tosa {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS
2367d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace tosa
2567d0d7acSMichele Scuttari } // namespace mlir
2667d0d7acSMichele Scuttari 
273bcaf2ebSGeorgios Pinitas using namespace mlir;
283bcaf2ebSGeorgios Pinitas using namespace mlir::tosa;
293bcaf2ebSGeorgios Pinitas 
303bcaf2ebSGeorgios Pinitas namespace {
313bcaf2ebSGeorgios Pinitas 
321b7feac2SJacques Pienaar template <typename... Args>
331b7feac2SJacques Pienaar void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
3426d811b3SMarkus Böck   (Args::getCanonicalizationPatterns(patterns, ctx), ...);
351b7feac2SJacques Pienaar }
361b7feac2SJacques Pienaar 
371b7feac2SJacques Pienaar void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
381b7feac2SJacques Pienaar                                              RewritePatternSet &patterns) {
391b7feac2SJacques Pienaar   addOpsCanonicalizations<
401b7feac2SJacques Pienaar #define GET_OP_LIST
411b7feac2SJacques Pienaar #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
421b7feac2SJacques Pienaar       >(ctx, patterns);
431b7feac2SJacques Pienaar }
441b7feac2SJacques Pienaar 
453bcaf2ebSGeorgios Pinitas struct TosaLayerwiseConstantFoldPass
4667d0d7acSMichele Scuttari     : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
4767d0d7acSMichele Scuttari           TosaLayerwiseConstantFoldPass> {
489dd15f74SAmir Bishara   TosaLayerwiseConstantFoldPass(
499dd15f74SAmir Bishara       const TosaLayerwiseConstantFoldPassOptions &options)
509dd15f74SAmir Bishara       : TosaLayerwiseConstantFoldPassBase(options) {}
519dd15f74SAmir Bishara 
523bcaf2ebSGeorgios Pinitas   void runOnOperation() override {
533bcaf2ebSGeorgios Pinitas     auto *ctx = &getContext();
543bcaf2ebSGeorgios Pinitas     RewritePatternSet patterns(ctx);
553bcaf2ebSGeorgios Pinitas     auto func = getOperation();
563bcaf2ebSGeorgios Pinitas 
57d84d418eSTina Jung     mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
583bcaf2ebSGeorgios Pinitas     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
599dd15f74SAmir Bishara     mlir::tosa::populateTosaConstantReduction(ctx, patterns,
609dd15f74SAmir Bishara                                               aggressiveReduceConstant);
611b7feac2SJacques Pienaar     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
623bcaf2ebSGeorgios Pinitas 
63*09dfc571SJacques Pienaar     if (applyPatternsGreedily(func, std::move(patterns)).failed())
643bcaf2ebSGeorgios Pinitas       signalPassFailure();
653bcaf2ebSGeorgios Pinitas   }
663bcaf2ebSGeorgios Pinitas };
673bcaf2ebSGeorgios Pinitas 
683bcaf2ebSGeorgios Pinitas } // namespace
69039b969bSMichele Scuttari 
70039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
719dd15f74SAmir Bishara   return std::make_unique<TosaLayerwiseConstantFoldPass>(
729dd15f74SAmir Bishara       TosaLayerwiseConstantFoldPassOptions{false});
739dd15f74SAmir Bishara }
749dd15f74SAmir Bishara 
759dd15f74SAmir Bishara std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(
769dd15f74SAmir Bishara     const TosaLayerwiseConstantFoldPassOptions &options) {
779dd15f74SAmir Bishara   return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
78039b969bSMichele Scuttari }
79