xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp (revision 36a405519bf54c7b9bc1247286c59beca0d8eff8)
1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/IRMapping.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 namespace {
30 struct ForLoopRangeFolding
31     : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
32   void runOnOperation() override;
33 };
34 } // namespace
35 
36 void ForLoopRangeFolding::runOnOperation() {
37   getOperation()->walk([&](ForOp op) {
38     Value indVar = op.getInductionVar();
39 
40     auto canBeFolded = [&](Value value) {
41       return op.isDefinedOutsideOfLoop(value) || value == indVar;
42     };
43 
44     // Fold until a fixed point is reached
45     while (true) {
46 
47       // If the induction variable is used more than once, we can't fold its
48       // arith ops into the loop range
49       if (!indVar.hasOneUse())
50         break;
51 
52       Operation *user = *indVar.getUsers().begin();
53       if (!isa<arith::AddIOp, arith::MulIOp>(user))
54         break;
55 
56       if (!llvm::all_of(user->getOperands(), canBeFolded))
57         break;
58 
59       OpBuilder b(op);
60       IRMapping lbMap;
61       lbMap.map(indVar, op.getLowerBound());
62       IRMapping ubMap;
63       ubMap.map(indVar, op.getUpperBound());
64       IRMapping stepMap;
65       stepMap.map(indVar, op.getStep());
66 
67       if (isa<arith::AddIOp>(user)) {
68         Operation *lbFold = b.clone(*user, lbMap);
69         Operation *ubFold = b.clone(*user, ubMap);
70 
71         op.setLowerBound(lbFold->getResult(0));
72         op.setUpperBound(ubFold->getResult(0));
73 
74       } else if (isa<arith::MulIOp>(user)) {
75         Operation *lbFold = b.clone(*user, lbMap);
76         Operation *ubFold = b.clone(*user, ubMap);
77         Operation *stepFold = b.clone(*user, stepMap);
78 
79         op.setLowerBound(lbFold->getResult(0));
80         op.setUpperBound(ubFold->getResult(0));
81         op.setStep(stepFold->getResult(0));
82       }
83 
84       ValueRange wrapIndvar(indVar);
85       user->replaceAllUsesWith(wrapIndvar);
86       user->erase();
87     }
88   });
89 }
90 
91 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
92   return std::make_unique<ForLoopRangeFolding>();
93 }
94