xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp (revision 0f8a6b7d03550cb58cf49535af2de2230abfe997)
1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "mlir/Dialect/Arith/Transforms/Passes.h"
12 
13 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
18 
19 namespace mlir {
20 namespace arith {
21 #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
22 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
23 } // namespace arith
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::arith;
28 using namespace mlir::dataflow;
29 
30 /// Succeeds when a value is statically non-negative in that it has a lower
31 /// bound on its value (if it is treated as signed) and that bound is
32 /// non-negative.
33 // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
34 // relies on this. These transformations may not be valid for 32bit index,
35 // need more investigation.
36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
37   auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
38   if (!result || result->getValue().isUninitialized())
39     return failure();
40   const ConstantIntRanges &range = result->getValue().getValue();
41   return success(range.smin().isNonNegative());
42 }
43 
44 /// Succeeds if an op can be converted to its unsigned equivalent without
45 /// changing its semantics. This is the case when none of its openands or
46 /// results can be below 0 when analyzed from a signed perspective.
47 static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
48                                            Operation *op) {
49   auto nonNegativePred = [&solver](Value v) -> bool {
50     return succeeded(staticallyNonNegative(solver, v));
51   };
52   return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
53                  llvm::all_of(op->getResults(), nonNegativePred));
54 }
55 
56 /// Succeeds when the comparison predicate is a signed operation and all the
57 /// operands are non-negative, indicating that the cmpi operation `op` can have
58 /// its predicate changed to an unsigned equivalent.
59 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
60   CmpIPredicate pred = op.getPredicate();
61   switch (pred) {
62   case CmpIPredicate::sle:
63   case CmpIPredicate::slt:
64   case CmpIPredicate::sge:
65   case CmpIPredicate::sgt:
66     return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
67       return succeeded(staticallyNonNegative(solver, v));
68     }));
69   default:
70     return failure();
71   }
72 }
73 
74 /// Return the unsigned equivalent of a signed comparison predicate,
75 /// or the predicate itself if there is none.
76 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
77   switch (pred) {
78   case CmpIPredicate::sle:
79     return CmpIPredicate::ule;
80   case CmpIPredicate::slt:
81     return CmpIPredicate::ult;
82   case CmpIPredicate::sge:
83     return CmpIPredicate::uge;
84   case CmpIPredicate::sgt:
85     return CmpIPredicate::ugt;
86   default:
87     return pred;
88   }
89 }
90 
91 namespace {
92 class DataFlowListener : public RewriterBase::Listener {
93 public:
94   DataFlowListener(DataFlowSolver &s) : s(s) {}
95 
96 protected:
97   void notifyOperationErased(Operation *op) override {
98     s.eraseState(s.getProgramPointAfter(op));
99     for (Value res : op->getResults())
100       s.eraseState(res);
101   }
102 
103   DataFlowSolver &s;
104 };
105 
106 template <typename Signed, typename Unsigned>
107 struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
108   ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
109       : OpRewritePattern<Signed>(context), solver(s) {}
110 
111   LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
112     if (failed(
113             staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
114       return failure();
115 
116     rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
117                                     op->getAttrs());
118     return success();
119   }
120 
121 private:
122   DataFlowSolver &solver;
123 };
124 
125 struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
126   ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
127       : OpRewritePattern<CmpIOp>(context), solver(s) {}
128 
129   LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
130     if (failed(isCmpIConvertable(this->solver, op)))
131       return failure();
132 
133     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
134                                   op.getLhs(), op.getRhs());
135     return success();
136   }
137 
138 private:
139   DataFlowSolver &solver;
140 };
141 
142 struct ArithUnsignedWhenEquivalentPass
143     : public arith::impl::ArithUnsignedWhenEquivalentBase<
144           ArithUnsignedWhenEquivalentPass> {
145 
146   void runOnOperation() override {
147     Operation *op = getOperation();
148     MLIRContext *ctx = op->getContext();
149     DataFlowSolver solver;
150     solver.load<DeadCodeAnalysis>();
151     solver.load<IntegerRangeAnalysis>();
152     if (failed(solver.initializeAndRun(op)))
153       return signalPassFailure();
154 
155     DataFlowListener listener(solver);
156 
157     RewritePatternSet patterns(ctx);
158     populateUnsignedWhenEquivalentPatterns(patterns, solver);
159 
160     walkAndApplyPatterns(op, std::move(patterns), &listener);
161   }
162 };
163 } // end anonymous namespace
164 
165 void mlir::arith::populateUnsignedWhenEquivalentPatterns(
166     RewritePatternSet &patterns, DataFlowSolver &solver) {
167   patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
168                ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
169                ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
170                ConvertOpToUnsigned<RemSIOp, RemUIOp>,
171                ConvertOpToUnsigned<MinSIOp, MinUIOp>,
172                ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
173                ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
174       patterns.getContext(), solver);
175 }
176 
177 std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
178   return std::make_unique<ArithUnsignedWhenEquivalentPass>();
179 }
180