Files
OpenVulkano/openVulkanoCpp/Math/Float16.hpp

292 lines
7.8 KiB
C++

/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/
#pragma once
#ifndef __is_identifier
#define __is_identifier(x) 1 // Compatibility with non-clang compilers.
#endif
#define __has_keyword(__x) !(__is_identifier(__x))
// map a half float type, if available, to float16
#if __has_keyword(_Float16)
typedef _Float16 float16;
#elif __has_keyword(__fp16)
typedef __fp16 float16;
#else
#define USING_CUSTOM_FLOAT16
class float16
{
uint16_t m_data;
union IEEE_FP16
{
uint16_t data;
struct
{
uint16_t mantissa : 10;
uint16_t exponent : 5;
uint16_t sign : 1;
} IEEE;
};
// Helper for conversion
union IEEE_FP32
{
float Float;
struct
{
uint32_t mantissa : 23;
uint32_t exponent : 8;
uint32_t sign : 1;
} IEEE;
};
static constexpr uint16_t float2half(float value)
{
IEEE_FP32 f;
f.Float = value;
IEEE_FP16 fp16;
fp16.IEEE.sign = f.IEEE.sign;
if ( !f.IEEE.exponent)
{
fp16.IEEE.mantissa = 0;
fp16.IEEE.exponent = 0;
}
else if (f.IEEE.exponent==0xff)
{
// NaN or INF
fp16.IEEE.mantissa = (f.IEEE.mantissa!=0) ? 1 : 0;
fp16.IEEE.exponent = 31;
}
else
{
// regular number
int new_exp = f.IEEE.exponent-127;
if (new_exp<-24)
{ // this maps to 0
fp16.IEEE.mantissa = 0;
fp16.IEEE.exponent = 0;
}
else if (new_exp < -14)
{
// this maps to a denorm
fp16.IEEE.exponent = 0;
unsigned int exp_val = (unsigned int) (-14 - new_exp); // 2^-exp_val
switch (exp_val)
{
case 0:
fp16.IEEE.mantissa = 0;
break;
case 1: fp16.IEEE.mantissa = 512 + (f.IEEE.mantissa>>14); break;
case 2: fp16.IEEE.mantissa = 256 + (f.IEEE.mantissa>>15); break;
case 3: fp16.IEEE.mantissa = 128 + (f.IEEE.mantissa>>16); break;
case 4: fp16.IEEE.mantissa = 64 + (f.IEEE.mantissa>>17); break;
case 5: fp16.IEEE.mantissa = 32 + (f.IEEE.mantissa>>18); break;
case 6: fp16.IEEE.mantissa = 16 + (f.IEEE.mantissa>>19); break;
case 7: fp16.IEEE.mantissa = 8 + (f.IEEE.mantissa>>20); break;
case 8: fp16.IEEE.mantissa = 4 + (f.IEEE.mantissa>>21); break;
case 9: fp16.IEEE.mantissa = 2 + (f.IEEE.mantissa>>22); break;
case 10: fp16.IEEE.mantissa = 1; break;
}
}
else if (new_exp>15)
{ // map this value to infinity
fp16.IEEE.mantissa = 0;
fp16.IEEE.exponent = 31;
}
else
{
fp16.IEEE.exponent = new_exp+15;
fp16.IEEE.mantissa = (f.IEEE.mantissa >> 13);
}
}
return fp16.data;
}
public:
constexpr float16() : m_data(0) {}
constexpr float16(const float16&) = default;
constexpr float16(float16&&) = default;
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
constexpr float16(const T& other) : m_data(float2half(static_cast<float>(other))) {}
constexpr float16(uint16_t mantissa, uint16_t exponent, uint16_t sign)
{
IEEE_FP16 fp16;
fp16.IEEE.mantissa = mantissa;
fp16.IEEE.exponent = exponent;
fp16.IEEE.sign = sign;
m_data = fp16.data;
}
//region operators
// Operator +=, -=, *=, /= and +, -, *, /
#define BINARY_ARITHMETIC_OPERATOR(OP) \
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> \
constexpr float16& operator OP##=(const T& rhs) { \
return *this = operator float() OP static_cast<float>(rhs); \
} \
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> \
[[nodiscard]] constexpr float16 operator OP(const T& rhs) const { \
return { operator float() OP static_cast<T>(rhs) }; \
}
BINARY_ARITHMETIC_OPERATOR(+)
BINARY_ARITHMETIC_OPERATOR(-)
BINARY_ARITHMETIC_OPERATOR(*)
BINARY_ARITHMETIC_OPERATOR(/)
#undef BINARY_ARITHMETIC_OPERATOR
// Operator ++, --
constexpr float16& operator++()
{
return *this += 1;
}
constexpr float16 operator++(int)
{
float16 ret(*this);
operator++();
return ret;
}
constexpr float16& operator--()
{
return *this -= 1;
}
constexpr float16 operator--(int)
{
float16 ret(*this);
operator--();
return ret;
}
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
constexpr float16& operator=(const T& rhs)
{
m_data = float2half(rhs);
return *this;
}
template<typename T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
[[nodiscard]] constexpr auto operator<=>(const T& other) const { return operator float() <=> static_cast<float>(other); }
//endregion
// Operator float
[[nodiscard]] constexpr operator float() const
{
IEEE_FP32 fp32;
IEEE_FP16 fp16;
fp16.data = m_data;
fp32.IEEE.sign = fp16.IEEE.sign;
if (!fp16.IEEE.exponent)
{
if (!fp16.IEEE.mantissa)
{
fp32.IEEE.mantissa = 0;
fp32.IEEE.exponent = 0;
}
else
{
const float half_denorm = 1.0f / 16384.0f;
float mantissa = static_cast<float>(fp16.IEEE.mantissa) / 1024.0f;
float sgn = (fp16.IEEE.sign) ? -1.0f : 1.0f;
fp32.Float = sgn * mantissa * half_denorm;
}
}
else if (31 == fp16.IEEE.exponent)
{
fp32.IEEE.exponent = 0xff;
fp32.IEEE.mantissa = (fp16.IEEE.mantissa != 0) ? 1 : 0;
}
else
{
fp32.IEEE.exponent = fp16.IEEE.exponent + 112;
fp32.IEEE.mantissa = fp16.IEEE.mantissa << 13;
}
return fp32.Float;
}
template<typename Key> friend
struct std::hash;
};
namespace std
{
template<>
struct hash<float16>
{
std::size_t operator()(const float16& key) const
{
return hash<uint16_t>()(key.m_data);
}
};
}
#endif
#undef __has_keyword
namespace std
{
template <>
class numeric_limits<float16>
{
public:
// General -- meaningful for all specializations.
static const bool is_specialized = true;
static float16 min() { uint16_t v = 0x400; return *(float16 *)&v;/*return float16(0, 1, 0);*/ }
static float16 max() { uint16_t v = 0x7bff; return *(float16 *)&v;/*return float16(~0, 30, 0);*/ }
static const int radix = 2;
static const int digits = 10; // conservative assumption
static const int digits10 = 2; // conservative assumption
static const bool is_signed = true;
static const bool is_integer = true;
static const bool is_exact = false;
static const bool traps = false;
static const bool is_modulo = false;
static const bool is_bounded = true;
// Floating point specific.
static float16 epsilon() { uint16_t v = 0x13ff; return *(float16 *)&v;/*return float16(0.00097656f);*/ } // from OpenEXR, needs to be confirmed
static float16 round_error() { uint16_t v = 0xfff; return *(float16 *)&v;/*return float16(0.00097656f/2);*/ }
static const int min_exponent10 = -9;
static const int max_exponent10 = 9;
static const int min_exponent = -15;
static const int max_exponent = 15;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const bool is_iec559 = false;
static const bool has_denorm = denorm_present;
static const bool tinyness_before = false;
static const float_round_style round_style = round_to_nearest;
static float16 denorm_min() { uint16_t v = 0x8001; return *(float16 *)&v;/*return float16(1, 0, 1);*/ }
static float16 infinity() { uint16_t v = 0x7c00; return *(float16 *)&v;/*return float16(0, 31, 0);*/ }
static float16 quiet_NaN() { uint16_t v = 0x7c01; return *(float16 *)&v;/*return float16(1, 31, 0);*/ }
static float16 signaling_NaN () { uint16_t v = 0x7c01; return *(float16 *)&v;/*return float16(1, 31, 0);*/ }
};
}
typedef float16 fp16;
typedef float16 half;