1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Interfaces/InferIntRangeInterface.h" 11 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" 12 13 #include "llvm/Support/Debug.h" 14 #include <optional> 15 16 #define DEBUG_TYPE "int-range-analysis" 17 18 using namespace mlir; 19 using namespace mlir::arith; 20 using namespace mlir::intrange; 21 22 static intrange::OverflowFlags 23 convertArithOverflowFlags(arith::IntegerOverflowFlags flags) { 24 intrange::OverflowFlags retFlags = intrange::OverflowFlags::None; 25 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw)) 26 retFlags |= intrange::OverflowFlags::Nsw; 27 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw)) 28 retFlags |= intrange::OverflowFlags::Nuw; 29 return retFlags; 30 } 31 32 //===----------------------------------------------------------------------===// 33 // ConstantOp 34 //===----------------------------------------------------------------------===// 35 36 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 37 SetIntRangeFn setResultRange) { 38 if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) { 39 const APInt &value = scalarCstAttr.getValue(); 40 setResultRange(getResult(), ConstantIntRanges::constant(value)); 41 return; 42 } 43 if (auto arrayCstAttr = 44 llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) { 45 if (arrayCstAttr.isSplat()) { 46 setResultRange(getResult(), ConstantIntRanges::constant( 47 arrayCstAttr.getSplatValue<APInt>())); 48 return; 49 } 50 51 std::optional<ConstantIntRanges> result; 52 for (const APInt &val : arrayCstAttr) { 53 auto range = ConstantIntRanges::constant(val); 54 result = (result ? result->rangeUnion(range) : range); 55 } 56 57 assert(result && "Zero-sized vectors are not allowed"); 58 setResultRange(getResult(), *result); 59 return; 60 } 61 } 62 63 //===----------------------------------------------------------------------===// 64 // AddIOp 65 //===----------------------------------------------------------------------===// 66 67 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 68 SetIntRangeFn setResultRange) { 69 setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags( 70 getOverflowFlags()))); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // SubIOp 75 //===----------------------------------------------------------------------===// 76 77 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 78 SetIntRangeFn setResultRange) { 79 setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags( 80 getOverflowFlags()))); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // MulIOp 85 //===----------------------------------------------------------------------===// 86 87 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 88 SetIntRangeFn setResultRange) { 89 setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags( 90 getOverflowFlags()))); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // DivUIOp 95 //===----------------------------------------------------------------------===// 96 97 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 98 SetIntRangeFn setResultRange) { 99 setResultRange(getResult(), inferDivU(argRanges)); 100 } 101 102 //===----------------------------------------------------------------------===// 103 // DivSIOp 104 //===----------------------------------------------------------------------===// 105 106 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 107 SetIntRangeFn setResultRange) { 108 setResultRange(getResult(), inferDivS(argRanges)); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // CeilDivUIOp 113 //===----------------------------------------------------------------------===// 114 115 void arith::CeilDivUIOp::inferResultRanges( 116 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 117 setResultRange(getResult(), inferCeilDivU(argRanges)); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // CeilDivSIOp 122 //===----------------------------------------------------------------------===// 123 124 void arith::CeilDivSIOp::inferResultRanges( 125 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 126 setResultRange(getResult(), inferCeilDivS(argRanges)); 127 } 128 129 //===----------------------------------------------------------------------===// 130 // FloorDivSIOp 131 //===----------------------------------------------------------------------===// 132 133 void arith::FloorDivSIOp::inferResultRanges( 134 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 135 return setResultRange(getResult(), inferFloorDivS(argRanges)); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // RemUIOp 140 //===----------------------------------------------------------------------===// 141 142 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 143 SetIntRangeFn setResultRange) { 144 setResultRange(getResult(), inferRemU(argRanges)); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // RemSIOp 149 //===----------------------------------------------------------------------===// 150 151 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 152 SetIntRangeFn setResultRange) { 153 setResultRange(getResult(), inferRemS(argRanges)); 154 } 155 156 //===----------------------------------------------------------------------===// 157 // AndIOp 158 //===----------------------------------------------------------------------===// 159 160 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 161 SetIntRangeFn setResultRange) { 162 setResultRange(getResult(), inferAnd(argRanges)); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // OrIOp 167 //===----------------------------------------------------------------------===// 168 169 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 170 SetIntRangeFn setResultRange) { 171 setResultRange(getResult(), inferOr(argRanges)); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // XOrIOp 176 //===----------------------------------------------------------------------===// 177 178 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 179 SetIntRangeFn setResultRange) { 180 setResultRange(getResult(), inferXor(argRanges)); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // MaxSIOp 185 //===----------------------------------------------------------------------===// 186 187 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 188 SetIntRangeFn setResultRange) { 189 setResultRange(getResult(), inferMaxS(argRanges)); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // MaxUIOp 194 //===----------------------------------------------------------------------===// 195 196 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 197 SetIntRangeFn setResultRange) { 198 setResultRange(getResult(), inferMaxU(argRanges)); 199 } 200 201 //===----------------------------------------------------------------------===// 202 // MinSIOp 203 //===----------------------------------------------------------------------===// 204 205 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 206 SetIntRangeFn setResultRange) { 207 setResultRange(getResult(), inferMinS(argRanges)); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // MinUIOp 212 //===----------------------------------------------------------------------===// 213 214 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 215 SetIntRangeFn setResultRange) { 216 setResultRange(getResult(), inferMinU(argRanges)); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // ExtUIOp 221 //===----------------------------------------------------------------------===// 222 223 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 224 SetIntRangeFn setResultRange) { 225 unsigned destWidth = 226 ConstantIntRanges::getStorageBitwidth(getResult().getType()); 227 setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // ExtSIOp 232 //===----------------------------------------------------------------------===// 233 234 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 235 SetIntRangeFn setResultRange) { 236 unsigned destWidth = 237 ConstantIntRanges::getStorageBitwidth(getResult().getType()); 238 setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // TruncIOp 243 //===----------------------------------------------------------------------===// 244 245 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 246 SetIntRangeFn setResultRange) { 247 unsigned destWidth = 248 ConstantIntRanges::getStorageBitwidth(getResult().getType()); 249 setResultRange(getResult(), truncRange(argRanges[0], destWidth)); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // IndexCastOp 254 //===----------------------------------------------------------------------===// 255 256 void arith::IndexCastOp::inferResultRanges( 257 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 258 Type sourceType = getOperand().getType(); 259 Type destType = getResult().getType(); 260 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); 261 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 262 263 if (srcWidth < destWidth) 264 setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); 265 else if (srcWidth > destWidth) 266 setResultRange(getResult(), truncRange(argRanges[0], destWidth)); 267 else 268 setResultRange(getResult(), argRanges[0]); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // IndexCastUIOp 273 //===----------------------------------------------------------------------===// 274 275 void arith::IndexCastUIOp::inferResultRanges( 276 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 277 Type sourceType = getOperand().getType(); 278 Type destType = getResult().getType(); 279 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); 280 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 281 282 if (srcWidth < destWidth) 283 setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); 284 else if (srcWidth > destWidth) 285 setResultRange(getResult(), truncRange(argRanges[0], destWidth)); 286 else 287 setResultRange(getResult(), argRanges[0]); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // CmpIOp 292 //===----------------------------------------------------------------------===// 293 294 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 295 SetIntRangeFn setResultRange) { 296 arith::CmpIPredicate arithPred = getPredicate(); 297 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred); 298 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 299 300 APInt min = APInt::getZero(1); 301 APInt max = APInt::getAllOnes(1); 302 303 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs); 304 if (truthValue.has_value() && *truthValue) 305 min = max; 306 else if (truthValue.has_value() && !(*truthValue)) 307 max = min; 308 309 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // SelectOp 314 //===----------------------------------------------------------------------===// 315 316 void arith::SelectOp::inferResultRangesFromOptional( 317 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) { 318 std::optional<APInt> mbCondVal = 319 argRanges[0].isUninitialized() 320 ? std::nullopt 321 : argRanges[0].getValue().getConstantValue(); 322 323 const IntegerValueRange &trueCase = argRanges[1]; 324 const IntegerValueRange &falseCase = argRanges[2]; 325 326 if (mbCondVal) { 327 if (mbCondVal->isZero()) 328 setResultRange(getResult(), falseCase); 329 else 330 setResultRange(getResult(), trueCase); 331 return; 332 } 333 setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase)); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // ShLIOp 338 //===----------------------------------------------------------------------===// 339 340 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 341 SetIntRangeFn setResultRange) { 342 setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags( 343 getOverflowFlags()))); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ShRUIOp 348 //===----------------------------------------------------------------------===// 349 350 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 351 SetIntRangeFn setResultRange) { 352 setResultRange(getResult(), inferShrU(argRanges)); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // ShRSIOp 357 //===----------------------------------------------------------------------===// 358 359 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 360 SetIntRangeFn setResultRange) { 361 setResultRange(getResult(), inferShrS(argRanges)); 362 } 363