//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "int-range-analysis" using namespace mlir; using namespace mlir::index; using namespace mlir::intrange; //===----------------------------------------------------------------------===// // Constants //===----------------------------------------------------------------------===// void ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const APInt &value = getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); } void BoolConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { bool value = getValue(); APInt asInt(/*numBits=*/1, value); setResultRange(getResult(), ConstantIntRanges::constant(asInt)); } //===----------------------------------------------------------------------===// // Arithmec operations. All of these operations will have their results inferred // using both the 64-bit values and truncated 32-bit values of their inputs, // with the results being the union of those inferences, except where the // truncation of the 64-bit result is equal to the 32-bit result (at which time // we take the 64-bit result). //===----------------------------------------------------------------------===// // Some arithmetic inference functions allow specifying special overflow / wrap // behavior. We do not require this for the IndexOps and use this helper to call // the inference function without any `OverflowFlags`. static std::function)> inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) { return [inferWithOvfFn](ArrayRef argRanges) { return inferWithOvfFn(argRanges, OverflowFlags::None); }; } void AddOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd), argRanges, CmpMode::Both)); } void SubOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub), argRanges, CmpMode::Both)); } void MulOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul), argRanges, CmpMode::Both)); } void DivUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); } void DivSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); } void CeilDivUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); } void CeilDivSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); } void FloorDivSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { return setResultRange( getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); } void RemSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); } void RemUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); } void MaxSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); } void MaxUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); } void MinSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); } void MinUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); } void ShlOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl), argRanges, CmpMode::Both)); } void ShrSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); } void ShrUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); } void AndOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); } void OrOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); } void XOrOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); } //===----------------------------------------------------------------------===// // Casts //===----------------------------------------------------------------------===// static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, unsigned srcWidth, unsigned destWidth, bool isSigned) { if (srcWidth < destWidth) return isSigned ? extSIRange(range, destWidth) : extUIRange(range, destWidth); if (srcWidth > destWidth) return truncRange(range, destWidth); return range; } // When casting to `index`, we will take the union of the possible fixed-width // casts. static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, Type sourceType, Type destType, bool isSigned) { unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (sourceType.isIndex()) return makeLikeDest(range, srcWidth, destWidth, isSigned); // We are casting to indexs, so use the union of the 32-bit and 64-bit casts ConstantIntRanges storageRange = makeLikeDest(range, srcWidth, destWidth, isSigned); ConstantIntRanges minWidthRange = makeLikeDest(range, srcWidth, indexMinWidth, isSigned); ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); return ret; } void CastSOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type sourceType = getOperand().getType(); Type destType = getResult().getType(); setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, /*isSigned=*/true)); } void CastUOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type sourceType = getOperand().getType(); Type destType = getResult().getType(); setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, /*isSigned=*/false)); } //===----------------------------------------------------------------------===// // CmpOp //===----------------------------------------------------------------------===// void CmpOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { index::IndexCmpPredicate indexPred = getPred(); intrange::CmpPredicate pred = static_cast(indexPred); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnes(1); std::optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), rhsTrunc = truncRange(rhs, indexMinWidth); std::optional truthValue32 = intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); if (truthValue64 == truthValue32) { if (truthValue64.has_value() && *truthValue64) min = max; else if (truthValue64.has_value() && !(*truthValue64)) max = min; } setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); } //===----------------------------------------------------------------------===// // SizeOf, which is bounded between the two supported bitwidth (32 and 64). //===----------------------------------------------------------------------===// void SizeOfOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { unsigned storageWidth = ConstantIntRanges::getStorageBitwidth(getResult().getType()); APInt min(/*numBits=*/storageWidth, indexMinWidth); APInt max(/*numBits=*/storageWidth, indexMaxWidth); setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); }