11a867bf1SIvan Butygin //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// 21a867bf1SIvan Butygin // 31a867bf1SIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41a867bf1SIvan Butygin // See https://llvm.org/LICENSE.txt for license information. 51a867bf1SIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61a867bf1SIvan Butygin // 71a867bf1SIvan Butygin //===----------------------------------------------------------------------===// 81a867bf1SIvan Butygin 951911a62SMehdi Amini #include <utility> 1051911a62SMehdi Amini 1147229111SKrzysztof Drewniak #include "mlir/Analysis/DataFlowFramework.h" 121a867bf1SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h" 131a867bf1SIvan Butygin 141a867bf1SIvan Butygin #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 151a867bf1SIvan Butygin #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 161a867bf1SIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h" 1747229111SKrzysztof Drewniak #include "mlir/Dialect/Utils/StaticValueUtils.h" 189f0f6df0SIvan Butygin #include "mlir/IR/IRMapping.h" 1947229111SKrzysztof Drewniak #include "mlir/IR/Matchers.h" 2047229111SKrzysztof Drewniak #include "mlir/IR/PatternMatch.h" 219f0f6df0SIvan Butygin #include "mlir/IR/TypeUtilities.h" 2247229111SKrzysztof Drewniak #include "mlir/Interfaces/SideEffectInterfaces.h" 2347229111SKrzysztof Drewniak #include "mlir/Transforms/FoldUtils.h" 241a867bf1SIvan Butygin #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 251a867bf1SIvan Butygin 261a867bf1SIvan Butygin namespace mlir::arith { 271a867bf1SIvan Butygin #define GEN_PASS_DEF_ARITHINTRANGEOPTS 281a867bf1SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 299f0f6df0SIvan Butygin 309f0f6df0SIvan Butygin #define GEN_PASS_DEF_ARITHINTRANGENARROWING 319f0f6df0SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 321a867bf1SIvan Butygin } // namespace mlir::arith 331a867bf1SIvan Butygin 341a867bf1SIvan Butygin using namespace mlir; 351a867bf1SIvan Butygin using namespace mlir::arith; 361a867bf1SIvan Butygin using namespace mlir::dataflow; 371a867bf1SIvan Butygin 3847229111SKrzysztof Drewniak static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver, 3947229111SKrzysztof Drewniak Value value) { 4047229111SKrzysztof Drewniak auto *maybeInferredRange = 4147229111SKrzysztof Drewniak solver.lookupState<IntegerValueRangeLattice>(value); 4247229111SKrzysztof Drewniak if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) 4347229111SKrzysztof Drewniak return std::nullopt; 4447229111SKrzysztof Drewniak const ConstantIntRanges &inferredRange = 4547229111SKrzysztof Drewniak maybeInferredRange->getValue().getValue(); 4647229111SKrzysztof Drewniak return inferredRange.getConstantValue(); 471a867bf1SIvan Butygin } 481a867bf1SIvan Butygin 499bf79308SKrzysztof Drewniak static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, 509bf79308SKrzysztof Drewniak Value newVal) { 519bf79308SKrzysztof Drewniak assert(oldVal.getType() == newVal.getType() && 529bf79308SKrzysztof Drewniak "Can't copy integer ranges between different types"); 539bf79308SKrzysztof Drewniak auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal); 549bf79308SKrzysztof Drewniak if (!oldState) 559bf79308SKrzysztof Drewniak return; 569bf79308SKrzysztof Drewniak (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join( 579bf79308SKrzysztof Drewniak *oldState); 589bf79308SKrzysztof Drewniak } 599bf79308SKrzysztof Drewniak 6047229111SKrzysztof Drewniak /// Patterned after SCCP 6147229111SKrzysztof Drewniak static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, 6247229111SKrzysztof Drewniak PatternRewriter &rewriter, 6347229111SKrzysztof Drewniak Value value) { 6447229111SKrzysztof Drewniak if (value.use_empty()) 651a867bf1SIvan Butygin return failure(); 6647229111SKrzysztof Drewniak std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value); 6747229111SKrzysztof Drewniak if (!maybeConstValue.has_value()) 681a867bf1SIvan Butygin return failure(); 691a867bf1SIvan Butygin 70f54cdc5dSIvan Butygin Type type = value.getType(); 71f54cdc5dSIvan Butygin Location loc = value.getLoc(); 7247229111SKrzysztof Drewniak Operation *maybeDefiningOp = value.getDefiningOp(); 7347229111SKrzysztof Drewniak Dialect *valueDialect = 7447229111SKrzysztof Drewniak maybeDefiningOp ? maybeDefiningOp->getDialect() 7547229111SKrzysztof Drewniak : value.getParentRegion()->getParentOp()->getDialect(); 76f54cdc5dSIvan Butygin 77f54cdc5dSIvan Butygin Attribute constAttr; 78f54cdc5dSIvan Butygin if (auto shaped = dyn_cast<ShapedType>(type)) { 79f54cdc5dSIvan Butygin constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); 80f54cdc5dSIvan Butygin } else { 81f54cdc5dSIvan Butygin constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); 82f54cdc5dSIvan Butygin } 83f54cdc5dSIvan Butygin Operation *constOp = 84f54cdc5dSIvan Butygin valueDialect->materializeConstant(rewriter, constAttr, type, loc); 8547229111SKrzysztof Drewniak // Fall back to arith.constant if the dialect materializer doesn't know what 8647229111SKrzysztof Drewniak // to do with an integer constant. 8747229111SKrzysztof Drewniak if (!constOp) 8847229111SKrzysztof Drewniak constOp = rewriter.getContext() 8947229111SKrzysztof Drewniak ->getLoadedDialect<ArithDialect>() 90f54cdc5dSIvan Butygin ->materializeConstant(rewriter, constAttr, type, loc); 9147229111SKrzysztof Drewniak if (!constOp) 921a867bf1SIvan Butygin return failure(); 931a867bf1SIvan Butygin 949bf79308SKrzysztof Drewniak copyIntegerRange(solver, value, constOp->getResult(0)); 9547229111SKrzysztof Drewniak rewriter.replaceAllUsesWith(value, constOp->getResult(0)); 9647229111SKrzysztof Drewniak return success(); 971a867bf1SIvan Butygin } 981a867bf1SIvan Butygin 991a867bf1SIvan Butygin namespace { 10078b3a004SFelix Schneider class DataFlowListener : public RewriterBase::Listener { 10178b3a004SFelix Schneider public: 10278b3a004SFelix Schneider DataFlowListener(DataFlowSolver &s) : s(s) {} 10378b3a004SFelix Schneider 10478b3a004SFelix Schneider protected: 10578b3a004SFelix Schneider void notifyOperationErased(Operation *op) override { 1064b3f251bSdonald chen s.eraseState(s.getProgramPointAfter(op)); 10778b3a004SFelix Schneider for (Value res : op->getResults()) 10878b3a004SFelix Schneider s.eraseState(res); 10978b3a004SFelix Schneider } 11078b3a004SFelix Schneider 11178b3a004SFelix Schneider DataFlowSolver &s; 11278b3a004SFelix Schneider }; 11378b3a004SFelix Schneider 11447229111SKrzysztof Drewniak /// Rewrite any results of `op` that were inferred to be constant integers to 11547229111SKrzysztof Drewniak /// and replace their uses with that constant. Return success() if all results 11647229111SKrzysztof Drewniak /// where thus replaced and the operation is erased. Also replace any block 11747229111SKrzysztof Drewniak /// arguments with their constant values. 11847229111SKrzysztof Drewniak struct MaterializeKnownConstantValues : public RewritePattern { 11947229111SKrzysztof Drewniak MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) 12047229111SKrzysztof Drewniak : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), 12147229111SKrzysztof Drewniak solver(s) {} 1221a867bf1SIvan Butygin 12347229111SKrzysztof Drewniak LogicalResult match(Operation *op) const override { 12447229111SKrzysztof Drewniak if (matchPattern(op, m_Constant())) 12547229111SKrzysztof Drewniak return failure(); 1261a867bf1SIvan Butygin 12747229111SKrzysztof Drewniak auto needsReplacing = [&](Value v) { 12847229111SKrzysztof Drewniak return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); 12947229111SKrzysztof Drewniak }; 13047229111SKrzysztof Drewniak bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); 13147229111SKrzysztof Drewniak if (op->getNumRegions() == 0) 13247229111SKrzysztof Drewniak return success(hasConstantResults); 13347229111SKrzysztof Drewniak bool hasConstantRegionArgs = false; 13447229111SKrzysztof Drewniak for (Region ®ion : op->getRegions()) { 13547229111SKrzysztof Drewniak for (Block &block : region.getBlocks()) { 13647229111SKrzysztof Drewniak hasConstantRegionArgs |= 13747229111SKrzysztof Drewniak llvm::any_of(block.getArguments(), needsReplacing); 13847229111SKrzysztof Drewniak } 13947229111SKrzysztof Drewniak } 14047229111SKrzysztof Drewniak return success(hasConstantResults || hasConstantRegionArgs); 14147229111SKrzysztof Drewniak } 14247229111SKrzysztof Drewniak 14347229111SKrzysztof Drewniak void rewrite(Operation *op, PatternRewriter &rewriter) const override { 14447229111SKrzysztof Drewniak bool replacedAll = (op->getNumResults() != 0); 14547229111SKrzysztof Drewniak for (Value v : op->getResults()) 14647229111SKrzysztof Drewniak replacedAll &= 14747229111SKrzysztof Drewniak (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || 14847229111SKrzysztof Drewniak v.use_empty()); 14947229111SKrzysztof Drewniak if (replacedAll && isOpTriviallyDead(op)) { 15047229111SKrzysztof Drewniak rewriter.eraseOp(op); 15147229111SKrzysztof Drewniak return; 15247229111SKrzysztof Drewniak } 15347229111SKrzysztof Drewniak 15447229111SKrzysztof Drewniak PatternRewriter::InsertionGuard guard(rewriter); 15547229111SKrzysztof Drewniak for (Region ®ion : op->getRegions()) { 15647229111SKrzysztof Drewniak for (Block &block : region.getBlocks()) { 15747229111SKrzysztof Drewniak rewriter.setInsertionPointToStart(&block); 15847229111SKrzysztof Drewniak for (BlockArgument &arg : block.getArguments()) { 15947229111SKrzysztof Drewniak (void)maybeReplaceWithConstant(solver, rewriter, arg); 16047229111SKrzysztof Drewniak } 16147229111SKrzysztof Drewniak } 16247229111SKrzysztof Drewniak } 16347229111SKrzysztof Drewniak } 16447229111SKrzysztof Drewniak 16547229111SKrzysztof Drewniak private: 16647229111SKrzysztof Drewniak DataFlowSolver &solver; 16747229111SKrzysztof Drewniak }; 16847229111SKrzysztof Drewniak 16947229111SKrzysztof Drewniak template <typename RemOp> 17047229111SKrzysztof Drewniak struct DeleteTrivialRem : public OpRewritePattern<RemOp> { 17147229111SKrzysztof Drewniak DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) 17247229111SKrzysztof Drewniak : OpRewritePattern<RemOp>(context), solver(s) {} 17347229111SKrzysztof Drewniak 17447229111SKrzysztof Drewniak LogicalResult matchAndRewrite(RemOp op, 1751a867bf1SIvan Butygin PatternRewriter &rewriter) const override { 17647229111SKrzysztof Drewniak Value lhs = op.getOperand(0); 17747229111SKrzysztof Drewniak Value rhs = op.getOperand(1); 17847229111SKrzysztof Drewniak auto maybeModulus = getConstantIntValue(rhs); 17947229111SKrzysztof Drewniak if (!maybeModulus.has_value()) 18047229111SKrzysztof Drewniak return failure(); 18147229111SKrzysztof Drewniak int64_t modulus = *maybeModulus; 18247229111SKrzysztof Drewniak if (modulus <= 0) 18347229111SKrzysztof Drewniak return failure(); 18447229111SKrzysztof Drewniak auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs); 18547229111SKrzysztof Drewniak if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) 18647229111SKrzysztof Drewniak return failure(); 18747229111SKrzysztof Drewniak const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); 18847229111SKrzysztof Drewniak const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin(); 18947229111SKrzysztof Drewniak const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax(); 19047229111SKrzysztof Drewniak // The minima and maxima here are given as closed ranges, we must be 19147229111SKrzysztof Drewniak // strictly less than the modulus. 19247229111SKrzysztof Drewniak if (min.isNegative() || min.uge(modulus)) 19347229111SKrzysztof Drewniak return failure(); 19447229111SKrzysztof Drewniak if (max.isNegative() || max.uge(modulus)) 19547229111SKrzysztof Drewniak return failure(); 19647229111SKrzysztof Drewniak if (!min.ule(max)) 1971a867bf1SIvan Butygin return failure(); 1981a867bf1SIvan Butygin 19947229111SKrzysztof Drewniak // With all those conditions out of the way, we know thas this invocation of 20047229111SKrzysztof Drewniak // a remainder is a noop because the input is strictly within the range 20147229111SKrzysztof Drewniak // [0, modulus), so get rid of it. 20247229111SKrzysztof Drewniak rewriter.replaceOp(op, ValueRange{lhs}); 2031a867bf1SIvan Butygin return success(); 2041a867bf1SIvan Butygin } 2051a867bf1SIvan Butygin 2061a867bf1SIvan Butygin private: 2071a867bf1SIvan Butygin DataFlowSolver &solver; 2081a867bf1SIvan Butygin }; 2091a867bf1SIvan Butygin 2109bf79308SKrzysztof Drewniak /// Gather ranges for all the values in `values`. Appends to the existing 2119bf79308SKrzysztof Drewniak /// vector. 2129bf79308SKrzysztof Drewniak static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, 2139bf79308SKrzysztof Drewniak SmallVectorImpl<ConstantIntRanges> &ranges) { 2149bf79308SKrzysztof Drewniak for (Value val : values) { 2159f0f6df0SIvan Butygin auto *maybeInferredRange = 2169bf79308SKrzysztof Drewniak solver.lookupState<IntegerValueRangeLattice>(val); 2179f0f6df0SIvan Butygin if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) 2189bf79308SKrzysztof Drewniak return failure(); 2199f0f6df0SIvan Butygin 2209f0f6df0SIvan Butygin const ConstantIntRanges &inferredRange = 2219f0f6df0SIvan Butygin maybeInferredRange->getValue().getValue(); 2229bf79308SKrzysztof Drewniak ranges.push_back(inferredRange); 2239f0f6df0SIvan Butygin } 2249bf79308SKrzysztof Drewniak return success(); 2259f0f6df0SIvan Butygin } 2269f0f6df0SIvan Butygin 2279f0f6df0SIvan Butygin /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, 2289f0f6df0SIvan Butygin /// return shaped type as well. 2299f0f6df0SIvan Butygin static Type getTargetType(Type srcType, unsigned targetBitwidth) { 2309f0f6df0SIvan Butygin auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); 2319f0f6df0SIvan Butygin if (auto shaped = dyn_cast<ShapedType>(srcType)) 2329f0f6df0SIvan Butygin return shaped.clone(dstType); 2339f0f6df0SIvan Butygin 2349f0f6df0SIvan Butygin assert(srcType.isIntOrIndex() && "Invalid src type"); 2359f0f6df0SIvan Butygin return dstType; 2369f0f6df0SIvan Butygin } 2379f0f6df0SIvan Butygin 2389bf79308SKrzysztof Drewniak namespace { 2399bf79308SKrzysztof Drewniak // Enum for tracking which type of truncation should be performed 2409bf79308SKrzysztof Drewniak // to narrow an operation, if any. 2419bf79308SKrzysztof Drewniak enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; 2429bf79308SKrzysztof Drewniak } // namespace 2439bf79308SKrzysztof Drewniak 2449bf79308SKrzysztof Drewniak /// If the values within `range` can be represented using only `width` bits, 2459bf79308SKrzysztof Drewniak /// return the kind of truncation needed to preserve that property. 2469bf79308SKrzysztof Drewniak /// 2479bf79308SKrzysztof Drewniak /// This check relies on the fact that the signed and unsigned ranges are both 2489bf79308SKrzysztof Drewniak /// always correct, but that one might be an approximation of the other, 2499bf79308SKrzysztof Drewniak /// so we want to use the correct truncation operation. 2509bf79308SKrzysztof Drewniak static CastKind checkTruncatability(const ConstantIntRanges &range, 2519bf79308SKrzysztof Drewniak unsigned targetWidth) { 2529bf79308SKrzysztof Drewniak unsigned srcWidth = range.smin().getBitWidth(); 2539bf79308SKrzysztof Drewniak if (srcWidth <= targetWidth) 2549bf79308SKrzysztof Drewniak return CastKind::None; 2559bf79308SKrzysztof Drewniak unsigned removedWidth = srcWidth - targetWidth; 2569bf79308SKrzysztof Drewniak // The sign bits need to extend into the sign bit of the target width. For 2579bf79308SKrzysztof Drewniak // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign 2589bf79308SKrzysztof Drewniak // bits. 2599bf79308SKrzysztof Drewniak bool canTruncateSigned = 2609bf79308SKrzysztof Drewniak range.smin().getNumSignBits() >= (removedWidth + 1) && 2619bf79308SKrzysztof Drewniak range.smax().getNumSignBits() >= (removedWidth + 1); 2629bf79308SKrzysztof Drewniak bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && 2639bf79308SKrzysztof Drewniak range.umax().countLeadingZeros() >= removedWidth; 2649bf79308SKrzysztof Drewniak if (canTruncateSigned && canTruncateUnsigned) 2659bf79308SKrzysztof Drewniak return CastKind::Both; 2669bf79308SKrzysztof Drewniak if (canTruncateSigned) 2679bf79308SKrzysztof Drewniak return CastKind::Signed; 2689bf79308SKrzysztof Drewniak if (canTruncateUnsigned) 2699bf79308SKrzysztof Drewniak return CastKind::Unsigned; 2709bf79308SKrzysztof Drewniak return CastKind::None; 2719f0f6df0SIvan Butygin } 2729f0f6df0SIvan Butygin 2739bf79308SKrzysztof Drewniak static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { 2749bf79308SKrzysztof Drewniak if (lhs == CastKind::None || rhs == CastKind::None) 2759bf79308SKrzysztof Drewniak return CastKind::None; 2769bf79308SKrzysztof Drewniak if (lhs == CastKind::Both) 2779bf79308SKrzysztof Drewniak return rhs; 2789bf79308SKrzysztof Drewniak if (rhs == CastKind::Both) 2799bf79308SKrzysztof Drewniak return lhs; 2809bf79308SKrzysztof Drewniak if (lhs == rhs) 2819bf79308SKrzysztof Drewniak return lhs; 2829bf79308SKrzysztof Drewniak return CastKind::None; 2839bf79308SKrzysztof Drewniak } 2849bf79308SKrzysztof Drewniak 2859bf79308SKrzysztof Drewniak static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, 2869bf79308SKrzysztof Drewniak CastKind castKind) { 2879f0f6df0SIvan Butygin Type srcType = src.getType(); 2889f0f6df0SIvan Butygin assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) && 2899f0f6df0SIvan Butygin "Mixing vector and non-vector types"); 2909bf79308SKrzysztof Drewniak assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); 2919f0f6df0SIvan Butygin Type srcElemType = getElementTypeOrSelf(srcType); 2929f0f6df0SIvan Butygin Type dstElemType = getElementTypeOrSelf(dstType); 2939f0f6df0SIvan Butygin assert(srcElemType.isIntOrIndex() && "Invalid src type"); 2949f0f6df0SIvan Butygin assert(dstElemType.isIntOrIndex() && "Invalid dst type"); 2959f0f6df0SIvan Butygin if (srcType == dstType) 2969f0f6df0SIvan Butygin return src; 2979f0f6df0SIvan Butygin 2989bf79308SKrzysztof Drewniak if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) { 2999bf79308SKrzysztof Drewniak if (castKind == CastKind::Signed) 3009bf79308SKrzysztof Drewniak return builder.create<arith::IndexCastOp>(loc, dstType, src); 3019f0f6df0SIvan Butygin return builder.create<arith::IndexCastUIOp>(loc, dstType, src); 3029bf79308SKrzysztof Drewniak } 3039f0f6df0SIvan Butygin 3049f0f6df0SIvan Butygin auto srcInt = cast<IntegerType>(srcElemType); 3059f0f6df0SIvan Butygin auto dstInt = cast<IntegerType>(dstElemType); 3069f0f6df0SIvan Butygin if (dstInt.getWidth() < srcInt.getWidth()) 3079f0f6df0SIvan Butygin return builder.create<arith::TruncIOp>(loc, dstType, src); 3089f0f6df0SIvan Butygin 3099bf79308SKrzysztof Drewniak if (castKind == CastKind::Signed) 3109bf79308SKrzysztof Drewniak return builder.create<arith::ExtSIOp>(loc, dstType, src); 3119f0f6df0SIvan Butygin return builder.create<arith::ExtUIOp>(loc, dstType, src); 3129f0f6df0SIvan Butygin } 3139f0f6df0SIvan Butygin 3149f0f6df0SIvan Butygin struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> { 3159f0f6df0SIvan Butygin NarrowElementwise(MLIRContext *context, DataFlowSolver &s, 3169f0f6df0SIvan Butygin ArrayRef<unsigned> target) 3179f0f6df0SIvan Butygin : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} 3189f0f6df0SIvan Butygin 3199f0f6df0SIvan Butygin using OpTraitRewritePattern::OpTraitRewritePattern; 3209f0f6df0SIvan Butygin LogicalResult matchAndRewrite(Operation *op, 3219f0f6df0SIvan Butygin PatternRewriter &rewriter) const override { 3229bf79308SKrzysztof Drewniak if (op->getNumResults() == 0) 3239bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); 3249f0f6df0SIvan Butygin 3259bf79308SKrzysztof Drewniak SmallVector<ConstantIntRanges> ranges; 3269bf79308SKrzysztof Drewniak if (failed(collectRanges(solver, op->getOperands(), ranges))) 3279bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "input without specified range"); 3289bf79308SKrzysztof Drewniak if (failed(collectRanges(solver, op->getResults(), ranges))) 3299bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "output without specified range"); 3309f0f6df0SIvan Butygin 3319f0f6df0SIvan Butygin Type srcType = op->getResult(0).getType(); 3329bf79308SKrzysztof Drewniak if (!llvm::all_equal(op->getResultTypes())) 3339bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "mismatched result types"); 3349bf79308SKrzysztof Drewniak if (op->getNumOperands() == 0 || 3359bf79308SKrzysztof Drewniak !llvm::all_of(op->getOperandTypes(), 3369bf79308SKrzysztof Drewniak [=](Type t) { return t == srcType; })) 3379bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure( 3389bf79308SKrzysztof Drewniak op, "no operands or operand types don't match result type"); 3399f0f6df0SIvan Butygin 3409bf79308SKrzysztof Drewniak for (unsigned targetBitwidth : targetBitwidths) { 3419bf79308SKrzysztof Drewniak CastKind castKind = CastKind::Both; 3429bf79308SKrzysztof Drewniak for (const ConstantIntRanges &range : ranges) { 3439bf79308SKrzysztof Drewniak castKind = mergeCastKinds(castKind, 3449bf79308SKrzysztof Drewniak checkTruncatability(range, targetBitwidth)); 3459bf79308SKrzysztof Drewniak if (castKind == CastKind::None) 3469bf79308SKrzysztof Drewniak break; 3479bf79308SKrzysztof Drewniak } 3489bf79308SKrzysztof Drewniak if (castKind == CastKind::None) 3499f0f6df0SIvan Butygin continue; 3509f0f6df0SIvan Butygin Type targetType = getTargetType(srcType, targetBitwidth); 3519f0f6df0SIvan Butygin if (targetType == srcType) 3529f0f6df0SIvan Butygin continue; 3539f0f6df0SIvan Butygin 3549f0f6df0SIvan Butygin Location loc = op->getLoc(); 3559f0f6df0SIvan Butygin IRMapping mapping; 3569bf79308SKrzysztof Drewniak for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { 3579bf79308SKrzysztof Drewniak CastKind argCastKind = castKind; 3589bf79308SKrzysztof Drewniak // When dealing with `index` values, preserve non-negativity in the 3599bf79308SKrzysztof Drewniak // index_casts since we can't recover this in unsigned when equivalent. 3609bf79308SKrzysztof Drewniak if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) 3619bf79308SKrzysztof Drewniak argCastKind = CastKind::Both; 3629bf79308SKrzysztof Drewniak Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); 3639f0f6df0SIvan Butygin mapping.map(arg, newArg); 3649f0f6df0SIvan Butygin } 3659f0f6df0SIvan Butygin 3669f0f6df0SIvan Butygin Operation *newOp = rewriter.clone(*op, mapping); 3679f0f6df0SIvan Butygin rewriter.modifyOpInPlace(newOp, [&]() { 3689f0f6df0SIvan Butygin for (OpResult res : newOp->getResults()) { 3699f0f6df0SIvan Butygin res.setType(targetType); 3709f0f6df0SIvan Butygin } 3719f0f6df0SIvan Butygin }); 3729f0f6df0SIvan Butygin SmallVector<Value> newResults; 3739bf79308SKrzysztof Drewniak for (auto [newRes, oldRes] : 3749bf79308SKrzysztof Drewniak llvm::zip_equal(newOp->getResults(), op->getResults())) { 3759bf79308SKrzysztof Drewniak Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); 3769bf79308SKrzysztof Drewniak copyIntegerRange(solver, oldRes, castBack); 3779bf79308SKrzysztof Drewniak newResults.push_back(castBack); 3789bf79308SKrzysztof Drewniak } 3799f0f6df0SIvan Butygin 3809f0f6df0SIvan Butygin rewriter.replaceOp(op, newResults); 3819f0f6df0SIvan Butygin return success(); 3829f0f6df0SIvan Butygin } 3839f0f6df0SIvan Butygin return failure(); 3849f0f6df0SIvan Butygin } 3859f0f6df0SIvan Butygin 3869f0f6df0SIvan Butygin private: 3879f0f6df0SIvan Butygin DataFlowSolver &solver; 3889f0f6df0SIvan Butygin SmallVector<unsigned, 4> targetBitwidths; 3899f0f6df0SIvan Butygin }; 3909f0f6df0SIvan Butygin 3919f0f6df0SIvan Butygin struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> { 3929f0f6df0SIvan Butygin NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target) 3939f0f6df0SIvan Butygin : OpRewritePattern(context), solver(s), targetBitwidths(target) {} 3949f0f6df0SIvan Butygin 3959f0f6df0SIvan Butygin LogicalResult matchAndRewrite(arith::CmpIOp op, 3969f0f6df0SIvan Butygin PatternRewriter &rewriter) const override { 3979f0f6df0SIvan Butygin Value lhs = op.getLhs(); 3989f0f6df0SIvan Butygin Value rhs = op.getRhs(); 3999f0f6df0SIvan Butygin 4009bf79308SKrzysztof Drewniak SmallVector<ConstantIntRanges> ranges; 4019bf79308SKrzysztof Drewniak if (failed(collectRanges(solver, op.getOperands(), ranges))) 4029f0f6df0SIvan Butygin return failure(); 4039bf79308SKrzysztof Drewniak const ConstantIntRanges &lhsRange = ranges[0]; 4049bf79308SKrzysztof Drewniak const ConstantIntRanges &rhsRange = ranges[1]; 4059f0f6df0SIvan Butygin 4069f0f6df0SIvan Butygin Type srcType = lhs.getType(); 4079bf79308SKrzysztof Drewniak for (unsigned targetBitwidth : targetBitwidths) { 4089bf79308SKrzysztof Drewniak CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); 4099bf79308SKrzysztof Drewniak CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); 4109bf79308SKrzysztof Drewniak CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); 4119bf79308SKrzysztof Drewniak // Note: this includes target width > src width. 4129bf79308SKrzysztof Drewniak if (castKind == CastKind::None) 4139f0f6df0SIvan Butygin continue; 4149f0f6df0SIvan Butygin 4159f0f6df0SIvan Butygin Type targetType = getTargetType(srcType, targetBitwidth); 4169f0f6df0SIvan Butygin if (targetType == srcType) 4179f0f6df0SIvan Butygin continue; 4189f0f6df0SIvan Butygin 4199f0f6df0SIvan Butygin Location loc = op->getLoc(); 4209f0f6df0SIvan Butygin IRMapping mapping; 4219bf79308SKrzysztof Drewniak Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); 4229bf79308SKrzysztof Drewniak Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); 4239bf79308SKrzysztof Drewniak mapping.map(lhs, lhsCast); 4249bf79308SKrzysztof Drewniak mapping.map(rhs, rhsCast); 4259f0f6df0SIvan Butygin 4269f0f6df0SIvan Butygin Operation *newOp = rewriter.clone(*op, mapping); 4279bf79308SKrzysztof Drewniak copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); 4289f0f6df0SIvan Butygin rewriter.replaceOp(op, newOp->getResults()); 4299f0f6df0SIvan Butygin return success(); 4309f0f6df0SIvan Butygin } 4319f0f6df0SIvan Butygin return failure(); 4329f0f6df0SIvan Butygin } 4339f0f6df0SIvan Butygin 4349f0f6df0SIvan Butygin private: 4359f0f6df0SIvan Butygin DataFlowSolver &solver; 4369f0f6df0SIvan Butygin SmallVector<unsigned, 4> targetBitwidths; 4379f0f6df0SIvan Butygin }; 4389f0f6df0SIvan Butygin 4399f0f6df0SIvan Butygin /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg 4409f0f6df0SIvan Butygin /// This pattern assumes all passed `targetBitwidths` are not wider than index 4419f0f6df0SIvan Butygin /// type. 4429bf79308SKrzysztof Drewniak template <typename CastOp> 4439bf79308SKrzysztof Drewniak struct FoldIndexCastChain final : OpRewritePattern<CastOp> { 4449f0f6df0SIvan Butygin FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target) 4459bf79308SKrzysztof Drewniak : OpRewritePattern<CastOp>(context), targetBitwidths(target) {} 4469f0f6df0SIvan Butygin 4479bf79308SKrzysztof Drewniak LogicalResult matchAndRewrite(CastOp op, 4489f0f6df0SIvan Butygin PatternRewriter &rewriter) const override { 4499bf79308SKrzysztof Drewniak auto srcOp = op.getIn().template getDefiningOp<CastOp>(); 4509f0f6df0SIvan Butygin if (!srcOp) 4519bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); 4529f0f6df0SIvan Butygin 4539f0f6df0SIvan Butygin Value src = srcOp.getIn(); 4549f0f6df0SIvan Butygin if (src.getType() != op.getType()) 4559bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "outer types don't match"); 4569bf79308SKrzysztof Drewniak 4579bf79308SKrzysztof Drewniak if (!srcOp.getType().isIndex()) 4589bf79308SKrzysztof Drewniak return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); 4599f0f6df0SIvan Butygin 4609f0f6df0SIvan Butygin auto intType = dyn_cast<IntegerType>(op.getType()); 4619f0f6df0SIvan Butygin if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) 4629f0f6df0SIvan Butygin return failure(); 4639f0f6df0SIvan Butygin 4649f0f6df0SIvan Butygin rewriter.replaceOp(op, src); 4659f0f6df0SIvan Butygin return success(); 4669f0f6df0SIvan Butygin } 4679f0f6df0SIvan Butygin 4689f0f6df0SIvan Butygin private: 4699f0f6df0SIvan Butygin SmallVector<unsigned, 4> targetBitwidths; 4709f0f6df0SIvan Butygin }; 4719f0f6df0SIvan Butygin 4729f0f6df0SIvan Butygin struct IntRangeOptimizationsPass final 4739f0f6df0SIvan Butygin : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> { 4741a867bf1SIvan Butygin 4751a867bf1SIvan Butygin void runOnOperation() override { 4761a867bf1SIvan Butygin Operation *op = getOperation(); 4771a867bf1SIvan Butygin MLIRContext *ctx = op->getContext(); 4781a867bf1SIvan Butygin DataFlowSolver solver; 4791a867bf1SIvan Butygin solver.load<DeadCodeAnalysis>(); 4801a867bf1SIvan Butygin solver.load<IntegerRangeAnalysis>(); 4811a867bf1SIvan Butygin if (failed(solver.initializeAndRun(op))) 4821a867bf1SIvan Butygin return signalPassFailure(); 4831a867bf1SIvan Butygin 48478b3a004SFelix Schneider DataFlowListener listener(solver); 48578b3a004SFelix Schneider 4861a867bf1SIvan Butygin RewritePatternSet patterns(ctx); 4871a867bf1SIvan Butygin populateIntRangeOptimizationsPatterns(patterns, solver); 4881a867bf1SIvan Butygin 48978b3a004SFelix Schneider GreedyRewriteConfig config; 49078b3a004SFelix Schneider config.listener = &listener; 49178b3a004SFelix Schneider 492*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns), config))) 4931a867bf1SIvan Butygin signalPassFailure(); 4941a867bf1SIvan Butygin } 4951a867bf1SIvan Butygin }; 4969f0f6df0SIvan Butygin 4979f0f6df0SIvan Butygin struct IntRangeNarrowingPass final 4989f0f6df0SIvan Butygin : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> { 4999f0f6df0SIvan Butygin using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; 5009f0f6df0SIvan Butygin 5019f0f6df0SIvan Butygin void runOnOperation() override { 5029f0f6df0SIvan Butygin Operation *op = getOperation(); 5039f0f6df0SIvan Butygin MLIRContext *ctx = op->getContext(); 5049f0f6df0SIvan Butygin DataFlowSolver solver; 5059f0f6df0SIvan Butygin solver.load<DeadCodeAnalysis>(); 5069f0f6df0SIvan Butygin solver.load<IntegerRangeAnalysis>(); 5079f0f6df0SIvan Butygin if (failed(solver.initializeAndRun(op))) 5089f0f6df0SIvan Butygin return signalPassFailure(); 5099f0f6df0SIvan Butygin 5109f0f6df0SIvan Butygin DataFlowListener listener(solver); 5119f0f6df0SIvan Butygin 5129f0f6df0SIvan Butygin RewritePatternSet patterns(ctx); 5139f0f6df0SIvan Butygin populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); 5149f0f6df0SIvan Butygin 5159f0f6df0SIvan Butygin GreedyRewriteConfig config; 5169f0f6df0SIvan Butygin // We specifically need bottom-up traversal as cmpi pattern needs range 5179f0f6df0SIvan Butygin // data, attached to its original argument values. 5189f0f6df0SIvan Butygin config.useTopDownTraversal = false; 5199f0f6df0SIvan Butygin config.listener = &listener; 5209f0f6df0SIvan Butygin 521*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns), config))) 5229f0f6df0SIvan Butygin signalPassFailure(); 5239f0f6df0SIvan Butygin } 5249f0f6df0SIvan Butygin }; 5251a867bf1SIvan Butygin } // namespace 5261a867bf1SIvan Butygin 5271a867bf1SIvan Butygin void mlir::arith::populateIntRangeOptimizationsPatterns( 5281a867bf1SIvan Butygin RewritePatternSet &patterns, DataFlowSolver &solver) { 52947229111SKrzysztof Drewniak patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>, 53047229111SKrzysztof Drewniak DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver); 5311a867bf1SIvan Butygin } 5321a867bf1SIvan Butygin 5339f0f6df0SIvan Butygin void mlir::arith::populateIntRangeNarrowingPatterns( 5349f0f6df0SIvan Butygin RewritePatternSet &patterns, DataFlowSolver &solver, 5359f0f6df0SIvan Butygin ArrayRef<unsigned> bitwidthsSupported) { 5369f0f6df0SIvan Butygin patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver, 5379f0f6df0SIvan Butygin bitwidthsSupported); 5389bf79308SKrzysztof Drewniak patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>, 5399bf79308SKrzysztof Drewniak FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(), 5409bf79308SKrzysztof Drewniak bitwidthsSupported); 5419f0f6df0SIvan Butygin } 5429f0f6df0SIvan Butygin 5431a867bf1SIvan Butygin std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() { 5441a867bf1SIvan Butygin return std::make_unique<IntRangeOptimizationsPass>(); 5451a867bf1SIvan Butygin } 546