288 lines
7.3 KiB
C++
288 lines
7.3 KiB
C++
/*
|
|
* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#ifndef __is_identifier
|
|
#define __is_identifier(x) 1 // Compatibility with non-clang compilers.
|
|
#endif
|
|
|
|
#define __has_keyword(__x) !(__is_identifier(__x))
|
|
|
|
// map a half float type, if available, to float16
|
|
#if __has_keyword(_Float16)
|
|
typedef _Float16 float16;
|
|
#elif __has_keyword(__fp16)
|
|
typedef __fp16 float16;
|
|
#else
|
|
|
|
class float16
|
|
{
|
|
uint16_t m_data;
|
|
|
|
union IEEE_FP16
|
|
{
|
|
uint16_t data;
|
|
struct
|
|
{
|
|
uint16_t mantissa : 10;
|
|
uint16_t exponent : 5;
|
|
uint16_t sign : 1;
|
|
} IEEE;
|
|
};
|
|
|
|
// Helper for conversion
|
|
union IEEE_FP32
|
|
{
|
|
float Float;
|
|
struct
|
|
{
|
|
uint32_t mantissa : 23;
|
|
uint32_t exponent : 8;
|
|
uint32_t sign : 1;
|
|
} IEEE;
|
|
};
|
|
|
|
static constexpr uint16_t float2half(float value)
|
|
{
|
|
IEEE_FP32 f;
|
|
f.Float = value;
|
|
|
|
IEEE_FP16 fp16;
|
|
fp16.IEEE.sign = f.IEEE.sign;
|
|
|
|
if ( !f.IEEE.exponent)
|
|
{
|
|
fp16.IEEE.mantissa = 0;
|
|
fp16.IEEE.exponent = 0;
|
|
}
|
|
else if (f.IEEE.exponent==0xff)
|
|
{
|
|
// NaN or INF
|
|
fp16.IEEE.mantissa = (f.IEEE.mantissa!=0) ? 1 : 0;
|
|
fp16.IEEE.exponent = 31;
|
|
}
|
|
else
|
|
{
|
|
// regular number
|
|
int new_exp = f.IEEE.exponent-127;
|
|
|
|
if (new_exp<-24)
|
|
{ // this maps to 0
|
|
fp16.IEEE.mantissa = 0;
|
|
fp16.IEEE.exponent = 0;
|
|
}
|
|
|
|
else if (new_exp < -14)
|
|
{
|
|
// this maps to a denorm
|
|
fp16.IEEE.exponent = 0;
|
|
unsigned int exp_val = (unsigned int) (-14 - new_exp); // 2^-exp_val
|
|
switch (exp_val)
|
|
{
|
|
case 0:
|
|
fp16.IEEE.mantissa = 0;
|
|
break;
|
|
case 1: fp16.IEEE.mantissa = 512 + (f.IEEE.mantissa>>14); break;
|
|
case 2: fp16.IEEE.mantissa = 256 + (f.IEEE.mantissa>>15); break;
|
|
case 3: fp16.IEEE.mantissa = 128 + (f.IEEE.mantissa>>16); break;
|
|
case 4: fp16.IEEE.mantissa = 64 + (f.IEEE.mantissa>>17); break;
|
|
case 5: fp16.IEEE.mantissa = 32 + (f.IEEE.mantissa>>18); break;
|
|
case 6: fp16.IEEE.mantissa = 16 + (f.IEEE.mantissa>>19); break;
|
|
case 7: fp16.IEEE.mantissa = 8 + (f.IEEE.mantissa>>20); break;
|
|
case 8: fp16.IEEE.mantissa = 4 + (f.IEEE.mantissa>>21); break;
|
|
case 9: fp16.IEEE.mantissa = 2 + (f.IEEE.mantissa>>22); break;
|
|
case 10: fp16.IEEE.mantissa = 1; break;
|
|
}
|
|
}
|
|
else if (new_exp>15)
|
|
{ // map this value to infinity
|
|
fp16.IEEE.mantissa = 0;
|
|
fp16.IEEE.exponent = 31;
|
|
}
|
|
else
|
|
{
|
|
fp16.IEEE.exponent = new_exp+15;
|
|
fp16.IEEE.mantissa = (f.IEEE.mantissa >> 13);
|
|
}
|
|
}
|
|
return fp16.data;
|
|
}
|
|
|
|
public:
|
|
constexpr float16() : m_data(0) {}
|
|
|
|
constexpr float16(const float16&) = default;
|
|
|
|
constexpr float16(float16&&) = default;
|
|
|
|
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
|
|
constexpr float16(const T& other) : m_data(float2half(static_cast<float>(other))) {}
|
|
|
|
constexpr float16(uint16_t mantissa, uint16_t exponent, uint16_t sign)
|
|
{
|
|
IEEE_FP16 fp16;
|
|
fp16.IEEE.mantissa = mantissa;
|
|
fp16.IEEE.exponent = exponent;
|
|
fp16.IEEE.sign = sign;
|
|
m_data = fp16.data;
|
|
}
|
|
|
|
//region operators
|
|
// Operator +=, -=, *=, /= and +, -, *, /
|
|
#define BINARY_ARITHMETIC_OPERATOR(OP) \
|
|
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> \
|
|
constexpr float16& operator OP##=(const T& rhs) { \
|
|
return *this = operator float() OP static_cast<float>(rhs); \
|
|
} \
|
|
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> \
|
|
[[nodiscard]] constexpr float16 operator OP(const T& rhs) const { \
|
|
return { operator float() OP static_cast<T>(rhs) }; \
|
|
}
|
|
|
|
BINARY_ARITHMETIC_OPERATOR(+)
|
|
BINARY_ARITHMETIC_OPERATOR(-)
|
|
BINARY_ARITHMETIC_OPERATOR(*)
|
|
BINARY_ARITHMETIC_OPERATOR(/)
|
|
|
|
#undef BINARY_ARITHMETIC_OPERATOR
|
|
|
|
// Operator ++, --
|
|
constexpr float16& operator++()
|
|
{
|
|
return *this += 1;
|
|
}
|
|
|
|
constexpr float16 operator++(int)
|
|
{
|
|
float16 ret(*this);
|
|
operator++();
|
|
return ret;
|
|
}
|
|
|
|
constexpr float16& operator--()
|
|
{
|
|
return *this -= 1;
|
|
}
|
|
|
|
constexpr float16 operator--(int)
|
|
{
|
|
float16 ret(*this);
|
|
operator--();
|
|
return ret;
|
|
}
|
|
|
|
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
|
|
constexpr float16& operator=(const T& rhs)
|
|
{
|
|
m_data = float2half(rhs);
|
|
return *this;
|
|
}
|
|
|
|
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
|
|
[[nodiscard]] constexpr auto operator<=>(const T& other) const { return operator float() <=> static_cast<float>(other); }
|
|
//endregion
|
|
|
|
// Operator float
|
|
[[nodiscard]] constexpr operator float() const
|
|
{
|
|
IEEE_FP32 fp32;
|
|
IEEE_FP16 fp16;
|
|
fp16.data = m_data;
|
|
fp32.IEEE.sign = fp16.IEEE.sign;
|
|
|
|
if (!fp16.IEEE.exponent)
|
|
{
|
|
if (!fp16.IEEE.mantissa)
|
|
{
|
|
fp32.IEEE.mantissa = 0;
|
|
fp32.IEEE.exponent = 0;
|
|
}
|
|
else
|
|
{
|
|
const float half_denorm = 1.0f / 16384.0f;
|
|
float mantissa = static_cast<float>(fp16.IEEE.mantissa) / 1024.0f;
|
|
float sgn = (fp16.IEEE.sign) ? -1.0f : 1.0f;
|
|
fp32.Float = sgn * mantissa * half_denorm;
|
|
}
|
|
}
|
|
else if (31 == fp16.IEEE.exponent)
|
|
{
|
|
fp32.IEEE.exponent = 0xff;
|
|
fp32.IEEE.mantissa = (fp16.IEEE.mantissa != 0) ? 1 : 0;
|
|
}
|
|
else
|
|
{
|
|
fp32.IEEE.exponent = fp16.IEEE.exponent + 112;
|
|
fp32.IEEE.mantissa = fp16.IEEE.mantissa << 13;
|
|
}
|
|
return fp32.Float;
|
|
}
|
|
|
|
template<typename Key> friend
|
|
struct std::hash;
|
|
};
|
|
|
|
namespace std
|
|
{
|
|
template<>
|
|
struct hash<float16>
|
|
{
|
|
std::size_t operator()(const float16& key) const
|
|
{
|
|
return hash<uint16_t>()(key.m_data);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
class numeric_limits<float16>
|
|
{
|
|
public:
|
|
// General -- meaningful for all specializations.
|
|
static const bool is_specialized = true;
|
|
static float16 min() { return float16(0, 1, 0); }
|
|
static float16 max() { return float16(~0, 30, 0); }
|
|
static const int radix = 2;
|
|
static const int digits = 10; // conservative assumption
|
|
static const int digits10 = 2; // conservative assumption
|
|
static const bool is_signed = true;
|
|
static const bool is_integer = true;
|
|
static const bool is_exact = false;
|
|
static const bool traps = false;
|
|
static const bool is_modulo = false;
|
|
static const bool is_bounded = true;
|
|
|
|
// Floating point specific.
|
|
|
|
static float16 epsilon() { return float16(0.00097656f); } // from OpenEXR, needs to be confirmed
|
|
static float16 round_error() { return float16(0.00097656f/2); }
|
|
static const int min_exponent10 = -9;
|
|
static const int max_exponent10 = 9;
|
|
static const int min_exponent = -15;
|
|
static const int max_exponent = 15;
|
|
|
|
static const bool has_infinity = true;
|
|
static const bool has_quiet_NaN = true;
|
|
static const bool has_signaling_NaN = true;
|
|
static const bool is_iec559 = false;
|
|
static const bool has_denorm = denorm_present;
|
|
static const bool tinyness_before = false;
|
|
static const float_round_style round_style = round_to_nearest;
|
|
|
|
static float16 denorm_min() { return float16(1, 0, 1); }
|
|
static float16 infinity() { return float16(0, 31, 0); }
|
|
static float16 quiet_NaN() { return float16(1, 31, 0); }
|
|
static float16 signaling_NaN () { return float16(1, 31, 0); }
|
|
};
|
|
}
|
|
|
|
#endif
|
|
#undef __has_keyword
|
|
|
|
|
|
typedef float16 fp16;
|
|
typedef float16 half;
|