Add running average

This commit is contained in:
Georg Hagen
2025-05-13 17:37:20 +02:00
parent 9215a0a305
commit 28bf76b792
2 changed files with 277 additions and 0 deletions

View File

@@ -0,0 +1,97 @@
/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/
#pragma once
#include "Data/Containers/RingBuffer.hpp"
#include <type_traits>
#include <concepts>
namespace OpenVulkano::Math
{
template<typename T>
concept SupportsRunningAverage = requires(T a, T b, int64_t size)
{
{ a += b } -> std::same_as<T&>;
{ a -= b } -> std::same_as<T&>;
{ a * size } -> std::convertible_to<T>;
{ a / size } -> std::convertible_to<T>;
};
template<typename T,
size_t SIZE = RING_BUFFER_DYNAMIC,
bool RECALC = std::is_floating_point_v<T>,
size_t RECALC_OP_COUNT = 100> requires SupportsRunningAverage<T>
class RunningAverage final
{
struct Empty {};
RingBuffer<T, SIZE> m_buffer;
T m_value;
[[no_unique_address]] std::conditional_t<RECALC, size_t, Empty> m_operations;
public:
RunningAverage(size_t size) requires (std::is_default_constructible_v<T> && SIZE == RING_BUFFER_DYNAMIC)
: m_buffer(size), m_value(T() * static_cast<int64_t>(size))
{
assert(size > 0);
m_buffer.Fill(T());
if constexpr (RECALC) m_operations = 0;
}
RunningAverage(size_t size, const T& def) requires (SIZE == RING_BUFFER_DYNAMIC)
: m_buffer(size), m_value(def * static_cast<int64_t>(size))
{
assert(size > 0);
m_buffer.Fill(def);
if constexpr (RECALC) m_operations = 0;
}
RunningAverage() requires (std::is_default_constructible_v<T> && SIZE != RING_BUFFER_DYNAMIC)
: m_value(T() * static_cast<int64_t>(SIZE))
{
static_assert(SIZE > 0);
m_buffer.Fill(T());
if constexpr (RECALC) m_operations = 0;
}
RunningAverage(const T& def) requires (SIZE != RING_BUFFER_DYNAMIC)
: m_value(def * static_cast<int64_t>(SIZE))
{
static_assert(SIZE > 0);
m_buffer.Fill(def);
if constexpr (RECALC) m_operations = 0;
}
void Push(const T& value)
{
m_value -= m_buffer.PushAndOverwrite(value);
m_value += value;
if constexpr (RECALC)
{
if (++m_operations >= RECALC_OP_COUNT)
{
m_operations = 0;
Recalculate();
}
}
}
[[nodiscard]] T Get() const { return m_value / static_cast<int64_t>(m_buffer.Size()); }
void Recalculate()
{
m_value = m_buffer[0];
auto it = m_buffer.cbegin();
it++;
for(auto end = m_buffer.cend(); it != end; it++)
{
m_value += *it;
}
}
};
}

View File

@@ -0,0 +1,180 @@
/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/
#include <catch2/catch_all.hpp>
#include "Math/RunningAverage.hpp"
#include "Math/Math.hpp"
using namespace OpenVulkano::Math;
TEST_CASE("RunningAverage Basic Functionality - Integer Type", "[RunningAverage]")
{
SECTION("Dynamic Size Constructor with Default Value")
{
RunningAverage<int> avg(5);
REQUIRE(avg.Get() == 0);
}
SECTION("Dynamic Size Constructor with Custom Default")
{
RunningAverage<int> avg(5, 10);
REQUIRE(avg.Get() == 10);
}
SECTION("Static Size Constructor with Default Value")
{
RunningAverage<int, 5> avg;
REQUIRE(avg.Get() == 0);
}
SECTION("Static Size Constructor with Custom Default")
{
RunningAverage<int, 5> avg(10);
REQUIRE(avg.Get() == 10);
}
SECTION("Push Single Value")
{
RunningAverage<int> avg(5, 0);
avg.Push(5);
REQUIRE(avg.Get() == 1); // (0+0+0+0+5)/5 = 1
}
SECTION("Push Multiple Values")
{
RunningAverage<int> avg(4, 0);
avg.Push(4);
avg.Push(8);
avg.Push(12);
avg.Push(16);
REQUIRE(avg.Get() == 10); // (4+8+12+16)/4 = 10
}
SECTION("Push with Overwrite")
{
RunningAverage<int> avg(3, 0);
avg.Push(3);
avg.Push(6);
avg.Push(9);
REQUIRE(avg.Get() == 6); // (3+6+9)/3 = 6
avg.Push(12); // Overwrite 3
REQUIRE(avg.Get() == 9); // (6+9+12)/3 = 9
}
}
TEST_CASE("RunningAverage with Floating Point Types", "[RunningAverage]")
{
SECTION("Basic Floating Point Operation")
{
RunningAverage<double> avg(3, 0.0);
avg.Push(1.5);
avg.Push(2.5);
avg.Push(3.5);
REQUIRE(avg.Get() == Catch::Approx(2.5)); // (1.5+2.5+3.5)/3 = 2.5
}
SECTION("Recalculation with Floating Point")
{
// Set recalculation to happen after 3 operations
RunningAverage<double, std::numeric_limits<size_t>::max(), true, 3> avg(3, 0.0);
avg.Push(1.0);
avg.Push(2.0);
avg.Push(3.0); // This should trigger recalculation
REQUIRE(avg.Get() == Catch::Approx(2.0)); // (1.0+2.0+3.0)/3 = 2.0
// Push another value to check that recalculation didn't affect the result
avg.Push(4.0);
REQUIRE(avg.Get() == Catch::Approx(3.0)); // (2.0+3.0+4.0)/3 = 3.0
}
SECTION("Manual Recalculation")
{
RunningAverage<double> avg(3, 0.0);
avg.Push(1.5);
avg.Push(2.5);
avg.Push(3.5);
// Force recalculation and verify it doesn't change the result
avg.Recalculate();
REQUIRE(avg.Get() == Catch::Approx(2.5));
}
}
TEST_CASE("RunningAverage Edge Cases", "[RunningAverage]")
{
SECTION("Large Values")
{
RunningAverage<int> avg(3, 0);
avg.Push(1000000);
avg.Push(2000000);
avg.Push(3000000);
REQUIRE(avg.Get() == 2000000);
}
SECTION("Negative Values")
{
RunningAverage<int> avg(3, 0);
avg.Push(-10);
avg.Push(-20);
avg.Push(-30);
REQUIRE(avg.Get() == -20);
}
SECTION("Mixed Sign Values")
{
RunningAverage<int> avg(4, 0);
avg.Push(-10);
avg.Push(10);
avg.Push(-10);
avg.Push(10);
REQUIRE(avg.Get() == 0);
}
}
TEST_CASE("Custom Types with RunningAverage", "[RunningAverage]")
{
SECTION("Vector Type Average")
{
RunningAverage<Vector2f> avg(3, Vector2f(0, 0));
avg.Push(Vector2f(1, 2));
avg.Push(Vector2f(3, 4));
avg.Push(Vector2f(5, 6));
Vector2f result = avg.Get();
REQUIRE(result.x == Catch::Approx(3.0f));
REQUIRE(result.y == Catch::Approx(4.0f));
}
}
// This test specifically checks the correctness of recalculation for floating point
TEST_CASE("Precision with Floating Point Types", "[RunningAverage]")
{
SECTION("Accumulation Error Correction")
{
// Create a running average that would normally suffer from floating point errors
// Force recalculation after 20 operations
RunningAverage<double, std::numeric_limits<size_t>::max(), true, 20> avg(5, 0.0);
// Push values that would cause floating point rounding errors
for (int i = 0; i < 100; i++) {
avg.Push(0.1); // 0.1 cannot be represented exactly in binary floating point
}
// Check that the result is still accurate due to periodic recalculation
REQUIRE(avg.Get() == Catch::Approx(0.1).epsilon(0.0001));
}
}
TEST_CASE("Class size", "[RunningAverage]")
{
REQUIRE(sizeof(RunningAverage<int32_t>) == 40);
REQUIRE(sizeof(RunningAverage<int32_t, ::OpenVulkano::RING_BUFFER_DYNAMIC, true>) == 48);
REQUIRE(sizeof(RunningAverage<float>) == 48);
REQUIRE(sizeof(RunningAverage<int32_t, 5>) == 48);
REQUIRE(sizeof(RunningAverage<Vector2f>) == 40);
}