xref: /llvm-project/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (revision b5c5c2b26fd4bd0d0d237aaf77a01ca528810707)
1ab701975SMogball //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2ab701975SMogball //
3ab701975SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ab701975SMogball // See https://llvm.org/LICENSE.txt for license information.
5ab701975SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ab701975SMogball //
7ab701975SMogball //===----------------------------------------------------------------------===//
8ab701975SMogball //
9ab701975SMogball // This file defines the dataflow analysis class for integer range inference
10ab701975SMogball // which is used in transformations over the `arith` dialect such as
11ab701975SMogball // branch elimination or signed->unsigned rewriting
12ab701975SMogball //
13ab701975SMogball //===----------------------------------------------------------------------===//
14ab701975SMogball 
15ab701975SMogball #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16ab701975SMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17383f2bd5SMehdi Amini #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18383f2bd5SMehdi Amini #include "mlir/Analysis/DataFlowFramework.h"
19383f2bd5SMehdi Amini #include "mlir/IR/BuiltinAttributes.h"
20383f2bd5SMehdi Amini #include "mlir/IR/Dialect.h"
21383f2bd5SMehdi Amini #include "mlir/IR/OpDefinition.h"
22f54cdc5dSIvan Butygin #include "mlir/IR/TypeUtilities.h"
23383f2bd5SMehdi Amini #include "mlir/IR/Value.h"
24383f2bd5SMehdi Amini #include "mlir/Interfaces/ControlFlowInterfaces.h"
25ab701975SMogball #include "mlir/Interfaces/InferIntRangeInterface.h"
26ab701975SMogball #include "mlir/Interfaces/LoopLikeInterface.h"
27383f2bd5SMehdi Amini #include "mlir/Support/LLVM.h"
28383f2bd5SMehdi Amini #include "llvm/ADT/STLExtras.h"
29383f2bd5SMehdi Amini #include "llvm/Support/Casting.h"
30ab701975SMogball #include "llvm/Support/Debug.h"
31383f2bd5SMehdi Amini #include <cassert>
32a1fe1f5fSKazu Hirata #include <optional>
33383f2bd5SMehdi Amini #include <utility>
34ab701975SMogball 
35ab701975SMogball #define DEBUG_TYPE "int-range-analysis"
36ab701975SMogball 
37ab701975SMogball using namespace mlir;
38ab701975SMogball using namespace mlir::dataflow;
39ab701975SMogball 
40ab701975SMogball void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
41ab701975SMogball   Lattice::onUpdate(solver);
42ab701975SMogball 
43ab701975SMogball   // If the integer range can be narrowed to a constant, update the constant
44ab701975SMogball   // value of the SSA value.
450a81ace0SKazu Hirata   std::optional<APInt> constant = getValue().getValue().getConstantValue();
46*b5c5c2b2SKazu Hirata   auto value = cast<Value>(anchor);
47ab701975SMogball   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48ab701975SMogball   if (!constant)
49de0ebc52SZhixun Tan     return solver->propagateIfChanged(
50de0ebc52SZhixun Tan         cv, cv->join(ConstantValue::getUnknownConstant()));
51ab701975SMogball 
52ab701975SMogball   Dialect *dialect;
53ab701975SMogball   if (auto *parent = value.getDefiningOp())
54ab701975SMogball     dialect = parent->getDialect();
55ab701975SMogball   else
56ab701975SMogball     dialect = value.getParentBlock()->getParentOp()->getDialect();
57f54cdc5dSIvan Butygin 
58f54cdc5dSIvan Butygin   Type type = getElementTypeOrSelf(value);
59ab701975SMogball   solver->propagateIfChanged(
60f54cdc5dSIvan Butygin       cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
61ab701975SMogball }
62ab701975SMogball 
6315e915a4SIvan Butygin LogicalResult IntegerRangeAnalysis::visitOperation(
64ab701975SMogball     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
65ab701975SMogball     ArrayRef<IntegerValueRangeLattice *> results) {
66ab701975SMogball   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
6715e915a4SIvan Butygin   if (!inferrable) {
6815e915a4SIvan Butygin     setAllToEntryStates(results);
6915e915a4SIvan Butygin     return success();
7015e915a4SIvan Butygin   }
71ab701975SMogball 
72ab701975SMogball   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
736aeea700SSpenser Bauman   auto argRanges = llvm::map_to_vector(
746aeea700SSpenser Bauman       operands, [](const IntegerValueRangeLattice *lattice) {
756aeea700SSpenser Bauman         return lattice->getValue();
766aeea700SSpenser Bauman       });
77ab701975SMogball 
786aeea700SSpenser Bauman   auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
795550c821STres Popp     auto result = dyn_cast<OpResult>(v);
80ab701975SMogball     if (!result)
81ab701975SMogball       return;
82360c1111SKazu Hirata     assert(llvm::is_contained(op->getResults(), result));
83ab701975SMogball 
84ab701975SMogball     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
85ab701975SMogball     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
8647bf3e38SZhixun Tan     IntegerValueRange oldRange = lattice->getValue();
87ab701975SMogball 
886aeea700SSpenser Bauman     ChangeResult changed = lattice->join(attrs);
89ab701975SMogball 
90ab701975SMogball     // Catch loop results with loop variant bounds and conservatively make
91ab701975SMogball     // them [-inf, inf] so we don't circle around infinitely often (because
92ab701975SMogball     // the dataflow analysis in MLIR doesn't attempt to work out trip counts
93ab701975SMogball     // and often can't).
94ab701975SMogball     bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
95ab701975SMogball       return op->hasTrait<OpTrait::IsTerminator>();
96ab701975SMogball     });
9747bf3e38SZhixun Tan     if (isYieldedResult && !oldRange.isUninitialized() &&
9847bf3e38SZhixun Tan         !(lattice->getValue() == oldRange)) {
99ab701975SMogball       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
100de0ebc52SZhixun Tan       changed |= lattice->join(IntegerValueRange::getMaxRange(v));
101ab701975SMogball     }
102ab701975SMogball     propagateIfChanged(lattice, changed);
103ab701975SMogball   };
104ab701975SMogball 
1056aeea700SSpenser Bauman   inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
10615e915a4SIvan Butygin   return success();
107ab701975SMogball }
108ab701975SMogball 
109ab701975SMogball void IntegerRangeAnalysis::visitNonControlFlowArguments(
110ab701975SMogball     Operation *op, const RegionSuccessor &successor,
111ab701975SMogball     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
112ab701975SMogball   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
113ab701975SMogball     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
114ab701975SMogball 
1156aeea700SSpenser Bauman     auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
1164b3f251bSdonald chen       return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
1176aeea700SSpenser Bauman     });
1186aeea700SSpenser Bauman 
1196aeea700SSpenser Bauman     auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
1205550c821STres Popp       auto arg = dyn_cast<BlockArgument>(v);
121ab701975SMogball       if (!arg)
122ab701975SMogball         return;
123360c1111SKazu Hirata       if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
124ab701975SMogball         return;
125ab701975SMogball 
126ab701975SMogball       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
127ab701975SMogball       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
12847bf3e38SZhixun Tan       IntegerValueRange oldRange = lattice->getValue();
129ab701975SMogball 
1306aeea700SSpenser Bauman       ChangeResult changed = lattice->join(attrs);
131ab701975SMogball 
132ab701975SMogball       // Catch loop results with loop variant bounds and conservatively make
133ab701975SMogball       // them [-inf, inf] so we don't circle around infinitely often (because
134ab701975SMogball       // the dataflow analysis in MLIR doesn't attempt to work out trip counts
135ab701975SMogball       // and often can't).
136ab701975SMogball       bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
137ab701975SMogball         return op->hasTrait<OpTrait::IsTerminator>();
138ab701975SMogball       });
13947bf3e38SZhixun Tan       if (isYieldedValue && !oldRange.isUninitialized() &&
14047bf3e38SZhixun Tan           !(lattice->getValue() == oldRange)) {
141ab701975SMogball         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
142de0ebc52SZhixun Tan         changed |= lattice->join(IntegerValueRange::getMaxRange(v));
143ab701975SMogball       }
144ab701975SMogball       propagateIfChanged(lattice, changed);
145ab701975SMogball     };
146ab701975SMogball 
1476aeea700SSpenser Bauman     inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
148ab701975SMogball     return;
149ab701975SMogball   }
150ab701975SMogball 
151ab701975SMogball   /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152ab701975SMogball   /// on a LoopLikeInterface return the lower/upper bound for that result if
153ab701975SMogball   /// possible.
15422426110SRamkumar Ramachandra   auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155ab701975SMogball                                   Type boundType, bool getUpper) {
156ab701975SMogball     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157491d2701SKazu Hirata     if (loopBound.has_value()) {
158*b5c5c2b2SKazu Hirata       if (auto attr = dyn_cast<Attribute>(*loopBound)) {
159*b5c5c2b2SKazu Hirata         if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
160ab701975SMogball           return bound.getValue();
16168f58812STres Popp       } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
162ab701975SMogball         const IntegerValueRangeLattice *lattice =
1634b3f251bSdonald chen             getLatticeElementFor(getProgramPointAfter(op), value);
16413c648f6SVictor Perez         if (lattice != nullptr && !lattice->getValue().isUninitialized())
165ab701975SMogball           return getUpper ? lattice->getValue().getValue().smax()
166ab701975SMogball                           : lattice->getValue().getValue().smin();
167ab701975SMogball       }
168ab701975SMogball     }
169ab701975SMogball     // Given the results of getConstant{Lower,Upper}Bound()
170ab701975SMogball     // or getConstantStep() on a LoopLikeInterface return the lower/upper
171ab701975SMogball     // bound
172ab701975SMogball     return getUpper ? APInt::getSignedMaxValue(width)
173ab701975SMogball                     : APInt::getSignedMinValue(width);
174ab701975SMogball   };
175ab701975SMogball 
176ab701975SMogball   // Infer bounds for loop arguments that have static bounds
177ab701975SMogball   if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
17822426110SRamkumar Ramachandra     std::optional<Value> iv = loop.getSingleInductionVar();
179ab701975SMogball     if (!iv) {
180b2b7efb9SAlex Zinenko       return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
181ab701975SMogball           op, successor, argLattices, firstIndex);
182ab701975SMogball     }
18322426110SRamkumar Ramachandra     std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
18422426110SRamkumar Ramachandra     std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
18522426110SRamkumar Ramachandra     std::optional<OpFoldResult> step = loop.getSingleStep();
186ab701975SMogball     APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
187ab701975SMogball                                      /*getUpper=*/false);
188ab701975SMogball     APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
189ab701975SMogball                                      /*getUpper=*/true);
190ab701975SMogball     // Assume positivity for uniscoverable steps by way of getUpper = true.
191ab701975SMogball     APInt stepVal =
192ab701975SMogball         getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
193ab701975SMogball 
194ab701975SMogball     if (stepVal.isNegative()) {
195ab701975SMogball       std::swap(min, max);
196ab701975SMogball     } else {
197ab701975SMogball       // Correct the upper bound by subtracting 1 so that it becomes a <=
198ab701975SMogball       // bound, because loops do not generally include their upper bound.
199ab701975SMogball       max -= 1;
200ab701975SMogball     }
201ab701975SMogball 
202b78883fcSFelix Schneider     // If we infer the lower bound to be larger than the upper bound, the
203b78883fcSFelix Schneider     // resulting range is meaningless and should not be used in further
204b78883fcSFelix Schneider     // inferences.
205b78883fcSFelix Schneider     if (max.sge(min)) {
206ab701975SMogball       IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
207ab701975SMogball       auto ivRange = ConstantIntRanges::fromSigned(min, max);
20847bf3e38SZhixun Tan       propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
209b78883fcSFelix Schneider     }
210ab701975SMogball     return;
211ab701975SMogball   }
212ab701975SMogball 
213b2b7efb9SAlex Zinenko   return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
214ab701975SMogball       op, successor, argLattices, firstIndex);
215ab701975SMogball }
216