xref: /llvm-project/mlir/include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h (revision 9ce445b8c762450d5772dd05a1a10297b0856a4e)
1 //===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A collection of "safe" arithmetic helper methods.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
14 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
15 
16 #include <cassert>
17 #include <cinttypes>
18 #include <limits>
19 #include <type_traits>
20 
21 namespace mlir {
22 namespace sparse_tensor {
23 namespace detail {
24 
25 //===----------------------------------------------------------------------===//
26 //
27 // Safe comparison functions.
28 //
29 // Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which
30 // are careful to ensure that negatives are always considered strictly
31 // less than non-negatives regardless of the signedness of the types of
32 // the two arguments.  They are "safe" in that they guarantee to *always*
33 // give an output and that that output is correct; in particular this means
34 // they never use assertions or other mechanisms for "returning an error".
35 //
36 // These functions are C++17-compatible backports of the safe comparison
37 // functions added in C++20, and the implementations are based on the
38 // sample implementations provided by the standard:
39 // <https://en.cppreference.com/w/cpp/utility/intcmp>.
40 //
41 //===----------------------------------------------------------------------===//
42 
43 template <typename T, typename U>
safelyEQ(T t,U u)44 constexpr bool safelyEQ(T t, U u) noexcept {
45   using UT = std::make_unsigned_t<T>;
46   using UU = std::make_unsigned_t<U>;
47   if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
48     return t == u;
49   else if constexpr (std::is_signed_v<T>)
50     return t < 0 ? false : static_cast<UT>(t) == u;
51   else
52     return u < 0 ? false : t == static_cast<UU>(u);
53 }
54 
55 template <typename T, typename U>
safelyNE(T t,U u)56 constexpr bool safelyNE(T t, U u) noexcept {
57   return !safelyEQ(t, u);
58 }
59 
60 template <typename T, typename U>
safelyLT(T t,U u)61 constexpr bool safelyLT(T t, U u) noexcept {
62   using UT = std::make_unsigned_t<T>;
63   using UU = std::make_unsigned_t<U>;
64   if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
65     return t < u;
66   else if constexpr (std::is_signed_v<T>)
67     return t < 0 ? true : static_cast<UT>(t) < u;
68   else
69     return u < 0 ? false : t < static_cast<UU>(u);
70 }
71 
72 template <typename T, typename U>
safelyGT(T t,U u)73 constexpr bool safelyGT(T t, U u) noexcept {
74   return safelyLT(u, t);
75 }
76 
77 template <typename T, typename U>
safelyLE(T t,U u)78 constexpr bool safelyLE(T t, U u) noexcept {
79   return !safelyGT(t, u);
80 }
81 
82 template <typename T, typename U>
safelyGE(T t,U u)83 constexpr bool safelyGE(T t, U u) noexcept {
84   return !safelyLT(t, u);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 //
89 // Overflow checking functions.
90 //
91 // These functions use assertions to ensure correctness with respect to
92 // overflow/underflow.  Unlike the "safe" functions above, these "checked"
93 // functions only guarantee that *if* they return an answer then that answer
94 // is correct.  When assertions are enabled, they do their best to remain
95 // as fast as possible (since MLIR keeps assertions enabled by default,
96 // even for optimized builds).  When assertions are disabled, they use the
97 // standard unchecked implementations.
98 //
99 //===----------------------------------------------------------------------===//
100 
101 /// A version of `static_cast<To>` which checks for overflow/underflow.
102 /// The implementation avoids performing runtime assertions whenever
103 /// the types alone are sufficient to statically prove that overflow
104 /// cannot happen.
105 template <typename To, typename From>
checkOverflowCast(From x)106 [[nodiscard]] inline To checkOverflowCast(From x) {
107   // Check the lower bound. (For when casting from signed types.)
108   constexpr To minTo = std::numeric_limits<To>::min();
109   constexpr From minFrom = std::numeric_limits<From>::min();
110   if constexpr (!safelyGE(minFrom, minTo))
111     assert(safelyGE(x, minTo) && "cast would underflow");
112   // Check the upper bound.
113   constexpr To maxTo = std::numeric_limits<To>::max();
114   constexpr From maxFrom = std::numeric_limits<From>::max();
115   if constexpr (!safelyLE(maxFrom, maxTo))
116     assert(safelyLE(x, maxTo) && "cast would overflow");
117   // Now do the cast itself.
118   return static_cast<To>(x);
119 }
120 
121 /// A version of `operator*` on `uint64_t` which guards against overflows
122 /// (when assertions are enabled).
checkedMul(uint64_t lhs,uint64_t rhs)123 inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
124   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
125          "Integer overflow");
126   return lhs * rhs;
127 }
128 
129 } // namespace detail
130 } // namespace sparse_tensor
131 } // namespace mlir
132 
133 #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
134