1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop range folding. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/IR/IRMapping.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING 23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 using namespace mlir::scf; 28 29 namespace { 30 struct ForLoopRangeFolding 31 : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> { 32 void runOnOperation() override; 33 }; 34 } // namespace 35 36 void ForLoopRangeFolding::runOnOperation() { 37 getOperation()->walk([&](ForOp op) { 38 Value indVar = op.getInductionVar(); 39 40 auto canBeFolded = [&](Value value) { 41 return op.isDefinedOutsideOfLoop(value) || value == indVar; 42 }; 43 44 // Fold until a fixed point is reached 45 while (true) { 46 47 // If the induction variable is used more than once, we can't fold its 48 // arith ops into the loop range 49 if (!indVar.hasOneUse()) 50 break; 51 52 Operation *user = *indVar.getUsers().begin(); 53 if (!isa<arith::AddIOp, arith::MulIOp>(user)) 54 break; 55 56 if (!llvm::all_of(user->getOperands(), canBeFolded)) 57 break; 58 59 OpBuilder b(op); 60 IRMapping lbMap; 61 lbMap.map(indVar, op.getLowerBound()); 62 IRMapping ubMap; 63 ubMap.map(indVar, op.getUpperBound()); 64 IRMapping stepMap; 65 stepMap.map(indVar, op.getStep()); 66 67 if (isa<arith::AddIOp>(user)) { 68 Operation *lbFold = b.clone(*user, lbMap); 69 Operation *ubFold = b.clone(*user, ubMap); 70 71 op.setLowerBound(lbFold->getResult(0)); 72 op.setUpperBound(ubFold->getResult(0)); 73 74 } else if (isa<arith::MulIOp>(user)) { 75 Operation *lbFold = b.clone(*user, lbMap); 76 Operation *ubFold = b.clone(*user, ubMap); 77 Operation *stepFold = b.clone(*user, stepMap); 78 79 op.setLowerBound(lbFold->getResult(0)); 80 op.setUpperBound(ubFold->getResult(0)); 81 op.setStep(stepFold->getResult(0)); 82 } 83 84 ValueRange wrapIndvar(indVar); 85 user->replaceAllUsesWith(wrapIndvar); 86 user->erase(); 87 } 88 }); 89 } 90 91 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() { 92 return std::make_unique<ForLoopRangeFolding>(); 93 } 94