xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Passes.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 
21 using namespace mlir;
22 using namespace mlir::scf;
23 
24 namespace {
25 struct ForLoopRangeFolding
26     : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
27   void runOnOperation() override;
28 };
29 } // namespace
30 
31 void ForLoopRangeFolding::runOnOperation() {
32   getOperation()->walk([&](ForOp op) {
33     Value indVar = op.getInductionVar();
34 
35     auto canBeFolded = [&](Value value) {
36       return op.isDefinedOutsideOfLoop(value) || value == indVar;
37     };
38 
39     // Fold until a fixed point is reached
40     while (true) {
41 
42       // If the induction variable is used more than once, we can't fold its
43       // arith ops into the loop range
44       if (!indVar.hasOneUse())
45         break;
46 
47       Operation *user = *indVar.getUsers().begin();
48       if (!isa<arith::AddIOp, arith::MulIOp>(user))
49         break;
50 
51       if (!llvm::all_of(user->getOperands(), canBeFolded))
52         break;
53 
54       OpBuilder b(op);
55       BlockAndValueMapping lbMap;
56       lbMap.map(indVar, op.getLowerBound());
57       BlockAndValueMapping ubMap;
58       ubMap.map(indVar, op.getUpperBound());
59       BlockAndValueMapping stepMap;
60       stepMap.map(indVar, op.getStep());
61 
62       if (isa<arith::AddIOp>(user)) {
63         Operation *lbFold = b.clone(*user, lbMap);
64         Operation *ubFold = b.clone(*user, ubMap);
65 
66         op.setLowerBound(lbFold->getResult(0));
67         op.setUpperBound(ubFold->getResult(0));
68 
69       } else if (isa<arith::MulIOp>(user)) {
70         Operation *ubFold = b.clone(*user, ubMap);
71         Operation *stepFold = b.clone(*user, stepMap);
72 
73         op.setUpperBound(ubFold->getResult(0));
74         op.setStep(stepFold->getResult(0));
75       }
76 
77       ValueRange wrapIndvar(indVar);
78       user->replaceAllUsesWith(wrapIndvar);
79       user->erase();
80     }
81   });
82 }
83 
84 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
85   return std::make_unique<ForLoopRangeFolding>();
86 }
87