xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp (revision 36a405519bf54c7b9bc1247286c59beca0d8eff8)
13f429e82SAnthony Canino //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
23f429e82SAnthony Canino //
33f429e82SAnthony Canino // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43f429e82SAnthony Canino // See https://llvm.org/LICENSE.txt for license information.
53f429e82SAnthony Canino // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63f429e82SAnthony Canino //
73f429e82SAnthony Canino //===----------------------------------------------------------------------===//
83f429e82SAnthony Canino //
93f429e82SAnthony Canino // This file implements loop range folding.
103f429e82SAnthony Canino //
113f429e82SAnthony Canino //===----------------------------------------------------------------------===//
123f429e82SAnthony Canino 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1467d0d7acSMichele Scuttari 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
194d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
203f429e82SAnthony Canino 
2167d0d7acSMichele Scuttari namespace mlir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
2567d0d7acSMichele Scuttari 
263f429e82SAnthony Canino using namespace mlir;
273f429e82SAnthony Canino using namespace mlir::scf;
283f429e82SAnthony Canino 
293f429e82SAnthony Canino namespace {
30039b969bSMichele Scuttari struct ForLoopRangeFolding
3167d0d7acSMichele Scuttari     : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
323f429e82SAnthony Canino   void runOnOperation() override;
333f429e82SAnthony Canino };
343f429e82SAnthony Canino } // namespace
353f429e82SAnthony Canino 
36039b969bSMichele Scuttari void ForLoopRangeFolding::runOnOperation() {
373f429e82SAnthony Canino   getOperation()->walk([&](ForOp op) {
383f429e82SAnthony Canino     Value indVar = op.getInductionVar();
393f429e82SAnthony Canino 
403f429e82SAnthony Canino     auto canBeFolded = [&](Value value) {
413f429e82SAnthony Canino       return op.isDefinedOutsideOfLoop(value) || value == indVar;
423f429e82SAnthony Canino     };
433f429e82SAnthony Canino 
443f429e82SAnthony Canino     // Fold until a fixed point is reached
453f429e82SAnthony Canino     while (true) {
463f429e82SAnthony Canino 
473f429e82SAnthony Canino       // If the induction variable is used more than once, we can't fold its
483f429e82SAnthony Canino       // arith ops into the loop range
493f429e82SAnthony Canino       if (!indVar.hasOneUse())
503f429e82SAnthony Canino         break;
513f429e82SAnthony Canino 
523f429e82SAnthony Canino       Operation *user = *indVar.getUsers().begin();
53a54f4eaeSMogball       if (!isa<arith::AddIOp, arith::MulIOp>(user))
543f429e82SAnthony Canino         break;
553f429e82SAnthony Canino 
563f429e82SAnthony Canino       if (!llvm::all_of(user->getOperands(), canBeFolded))
573f429e82SAnthony Canino         break;
583f429e82SAnthony Canino 
593f429e82SAnthony Canino       OpBuilder b(op);
604d67b278SJeff Niu       IRMapping lbMap;
61c0342a2dSJacques Pienaar       lbMap.map(indVar, op.getLowerBound());
624d67b278SJeff Niu       IRMapping ubMap;
63c0342a2dSJacques Pienaar       ubMap.map(indVar, op.getUpperBound());
644d67b278SJeff Niu       IRMapping stepMap;
65c0342a2dSJacques Pienaar       stepMap.map(indVar, op.getStep());
663f429e82SAnthony Canino 
67a54f4eaeSMogball       if (isa<arith::AddIOp>(user)) {
683f429e82SAnthony Canino         Operation *lbFold = b.clone(*user, lbMap);
693f429e82SAnthony Canino         Operation *ubFold = b.clone(*user, ubMap);
703f429e82SAnthony Canino 
713f429e82SAnthony Canino         op.setLowerBound(lbFold->getResult(0));
723f429e82SAnthony Canino         op.setUpperBound(ubFold->getResult(0));
733f429e82SAnthony Canino 
74a54f4eaeSMogball       } else if (isa<arith::MulIOp>(user)) {
75*36a40551SSasha Lopoukhine         Operation *lbFold = b.clone(*user, lbMap);
763f429e82SAnthony Canino         Operation *ubFold = b.clone(*user, ubMap);
773f429e82SAnthony Canino         Operation *stepFold = b.clone(*user, stepMap);
783f429e82SAnthony Canino 
79*36a40551SSasha Lopoukhine         op.setLowerBound(lbFold->getResult(0));
803f429e82SAnthony Canino         op.setUpperBound(ubFold->getResult(0));
813f429e82SAnthony Canino         op.setStep(stepFold->getResult(0));
823f429e82SAnthony Canino       }
833f429e82SAnthony Canino 
843f429e82SAnthony Canino       ValueRange wrapIndvar(indVar);
853f429e82SAnthony Canino       user->replaceAllUsesWith(wrapIndvar);
863f429e82SAnthony Canino       user->erase();
873f429e82SAnthony Canino     }
883f429e82SAnthony Canino   });
893f429e82SAnthony Canino }
90039b969bSMichele Scuttari 
91039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
92039b969bSMichele Scuttari   return std::make_unique<ForLoopRangeFolding>();
93039b969bSMichele Scuttari }
94