1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop range folding. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/SCF/Passes.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 using namespace mlir::scf; 24 25 namespace { 26 struct ForLoopRangeFolding 27 : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> { 28 void runOnOperation() override; 29 }; 30 } // namespace 31 32 void ForLoopRangeFolding::runOnOperation() { 33 getOperation()->walk([&](ForOp op) { 34 Value indVar = op.getInductionVar(); 35 36 auto canBeFolded = [&](Value value) { 37 return op.isDefinedOutsideOfLoop(value) || value == indVar; 38 }; 39 40 // Fold until a fixed point is reached 41 while (true) { 42 43 // If the induction variable is used more than once, we can't fold its 44 // arith ops into the loop range 45 if (!indVar.hasOneUse()) 46 break; 47 48 Operation *user = *indVar.getUsers().begin(); 49 if (!isa<arith::AddIOp, arith::MulIOp>(user)) 50 break; 51 52 if (!llvm::all_of(user->getOperands(), canBeFolded)) 53 break; 54 55 OpBuilder b(op); 56 BlockAndValueMapping lbMap; 57 lbMap.map(indVar, op.getLowerBound()); 58 BlockAndValueMapping ubMap; 59 ubMap.map(indVar, op.getUpperBound()); 60 BlockAndValueMapping stepMap; 61 stepMap.map(indVar, op.getStep()); 62 63 if (isa<arith::AddIOp>(user)) { 64 Operation *lbFold = b.clone(*user, lbMap); 65 Operation *ubFold = b.clone(*user, ubMap); 66 67 op.setLowerBound(lbFold->getResult(0)); 68 op.setUpperBound(ubFold->getResult(0)); 69 70 } else if (isa<arith::MulIOp>(user)) { 71 Operation *ubFold = b.clone(*user, ubMap); 72 Operation *stepFold = b.clone(*user, stepMap); 73 74 op.setUpperBound(ubFold->getResult(0)); 75 op.setStep(stepFold->getResult(0)); 76 } 77 78 ValueRange wrapIndvar(indVar); 79 user->replaceAllUsesWith(wrapIndvar); 80 user->erase(); 81 } 82 }); 83 } 84 85 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() { 86 return std::make_unique<ForLoopRangeFolding>(); 87 } 88