xref: /llvm-project/mlir/lib/Interfaces/InferIntRangeInterface.cpp (revision 616aff126caaf93a0d9868d279e4c99d1e45fef0)
195aff23eSKrzysztof Drewniak //===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
295aff23eSKrzysztof Drewniak //
395aff23eSKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
495aff23eSKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
595aff23eSKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
695aff23eSKrzysztof Drewniak //
795aff23eSKrzysztof Drewniak //===----------------------------------------------------------------------===//
895aff23eSKrzysztof Drewniak 
995aff23eSKrzysztof Drewniak #include "mlir/Interfaces/InferIntRangeInterface.h"
1095aff23eSKrzysztof Drewniak #include "mlir/IR/BuiltinTypes.h"
11*616aff12SKrzysztof Drewniak #include "mlir/IR/TypeUtilities.h"
1295aff23eSKrzysztof Drewniak #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
13d99d258eSMehdi Amini #include <optional>
1495aff23eSKrzysztof Drewniak 
1595aff23eSKrzysztof Drewniak using namespace mlir;
1695aff23eSKrzysztof Drewniak 
1795aff23eSKrzysztof Drewniak bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
1895aff23eSKrzysztof Drewniak   return umin().getBitWidth() == other.umin().getBitWidth() &&
1995aff23eSKrzysztof Drewniak          umin() == other.umin() && umax() == other.umax() &&
2095aff23eSKrzysztof Drewniak          smin() == other.smin() && smax() == other.smax();
2195aff23eSKrzysztof Drewniak }
2295aff23eSKrzysztof Drewniak 
2395aff23eSKrzysztof Drewniak const APInt &ConstantIntRanges::umin() const { return uminVal; }
2495aff23eSKrzysztof Drewniak 
2595aff23eSKrzysztof Drewniak const APInt &ConstantIntRanges::umax() const { return umaxVal; }
2695aff23eSKrzysztof Drewniak 
2795aff23eSKrzysztof Drewniak const APInt &ConstantIntRanges::smin() const { return sminVal; }
2895aff23eSKrzysztof Drewniak 
2995aff23eSKrzysztof Drewniak const APInt &ConstantIntRanges::smax() const { return smaxVal; }
3095aff23eSKrzysztof Drewniak 
3195aff23eSKrzysztof Drewniak unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
32*616aff12SKrzysztof Drewniak   type = getElementTypeOrSelf(type);
3395aff23eSKrzysztof Drewniak   if (type.isIndex())
3495aff23eSKrzysztof Drewniak     return IndexType::kInternalStorageBitWidth;
355550c821STres Popp   if (auto integerType = dyn_cast<IntegerType>(type))
3695aff23eSKrzysztof Drewniak     return integerType.getWidth();
3795aff23eSKrzysztof Drewniak   // Non-integer types have their bounds stored in width 0 `APInt`s.
3895aff23eSKrzysztof Drewniak   return 0;
3995aff23eSKrzysztof Drewniak }
4095aff23eSKrzysztof Drewniak 
4175bfc6f2SKrzysztof Drewniak ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
4275bfc6f2SKrzysztof Drewniak   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
4375bfc6f2SKrzysztof Drewniak }
4475bfc6f2SKrzysztof Drewniak 
4575bfc6f2SKrzysztof Drewniak ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
4675bfc6f2SKrzysztof Drewniak   return {value, value, value, value};
4775bfc6f2SKrzysztof Drewniak }
4875bfc6f2SKrzysztof Drewniak 
4975bfc6f2SKrzysztof Drewniak ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
5075bfc6f2SKrzysztof Drewniak                                            bool isSigned) {
5175bfc6f2SKrzysztof Drewniak   if (isSigned)
5275bfc6f2SKrzysztof Drewniak     return fromSigned(min, max);
5375bfc6f2SKrzysztof Drewniak   return fromUnsigned(min, max);
5495aff23eSKrzysztof Drewniak }
5595aff23eSKrzysztof Drewniak 
5695aff23eSKrzysztof Drewniak ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
5795aff23eSKrzysztof Drewniak                                                 const APInt &smax) {
5895aff23eSKrzysztof Drewniak   unsigned int width = smin.getBitWidth();
5995aff23eSKrzysztof Drewniak   APInt umin, umax;
6095aff23eSKrzysztof Drewniak   if (smin.isNonNegative() == smax.isNonNegative()) {
6195aff23eSKrzysztof Drewniak     umin = smin.ult(smax) ? smin : smax;
6295aff23eSKrzysztof Drewniak     umax = smin.ugt(smax) ? smin : smax;
6395aff23eSKrzysztof Drewniak   } else {
6495aff23eSKrzysztof Drewniak     umin = APInt::getMinValue(width);
6595aff23eSKrzysztof Drewniak     umax = APInt::getMaxValue(width);
6695aff23eSKrzysztof Drewniak   }
6795aff23eSKrzysztof Drewniak   return {umin, umax, smin, smax};
6895aff23eSKrzysztof Drewniak }
6995aff23eSKrzysztof Drewniak 
7095aff23eSKrzysztof Drewniak ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
7195aff23eSKrzysztof Drewniak                                                   const APInt &umax) {
7295aff23eSKrzysztof Drewniak   unsigned int width = umin.getBitWidth();
7395aff23eSKrzysztof Drewniak   APInt smin, smax;
7495aff23eSKrzysztof Drewniak   if (umin.isNonNegative() == umax.isNonNegative()) {
7595aff23eSKrzysztof Drewniak     smin = umin.slt(umax) ? umin : umax;
7695aff23eSKrzysztof Drewniak     smax = umin.sgt(umax) ? umin : umax;
7795aff23eSKrzysztof Drewniak   } else {
7895aff23eSKrzysztof Drewniak     smin = APInt::getSignedMinValue(width);
7995aff23eSKrzysztof Drewniak     smax = APInt::getSignedMaxValue(width);
8095aff23eSKrzysztof Drewniak   }
8195aff23eSKrzysztof Drewniak   return {umin, umax, smin, smax};
8295aff23eSKrzysztof Drewniak }
8395aff23eSKrzysztof Drewniak 
8495aff23eSKrzysztof Drewniak ConstantIntRanges
8595aff23eSKrzysztof Drewniak ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
8695aff23eSKrzysztof Drewniak   // "Not an integer" poisons everything and also cannot be fed to comparison
8795aff23eSKrzysztof Drewniak   // operators.
8895aff23eSKrzysztof Drewniak   if (umin().getBitWidth() == 0)
8995aff23eSKrzysztof Drewniak     return *this;
9095aff23eSKrzysztof Drewniak   if (other.umin().getBitWidth() == 0)
9195aff23eSKrzysztof Drewniak     return other;
9295aff23eSKrzysztof Drewniak 
9395aff23eSKrzysztof Drewniak   const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
9495aff23eSKrzysztof Drewniak   const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
9595aff23eSKrzysztof Drewniak   const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
9695aff23eSKrzysztof Drewniak   const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
9795aff23eSKrzysztof Drewniak 
9895aff23eSKrzysztof Drewniak   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
9995aff23eSKrzysztof Drewniak }
10095aff23eSKrzysztof Drewniak 
10175bfc6f2SKrzysztof Drewniak ConstantIntRanges
10275bfc6f2SKrzysztof Drewniak ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
10375bfc6f2SKrzysztof Drewniak   // "Not an integer" poisons everything and also cannot be fed to comparison
10475bfc6f2SKrzysztof Drewniak   // operators.
10575bfc6f2SKrzysztof Drewniak   if (umin().getBitWidth() == 0)
10675bfc6f2SKrzysztof Drewniak     return *this;
10775bfc6f2SKrzysztof Drewniak   if (other.umin().getBitWidth() == 0)
10875bfc6f2SKrzysztof Drewniak     return other;
10975bfc6f2SKrzysztof Drewniak 
11075bfc6f2SKrzysztof Drewniak   const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
11175bfc6f2SKrzysztof Drewniak   const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
11275bfc6f2SKrzysztof Drewniak   const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
11375bfc6f2SKrzysztof Drewniak   const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
11475bfc6f2SKrzysztof Drewniak 
11575bfc6f2SKrzysztof Drewniak   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
11675bfc6f2SKrzysztof Drewniak }
11775bfc6f2SKrzysztof Drewniak 
1180a81ace0SKazu Hirata std::optional<APInt> ConstantIntRanges::getConstantValue() const {
11995aff23eSKrzysztof Drewniak   // Note: we need to exclude the trivially-equal width 0 values here.
12095aff23eSKrzysztof Drewniak   if (umin() == umax() && umin().getBitWidth() != 0)
12195aff23eSKrzysztof Drewniak     return umin();
12295aff23eSKrzysztof Drewniak   if (smin() == smax() && smin().getBitWidth() != 0)
12395aff23eSKrzysztof Drewniak     return smin();
1241a36588eSKazu Hirata   return std::nullopt;
12595aff23eSKrzysztof Drewniak }
12695aff23eSKrzysztof Drewniak 
12795aff23eSKrzysztof Drewniak raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
12895aff23eSKrzysztof Drewniak   return os << "unsigned : [" << range.umin() << ", " << range.umax()
12995aff23eSKrzysztof Drewniak             << "] signed : [" << range.smin() << ", " << range.smax() << "]";
13095aff23eSKrzysztof Drewniak }
1316aeea700SSpenser Bauman 
1326aeea700SSpenser Bauman IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
1336aeea700SSpenser Bauman   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
1346aeea700SSpenser Bauman   if (width == 0)
1356aeea700SSpenser Bauman     return {};
1366aeea700SSpenser Bauman 
1376aeea700SSpenser Bauman   APInt umin = APInt::getMinValue(width);
1386aeea700SSpenser Bauman   APInt umax = APInt::getMaxValue(width);
1396aeea700SSpenser Bauman   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
1406aeea700SSpenser Bauman   APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
1416aeea700SSpenser Bauman   return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
1426aeea700SSpenser Bauman }
1436aeea700SSpenser Bauman 
1446aeea700SSpenser Bauman raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
1456aeea700SSpenser Bauman   range.print(os);
1466aeea700SSpenser Bauman   return os;
1476aeea700SSpenser Bauman }
1486aeea700SSpenser Bauman 
1496aeea700SSpenser Bauman void mlir::intrange::detail::defaultInferResultRanges(
1506aeea700SSpenser Bauman     InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
1516aeea700SSpenser Bauman     SetIntLatticeFn setResultRanges) {
1526aeea700SSpenser Bauman   llvm::SmallVector<ConstantIntRanges> unpacked;
1536aeea700SSpenser Bauman   unpacked.reserve(argRanges.size());
1546aeea700SSpenser Bauman 
1556aeea700SSpenser Bauman   for (const IntegerValueRange &range : argRanges) {
1566aeea700SSpenser Bauman     if (range.isUninitialized())
1576aeea700SSpenser Bauman       return;
1586aeea700SSpenser Bauman     unpacked.push_back(range.getValue());
1596aeea700SSpenser Bauman   }
1606aeea700SSpenser Bauman 
1616aeea700SSpenser Bauman   interface.inferResultRanges(
1626aeea700SSpenser Bauman       unpacked,
1636aeea700SSpenser Bauman       [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
1646aeea700SSpenser Bauman         setResultRanges(value, IntegerValueRange{argRanges});
1656aeea700SSpenser Bauman       });
1666aeea700SSpenser Bauman }
1676aeea700SSpenser Bauman 
1686aeea700SSpenser Bauman void mlir::intrange::detail::defaultInferResultRangesFromOptional(
1696aeea700SSpenser Bauman     InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
1706aeea700SSpenser Bauman     SetIntRangeFn setResultRanges) {
1716aeea700SSpenser Bauman   auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
1726aeea700SSpenser Bauman   interface.inferResultRangesFromOptional(
1736aeea700SSpenser Bauman       ranges,
1746aeea700SSpenser Bauman       [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
1756aeea700SSpenser Bauman         if (!argRanges.isUninitialized())
1766aeea700SSpenser Bauman           setResultRanges(value, argRanges.getValue());
1776aeea700SSpenser Bauman       });
1786aeea700SSpenser Bauman }
179