1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop range folding. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDINGPASS 23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 using namespace mlir::scf; 28 29 namespace { 30 struct SCFForLoopRangeFoldingPass 31 : public impl::SCFForLoopRangeFoldingPassBase<SCFForLoopRangeFoldingPass> { 32 using SCFForLoopRangeFoldingPassBase::SCFForLoopRangeFoldingPassBase; 33 34 void runOnOperation() override; 35 }; 36 } // namespace 37 38 void SCFForLoopRangeFoldingPass::runOnOperation() { 39 getOperation()->walk([&](ForOp op) { 40 Value indVar = op.getInductionVar(); 41 42 auto canBeFolded = [&](Value value) { 43 return op.isDefinedOutsideOfLoop(value) || value == indVar; 44 }; 45 46 // Fold until a fixed point is reached 47 while (true) { 48 49 // If the induction variable is used more than once, we can't fold its 50 // arith ops into the loop range 51 if (!indVar.hasOneUse()) 52 break; 53 54 Operation *user = *indVar.getUsers().begin(); 55 if (!isa<arith::AddIOp, arith::MulIOp>(user)) 56 break; 57 58 if (!llvm::all_of(user->getOperands(), canBeFolded)) 59 break; 60 61 OpBuilder b(op); 62 BlockAndValueMapping lbMap; 63 lbMap.map(indVar, op.getLowerBound()); 64 BlockAndValueMapping ubMap; 65 ubMap.map(indVar, op.getUpperBound()); 66 BlockAndValueMapping stepMap; 67 stepMap.map(indVar, op.getStep()); 68 69 if (isa<arith::AddIOp>(user)) { 70 Operation *lbFold = b.clone(*user, lbMap); 71 Operation *ubFold = b.clone(*user, ubMap); 72 73 op.setLowerBound(lbFold->getResult(0)); 74 op.setUpperBound(ubFold->getResult(0)); 75 76 } else if (isa<arith::MulIOp>(user)) { 77 Operation *ubFold = b.clone(*user, ubMap); 78 Operation *stepFold = b.clone(*user, stepMap); 79 80 op.setUpperBound(ubFold->getResult(0)); 81 op.setStep(stepFold->getResult(0)); 82 } 83 84 ValueRange wrapIndvar(indVar); 85 user->replaceAllUsesWith(wrapIndvar); 86 user->erase(); 87 } 88 }); 89 } 90