xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding transformations on TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 namespace mlir {
21 namespace tosa {
22 #define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS
23 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
24 } // namespace tosa
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::tosa;
29 
30 namespace {
31 
32 template <typename... Args>
33 void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
34   (Args::getCanonicalizationPatterns(patterns, ctx), ...);
35 }
36 
37 void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
38                                              RewritePatternSet &patterns) {
39   addOpsCanonicalizations<
40 #define GET_OP_LIST
41 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
42       >(ctx, patterns);
43 }
44 
45 struct TosaLayerwiseConstantFoldPass
46     : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
47           TosaLayerwiseConstantFoldPass> {
48   TosaLayerwiseConstantFoldPass(
49       const TosaLayerwiseConstantFoldPassOptions &options)
50       : TosaLayerwiseConstantFoldPassBase(options) {}
51 
52   void runOnOperation() override {
53     auto *ctx = &getContext();
54     RewritePatternSet patterns(ctx);
55     auto func = getOperation();
56 
57     mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
58     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
59     mlir::tosa::populateTosaConstantReduction(ctx, patterns,
60                                               aggressiveReduceConstant);
61     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
62 
63     if (applyPatternsGreedily(func, std::move(patterns)).failed())
64       signalPassFailure();
65   }
66 };
67 
68 } // namespace
69 
70 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
71   return std::make_unique<TosaLayerwiseConstantFoldPass>(
72       TosaLayerwiseConstantFoldPassOptions{false});
73 }
74 
75 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(
76     const TosaLayerwiseConstantFoldPassOptions &options) {
77   return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
78 }
79