xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp (revision 2be8af8f0e0780901213b6fd3013a5268ddc3359)
1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDINGPASS
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 namespace {
30 struct SCFForLoopRangeFoldingPass
31     : public impl::SCFForLoopRangeFoldingPassBase<SCFForLoopRangeFoldingPass> {
32   using SCFForLoopRangeFoldingPassBase::SCFForLoopRangeFoldingPassBase;
33 
34   void runOnOperation() override;
35 };
36 } // namespace
37 
38 void SCFForLoopRangeFoldingPass::runOnOperation() {
39   getOperation()->walk([&](ForOp op) {
40     Value indVar = op.getInductionVar();
41 
42     auto canBeFolded = [&](Value value) {
43       return op.isDefinedOutsideOfLoop(value) || value == indVar;
44     };
45 
46     // Fold until a fixed point is reached
47     while (true) {
48 
49       // If the induction variable is used more than once, we can't fold its
50       // arith ops into the loop range
51       if (!indVar.hasOneUse())
52         break;
53 
54       Operation *user = *indVar.getUsers().begin();
55       if (!isa<arith::AddIOp, arith::MulIOp>(user))
56         break;
57 
58       if (!llvm::all_of(user->getOperands(), canBeFolded))
59         break;
60 
61       OpBuilder b(op);
62       BlockAndValueMapping lbMap;
63       lbMap.map(indVar, op.getLowerBound());
64       BlockAndValueMapping ubMap;
65       ubMap.map(indVar, op.getUpperBound());
66       BlockAndValueMapping stepMap;
67       stepMap.map(indVar, op.getStep());
68 
69       if (isa<arith::AddIOp>(user)) {
70         Operation *lbFold = b.clone(*user, lbMap);
71         Operation *ubFold = b.clone(*user, ubMap);
72 
73         op.setLowerBound(lbFold->getResult(0));
74         op.setUpperBound(ubFold->getResult(0));
75 
76       } else if (isa<arith::MulIOp>(user)) {
77         Operation *ubFold = b.clone(*user, ubMap);
78         Operation *stepFold = b.clone(*user, stepMap);
79 
80         op.setUpperBound(ubFold->getResult(0));
81         op.setStep(stepFold->getResult(0));
82       }
83 
84       ValueRange wrapIndvar(indVar);
85       user->replaceAllUsesWith(wrapIndvar);
86       user->erase();
87     }
88   });
89 }
90