mirror of
https://github.com/Akkudoktor-EOS/EOS.git
synced 2026-02-24 01:46:21 +00:00
Add database support for measurements and historic prediction data. (#848)
The database supports backend selection, compression, incremental data load, automatic data saving to storage, automatic vaccum and compaction. Make SQLite3 and LMDB database backends available. Update tests for new interface conventions regarding data sequences, data containers, data providers. This includes the measurements provider and the prediction providers. Add database documentation. The fix includes several bug fixes that are not directly related to the database implementation but are necessary to keep EOS running properly and to test and document the changes. * fix: config eos test setup Make the config_eos fixture generate a new instance of the config_eos singleton. Use correct env names to setup data folder path. * fix: startup with no config Make cache and measurements complain about missing data path configuration but do not bail out. * fix: soc data preparation and usage for genetic optimization. Search for soc measurments 48 hours around the optimization start time. Only clamp soc to maximum in battery device simulation. * fix: dashboard bailout on zero value solution display Do not use zero values to calculate the chart values adjustment for display. * fix: openapi generation script Make the script also replace data_folder_path and data_output_path to hide real (test) environment pathes. * feat: add make repeated task function make_repeated_task allows to wrap a function to be repeated cyclically. * chore: removed index based data sequence access Index based data sequence access does not make sense as the sequence can be backed by the database. The sequence is now purely time series data. * chore: refactor eos startup to avoid module import startup Avoid module import initialisation expecially of the EOS configuration. Config mutation, singleton initialization, logging setup, argparse parsing, background task definitions depending on config and environment-dependent behavior is now done at function startup. * chore: introduce retention manager A single long-running background task that owns the scheduling of all periodic server-maintenance jobs (cache cleanup, DB autosave, …) * chore: canonicalize timezone name for UTC Timezone names that are semantically identical to UTC are canonicalized to UTC. * chore: extend config file migration for default value handling Extend the config file migration handling values None or nonexisting values that will invoke a default value generation in the new config file. Also adapt test to handle this situation. * chore: extend datetime util test cases * chore: make version test check for untracked files Check for files that are not tracked by git. Version calculation will be wrong if these files will not be commited. * chore: bump pandas to 3.0.0 Pandas 3.0 now performs inference on the appropriate resolution (a.k.a. unit) for the output dtype which may become datetime64[us] (before it was ns). Also numeric dtype detection is now more strict which needs a different detection for numerics. * chore: bump pydantic-settings to 2.12.0 pydantic-settings 2.12.0 under pytest creates a different behaviour. The tests were adapted and a workaround was introduced. Also ConfigEOS was adapted to allow for fine grain initialization control to be able to switch off certain settings such as file settings during test. * chore: remove sci learn kit from dependencies The sci learn kit is not strictly necessary as long as we have scipy. * chore: add documentation mode guarding for sphinx autosummary Sphinx autosummary excecutes functions. Prevent exceptions in case of pure doc mode. * chore: adapt docker-build CI workflow to stricter GitHub handling Signed-off-by: Bobby Noelte <b0661n0e17e@gmail.com>
This commit is contained in:
@@ -22,7 +22,8 @@ from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
from xprocess import ProcessStarter, XProcess
|
||||
|
||||
from akkudoktoreos.config.config import ConfigEOS, get_config
|
||||
from akkudoktoreos.config.config import ConfigEOS
|
||||
from akkudoktoreos.core.coreabc import get_config, get_prediction, singletons_init
|
||||
from akkudoktoreos.core.version import _version_hash, version
|
||||
from akkudoktoreos.server.server import get_default_host
|
||||
|
||||
@@ -134,8 +135,6 @@ def is_ci() -> bool:
|
||||
|
||||
@pytest.fixture
|
||||
def prediction_eos():
|
||||
from akkudoktoreos.prediction.prediction import get_prediction
|
||||
|
||||
return get_prediction()
|
||||
|
||||
|
||||
@@ -172,6 +171,37 @@ def cfg_non_existent(request):
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
# Provide pytest EOS config management
|
||||
# ------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_default_dirs(tmpdir):
|
||||
"""Fixture that provides a list of directories to be used as config dir."""
|
||||
tmp_user_home_dir = Path(tmpdir)
|
||||
|
||||
# Default config directory from platform user config directory
|
||||
config_default_dir_user = tmp_user_home_dir / "config"
|
||||
|
||||
# Default config directory from current working directory
|
||||
config_default_dir_cwd = tmp_user_home_dir / "cwd"
|
||||
config_default_dir_cwd.mkdir()
|
||||
|
||||
# Default config directory from default config file
|
||||
config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data")
|
||||
|
||||
# Default data directory from platform user data directory
|
||||
data_default_dir_user = tmp_user_home_dir
|
||||
|
||||
return (
|
||||
config_default_dir_user,
|
||||
config_default_dir_cwd,
|
||||
config_default_dir_default,
|
||||
data_default_dir_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def user_cwd(config_default_dirs):
|
||||
"""Patch cwd provided by module pathlib.Path.cwd."""
|
||||
@@ -203,64 +233,102 @@ def user_data_dir(config_default_dirs):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_eos(
|
||||
def config_eos_factory(
|
||||
disable_debug_logging,
|
||||
user_config_dir,
|
||||
user_data_dir,
|
||||
user_cwd,
|
||||
config_default_dirs,
|
||||
monkeypatch,
|
||||
) -> ConfigEOS:
|
||||
"""Fixture to reset EOS config to default values."""
|
||||
monkeypatch.setenv(
|
||||
"EOS_CONFIG__DATA_CACHE_SUBPATH", str(config_default_dirs[-1] / "data/cache")
|
||||
)
|
||||
monkeypatch.setenv(
|
||||
"EOS_CONFIG__DATA_OUTPUT_SUBPATH", str(config_default_dirs[-1] / "data/output")
|
||||
)
|
||||
config_file = config_default_dirs[0] / ConfigEOS.CONFIG_FILE_NAME
|
||||
config_file_cwd = config_default_dirs[1] / ConfigEOS.CONFIG_FILE_NAME
|
||||
assert not config_file.exists()
|
||||
assert not config_file_cwd.exists()
|
||||
):
|
||||
"""Factory fixture for creating a fully initialized ``ConfigEOS`` instance.
|
||||
|
||||
config_eos = get_config()
|
||||
config_eos.reset_settings()
|
||||
assert config_file == config_eos.general.config_file_path
|
||||
assert config_file.exists()
|
||||
assert not config_file_cwd.exists()
|
||||
Returns a callable that creates a ``ConfigEOS`` singleton with a controlled
|
||||
filesystem layout and environment variables. Allows tests to customize which
|
||||
pydantic-settings sources are enabled (init, env, dotenv, file, secrets).
|
||||
|
||||
# Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user)
|
||||
assert config_default_dirs[-1] / "data" == config_eos.general.data_folder_path
|
||||
assert config_default_dirs[-1] / "data/cache" == config_eos.cache.path()
|
||||
assert config_default_dirs[-1] / "data/output" == config_eos.general.data_output_path
|
||||
assert config_default_dirs[-1] / "data/output/eos.log" == config_eos.logging.file_path
|
||||
return config_eos
|
||||
The factory ensures:
|
||||
- Required directories exist
|
||||
- No pre-existing config files are present
|
||||
- Settings are reloaded to respect test-specific configuration
|
||||
- Dependent singletons are initialized
|
||||
|
||||
The singleton instance is reset during fixture teardown.
|
||||
"""
|
||||
def _create(init: dict[str, bool] | None = None) -> ConfigEOS:
|
||||
init = init or {
|
||||
"with_init_settings": True,
|
||||
"with_env_settings": True,
|
||||
"with_dotenv_settings": False,
|
||||
"with_file_settings": False,
|
||||
"with_file_secret_settings": False,
|
||||
}
|
||||
|
||||
# reset singleton before touching env or config
|
||||
ConfigEOS.reset_instance()
|
||||
ConfigEOS._init_config_eos = {
|
||||
"with_init_settings": True,
|
||||
"with_env_settings": True,
|
||||
"with_dotenv_settings": True,
|
||||
"with_file_settings": True,
|
||||
"with_file_secret_settings": True,
|
||||
}
|
||||
ConfigEOS._config_file_path = None
|
||||
ConfigEOS._force_documentation_mode = False
|
||||
|
||||
data_folder_path = config_default_dirs[-1] / "data"
|
||||
data_folder_path.mkdir(exist_ok=True)
|
||||
|
||||
config_dir = config_default_dirs[0]
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
|
||||
cwd = config_default_dirs[1]
|
||||
cwd.mkdir(exist_ok=True)
|
||||
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", str(config_dir))
|
||||
monkeypatch.setenv("EOS_GENERAL__DATA_FOLDER_PATH", str(data_folder_path))
|
||||
monkeypatch.setenv("EOS_GENERAL__DATA_CACHE_SUBPATH", "cache")
|
||||
monkeypatch.setenv("EOS_GENERAL__DATA_OUTPUT_SUBPATH", "output")
|
||||
|
||||
# Ensure no config files exist
|
||||
config_file = config_dir / ConfigEOS.CONFIG_FILE_NAME
|
||||
config_file_cwd = cwd / ConfigEOS.CONFIG_FILE_NAME
|
||||
assert not config_file.exists()
|
||||
assert not config_file_cwd.exists()
|
||||
|
||||
config_eos = get_config(init=init)
|
||||
# Ensure newly created configurations are respected
|
||||
# Note: Workaround for pydantic_settings and pytest
|
||||
config_eos.reset_settings()
|
||||
|
||||
# Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user)
|
||||
assert config_eos.general.data_folder_path == data_folder_path
|
||||
assert config_eos.general.data_output_subpath == Path("output")
|
||||
assert config_eos.cache.subpath == "cache"
|
||||
assert config_eos.cache.path() == config_default_dirs[-1] / "data/cache"
|
||||
assert config_eos.logging.file_path == config_default_dirs[-1] / "data/output/eos.log"
|
||||
|
||||
# Check config file path
|
||||
assert str(config_eos.general.config_file_path) == str(config_file)
|
||||
assert config_file.exists()
|
||||
assert not config_file_cwd.exists()
|
||||
|
||||
# Initialize all other singletons (if not already initialized)
|
||||
singletons_init()
|
||||
|
||||
return config_eos
|
||||
|
||||
yield _create
|
||||
|
||||
# teardown - final safety net
|
||||
ConfigEOS.reset_instance()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_default_dirs(tmpdir):
|
||||
"""Fixture that provides a list of directories to be used as config dir."""
|
||||
tmp_user_home_dir = Path(tmpdir)
|
||||
|
||||
# Default config directory from platform user config directory
|
||||
config_default_dir_user = tmp_user_home_dir / "config"
|
||||
|
||||
# Default config directory from current working directory
|
||||
config_default_dir_cwd = tmp_user_home_dir / "cwd"
|
||||
config_default_dir_cwd.mkdir()
|
||||
|
||||
# Default config directory from default config file
|
||||
config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data")
|
||||
|
||||
# Default data directory from platform user data directory
|
||||
data_default_dir_user = tmp_user_home_dir
|
||||
|
||||
return (
|
||||
config_default_dir_user,
|
||||
config_default_dir_cwd,
|
||||
config_default_dir_default,
|
||||
data_default_dir_user,
|
||||
)
|
||||
def config_eos(config_eos_factory) -> ConfigEOS:
|
||||
"""Fixture to reset EOS config to default values."""
|
||||
config_eos = config_eos_factory()
|
||||
return config_eos
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
@@ -405,7 +473,11 @@ def server_base(
|
||||
Yields:
|
||||
dict[str, str]: A dictionary containing:
|
||||
- "server" (str): URL of the server.
|
||||
- "port": port
|
||||
- "eosdash_server": eosdash_server
|
||||
- "eosdash_port": eosdash_port
|
||||
- "eos_dir" (str): Path to the temporary EOS_DIR.
|
||||
- "timeout": server_timeout
|
||||
"""
|
||||
host = get_default_host()
|
||||
port = 8503
|
||||
@@ -427,12 +499,14 @@ def server_base(
|
||||
|
||||
eos_tmp_dir = tempfile.TemporaryDirectory()
|
||||
eos_dir = str(eos_tmp_dir.name)
|
||||
eos_general_data_folder_path = str(Path(eos_dir) / "data")
|
||||
|
||||
class Starter(ProcessStarter):
|
||||
# Set environment for server run
|
||||
env = os.environ.copy()
|
||||
env["EOS_DIR"] = eos_dir
|
||||
env["EOS_CONFIG_DIR"] = eos_dir
|
||||
env["EOS_GENERAL__DATA_FOLDER_PATH"] = eos_general_data_folder_path
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ from typing import Any
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from akkudoktoreos.config.config import get_config
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_config, get_ems, get_prediction
|
||||
from akkudoktoreos.core.emsettings import EnergyManagementMode
|
||||
from akkudoktoreos.optimization.genetic.geneticparams import (
|
||||
GeneticOptimizationParameters,
|
||||
)
|
||||
from akkudoktoreos.prediction.prediction import get_prediction
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime
|
||||
|
||||
config_eos = get_config()
|
||||
|
||||
@@ -6,8 +6,7 @@ import pstats
|
||||
import sys
|
||||
import time
|
||||
|
||||
from akkudoktoreos.config.config import get_config
|
||||
from akkudoktoreos.prediction.prediction import get_prediction
|
||||
from akkudoktoreos.core.coreabc import get_config, get_prediction
|
||||
|
||||
config_eos = get_config()
|
||||
prediction_eos = get_prediction()
|
||||
|
||||
@@ -12,11 +12,11 @@ import pytest
|
||||
from akkudoktoreos.adapter.adapter import (
|
||||
Adapter,
|
||||
AdapterCommonSettings,
|
||||
get_adapter,
|
||||
)
|
||||
from akkudoktoreos.adapter.adapterabc import AdapterContainer
|
||||
from akkudoktoreos.adapter.homeassistant import HomeAssistantAdapter
|
||||
from akkudoktoreos.adapter.nodered import NodeREDAdapter
|
||||
from akkudoktoreos.core.coreabc import get_adapter
|
||||
|
||||
# ---------- Typed aliases for fixtures ----------
|
||||
AdapterFixture: TypeAlias = Adapter
|
||||
|
||||
@@ -167,12 +167,21 @@ def temp_store_file():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_file_store(temp_store_file):
|
||||
def cache_file_store(temp_store_file, monkeypatch):
|
||||
"""A pytest fixture that creates a new CacheFileStore instance for testing."""
|
||||
|
||||
cache = CacheFileStore()
|
||||
cache._store_file = temp_store_file
|
||||
|
||||
# Patch the _cache_file method to return the temp file
|
||||
monkeypatch.setattr(
|
||||
cache,
|
||||
"_store_file",
|
||||
lambda: temp_store_file,
|
||||
)
|
||||
|
||||
cache.clear(clear_all=True)
|
||||
assert len(cache._store) == 0
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
@@ -481,7 +490,7 @@ class TestCacheFileStore:
|
||||
cache_file_store.save_store()
|
||||
|
||||
# Verify the file content
|
||||
with cache_file_store._store_file.open("r", encoding="utf-8", newline=None) as f:
|
||||
with cache_file_store._store_file().open("r", encoding="utf-8", newline=None) as f:
|
||||
store_loaded = json.load(f)
|
||||
assert "test_key" in store_loaded
|
||||
assert store_loaded["test_key"]["cache_file"] == "cache_file_path"
|
||||
@@ -501,7 +510,7 @@ class TestCacheFileStore:
|
||||
"ttl_duration": None,
|
||||
}
|
||||
}
|
||||
with cache_file_store._store_file.open("w", encoding="utf-8", newline="\n") as f:
|
||||
with cache_file_store._store_file().open("w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(cache_record, f, indent=4)
|
||||
|
||||
# Mock the open function to return a MagicMock for the cache file
|
||||
|
||||
@@ -109,17 +109,18 @@ def test_config_ipaddress(monkeypatch, config_eos):
|
||||
assert config_eos.server.host == "localhost"
|
||||
|
||||
|
||||
def test_singleton_behavior(config_eos, config_default_dirs):
|
||||
def test_singleton_behavior(config_eos, config_default_dirs, monkeypatch):
|
||||
"""Test that ConfigEOS behaves as a singleton."""
|
||||
initial_cfg_file = config_eos.general.config_file_path
|
||||
with patch(
|
||||
"akkudoktoreos.config.config.user_config_dir", return_value=str(config_default_dirs[0])
|
||||
):
|
||||
instance1 = ConfigEOS()
|
||||
instance2 = ConfigEOS()
|
||||
assert instance1 is config_eos
|
||||
config_eos.reset_instance()
|
||||
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", str(config_default_dirs[0]))
|
||||
|
||||
instance1 = ConfigEOS()
|
||||
instance2 = ConfigEOS()
|
||||
|
||||
assert instance1 is not config_eos
|
||||
assert instance1 is instance2
|
||||
assert instance1.general.config_file_path == initial_cfg_file
|
||||
assert instance1._config_file_path == instance2._config_file_path
|
||||
|
||||
|
||||
def test_config_file_priority(config_default_dirs):
|
||||
@@ -169,17 +170,22 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
monkeypatch.setenv("EOS_DIR", str(temp_dir_path))
|
||||
monkeypatch.delenv("EOS_CONFIG_DIR", raising=False)
|
||||
assert config_eos._get_config_file_path() == (cfg_file(temp_dir_path), False)
|
||||
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", "config")
|
||||
config_dir = temp_dir_path / "config"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
assert config_eos._get_config_file_path() == (
|
||||
cfg_file(temp_dir_path / "config"),
|
||||
cfg_file(config_dir),
|
||||
False,
|
||||
)
|
||||
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config2"))
|
||||
config_dir = temp_dir_path / "config2"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
assert config_eos._get_config_file_path() == (
|
||||
cfg_file(temp_dir_path / "config2"),
|
||||
cfg_file(config_dir),
|
||||
False,
|
||||
)
|
||||
|
||||
@@ -188,8 +194,10 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_
|
||||
assert config_eos._get_config_file_path() == (cfg_file(config_default_dir_user), False)
|
||||
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config3"))
|
||||
config_dir = temp_dir_path / "config3"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
assert config_eos._get_config_file_path() == (
|
||||
cfg_file(temp_dir_path / "config3"),
|
||||
cfg_file(config_dir),
|
||||
False,
|
||||
)
|
||||
|
||||
@@ -199,7 +207,7 @@ def test_config_copy(config_eos, monkeypatch):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_folder_path = Path(temp_dir)
|
||||
temp_config_file_path = temp_folder_path.joinpath(config_eos.CONFIG_FILE_NAME).resolve()
|
||||
monkeypatch.setenv(config_eos.EOS_DIR, str(temp_folder_path))
|
||||
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_folder_path))
|
||||
assert not temp_config_file_path.exists()
|
||||
with patch("akkudoktoreos.config.config.user_config_dir", return_value=temp_dir):
|
||||
assert config_eos._get_config_file_path() == (temp_config_file_path, False)
|
||||
|
||||
@@ -26,6 +26,9 @@ MIGRATION_PAIRS = [
|
||||
# (DIR_TESTDATA / "old_config_X.json", DIR_TESTDATA / "expected_config_X.json"),
|
||||
]
|
||||
|
||||
# Any sentinel in expected data
|
||||
_ANY_SENTINEL = "__ANY__"
|
||||
|
||||
|
||||
def _dict_contains(superset: Any, subset: Any, path="") -> list[str]:
|
||||
"""Recursively verify that all key-value pairs from a subset dictionary or list exist in a superset.
|
||||
@@ -60,6 +63,9 @@ def _dict_contains(superset: Any, subset: Any, path="") -> list[str]:
|
||||
errors.extend(_dict_contains(superset[i], elem, f"{path}[{i}]" if path else f"[{i}]"))
|
||||
|
||||
else:
|
||||
# "__ANY__" in expected means "accept whatever value the migration produces"
|
||||
if subset == _ANY_SENTINEL:
|
||||
return errors
|
||||
# Compare values (with numeric tolerance)
|
||||
if isinstance(subset, (int, float)) and isinstance(superset, (int, float)):
|
||||
if abs(float(subset) - float(superset)) > 1e-6:
|
||||
@@ -162,6 +168,7 @@ class TestConfigMigration:
|
||||
assert backup_file.exists(), f"Backup file not created for {old_file.name}"
|
||||
|
||||
# --- Compare migrated result with expected output ---
|
||||
old_data = json.loads(old_file.read_text(encoding="utf-8"))
|
||||
new_data = json.loads(working_file.read_text(encoding="utf-8"))
|
||||
expected_data = json.loads(expected_file.read_text(encoding="utf-8"))
|
||||
|
||||
@@ -202,6 +209,14 @@ class TestConfigMigration:
|
||||
# Verify the migrated value matches the expected one
|
||||
new_value = configmigrate._get_json_nested_value(new_data, new_path)
|
||||
if new_value != expected_value:
|
||||
# Check if this mapping uses _KEEP_DEFAULT and the old value was None/missing
|
||||
old_value = configmigrate._get_json_nested_value(old_data, old_path)
|
||||
keep_default = (
|
||||
isinstance(mapping, tuple)
|
||||
and configmigrate._KEEP_DEFAULT in mapping
|
||||
)
|
||||
if keep_default and old_value is None:
|
||||
continue # acceptable: old was None, new model keeps its default
|
||||
mismatched_values.append(
|
||||
f"{old_path} → {new_path}: expected {expected_value!r}, got {new_value!r}"
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1114
tests/test_dataabccompact.py
Normal file
1114
tests/test_dataabccompact.py
Normal file
File diff suppressed because it is too large
Load Diff
1148
tests/test_database.py
Normal file
1148
tests/test_database.py
Normal file
File diff suppressed because it is too large
Load Diff
888
tests/test_databaseabc.py
Normal file
888
tests/test_databaseabc.py
Normal file
@@ -0,0 +1,888 @@
|
||||
from typing import Any, Iterator, Literal, Optional, Type, cast
|
||||
|
||||
import pytest
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from akkudoktoreos.core.databaseabc import (
|
||||
DATABASE_METADATA_KEY,
|
||||
DatabaseRecordProtocolMixin,
|
||||
DatabaseTimestamp,
|
||||
_DatabaseTimestampUnbound,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import (
|
||||
DateTime,
|
||||
Duration,
|
||||
to_datetime,
|
||||
to_duration,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleRecord(BaseModel):
|
||||
date_time: Optional[DateTime] = Field(
|
||||
default=None, json_schema_extra={"description": "DateTime"}
|
||||
)
|
||||
value: Optional[float] = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if key == "date_time":
|
||||
return self.date_time
|
||||
if key == "value":
|
||||
return self.value
|
||||
assert key is None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake database backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleDatabase:
|
||||
def __init__(self):
|
||||
self._data: dict[Optional[str], dict[bytes, bytes]] = {}
|
||||
self._metadata: Optional[bytes] = None
|
||||
self.is_open = True
|
||||
self.compression = False
|
||||
self.compression_level = 0
|
||||
self.storage_path = "/fake"
|
||||
|
||||
# serialization (pass-through)
|
||||
|
||||
def serialize_data(self, data: bytes) -> bytes:
|
||||
return data
|
||||
|
||||
def deserialize_data(self, data: bytes) -> bytes:
|
||||
return data
|
||||
|
||||
# metadata
|
||||
|
||||
def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None:
|
||||
self._metadata = metadata
|
||||
|
||||
def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]:
|
||||
return self._metadata
|
||||
|
||||
# write
|
||||
|
||||
def save_records(
|
||||
self, records: list[tuple[bytes, bytes]], namespace: Optional[str] = None
|
||||
) -> int:
|
||||
ns = self._data.setdefault(namespace, {})
|
||||
saved = 0
|
||||
for key, value in records:
|
||||
ns[key] = value
|
||||
saved += 1
|
||||
return saved
|
||||
|
||||
def delete_records(
|
||||
self, keys: Iterator[bytes], namespace: Optional[str] = None
|
||||
) -> int:
|
||||
ns_data = self._data.get(namespace, {})
|
||||
deleted = 0
|
||||
for key in keys:
|
||||
if key in ns_data:
|
||||
del ns_data[key]
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
# read
|
||||
|
||||
def iterate_records(
|
||||
self,
|
||||
start_key: Optional[bytes] = None,
|
||||
end_key: Optional[bytes] = None,
|
||||
namespace: Optional[str] = None,
|
||||
reverse: bool = False,
|
||||
) -> Iterator[tuple[bytes, bytes]]:
|
||||
items = self._data.get(namespace, {})
|
||||
keys = sorted(items, reverse=reverse)
|
||||
for k in keys:
|
||||
if k == DATABASE_METADATA_KEY:
|
||||
continue
|
||||
if start_key and k < start_key:
|
||||
continue
|
||||
if end_key and k >= end_key:
|
||||
continue
|
||||
yield k, items[k]
|
||||
|
||||
# stats
|
||||
|
||||
def count_records(
|
||||
self,
|
||||
start_key: Optional[bytes] = None,
|
||||
end_key: Optional[bytes] = None,
|
||||
*,
|
||||
namespace: Optional[str] = None,
|
||||
) -> int:
|
||||
items = self._data.get(namespace, {})
|
||||
count = 0
|
||||
for k in items:
|
||||
if k == DATABASE_METADATA_KEY:
|
||||
continue
|
||||
if start_key and k < start_key:
|
||||
continue
|
||||
if end_key and k >= end_key:
|
||||
continue
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_key_range(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> tuple[Optional[bytes], Optional[bytes]]:
|
||||
items = self._data.get(namespace, {})
|
||||
keys = sorted(k for k in items if k != DATABASE_METADATA_KEY)
|
||||
if not keys:
|
||||
return None, None
|
||||
return keys[0], keys[-1]
|
||||
|
||||
def get_backend_stats(self, namespace: Optional[str] = None) -> dict:
|
||||
return {}
|
||||
|
||||
def flush(self, namespace: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete test sequence — minimal, no Pydantic / singleton overhead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleSequence(DatabaseRecordProtocolMixin[SampleRecord]):
|
||||
"""Minimal concrete implementation for unit-testing the mixin."""
|
||||
|
||||
def __init__(self):
|
||||
self.records: list[SampleRecord] = []
|
||||
self._db_record_index: dict[DatabaseTimestamp, SampleRecord] = {}
|
||||
self._db_sorted_timestamps: list[DatabaseTimestamp] = []
|
||||
self._db_dirty_timestamps: set[DatabaseTimestamp] = set()
|
||||
self._db_new_timestamps: set[DatabaseTimestamp] = set()
|
||||
self._db_deleted_timestamps: set[DatabaseTimestamp] = set()
|
||||
self._db_initialized: bool = True
|
||||
self._db_storage_initialized: bool = False
|
||||
self._db_metadata: Optional[dict] = None
|
||||
self._db_loaded_range = None
|
||||
from akkudoktoreos.core.databaseabc import DatabaseRecordProtocolLoadPhase
|
||||
self._db_load_phase = DatabaseRecordProtocolLoadPhase.NONE
|
||||
self._db_version: int = 1
|
||||
|
||||
self.database = SampleDatabase()
|
||||
self.config = type(
|
||||
"Cfg",
|
||||
(),
|
||||
{
|
||||
"database": type(
|
||||
"DBCfg",
|
||||
(),
|
||||
{
|
||||
"auto_save": False,
|
||||
"compression_level": 0,
|
||||
"autosave_interval_sec": 10,
|
||||
"initial_load_window_h": None,
|
||||
"keep_duration_h": None,
|
||||
},
|
||||
)()
|
||||
},
|
||||
)()
|
||||
|
||||
@classmethod
|
||||
def record_class(cls) -> Type[SampleRecord]:
|
||||
return SampleRecord
|
||||
|
||||
def db_namespace(self) -> str:
|
||||
return "test"
|
||||
|
||||
@property
|
||||
def record_keys_writable(self) -> list[str]:
|
||||
"""Return writable field names of SampleRecord.
|
||||
|
||||
Required by _db_compact_tier which iterates record_keys_writable
|
||||
to decide which fields to resample. Must match exactly what
|
||||
key_to_array accepts — only 'value' here, not 'date_time'.
|
||||
"""
|
||||
return ["value"]
|
||||
|
||||
# Override key_to_array for the mixin tests — the full DataSequence
|
||||
# implementation lives in dataabc.py; here we provide a minimal version
|
||||
# that resamples the single `value` field to demonstrate compaction.
|
||||
def key_to_array(
|
||||
self,
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
fill_method: Optional[str] = None,
|
||||
dropna: Optional[bool] = True,
|
||||
boundary: Literal["strict", "context"] = "context",
|
||||
align_to_interval: bool = False,
|
||||
) -> NDArray[Shape["*"], Any]:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
|
||||
dates = []
|
||||
values = []
|
||||
for record in self.records:
|
||||
if record.date_time is None:
|
||||
continue
|
||||
ts = DatabaseTimestamp.from_datetime(record.date_time)
|
||||
if start_datetime and DatabaseTimestamp.from_datetime(start_datetime) > ts:
|
||||
continue
|
||||
if end_datetime and DatabaseTimestamp.from_datetime(end_datetime) <= ts:
|
||||
continue
|
||||
dates.append(record.date_time)
|
||||
values.append(getattr(record, key, None))
|
||||
|
||||
if not dates:
|
||||
return np.array([])
|
||||
|
||||
index = pd.to_datetime(dates, utc=True)
|
||||
series = pd.Series(values, index=index, dtype=float)
|
||||
freq = f"{int(interval.total_seconds())}s"
|
||||
origin = start_datetime if start_datetime else "start_day"
|
||||
resampled = series.resample(freq, origin=origin).mean().interpolate("time")
|
||||
|
||||
if start_datetime is not None:
|
||||
resampled = resampled.truncate(before=start_datetime)
|
||||
if end_datetime is not None:
|
||||
resampled = resampled.truncate(after=end_datetime)
|
||||
|
||||
return resampled.values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _insert_records_every_n_minutes(
|
||||
seq: SampleSequence,
|
||||
base: DateTime,
|
||||
count: int,
|
||||
interval_minutes: int,
|
||||
value_fn=None,
|
||||
) -> None:
|
||||
"""Insert `count` records spaced `interval_minutes` apart starting at `base`."""
|
||||
for i in range(count):
|
||||
dt = base.add(minutes=i * interval_minutes)
|
||||
value = value_fn(i) if value_fn else float(i)
|
||||
seq.db_insert_record(SampleRecord(date_time=dt, value=value))
|
||||
seq.db_save_records()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seq():
|
||||
return SampleSequence()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seq_with_15min_data():
|
||||
"""Sequence with 15-min records spanning 4 weeks, so both tiers have data."""
|
||||
s = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
# 4 weeks × 7 days × 24 h × 4 records/h = 2688 records
|
||||
base = now.subtract(weeks=4)
|
||||
_insert_records_every_n_minutes(s, base, count=2688, interval_minutes=15)
|
||||
return s, now
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seq_sparse():
|
||||
"""Sequence with only 3 records spread over 4 weeks — sparse, no compaction benefit."""
|
||||
s = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(weeks=4)
|
||||
for offset_days in [0, 14, 27]:
|
||||
dt = base.add(days=offset_days)
|
||||
s.db_insert_record(SampleRecord(date_time=dt, value=float(offset_days)))
|
||||
s.db_save_records()
|
||||
return s, now
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Existing tests (unchanged)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatabaseRecordProtocolMixin:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start_str, value_count, interval_seconds",
|
||||
[
|
||||
("2024-11-10 00:00:00", 24, 3600),
|
||||
("2024-08-10 00:00:00", 24, 3600),
|
||||
("2024-03-31 00:00:00", 24, 3600),
|
||||
("2024-10-27 00:00:00", 24, 3600),
|
||||
],
|
||||
)
|
||||
def test_db_generate_timestamps_utc_spacing(
|
||||
self, seq, start_str, value_count, interval_seconds
|
||||
):
|
||||
start_dt = to_datetime(start_str, in_timezone="Europe/Berlin")
|
||||
assert start_dt.tz.name == "Europe/Berlin"
|
||||
|
||||
db_start = DatabaseTimestamp.from_datetime(start_dt)
|
||||
generated = list(seq.db_generate_timestamps(db_start, value_count))
|
||||
|
||||
assert len(generated) == value_count
|
||||
|
||||
for db_dt in generated:
|
||||
dt = DatabaseTimestamp.to_datetime(db_dt)
|
||||
assert dt.tz.name == "UTC"
|
||||
|
||||
assert len(generated) == len(set(generated)), "Duplicate UTC datetimes found"
|
||||
|
||||
for i in range(1, len(generated)):
|
||||
last_dt = DatabaseTimestamp.to_datetime(generated[i - 1])
|
||||
current_dt = DatabaseTimestamp.to_datetime(generated[i])
|
||||
delta = (current_dt - last_dt).total_seconds()
|
||||
assert delta == interval_seconds, f"Spacing mismatch at index {i}: {delta}s"
|
||||
|
||||
def test_insert_and_memory_range(self, seq):
|
||||
t0 = to_datetime()
|
||||
t1 = t0.add(hours=1)
|
||||
|
||||
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
|
||||
seq.db_insert_record(SampleRecord(date_time=t1, value=2))
|
||||
|
||||
assert seq.records[0].date_time == t0
|
||||
assert seq.records[-1].date_time == t1
|
||||
assert len(seq.records) == 2
|
||||
|
||||
def test_roundtrip_reload(self):
|
||||
seq = SampleSequence()
|
||||
t0 = to_datetime()
|
||||
t1 = t0.add(hours=1)
|
||||
|
||||
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
|
||||
seq.db_insert_record(SampleRecord(date_time=t1, value=2))
|
||||
assert seq.db_save_records() == 2
|
||||
|
||||
db = seq.database
|
||||
seq2 = SampleSequence()
|
||||
seq2.database = db
|
||||
loaded = seq2.db_load_records()
|
||||
|
||||
assert loaded == 2
|
||||
assert len(seq2.records) == 2
|
||||
|
||||
def test_db_count_records(self, seq):
|
||||
t0 = to_datetime()
|
||||
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
|
||||
assert seq.db_count_records() == 1
|
||||
seq.db_save_records()
|
||||
assert seq.db_count_records() == 1
|
||||
|
||||
def test_delete_range(self, seq):
|
||||
base = to_datetime()
|
||||
for i in range(5):
|
||||
seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i))
|
||||
|
||||
db_start = DatabaseTimestamp.from_datetime(base.add(minutes=1))
|
||||
db_end = DatabaseTimestamp.from_datetime(base.add(minutes=4))
|
||||
deleted = seq.db_delete_records(start_timestamp=db_start, end_timestamp=db_end)
|
||||
|
||||
assert deleted == 3
|
||||
assert [r.value for r in seq.records] == [0, 4]
|
||||
|
||||
def test_db_count_records_memory_only_multiple(self):
|
||||
seq = SampleSequence()
|
||||
base = to_datetime()
|
||||
for i in range(3):
|
||||
seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i))
|
||||
assert seq.db_count_records() == 3
|
||||
|
||||
def test_db_count_records_memory_newer_than_db(self):
|
||||
seq = SampleSequence()
|
||||
base = to_datetime()
|
||||
seq.db_insert_record(SampleRecord(date_time=base, value=1))
|
||||
seq.db_save_records()
|
||||
seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2))
|
||||
seq.db_insert_record(SampleRecord(date_time=base.add(hours=2), value=3))
|
||||
assert seq.db_count_records() == 3
|
||||
|
||||
def test_db_count_records_memory_older_than_db(self):
|
||||
seq = SampleSequence()
|
||||
base = to_datetime()
|
||||
seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2))
|
||||
seq.db_save_records()
|
||||
seq.db_insert_record(SampleRecord(date_time=base, value=1))
|
||||
assert seq.db_count_records() == 2
|
||||
|
||||
def test_db_count_records_empty_everywhere(self):
|
||||
seq = SampleSequence()
|
||||
assert seq.db_count_records() == 0
|
||||
|
||||
def test_metadata_not_counted(self, seq):
|
||||
seq.database._data.setdefault("test", {})[DATABASE_METADATA_KEY] = b"meta"
|
||||
assert seq.db_count_records() == 0
|
||||
|
||||
def test_key_range_excludes_metadata(self, seq):
|
||||
ns = seq.db_namespace()
|
||||
seq.database._data.setdefault(ns, {})[DATABASE_METADATA_KEY] = b"meta"
|
||||
assert seq.database.get_key_range(ns) == (None, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compaction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTiers:
|
||||
"""Tests for db_compact_tiers() and the tier hook."""
|
||||
|
||||
def test_default_tiers_returns_two_entries(self, seq):
|
||||
tiers = seq.db_compact_tiers()
|
||||
assert len(tiers) == 2
|
||||
|
||||
def test_default_tiers_ordered_shortest_first(self, seq):
|
||||
tiers = seq.db_compact_tiers()
|
||||
ages = [t[0].total_seconds() for t in tiers]
|
||||
assert ages == sorted(ages), "Tiers must be ordered shortest age first"
|
||||
|
||||
def test_default_tiers_first_is_2h_to_15min(self, seq):
|
||||
tiers = seq.db_compact_tiers()
|
||||
age_sec, interval_sec = (
|
||||
tiers[0][0].total_seconds(),
|
||||
tiers[0][1].total_seconds(),
|
||||
)
|
||||
assert age_sec == 2 * 3600
|
||||
assert interval_sec == 15 * 60
|
||||
|
||||
def test_default_tiers_second_is_2weeks_to_1h(self, seq):
|
||||
tiers = seq.db_compact_tiers()
|
||||
age_sec, interval_sec = (
|
||||
tiers[1][0].total_seconds(),
|
||||
tiers[1][1].total_seconds(),
|
||||
)
|
||||
assert age_sec == 14 * 24 * 3600
|
||||
assert interval_sec == 3600
|
||||
|
||||
def test_override_tiers(self):
|
||||
class CustomSeq(SampleSequence):
|
||||
def db_compact_tiers(self):
|
||||
return [(to_duration("7 days"), to_duration("1 hour"))]
|
||||
|
||||
s = CustomSeq()
|
||||
tiers = s.db_compact_tiers()
|
||||
assert len(tiers) == 1
|
||||
assert tiers[0][1].total_seconds() == 3600
|
||||
|
||||
def test_empty_tiers_disables_compaction(self):
|
||||
class NoCompactSeq(SampleSequence):
|
||||
def db_compact_tiers(self):
|
||||
return []
|
||||
|
||||
s = NoCompactSeq()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(weeks=4)
|
||||
_insert_records_every_n_minutes(s, base, count=100, interval_minutes=15)
|
||||
|
||||
deleted = s.db_compact()
|
||||
assert deleted == 0
|
||||
|
||||
|
||||
class TestCompactState:
|
||||
"""Tests for _db_get_compact_state / _db_set_compact_state."""
|
||||
|
||||
def test_get_state_returns_none_when_no_metadata(self, seq):
|
||||
interval = to_duration("1 hour")
|
||||
assert seq._db_get_compact_state(interval) is None
|
||||
|
||||
def test_set_and_get_state_roundtrip(self, seq):
|
||||
interval = to_duration("1 hour")
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
ts = DatabaseTimestamp.from_datetime(now)
|
||||
|
||||
seq._db_set_compact_state(interval, ts)
|
||||
retrieved = seq._db_get_compact_state(interval)
|
||||
|
||||
assert retrieved == ts
|
||||
|
||||
def test_state_is_per_tier(self, seq):
|
||||
"""Different tier intervals must not overwrite each other."""
|
||||
interval_15min = to_duration("15 minutes")
|
||||
interval_1h = to_duration("1 hour")
|
||||
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
ts_15 = DatabaseTimestamp.from_datetime(now)
|
||||
ts_1h = DatabaseTimestamp.from_datetime(now.subtract(days=1))
|
||||
|
||||
seq._db_set_compact_state(interval_15min, ts_15)
|
||||
seq._db_set_compact_state(interval_1h, ts_1h)
|
||||
|
||||
assert seq._db_get_compact_state(interval_15min) == ts_15
|
||||
assert seq._db_get_compact_state(interval_1h) == ts_1h
|
||||
|
||||
def test_state_persists_in_metadata(self, seq):
|
||||
"""State must survive a metadata reload."""
|
||||
interval = to_duration("1 hour")
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
ts = DatabaseTimestamp.from_datetime(now)
|
||||
|
||||
seq._db_set_compact_state(interval, ts)
|
||||
|
||||
# Reload metadata from fake DB
|
||||
seq2 = SampleSequence()
|
||||
seq2.database = seq.database
|
||||
seq2._db_metadata = seq2._db_load_metadata()
|
||||
|
||||
assert seq2._db_get_compact_state(interval) == ts
|
||||
|
||||
|
||||
class TestCompactSparseGuard:
|
||||
"""The inflation guard must skip compaction when records are already sparse."""
|
||||
|
||||
def test_sparse_data_aligns_but_does_not_reduce_cardinality(self, seq_sparse):
|
||||
"""Sparse data must be aligned to the target interval for all records that were modified."""
|
||||
seq, _ = seq_sparse
|
||||
|
||||
interval = to_duration("15 minutes")
|
||||
interval_sec = int(interval.total_seconds())
|
||||
|
||||
# Snapshot original timestamps
|
||||
before_epochs = {
|
||||
int(r.date_time.timestamp())
|
||||
for r in seq.records
|
||||
}
|
||||
|
||||
seq._db_compact_tier(
|
||||
to_duration("30 minutes"),
|
||||
interval,
|
||||
)
|
||||
|
||||
after_epochs = {
|
||||
int(r.date_time.timestamp())
|
||||
for r in seq.records
|
||||
}
|
||||
|
||||
# Cardinality must not increase
|
||||
assert len(after_epochs) <= len(before_epochs)
|
||||
|
||||
# Any timestamp that changed must now be aligned
|
||||
changed_epochs = after_epochs - before_epochs
|
||||
|
||||
for epoch in changed_epochs:
|
||||
assert epoch % interval_sec == 0
|
||||
|
||||
def test_sparse_guard_advances_cutoff(self, seq_sparse):
|
||||
"""Even when skipped, the cutoff should be stored so next run skips the same window."""
|
||||
seq, _ = seq_sparse
|
||||
interval_1h = to_duration("1 hour")
|
||||
interval_15min = to_duration("15 minutes")
|
||||
|
||||
seq.db_compact()
|
||||
|
||||
# Both tiers should have stored a cutoff even though nothing was deleted
|
||||
assert seq._db_get_compact_state(interval_1h) is not None
|
||||
assert seq._db_get_compact_state(interval_15min) is not None
|
||||
|
||||
def test_exactly_at_boundary_remains_stable(self, seq):
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
interval = to_duration("1 hour")
|
||||
|
||||
raw_base = now.subtract(hours=5).set(minute=0, second=0, microsecond=0)
|
||||
base = raw_base.subtract(seconds=int(raw_base.timestamp()) % 3600)
|
||||
|
||||
for i in range(4):
|
||||
seq.db_insert_record(
|
||||
SampleRecord(
|
||||
date_time=base.add(hours=i),
|
||||
value=float(i),
|
||||
)
|
||||
)
|
||||
|
||||
seq.db_insert_record(
|
||||
SampleRecord(date_time=now.subtract(seconds=1), value=0.0)
|
||||
)
|
||||
seq.db_save_records()
|
||||
|
||||
before = [
|
||||
(int(r.date_time.timestamp()), r.value)
|
||||
for r in seq.records
|
||||
]
|
||||
|
||||
seq._db_compact_tier(
|
||||
to_duration("30 minutes"),
|
||||
interval,
|
||||
)
|
||||
|
||||
after = [
|
||||
(int(r.date_time.timestamp()), r.value)
|
||||
for r in seq.records
|
||||
]
|
||||
|
||||
assert before == after
|
||||
|
||||
|
||||
class TestCompactTierWorker:
|
||||
"""Unit tests for _db_compact_tier directly."""
|
||||
|
||||
def test_empty_sequence_returns_zero(self, seq):
|
||||
age = to_duration("2 hours")
|
||||
interval = to_duration("15 minutes")
|
||||
assert seq._db_compact_tier(age, interval) == 0
|
||||
|
||||
def test_all_records_too_recent_skipped(self):
|
||||
"""Records within the age threshold must not be touched."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
# Insert 10 records from 30 minutes ago — all within 2h threshold
|
||||
base = now.subtract(minutes=30)
|
||||
_insert_records_every_n_minutes(seq, base, count=10, interval_minutes=1)
|
||||
|
||||
before = seq.db_count_records()
|
||||
deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
assert deleted == 0
|
||||
assert seq.db_count_records() == before
|
||||
|
||||
def test_compaction_reduces_record_count(self):
|
||||
"""Dense 1-min records older than 2h should be downsampled to 15-min."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
# Insert 1-min records for 6 hours ending 3 hours ago
|
||||
base = now.subtract(hours=9)
|
||||
_insert_records_every_n_minutes(seq, base, count=6 * 60, interval_minutes=1)
|
||||
|
||||
before = seq.db_count_records()
|
||||
deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
after = seq.db_count_records()
|
||||
assert deleted > 0
|
||||
assert after < before
|
||||
|
||||
def test_records_within_threshold_preserved(self):
|
||||
"""Records newer than age_threshold must remain untouched after compaction."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
|
||||
# Old dense records (will be compacted)
|
||||
old_base = now.subtract(hours=6)
|
||||
_insert_records_every_n_minutes(seq, old_base, count=4 * 60, interval_minutes=1)
|
||||
|
||||
# Recent records (must not be touched) — insert 5 records in the last hour
|
||||
recent_base = now.subtract(minutes=50)
|
||||
_insert_records_every_n_minutes(seq, recent_base, count=5, interval_minutes=10)
|
||||
|
||||
recent_before = [
|
||||
r for r in seq.records
|
||||
if r.date_time and r.date_time >= recent_base
|
||||
]
|
||||
|
||||
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
recent_after = [
|
||||
r for r in seq.records
|
||||
if r.date_time and r.date_time >= recent_base
|
||||
]
|
||||
assert len(recent_after) == len(recent_before)
|
||||
|
||||
def test_incremental_cutoff_prevents_recompaction(self):
|
||||
"""Running compaction twice must not re-compact already-compacted data."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(hours=8)
|
||||
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
|
||||
|
||||
age = to_duration("2 hours")
|
||||
interval = to_duration("15 minutes")
|
||||
|
||||
deleted_first = seq._db_compact_tier(age, interval)
|
||||
count_after_first = seq.db_count_records()
|
||||
|
||||
deleted_second = seq._db_compact_tier(age, interval)
|
||||
count_after_second = seq.db_count_records()
|
||||
|
||||
assert deleted_first > 0
|
||||
assert deleted_second == 0, "Second run must be a no-op"
|
||||
assert count_after_first == count_after_second
|
||||
|
||||
def test_cutoff_stored_after_compaction(self):
|
||||
"""Cutoff timestamp must be persisted after a successful compaction run."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(hours=8)
|
||||
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
|
||||
|
||||
interval = to_duration("15 minutes")
|
||||
seq._db_compact_tier(to_duration("2 hours"), interval)
|
||||
|
||||
assert seq._db_get_compact_state(interval) is not None
|
||||
|
||||
|
||||
class TestDbCompact:
|
||||
"""Integration tests for the public db_compact() entry point."""
|
||||
|
||||
def test_compact_dense_data_both_tiers(self, seq_with_15min_data):
|
||||
"""4 weeks of 15-min data should be reduced by both tiers."""
|
||||
seq, _ = seq_with_15min_data
|
||||
before = seq.db_count_records()
|
||||
|
||||
total_deleted = seq.db_compact()
|
||||
|
||||
after = seq.db_count_records()
|
||||
assert total_deleted > 0
|
||||
assert after < before
|
||||
|
||||
def test_compact_coarsest_tier_runs_first(self, seq_with_15min_data):
|
||||
"""The 1-hour tier (coarsest) must run before the 15-min tier.
|
||||
|
||||
If coarsest ran last it would re-compact records the 15-min tier
|
||||
had already downsampled — verified by checking that the 1-hour
|
||||
cutoff is not later than the 15-min cutoff.
|
||||
"""
|
||||
seq, _ = seq_with_15min_data
|
||||
seq.db_compact()
|
||||
|
||||
cutoff_1h = seq._db_get_compact_state(to_duration("1 hour"))
|
||||
cutoff_15min = seq._db_get_compact_state(to_duration("15 minutes"))
|
||||
|
||||
assert cutoff_1h is not None
|
||||
assert cutoff_15min is not None
|
||||
# The 1h tier covers older data → its cutoff must be earlier than 15min tier
|
||||
assert cutoff_1h <= cutoff_15min
|
||||
|
||||
def test_compact_idempotent(self, seq_with_15min_data):
|
||||
"""Running db_compact twice must not change record count."""
|
||||
seq, _ = seq_with_15min_data
|
||||
seq.db_compact()
|
||||
after_first = seq.db_count_records()
|
||||
|
||||
seq.db_compact()
|
||||
after_second = seq.db_count_records()
|
||||
|
||||
assert after_first == after_second
|
||||
|
||||
def test_compact_empty_sequence_returns_zero(self, seq):
|
||||
assert seq.db_compact() == 0
|
||||
|
||||
def test_compact_with_override_tiers(self):
|
||||
"""Passing compact_tiers directly must override db_compact_tiers()."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(weeks=3)
|
||||
_insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 4, interval_minutes=15)
|
||||
|
||||
before = seq.db_count_records()
|
||||
deleted = seq.db_compact(
|
||||
compact_tiers=[(to_duration("1 day"), to_duration("1 hour"))]
|
||||
)
|
||||
|
||||
assert deleted > 0
|
||||
assert seq.db_count_records() < before
|
||||
|
||||
def test_compact_only_processes_new_window_on_second_call(self):
|
||||
"""Second call processes only the new window, not the full history."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(weeks=3)
|
||||
# Dense 1-min data for 3 weeks
|
||||
_insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 60, interval_minutes=1)
|
||||
|
||||
seq.db_compact()
|
||||
count_after_first = seq.db_count_records()
|
||||
|
||||
# Add one more day of dense data in the past (simulate new old data arriving)
|
||||
extra_base = now.subtract(weeks=3).subtract(days=1)
|
||||
_insert_records_every_n_minutes(seq, extra_base, count=24 * 60, interval_minutes=1)
|
||||
|
||||
seq.db_compact()
|
||||
count_after_second = seq.db_count_records()
|
||||
|
||||
# Second compact should have processed the newly added old data
|
||||
# Record count may change but should not exceed first compacted count by much
|
||||
assert count_after_second >= 0 # basic sanity
|
||||
|
||||
|
||||
class TestCompactDataIntegrity:
|
||||
"""Verify value integrity is preserved after compaction."""
|
||||
|
||||
def test_constant_value_preserved(self):
|
||||
"""Constant value field must survive mean-resampling unchanged."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(hours=6)
|
||||
|
||||
# All values = 42.0
|
||||
_insert_records_every_n_minutes(
|
||||
seq, base, count=6 * 60, interval_minutes=1, value_fn=lambda _: 42.0
|
||||
)
|
||||
|
||||
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
for record in seq.records:
|
||||
if record.date_time and record.date_time < now.subtract(hours=2):
|
||||
assert record.value == pytest.approx(42.0, abs=1e-6)
|
||||
|
||||
def test_recent_records_not_modified(self):
|
||||
"""Records newer than the age threshold must have unchanged values."""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
|
||||
old_base = now.subtract(hours=6)
|
||||
_insert_records_every_n_minutes(seq, old_base, count=3 * 60, interval_minutes=1)
|
||||
|
||||
# Known recent values
|
||||
recent_base = now.subtract(minutes=30)
|
||||
expected = {i * 10: float(100 + i) for i in range(3)}
|
||||
for offset, val in expected.items():
|
||||
dt = recent_base.add(minutes=offset)
|
||||
seq.db_insert_record(SampleRecord(date_time=dt, value=val))
|
||||
seq.db_save_records()
|
||||
|
||||
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
for record in seq.records:
|
||||
if record.date_time and record.date_time >= recent_base:
|
||||
offset = int((record.date_time - recent_base).total_seconds() / 60)
|
||||
if offset in expected:
|
||||
assert record.value == pytest.approx(expected[offset], abs=1e-6)
|
||||
|
||||
def test_compacted_timestamps_spacing(self):
|
||||
"""Resampled records must be fewer than original and span the compaction window.
|
||||
|
||||
Exact per-bucket spacing depends on the full DataSequence.key_to_array
|
||||
implementation (pandas resampling). The stub key_to_array in SampleSequence
|
||||
only guarantees a reduction in count — uniform spacing is verified in
|
||||
test_dataabc_compact.py against the real implementation.
|
||||
"""
|
||||
seq = SampleSequence()
|
||||
now = to_datetime().in_timezone("UTC")
|
||||
base = now.subtract(hours=6)
|
||||
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
|
||||
|
||||
before = seq.db_count_records()
|
||||
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
|
||||
|
||||
cutoff = now.subtract(hours=2)
|
||||
compacted = sorted(
|
||||
[r for r in seq.records if r.date_time and r.date_time < cutoff],
|
||||
key=lambda r: cast(DateTime, r.date_time),
|
||||
)
|
||||
|
||||
# Must have produced fewer records than the original 1-min data
|
||||
assert len(compacted) > 0, "Expected at least one compacted record"
|
||||
assert len(compacted) < before, "Compaction must reduce record count"
|
||||
|
||||
# Window start is floored to interval boundary
|
||||
interval_sec = 15 * 60
|
||||
expected_window_start = DateTime.fromtimestamp(
|
||||
(int(base.timestamp()) // interval_sec) * interval_sec,
|
||||
tz="UTC",
|
||||
)
|
||||
assert compacted[0].date_time >= expected_window_start
|
||||
|
||||
# Last compacted record must be before the cutoff
|
||||
assert compacted[-1].date_time < cutoff
|
||||
@@ -1460,6 +1460,17 @@ class TestTimeWindowSequence:
|
||||
# - without local timezone as UTC
|
||||
(
|
||||
"TC014",
|
||||
"UTC",
|
||||
"2024-01-03",
|
||||
None,
|
||||
"UTC",
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 3, 0, 0, 0, tz="UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC015",
|
||||
"Atlantic/Canary",
|
||||
"02/02/24",
|
||||
None,
|
||||
@@ -1470,7 +1481,7 @@ class TestTimeWindowSequence:
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC015",
|
||||
"TC016",
|
||||
"Atlantic/Canary",
|
||||
"2024-03-03T10:20:30.000Z", # No dalight saving time at this date
|
||||
None,
|
||||
@@ -1484,7 +1495,7 @@ class TestTimeWindowSequence:
|
||||
# from pendulum.datetime to pendulum.datetime object
|
||||
# ---------------------------------------
|
||||
(
|
||||
"TC016",
|
||||
"TC017",
|
||||
"Atlantic/Canary",
|
||||
pendulum.datetime(2024, 4, 4, 0, 0, 0),
|
||||
None,
|
||||
@@ -1495,7 +1506,7 @@ class TestTimeWindowSequence:
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC017",
|
||||
"TC018",
|
||||
"Atlantic/Canary",
|
||||
pendulum.datetime(2024, 4, 4, 1, 0, 0),
|
||||
None,
|
||||
@@ -1506,7 +1517,7 @@ class TestTimeWindowSequence:
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC018",
|
||||
"TC019",
|
||||
"Atlantic/Canary",
|
||||
pendulum.datetime(2024, 4, 4, 1, 0, 0, tz="Etc/UTC"),
|
||||
None,
|
||||
@@ -1517,7 +1528,7 @@ class TestTimeWindowSequence:
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC019",
|
||||
"TC020",
|
||||
"Atlantic/Canary",
|
||||
pendulum.datetime(2024, 4, 4, 2, 0, 0, tz="Europe/Berlin"),
|
||||
None,
|
||||
@@ -1533,7 +1544,7 @@ class TestTimeWindowSequence:
|
||||
# - no timezone
|
||||
# local timezone UTC
|
||||
(
|
||||
"TC020",
|
||||
"TC021",
|
||||
"Etc/UTC",
|
||||
"2023-11-06T00:00:00",
|
||||
"UTC",
|
||||
@@ -1545,7 +1556,7 @@ class TestTimeWindowSequence:
|
||||
),
|
||||
# local timezone "Europe/Berlin"
|
||||
(
|
||||
"TC021",
|
||||
"TC022",
|
||||
"Europe/Berlin",
|
||||
"2023-11-06T00:00:00",
|
||||
"UTC",
|
||||
@@ -1557,7 +1568,7 @@ class TestTimeWindowSequence:
|
||||
),
|
||||
# - no microseconds
|
||||
(
|
||||
"TC022",
|
||||
"TC023",
|
||||
"Atlantic/Canary",
|
||||
"2024-10-30T00:00:00+01:00",
|
||||
"UTC",
|
||||
@@ -1568,7 +1579,7 @@ class TestTimeWindowSequence:
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC023",
|
||||
"TC024",
|
||||
"Atlantic/Canary",
|
||||
"2024-10-30T01:00:00+01:00",
|
||||
"utc",
|
||||
@@ -1580,7 +1591,7 @@ class TestTimeWindowSequence:
|
||||
),
|
||||
# - with microseconds
|
||||
(
|
||||
"TC024",
|
||||
"TC025",
|
||||
"Atlantic/Canary",
|
||||
"2024-10-07T10:20:30.000+02:00",
|
||||
"UTC",
|
||||
@@ -1596,7 +1607,7 @@ class TestTimeWindowSequence:
|
||||
# - no timezone
|
||||
# local timezone
|
||||
(
|
||||
"TC025",
|
||||
"TC026",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|
||||
@@ -14,8 +14,10 @@ DIR_DOCS_GENERATED = DIR_PROJECT_ROOT / "docs" / "_generated"
|
||||
DIR_TEST_GENERATED = DIR_TESTDATA / "docs" / "_generated"
|
||||
|
||||
|
||||
def test_openapi_spec_current(config_eos):
|
||||
def test_openapi_spec_current(config_eos, set_other_timezone):
|
||||
"""Verify the openapi spec hasn´t changed."""
|
||||
set_other_timezone("UTC") # CI runs on UTC
|
||||
|
||||
expected_spec_path = DIR_PROJECT_ROOT / "openapi.json"
|
||||
new_spec_path = DIR_TESTDATA / "openapi-new.json"
|
||||
|
||||
@@ -23,7 +25,7 @@ def test_openapi_spec_current(config_eos):
|
||||
expected_spec = json.load(f_expected)
|
||||
|
||||
# Patch get_config and import within guard to patch global variables within the eos module.
|
||||
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
|
||||
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
|
||||
# Ensure the script works correctly as part of a package
|
||||
root_dir = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(root_dir))
|
||||
@@ -39,7 +41,7 @@ def test_openapi_spec_current(config_eos):
|
||||
expected_spec_str = json.dumps(expected_spec, indent=4, sort_keys=True)
|
||||
|
||||
try:
|
||||
assert spec_str == expected_spec_str
|
||||
assert json.loads(spec_str) == json.loads(expected_spec_str)
|
||||
except AssertionError as e:
|
||||
pytest.fail(
|
||||
f"Expected {new_spec_path} to equal {expected_spec_path}.\n"
|
||||
@@ -47,8 +49,10 @@ def test_openapi_spec_current(config_eos):
|
||||
)
|
||||
|
||||
|
||||
def test_openapi_md_current(config_eos):
|
||||
def test_openapi_md_current(config_eos, set_other_timezone):
|
||||
"""Verify the generated openapi markdown hasn´t changed."""
|
||||
set_other_timezone("UTC") # CI runs on UTC
|
||||
|
||||
expected_spec_md_path = DIR_PROJECT_ROOT / "docs" / "_generated" / "openapi.md"
|
||||
new_spec_md_path = DIR_TESTDATA / "openapi-new.md"
|
||||
|
||||
@@ -56,7 +60,7 @@ def test_openapi_md_current(config_eos):
|
||||
expected_spec_md = f_expected.read()
|
||||
|
||||
# Patch get_config and import within guard to patch global variables within the eos module.
|
||||
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
|
||||
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
|
||||
# Ensure the script works correctly as part of a package
|
||||
root_dir = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(root_dir))
|
||||
@@ -76,8 +80,10 @@ def test_openapi_md_current(config_eos):
|
||||
)
|
||||
|
||||
|
||||
def test_config_md_current(config_eos):
|
||||
def test_config_md_current(config_eos, set_other_timezone):
|
||||
"""Verify the generated configuration markdown hasn´t changed."""
|
||||
set_other_timezone("UTC") # CI runs on UTC
|
||||
|
||||
assert DIR_DOCS_GENERATED.exists()
|
||||
|
||||
# Remove any leftover files from last run
|
||||
@@ -88,7 +94,7 @@ def test_config_md_current(config_eos):
|
||||
DIR_TEST_GENERATED.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Patch get_config and import within guard to patch global variables within the eos module.
|
||||
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
|
||||
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
|
||||
# Ensure the script works correctly as part of a package
|
||||
root_dir = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(root_dir))
|
||||
@@ -106,7 +112,11 @@ def test_config_md_current(config_eos):
|
||||
tested.append(DIR_TEST_GENERATED / file_name)
|
||||
|
||||
# Create test files
|
||||
config_md = generate_config_md.generate_config_md(tested[0], config_eos)
|
||||
try:
|
||||
config_eos._force_documentation_mode = True
|
||||
config_md = generate_config_md.generate_config_md(tested[0], config_eos)
|
||||
finally:
|
||||
config_eos._force_documentation_mode = False
|
||||
|
||||
# Check test files are the same as the expected files
|
||||
for i, expected_path in enumerate(expected):
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.coreabc import singletons_init
|
||||
|
||||
DIR_PROJECT_ROOT = Path(__file__).absolute().parent.parent
|
||||
DIR_BUILD = DIR_PROJECT_ROOT / "build"
|
||||
DIR_BUILD_DOCS = DIR_PROJECT_ROOT / "build" / "docs"
|
||||
@@ -80,6 +82,7 @@ class TestSphinxDocumentation:
|
||||
|
||||
def test_sphinx_build(self, sphinx_changed: Optional[str], is_finalize: bool):
|
||||
"""Build Sphinx documentation and ensure no major warnings appear in the build output."""
|
||||
|
||||
# Ensure docs folder exists
|
||||
if not DIR_DOCS.exists():
|
||||
pytest.skip(f"Skipping Sphinx build test - docs folder not present: {DIR_DOCS}")
|
||||
@@ -88,7 +91,7 @@ class TestSphinxDocumentation:
|
||||
pytest.skip(f"Skipping Sphinx build — no relevant file changes detected: {HASH_FILE}")
|
||||
|
||||
if not is_finalize:
|
||||
pytest.skip("Skipping Sphinx test — not full run")
|
||||
pytest.skip("Skipping Sphinx test — not finalize")
|
||||
|
||||
# Clean directories
|
||||
self._cleanup_autosum_dirs()
|
||||
@@ -123,7 +126,11 @@ class TestSphinxDocumentation:
|
||||
# Remove temporary EOS_DIR
|
||||
eos_tmp_dir.cleanup()
|
||||
|
||||
assert returncode == 0
|
||||
if returncode != 0:
|
||||
pytest.fail(
|
||||
f"Sphinx build failed with exit code {returncode}.\n"
|
||||
f"{output}\n"
|
||||
)
|
||||
|
||||
# Possible markers: ERROR: WARNING: TRACEBACK:
|
||||
major_markers = ("ERROR:", "TRACEBACK:")
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from loguru import logger
|
||||
|
||||
from akkudoktoreos.core.cache import CacheFileStore
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.elecpriceakkudoktor import (
|
||||
AkkudoktorElecPrice,
|
||||
AkkudoktorElecPriceValue,
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from loguru import logger
|
||||
|
||||
from akkudoktoreos.core.cache import CacheFileStore
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.elecpriceakkudoktor import (
|
||||
AkkudoktorElecPrice,
|
||||
AkkudoktorElecPriceValue,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy.testing as npt
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
|
||||
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
|
||||
|
||||
@@ -83,6 +84,7 @@ def test_invalid_provider(provider, config_eos):
|
||||
)
|
||||
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
|
||||
"""Test fetching forecast from Import."""
|
||||
key = "elecprice_marketprice_wh"
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
|
||||
if from_file:
|
||||
@@ -91,7 +93,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
else:
|
||||
config_eos.elecprice.elecpriceimport.import_file_path = None
|
||||
assert config_eos.elecprice.elecpriceimport.import_file_path is None
|
||||
provider.clear()
|
||||
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
|
||||
# Call the method
|
||||
provider.update_data()
|
||||
@@ -100,16 +102,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
assert provider.ems_start_datetime is not None
|
||||
assert provider.total_hours is not None
|
||||
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["elecprice_marketprice_wh"]
|
||||
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
expected_value = values[expected_value_index]
|
||||
result_datetime = provider.records[i].date_time
|
||||
result_value = provider.records[i]["elecprice_marketprice_wh"]
|
||||
|
||||
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
|
||||
# print(f"{i}: Result: {result_datetime}:{result_value}")
|
||||
assert compare_datetimes(result_datetime, expected_datetime).equal
|
||||
assert result_value == expected_value
|
||||
expected_values = sample_import_1_json[key]
|
||||
result_values = provider.key_to_array(
|
||||
key=key,
|
||||
start_datetime=provider.ems_start_datetime,
|
||||
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
|
||||
interval=to_duration("1 hour"),
|
||||
)
|
||||
# Allow for some difference due to value calculation on DST change
|
||||
npt.assert_allclose(result_values, expected_values, rtol=0.001)
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.feedintarifffixed import FeedInTariffFixed
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from akkudoktoreos.config.config import ConfigEOS
|
||||
from akkudoktoreos.core.cache import CacheEnergyManagementStore
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.optimization.genetic.genetic import GeneticOptimization
|
||||
from akkudoktoreos.optimization.genetic.geneticparams import (
|
||||
GeneticOptimizationParameters,
|
||||
@@ -18,7 +18,7 @@ from akkudoktoreos.utils.visualize import (
|
||||
prepare_visualize, # Import the new prepare_visualize
|
||||
)
|
||||
|
||||
ems_eos = get_ems()
|
||||
ems_eos = get_ems(init=True) # init once
|
||||
|
||||
DIR_TESTDATA = Path(__file__).parent / "testdata"
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import numpy as np
|
||||
import pendulum
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.measurement.measurement import MeasurementDataRecord, get_measurement
|
||||
from akkudoktoreos.core.coreabc import get_ems, get_measurement
|
||||
from akkudoktoreos.measurement.measurement import MeasurementDataRecord
|
||||
from akkudoktoreos.prediction.loadakkudoktor import (
|
||||
LoadAkkudoktor,
|
||||
LoadAkkudoktorAdjusted,
|
||||
@@ -63,7 +63,7 @@ def measurement_eos():
|
||||
dt = to_datetime("2024-01-01T00:00:00")
|
||||
interval = to_duration("1 hour")
|
||||
for i in range(25):
|
||||
measurement.records.append(
|
||||
measurement.insert_by_datetime(
|
||||
MeasurementDataRecord(
|
||||
date_time=dt,
|
||||
load0_mr=load0_mr,
|
||||
@@ -138,7 +138,7 @@ def test_update_data(mock_load_data, loadakkudoktor):
|
||||
ems_eos.set_start_datetime(pendulum.datetime(2024, 1, 1))
|
||||
|
||||
# Assure there are no prediction records
|
||||
loadakkudoktor.clear()
|
||||
loadakkudoktor.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
assert len(loadakkudoktor) == 0
|
||||
|
||||
# Execute the method
|
||||
@@ -152,6 +152,24 @@ def test_calculate_adjustment(loadakkudoktoradjusted, measurement_eos):
|
||||
"""Test `_calculate_adjustment` for various scenarios."""
|
||||
data_year_energy = np.random.rand(365, 2, 24)
|
||||
|
||||
# Check the test setup
|
||||
assert loadakkudoktoradjusted.measurement is measurement_eos
|
||||
assert measurement_eos.min_datetime == to_datetime("2024-01-01T00:00:00")
|
||||
assert measurement_eos.max_datetime == to_datetime("2024-01-02T00:00:00")
|
||||
# Use same calculation as in _calculate_adjustment
|
||||
compare_start = measurement_eos.max_datetime - to_duration("7 days")
|
||||
if compare_datetimes(compare_start, measurement_eos.min_datetime).lt:
|
||||
# Not enough measurements for 7 days - use what is available
|
||||
compare_start = measurement_eos.min_datetime
|
||||
compare_end = measurement_eos.max_datetime
|
||||
compare_interval = to_duration("1 hour")
|
||||
load_total_kwh_array = measurement_eos.load_total_kwh(
|
||||
start_datetime=compare_start,
|
||||
end_datetime=compare_end,
|
||||
interval=compare_interval,
|
||||
)
|
||||
np.testing.assert_allclose(load_total_kwh_array, [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
|
||||
# Call the method and validate results
|
||||
weekday_adjust, weekend_adjust = loadakkudoktoradjusted._calculate_adjustment(data_year_energy)
|
||||
assert weekday_adjust.shape == (24,)
|
||||
|
||||
@@ -3,10 +3,17 @@ import pytest
|
||||
from pendulum import datetime, duration
|
||||
|
||||
from akkudoktoreos.config.config import SettingsEOS
|
||||
from akkudoktoreos.core.coreabc import get_measurement
|
||||
from akkudoktoreos.measurement.measurement import (
|
||||
MeasurementCommonSettings,
|
||||
MeasurementDataRecord,
|
||||
get_measurement,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import (
|
||||
DateTime,
|
||||
Duration,
|
||||
compare_datetimes,
|
||||
to_datetime,
|
||||
to_duration,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,8 +48,9 @@ class TestMeasurementDataRecord:
|
||||
|
||||
def test_getitem_existing_field(self, record):
|
||||
"""Test that __getitem__ returns correct value for existing native field."""
|
||||
record.date_time = "2024-01-01T00:00:00+00:00"
|
||||
assert record["date_time"] is not None
|
||||
date_time = "2024-01-01T00:00:00+00:00"
|
||||
record.date_time = date_time
|
||||
assert compare_datetimes(record["date_time"], to_datetime(date_time)).equal
|
||||
|
||||
def test_getitem_existing_measurement(self, record):
|
||||
"""Test that __getitem__ retrieves existing measurement values."""
|
||||
@@ -220,6 +228,7 @@ class TestMeasurement:
|
||||
# Load meter readings are in kWh
|
||||
config_eos.measurement.load_emr_keys = ["load0_mr", "load1_mr", "load2_mr", "load3_mr"]
|
||||
measurement = get_measurement()
|
||||
measurement.delete_by_datetime(None, None)
|
||||
record0 = MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=0),
|
||||
load0_mr=100,
|
||||
@@ -227,52 +236,54 @@ class TestMeasurement:
|
||||
)
|
||||
assert record0.load0_mr == 100
|
||||
assert record0.load1_mr == 200
|
||||
measurement.records = [
|
||||
records = [
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=0),
|
||||
date_time=to_datetime("2023-01-01T00:00:00"),
|
||||
load0_mr=100,
|
||||
load1_mr=200,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=1),
|
||||
date_time=to_datetime("2023-01-01T01:00:00"),
|
||||
load0_mr=150,
|
||||
load1_mr=250,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=2),
|
||||
date_time=to_datetime("2023-01-01T02:00:00"),
|
||||
load0_mr=200,
|
||||
load1_mr=300,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=3),
|
||||
date_time=to_datetime("2023-01-01T03:00:00"),
|
||||
load0_mr=250,
|
||||
load1_mr=350,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=4),
|
||||
date_time=to_datetime("2023-01-01T04:00:00"),
|
||||
load0_mr=300,
|
||||
load1_mr=400,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=5),
|
||||
date_time=to_datetime("2023-01-01T05:00:00"),
|
||||
load0_mr=350,
|
||||
load1_mr=450,
|
||||
),
|
||||
]
|
||||
for record in records:
|
||||
measurement.insert_by_datetime(record)
|
||||
return measurement
|
||||
|
||||
def test_interval_count(self, measurement_eos):
|
||||
"""Test interval count calculation."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
start = to_datetime("2023-01-01T00:00:00")
|
||||
end = to_datetime("2023-01-01T03:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
assert measurement_eos._interval_count(start, end, interval) == 3
|
||||
|
||||
def test_interval_count_invalid_end_before_start(self, measurement_eos):
|
||||
"""Test interval count raises ValueError when end_datetime is before start_datetime."""
|
||||
start = datetime(2023, 1, 1, 3)
|
||||
end = datetime(2023, 1, 1, 0)
|
||||
start = to_datetime("2023-01-01T03:00:00")
|
||||
end = to_datetime("2023-01-01T00:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
with pytest.raises(ValueError, match="end_datetime must be after start_datetime"):
|
||||
@@ -280,8 +291,8 @@ class TestMeasurement:
|
||||
|
||||
def test_interval_count_invalid_non_positive_interval(self, measurement_eos):
|
||||
"""Test interval count raises ValueError when interval is non-positive."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
start = to_datetime("2023-01-01T00:00:00")
|
||||
end = to_datetime("2023-01-01T03:00:00")
|
||||
|
||||
with pytest.raises(ValueError, match="interval must be positive"):
|
||||
measurement_eos._interval_count(start, end, duration(hours=0))
|
||||
@@ -289,8 +300,8 @@ class TestMeasurement:
|
||||
def test_energy_from_meter_readings_valid_input(self, measurement_eos):
|
||||
"""Test _energy_from_meter_readings with valid inputs and proper alignment of load data."""
|
||||
key = "load0_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
@@ -303,12 +314,12 @@ class TestMeasurement:
|
||||
def test_energy_from_meter_readings_empty_array(self, measurement_eos):
|
||||
"""Test _energy_from_meter_readings with no data (empty array)."""
|
||||
key = "load0_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
# Use empyt records array
|
||||
measurement_eos.records = []
|
||||
measurement_eos.delete_by_datetime(start_datetime, end_datetime)
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
@@ -324,25 +335,46 @@ class TestMeasurement:
|
||||
def test_energy_from_meter_readings_misaligned_array(self, measurement_eos):
|
||||
"""Test _energy_from_meter_readings with misaligned array size."""
|
||||
key = "load1_mr"
|
||||
start_datetime = measurement_eos.min_datetime
|
||||
end_datetime = measurement_eos.max_datetime
|
||||
interval = duration(hours=1)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
|
||||
# Use misaligned array, latest interval set to 2 hours (instead of 1 hour)
|
||||
measurement_eos.records[-1].date_time = datetime(2023, 1, 1, 6)
|
||||
latest_record_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
new_record_datetime = to_datetime("2023-01-01T06:00:00")
|
||||
record = measurement_eos.get_by_datetime(latest_record_datetime)
|
||||
assert record is not None
|
||||
measurement_eos.delete_by_datetime(start_datetime = latest_record_datetime,
|
||||
end_datetime = new_record_datetime)
|
||||
record.date_time = new_record_datetime
|
||||
measurement_eos.insert_by_datetime(record)
|
||||
|
||||
# Check test setup
|
||||
dates, values = measurement_eos.key_to_lists(key, start_datetime, None)
|
||||
assert dates == [
|
||||
to_datetime("2023-01-01T00:00:00"),
|
||||
to_datetime("2023-01-01T01:00:00"),
|
||||
to_datetime("2023-01-01T02:00:00"),
|
||||
to_datetime("2023-01-01T03:00:00"),
|
||||
to_datetime("2023-01-01T04:00:00"),
|
||||
to_datetime("2023-01-01T06:00:00"),
|
||||
]
|
||||
assert values == [200, 250, 300, 350, 400, 450]
|
||||
array = measurement_eos.key_to_array(key, start_datetime, end_datetime + interval, interval=interval)
|
||||
np.testing.assert_array_equal(array, [200, 250, 300, 350, 400, 425])
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
)
|
||||
|
||||
expected_load_array = np.array([50, 50, 50, 50, 25]) # Differences between consecutive readings
|
||||
expected_load_array = np.array([50., 50., 50., 50., 25.]) # Differences between consecutive readings
|
||||
np.testing.assert_array_equal(load_array, expected_load_array)
|
||||
|
||||
def test_energy_from_meter_readings_partial_data(self, measurement_eos, caplog):
|
||||
"""Test _energy_from_meter_readings with partial data (misaligned but empty array)."""
|
||||
key = "load2_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
with caplog.at_level("DEBUG"):
|
||||
@@ -359,8 +391,8 @@ class TestMeasurement:
|
||||
def test_energy_from_meter_readings_negative_interval(self, measurement_eos):
|
||||
"""Test _energy_from_meter_readings with a negative interval."""
|
||||
key = "load3_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
interval = duration(hours=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="interval must be positive"):
|
||||
@@ -368,11 +400,11 @@ class TestMeasurement:
|
||||
|
||||
def test_load_total_kwh(self, measurement_eos):
|
||||
"""Test total load calculation."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 2)
|
||||
start_datetime = to_datetime("2023-01-01T03:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T05:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
|
||||
|
||||
# Expected total load per interval
|
||||
expected = np.array([100, 100]) # Differences between consecutive meter readings
|
||||
@@ -381,20 +413,20 @@ class TestMeasurement:
|
||||
def test_load_total_kwh_no_data(self, measurement_eos):
|
||||
"""Test total load calculation with no data."""
|
||||
measurement_eos.records = []
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
start_datetime = to_datetime("2023-01-01T00:00:00")
|
||||
end_datetime = to_datetime("2023-01-01T03:00:00")
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
|
||||
expected = np.zeros(3) # No data, so all intervals are zero
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_load_total_kwh_partial_intervals(self, measurement_eos):
|
||||
"""Test total load calculation with partial intervals."""
|
||||
start = datetime(2023, 1, 1, 0, 30) # Start in the middle of an interval
|
||||
end = datetime(2023, 1, 1, 1, 30) # End in the middle of another interval
|
||||
start_datetime = to_datetime("2023-01-01T00:30:00") # Start in the middle of an interval
|
||||
end_datetime = to_datetime("2023-01-01T01:30:00") # End in the middle of another interval
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
|
||||
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
|
||||
expected = np.array([100]) # Only one complete interval covered
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from akkudoktoreos.core.coreabc import get_prediction
|
||||
from akkudoktoreos.prediction.elecpriceakkudoktor import ElecPriceAkkudoktor
|
||||
from akkudoktoreos.prediction.elecpriceenergycharts import ElecPriceEnergyCharts
|
||||
from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport
|
||||
@@ -15,7 +16,6 @@ from akkudoktoreos.prediction.loadvrm import LoadVrm
|
||||
from akkudoktoreos.prediction.prediction import (
|
||||
Prediction,
|
||||
PredictionCommonSettings,
|
||||
get_prediction,
|
||||
)
|
||||
from akkudoktoreos.prediction.pvforecastakkudoktor import PVForecastAkkudoktor
|
||||
from akkudoktoreos.prediction.pvforecastimport import PVForecastImport
|
||||
|
||||
@@ -7,10 +7,10 @@ import pendulum
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.prediction import PredictionCommonSettings
|
||||
from akkudoktoreos.prediction.predictionabc import (
|
||||
PredictionBase,
|
||||
PredictionABC,
|
||||
PredictionContainer,
|
||||
PredictionProvider,
|
||||
PredictionRecord,
|
||||
@@ -28,7 +28,7 @@ class DerivedConfig(PredictionCommonSettings):
|
||||
class_constant: Optional[int] = Field(default=None, description="Test config by class constant")
|
||||
|
||||
|
||||
class DerivedBase(PredictionBase):
|
||||
class DerivedBase(PredictionABC):
|
||||
instance_field: Optional[str] = Field(default=None, description="Field Value")
|
||||
class_constant: ClassVar[int] = 30
|
||||
|
||||
@@ -84,7 +84,7 @@ class DerivedPredictionContainer(PredictionContainer):
|
||||
# ----------
|
||||
|
||||
|
||||
class TestPredictionBase:
|
||||
class TestPredictionABC:
|
||||
@pytest.fixture
|
||||
def base(self, monkeypatch):
|
||||
# Provide default values for configuration
|
||||
@@ -216,17 +216,19 @@ class TestPredictionProvider:
|
||||
def test_delete_by_datetime(self, provider, sample_start_datetime):
|
||||
"""Test `delete_by_datetime` method for removing records by datetime range."""
|
||||
# Add records to the provider for deletion testing
|
||||
provider.records = [
|
||||
records = [
|
||||
self.create_test_record(sample_start_datetime - to_duration("3 hours"), 1),
|
||||
self.create_test_record(sample_start_datetime - to_duration("1 hour"), 2),
|
||||
self.create_test_record(sample_start_datetime + to_duration("1 hour"), 3),
|
||||
]
|
||||
for record in records:
|
||||
provider.insert_by_datetime(record)
|
||||
|
||||
provider.delete_by_datetime(
|
||||
start_datetime=sample_start_datetime - to_duration("2 hours"),
|
||||
end_datetime=sample_start_datetime + to_duration("2 hours"),
|
||||
)
|
||||
assert len(provider.records) == 1, (
|
||||
assert len(provider) == 1, (
|
||||
"Only one record should remain after deletion by datetime."
|
||||
)
|
||||
assert provider.records[0].date_time == sample_start_datetime - to_duration("3 hours"), (
|
||||
@@ -243,15 +245,17 @@ class TestPredictionContainer:
|
||||
|
||||
@pytest.fixture
|
||||
def container_with_providers(self):
|
||||
record1 = self.create_test_record(datetime(2023, 11, 5), 1)
|
||||
record2 = self.create_test_record(datetime(2023, 11, 6), 2)
|
||||
record3 = self.create_test_record(datetime(2023, 11, 7), 3)
|
||||
records = [
|
||||
# Test records - include 'prediction_value' key
|
||||
self.create_test_record(datetime(2023, 11, 5), 1),
|
||||
self.create_test_record(datetime(2023, 11, 6), 2),
|
||||
self.create_test_record(datetime(2023, 11, 7), 3),
|
||||
]
|
||||
provider = DerivedPredictionProvider()
|
||||
provider.clear()
|
||||
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
assert len(provider) == 0
|
||||
provider.append(record1)
|
||||
provider.append(record2)
|
||||
provider.append(record3)
|
||||
for record in records:
|
||||
provider.insert_by_datetime(record)
|
||||
assert len(provider) == 3
|
||||
container = DerivedPredictionContainer()
|
||||
container.providers.clear()
|
||||
@@ -378,7 +382,9 @@ class TestPredictionContainer:
|
||||
assert len(container_with_providers.providers) == 1
|
||||
# check all keys are available (don't care for position)
|
||||
for key in ["prediction_value", "date_time"]:
|
||||
assert key in list(container_with_providers.keys())
|
||||
assert key in container_with_providers.record_keys
|
||||
for key in ["prediction_value", "date_time"]:
|
||||
assert key in container_with_providers.keys()
|
||||
series = container_with_providers["prediction_value"]
|
||||
assert isinstance(series, pd.Series)
|
||||
assert series.name == "prediction_value"
|
||||
|
||||
@@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.prediction.prediction import get_prediction
|
||||
from akkudoktoreos.core.coreabc import get_ems, get_prediction
|
||||
from akkudoktoreos.prediction.pvforecastakkudoktor import (
|
||||
AkkudoktorForecastHorizon,
|
||||
AkkudoktorForecastMeta,
|
||||
@@ -137,7 +136,7 @@ def provider():
|
||||
def provider_empty_instance():
|
||||
"""Fixture that returns an empty instance of PVForecast."""
|
||||
empty_instance = PVForecastAkkudoktor()
|
||||
empty_instance.clear()
|
||||
empty_instance.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
assert len(empty_instance) == 0
|
||||
return empty_instance
|
||||
|
||||
@@ -277,7 +276,7 @@ def test_pvforecast_akkudoktor_update_with_sample_forecast(
|
||||
ems_eos.set_start_datetime(sample_forecast_start)
|
||||
provider.update_data(force_enable=True, force_update=True)
|
||||
assert compare_datetimes(provider.ems_start_datetime, sample_forecast_start).equal
|
||||
assert compare_datetimes(provider[0].date_time, to_datetime(sample_forecast_start)).equal
|
||||
assert compare_datetimes(provider.records[0].date_time, to_datetime(sample_forecast_start)).equal
|
||||
|
||||
|
||||
# Report Generation Test
|
||||
@@ -290,7 +289,7 @@ def test_report_ac_power_and_measurement(provider, config_eos):
|
||||
pvforecast_dc_power=450.0,
|
||||
pvforecast_ac_power=400.0,
|
||||
)
|
||||
provider.append(record)
|
||||
provider.insert_by_datetime(record)
|
||||
|
||||
report = provider.report_ac_power_and_measurement()
|
||||
assert "DC: 450.0" in report
|
||||
@@ -323,19 +322,19 @@ def test_timezone_behaviour(
|
||||
expected_datetime = to_datetime("2024-10-06T00:00:00+0200", in_timezone=other_timezone)
|
||||
assert compare_datetimes(other_start_datetime, expected_datetime).equal
|
||||
|
||||
provider.clear()
|
||||
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
assert len(provider) == 0
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(other_start_datetime)
|
||||
provider.update_data(force_update=True)
|
||||
assert compare_datetimes(provider.ems_start_datetime, other_start_datetime).equal
|
||||
# Check wether first record starts at requested sample start time
|
||||
assert compare_datetimes(provider[0].date_time, sample_forecast_start).equal
|
||||
assert compare_datetimes(provider.records[0].date_time, sample_forecast_start).equal
|
||||
|
||||
# Test updating AC power measurement for a specific date.
|
||||
provider.update_value(sample_forecast_start, "pvforecastakkudoktor_ac_power_measured", 1000)
|
||||
# Check wether first record was filled with ac power measurement
|
||||
assert provider[0].pvforecastakkudoktor_ac_power_measured == 1000
|
||||
assert provider.records[0].pvforecastakkudoktor_ac_power_measured == 1000
|
||||
|
||||
# Test fetching temperature forecast for a specific date.
|
||||
other_end_datetime = other_start_datetime + to_duration("24 hours")
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy.testing as npt
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.pvforecastimport import PVForecastImport
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
|
||||
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
|
||||
|
||||
@@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos):
|
||||
)
|
||||
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
|
||||
"""Test fetching forecast from import."""
|
||||
key = "pvforecast_ac_power"
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
|
||||
if from_file:
|
||||
@@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
else:
|
||||
config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path = None
|
||||
assert config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path is None
|
||||
provider.clear()
|
||||
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
|
||||
# Call the method
|
||||
provider.update_data()
|
||||
@@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
assert provider.ems_start_datetime is not None
|
||||
assert provider.total_hours is not None
|
||||
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["pvforecast_ac_power"]
|
||||
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
expected_value = values[expected_value_index]
|
||||
result_datetime = provider.records[i].date_time
|
||||
result_value = provider.records[i]["pvforecast_ac_power"]
|
||||
|
||||
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
|
||||
# print(f"{i}: Result: {result_datetime}:{result_value}")
|
||||
assert compare_datetimes(result_datetime, expected_datetime).equal
|
||||
assert result_value == expected_value
|
||||
expected_values = sample_import_1_json[key]
|
||||
result_values = provider.key_to_array(
|
||||
key=key,
|
||||
start_datetime=provider.ems_start_datetime,
|
||||
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
|
||||
interval=to_duration("1 hour"),
|
||||
)
|
||||
# Allow for some difference due to value calculation on DST change
|
||||
npt.assert_allclose(result_values, expected_values, rtol=0.001)
|
||||
|
||||
701
tests/test_retentionmanager.py
Normal file
701
tests/test_retentionmanager.py
Normal file
@@ -0,0 +1,701 @@
|
||||
"""Tests for RetentionManager and JobState."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
import akkudoktoreos.server.retentionmanager
|
||||
from akkudoktoreos.server.retentionmanager import JobState, RetentionManager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INTERVAL = 10.0
|
||||
DUE_INTERVAL = 0.001 # non-zero so interval() does not fall back to fallback_interval
|
||||
FALLBACK = 300.0
|
||||
|
||||
|
||||
def make_config_getter(interval: float = INTERVAL) -> Any:
|
||||
"""Return a simple config getter that always yields ``interval`` for any key."""
|
||||
return lambda key: interval
|
||||
|
||||
|
||||
def make_config_getter_none() -> Any:
|
||||
"""Return a config getter that always yields ``None`` (job disabled)."""
|
||||
return lambda key: None
|
||||
|
||||
|
||||
def make_manager(interval: float = INTERVAL, shutdown_timeout: float = 5.0) -> RetentionManager:
|
||||
"""Return a ``RetentionManager`` backed by a fixed-interval config getter."""
|
||||
return RetentionManager(make_config_getter(interval), shutdown_timeout=shutdown_timeout)
|
||||
|
||||
|
||||
def make_manager_none(shutdown_timeout: float = 5.0) -> RetentionManager:
|
||||
"""Return a ``RetentionManager`` whose config getter always returns None (all jobs disabled)."""
|
||||
return RetentionManager(make_config_getter_none(), shutdown_timeout=shutdown_timeout)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetentionManager:
|
||||
"""Tests for :class:`RetentionManager` and :class:`JobState`."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_init_stores_config_getter(self) -> None:
|
||||
"""The config getter passed to __init__ is stored and forwarded to jobs."""
|
||||
getter = make_config_getter()
|
||||
manager = RetentionManager(getter)
|
||||
assert manager._config_getter is getter
|
||||
|
||||
def test_init_empty_job_registry(self) -> None:
|
||||
"""A newly created manager has no registered jobs."""
|
||||
manager = make_manager()
|
||||
assert manager._jobs == {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# register / unregister
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_register_adds_job(self) -> None:
|
||||
"""Registering a function adds a JobState entry."""
|
||||
manager = make_manager()
|
||||
func = MagicMock()
|
||||
manager.register("job1", func, interval_attr="some/key")
|
||||
assert "job1" in manager._jobs
|
||||
|
||||
def test_register_job_state_fields(self) -> None:
|
||||
"""Registered JobState carries the correct initial field values."""
|
||||
manager = make_manager()
|
||||
func = MagicMock()
|
||||
manager.register("job1", func, interval_attr="some/key", fallback_interval=60.0)
|
||||
job = manager._jobs["job1"]
|
||||
assert job.name == "job1"
|
||||
assert job.func is func
|
||||
assert job.interval_attr == "some/key"
|
||||
assert job.fallback_interval == 60.0
|
||||
assert job.config_getter is manager._config_getter
|
||||
assert job.on_exception is None
|
||||
assert job.last_run_at == 0.0
|
||||
assert job.run_count == 0
|
||||
assert job.is_running is False
|
||||
|
||||
def test_register_stores_on_exception(self) -> None:
|
||||
"""The on_exception callback is stored on the JobState."""
|
||||
manager = make_manager()
|
||||
handler = MagicMock()
|
||||
manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler)
|
||||
assert manager._jobs["job1"].on_exception is handler
|
||||
|
||||
def test_register_duplicate_raises(self) -> None:
|
||||
"""Registering the same name twice raises ValueError."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
with pytest.raises(ValueError, match="job1"):
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
|
||||
def test_unregister_removes_job(self) -> None:
|
||||
"""Unregistering a job removes it from the registry."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
manager.unregister("job1")
|
||||
assert "job1" not in manager._jobs
|
||||
|
||||
def test_unregister_missing_job_is_noop(self) -> None:
|
||||
"""Unregistering a non-existent job does not raise."""
|
||||
manager = make_manager()
|
||||
manager.unregister("nonexistent") # must not raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JobState.interval()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_job_interval_from_config_getter(self) -> None:
|
||||
"""JobState.interval() returns the value provided by config_getter."""
|
||||
manager = make_manager(interval=42.0)
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
assert manager._jobs["job1"].interval() == 42.0
|
||||
|
||||
def test_job_interval_none_when_config_returns_none(self) -> None:
|
||||
"""JobState.interval() returns None when config_getter returns None (job disabled)."""
|
||||
manager = make_manager_none()
|
||||
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=FALLBACK)
|
||||
assert manager._jobs["job1"].interval() is None
|
||||
|
||||
def test_job_interval_none_does_not_fall_back(self) -> None:
|
||||
"""A None config value must NOT fall back to fallback_interval -- None means disabled."""
|
||||
manager = make_manager_none()
|
||||
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0)
|
||||
# If None incorrectly fell back, this would return 99.0 instead of None
|
||||
assert manager._jobs["job1"].interval() is None
|
||||
|
||||
def test_job_interval_fallback_on_key_error(self) -> None:
|
||||
"""JobState.interval() uses fallback_interval when config_getter raises KeyError."""
|
||||
manager = RetentionManager(lambda key: (_ for _ in ()).throw(KeyError(key)))
|
||||
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0)
|
||||
assert manager._jobs["job1"].interval() == 99.0
|
||||
|
||||
def test_job_interval_fallback_on_index_error(self) -> None:
|
||||
"""JobState.interval() uses fallback_interval when config_getter raises IndexError."""
|
||||
manager = RetentionManager(lambda key: (_ for _ in ()).throw(IndexError()))
|
||||
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=77.0)
|
||||
assert manager._jobs["job1"].interval() == 77.0
|
||||
|
||||
def test_job_interval_fallback_on_zero_value(self) -> None:
|
||||
"""JobState.interval() uses fallback_interval when config_getter returns zero."""
|
||||
manager = RetentionManager(lambda key: 0)
|
||||
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=55.0)
|
||||
assert manager._jobs["job1"].interval() == 55.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JobState.is_due()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_job_is_due_when_never_run(self) -> None:
|
||||
"""A job is always due when it has never been run (last_run_at == 0.0)."""
|
||||
manager = make_manager(interval=INTERVAL)
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
assert manager._jobs["job1"].is_due() is True
|
||||
|
||||
def test_job_is_not_due_immediately_after_run(self) -> None:
|
||||
"""A job is not due immediately after last_run_at is set to now."""
|
||||
manager = make_manager(interval=INTERVAL)
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
manager._jobs["job1"].last_run_at = time.monotonic()
|
||||
assert manager._jobs["job1"].is_due() is False
|
||||
|
||||
def test_job_is_due_after_interval_elapsed(self) -> None:
|
||||
"""A job becomes due once the interval has passed since last_run_at."""
|
||||
manager = make_manager(interval=1.0)
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
manager._jobs["job1"].last_run_at = time.monotonic() - 2.0 # 2 s ago > 1 s interval
|
||||
assert manager._jobs["job1"].is_due() is True
|
||||
|
||||
def test_job_is_never_due_when_interval_is_none(self) -> None:
|
||||
"""is_due() returns False when interval() is None, even if last_run_at is 0."""
|
||||
manager = make_manager_none()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
# last_run_at == 0.0 would make any enabled job due immediately
|
||||
assert job.last_run_at == 0.0
|
||||
assert job.is_due() is False
|
||||
|
||||
def test_job_is_never_due_when_disabled_regardless_of_last_run(self) -> None:
|
||||
"""is_due() stays False for a disabled job even long after its last run."""
|
||||
manager = make_manager_none()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
job.last_run_at = time.monotonic() - 365 * 24 * 3600 # "ran" a year ago
|
||||
assert job.is_due() is False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JobState.summary()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_summary_keys(self) -> None:
|
||||
"""summary() returns all expected keys including interval_s."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
summary = manager._jobs["job1"].summary()
|
||||
assert set(summary.keys()) == {
|
||||
"name", "interval_attr", "interval_s", "last_run_at",
|
||||
"last_duration_s", "last_error", "run_count", "is_running",
|
||||
}
|
||||
|
||||
def test_summary_interval_s_reflects_config(self) -> None:
|
||||
"""summary()['interval_s'] matches the value returned by interval()."""
|
||||
manager = make_manager(interval=42.0)
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
assert manager._jobs["job1"].summary()["interval_s"] == 42.0
|
||||
|
||||
def test_summary_interval_s_is_none_when_disabled(self) -> None:
|
||||
"""summary()['interval_s'] is None when the job is disabled via config."""
|
||||
manager = make_manager_none()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
assert manager._jobs["job1"].summary()["interval_s"] is None
|
||||
|
||||
def test_summary_values(self) -> None:
|
||||
"""summary() reflects the current JobState values."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="my/key")
|
||||
job = manager._jobs["job1"]
|
||||
job.last_run_at = 1234.5
|
||||
job.last_duration = 0.12345
|
||||
job.last_error = "oops"
|
||||
job.run_count = 3
|
||||
job.is_running = True
|
||||
s = job.summary()
|
||||
assert s["name"] == "job1"
|
||||
assert s["interval_attr"] == "my/key"
|
||||
assert s["last_run_at"] == 1234.5
|
||||
assert s["last_duration_s"] == 0.1235 # rounded to 4 dp
|
||||
assert s["last_error"] == "oops"
|
||||
assert s["run_count"] == 3
|
||||
assert s["is_running"] is True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# status()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_status_empty(self) -> None:
|
||||
"""status() returns an empty list when no jobs are registered."""
|
||||
assert make_manager().status() == []
|
||||
|
||||
def test_status_contains_all_jobs(self) -> None:
|
||||
"""status() returns one entry per registered job."""
|
||||
manager = make_manager()
|
||||
manager.register("a", MagicMock(), interval_attr="k1")
|
||||
manager.register("b", MagicMock(), interval_attr="k2")
|
||||
names = {s["name"] for s in manager.status()}
|
||||
assert names == {"a", "b"}
|
||||
|
||||
def test_status_shows_disabled_job(self) -> None:
|
||||
"""status() includes disabled jobs with interval_s == None."""
|
||||
manager = make_manager_none()
|
||||
manager.register("disabled", MagicMock(), interval_attr="k")
|
||||
entries = manager.status()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["interval_s"] is None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# tick() -- job dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_runs_due_sync_job(self) -> None:
|
||||
"""tick() executes a sync job that is due."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
func = MagicMock()
|
||||
manager.register("job1", func, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
await manager.shutdown()
|
||||
func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_runs_due_async_job(self) -> None:
|
||||
"""tick() executes an async job that is due."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
func = AsyncMock()
|
||||
manager.register("job1", func, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
await manager.shutdown()
|
||||
func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_skips_not_due_job(self) -> None:
|
||||
"""tick() does not execute a job whose interval has not yet elapsed."""
|
||||
manager = make_manager(interval=9999.0)
|
||||
func = MagicMock()
|
||||
manager.register("job1", func, interval_attr="k")
|
||||
manager._jobs["job1"].last_run_at = time.monotonic() # just ran
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await manager.shutdown()
|
||||
func.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_skips_disabled_job(self) -> None:
|
||||
"""tick() never executes a job whose interval is None, even if never run before."""
|
||||
manager = make_manager_none()
|
||||
func = MagicMock()
|
||||
manager.register("disabled", func, interval_attr="k")
|
||||
job = manager._jobs["disabled"]
|
||||
# last_run_at == 0.0 would fire any enabled job immediately
|
||||
assert job.last_run_at == 0.0
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await manager.shutdown()
|
||||
func.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_skips_disabled_job_adds_no_task(self) -> None:
|
||||
"""tick() adds no task to _running_tasks for a disabled job."""
|
||||
manager = make_manager_none()
|
||||
manager.register("disabled", AsyncMock(), interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
assert len(manager._running_tasks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_enabled_and_disabled_jobs_mixed(self) -> None:
|
||||
"""tick() fires enabled jobs and silently skips disabled ones in the same manager."""
|
||||
results: list[str] = []
|
||||
|
||||
async def enabled_job() -> None:
|
||||
results.append("ran")
|
||||
|
||||
manager = RetentionManager(
|
||||
lambda key: DUE_INTERVAL if key == "enabled/interval" else None,
|
||||
shutdown_timeout=5.0,
|
||||
)
|
||||
manager.register("enabled", enabled_job, interval_attr="enabled/interval")
|
||||
manager.register("disabled", AsyncMock(), interval_attr="disabled/interval")
|
||||
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await manager.shutdown()
|
||||
|
||||
assert results == ["ran"], "Only the enabled job must have run"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_skips_already_running_job(self) -> None:
|
||||
"""tick() does not start a job that is still marked as running."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
func = MagicMock()
|
||||
manager.register("job1", func, interval_attr="k")
|
||||
manager._jobs["job1"].is_running = True
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await manager.shutdown()
|
||||
func.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_runs_multiple_jobs_concurrently(self) -> None:
|
||||
"""tick() fires all due jobs as independent tasks."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
results: list[str] = []
|
||||
|
||||
async def job_a() -> None:
|
||||
results.append("a")
|
||||
|
||||
async def job_b() -> None:
|
||||
results.append("b")
|
||||
|
||||
manager.register("a", job_a, interval_attr="k")
|
||||
manager.register("b", job_b, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
await manager.shutdown()
|
||||
assert sorted(results) == ["a", "b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_adds_tasks_to_running_set(self) -> None:
|
||||
"""tick() adds a task to _running_tasks for each due job."""
|
||||
barrier = asyncio.Event()
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
|
||||
async def blocking_job() -> None:
|
||||
await barrier.wait()
|
||||
|
||||
manager.register("job1", blocking_job, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
# Task is still running (barrier not set), so it must be in the set.
|
||||
assert len(manager._running_tasks) == 1
|
||||
barrier.set()
|
||||
await manager.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tick_removes_task_from_running_set_on_completion(self) -> None:
|
||||
"""Completed tasks are removed from _running_tasks automatically."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
manager.register("job1", AsyncMock(), interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await manager.shutdown()
|
||||
assert len(manager._running_tasks) == 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# shutdown()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_returns_immediately_when_no_tasks(self) -> None:
|
||||
"""shutdown() completes without blocking when no tasks are running."""
|
||||
manager = make_manager()
|
||||
await manager.shutdown() # must return promptly without raising
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_waits_for_in_flight_task(self) -> None:
|
||||
"""shutdown() blocks until a long-running job task finishes."""
|
||||
barrier = asyncio.Event()
|
||||
finished: list[bool] = []
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
|
||||
async def slow_job() -> None:
|
||||
await barrier.wait()
|
||||
finished.append(True)
|
||||
|
||||
manager.register("job1", slow_job, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
assert finished == [] # job still blocked
|
||||
barrier.set()
|
||||
await manager.shutdown()
|
||||
assert finished == [True] # job completed before shutdown returned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_waits_for_multiple_in_flight_tasks(self) -> None:
|
||||
"""shutdown() waits for all concurrently running job tasks."""
|
||||
barrier = asyncio.Event()
|
||||
finished: list[str] = []
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
|
||||
async def slow_a() -> None:
|
||||
await barrier.wait()
|
||||
finished.append("a")
|
||||
|
||||
async def slow_b() -> None:
|
||||
await barrier.wait()
|
||||
finished.append("b")
|
||||
|
||||
manager.register("a", slow_a, interval_attr="k")
|
||||
manager.register("b", slow_b, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await asyncio.sleep(0) # second yield ensures tasks have started
|
||||
assert finished == []
|
||||
barrier.set()
|
||||
await manager.shutdown()
|
||||
assert sorted(finished) == ["a", "b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_does_not_raise_when_job_failed(self) -> None:
|
||||
"""shutdown() completes without raising even if a job task raised an exception."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
|
||||
def failing_func() -> None:
|
||||
raise RuntimeError("job error")
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await manager.shutdown() # must not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_clears_running_tasks_set(self) -> None:
|
||||
"""_running_tasks is empty after shutdown() completes."""
|
||||
manager = make_manager(interval=DUE_INTERVAL)
|
||||
manager.register("job1", AsyncMock(), interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
|
||||
await manager.shutdown()
|
||||
assert manager._running_tasks == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_timeout_returns_without_blocking(self) -> None:
|
||||
"""shutdown() returns once the timeout elapses even if a job is still running."""
|
||||
stuck = asyncio.Event() # never set -- job blocks forever
|
||||
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
|
||||
|
||||
async def forever() -> None:
|
||||
await stuck.wait()
|
||||
|
||||
manager.register("stuck", forever, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
# Must return within the timeout, not block forever.
|
||||
await manager.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_timeout_logs_error_for_pending_jobs(self) -> None:
|
||||
"""An error is logged listing jobs still running after the timeout."""
|
||||
stuck = asyncio.Event()
|
||||
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
|
||||
|
||||
async def forever() -> None:
|
||||
await stuck.wait()
|
||||
|
||||
manager.register("stuck_job", forever, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with patch.object(logger, "error") as mock_error:
|
||||
await manager.shutdown()
|
||||
assert mock_error.called, "Expected logger.error to be called on timeout"
|
||||
# All positional args joined: the stuck job name must appear.
|
||||
logged = str(mock_error.call_args_list)
|
||||
assert "stuck_job" in logged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_timeout_clears_running_tasks_set(self) -> None:
|
||||
"""_running_tasks is cleared even when the timeout elapses."""
|
||||
stuck = asyncio.Event()
|
||||
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
|
||||
|
||||
async def forever() -> None:
|
||||
await stuck.wait()
|
||||
|
||||
manager.register("stuck", forever, interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await manager.shutdown()
|
||||
assert manager._running_tasks == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_no_error_logged_when_all_finish_in_time(self) -> None:
|
||||
"""No error is logged when all tasks complete within the timeout."""
|
||||
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=5.0)
|
||||
manager.register("job1", AsyncMock(), interval_attr="k")
|
||||
await manager.tick()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with patch.object(logger, "error") as mock_error:
|
||||
await manager.shutdown()
|
||||
mock_error.assert_not_called()
|
||||
|
||||
def test_init_stores_shutdown_timeout(self) -> None:
|
||||
"""The shutdown_timeout passed to __init__ is stored on the instance."""
|
||||
manager = RetentionManager(make_config_getter(), shutdown_timeout=99.0)
|
||||
assert manager._shutdown_timeout == 99.0
|
||||
|
||||
def test_init_default_shutdown_timeout(self) -> None:
|
||||
"""The default shutdown_timeout is 30 seconds."""
|
||||
manager = RetentionManager(make_config_getter())
|
||||
assert manager._shutdown_timeout == 30.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _run_job() -- state updates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_increments_run_count(self) -> None:
|
||||
"""_run_job() increments run_count after each execution."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
await manager._run_job(job)
|
||||
await manager._run_job(job)
|
||||
assert job.run_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_updates_last_run_at(self) -> None:
|
||||
"""_run_job() sets last_run_at to a recent monotonic timestamp."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
before = time.monotonic()
|
||||
await manager._run_job(job)
|
||||
assert job.last_run_at >= before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_updates_last_duration(self) -> None:
|
||||
"""_run_job() records a non-negative last_duration."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
await manager._run_job(job)
|
||||
assert job.last_duration >= 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_clears_is_running_on_success(self) -> None:
|
||||
"""is_running is False after a successful job execution."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
await manager._run_job(job)
|
||||
assert job.is_running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_clears_last_error_on_success(self) -> None:
|
||||
"""last_error is set to None after a successful execution."""
|
||||
manager = make_manager()
|
||||
manager.register("job1", MagicMock(), interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
job.last_error = "stale error"
|
||||
await manager._run_job(job)
|
||||
assert job.last_error is None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _run_job() -- exception handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_stores_exception_message(self) -> None:
|
||||
"""last_error is set to the exception message when the job raises."""
|
||||
manager = make_manager()
|
||||
|
||||
def failing_func() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
await manager._run_job(job)
|
||||
assert job.last_error == "boom"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_still_updates_state_after_exception(self) -> None:
|
||||
"""run_count and last_run_at are updated even when the job raises."""
|
||||
manager = make_manager()
|
||||
|
||||
def failing_func() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k")
|
||||
job = manager._jobs["job1"]
|
||||
before = time.monotonic()
|
||||
await manager._run_job(job)
|
||||
assert job.run_count == 1
|
||||
assert job.last_run_at >= before
|
||||
assert job.is_running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_calls_sync_on_exception_handler(self) -> None:
|
||||
"""A sync on_exception handler is called with the raised exception."""
|
||||
manager = make_manager()
|
||||
handler = MagicMock()
|
||||
exc = RuntimeError("oops")
|
||||
|
||||
def failing_func() -> None:
|
||||
raise exc
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k", on_exception=handler)
|
||||
await manager._run_job(manager._jobs["job1"])
|
||||
handler.assert_called_once_with(exc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_calls_async_on_exception_handler(self) -> None:
|
||||
"""An async on_exception handler is awaited with the raised exception."""
|
||||
manager = make_manager()
|
||||
handler = AsyncMock()
|
||||
exc = RuntimeError("oops")
|
||||
|
||||
def failing_func() -> None:
|
||||
raise exc
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k", on_exception=handler)
|
||||
await manager._run_job(manager._jobs["job1"])
|
||||
handler.assert_called_once_with(exc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_no_on_exception_handler_does_not_raise(self) -> None:
|
||||
"""A failing job without on_exception does not propagate the exception."""
|
||||
manager = make_manager()
|
||||
|
||||
def failing_func() -> None:
|
||||
raise RuntimeError("silent failure")
|
||||
|
||||
manager.register("job1", failing_func, interval_attr="k")
|
||||
await manager._run_job(manager._jobs["job1"]) # must not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_on_exception_not_called_on_success(self) -> None:
|
||||
"""on_exception is not called when the job succeeds."""
|
||||
manager = make_manager()
|
||||
handler = MagicMock()
|
||||
manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler)
|
||||
await manager._run_job(manager._jobs["job1"])
|
||||
handler.assert_not_called()
|
||||
@@ -2,11 +2,22 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from akkudoktoreos.core.version import _version_calculate, _version_hash
|
||||
from akkudoktoreos.core.version import (
|
||||
ALLOWED_SUFFIXES,
|
||||
DIR_PACKAGE_ROOT,
|
||||
EXCLUDED_DIR_PATTERNS,
|
||||
EXCLUDED_FILES,
|
||||
HashConfig,
|
||||
_version_calculate,
|
||||
_version_hash,
|
||||
collect_files,
|
||||
hash_files,
|
||||
)
|
||||
|
||||
DIR_PROJECT_ROOT = Path(__file__).parent.parent
|
||||
GET_VERSION_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "get_version.py"
|
||||
@@ -14,11 +25,166 @@ BUMP_DEV_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "bump_dev_version.py"
|
||||
UPDATE_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "update_version.py"
|
||||
|
||||
|
||||
# --- Git helpers ---
|
||||
|
||||
def get_git_tracked_files(repo_path: Path) -> Optional[set[Path]]:
|
||||
"""Get set of all files tracked by git in the repository.
|
||||
|
||||
Returns None if not a git repository or git command fails.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "ls-files"],
|
||||
cwd=repo_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# Convert relative paths to absolute paths
|
||||
tracked_files = {
|
||||
(repo_path / line.strip()).resolve()
|
||||
for line in result.stdout.splitlines()
|
||||
if line.strip()
|
||||
}
|
||||
return tracked_files
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def is_git_repository(path: Path) -> bool:
|
||||
"""Check if path is inside a git repository."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--git-dir"],
|
||||
cwd=path,
|
||||
capture_output=True,
|
||||
check=True
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def get_git_root(path: Path) -> Optional[Path]:
|
||||
"""Get the root directory of the git repository containing path."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
cwd=path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
return Path(result.stdout.strip())
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def check_files_in_git(
|
||||
files: list[Path],
|
||||
base_path: Optional[Path] = None
|
||||
) -> tuple[list[Path], list[Path]]:
|
||||
"""Check which files are tracked by git.
|
||||
|
||||
Args:
|
||||
files: List of files to check
|
||||
base_path: Base path to check for git repository (uses first file's parent if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (tracked_files, untracked_files)
|
||||
|
||||
Example:
|
||||
>>> files = collect_files(config)
|
||||
>>> tracked, untracked = check_files_in_git(files)
|
||||
>>> if untracked:
|
||||
... print(f"Warning: {len(untracked)} files not in git")
|
||||
"""
|
||||
if not files:
|
||||
return [], []
|
||||
|
||||
check_path = base_path or files[0].parent
|
||||
|
||||
assert is_git_repository(check_path)
|
||||
|
||||
git_root = get_git_root(check_path)
|
||||
if not git_root:
|
||||
return [], files
|
||||
|
||||
git_tracked = get_git_tracked_files(git_root)
|
||||
if git_tracked is None:
|
||||
return [], files
|
||||
|
||||
tracked = [f for f in files if f in git_tracked]
|
||||
untracked = [f for f in files if f not in git_tracked]
|
||||
|
||||
return tracked, untracked
|
||||
|
||||
|
||||
# --- Helper to create test files ---
|
||||
def write_file(path: Path, content: str):
|
||||
path.write_text(content, encoding="utf-8")
|
||||
return path
|
||||
|
||||
# -- Test version calculation ---
|
||||
|
||||
def test_version_hash() -> None:
|
||||
"""Test which files are used for version hash calculation."""
|
||||
|
||||
watched_paths = [DIR_PACKAGE_ROOT]
|
||||
|
||||
# Collect files
|
||||
config = HashConfig(
|
||||
paths=watched_paths,
|
||||
allowed_suffixes=ALLOWED_SUFFIXES,
|
||||
excluded_dir_patterns=EXCLUDED_DIR_PATTERNS,
|
||||
excluded_files=EXCLUDED_FILES
|
||||
)
|
||||
|
||||
files = collect_files(config)
|
||||
hash_digest = hash_files(files)
|
||||
|
||||
# Check git
|
||||
tracked, untracked = check_files_in_git(files, DIR_PACKAGE_ROOT)
|
||||
tracked_files: list[Path] = tracked
|
||||
untracked_files: list[Path] = untracked
|
||||
|
||||
if untracked_files:
|
||||
error_msg = f"\n{'='*60}"
|
||||
error_msg += f"Version Hash Inspection"
|
||||
error_msg += f"{'='*60}\n"
|
||||
error_msg += f"Hash: {hash_digest}"
|
||||
error_msg += f"Based on {len(files)} files:\n"
|
||||
|
||||
error_msg += f"OK: {len(tracked_files)} files tracked by git:\n"
|
||||
for i, file_path in enumerate(files, 1):
|
||||
try:
|
||||
rel_path = file_path.relative_to(DIR_PACKAGE_ROOT)
|
||||
status = ""
|
||||
if file_path in untracked_files:
|
||||
continue
|
||||
elif file_path in tracked_files:
|
||||
status = " [tracked]"
|
||||
error_msg += f" {i:3d}. {rel_path}{status}\n"
|
||||
except ValueError:
|
||||
error_msg += f" {i:3d}. {file_path}\n"
|
||||
|
||||
error_msg += f"Warning: {len(untracked_files)} files not tracked by git:\n"
|
||||
for i, file_path in enumerate(files, 1):
|
||||
try:
|
||||
rel_path = file_path.relative_to(DIR_PACKAGE_ROOT)
|
||||
status = ""
|
||||
if file_path in untracked_files:
|
||||
status = " [NOT IN GIT]"
|
||||
elif file_path in tracked_files:
|
||||
continue
|
||||
error_msg += f" {i:3d}. {rel_path}{status}\n"
|
||||
except ValueError:
|
||||
error_msg += f" {i:3d}. {file_path}\n"
|
||||
|
||||
error_msg += f"\n{'='*60}\n"
|
||||
|
||||
pytest.fail(error_msg)
|
||||
|
||||
|
||||
# --- Test version helpers ---
|
||||
def test_version_non_dev(monkeypatch):
|
||||
@@ -38,7 +204,7 @@ def test_version_dev_precision_8(monkeypatch):
|
||||
|
||||
result = _version_calculate()
|
||||
|
||||
# compute expected suffix
|
||||
# Compute expected suffix using the same logic as _version_calculate
|
||||
hash_value = int(fake_hash, 16)
|
||||
expected_digits = str(hash_value % (10 ** 8)).zfill(8)
|
||||
|
||||
@@ -60,12 +226,17 @@ def test_version_dev_precision_8_different_hash(monkeypatch):
|
||||
|
||||
result = _version_calculate()
|
||||
|
||||
# Compute expected suffix using the same logic as _version_calculate
|
||||
hash_value = int(fake_hash, 16)
|
||||
expected_digits = str(hash_value % (10 ** 8)).zfill(8)
|
||||
|
||||
expected = f"0.2.0.dev{expected_digits}"
|
||||
|
||||
assert result == expected
|
||||
assert len(expected_digits) == 8
|
||||
assert result.startswith("0.2.0.dev")
|
||||
assert result == expected
|
||||
|
||||
|
||||
|
||||
# --- 1️⃣ Test get_version.py ---
|
||||
|
||||
@@ -6,7 +6,7 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.cache import CacheFileStore
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.weatherbrightsky import WeatherBrightSky
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from akkudoktoreos.core.cache import CacheFileStore
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.weatherclearoutside import WeatherClearOutside
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy.testing as npt
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.coreabc import get_ems
|
||||
from akkudoktoreos.prediction.weatherimport import WeatherImport
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
|
||||
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
|
||||
|
||||
@@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos, monkeypatch):
|
||||
)
|
||||
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
|
||||
"""Test fetching forecast from Import."""
|
||||
key = "weather_temp_air"
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
|
||||
if from_file:
|
||||
@@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
else:
|
||||
config_eos.weather.provider_settings.WeatherImport.import_file_path = None
|
||||
assert config_eos.weather.provider_settings.WeatherImport.import_file_path is None
|
||||
provider.clear()
|
||||
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
|
||||
|
||||
# Call the method
|
||||
provider.update_data()
|
||||
@@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
|
||||
assert provider.ems_start_datetime is not None
|
||||
assert provider.total_hours is not None
|
||||
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["weather_temp_air"]
|
||||
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
expected_value = values[expected_value_index]
|
||||
result_datetime = provider.records[i].date_time
|
||||
result_value = provider.records[i]["weather_temp_air"]
|
||||
|
||||
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
|
||||
# print(f"{i}: Result: {result_datetime}:{result_value}")
|
||||
assert compare_datetimes(result_datetime, expected_datetime).equal
|
||||
assert result_value == expected_value
|
||||
expected_values = sample_import_1_json[key]
|
||||
result_values = provider.key_to_array(
|
||||
key=key,
|
||||
start_datetime=provider.ems_start_datetime,
|
||||
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
|
||||
interval=to_duration("1 hour"),
|
||||
)
|
||||
# Allow for some difference due to value calculation on DST change
|
||||
npt.assert_allclose(result_values, expected_values, rtol=0.001)
|
||||
|
||||
2
tests/testdata/eos_config_andreas_now.json
vendored
2
tests/testdata/eos_config_andreas_now.json
vendored
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"general": {
|
||||
"data_folder_path": null,
|
||||
"data_folder_path": "__ANY__",
|
||||
"data_output_subpath": "output",
|
||||
"latitude": 52.5,
|
||||
"longitude": 13.4
|
||||
|
||||
Reference in New Issue
Block a user