109467b48Spatrick //===-------------- lib/Support/BranchProbability.cpp -----------*- C++ -*-===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This file implements Branch Probability class.
1009467b48Spatrick //
1109467b48Spatrick //===----------------------------------------------------------------------===//
1209467b48Spatrick
1309467b48Spatrick #include "llvm/Support/BranchProbability.h"
1409467b48Spatrick #include "llvm/Config/llvm-config.h"
1509467b48Spatrick #include "llvm/Support/Debug.h"
1609467b48Spatrick #include "llvm/Support/Format.h"
1709467b48Spatrick #include "llvm/Support/raw_ostream.h"
1809467b48Spatrick #include <cassert>
19*73471bf0Spatrick #include <cmath>
2009467b48Spatrick
2109467b48Spatrick using namespace llvm;
2209467b48Spatrick
23097a140dSpatrick constexpr uint32_t BranchProbability::D;
2409467b48Spatrick
print(raw_ostream & OS) const2509467b48Spatrick raw_ostream &BranchProbability::print(raw_ostream &OS) const {
2609467b48Spatrick if (isUnknown())
2709467b48Spatrick return OS << "?%";
2809467b48Spatrick
2909467b48Spatrick // Get a percentage rounded to two decimal digits. This avoids
3009467b48Spatrick // implementation-defined rounding inside printf.
3109467b48Spatrick double Percent = rint(((double)N / D) * 100.0 * 100.0) / 100.0;
3209467b48Spatrick return OS << format("0x%08" PRIx32 " / 0x%08" PRIx32 " = %.2f%%", N, D,
3309467b48Spatrick Percent);
3409467b48Spatrick }
3509467b48Spatrick
3609467b48Spatrick #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const3709467b48Spatrick LLVM_DUMP_METHOD void BranchProbability::dump() const { print(dbgs()) << '\n'; }
3809467b48Spatrick #endif
3909467b48Spatrick
BranchProbability(uint32_t Numerator,uint32_t Denominator)4009467b48Spatrick BranchProbability::BranchProbability(uint32_t Numerator, uint32_t Denominator) {
4109467b48Spatrick assert(Denominator > 0 && "Denominator cannot be 0!");
4209467b48Spatrick assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
4309467b48Spatrick if (Denominator == D)
4409467b48Spatrick N = Numerator;
4509467b48Spatrick else {
4609467b48Spatrick uint64_t Prob64 =
4709467b48Spatrick (Numerator * static_cast<uint64_t>(D) + Denominator / 2) / Denominator;
4809467b48Spatrick N = static_cast<uint32_t>(Prob64);
4909467b48Spatrick }
5009467b48Spatrick }
5109467b48Spatrick
5209467b48Spatrick BranchProbability
getBranchProbability(uint64_t Numerator,uint64_t Denominator)5309467b48Spatrick BranchProbability::getBranchProbability(uint64_t Numerator,
5409467b48Spatrick uint64_t Denominator) {
5509467b48Spatrick assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
5609467b48Spatrick // Scale down Denominator to fit in a 32-bit integer.
5709467b48Spatrick int Scale = 0;
5809467b48Spatrick while (Denominator > UINT32_MAX) {
5909467b48Spatrick Denominator >>= 1;
6009467b48Spatrick Scale++;
6109467b48Spatrick }
6209467b48Spatrick return BranchProbability(Numerator >> Scale, Denominator);
6309467b48Spatrick }
6409467b48Spatrick
6509467b48Spatrick // If ConstD is not zero, then replace D by ConstD so that division and modulo
6609467b48Spatrick // operations by D can be optimized, in case this function is not inlined by the
6709467b48Spatrick // compiler.
6809467b48Spatrick template <uint32_t ConstD>
scale(uint64_t Num,uint32_t N,uint32_t D)6909467b48Spatrick static uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
7009467b48Spatrick if (ConstD > 0)
7109467b48Spatrick D = ConstD;
7209467b48Spatrick
7309467b48Spatrick assert(D && "divide by 0");
7409467b48Spatrick
7509467b48Spatrick // Fast path for multiplying by 1.0.
7609467b48Spatrick if (!Num || D == N)
7709467b48Spatrick return Num;
7809467b48Spatrick
7909467b48Spatrick // Split Num into upper and lower parts to multiply, then recombine.
8009467b48Spatrick uint64_t ProductHigh = (Num >> 32) * N;
8109467b48Spatrick uint64_t ProductLow = (Num & UINT32_MAX) * N;
8209467b48Spatrick
8309467b48Spatrick // Split into 32-bit digits.
8409467b48Spatrick uint32_t Upper32 = ProductHigh >> 32;
8509467b48Spatrick uint32_t Lower32 = ProductLow & UINT32_MAX;
8609467b48Spatrick uint32_t Mid32Partial = ProductHigh & UINT32_MAX;
8709467b48Spatrick uint32_t Mid32 = Mid32Partial + (ProductLow >> 32);
8809467b48Spatrick
8909467b48Spatrick // Carry.
9009467b48Spatrick Upper32 += Mid32 < Mid32Partial;
9109467b48Spatrick
9209467b48Spatrick uint64_t Rem = (uint64_t(Upper32) << 32) | Mid32;
9309467b48Spatrick uint64_t UpperQ = Rem / D;
9409467b48Spatrick
9509467b48Spatrick // Check for overflow.
9609467b48Spatrick if (UpperQ > UINT32_MAX)
9709467b48Spatrick return UINT64_MAX;
9809467b48Spatrick
9909467b48Spatrick Rem = ((Rem % D) << 32) | Lower32;
10009467b48Spatrick uint64_t LowerQ = Rem / D;
10109467b48Spatrick uint64_t Q = (UpperQ << 32) + LowerQ;
10209467b48Spatrick
10309467b48Spatrick // Check for overflow.
10409467b48Spatrick return Q < LowerQ ? UINT64_MAX : Q;
10509467b48Spatrick }
10609467b48Spatrick
scale(uint64_t Num) const10709467b48Spatrick uint64_t BranchProbability::scale(uint64_t Num) const {
10809467b48Spatrick return ::scale<D>(Num, N, D);
10909467b48Spatrick }
11009467b48Spatrick
scaleByInverse(uint64_t Num) const11109467b48Spatrick uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
11209467b48Spatrick return ::scale<0>(Num, D, N);
11309467b48Spatrick }
114