xref: /llvm-project/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (revision b5c5c2b26fd4bd0d0d237aaf77a01ca528810707)
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Interfaces/ControlFlowInterfaces.h"
25 #include "mlir/Interfaces/InferIntRangeInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include <cassert>
32 #include <optional>
33 #include <utility>
34 
35 #define DEBUG_TYPE "int-range-analysis"
36 
37 using namespace mlir;
38 using namespace mlir::dataflow;
39 
40 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
41   Lattice::onUpdate(solver);
42 
43   // If the integer range can be narrowed to a constant, update the constant
44   // value of the SSA value.
45   std::optional<APInt> constant = getValue().getValue().getConstantValue();
46   auto value = cast<Value>(anchor);
47   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48   if (!constant)
49     return solver->propagateIfChanged(
50         cv, cv->join(ConstantValue::getUnknownConstant()));
51 
52   Dialect *dialect;
53   if (auto *parent = value.getDefiningOp())
54     dialect = parent->getDialect();
55   else
56     dialect = value.getParentBlock()->getParentOp()->getDialect();
57 
58   Type type = getElementTypeOrSelf(value);
59   solver->propagateIfChanged(
60       cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
61 }
62 
63 LogicalResult IntegerRangeAnalysis::visitOperation(
64     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
65     ArrayRef<IntegerValueRangeLattice *> results) {
66   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
67   if (!inferrable) {
68     setAllToEntryStates(results);
69     return success();
70   }
71 
72   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
73   auto argRanges = llvm::map_to_vector(
74       operands, [](const IntegerValueRangeLattice *lattice) {
75         return lattice->getValue();
76       });
77 
78   auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
79     auto result = dyn_cast<OpResult>(v);
80     if (!result)
81       return;
82     assert(llvm::is_contained(op->getResults(), result));
83 
84     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
85     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
86     IntegerValueRange oldRange = lattice->getValue();
87 
88     ChangeResult changed = lattice->join(attrs);
89 
90     // Catch loop results with loop variant bounds and conservatively make
91     // them [-inf, inf] so we don't circle around infinitely often (because
92     // the dataflow analysis in MLIR doesn't attempt to work out trip counts
93     // and often can't).
94     bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
95       return op->hasTrait<OpTrait::IsTerminator>();
96     });
97     if (isYieldedResult && !oldRange.isUninitialized() &&
98         !(lattice->getValue() == oldRange)) {
99       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
100       changed |= lattice->join(IntegerValueRange::getMaxRange(v));
101     }
102     propagateIfChanged(lattice, changed);
103   };
104 
105   inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
106   return success();
107 }
108 
109 void IntegerRangeAnalysis::visitNonControlFlowArguments(
110     Operation *op, const RegionSuccessor &successor,
111     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
112   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
113     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
114 
115     auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
116       return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
117     });
118 
119     auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
120       auto arg = dyn_cast<BlockArgument>(v);
121       if (!arg)
122         return;
123       if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
124         return;
125 
126       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
127       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
128       IntegerValueRange oldRange = lattice->getValue();
129 
130       ChangeResult changed = lattice->join(attrs);
131 
132       // Catch loop results with loop variant bounds and conservatively make
133       // them [-inf, inf] so we don't circle around infinitely often (because
134       // the dataflow analysis in MLIR doesn't attempt to work out trip counts
135       // and often can't).
136       bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
137         return op->hasTrait<OpTrait::IsTerminator>();
138       });
139       if (isYieldedValue && !oldRange.isUninitialized() &&
140           !(lattice->getValue() == oldRange)) {
141         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
142         changed |= lattice->join(IntegerValueRange::getMaxRange(v));
143       }
144       propagateIfChanged(lattice, changed);
145     };
146 
147     inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
148     return;
149   }
150 
151   /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152   /// on a LoopLikeInterface return the lower/upper bound for that result if
153   /// possible.
154   auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155                                   Type boundType, bool getUpper) {
156     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157     if (loopBound.has_value()) {
158       if (auto attr = dyn_cast<Attribute>(*loopBound)) {
159         if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
160           return bound.getValue();
161       } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
162         const IntegerValueRangeLattice *lattice =
163             getLatticeElementFor(getProgramPointAfter(op), value);
164         if (lattice != nullptr && !lattice->getValue().isUninitialized())
165           return getUpper ? lattice->getValue().getValue().smax()
166                           : lattice->getValue().getValue().smin();
167       }
168     }
169     // Given the results of getConstant{Lower,Upper}Bound()
170     // or getConstantStep() on a LoopLikeInterface return the lower/upper
171     // bound
172     return getUpper ? APInt::getSignedMaxValue(width)
173                     : APInt::getSignedMinValue(width);
174   };
175 
176   // Infer bounds for loop arguments that have static bounds
177   if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
178     std::optional<Value> iv = loop.getSingleInductionVar();
179     if (!iv) {
180       return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
181           op, successor, argLattices, firstIndex);
182     }
183     std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
184     std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
185     std::optional<OpFoldResult> step = loop.getSingleStep();
186     APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
187                                      /*getUpper=*/false);
188     APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
189                                      /*getUpper=*/true);
190     // Assume positivity for uniscoverable steps by way of getUpper = true.
191     APInt stepVal =
192         getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
193 
194     if (stepVal.isNegative()) {
195       std::swap(min, max);
196     } else {
197       // Correct the upper bound by subtracting 1 so that it becomes a <=
198       // bound, because loops do not generally include their upper bound.
199       max -= 1;
200     }
201 
202     // If we infer the lower bound to be larger than the upper bound, the
203     // resulting range is meaningless and should not be used in further
204     // inferences.
205     if (max.sge(min)) {
206       IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
207       auto ivRange = ConstantIntRanges::fromSigned(min, max);
208       propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
209     }
210     return;
211   }
212 
213   return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
214       op, successor, argLattices, firstIndex);
215 }
216