xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp (revision c0342a2de8aa9ffe129bce50e9c90647a772bb2b)
1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 using namespace mlir::scf;
24 
25 namespace {
26 struct ForLoopRangeFolding
27     : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
28   void runOnOperation() override;
29 };
30 } // namespace
31 
32 void ForLoopRangeFolding::runOnOperation() {
33   getOperation()->walk([&](ForOp op) {
34     Value indVar = op.getInductionVar();
35 
36     auto canBeFolded = [&](Value value) {
37       return op.isDefinedOutsideOfLoop(value) || value == indVar;
38     };
39 
40     // Fold until a fixed point is reached
41     while (true) {
42 
43       // If the induction variable is used more than once, we can't fold its
44       // arith ops into the loop range
45       if (!indVar.hasOneUse())
46         break;
47 
48       Operation *user = *indVar.getUsers().begin();
49       if (!isa<arith::AddIOp, arith::MulIOp>(user))
50         break;
51 
52       if (!llvm::all_of(user->getOperands(), canBeFolded))
53         break;
54 
55       OpBuilder b(op);
56       BlockAndValueMapping lbMap;
57       lbMap.map(indVar, op.getLowerBound());
58       BlockAndValueMapping ubMap;
59       ubMap.map(indVar, op.getUpperBound());
60       BlockAndValueMapping stepMap;
61       stepMap.map(indVar, op.getStep());
62 
63       if (isa<arith::AddIOp>(user)) {
64         Operation *lbFold = b.clone(*user, lbMap);
65         Operation *ubFold = b.clone(*user, ubMap);
66 
67         op.setLowerBound(lbFold->getResult(0));
68         op.setUpperBound(ubFold->getResult(0));
69 
70       } else if (isa<arith::MulIOp>(user)) {
71         Operation *ubFold = b.clone(*user, ubMap);
72         Operation *stepFold = b.clone(*user, stepMap);
73 
74         op.setUpperBound(ubFold->getResult(0));
75         op.setStep(stepFold->getResult(0));
76       }
77 
78       ValueRange wrapIndvar(indVar);
79       user->replaceAllUsesWith(wrapIndvar);
80       user->erase();
81     }
82   });
83 }
84 
85 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
86   return std::make_unique<ForLoopRangeFolding>();
87 }
88