1ea8ed5cbSbixia1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===//
2ea8ed5cbSbixia1 //
3ea8ed5cbSbixia1 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ea8ed5cbSbixia1 // See https://llvm.org/LICENSE.txt for license information.
5ea8ed5cbSbixia1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ea8ed5cbSbixia1 //
7ea8ed5cbSbixia1 //===----------------------------------------------------------------------===//
8ea8ed5cbSbixia1 //
9ea8ed5cbSbixia1 // This file implements f16 and bf16 to support the compilation and execution
10ea8ed5cbSbixia1 // of programs using these types.
11ea8ed5cbSbixia1 //
12ea8ed5cbSbixia1 //===----------------------------------------------------------------------===//
13ea8ed5cbSbixia1
14ea8ed5cbSbixia1 #include "mlir/ExecutionEngine/Float16bits.h"
150fca5c5fSwren romano
160fca5c5fSwren romano #ifdef MLIR_FLOAT16_DEFINE_FUNCTIONS // We are building this library
170fca5c5fSwren romano
18b3127769SBenjamin Kramer #include <cmath>
19f695554aSBenjamin Kramer #include <cstring>
20ea8ed5cbSbixia1
21ea8ed5cbSbixia1 namespace {
22ea8ed5cbSbixia1
23ea8ed5cbSbixia1 // Union used to make the int/float aliasing explicit so we can access the raw
24ea8ed5cbSbixia1 // bits.
25ea8ed5cbSbixia1 union Float32Bits {
26ea8ed5cbSbixia1 uint32_t u;
27ea8ed5cbSbixia1 float f;
28ea8ed5cbSbixia1 };
29ea8ed5cbSbixia1
30ea8ed5cbSbixia1 const uint32_t kF32MantiBits = 23;
31ea8ed5cbSbixia1 const uint32_t kF32HalfMantiBitDiff = 13;
32ea8ed5cbSbixia1 const uint32_t kF32HalfBitDiff = 16;
33ea8ed5cbSbixia1 const Float32Bits kF32Magic = {113 << kF32MantiBits};
34ea8ed5cbSbixia1 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
35ea8ed5cbSbixia1
36ea8ed5cbSbixia1 // Constructs the 16 bit representation for a half precision value from a float
37ea8ed5cbSbixia1 // value. This implementation is adapted from Eigen.
float2half(float floatValue)38ea8ed5cbSbixia1 uint16_t float2half(float floatValue) {
39ea8ed5cbSbixia1 const Float32Bits inf = {255 << kF32MantiBits};
40ea8ed5cbSbixia1 const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
41ea8ed5cbSbixia1 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
42ea8ed5cbSbixia1 << kF32MantiBits};
43ea8ed5cbSbixia1 uint32_t signMask = 0x80000000u;
44ea8ed5cbSbixia1 uint16_t halfValue = static_cast<uint16_t>(0x0u);
45ea8ed5cbSbixia1 Float32Bits f;
46ea8ed5cbSbixia1 f.f = floatValue;
47ea8ed5cbSbixia1 uint32_t sign = f.u & signMask;
48ea8ed5cbSbixia1 f.u ^= sign;
49ea8ed5cbSbixia1
50ea8ed5cbSbixia1 if (f.u >= f16max.u) {
51ea8ed5cbSbixia1 const uint32_t halfQnan = 0x7e00;
52ea8ed5cbSbixia1 const uint32_t halfInf = 0x7c00;
53ea8ed5cbSbixia1 // Inf or NaN (all exponent bits set).
54ea8ed5cbSbixia1 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
55ea8ed5cbSbixia1 } else {
56ea8ed5cbSbixia1 // (De)normalized number or zero.
57ea8ed5cbSbixia1 if (f.u < kF32Magic.u) {
58ea8ed5cbSbixia1 // The resulting FP16 is subnormal or zero.
59ea8ed5cbSbixia1 //
60ea8ed5cbSbixia1 // Use a magic value to align our 10 mantissa bits at the bottom of the
61ea8ed5cbSbixia1 // float. As long as FP addition is round-to-nearest-even this works.
62ea8ed5cbSbixia1 f.f += denormMagic.f;
63ea8ed5cbSbixia1
64ea8ed5cbSbixia1 halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
65ea8ed5cbSbixia1 } else {
66ea8ed5cbSbixia1 uint32_t mantOdd =
67ea8ed5cbSbixia1 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.
68ea8ed5cbSbixia1
69ea8ed5cbSbixia1 // Update exponent, rounding bias part 1. The following expressions are
70ea8ed5cbSbixia1 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
71ea8ed5cbSbixia1 // 0xfff`, but without arithmetic overflow.
72ea8ed5cbSbixia1 f.u += 0xc8000fffU;
73ea8ed5cbSbixia1 // Rounding bias part 2.
74ea8ed5cbSbixia1 f.u += mantOdd;
75ea8ed5cbSbixia1 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
76ea8ed5cbSbixia1 }
77ea8ed5cbSbixia1 }
78ea8ed5cbSbixia1
79ea8ed5cbSbixia1 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
80ea8ed5cbSbixia1 return halfValue;
81ea8ed5cbSbixia1 }
82ea8ed5cbSbixia1
83ea8ed5cbSbixia1 // Converts the 16 bit representation of a half precision value to a float
84ea8ed5cbSbixia1 // value. This implementation is adapted from Eigen.
half2float(uint16_t halfValue)85ea8ed5cbSbixia1 float half2float(uint16_t halfValue) {
86ea8ed5cbSbixia1 const uint32_t shiftedExp =
87ea8ed5cbSbixia1 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.
88ea8ed5cbSbixia1
89ea8ed5cbSbixia1 // Initialize the float representation with the exponent/mantissa bits.
90ea8ed5cbSbixia1 Float32Bits f = {
91ea8ed5cbSbixia1 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
92ea8ed5cbSbixia1 const uint32_t exp = shiftedExp & f.u;
93ea8ed5cbSbixia1 f.u += kF32HalfExpAdjust; // Adjust the exponent
94ea8ed5cbSbixia1
95ea8ed5cbSbixia1 // Handle exponent special cases.
96ea8ed5cbSbixia1 if (exp == shiftedExp) {
97ea8ed5cbSbixia1 // Inf/NaN
98ea8ed5cbSbixia1 f.u += kF32HalfExpAdjust;
99ea8ed5cbSbixia1 } else if (exp == 0) {
100ea8ed5cbSbixia1 // Zero/Denormal?
101ea8ed5cbSbixia1 f.u += 1 << kF32MantiBits;
102ea8ed5cbSbixia1 f.f -= kF32Magic.f;
103ea8ed5cbSbixia1 }
104ea8ed5cbSbixia1
105ea8ed5cbSbixia1 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
106ea8ed5cbSbixia1 return f.f;
107ea8ed5cbSbixia1 }
108ea8ed5cbSbixia1
109ea8ed5cbSbixia1 const uint32_t kF32BfMantiBitDiff = 16;
110ea8ed5cbSbixia1
111ea8ed5cbSbixia1 // Constructs the 16 bit representation for a bfloat value from a float value.
112ea8ed5cbSbixia1 // This implementation is adapted from Eigen.
float2bfloat(float floatValue)113ea8ed5cbSbixia1 uint16_t float2bfloat(float floatValue) {
114b3127769SBenjamin Kramer if (std::isnan(floatValue))
115b3127769SBenjamin Kramer return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0;
116b3127769SBenjamin Kramer
117ea8ed5cbSbixia1 Float32Bits floatBits;
118ea8ed5cbSbixia1 floatBits.f = floatValue;
119ea8ed5cbSbixia1 uint16_t bfloatBits;
120ea8ed5cbSbixia1
121ea8ed5cbSbixia1 // Least significant bit of resulting bfloat.
122ea8ed5cbSbixia1 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
123be799722SMehdi Amini uint32_t roundingBias = 0x7fff + lsb;
124be799722SMehdi Amini floatBits.u += roundingBias;
125ea8ed5cbSbixia1 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
126ea8ed5cbSbixia1 return bfloatBits;
127ea8ed5cbSbixia1 }
128ea8ed5cbSbixia1
129ea8ed5cbSbixia1 // Converts the 16 bit representation of a bfloat value to a float value. This
130ea8ed5cbSbixia1 // implementation is adapted from Eigen.
bfloat2float(uint16_t bfloatBits)131ea8ed5cbSbixia1 float bfloat2float(uint16_t bfloatBits) {
132ea8ed5cbSbixia1 Float32Bits floatBits;
133ea8ed5cbSbixia1 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
134ea8ed5cbSbixia1 return floatBits.f;
135ea8ed5cbSbixia1 }
136ea8ed5cbSbixia1
137ea8ed5cbSbixia1 } // namespace
138ea8ed5cbSbixia1
f16(float f)139ea8ed5cbSbixia1 f16::f16(float f) : bits(float2half(f)) {}
140ea8ed5cbSbixia1
bf16(float f)141ea8ed5cbSbixia1 bf16::bf16(float f) : bits(float2bfloat(f)) {}
142ea8ed5cbSbixia1
operator <<(std::ostream & os,const f16 & f)143ea8ed5cbSbixia1 std::ostream &operator<<(std::ostream &os, const f16 &f) {
144ea8ed5cbSbixia1 os << half2float(f.bits);
145ea8ed5cbSbixia1 return os;
146ea8ed5cbSbixia1 }
147ea8ed5cbSbixia1
operator <<(std::ostream & os,const bf16 & d)148ea8ed5cbSbixia1 std::ostream &operator<<(std::ostream &os, const bf16 &d) {
149ea8ed5cbSbixia1 os << bfloat2float(d.bits);
150ea8ed5cbSbixia1 return os;
151ea8ed5cbSbixia1 }
1523420cd7cSBenjamin Kramer
operator ==(const f16 & f1,const f16 & f2)153753dc0a0SYinying Li bool operator==(const f16 &f1, const f16 &f2) { return f1.bits == f2.bits; }
154753dc0a0SYinying Li
operator ==(const bf16 & f1,const bf16 & f2)155753dc0a0SYinying Li bool operator==(const bf16 &f1, const bf16 &f2) { return f1.bits == f2.bits; }
156753dc0a0SYinying Li
157f695554aSBenjamin Kramer // Mark these symbols as weak so they don't conflict when compiler-rt also
158f695554aSBenjamin Kramer // defines them.
159f695554aSBenjamin Kramer #define ATTR_WEAK
160745a4caaSBenjamin Kramer #ifdef __has_attribute
161745a4caaSBenjamin Kramer #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
162745a4caaSBenjamin Kramer !defined(_WIN32)
163f695554aSBenjamin Kramer #undef ATTR_WEAK
164f695554aSBenjamin Kramer #define ATTR_WEAK __attribute__((__weak__))
165d5c29b23SBenjamin Kramer #endif
166745a4caaSBenjamin Kramer #endif
167f695554aSBenjamin Kramer
168*9a3ece23SBenjamin Kramer #if defined(__x86_64__) || defined(_M_X64)
169726719e9SBenjamin Kramer // On x86 bfloat16 is passed in SSE registers. Since both float and __bf16
170f695554aSBenjamin Kramer // are passed in the same register we can use the wider type and careful casting
171f695554aSBenjamin Kramer // to conform to x86_64 psABI. This only works with the assumption that we're
172f695554aSBenjamin Kramer // dealing with little-endian values passed in wider registers.
173726719e9SBenjamin Kramer // Ideally this would directly use __bf16, but that type isn't supported by all
174726719e9SBenjamin Kramer // compilers.
175f695554aSBenjamin Kramer using BF16ABIType = float;
176f695554aSBenjamin Kramer #else
177f695554aSBenjamin Kramer // Default to uint16_t if we have nothing else.
178f695554aSBenjamin Kramer using BF16ABIType = uint16_t;
179f695554aSBenjamin Kramer #endif
180f695554aSBenjamin Kramer
181f695554aSBenjamin Kramer // Provide a float->bfloat conversion routine in case the runtime doesn't have
182f695554aSBenjamin Kramer // one.
__truncsfbf2(float f)183f695554aSBenjamin Kramer extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
184f695554aSBenjamin Kramer uint16_t bf = float2bfloat(f);
185f695554aSBenjamin Kramer // The output can be a float type, bitcast it from uint16_t.
186f695554aSBenjamin Kramer BF16ABIType ret = 0;
187f695554aSBenjamin Kramer std::memcpy(&ret, &bf, sizeof(bf));
188f695554aSBenjamin Kramer return ret;
1893420cd7cSBenjamin Kramer }
1903420cd7cSBenjamin Kramer
1913420cd7cSBenjamin Kramer // Provide a double->bfloat conversion routine in case the runtime doesn't have
1923420cd7cSBenjamin Kramer // one.
__truncdfbf2(double d)193f695554aSBenjamin Kramer extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
1943420cd7cSBenjamin Kramer // This does a double rounding step, but it's precise enough for our use
1953420cd7cSBenjamin Kramer // cases.
196726719e9SBenjamin Kramer return __truncsfbf2(static_cast<float>(d));
1973420cd7cSBenjamin Kramer }
1980fca5c5fSwren romano
199657f60a0SAart Bik // Provide these to the CRunner with the local float16 knowledge.
printF16(uint16_t bits)2007b4ea67fSMehdi Amini extern "C" void printF16(uint16_t bits) {
201657f60a0SAart Bik f16 f;
202657f60a0SAart Bik std::memcpy(&f, &bits, sizeof(f16));
203657f60a0SAart Bik std::cout << f;
204657f60a0SAart Bik }
printBF16(uint16_t bits)2057b4ea67fSMehdi Amini extern "C" void printBF16(uint16_t bits) {
206657f60a0SAart Bik bf16 f;
207657f60a0SAart Bik std::memcpy(&f, &bits, sizeof(bf16));
208657f60a0SAart Bik std::cout << f;
209657f60a0SAart Bik }
210657f60a0SAart Bik
2110fca5c5fSwren romano #endif // MLIR_FLOAT16_DEFINE_FUNCTIONS
212