xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Analysis/DataFlowFramework.h"
12 #include "mlir/Dialect/Arith/Transforms/Passes.h"
13 
14 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32 } // namespace mlir::arith
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 using namespace mlir::dataflow;
37 
38 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
39                                                   Value value) {
40   auto *maybeInferredRange =
41       solver.lookupState<IntegerValueRangeLattice>(value);
42   if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
43     return std::nullopt;
44   const ConstantIntRanges &inferredRange =
45       maybeInferredRange->getValue().getValue();
46   return inferredRange.getConstantValue();
47 }
48 
49 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
50                              Value newVal) {
51   assert(oldVal.getType() == newVal.getType() &&
52          "Can't copy integer ranges between different types");
53   auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
54   if (!oldState)
55     return;
56   (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
57       *oldState);
58 }
59 
60 /// Patterned after SCCP
61 static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
62                                               PatternRewriter &rewriter,
63                                               Value value) {
64   if (value.use_empty())
65     return failure();
66   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
67   if (!maybeConstValue.has_value())
68     return failure();
69 
70   Type type = value.getType();
71   Location loc = value.getLoc();
72   Operation *maybeDefiningOp = value.getDefiningOp();
73   Dialect *valueDialect =
74       maybeDefiningOp ? maybeDefiningOp->getDialect()
75                       : value.getParentRegion()->getParentOp()->getDialect();
76 
77   Attribute constAttr;
78   if (auto shaped = dyn_cast<ShapedType>(type)) {
79     constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
80   } else {
81     constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
82   }
83   Operation *constOp =
84       valueDialect->materializeConstant(rewriter, constAttr, type, loc);
85   // Fall back to arith.constant if the dialect materializer doesn't know what
86   // to do with an integer constant.
87   if (!constOp)
88     constOp = rewriter.getContext()
89                   ->getLoadedDialect<ArithDialect>()
90                   ->materializeConstant(rewriter, constAttr, type, loc);
91   if (!constOp)
92     return failure();
93 
94   copyIntegerRange(solver, value, constOp->getResult(0));
95   rewriter.replaceAllUsesWith(value, constOp->getResult(0));
96   return success();
97 }
98 
99 namespace {
100 class DataFlowListener : public RewriterBase::Listener {
101 public:
102   DataFlowListener(DataFlowSolver &s) : s(s) {}
103 
104 protected:
105   void notifyOperationErased(Operation *op) override {
106     s.eraseState(s.getProgramPointAfter(op));
107     for (Value res : op->getResults())
108       s.eraseState(res);
109   }
110 
111   DataFlowSolver &s;
112 };
113 
114 /// Rewrite any results of `op` that were inferred to be constant integers to
115 /// and replace their uses with that constant. Return success() if all results
116 /// where thus replaced and the operation is erased. Also replace any block
117 /// arguments with their constant values.
118 struct MaterializeKnownConstantValues : public RewritePattern {
119   MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
120       : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
121         solver(s) {}
122 
123   LogicalResult match(Operation *op) const override {
124     if (matchPattern(op, m_Constant()))
125       return failure();
126 
127     auto needsReplacing = [&](Value v) {
128       return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
129     };
130     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
131     if (op->getNumRegions() == 0)
132       return success(hasConstantResults);
133     bool hasConstantRegionArgs = false;
134     for (Region &region : op->getRegions()) {
135       for (Block &block : region.getBlocks()) {
136         hasConstantRegionArgs |=
137             llvm::any_of(block.getArguments(), needsReplacing);
138       }
139     }
140     return success(hasConstantResults || hasConstantRegionArgs);
141   }
142 
143   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
144     bool replacedAll = (op->getNumResults() != 0);
145     for (Value v : op->getResults())
146       replacedAll &=
147           (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
148            v.use_empty());
149     if (replacedAll && isOpTriviallyDead(op)) {
150       rewriter.eraseOp(op);
151       return;
152     }
153 
154     PatternRewriter::InsertionGuard guard(rewriter);
155     for (Region &region : op->getRegions()) {
156       for (Block &block : region.getBlocks()) {
157         rewriter.setInsertionPointToStart(&block);
158         for (BlockArgument &arg : block.getArguments()) {
159           (void)maybeReplaceWithConstant(solver, rewriter, arg);
160         }
161       }
162     }
163   }
164 
165 private:
166   DataFlowSolver &solver;
167 };
168 
169 template <typename RemOp>
170 struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
171   DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
172       : OpRewritePattern<RemOp>(context), solver(s) {}
173 
174   LogicalResult matchAndRewrite(RemOp op,
175                                 PatternRewriter &rewriter) const override {
176     Value lhs = op.getOperand(0);
177     Value rhs = op.getOperand(1);
178     auto maybeModulus = getConstantIntValue(rhs);
179     if (!maybeModulus.has_value())
180       return failure();
181     int64_t modulus = *maybeModulus;
182     if (modulus <= 0)
183       return failure();
184     auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
185     if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
186       return failure();
187     const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
188     const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
189     const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
190     // The minima and maxima here are given as closed ranges, we must be
191     // strictly less than the modulus.
192     if (min.isNegative() || min.uge(modulus))
193       return failure();
194     if (max.isNegative() || max.uge(modulus))
195       return failure();
196     if (!min.ule(max))
197       return failure();
198 
199     // With all those conditions out of the way, we know thas this invocation of
200     // a remainder is a noop because the input is strictly within the range
201     // [0, modulus), so get rid of it.
202     rewriter.replaceOp(op, ValueRange{lhs});
203     return success();
204   }
205 
206 private:
207   DataFlowSolver &solver;
208 };
209 
210 /// Gather ranges for all the values in `values`. Appends to the existing
211 /// vector.
212 static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
213                                    SmallVectorImpl<ConstantIntRanges> &ranges) {
214   for (Value val : values) {
215     auto *maybeInferredRange =
216         solver.lookupState<IntegerValueRangeLattice>(val);
217     if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
218       return failure();
219 
220     const ConstantIntRanges &inferredRange =
221         maybeInferredRange->getValue().getValue();
222     ranges.push_back(inferredRange);
223   }
224   return success();
225 }
226 
227 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
228 /// return shaped type as well.
229 static Type getTargetType(Type srcType, unsigned targetBitwidth) {
230   auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
231   if (auto shaped = dyn_cast<ShapedType>(srcType))
232     return shaped.clone(dstType);
233 
234   assert(srcType.isIntOrIndex() && "Invalid src type");
235   return dstType;
236 }
237 
238 namespace {
239 // Enum for tracking which type of truncation should be performed
240 // to narrow an operation, if any.
241 enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
242 } // namespace
243 
244 /// If the values within `range` can be represented using only `width` bits,
245 /// return the kind of truncation needed to preserve that property.
246 ///
247 /// This check relies on the fact that the signed and unsigned ranges are both
248 /// always correct, but that one might be an approximation of the other,
249 /// so we want to use the correct truncation operation.
250 static CastKind checkTruncatability(const ConstantIntRanges &range,
251                                     unsigned targetWidth) {
252   unsigned srcWidth = range.smin().getBitWidth();
253   if (srcWidth <= targetWidth)
254     return CastKind::None;
255   unsigned removedWidth = srcWidth - targetWidth;
256   // The sign bits need to extend into the sign bit of the target width. For
257   // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
258   // bits.
259   bool canTruncateSigned =
260       range.smin().getNumSignBits() >= (removedWidth + 1) &&
261       range.smax().getNumSignBits() >= (removedWidth + 1);
262   bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
263                              range.umax().countLeadingZeros() >= removedWidth;
264   if (canTruncateSigned && canTruncateUnsigned)
265     return CastKind::Both;
266   if (canTruncateSigned)
267     return CastKind::Signed;
268   if (canTruncateUnsigned)
269     return CastKind::Unsigned;
270   return CastKind::None;
271 }
272 
273 static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
274   if (lhs == CastKind::None || rhs == CastKind::None)
275     return CastKind::None;
276   if (lhs == CastKind::Both)
277     return rhs;
278   if (rhs == CastKind::Both)
279     return lhs;
280   if (lhs == rhs)
281     return lhs;
282   return CastKind::None;
283 }
284 
285 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
286                     CastKind castKind) {
287   Type srcType = src.getType();
288   assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
289          "Mixing vector and non-vector types");
290   assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
291   Type srcElemType = getElementTypeOrSelf(srcType);
292   Type dstElemType = getElementTypeOrSelf(dstType);
293   assert(srcElemType.isIntOrIndex() && "Invalid src type");
294   assert(dstElemType.isIntOrIndex() && "Invalid dst type");
295   if (srcType == dstType)
296     return src;
297 
298   if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
299     if (castKind == CastKind::Signed)
300       return builder.create<arith::IndexCastOp>(loc, dstType, src);
301     return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
302   }
303 
304   auto srcInt = cast<IntegerType>(srcElemType);
305   auto dstInt = cast<IntegerType>(dstElemType);
306   if (dstInt.getWidth() < srcInt.getWidth())
307     return builder.create<arith::TruncIOp>(loc, dstType, src);
308 
309   if (castKind == CastKind::Signed)
310     return builder.create<arith::ExtSIOp>(loc, dstType, src);
311   return builder.create<arith::ExtUIOp>(loc, dstType, src);
312 }
313 
314 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
315   NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
316                     ArrayRef<unsigned> target)
317       : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
318 
319   using OpTraitRewritePattern::OpTraitRewritePattern;
320   LogicalResult matchAndRewrite(Operation *op,
321                                 PatternRewriter &rewriter) const override {
322     if (op->getNumResults() == 0)
323       return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
324 
325     SmallVector<ConstantIntRanges> ranges;
326     if (failed(collectRanges(solver, op->getOperands(), ranges)))
327       return rewriter.notifyMatchFailure(op, "input without specified range");
328     if (failed(collectRanges(solver, op->getResults(), ranges)))
329       return rewriter.notifyMatchFailure(op, "output without specified range");
330 
331     Type srcType = op->getResult(0).getType();
332     if (!llvm::all_equal(op->getResultTypes()))
333       return rewriter.notifyMatchFailure(op, "mismatched result types");
334     if (op->getNumOperands() == 0 ||
335         !llvm::all_of(op->getOperandTypes(),
336                       [=](Type t) { return t == srcType; }))
337       return rewriter.notifyMatchFailure(
338           op, "no operands or operand types don't match result type");
339 
340     for (unsigned targetBitwidth : targetBitwidths) {
341       CastKind castKind = CastKind::Both;
342       for (const ConstantIntRanges &range : ranges) {
343         castKind = mergeCastKinds(castKind,
344                                   checkTruncatability(range, targetBitwidth));
345         if (castKind == CastKind::None)
346           break;
347       }
348       if (castKind == CastKind::None)
349         continue;
350       Type targetType = getTargetType(srcType, targetBitwidth);
351       if (targetType == srcType)
352         continue;
353 
354       Location loc = op->getLoc();
355       IRMapping mapping;
356       for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
357         CastKind argCastKind = castKind;
358         // When dealing with `index` values, preserve non-negativity in the
359         // index_casts since we can't recover this in unsigned when equivalent.
360         if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
361           argCastKind = CastKind::Both;
362         Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
363         mapping.map(arg, newArg);
364       }
365 
366       Operation *newOp = rewriter.clone(*op, mapping);
367       rewriter.modifyOpInPlace(newOp, [&]() {
368         for (OpResult res : newOp->getResults()) {
369           res.setType(targetType);
370         }
371       });
372       SmallVector<Value> newResults;
373       for (auto [newRes, oldRes] :
374            llvm::zip_equal(newOp->getResults(), op->getResults())) {
375         Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
376         copyIntegerRange(solver, oldRes, castBack);
377         newResults.push_back(castBack);
378       }
379 
380       rewriter.replaceOp(op, newResults);
381       return success();
382     }
383     return failure();
384   }
385 
386 private:
387   DataFlowSolver &solver;
388   SmallVector<unsigned, 4> targetBitwidths;
389 };
390 
391 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
392   NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
393       : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
394 
395   LogicalResult matchAndRewrite(arith::CmpIOp op,
396                                 PatternRewriter &rewriter) const override {
397     Value lhs = op.getLhs();
398     Value rhs = op.getRhs();
399 
400     SmallVector<ConstantIntRanges> ranges;
401     if (failed(collectRanges(solver, op.getOperands(), ranges)))
402       return failure();
403     const ConstantIntRanges &lhsRange = ranges[0];
404     const ConstantIntRanges &rhsRange = ranges[1];
405 
406     Type srcType = lhs.getType();
407     for (unsigned targetBitwidth : targetBitwidths) {
408       CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
409       CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
410       CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
411       // Note: this includes target width > src width.
412       if (castKind == CastKind::None)
413         continue;
414 
415       Type targetType = getTargetType(srcType, targetBitwidth);
416       if (targetType == srcType)
417         continue;
418 
419       Location loc = op->getLoc();
420       IRMapping mapping;
421       Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
422       Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
423       mapping.map(lhs, lhsCast);
424       mapping.map(rhs, rhsCast);
425 
426       Operation *newOp = rewriter.clone(*op, mapping);
427       copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
428       rewriter.replaceOp(op, newOp->getResults());
429       return success();
430     }
431     return failure();
432   }
433 
434 private:
435   DataFlowSolver &solver;
436   SmallVector<unsigned, 4> targetBitwidths;
437 };
438 
439 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
440 /// This pattern assumes all passed `targetBitwidths` are not wider than index
441 /// type.
442 template <typename CastOp>
443 struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
444   FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
445       : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
446 
447   LogicalResult matchAndRewrite(CastOp op,
448                                 PatternRewriter &rewriter) const override {
449     auto srcOp = op.getIn().template getDefiningOp<CastOp>();
450     if (!srcOp)
451       return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
452 
453     Value src = srcOp.getIn();
454     if (src.getType() != op.getType())
455       return rewriter.notifyMatchFailure(op, "outer types don't match");
456 
457     if (!srcOp.getType().isIndex())
458       return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
459 
460     auto intType = dyn_cast<IntegerType>(op.getType());
461     if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
462       return failure();
463 
464     rewriter.replaceOp(op, src);
465     return success();
466   }
467 
468 private:
469   SmallVector<unsigned, 4> targetBitwidths;
470 };
471 
472 struct IntRangeOptimizationsPass final
473     : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
474 
475   void runOnOperation() override {
476     Operation *op = getOperation();
477     MLIRContext *ctx = op->getContext();
478     DataFlowSolver solver;
479     solver.load<DeadCodeAnalysis>();
480     solver.load<IntegerRangeAnalysis>();
481     if (failed(solver.initializeAndRun(op)))
482       return signalPassFailure();
483 
484     DataFlowListener listener(solver);
485 
486     RewritePatternSet patterns(ctx);
487     populateIntRangeOptimizationsPatterns(patterns, solver);
488 
489     GreedyRewriteConfig config;
490     config.listener = &listener;
491 
492     if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
493       signalPassFailure();
494   }
495 };
496 
497 struct IntRangeNarrowingPass final
498     : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
499   using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
500 
501   void runOnOperation() override {
502     Operation *op = getOperation();
503     MLIRContext *ctx = op->getContext();
504     DataFlowSolver solver;
505     solver.load<DeadCodeAnalysis>();
506     solver.load<IntegerRangeAnalysis>();
507     if (failed(solver.initializeAndRun(op)))
508       return signalPassFailure();
509 
510     DataFlowListener listener(solver);
511 
512     RewritePatternSet patterns(ctx);
513     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
514 
515     GreedyRewriteConfig config;
516     // We specifically need bottom-up traversal as cmpi pattern needs range
517     // data, attached to its original argument values.
518     config.useTopDownTraversal = false;
519     config.listener = &listener;
520 
521     if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
522       signalPassFailure();
523   }
524 };
525 } // namespace
526 
527 void mlir::arith::populateIntRangeOptimizationsPatterns(
528     RewritePatternSet &patterns, DataFlowSolver &solver) {
529   patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
530                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
531 }
532 
533 void mlir::arith::populateIntRangeNarrowingPatterns(
534     RewritePatternSet &patterns, DataFlowSolver &solver,
535     ArrayRef<unsigned> bitwidthsSupported) {
536   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
537                                               bitwidthsSupported);
538   patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
539                FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
540                                                        bitwidthsSupported);
541 }
542 
543 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
544   return std::make_unique<IntRangeOptimizationsPass>();
545 }
546