1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop range folding. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms/Passes.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 21 using namespace mlir; 22 using namespace mlir::scf; 23 24 namespace { 25 struct ForLoopRangeFolding 26 : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> { 27 void runOnOperation() override; 28 }; 29 } // namespace 30 31 void ForLoopRangeFolding::runOnOperation() { 32 getOperation()->walk([&](ForOp op) { 33 Value indVar = op.getInductionVar(); 34 35 auto canBeFolded = [&](Value value) { 36 return op.isDefinedOutsideOfLoop(value) || value == indVar; 37 }; 38 39 // Fold until a fixed point is reached 40 while (true) { 41 42 // If the induction variable is used more than once, we can't fold its 43 // arith ops into the loop range 44 if (!indVar.hasOneUse()) 45 break; 46 47 Operation *user = *indVar.getUsers().begin(); 48 if (!isa<arith::AddIOp, arith::MulIOp>(user)) 49 break; 50 51 if (!llvm::all_of(user->getOperands(), canBeFolded)) 52 break; 53 54 OpBuilder b(op); 55 BlockAndValueMapping lbMap; 56 lbMap.map(indVar, op.getLowerBound()); 57 BlockAndValueMapping ubMap; 58 ubMap.map(indVar, op.getUpperBound()); 59 BlockAndValueMapping stepMap; 60 stepMap.map(indVar, op.getStep()); 61 62 if (isa<arith::AddIOp>(user)) { 63 Operation *lbFold = b.clone(*user, lbMap); 64 Operation *ubFold = b.clone(*user, ubMap); 65 66 op.setLowerBound(lbFold->getResult(0)); 67 op.setUpperBound(ubFold->getResult(0)); 68 69 } else if (isa<arith::MulIOp>(user)) { 70 Operation *ubFold = b.clone(*user, ubMap); 71 Operation *stepFold = b.clone(*user, stepMap); 72 73 op.setUpperBound(ubFold->getResult(0)); 74 op.setStep(stepFold->getResult(0)); 75 } 76 77 ValueRange wrapIndvar(indVar); 78 user->replaceAllUsesWith(wrapIndvar); 79 user->erase(); 80 } 81 }); 82 } 83 84 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() { 85 return std::make_unique<ForLoopRangeFolding>(); 86 } 87