xref: /openbsd-src/gnu/llvm/llvm/lib/Support/BranchProbability.cpp (revision 73471bf04ceb096474c7f0fa83b1b65c70a787a1)
109467b48Spatrick //===-------------- lib/Support/BranchProbability.cpp -----------*- C++ -*-===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This file implements Branch Probability class.
1009467b48Spatrick //
1109467b48Spatrick //===----------------------------------------------------------------------===//
1209467b48Spatrick 
1309467b48Spatrick #include "llvm/Support/BranchProbability.h"
1409467b48Spatrick #include "llvm/Config/llvm-config.h"
1509467b48Spatrick #include "llvm/Support/Debug.h"
1609467b48Spatrick #include "llvm/Support/Format.h"
1709467b48Spatrick #include "llvm/Support/raw_ostream.h"
1809467b48Spatrick #include <cassert>
19*73471bf0Spatrick #include <cmath>
2009467b48Spatrick 
2109467b48Spatrick using namespace llvm;
2209467b48Spatrick 
23097a140dSpatrick constexpr uint32_t BranchProbability::D;
2409467b48Spatrick 
print(raw_ostream & OS) const2509467b48Spatrick raw_ostream &BranchProbability::print(raw_ostream &OS) const {
2609467b48Spatrick   if (isUnknown())
2709467b48Spatrick     return OS << "?%";
2809467b48Spatrick 
2909467b48Spatrick   // Get a percentage rounded to two decimal digits. This avoids
3009467b48Spatrick   // implementation-defined rounding inside printf.
3109467b48Spatrick   double Percent = rint(((double)N / D) * 100.0 * 100.0) / 100.0;
3209467b48Spatrick   return OS << format("0x%08" PRIx32 " / 0x%08" PRIx32 " = %.2f%%", N, D,
3309467b48Spatrick                       Percent);
3409467b48Spatrick }
3509467b48Spatrick 
3609467b48Spatrick #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const3709467b48Spatrick LLVM_DUMP_METHOD void BranchProbability::dump() const { print(dbgs()) << '\n'; }
3809467b48Spatrick #endif
3909467b48Spatrick 
BranchProbability(uint32_t Numerator,uint32_t Denominator)4009467b48Spatrick BranchProbability::BranchProbability(uint32_t Numerator, uint32_t Denominator) {
4109467b48Spatrick   assert(Denominator > 0 && "Denominator cannot be 0!");
4209467b48Spatrick   assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
4309467b48Spatrick   if (Denominator == D)
4409467b48Spatrick     N = Numerator;
4509467b48Spatrick   else {
4609467b48Spatrick     uint64_t Prob64 =
4709467b48Spatrick         (Numerator * static_cast<uint64_t>(D) + Denominator / 2) / Denominator;
4809467b48Spatrick     N = static_cast<uint32_t>(Prob64);
4909467b48Spatrick   }
5009467b48Spatrick }
5109467b48Spatrick 
5209467b48Spatrick BranchProbability
getBranchProbability(uint64_t Numerator,uint64_t Denominator)5309467b48Spatrick BranchProbability::getBranchProbability(uint64_t Numerator,
5409467b48Spatrick                                         uint64_t Denominator) {
5509467b48Spatrick   assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
5609467b48Spatrick   // Scale down Denominator to fit in a 32-bit integer.
5709467b48Spatrick   int Scale = 0;
5809467b48Spatrick   while (Denominator > UINT32_MAX) {
5909467b48Spatrick     Denominator >>= 1;
6009467b48Spatrick     Scale++;
6109467b48Spatrick   }
6209467b48Spatrick   return BranchProbability(Numerator >> Scale, Denominator);
6309467b48Spatrick }
6409467b48Spatrick 
6509467b48Spatrick // If ConstD is not zero, then replace D by ConstD so that division and modulo
6609467b48Spatrick // operations by D can be optimized, in case this function is not inlined by the
6709467b48Spatrick // compiler.
6809467b48Spatrick template <uint32_t ConstD>
scale(uint64_t Num,uint32_t N,uint32_t D)6909467b48Spatrick static uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
7009467b48Spatrick   if (ConstD > 0)
7109467b48Spatrick     D = ConstD;
7209467b48Spatrick 
7309467b48Spatrick   assert(D && "divide by 0");
7409467b48Spatrick 
7509467b48Spatrick   // Fast path for multiplying by 1.0.
7609467b48Spatrick   if (!Num || D == N)
7709467b48Spatrick     return Num;
7809467b48Spatrick 
7909467b48Spatrick   // Split Num into upper and lower parts to multiply, then recombine.
8009467b48Spatrick   uint64_t ProductHigh = (Num >> 32) * N;
8109467b48Spatrick   uint64_t ProductLow = (Num & UINT32_MAX) * N;
8209467b48Spatrick 
8309467b48Spatrick   // Split into 32-bit digits.
8409467b48Spatrick   uint32_t Upper32 = ProductHigh >> 32;
8509467b48Spatrick   uint32_t Lower32 = ProductLow & UINT32_MAX;
8609467b48Spatrick   uint32_t Mid32Partial = ProductHigh & UINT32_MAX;
8709467b48Spatrick   uint32_t Mid32 = Mid32Partial + (ProductLow >> 32);
8809467b48Spatrick 
8909467b48Spatrick   // Carry.
9009467b48Spatrick   Upper32 += Mid32 < Mid32Partial;
9109467b48Spatrick 
9209467b48Spatrick   uint64_t Rem = (uint64_t(Upper32) << 32) | Mid32;
9309467b48Spatrick   uint64_t UpperQ = Rem / D;
9409467b48Spatrick 
9509467b48Spatrick   // Check for overflow.
9609467b48Spatrick   if (UpperQ > UINT32_MAX)
9709467b48Spatrick     return UINT64_MAX;
9809467b48Spatrick 
9909467b48Spatrick   Rem = ((Rem % D) << 32) | Lower32;
10009467b48Spatrick   uint64_t LowerQ = Rem / D;
10109467b48Spatrick   uint64_t Q = (UpperQ << 32) + LowerQ;
10209467b48Spatrick 
10309467b48Spatrick   // Check for overflow.
10409467b48Spatrick   return Q < LowerQ ? UINT64_MAX : Q;
10509467b48Spatrick }
10609467b48Spatrick 
scale(uint64_t Num) const10709467b48Spatrick uint64_t BranchProbability::scale(uint64_t Num) const {
10809467b48Spatrick   return ::scale<D>(Num, N, D);
10909467b48Spatrick }
11009467b48Spatrick 
scaleByInverse(uint64_t Num) const11109467b48Spatrick uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
11209467b48Spatrick   return ::scale<0>(Num, D, N);
11309467b48Spatrick }
114