1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Analysis/DataFlowFramework.h" 12 #include "mlir/Dialect/Arith/Transforms/Passes.h" 13 14 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Utils/StaticValueUtils.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 namespace mlir::arith { 27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS 28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 29 30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING 31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 32 } // namespace mlir::arith 33 34 using namespace mlir; 35 using namespace mlir::arith; 36 using namespace mlir::dataflow; 37 38 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver, 39 Value value) { 40 auto *maybeInferredRange = 41 solver.lookupState<IntegerValueRangeLattice>(value); 42 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) 43 return std::nullopt; 44 const ConstantIntRanges &inferredRange = 45 maybeInferredRange->getValue().getValue(); 46 return inferredRange.getConstantValue(); 47 } 48 49 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, 50 Value newVal) { 51 assert(oldVal.getType() == newVal.getType() && 52 "Can't copy integer ranges between different types"); 53 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal); 54 if (!oldState) 55 return; 56 (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join( 57 *oldState); 58 } 59 60 /// Patterned after SCCP 61 static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, 62 PatternRewriter &rewriter, 63 Value value) { 64 if (value.use_empty()) 65 return failure(); 66 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value); 67 if (!maybeConstValue.has_value()) 68 return failure(); 69 70 Type type = value.getType(); 71 Location loc = value.getLoc(); 72 Operation *maybeDefiningOp = value.getDefiningOp(); 73 Dialect *valueDialect = 74 maybeDefiningOp ? maybeDefiningOp->getDialect() 75 : value.getParentRegion()->getParentOp()->getDialect(); 76 77 Attribute constAttr; 78 if (auto shaped = dyn_cast<ShapedType>(type)) { 79 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); 80 } else { 81 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); 82 } 83 Operation *constOp = 84 valueDialect->materializeConstant(rewriter, constAttr, type, loc); 85 // Fall back to arith.constant if the dialect materializer doesn't know what 86 // to do with an integer constant. 87 if (!constOp) 88 constOp = rewriter.getContext() 89 ->getLoadedDialect<ArithDialect>() 90 ->materializeConstant(rewriter, constAttr, type, loc); 91 if (!constOp) 92 return failure(); 93 94 copyIntegerRange(solver, value, constOp->getResult(0)); 95 rewriter.replaceAllUsesWith(value, constOp->getResult(0)); 96 return success(); 97 } 98 99 namespace { 100 class DataFlowListener : public RewriterBase::Listener { 101 public: 102 DataFlowListener(DataFlowSolver &s) : s(s) {} 103 104 protected: 105 void notifyOperationErased(Operation *op) override { 106 s.eraseState(s.getProgramPointAfter(op)); 107 for (Value res : op->getResults()) 108 s.eraseState(res); 109 } 110 111 DataFlowSolver &s; 112 }; 113 114 /// Rewrite any results of `op` that were inferred to be constant integers to 115 /// and replace their uses with that constant. Return success() if all results 116 /// where thus replaced and the operation is erased. Also replace any block 117 /// arguments with their constant values. 118 struct MaterializeKnownConstantValues : public RewritePattern { 119 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) 120 : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), 121 solver(s) {} 122 123 LogicalResult match(Operation *op) const override { 124 if (matchPattern(op, m_Constant())) 125 return failure(); 126 127 auto needsReplacing = [&](Value v) { 128 return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); 129 }; 130 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); 131 if (op->getNumRegions() == 0) 132 return success(hasConstantResults); 133 bool hasConstantRegionArgs = false; 134 for (Region ®ion : op->getRegions()) { 135 for (Block &block : region.getBlocks()) { 136 hasConstantRegionArgs |= 137 llvm::any_of(block.getArguments(), needsReplacing); 138 } 139 } 140 return success(hasConstantResults || hasConstantRegionArgs); 141 } 142 143 void rewrite(Operation *op, PatternRewriter &rewriter) const override { 144 bool replacedAll = (op->getNumResults() != 0); 145 for (Value v : op->getResults()) 146 replacedAll &= 147 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || 148 v.use_empty()); 149 if (replacedAll && isOpTriviallyDead(op)) { 150 rewriter.eraseOp(op); 151 return; 152 } 153 154 PatternRewriter::InsertionGuard guard(rewriter); 155 for (Region ®ion : op->getRegions()) { 156 for (Block &block : region.getBlocks()) { 157 rewriter.setInsertionPointToStart(&block); 158 for (BlockArgument &arg : block.getArguments()) { 159 (void)maybeReplaceWithConstant(solver, rewriter, arg); 160 } 161 } 162 } 163 } 164 165 private: 166 DataFlowSolver &solver; 167 }; 168 169 template <typename RemOp> 170 struct DeleteTrivialRem : public OpRewritePattern<RemOp> { 171 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) 172 : OpRewritePattern<RemOp>(context), solver(s) {} 173 174 LogicalResult matchAndRewrite(RemOp op, 175 PatternRewriter &rewriter) const override { 176 Value lhs = op.getOperand(0); 177 Value rhs = op.getOperand(1); 178 auto maybeModulus = getConstantIntValue(rhs); 179 if (!maybeModulus.has_value()) 180 return failure(); 181 int64_t modulus = *maybeModulus; 182 if (modulus <= 0) 183 return failure(); 184 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs); 185 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) 186 return failure(); 187 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); 188 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin(); 189 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax(); 190 // The minima and maxima here are given as closed ranges, we must be 191 // strictly less than the modulus. 192 if (min.isNegative() || min.uge(modulus)) 193 return failure(); 194 if (max.isNegative() || max.uge(modulus)) 195 return failure(); 196 if (!min.ule(max)) 197 return failure(); 198 199 // With all those conditions out of the way, we know thas this invocation of 200 // a remainder is a noop because the input is strictly within the range 201 // [0, modulus), so get rid of it. 202 rewriter.replaceOp(op, ValueRange{lhs}); 203 return success(); 204 } 205 206 private: 207 DataFlowSolver &solver; 208 }; 209 210 /// Gather ranges for all the values in `values`. Appends to the existing 211 /// vector. 212 static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, 213 SmallVectorImpl<ConstantIntRanges> &ranges) { 214 for (Value val : values) { 215 auto *maybeInferredRange = 216 solver.lookupState<IntegerValueRangeLattice>(val); 217 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) 218 return failure(); 219 220 const ConstantIntRanges &inferredRange = 221 maybeInferredRange->getValue().getValue(); 222 ranges.push_back(inferredRange); 223 } 224 return success(); 225 } 226 227 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, 228 /// return shaped type as well. 229 static Type getTargetType(Type srcType, unsigned targetBitwidth) { 230 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); 231 if (auto shaped = dyn_cast<ShapedType>(srcType)) 232 return shaped.clone(dstType); 233 234 assert(srcType.isIntOrIndex() && "Invalid src type"); 235 return dstType; 236 } 237 238 namespace { 239 // Enum for tracking which type of truncation should be performed 240 // to narrow an operation, if any. 241 enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; 242 } // namespace 243 244 /// If the values within `range` can be represented using only `width` bits, 245 /// return the kind of truncation needed to preserve that property. 246 /// 247 /// This check relies on the fact that the signed and unsigned ranges are both 248 /// always correct, but that one might be an approximation of the other, 249 /// so we want to use the correct truncation operation. 250 static CastKind checkTruncatability(const ConstantIntRanges &range, 251 unsigned targetWidth) { 252 unsigned srcWidth = range.smin().getBitWidth(); 253 if (srcWidth <= targetWidth) 254 return CastKind::None; 255 unsigned removedWidth = srcWidth - targetWidth; 256 // The sign bits need to extend into the sign bit of the target width. For 257 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign 258 // bits. 259 bool canTruncateSigned = 260 range.smin().getNumSignBits() >= (removedWidth + 1) && 261 range.smax().getNumSignBits() >= (removedWidth + 1); 262 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && 263 range.umax().countLeadingZeros() >= removedWidth; 264 if (canTruncateSigned && canTruncateUnsigned) 265 return CastKind::Both; 266 if (canTruncateSigned) 267 return CastKind::Signed; 268 if (canTruncateUnsigned) 269 return CastKind::Unsigned; 270 return CastKind::None; 271 } 272 273 static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { 274 if (lhs == CastKind::None || rhs == CastKind::None) 275 return CastKind::None; 276 if (lhs == CastKind::Both) 277 return rhs; 278 if (rhs == CastKind::Both) 279 return lhs; 280 if (lhs == rhs) 281 return lhs; 282 return CastKind::None; 283 } 284 285 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, 286 CastKind castKind) { 287 Type srcType = src.getType(); 288 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) && 289 "Mixing vector and non-vector types"); 290 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); 291 Type srcElemType = getElementTypeOrSelf(srcType); 292 Type dstElemType = getElementTypeOrSelf(dstType); 293 assert(srcElemType.isIntOrIndex() && "Invalid src type"); 294 assert(dstElemType.isIntOrIndex() && "Invalid dst type"); 295 if (srcType == dstType) 296 return src; 297 298 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) { 299 if (castKind == CastKind::Signed) 300 return builder.create<arith::IndexCastOp>(loc, dstType, src); 301 return builder.create<arith::IndexCastUIOp>(loc, dstType, src); 302 } 303 304 auto srcInt = cast<IntegerType>(srcElemType); 305 auto dstInt = cast<IntegerType>(dstElemType); 306 if (dstInt.getWidth() < srcInt.getWidth()) 307 return builder.create<arith::TruncIOp>(loc, dstType, src); 308 309 if (castKind == CastKind::Signed) 310 return builder.create<arith::ExtSIOp>(loc, dstType, src); 311 return builder.create<arith::ExtUIOp>(loc, dstType, src); 312 } 313 314 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> { 315 NarrowElementwise(MLIRContext *context, DataFlowSolver &s, 316 ArrayRef<unsigned> target) 317 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} 318 319 using OpTraitRewritePattern::OpTraitRewritePattern; 320 LogicalResult matchAndRewrite(Operation *op, 321 PatternRewriter &rewriter) const override { 322 if (op->getNumResults() == 0) 323 return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); 324 325 SmallVector<ConstantIntRanges> ranges; 326 if (failed(collectRanges(solver, op->getOperands(), ranges))) 327 return rewriter.notifyMatchFailure(op, "input without specified range"); 328 if (failed(collectRanges(solver, op->getResults(), ranges))) 329 return rewriter.notifyMatchFailure(op, "output without specified range"); 330 331 Type srcType = op->getResult(0).getType(); 332 if (!llvm::all_equal(op->getResultTypes())) 333 return rewriter.notifyMatchFailure(op, "mismatched result types"); 334 if (op->getNumOperands() == 0 || 335 !llvm::all_of(op->getOperandTypes(), 336 [=](Type t) { return t == srcType; })) 337 return rewriter.notifyMatchFailure( 338 op, "no operands or operand types don't match result type"); 339 340 for (unsigned targetBitwidth : targetBitwidths) { 341 CastKind castKind = CastKind::Both; 342 for (const ConstantIntRanges &range : ranges) { 343 castKind = mergeCastKinds(castKind, 344 checkTruncatability(range, targetBitwidth)); 345 if (castKind == CastKind::None) 346 break; 347 } 348 if (castKind == CastKind::None) 349 continue; 350 Type targetType = getTargetType(srcType, targetBitwidth); 351 if (targetType == srcType) 352 continue; 353 354 Location loc = op->getLoc(); 355 IRMapping mapping; 356 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { 357 CastKind argCastKind = castKind; 358 // When dealing with `index` values, preserve non-negativity in the 359 // index_casts since we can't recover this in unsigned when equivalent. 360 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) 361 argCastKind = CastKind::Both; 362 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); 363 mapping.map(arg, newArg); 364 } 365 366 Operation *newOp = rewriter.clone(*op, mapping); 367 rewriter.modifyOpInPlace(newOp, [&]() { 368 for (OpResult res : newOp->getResults()) { 369 res.setType(targetType); 370 } 371 }); 372 SmallVector<Value> newResults; 373 for (auto [newRes, oldRes] : 374 llvm::zip_equal(newOp->getResults(), op->getResults())) { 375 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); 376 copyIntegerRange(solver, oldRes, castBack); 377 newResults.push_back(castBack); 378 } 379 380 rewriter.replaceOp(op, newResults); 381 return success(); 382 } 383 return failure(); 384 } 385 386 private: 387 DataFlowSolver &solver; 388 SmallVector<unsigned, 4> targetBitwidths; 389 }; 390 391 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> { 392 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target) 393 : OpRewritePattern(context), solver(s), targetBitwidths(target) {} 394 395 LogicalResult matchAndRewrite(arith::CmpIOp op, 396 PatternRewriter &rewriter) const override { 397 Value lhs = op.getLhs(); 398 Value rhs = op.getRhs(); 399 400 SmallVector<ConstantIntRanges> ranges; 401 if (failed(collectRanges(solver, op.getOperands(), ranges))) 402 return failure(); 403 const ConstantIntRanges &lhsRange = ranges[0]; 404 const ConstantIntRanges &rhsRange = ranges[1]; 405 406 Type srcType = lhs.getType(); 407 for (unsigned targetBitwidth : targetBitwidths) { 408 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); 409 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); 410 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); 411 // Note: this includes target width > src width. 412 if (castKind == CastKind::None) 413 continue; 414 415 Type targetType = getTargetType(srcType, targetBitwidth); 416 if (targetType == srcType) 417 continue; 418 419 Location loc = op->getLoc(); 420 IRMapping mapping; 421 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); 422 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); 423 mapping.map(lhs, lhsCast); 424 mapping.map(rhs, rhsCast); 425 426 Operation *newOp = rewriter.clone(*op, mapping); 427 copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); 428 rewriter.replaceOp(op, newOp->getResults()); 429 return success(); 430 } 431 return failure(); 432 } 433 434 private: 435 DataFlowSolver &solver; 436 SmallVector<unsigned, 4> targetBitwidths; 437 }; 438 439 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg 440 /// This pattern assumes all passed `targetBitwidths` are not wider than index 441 /// type. 442 template <typename CastOp> 443 struct FoldIndexCastChain final : OpRewritePattern<CastOp> { 444 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target) 445 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {} 446 447 LogicalResult matchAndRewrite(CastOp op, 448 PatternRewriter &rewriter) const override { 449 auto srcOp = op.getIn().template getDefiningOp<CastOp>(); 450 if (!srcOp) 451 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); 452 453 Value src = srcOp.getIn(); 454 if (src.getType() != op.getType()) 455 return rewriter.notifyMatchFailure(op, "outer types don't match"); 456 457 if (!srcOp.getType().isIndex()) 458 return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); 459 460 auto intType = dyn_cast<IntegerType>(op.getType()); 461 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) 462 return failure(); 463 464 rewriter.replaceOp(op, src); 465 return success(); 466 } 467 468 private: 469 SmallVector<unsigned, 4> targetBitwidths; 470 }; 471 472 struct IntRangeOptimizationsPass final 473 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> { 474 475 void runOnOperation() override { 476 Operation *op = getOperation(); 477 MLIRContext *ctx = op->getContext(); 478 DataFlowSolver solver; 479 solver.load<DeadCodeAnalysis>(); 480 solver.load<IntegerRangeAnalysis>(); 481 if (failed(solver.initializeAndRun(op))) 482 return signalPassFailure(); 483 484 DataFlowListener listener(solver); 485 486 RewritePatternSet patterns(ctx); 487 populateIntRangeOptimizationsPatterns(patterns, solver); 488 489 GreedyRewriteConfig config; 490 config.listener = &listener; 491 492 if (failed(applyPatternsGreedily(op, std::move(patterns), config))) 493 signalPassFailure(); 494 } 495 }; 496 497 struct IntRangeNarrowingPass final 498 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> { 499 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; 500 501 void runOnOperation() override { 502 Operation *op = getOperation(); 503 MLIRContext *ctx = op->getContext(); 504 DataFlowSolver solver; 505 solver.load<DeadCodeAnalysis>(); 506 solver.load<IntegerRangeAnalysis>(); 507 if (failed(solver.initializeAndRun(op))) 508 return signalPassFailure(); 509 510 DataFlowListener listener(solver); 511 512 RewritePatternSet patterns(ctx); 513 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); 514 515 GreedyRewriteConfig config; 516 // We specifically need bottom-up traversal as cmpi pattern needs range 517 // data, attached to its original argument values. 518 config.useTopDownTraversal = false; 519 config.listener = &listener; 520 521 if (failed(applyPatternsGreedily(op, std::move(patterns), config))) 522 signalPassFailure(); 523 } 524 }; 525 } // namespace 526 527 void mlir::arith::populateIntRangeOptimizationsPatterns( 528 RewritePatternSet &patterns, DataFlowSolver &solver) { 529 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>, 530 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver); 531 } 532 533 void mlir::arith::populateIntRangeNarrowingPatterns( 534 RewritePatternSet &patterns, DataFlowSolver &solver, 535 ArrayRef<unsigned> bitwidthsSupported) { 536 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver, 537 bitwidthsSupported); 538 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>, 539 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(), 540 bitwidthsSupported); 541 } 542 543 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() { 544 return std::make_unique<IntRangeOptimizationsPass>(); 545 } 546