1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with 2 // unsigned 3 // ones when all their arguments and results are statically non-negative --===// 4 // 5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 6 // See https://llvm.org/LICENSE.txt for license information. 7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8 // 9 //===----------------------------------------------------------------------===// 10 11 #include "mlir/Dialect/Arith/Transforms/Passes.h" 12 13 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 14 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 18 19 namespace mlir { 20 namespace arith { 21 #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT 22 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 23 } // namespace arith 24 } // namespace mlir 25 26 using namespace mlir; 27 using namespace mlir::arith; 28 using namespace mlir::dataflow; 29 30 /// Succeeds when a value is statically non-negative in that it has a lower 31 /// bound on its value (if it is treated as signed) and that bound is 32 /// non-negative. 33 // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern 34 // relies on this. These transformations may not be valid for 32bit index, 35 // need more investigation. 36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { 37 auto *result = solver.lookupState<IntegerValueRangeLattice>(v); 38 if (!result || result->getValue().isUninitialized()) 39 return failure(); 40 const ConstantIntRanges &range = result->getValue().getValue(); 41 return success(range.smin().isNonNegative()); 42 } 43 44 /// Succeeds if an op can be converted to its unsigned equivalent without 45 /// changing its semantics. This is the case when none of its openands or 46 /// results can be below 0 when analyzed from a signed perspective. 47 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, 48 Operation *op) { 49 auto nonNegativePred = [&solver](Value v) -> bool { 50 return succeeded(staticallyNonNegative(solver, v)); 51 }; 52 return success(llvm::all_of(op->getOperands(), nonNegativePred) && 53 llvm::all_of(op->getResults(), nonNegativePred)); 54 } 55 56 /// Succeeds when the comparison predicate is a signed operation and all the 57 /// operands are non-negative, indicating that the cmpi operation `op` can have 58 /// its predicate changed to an unsigned equivalent. 59 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) { 60 CmpIPredicate pred = op.getPredicate(); 61 switch (pred) { 62 case CmpIPredicate::sle: 63 case CmpIPredicate::slt: 64 case CmpIPredicate::sge: 65 case CmpIPredicate::sgt: 66 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool { 67 return succeeded(staticallyNonNegative(solver, v)); 68 })); 69 default: 70 return failure(); 71 } 72 } 73 74 /// Return the unsigned equivalent of a signed comparison predicate, 75 /// or the predicate itself if there is none. 76 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { 77 switch (pred) { 78 case CmpIPredicate::sle: 79 return CmpIPredicate::ule; 80 case CmpIPredicate::slt: 81 return CmpIPredicate::ult; 82 case CmpIPredicate::sge: 83 return CmpIPredicate::uge; 84 case CmpIPredicate::sgt: 85 return CmpIPredicate::ugt; 86 default: 87 return pred; 88 } 89 } 90 91 namespace { 92 class DataFlowListener : public RewriterBase::Listener { 93 public: 94 DataFlowListener(DataFlowSolver &s) : s(s) {} 95 96 protected: 97 void notifyOperationErased(Operation *op) override { 98 s.eraseState(s.getProgramPointAfter(op)); 99 for (Value res : op->getResults()) 100 s.eraseState(res); 101 } 102 103 DataFlowSolver &s; 104 }; 105 106 template <typename Signed, typename Unsigned> 107 struct ConvertOpToUnsigned final : OpRewritePattern<Signed> { 108 ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s) 109 : OpRewritePattern<Signed>(context), solver(s) {} 110 111 LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override { 112 if (failed( 113 staticallyNonNegative(this->solver, static_cast<Operation *>(op)))) 114 return failure(); 115 116 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(), 117 op->getAttrs()); 118 return success(); 119 } 120 121 private: 122 DataFlowSolver &solver; 123 }; 124 125 struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> { 126 ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s) 127 : OpRewritePattern<CmpIOp>(context), solver(s) {} 128 129 LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override { 130 if (failed(isCmpIConvertable(this->solver, op))) 131 return failure(); 132 133 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()), 134 op.getLhs(), op.getRhs()); 135 return success(); 136 } 137 138 private: 139 DataFlowSolver &solver; 140 }; 141 142 struct ArithUnsignedWhenEquivalentPass 143 : public arith::impl::ArithUnsignedWhenEquivalentBase< 144 ArithUnsignedWhenEquivalentPass> { 145 146 void runOnOperation() override { 147 Operation *op = getOperation(); 148 MLIRContext *ctx = op->getContext(); 149 DataFlowSolver solver; 150 solver.load<DeadCodeAnalysis>(); 151 solver.load<IntegerRangeAnalysis>(); 152 if (failed(solver.initializeAndRun(op))) 153 return signalPassFailure(); 154 155 DataFlowListener listener(solver); 156 157 RewritePatternSet patterns(ctx); 158 populateUnsignedWhenEquivalentPatterns(patterns, solver); 159 160 walkAndApplyPatterns(op, std::move(patterns), &listener); 161 } 162 }; 163 } // end anonymous namespace 164 165 void mlir::arith::populateUnsignedWhenEquivalentPatterns( 166 RewritePatternSet &patterns, DataFlowSolver &solver) { 167 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, 168 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, 169 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, 170 ConvertOpToUnsigned<RemSIOp, RemUIOp>, 171 ConvertOpToUnsigned<MinSIOp, MinUIOp>, 172 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, 173 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( 174 patterns.getContext(), solver); 175 } 176 177 std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() { 178 return std::make_unique<ArithUnsignedWhenEquivalentPass>(); 179 } 180