//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" #define GEN_PASS_DEF_ARITHINTRANGENARROWING #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; static std::optional getMaybeConstantValue(DataFlowSolver &solver, Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return std::nullopt; const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); return inferredRange.getConstantValue(); } static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal) { assert(oldVal.getType() == newVal.getType() && "Can't copy integer ranges between different types"); auto *oldState = solver.lookupState(oldVal); if (!oldState) return; (void)solver.getOrCreateState(newVal)->join( *oldState); } /// Patterned after SCCP static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value) { if (value.use_empty()) return failure(); std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); Type type = value.getType(); Location loc = value.getLoc(); Operation *maybeDefiningOp = value.getDefiningOp(); Dialect *valueDialect = maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr; if (auto shaped = dyn_cast(type)) { constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); } else { constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); } Operation *constOp = valueDialect->materializeConstant(rewriter, constAttr, type, loc); // Fall back to arith.constant if the dialect materializer doesn't know what // to do with an integer constant. if (!constOp) constOp = rewriter.getContext() ->getLoadedDialect() ->materializeConstant(rewriter, constAttr, type, loc); if (!constOp) return failure(); copyIntegerRange(solver, value, constOp->getResult(0)); rewriter.replaceAllUsesWith(value, constOp->getResult(0)); return success(); } namespace { class DataFlowListener : public RewriterBase::Listener { public: DataFlowListener(DataFlowSolver &s) : s(s) {} protected: void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) s.eraseState(res); } DataFlowSolver &s; }; /// Rewrite any results of `op` that were inferred to be constant integers to /// and replace their uses with that constant. Return success() if all results /// where thus replaced and the operation is erased. Also replace any block /// arguments with their constant values. struct MaterializeKnownConstantValues : public RewritePattern { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), solver(s) {} LogicalResult match(Operation *op) const override { if (matchPattern(op, m_Constant())) return failure(); auto needsReplacing = [&](Value v) { return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); }; bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); if (op->getNumRegions() == 0) return success(hasConstantResults); bool hasConstantRegionArgs = false; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { hasConstantRegionArgs |= llvm::any_of(block.getArguments(), needsReplacing); } } return success(hasConstantResults || hasConstantRegionArgs); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { bool replacedAll = (op->getNumResults() != 0); for (Value v : op->getResults()) replacedAll &= (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || v.use_empty()); if (replacedAll && isOpTriviallyDead(op)) { rewriter.eraseOp(op); return; } PatternRewriter::InsertionGuard guard(rewriter); for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { rewriter.setInsertionPointToStart(&block); for (BlockArgument &arg : block.getArguments()) { (void)maybeReplaceWithConstant(solver, rewriter, arg); } } } } private: DataFlowSolver &solver; }; template struct DeleteTrivialRem : public OpRewritePattern { DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) : OpRewritePattern(context), solver(s) {} LogicalResult matchAndRewrite(RemOp op, PatternRewriter &rewriter) const override { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); auto maybeModulus = getConstantIntValue(rhs); if (!maybeModulus.has_value()) return failure(); int64_t modulus = *maybeModulus; if (modulus <= 0) return failure(); auto *maybeLhsRange = solver.lookupState(lhs); if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); const APInt &min = isa(op) ? lhsRange.umin() : lhsRange.smin(); const APInt &max = isa(op) ? lhsRange.umax() : lhsRange.smax(); // The minima and maxima here are given as closed ranges, we must be // strictly less than the modulus. if (min.isNegative() || min.uge(modulus)) return failure(); if (max.isNegative() || max.uge(modulus)) return failure(); if (!min.ule(max)) return failure(); // With all those conditions out of the way, we know thas this invocation of // a remainder is a noop because the input is strictly within the range // [0, modulus), so get rid of it. rewriter.replaceOp(op, ValueRange{lhs}); return success(); } private: DataFlowSolver &solver; }; /// Gather ranges for all the values in `values`. Appends to the existing /// vector. static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, SmallVectorImpl &ranges) { for (Value val : values) { auto *maybeInferredRange = solver.lookupState(val); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); ranges.push_back(inferredRange); } return success(); } /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, /// return shaped type as well. static Type getTargetType(Type srcType, unsigned targetBitwidth) { auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); if (auto shaped = dyn_cast(srcType)) return shaped.clone(dstType); assert(srcType.isIntOrIndex() && "Invalid src type"); return dstType; } namespace { // Enum for tracking which type of truncation should be performed // to narrow an operation, if any. enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; } // namespace /// If the values within `range` can be represented using only `width` bits, /// return the kind of truncation needed to preserve that property. /// /// This check relies on the fact that the signed and unsigned ranges are both /// always correct, but that one might be an approximation of the other, /// so we want to use the correct truncation operation. static CastKind checkTruncatability(const ConstantIntRanges &range, unsigned targetWidth) { unsigned srcWidth = range.smin().getBitWidth(); if (srcWidth <= targetWidth) return CastKind::None; unsigned removedWidth = srcWidth - targetWidth; // The sign bits need to extend into the sign bit of the target width. For // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign // bits. bool canTruncateSigned = range.smin().getNumSignBits() >= (removedWidth + 1) && range.smax().getNumSignBits() >= (removedWidth + 1); bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && range.umax().countLeadingZeros() >= removedWidth; if (canTruncateSigned && canTruncateUnsigned) return CastKind::Both; if (canTruncateSigned) return CastKind::Signed; if (canTruncateUnsigned) return CastKind::Unsigned; return CastKind::None; } static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { if (lhs == CastKind::None || rhs == CastKind::None) return CastKind::None; if (lhs == CastKind::Both) return rhs; if (rhs == CastKind::Both) return lhs; if (lhs == rhs) return lhs; return CastKind::None; } static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, CastKind castKind) { Type srcType = src.getType(); assert(isa(srcType) == isa(dstType) && "Mixing vector and non-vector types"); assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); Type srcElemType = getElementTypeOrSelf(srcType); Type dstElemType = getElementTypeOrSelf(dstType); assert(srcElemType.isIntOrIndex() && "Invalid src type"); assert(dstElemType.isIntOrIndex() && "Invalid dst type"); if (srcType == dstType) return src; if (isa(srcElemType) || isa(dstElemType)) { if (castKind == CastKind::Signed) return builder.create(loc, dstType, src); return builder.create(loc, dstType, src); } auto srcInt = cast(srcElemType); auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) return builder.create(loc, dstType, src); if (castKind == CastKind::Signed) return builder.create(loc, dstType, src); return builder.create(loc, dstType, src); } struct NarrowElementwise final : OpTraitRewritePattern { NarrowElementwise(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumResults() == 0) return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); SmallVector ranges; if (failed(collectRanges(solver, op->getOperands(), ranges))) return rewriter.notifyMatchFailure(op, "input without specified range"); if (failed(collectRanges(solver, op->getResults(), ranges))) return rewriter.notifyMatchFailure(op, "output without specified range"); Type srcType = op->getResult(0).getType(); if (!llvm::all_equal(op->getResultTypes())) return rewriter.notifyMatchFailure(op, "mismatched result types"); if (op->getNumOperands() == 0 || !llvm::all_of(op->getOperandTypes(), [=](Type t) { return t == srcType; })) return rewriter.notifyMatchFailure( op, "no operands or operand types don't match result type"); for (unsigned targetBitwidth : targetBitwidths) { CastKind castKind = CastKind::Both; for (const ConstantIntRanges &range : ranges) { castKind = mergeCastKinds(castKind, checkTruncatability(range, targetBitwidth)); if (castKind == CastKind::None) break; } if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; Location loc = op->getLoc(); IRMapping mapping; for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { CastKind argCastKind = castKind; // When dealing with `index` values, preserve non-negativity in the // index_casts since we can't recover this in unsigned when equivalent. if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) argCastKind = CastKind::Both; Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); mapping.map(arg, newArg); } Operation *newOp = rewriter.clone(*op, mapping); rewriter.modifyOpInPlace(newOp, [&]() { for (OpResult res : newOp->getResults()) { res.setType(targetType); } }); SmallVector newResults; for (auto [newRes, oldRes] : llvm::zip_equal(newOp->getResults(), op->getResults())) { Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); copyIntegerRange(solver, oldRes, castBack); newResults.push_back(castBack); } rewriter.replaceOp(op, newResults); return success(); } return failure(); } private: DataFlowSolver &solver; SmallVector targetBitwidths; }; struct NarrowCmpI final : OpRewritePattern { NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpRewritePattern(context), solver(s), targetBitwidths(target) {} LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); SmallVector ranges; if (failed(collectRanges(solver, op.getOperands(), ranges))) return failure(); const ConstantIntRanges &lhsRange = ranges[0]; const ConstantIntRanges &rhsRange = ranges[1]; Type srcType = lhs.getType(); for (unsigned targetBitwidth : targetBitwidths) { CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); // Note: this includes target width > src width. if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; Location loc = op->getLoc(); IRMapping mapping; Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); mapping.map(lhs, lhsCast); mapping.map(rhs, rhsCast); Operation *newOp = rewriter.clone(*op, mapping); copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); rewriter.replaceOp(op, newOp->getResults()); return success(); } return failure(); } private: DataFlowSolver &solver; SmallVector targetBitwidths; }; /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg /// This pattern assumes all passed `targetBitwidths` are not wider than index /// type. template struct FoldIndexCastChain final : OpRewritePattern { FoldIndexCastChain(MLIRContext *context, ArrayRef target) : OpRewritePattern(context), targetBitwidths(target) {} LogicalResult matchAndRewrite(CastOp op, PatternRewriter &rewriter) const override { auto srcOp = op.getIn().template getDefiningOp(); if (!srcOp) return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); Value src = srcOp.getIn(); if (src.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "outer types don't match"); if (!srcOp.getType().isIndex()) return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); auto intType = dyn_cast(op.getType()); if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) return failure(); rewriter.replaceOp(op, src); return success(); } private: SmallVector targetBitwidths; }; struct IntRangeOptimizationsPass final : arith::impl::ArithIntRangeOptsBase { void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); populateIntRangeOptimizationsPatterns(patterns, solver); GreedyRewriteConfig config; config.listener = &listener; if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; struct IntRangeNarrowingPass final : arith::impl::ArithIntRangeNarrowingBase { using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); GreedyRewriteConfig config; // We specifically need bottom-up traversal as cmpi pattern needs range // data, attached to its original argument values. config.useTopDownTraversal = false; config.listener = &listener; if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; } // namespace void mlir::arith::populateIntRangeOptimizationsPatterns( RewritePatternSet &patterns, DataFlowSolver &solver) { patterns.add, DeleteTrivialRem>(patterns.getContext(), solver); } void mlir::arith::populateIntRangeNarrowingPatterns( RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported) { patterns.add(patterns.getContext(), solver, bitwidthsSupported); patterns.add, FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); }