xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
11a867bf1SIvan Butygin //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
21a867bf1SIvan Butygin //
31a867bf1SIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41a867bf1SIvan Butygin // See https://llvm.org/LICENSE.txt for license information.
51a867bf1SIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61a867bf1SIvan Butygin //
71a867bf1SIvan Butygin //===----------------------------------------------------------------------===//
81a867bf1SIvan Butygin 
951911a62SMehdi Amini #include <utility>
1051911a62SMehdi Amini 
1147229111SKrzysztof Drewniak #include "mlir/Analysis/DataFlowFramework.h"
121a867bf1SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h"
131a867bf1SIvan Butygin 
141a867bf1SIvan Butygin #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
151a867bf1SIvan Butygin #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
161a867bf1SIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h"
1747229111SKrzysztof Drewniak #include "mlir/Dialect/Utils/StaticValueUtils.h"
189f0f6df0SIvan Butygin #include "mlir/IR/IRMapping.h"
1947229111SKrzysztof Drewniak #include "mlir/IR/Matchers.h"
2047229111SKrzysztof Drewniak #include "mlir/IR/PatternMatch.h"
219f0f6df0SIvan Butygin #include "mlir/IR/TypeUtilities.h"
2247229111SKrzysztof Drewniak #include "mlir/Interfaces/SideEffectInterfaces.h"
2347229111SKrzysztof Drewniak #include "mlir/Transforms/FoldUtils.h"
241a867bf1SIvan Butygin #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
251a867bf1SIvan Butygin 
261a867bf1SIvan Butygin namespace mlir::arith {
271a867bf1SIvan Butygin #define GEN_PASS_DEF_ARITHINTRANGEOPTS
281a867bf1SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
299f0f6df0SIvan Butygin 
309f0f6df0SIvan Butygin #define GEN_PASS_DEF_ARITHINTRANGENARROWING
319f0f6df0SIvan Butygin #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
321a867bf1SIvan Butygin } // namespace mlir::arith
331a867bf1SIvan Butygin 
341a867bf1SIvan Butygin using namespace mlir;
351a867bf1SIvan Butygin using namespace mlir::arith;
361a867bf1SIvan Butygin using namespace mlir::dataflow;
371a867bf1SIvan Butygin 
3847229111SKrzysztof Drewniak static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
3947229111SKrzysztof Drewniak                                                   Value value) {
4047229111SKrzysztof Drewniak   auto *maybeInferredRange =
4147229111SKrzysztof Drewniak       solver.lookupState<IntegerValueRangeLattice>(value);
4247229111SKrzysztof Drewniak   if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
4347229111SKrzysztof Drewniak     return std::nullopt;
4447229111SKrzysztof Drewniak   const ConstantIntRanges &inferredRange =
4547229111SKrzysztof Drewniak       maybeInferredRange->getValue().getValue();
4647229111SKrzysztof Drewniak   return inferredRange.getConstantValue();
471a867bf1SIvan Butygin }
481a867bf1SIvan Butygin 
499bf79308SKrzysztof Drewniak static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
509bf79308SKrzysztof Drewniak                              Value newVal) {
519bf79308SKrzysztof Drewniak   assert(oldVal.getType() == newVal.getType() &&
529bf79308SKrzysztof Drewniak          "Can't copy integer ranges between different types");
539bf79308SKrzysztof Drewniak   auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
549bf79308SKrzysztof Drewniak   if (!oldState)
559bf79308SKrzysztof Drewniak     return;
569bf79308SKrzysztof Drewniak   (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
579bf79308SKrzysztof Drewniak       *oldState);
589bf79308SKrzysztof Drewniak }
599bf79308SKrzysztof Drewniak 
6047229111SKrzysztof Drewniak /// Patterned after SCCP
6147229111SKrzysztof Drewniak static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
6247229111SKrzysztof Drewniak                                               PatternRewriter &rewriter,
6347229111SKrzysztof Drewniak                                               Value value) {
6447229111SKrzysztof Drewniak   if (value.use_empty())
651a867bf1SIvan Butygin     return failure();
6647229111SKrzysztof Drewniak   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
6747229111SKrzysztof Drewniak   if (!maybeConstValue.has_value())
681a867bf1SIvan Butygin     return failure();
691a867bf1SIvan Butygin 
70f54cdc5dSIvan Butygin   Type type = value.getType();
71f54cdc5dSIvan Butygin   Location loc = value.getLoc();
7247229111SKrzysztof Drewniak   Operation *maybeDefiningOp = value.getDefiningOp();
7347229111SKrzysztof Drewniak   Dialect *valueDialect =
7447229111SKrzysztof Drewniak       maybeDefiningOp ? maybeDefiningOp->getDialect()
7547229111SKrzysztof Drewniak                       : value.getParentRegion()->getParentOp()->getDialect();
76f54cdc5dSIvan Butygin 
77f54cdc5dSIvan Butygin   Attribute constAttr;
78f54cdc5dSIvan Butygin   if (auto shaped = dyn_cast<ShapedType>(type)) {
79f54cdc5dSIvan Butygin     constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
80f54cdc5dSIvan Butygin   } else {
81f54cdc5dSIvan Butygin     constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
82f54cdc5dSIvan Butygin   }
83f54cdc5dSIvan Butygin   Operation *constOp =
84f54cdc5dSIvan Butygin       valueDialect->materializeConstant(rewriter, constAttr, type, loc);
8547229111SKrzysztof Drewniak   // Fall back to arith.constant if the dialect materializer doesn't know what
8647229111SKrzysztof Drewniak   // to do with an integer constant.
8747229111SKrzysztof Drewniak   if (!constOp)
8847229111SKrzysztof Drewniak     constOp = rewriter.getContext()
8947229111SKrzysztof Drewniak                   ->getLoadedDialect<ArithDialect>()
90f54cdc5dSIvan Butygin                   ->materializeConstant(rewriter, constAttr, type, loc);
9147229111SKrzysztof Drewniak   if (!constOp)
921a867bf1SIvan Butygin     return failure();
931a867bf1SIvan Butygin 
949bf79308SKrzysztof Drewniak   copyIntegerRange(solver, value, constOp->getResult(0));
9547229111SKrzysztof Drewniak   rewriter.replaceAllUsesWith(value, constOp->getResult(0));
9647229111SKrzysztof Drewniak   return success();
971a867bf1SIvan Butygin }
981a867bf1SIvan Butygin 
991a867bf1SIvan Butygin namespace {
10078b3a004SFelix Schneider class DataFlowListener : public RewriterBase::Listener {
10178b3a004SFelix Schneider public:
10278b3a004SFelix Schneider   DataFlowListener(DataFlowSolver &s) : s(s) {}
10378b3a004SFelix Schneider 
10478b3a004SFelix Schneider protected:
10578b3a004SFelix Schneider   void notifyOperationErased(Operation *op) override {
1064b3f251bSdonald chen     s.eraseState(s.getProgramPointAfter(op));
10778b3a004SFelix Schneider     for (Value res : op->getResults())
10878b3a004SFelix Schneider       s.eraseState(res);
10978b3a004SFelix Schneider   }
11078b3a004SFelix Schneider 
11178b3a004SFelix Schneider   DataFlowSolver &s;
11278b3a004SFelix Schneider };
11378b3a004SFelix Schneider 
11447229111SKrzysztof Drewniak /// Rewrite any results of `op` that were inferred to be constant integers to
11547229111SKrzysztof Drewniak /// and replace their uses with that constant. Return success() if all results
11647229111SKrzysztof Drewniak /// where thus replaced and the operation is erased. Also replace any block
11747229111SKrzysztof Drewniak /// arguments with their constant values.
11847229111SKrzysztof Drewniak struct MaterializeKnownConstantValues : public RewritePattern {
11947229111SKrzysztof Drewniak   MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
12047229111SKrzysztof Drewniak       : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
12147229111SKrzysztof Drewniak         solver(s) {}
1221a867bf1SIvan Butygin 
12347229111SKrzysztof Drewniak   LogicalResult match(Operation *op) const override {
12447229111SKrzysztof Drewniak     if (matchPattern(op, m_Constant()))
12547229111SKrzysztof Drewniak       return failure();
1261a867bf1SIvan Butygin 
12747229111SKrzysztof Drewniak     auto needsReplacing = [&](Value v) {
12847229111SKrzysztof Drewniak       return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
12947229111SKrzysztof Drewniak     };
13047229111SKrzysztof Drewniak     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
13147229111SKrzysztof Drewniak     if (op->getNumRegions() == 0)
13247229111SKrzysztof Drewniak       return success(hasConstantResults);
13347229111SKrzysztof Drewniak     bool hasConstantRegionArgs = false;
13447229111SKrzysztof Drewniak     for (Region &region : op->getRegions()) {
13547229111SKrzysztof Drewniak       for (Block &block : region.getBlocks()) {
13647229111SKrzysztof Drewniak         hasConstantRegionArgs |=
13747229111SKrzysztof Drewniak             llvm::any_of(block.getArguments(), needsReplacing);
13847229111SKrzysztof Drewniak       }
13947229111SKrzysztof Drewniak     }
14047229111SKrzysztof Drewniak     return success(hasConstantResults || hasConstantRegionArgs);
14147229111SKrzysztof Drewniak   }
14247229111SKrzysztof Drewniak 
14347229111SKrzysztof Drewniak   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
14447229111SKrzysztof Drewniak     bool replacedAll = (op->getNumResults() != 0);
14547229111SKrzysztof Drewniak     for (Value v : op->getResults())
14647229111SKrzysztof Drewniak       replacedAll &=
14747229111SKrzysztof Drewniak           (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
14847229111SKrzysztof Drewniak            v.use_empty());
14947229111SKrzysztof Drewniak     if (replacedAll && isOpTriviallyDead(op)) {
15047229111SKrzysztof Drewniak       rewriter.eraseOp(op);
15147229111SKrzysztof Drewniak       return;
15247229111SKrzysztof Drewniak     }
15347229111SKrzysztof Drewniak 
15447229111SKrzysztof Drewniak     PatternRewriter::InsertionGuard guard(rewriter);
15547229111SKrzysztof Drewniak     for (Region &region : op->getRegions()) {
15647229111SKrzysztof Drewniak       for (Block &block : region.getBlocks()) {
15747229111SKrzysztof Drewniak         rewriter.setInsertionPointToStart(&block);
15847229111SKrzysztof Drewniak         for (BlockArgument &arg : block.getArguments()) {
15947229111SKrzysztof Drewniak           (void)maybeReplaceWithConstant(solver, rewriter, arg);
16047229111SKrzysztof Drewniak         }
16147229111SKrzysztof Drewniak       }
16247229111SKrzysztof Drewniak     }
16347229111SKrzysztof Drewniak   }
16447229111SKrzysztof Drewniak 
16547229111SKrzysztof Drewniak private:
16647229111SKrzysztof Drewniak   DataFlowSolver &solver;
16747229111SKrzysztof Drewniak };
16847229111SKrzysztof Drewniak 
16947229111SKrzysztof Drewniak template <typename RemOp>
17047229111SKrzysztof Drewniak struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
17147229111SKrzysztof Drewniak   DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
17247229111SKrzysztof Drewniak       : OpRewritePattern<RemOp>(context), solver(s) {}
17347229111SKrzysztof Drewniak 
17447229111SKrzysztof Drewniak   LogicalResult matchAndRewrite(RemOp op,
1751a867bf1SIvan Butygin                                 PatternRewriter &rewriter) const override {
17647229111SKrzysztof Drewniak     Value lhs = op.getOperand(0);
17747229111SKrzysztof Drewniak     Value rhs = op.getOperand(1);
17847229111SKrzysztof Drewniak     auto maybeModulus = getConstantIntValue(rhs);
17947229111SKrzysztof Drewniak     if (!maybeModulus.has_value())
18047229111SKrzysztof Drewniak       return failure();
18147229111SKrzysztof Drewniak     int64_t modulus = *maybeModulus;
18247229111SKrzysztof Drewniak     if (modulus <= 0)
18347229111SKrzysztof Drewniak       return failure();
18447229111SKrzysztof Drewniak     auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
18547229111SKrzysztof Drewniak     if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
18647229111SKrzysztof Drewniak       return failure();
18747229111SKrzysztof Drewniak     const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
18847229111SKrzysztof Drewniak     const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
18947229111SKrzysztof Drewniak     const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
19047229111SKrzysztof Drewniak     // The minima and maxima here are given as closed ranges, we must be
19147229111SKrzysztof Drewniak     // strictly less than the modulus.
19247229111SKrzysztof Drewniak     if (min.isNegative() || min.uge(modulus))
19347229111SKrzysztof Drewniak       return failure();
19447229111SKrzysztof Drewniak     if (max.isNegative() || max.uge(modulus))
19547229111SKrzysztof Drewniak       return failure();
19647229111SKrzysztof Drewniak     if (!min.ule(max))
1971a867bf1SIvan Butygin       return failure();
1981a867bf1SIvan Butygin 
19947229111SKrzysztof Drewniak     // With all those conditions out of the way, we know thas this invocation of
20047229111SKrzysztof Drewniak     // a remainder is a noop because the input is strictly within the range
20147229111SKrzysztof Drewniak     // [0, modulus), so get rid of it.
20247229111SKrzysztof Drewniak     rewriter.replaceOp(op, ValueRange{lhs});
2031a867bf1SIvan Butygin     return success();
2041a867bf1SIvan Butygin   }
2051a867bf1SIvan Butygin 
2061a867bf1SIvan Butygin private:
2071a867bf1SIvan Butygin   DataFlowSolver &solver;
2081a867bf1SIvan Butygin };
2091a867bf1SIvan Butygin 
2109bf79308SKrzysztof Drewniak /// Gather ranges for all the values in `values`. Appends to the existing
2119bf79308SKrzysztof Drewniak /// vector.
2129bf79308SKrzysztof Drewniak static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
2139bf79308SKrzysztof Drewniak                                    SmallVectorImpl<ConstantIntRanges> &ranges) {
2149bf79308SKrzysztof Drewniak   for (Value val : values) {
2159f0f6df0SIvan Butygin     auto *maybeInferredRange =
2169bf79308SKrzysztof Drewniak         solver.lookupState<IntegerValueRangeLattice>(val);
2179f0f6df0SIvan Butygin     if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
2189bf79308SKrzysztof Drewniak       return failure();
2199f0f6df0SIvan Butygin 
2209f0f6df0SIvan Butygin     const ConstantIntRanges &inferredRange =
2219f0f6df0SIvan Butygin         maybeInferredRange->getValue().getValue();
2229bf79308SKrzysztof Drewniak     ranges.push_back(inferredRange);
2239f0f6df0SIvan Butygin   }
2249bf79308SKrzysztof Drewniak   return success();
2259f0f6df0SIvan Butygin }
2269f0f6df0SIvan Butygin 
2279f0f6df0SIvan Butygin /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
2289f0f6df0SIvan Butygin /// return shaped type as well.
2299f0f6df0SIvan Butygin static Type getTargetType(Type srcType, unsigned targetBitwidth) {
2309f0f6df0SIvan Butygin   auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
2319f0f6df0SIvan Butygin   if (auto shaped = dyn_cast<ShapedType>(srcType))
2329f0f6df0SIvan Butygin     return shaped.clone(dstType);
2339f0f6df0SIvan Butygin 
2349f0f6df0SIvan Butygin   assert(srcType.isIntOrIndex() && "Invalid src type");
2359f0f6df0SIvan Butygin   return dstType;
2369f0f6df0SIvan Butygin }
2379f0f6df0SIvan Butygin 
2389bf79308SKrzysztof Drewniak namespace {
2399bf79308SKrzysztof Drewniak // Enum for tracking which type of truncation should be performed
2409bf79308SKrzysztof Drewniak // to narrow an operation, if any.
2419bf79308SKrzysztof Drewniak enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
2429bf79308SKrzysztof Drewniak } // namespace
2439bf79308SKrzysztof Drewniak 
2449bf79308SKrzysztof Drewniak /// If the values within `range` can be represented using only `width` bits,
2459bf79308SKrzysztof Drewniak /// return the kind of truncation needed to preserve that property.
2469bf79308SKrzysztof Drewniak ///
2479bf79308SKrzysztof Drewniak /// This check relies on the fact that the signed and unsigned ranges are both
2489bf79308SKrzysztof Drewniak /// always correct, but that one might be an approximation of the other,
2499bf79308SKrzysztof Drewniak /// so we want to use the correct truncation operation.
2509bf79308SKrzysztof Drewniak static CastKind checkTruncatability(const ConstantIntRanges &range,
2519bf79308SKrzysztof Drewniak                                     unsigned targetWidth) {
2529bf79308SKrzysztof Drewniak   unsigned srcWidth = range.smin().getBitWidth();
2539bf79308SKrzysztof Drewniak   if (srcWidth <= targetWidth)
2549bf79308SKrzysztof Drewniak     return CastKind::None;
2559bf79308SKrzysztof Drewniak   unsigned removedWidth = srcWidth - targetWidth;
2569bf79308SKrzysztof Drewniak   // The sign bits need to extend into the sign bit of the target width. For
2579bf79308SKrzysztof Drewniak   // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
2589bf79308SKrzysztof Drewniak   // bits.
2599bf79308SKrzysztof Drewniak   bool canTruncateSigned =
2609bf79308SKrzysztof Drewniak       range.smin().getNumSignBits() >= (removedWidth + 1) &&
2619bf79308SKrzysztof Drewniak       range.smax().getNumSignBits() >= (removedWidth + 1);
2629bf79308SKrzysztof Drewniak   bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
2639bf79308SKrzysztof Drewniak                              range.umax().countLeadingZeros() >= removedWidth;
2649bf79308SKrzysztof Drewniak   if (canTruncateSigned && canTruncateUnsigned)
2659bf79308SKrzysztof Drewniak     return CastKind::Both;
2669bf79308SKrzysztof Drewniak   if (canTruncateSigned)
2679bf79308SKrzysztof Drewniak     return CastKind::Signed;
2689bf79308SKrzysztof Drewniak   if (canTruncateUnsigned)
2699bf79308SKrzysztof Drewniak     return CastKind::Unsigned;
2709bf79308SKrzysztof Drewniak   return CastKind::None;
2719f0f6df0SIvan Butygin }
2729f0f6df0SIvan Butygin 
2739bf79308SKrzysztof Drewniak static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
2749bf79308SKrzysztof Drewniak   if (lhs == CastKind::None || rhs == CastKind::None)
2759bf79308SKrzysztof Drewniak     return CastKind::None;
2769bf79308SKrzysztof Drewniak   if (lhs == CastKind::Both)
2779bf79308SKrzysztof Drewniak     return rhs;
2789bf79308SKrzysztof Drewniak   if (rhs == CastKind::Both)
2799bf79308SKrzysztof Drewniak     return lhs;
2809bf79308SKrzysztof Drewniak   if (lhs == rhs)
2819bf79308SKrzysztof Drewniak     return lhs;
2829bf79308SKrzysztof Drewniak   return CastKind::None;
2839bf79308SKrzysztof Drewniak }
2849bf79308SKrzysztof Drewniak 
2859bf79308SKrzysztof Drewniak static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
2869bf79308SKrzysztof Drewniak                     CastKind castKind) {
2879f0f6df0SIvan Butygin   Type srcType = src.getType();
2889f0f6df0SIvan Butygin   assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
2899f0f6df0SIvan Butygin          "Mixing vector and non-vector types");
2909bf79308SKrzysztof Drewniak   assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
2919f0f6df0SIvan Butygin   Type srcElemType = getElementTypeOrSelf(srcType);
2929f0f6df0SIvan Butygin   Type dstElemType = getElementTypeOrSelf(dstType);
2939f0f6df0SIvan Butygin   assert(srcElemType.isIntOrIndex() && "Invalid src type");
2949f0f6df0SIvan Butygin   assert(dstElemType.isIntOrIndex() && "Invalid dst type");
2959f0f6df0SIvan Butygin   if (srcType == dstType)
2969f0f6df0SIvan Butygin     return src;
2979f0f6df0SIvan Butygin 
2989bf79308SKrzysztof Drewniak   if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
2999bf79308SKrzysztof Drewniak     if (castKind == CastKind::Signed)
3009bf79308SKrzysztof Drewniak       return builder.create<arith::IndexCastOp>(loc, dstType, src);
3019f0f6df0SIvan Butygin     return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
3029bf79308SKrzysztof Drewniak   }
3039f0f6df0SIvan Butygin 
3049f0f6df0SIvan Butygin   auto srcInt = cast<IntegerType>(srcElemType);
3059f0f6df0SIvan Butygin   auto dstInt = cast<IntegerType>(dstElemType);
3069f0f6df0SIvan Butygin   if (dstInt.getWidth() < srcInt.getWidth())
3079f0f6df0SIvan Butygin     return builder.create<arith::TruncIOp>(loc, dstType, src);
3089f0f6df0SIvan Butygin 
3099bf79308SKrzysztof Drewniak   if (castKind == CastKind::Signed)
3109bf79308SKrzysztof Drewniak     return builder.create<arith::ExtSIOp>(loc, dstType, src);
3119f0f6df0SIvan Butygin   return builder.create<arith::ExtUIOp>(loc, dstType, src);
3129f0f6df0SIvan Butygin }
3139f0f6df0SIvan Butygin 
3149f0f6df0SIvan Butygin struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
3159f0f6df0SIvan Butygin   NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
3169f0f6df0SIvan Butygin                     ArrayRef<unsigned> target)
3179f0f6df0SIvan Butygin       : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
3189f0f6df0SIvan Butygin 
3199f0f6df0SIvan Butygin   using OpTraitRewritePattern::OpTraitRewritePattern;
3209f0f6df0SIvan Butygin   LogicalResult matchAndRewrite(Operation *op,
3219f0f6df0SIvan Butygin                                 PatternRewriter &rewriter) const override {
3229bf79308SKrzysztof Drewniak     if (op->getNumResults() == 0)
3239bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
3249f0f6df0SIvan Butygin 
3259bf79308SKrzysztof Drewniak     SmallVector<ConstantIntRanges> ranges;
3269bf79308SKrzysztof Drewniak     if (failed(collectRanges(solver, op->getOperands(), ranges)))
3279bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "input without specified range");
3289bf79308SKrzysztof Drewniak     if (failed(collectRanges(solver, op->getResults(), ranges)))
3299bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "output without specified range");
3309f0f6df0SIvan Butygin 
3319f0f6df0SIvan Butygin     Type srcType = op->getResult(0).getType();
3329bf79308SKrzysztof Drewniak     if (!llvm::all_equal(op->getResultTypes()))
3339bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "mismatched result types");
3349bf79308SKrzysztof Drewniak     if (op->getNumOperands() == 0 ||
3359bf79308SKrzysztof Drewniak         !llvm::all_of(op->getOperandTypes(),
3369bf79308SKrzysztof Drewniak                       [=](Type t) { return t == srcType; }))
3379bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(
3389bf79308SKrzysztof Drewniak           op, "no operands or operand types don't match result type");
3399f0f6df0SIvan Butygin 
3409bf79308SKrzysztof Drewniak     for (unsigned targetBitwidth : targetBitwidths) {
3419bf79308SKrzysztof Drewniak       CastKind castKind = CastKind::Both;
3429bf79308SKrzysztof Drewniak       for (const ConstantIntRanges &range : ranges) {
3439bf79308SKrzysztof Drewniak         castKind = mergeCastKinds(castKind,
3449bf79308SKrzysztof Drewniak                                   checkTruncatability(range, targetBitwidth));
3459bf79308SKrzysztof Drewniak         if (castKind == CastKind::None)
3469bf79308SKrzysztof Drewniak           break;
3479bf79308SKrzysztof Drewniak       }
3489bf79308SKrzysztof Drewniak       if (castKind == CastKind::None)
3499f0f6df0SIvan Butygin         continue;
3509f0f6df0SIvan Butygin       Type targetType = getTargetType(srcType, targetBitwidth);
3519f0f6df0SIvan Butygin       if (targetType == srcType)
3529f0f6df0SIvan Butygin         continue;
3539f0f6df0SIvan Butygin 
3549f0f6df0SIvan Butygin       Location loc = op->getLoc();
3559f0f6df0SIvan Butygin       IRMapping mapping;
3569bf79308SKrzysztof Drewniak       for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
3579bf79308SKrzysztof Drewniak         CastKind argCastKind = castKind;
3589bf79308SKrzysztof Drewniak         // When dealing with `index` values, preserve non-negativity in the
3599bf79308SKrzysztof Drewniak         // index_casts since we can't recover this in unsigned when equivalent.
3609bf79308SKrzysztof Drewniak         if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
3619bf79308SKrzysztof Drewniak           argCastKind = CastKind::Both;
3629bf79308SKrzysztof Drewniak         Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
3639f0f6df0SIvan Butygin         mapping.map(arg, newArg);
3649f0f6df0SIvan Butygin       }
3659f0f6df0SIvan Butygin 
3669f0f6df0SIvan Butygin       Operation *newOp = rewriter.clone(*op, mapping);
3679f0f6df0SIvan Butygin       rewriter.modifyOpInPlace(newOp, [&]() {
3689f0f6df0SIvan Butygin         for (OpResult res : newOp->getResults()) {
3699f0f6df0SIvan Butygin           res.setType(targetType);
3709f0f6df0SIvan Butygin         }
3719f0f6df0SIvan Butygin       });
3729f0f6df0SIvan Butygin       SmallVector<Value> newResults;
3739bf79308SKrzysztof Drewniak       for (auto [newRes, oldRes] :
3749bf79308SKrzysztof Drewniak            llvm::zip_equal(newOp->getResults(), op->getResults())) {
3759bf79308SKrzysztof Drewniak         Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
3769bf79308SKrzysztof Drewniak         copyIntegerRange(solver, oldRes, castBack);
3779bf79308SKrzysztof Drewniak         newResults.push_back(castBack);
3789bf79308SKrzysztof Drewniak       }
3799f0f6df0SIvan Butygin 
3809f0f6df0SIvan Butygin       rewriter.replaceOp(op, newResults);
3819f0f6df0SIvan Butygin       return success();
3829f0f6df0SIvan Butygin     }
3839f0f6df0SIvan Butygin     return failure();
3849f0f6df0SIvan Butygin   }
3859f0f6df0SIvan Butygin 
3869f0f6df0SIvan Butygin private:
3879f0f6df0SIvan Butygin   DataFlowSolver &solver;
3889f0f6df0SIvan Butygin   SmallVector<unsigned, 4> targetBitwidths;
3899f0f6df0SIvan Butygin };
3909f0f6df0SIvan Butygin 
3919f0f6df0SIvan Butygin struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
3929f0f6df0SIvan Butygin   NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
3939f0f6df0SIvan Butygin       : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
3949f0f6df0SIvan Butygin 
3959f0f6df0SIvan Butygin   LogicalResult matchAndRewrite(arith::CmpIOp op,
3969f0f6df0SIvan Butygin                                 PatternRewriter &rewriter) const override {
3979f0f6df0SIvan Butygin     Value lhs = op.getLhs();
3989f0f6df0SIvan Butygin     Value rhs = op.getRhs();
3999f0f6df0SIvan Butygin 
4009bf79308SKrzysztof Drewniak     SmallVector<ConstantIntRanges> ranges;
4019bf79308SKrzysztof Drewniak     if (failed(collectRanges(solver, op.getOperands(), ranges)))
4029f0f6df0SIvan Butygin       return failure();
4039bf79308SKrzysztof Drewniak     const ConstantIntRanges &lhsRange = ranges[0];
4049bf79308SKrzysztof Drewniak     const ConstantIntRanges &rhsRange = ranges[1];
4059f0f6df0SIvan Butygin 
4069f0f6df0SIvan Butygin     Type srcType = lhs.getType();
4079bf79308SKrzysztof Drewniak     for (unsigned targetBitwidth : targetBitwidths) {
4089bf79308SKrzysztof Drewniak       CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
4099bf79308SKrzysztof Drewniak       CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
4109bf79308SKrzysztof Drewniak       CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
4119bf79308SKrzysztof Drewniak       // Note: this includes target width > src width.
4129bf79308SKrzysztof Drewniak       if (castKind == CastKind::None)
4139f0f6df0SIvan Butygin         continue;
4149f0f6df0SIvan Butygin 
4159f0f6df0SIvan Butygin       Type targetType = getTargetType(srcType, targetBitwidth);
4169f0f6df0SIvan Butygin       if (targetType == srcType)
4179f0f6df0SIvan Butygin         continue;
4189f0f6df0SIvan Butygin 
4199f0f6df0SIvan Butygin       Location loc = op->getLoc();
4209f0f6df0SIvan Butygin       IRMapping mapping;
4219bf79308SKrzysztof Drewniak       Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
4229bf79308SKrzysztof Drewniak       Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
4239bf79308SKrzysztof Drewniak       mapping.map(lhs, lhsCast);
4249bf79308SKrzysztof Drewniak       mapping.map(rhs, rhsCast);
4259f0f6df0SIvan Butygin 
4269f0f6df0SIvan Butygin       Operation *newOp = rewriter.clone(*op, mapping);
4279bf79308SKrzysztof Drewniak       copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
4289f0f6df0SIvan Butygin       rewriter.replaceOp(op, newOp->getResults());
4299f0f6df0SIvan Butygin       return success();
4309f0f6df0SIvan Butygin     }
4319f0f6df0SIvan Butygin     return failure();
4329f0f6df0SIvan Butygin   }
4339f0f6df0SIvan Butygin 
4349f0f6df0SIvan Butygin private:
4359f0f6df0SIvan Butygin   DataFlowSolver &solver;
4369f0f6df0SIvan Butygin   SmallVector<unsigned, 4> targetBitwidths;
4379f0f6df0SIvan Butygin };
4389f0f6df0SIvan Butygin 
4399f0f6df0SIvan Butygin /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
4409f0f6df0SIvan Butygin /// This pattern assumes all passed `targetBitwidths` are not wider than index
4419f0f6df0SIvan Butygin /// type.
4429bf79308SKrzysztof Drewniak template <typename CastOp>
4439bf79308SKrzysztof Drewniak struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
4449f0f6df0SIvan Butygin   FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
4459bf79308SKrzysztof Drewniak       : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
4469f0f6df0SIvan Butygin 
4479bf79308SKrzysztof Drewniak   LogicalResult matchAndRewrite(CastOp op,
4489f0f6df0SIvan Butygin                                 PatternRewriter &rewriter) const override {
4499bf79308SKrzysztof Drewniak     auto srcOp = op.getIn().template getDefiningOp<CastOp>();
4509f0f6df0SIvan Butygin     if (!srcOp)
4519bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
4529f0f6df0SIvan Butygin 
4539f0f6df0SIvan Butygin     Value src = srcOp.getIn();
4549f0f6df0SIvan Butygin     if (src.getType() != op.getType())
4559bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "outer types don't match");
4569bf79308SKrzysztof Drewniak 
4579bf79308SKrzysztof Drewniak     if (!srcOp.getType().isIndex())
4589bf79308SKrzysztof Drewniak       return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
4599f0f6df0SIvan Butygin 
4609f0f6df0SIvan Butygin     auto intType = dyn_cast<IntegerType>(op.getType());
4619f0f6df0SIvan Butygin     if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
4629f0f6df0SIvan Butygin       return failure();
4639f0f6df0SIvan Butygin 
4649f0f6df0SIvan Butygin     rewriter.replaceOp(op, src);
4659f0f6df0SIvan Butygin     return success();
4669f0f6df0SIvan Butygin   }
4679f0f6df0SIvan Butygin 
4689f0f6df0SIvan Butygin private:
4699f0f6df0SIvan Butygin   SmallVector<unsigned, 4> targetBitwidths;
4709f0f6df0SIvan Butygin };
4719f0f6df0SIvan Butygin 
4729f0f6df0SIvan Butygin struct IntRangeOptimizationsPass final
4739f0f6df0SIvan Butygin     : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
4741a867bf1SIvan Butygin 
4751a867bf1SIvan Butygin   void runOnOperation() override {
4761a867bf1SIvan Butygin     Operation *op = getOperation();
4771a867bf1SIvan Butygin     MLIRContext *ctx = op->getContext();
4781a867bf1SIvan Butygin     DataFlowSolver solver;
4791a867bf1SIvan Butygin     solver.load<DeadCodeAnalysis>();
4801a867bf1SIvan Butygin     solver.load<IntegerRangeAnalysis>();
4811a867bf1SIvan Butygin     if (failed(solver.initializeAndRun(op)))
4821a867bf1SIvan Butygin       return signalPassFailure();
4831a867bf1SIvan Butygin 
48478b3a004SFelix Schneider     DataFlowListener listener(solver);
48578b3a004SFelix Schneider 
4861a867bf1SIvan Butygin     RewritePatternSet patterns(ctx);
4871a867bf1SIvan Butygin     populateIntRangeOptimizationsPatterns(patterns, solver);
4881a867bf1SIvan Butygin 
48978b3a004SFelix Schneider     GreedyRewriteConfig config;
49078b3a004SFelix Schneider     config.listener = &listener;
49178b3a004SFelix Schneider 
492*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
4931a867bf1SIvan Butygin       signalPassFailure();
4941a867bf1SIvan Butygin   }
4951a867bf1SIvan Butygin };
4969f0f6df0SIvan Butygin 
4979f0f6df0SIvan Butygin struct IntRangeNarrowingPass final
4989f0f6df0SIvan Butygin     : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
4999f0f6df0SIvan Butygin   using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
5009f0f6df0SIvan Butygin 
5019f0f6df0SIvan Butygin   void runOnOperation() override {
5029f0f6df0SIvan Butygin     Operation *op = getOperation();
5039f0f6df0SIvan Butygin     MLIRContext *ctx = op->getContext();
5049f0f6df0SIvan Butygin     DataFlowSolver solver;
5059f0f6df0SIvan Butygin     solver.load<DeadCodeAnalysis>();
5069f0f6df0SIvan Butygin     solver.load<IntegerRangeAnalysis>();
5079f0f6df0SIvan Butygin     if (failed(solver.initializeAndRun(op)))
5089f0f6df0SIvan Butygin       return signalPassFailure();
5099f0f6df0SIvan Butygin 
5109f0f6df0SIvan Butygin     DataFlowListener listener(solver);
5119f0f6df0SIvan Butygin 
5129f0f6df0SIvan Butygin     RewritePatternSet patterns(ctx);
5139f0f6df0SIvan Butygin     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
5149f0f6df0SIvan Butygin 
5159f0f6df0SIvan Butygin     GreedyRewriteConfig config;
5169f0f6df0SIvan Butygin     // We specifically need bottom-up traversal as cmpi pattern needs range
5179f0f6df0SIvan Butygin     // data, attached to its original argument values.
5189f0f6df0SIvan Butygin     config.useTopDownTraversal = false;
5199f0f6df0SIvan Butygin     config.listener = &listener;
5209f0f6df0SIvan Butygin 
521*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
5229f0f6df0SIvan Butygin       signalPassFailure();
5239f0f6df0SIvan Butygin   }
5249f0f6df0SIvan Butygin };
5251a867bf1SIvan Butygin } // namespace
5261a867bf1SIvan Butygin 
5271a867bf1SIvan Butygin void mlir::arith::populateIntRangeOptimizationsPatterns(
5281a867bf1SIvan Butygin     RewritePatternSet &patterns, DataFlowSolver &solver) {
52947229111SKrzysztof Drewniak   patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
53047229111SKrzysztof Drewniak                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
5311a867bf1SIvan Butygin }
5321a867bf1SIvan Butygin 
5339f0f6df0SIvan Butygin void mlir::arith::populateIntRangeNarrowingPatterns(
5349f0f6df0SIvan Butygin     RewritePatternSet &patterns, DataFlowSolver &solver,
5359f0f6df0SIvan Butygin     ArrayRef<unsigned> bitwidthsSupported) {
5369f0f6df0SIvan Butygin   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
5379f0f6df0SIvan Butygin                                               bitwidthsSupported);
5389bf79308SKrzysztof Drewniak   patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
5399bf79308SKrzysztof Drewniak                FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
5409bf79308SKrzysztof Drewniak                                                        bitwidthsSupported);
5419f0f6df0SIvan Butygin }
5429f0f6df0SIvan Butygin 
5431a867bf1SIvan Butygin std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
5441a867bf1SIvan Butygin   return std::make_unique<IntRangeOptimizationsPass>();
5451a867bf1SIvan Butygin }
546