13f429e82SAnthony Canino //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// 23f429e82SAnthony Canino // 33f429e82SAnthony Canino // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43f429e82SAnthony Canino // See https://llvm.org/LICENSE.txt for license information. 53f429e82SAnthony Canino // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63f429e82SAnthony Canino // 73f429e82SAnthony Canino //===----------------------------------------------------------------------===// 83f429e82SAnthony Canino // 93f429e82SAnthony Canino // This file implements loop range folding. 103f429e82SAnthony Canino // 113f429e82SAnthony Canino //===----------------------------------------------------------------------===// 123f429e82SAnthony Canino 1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1467d0d7acSMichele Scuttari 15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h" 194d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 203f429e82SAnthony Canino 2167d0d7acSMichele Scuttari namespace mlir { 2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING 2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 2467d0d7acSMichele Scuttari } // namespace mlir 2567d0d7acSMichele Scuttari 263f429e82SAnthony Canino using namespace mlir; 273f429e82SAnthony Canino using namespace mlir::scf; 283f429e82SAnthony Canino 293f429e82SAnthony Canino namespace { 30039b969bSMichele Scuttari struct ForLoopRangeFolding 3167d0d7acSMichele Scuttari : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> { 323f429e82SAnthony Canino void runOnOperation() override; 333f429e82SAnthony Canino }; 343f429e82SAnthony Canino } // namespace 353f429e82SAnthony Canino 36039b969bSMichele Scuttari void ForLoopRangeFolding::runOnOperation() { 373f429e82SAnthony Canino getOperation()->walk([&](ForOp op) { 383f429e82SAnthony Canino Value indVar = op.getInductionVar(); 393f429e82SAnthony Canino 403f429e82SAnthony Canino auto canBeFolded = [&](Value value) { 413f429e82SAnthony Canino return op.isDefinedOutsideOfLoop(value) || value == indVar; 423f429e82SAnthony Canino }; 433f429e82SAnthony Canino 443f429e82SAnthony Canino // Fold until a fixed point is reached 453f429e82SAnthony Canino while (true) { 463f429e82SAnthony Canino 473f429e82SAnthony Canino // If the induction variable is used more than once, we can't fold its 483f429e82SAnthony Canino // arith ops into the loop range 493f429e82SAnthony Canino if (!indVar.hasOneUse()) 503f429e82SAnthony Canino break; 513f429e82SAnthony Canino 523f429e82SAnthony Canino Operation *user = *indVar.getUsers().begin(); 53a54f4eaeSMogball if (!isa<arith::AddIOp, arith::MulIOp>(user)) 543f429e82SAnthony Canino break; 553f429e82SAnthony Canino 563f429e82SAnthony Canino if (!llvm::all_of(user->getOperands(), canBeFolded)) 573f429e82SAnthony Canino break; 583f429e82SAnthony Canino 593f429e82SAnthony Canino OpBuilder b(op); 604d67b278SJeff Niu IRMapping lbMap; 61c0342a2dSJacques Pienaar lbMap.map(indVar, op.getLowerBound()); 624d67b278SJeff Niu IRMapping ubMap; 63c0342a2dSJacques Pienaar ubMap.map(indVar, op.getUpperBound()); 644d67b278SJeff Niu IRMapping stepMap; 65c0342a2dSJacques Pienaar stepMap.map(indVar, op.getStep()); 663f429e82SAnthony Canino 67a54f4eaeSMogball if (isa<arith::AddIOp>(user)) { 683f429e82SAnthony Canino Operation *lbFold = b.clone(*user, lbMap); 693f429e82SAnthony Canino Operation *ubFold = b.clone(*user, ubMap); 703f429e82SAnthony Canino 713f429e82SAnthony Canino op.setLowerBound(lbFold->getResult(0)); 723f429e82SAnthony Canino op.setUpperBound(ubFold->getResult(0)); 733f429e82SAnthony Canino 74a54f4eaeSMogball } else if (isa<arith::MulIOp>(user)) { 75*36a40551SSasha Lopoukhine Operation *lbFold = b.clone(*user, lbMap); 763f429e82SAnthony Canino Operation *ubFold = b.clone(*user, ubMap); 773f429e82SAnthony Canino Operation *stepFold = b.clone(*user, stepMap); 783f429e82SAnthony Canino 79*36a40551SSasha Lopoukhine op.setLowerBound(lbFold->getResult(0)); 803f429e82SAnthony Canino op.setUpperBound(ubFold->getResult(0)); 813f429e82SAnthony Canino op.setStep(stepFold->getResult(0)); 823f429e82SAnthony Canino } 833f429e82SAnthony Canino 843f429e82SAnthony Canino ValueRange wrapIndvar(indVar); 853f429e82SAnthony Canino user->replaceAllUsesWith(wrapIndvar); 863f429e82SAnthony Canino user->erase(); 873f429e82SAnthony Canino } 883f429e82SAnthony Canino }); 893f429e82SAnthony Canino } 90039b969bSMichele Scuttari 91039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() { 92039b969bSMichele Scuttari return std::make_unique<ForLoopRangeFolding>(); 93039b969bSMichele Scuttari } 94