1dac75ae5SNicolas Vasilache //===- TestConstantFold.cpp - Pass to test constant folding ---------------===// 2dac75ae5SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6dac75ae5SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 8dac75ae5SNicolas Vasilache 9dac75ae5SNicolas Vasilache #include "mlir/Pass/Pass.h" 10dac75ae5SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h" 11dac75ae5SNicolas Vasilache 12dac75ae5SNicolas Vasilache using namespace mlir; 13dac75ae5SNicolas Vasilache 14dac75ae5SNicolas Vasilache namespace { 15dac75ae5SNicolas Vasilache /// Simple constant folding pass. 16*9bdfa8dfSMatthias Springer struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>, 17*9bdfa8dfSMatthias Springer public RewriterBase::Listener { 185e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold) 195e50dd04SRiver Riddle 20b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-constant-fold"; } 21b5e22e6dSMehdi Amini StringRef getDescription() const final { 22b5e22e6dSMehdi Amini return "Test operation constant folding"; 23b5e22e6dSMehdi Amini } 2487d6bf37SRiver Riddle // All constants in the operation post folding. 2587d6bf37SRiver Riddle SmallVector<Operation *> existingConstants; 26dac75ae5SNicolas Vasilache 27dac75ae5SNicolas Vasilache void foldOperation(Operation *op, OperationFolder &helper); 2841574554SRiver Riddle void runOnOperation() override; 29*9bdfa8dfSMatthias Springer 30*9bdfa8dfSMatthias Springer void notifyOperationInserted(Operation *op) override { 31*9bdfa8dfSMatthias Springer existingConstants.push_back(op); 32*9bdfa8dfSMatthias Springer } 33*9bdfa8dfSMatthias Springer void notifyOperationRemoved(Operation *op) override { 34*9bdfa8dfSMatthias Springer auto it = llvm::find(existingConstants, op); 35*9bdfa8dfSMatthias Springer if (it != existingConstants.end()) 36*9bdfa8dfSMatthias Springer existingConstants.erase(it); 37*9bdfa8dfSMatthias Springer } 38dac75ae5SNicolas Vasilache }; 39be0a7e9fSMehdi Amini } // namespace 40dac75ae5SNicolas Vasilache 41dac75ae5SNicolas Vasilache void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) { 42dac75ae5SNicolas Vasilache // Attempt to fold the specified operation, including handling unused or 43dac75ae5SNicolas Vasilache // duplicated constants. 44*9bdfa8dfSMatthias Springer (void)helper.tryToFold(op); 45dac75ae5SNicolas Vasilache } 46dac75ae5SNicolas Vasilache 4741574554SRiver Riddle void TestConstantFold::runOnOperation() { 48dac75ae5SNicolas Vasilache existingConstants.clear(); 49dac75ae5SNicolas Vasilache 5087d6bf37SRiver Riddle // Collect and fold the operations within the operation. 51dac75ae5SNicolas Vasilache SmallVector<Operation *, 8> ops; 5287d6bf37SRiver Riddle getOperation()->walk([&](Operation *op) { ops.push_back(op); }); 53dac75ae5SNicolas Vasilache 54dac75ae5SNicolas Vasilache // Fold the constants in reverse so that the last generated constants from 55dac75ae5SNicolas Vasilache // folding are at the beginning. This creates somewhat of a linear ordering to 56dac75ae5SNicolas Vasilache // the newly generated constants that matches the operation order and improves 57dac75ae5SNicolas Vasilache // the readability of test cases. 58*9bdfa8dfSMatthias Springer OperationFolder helper(&getContext(), /*listener=*/this); 59dac75ae5SNicolas Vasilache for (Operation *op : llvm::reverse(ops)) 60dac75ae5SNicolas Vasilache foldOperation(op, helper); 61dac75ae5SNicolas Vasilache 62dac75ae5SNicolas Vasilache // By the time we are done, we may have simplified a bunch of code, leaving 63dac75ae5SNicolas Vasilache // around dead constants. Check for them now and remove them. 64dac75ae5SNicolas Vasilache for (auto *cst : existingConstants) { 65dac75ae5SNicolas Vasilache if (cst->use_empty()) 66dac75ae5SNicolas Vasilache cst->erase(); 67dac75ae5SNicolas Vasilache } 68dac75ae5SNicolas Vasilache } 69dac75ae5SNicolas Vasilache 70c6477050SMehdi Amini namespace mlir { 7172c65b69SAlexander Belyaev namespace test { 72b5e22e6dSMehdi Amini void registerTestConstantFold() { PassRegistration<TestConstantFold>(); } 7372c65b69SAlexander Belyaev } // namespace test 74c6477050SMehdi Amini } // namespace mlir 75