Add database support for measurements and historic prediction data. (#848)

The database supports backend selection, compression, incremental data load,
automatic data saving to storage, automatic vaccum and compaction.

Make SQLite3 and LMDB database backends available.

Update tests for new interface conventions regarding data sequences,
data containers, data providers. This includes the measurements provider and
the prediction providers.

Add database documentation.

The fix includes several bug fixes that are not directly related to the database
implementation but are necessary to keep EOS running properly and to test and
document the changes.

* fix: config eos test setup

  Make the config_eos fixture generate a new instance of the config_eos singleton.
  Use correct env names to setup data folder path.

* fix: startup with no config

  Make cache and measurements complain about missing data path configuration but
  do not bail out.

* fix: soc data preparation and usage for genetic optimization.

  Search for soc measurments 48 hours around the optimization start time.
  Only clamp soc to maximum in battery device simulation.

* fix: dashboard bailout on zero value solution display

  Do not use zero values to calculate the chart values adjustment for display.

* fix: openapi generation script

  Make the script also replace data_folder_path and data_output_path to hide
  real (test) environment pathes.

* feat: add make repeated task function

  make_repeated_task allows to wrap a function to be repeated cyclically.

* chore: removed index based data sequence access

  Index based data sequence access does not make sense as the sequence can be backed
  by the database. The sequence is now purely time series data.

* chore: refactor eos startup to avoid module import startup

  Avoid module import initialisation expecially of the EOS configuration.
  Config mutation, singleton initialization, logging setup, argparse parsing,
  background task definitions depending on config and environment-dependent behavior
  is now done at function startup.

* chore: introduce retention manager

  A single long-running background task that owns the scheduling of all periodic
  server-maintenance jobs (cache cleanup, DB autosave, …)

* chore: canonicalize timezone name for UTC

  Timezone names that are semantically identical to UTC are canonicalized to UTC.

* chore: extend config file migration for default value handling

  Extend the config file migration handling values None or nonexisting values
  that will invoke a default value generation in the new config file. Also
  adapt test to handle this situation.

* chore: extend datetime util test cases

* chore: make version test check for untracked files

  Check for files that are not tracked by git. Version calculation will be
  wrong if these files will not be commited.

* chore: bump pandas to 3.0.0

  Pandas 3.0 now performs inference on the appropriate resolution (a.k.a. unit)
  for the output dtype which may become datetime64[us] (before it was ns). Also
  numeric dtype detection is now more strict which needs a different detection for
  numerics.

* chore: bump pydantic-settings to 2.12.0

  pydantic-settings 2.12.0 under pytest creates a different behaviour. The tests
  were adapted and a workaround was introduced. Also ConfigEOS was adapted
  to allow for fine grain initialization control to be able to switch
  off certain settings such as file settings during test.

* chore: remove sci learn kit from dependencies

  The sci learn kit is not strictly necessary as long as we have scipy.

* chore: add documentation mode guarding for sphinx autosummary

  Sphinx autosummary excecutes functions. Prevent exceptions in case of pure doc
  mode.

* chore: adapt docker-build CI workflow to stricter GitHub handling

Signed-off-by: Bobby Noelte <b0661n0e17e@gmail.com>
This commit is contained in:
Bobby Noelte
2026-02-22 14:12:42 +01:00
committed by GitHub
parent 5f66591d21
commit 6498c7dc32
92 changed files with 12710 additions and 2173 deletions

View File

@@ -22,7 +22,8 @@ from _pytest.logging import LogCaptureFixture
from loguru import logger
from xprocess import ProcessStarter, XProcess
from akkudoktoreos.config.config import ConfigEOS, get_config
from akkudoktoreos.config.config import ConfigEOS
from akkudoktoreos.core.coreabc import get_config, get_prediction, singletons_init
from akkudoktoreos.core.version import _version_hash, version
from akkudoktoreos.server.server import get_default_host
@@ -134,8 +135,6 @@ def is_ci() -> bool:
@pytest.fixture
def prediction_eos():
from akkudoktoreos.prediction.prediction import get_prediction
return get_prediction()
@@ -172,6 +171,37 @@ def cfg_non_existent(request):
)
# ------------------------------------
# Provide pytest EOS config management
# ------------------------------------
@pytest.fixture
def config_default_dirs(tmpdir):
"""Fixture that provides a list of directories to be used as config dir."""
tmp_user_home_dir = Path(tmpdir)
# Default config directory from platform user config directory
config_default_dir_user = tmp_user_home_dir / "config"
# Default config directory from current working directory
config_default_dir_cwd = tmp_user_home_dir / "cwd"
config_default_dir_cwd.mkdir()
# Default config directory from default config file
config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data")
# Default data directory from platform user data directory
data_default_dir_user = tmp_user_home_dir
return (
config_default_dir_user,
config_default_dir_cwd,
config_default_dir_default,
data_default_dir_user,
)
@pytest.fixture(autouse=True)
def user_cwd(config_default_dirs):
"""Patch cwd provided by module pathlib.Path.cwd."""
@@ -203,64 +233,102 @@ def user_data_dir(config_default_dirs):
@pytest.fixture
def config_eos(
def config_eos_factory(
disable_debug_logging,
user_config_dir,
user_data_dir,
user_cwd,
config_default_dirs,
monkeypatch,
) -> ConfigEOS:
"""Fixture to reset EOS config to default values."""
monkeypatch.setenv(
"EOS_CONFIG__DATA_CACHE_SUBPATH", str(config_default_dirs[-1] / "data/cache")
)
monkeypatch.setenv(
"EOS_CONFIG__DATA_OUTPUT_SUBPATH", str(config_default_dirs[-1] / "data/output")
)
config_file = config_default_dirs[0] / ConfigEOS.CONFIG_FILE_NAME
config_file_cwd = config_default_dirs[1] / ConfigEOS.CONFIG_FILE_NAME
assert not config_file.exists()
assert not config_file_cwd.exists()
):
"""Factory fixture for creating a fully initialized ``ConfigEOS`` instance.
config_eos = get_config()
config_eos.reset_settings()
assert config_file == config_eos.general.config_file_path
assert config_file.exists()
assert not config_file_cwd.exists()
Returns a callable that creates a ``ConfigEOS`` singleton with a controlled
filesystem layout and environment variables. Allows tests to customize which
pydantic-settings sources are enabled (init, env, dotenv, file, secrets).
# Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user)
assert config_default_dirs[-1] / "data" == config_eos.general.data_folder_path
assert config_default_dirs[-1] / "data/cache" == config_eos.cache.path()
assert config_default_dirs[-1] / "data/output" == config_eos.general.data_output_path
assert config_default_dirs[-1] / "data/output/eos.log" == config_eos.logging.file_path
return config_eos
The factory ensures:
- Required directories exist
- No pre-existing config files are present
- Settings are reloaded to respect test-specific configuration
- Dependent singletons are initialized
The singleton instance is reset during fixture teardown.
"""
def _create(init: dict[str, bool] | None = None) -> ConfigEOS:
init = init or {
"with_init_settings": True,
"with_env_settings": True,
"with_dotenv_settings": False,
"with_file_settings": False,
"with_file_secret_settings": False,
}
# reset singleton before touching env or config
ConfigEOS.reset_instance()
ConfigEOS._init_config_eos = {
"with_init_settings": True,
"with_env_settings": True,
"with_dotenv_settings": True,
"with_file_settings": True,
"with_file_secret_settings": True,
}
ConfigEOS._config_file_path = None
ConfigEOS._force_documentation_mode = False
data_folder_path = config_default_dirs[-1] / "data"
data_folder_path.mkdir(exist_ok=True)
config_dir = config_default_dirs[0]
config_dir.mkdir(exist_ok=True)
cwd = config_default_dirs[1]
cwd.mkdir(exist_ok=True)
monkeypatch.setenv("EOS_CONFIG_DIR", str(config_dir))
monkeypatch.setenv("EOS_GENERAL__DATA_FOLDER_PATH", str(data_folder_path))
monkeypatch.setenv("EOS_GENERAL__DATA_CACHE_SUBPATH", "cache")
monkeypatch.setenv("EOS_GENERAL__DATA_OUTPUT_SUBPATH", "output")
# Ensure no config files exist
config_file = config_dir / ConfigEOS.CONFIG_FILE_NAME
config_file_cwd = cwd / ConfigEOS.CONFIG_FILE_NAME
assert not config_file.exists()
assert not config_file_cwd.exists()
config_eos = get_config(init=init)
# Ensure newly created configurations are respected
# Note: Workaround for pydantic_settings and pytest
config_eos.reset_settings()
# Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user)
assert config_eos.general.data_folder_path == data_folder_path
assert config_eos.general.data_output_subpath == Path("output")
assert config_eos.cache.subpath == "cache"
assert config_eos.cache.path() == config_default_dirs[-1] / "data/cache"
assert config_eos.logging.file_path == config_default_dirs[-1] / "data/output/eos.log"
# Check config file path
assert str(config_eos.general.config_file_path) == str(config_file)
assert config_file.exists()
assert not config_file_cwd.exists()
# Initialize all other singletons (if not already initialized)
singletons_init()
return config_eos
yield _create
# teardown - final safety net
ConfigEOS.reset_instance()
@pytest.fixture
def config_default_dirs(tmpdir):
"""Fixture that provides a list of directories to be used as config dir."""
tmp_user_home_dir = Path(tmpdir)
# Default config directory from platform user config directory
config_default_dir_user = tmp_user_home_dir / "config"
# Default config directory from current working directory
config_default_dir_cwd = tmp_user_home_dir / "cwd"
config_default_dir_cwd.mkdir()
# Default config directory from default config file
config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data")
# Default data directory from platform user data directory
data_default_dir_user = tmp_user_home_dir
return (
config_default_dir_user,
config_default_dir_cwd,
config_default_dir_default,
data_default_dir_user,
)
def config_eos(config_eos_factory) -> ConfigEOS:
"""Fixture to reset EOS config to default values."""
config_eos = config_eos_factory()
return config_eos
# ------------------------------------
@@ -405,7 +473,11 @@ def server_base(
Yields:
dict[str, str]: A dictionary containing:
- "server" (str): URL of the server.
- "port": port
- "eosdash_server": eosdash_server
- "eosdash_port": eosdash_port
- "eos_dir" (str): Path to the temporary EOS_DIR.
- "timeout": server_timeout
"""
host = get_default_host()
port = 8503
@@ -427,12 +499,14 @@ def server_base(
eos_tmp_dir = tempfile.TemporaryDirectory()
eos_dir = str(eos_tmp_dir.name)
eos_general_data_folder_path = str(Path(eos_dir) / "data")
class Starter(ProcessStarter):
# Set environment for server run
env = os.environ.copy()
env["EOS_DIR"] = eos_dir
env["EOS_CONFIG_DIR"] = eos_dir
env["EOS_GENERAL__DATA_FOLDER_PATH"] = eos_general_data_folder_path
if extra_env:
env.update(extra_env)

View File

@@ -12,13 +12,11 @@ from typing import Any
import numpy as np
from loguru import logger
from akkudoktoreos.config.config import get_config
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_config, get_ems, get_prediction
from akkudoktoreos.core.emsettings import EnergyManagementMode
from akkudoktoreos.optimization.genetic.geneticparams import (
GeneticOptimizationParameters,
)
from akkudoktoreos.prediction.prediction import get_prediction
from akkudoktoreos.utils.datetimeutil import to_datetime
config_eos = get_config()

View File

@@ -6,8 +6,7 @@ import pstats
import sys
import time
from akkudoktoreos.config.config import get_config
from akkudoktoreos.prediction.prediction import get_prediction
from akkudoktoreos.core.coreabc import get_config, get_prediction
config_eos = get_config()
prediction_eos = get_prediction()

View File

@@ -12,11 +12,11 @@ import pytest
from akkudoktoreos.adapter.adapter import (
Adapter,
AdapterCommonSettings,
get_adapter,
)
from akkudoktoreos.adapter.adapterabc import AdapterContainer
from akkudoktoreos.adapter.homeassistant import HomeAssistantAdapter
from akkudoktoreos.adapter.nodered import NodeREDAdapter
from akkudoktoreos.core.coreabc import get_adapter
# ---------- Typed aliases for fixtures ----------
AdapterFixture: TypeAlias = Adapter

View File

@@ -167,12 +167,21 @@ def temp_store_file():
@pytest.fixture
def cache_file_store(temp_store_file):
def cache_file_store(temp_store_file, monkeypatch):
"""A pytest fixture that creates a new CacheFileStore instance for testing."""
cache = CacheFileStore()
cache._store_file = temp_store_file
# Patch the _cache_file method to return the temp file
monkeypatch.setattr(
cache,
"_store_file",
lambda: temp_store_file,
)
cache.clear(clear_all=True)
assert len(cache._store) == 0
return cache
@@ -481,7 +490,7 @@ class TestCacheFileStore:
cache_file_store.save_store()
# Verify the file content
with cache_file_store._store_file.open("r", encoding="utf-8", newline=None) as f:
with cache_file_store._store_file().open("r", encoding="utf-8", newline=None) as f:
store_loaded = json.load(f)
assert "test_key" in store_loaded
assert store_loaded["test_key"]["cache_file"] == "cache_file_path"
@@ -501,7 +510,7 @@ class TestCacheFileStore:
"ttl_duration": None,
}
}
with cache_file_store._store_file.open("w", encoding="utf-8", newline="\n") as f:
with cache_file_store._store_file().open("w", encoding="utf-8", newline="\n") as f:
json.dump(cache_record, f, indent=4)
# Mock the open function to return a MagicMock for the cache file

View File

@@ -109,17 +109,18 @@ def test_config_ipaddress(monkeypatch, config_eos):
assert config_eos.server.host == "localhost"
def test_singleton_behavior(config_eos, config_default_dirs):
def test_singleton_behavior(config_eos, config_default_dirs, monkeypatch):
"""Test that ConfigEOS behaves as a singleton."""
initial_cfg_file = config_eos.general.config_file_path
with patch(
"akkudoktoreos.config.config.user_config_dir", return_value=str(config_default_dirs[0])
):
instance1 = ConfigEOS()
instance2 = ConfigEOS()
assert instance1 is config_eos
config_eos.reset_instance()
monkeypatch.setenv("EOS_CONFIG_DIR", str(config_default_dirs[0]))
instance1 = ConfigEOS()
instance2 = ConfigEOS()
assert instance1 is not config_eos
assert instance1 is instance2
assert instance1.general.config_file_path == initial_cfg_file
assert instance1._config_file_path == instance2._config_file_path
def test_config_file_priority(config_default_dirs):
@@ -169,17 +170,22 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
monkeypatch.setenv("EOS_DIR", str(temp_dir_path))
monkeypatch.delenv("EOS_CONFIG_DIR", raising=False)
assert config_eos._get_config_file_path() == (cfg_file(temp_dir_path), False)
monkeypatch.setenv("EOS_CONFIG_DIR", "config")
config_dir = temp_dir_path / "config"
config_dir.mkdir(exist_ok=True)
assert config_eos._get_config_file_path() == (
cfg_file(temp_dir_path / "config"),
cfg_file(config_dir),
False,
)
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config2"))
config_dir = temp_dir_path / "config2"
config_dir.mkdir(exist_ok=True)
assert config_eos._get_config_file_path() == (
cfg_file(temp_dir_path / "config2"),
cfg_file(config_dir),
False,
)
@@ -188,8 +194,10 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_
assert config_eos._get_config_file_path() == (cfg_file(config_default_dir_user), False)
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config3"))
config_dir = temp_dir_path / "config3"
config_dir.mkdir(exist_ok=True)
assert config_eos._get_config_file_path() == (
cfg_file(temp_dir_path / "config3"),
cfg_file(config_dir),
False,
)
@@ -199,7 +207,7 @@ def test_config_copy(config_eos, monkeypatch):
with tempfile.TemporaryDirectory() as temp_dir:
temp_folder_path = Path(temp_dir)
temp_config_file_path = temp_folder_path.joinpath(config_eos.CONFIG_FILE_NAME).resolve()
monkeypatch.setenv(config_eos.EOS_DIR, str(temp_folder_path))
monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_folder_path))
assert not temp_config_file_path.exists()
with patch("akkudoktoreos.config.config.user_config_dir", return_value=temp_dir):
assert config_eos._get_config_file_path() == (temp_config_file_path, False)

View File

@@ -26,6 +26,9 @@ MIGRATION_PAIRS = [
# (DIR_TESTDATA / "old_config_X.json", DIR_TESTDATA / "expected_config_X.json"),
]
# Any sentinel in expected data
_ANY_SENTINEL = "__ANY__"
def _dict_contains(superset: Any, subset: Any, path="") -> list[str]:
"""Recursively verify that all key-value pairs from a subset dictionary or list exist in a superset.
@@ -60,6 +63,9 @@ def _dict_contains(superset: Any, subset: Any, path="") -> list[str]:
errors.extend(_dict_contains(superset[i], elem, f"{path}[{i}]" if path else f"[{i}]"))
else:
# "__ANY__" in expected means "accept whatever value the migration produces"
if subset == _ANY_SENTINEL:
return errors
# Compare values (with numeric tolerance)
if isinstance(subset, (int, float)) and isinstance(superset, (int, float)):
if abs(float(subset) - float(superset)) > 1e-6:
@@ -162,6 +168,7 @@ class TestConfigMigration:
assert backup_file.exists(), f"Backup file not created for {old_file.name}"
# --- Compare migrated result with expected output ---
old_data = json.loads(old_file.read_text(encoding="utf-8"))
new_data = json.loads(working_file.read_text(encoding="utf-8"))
expected_data = json.loads(expected_file.read_text(encoding="utf-8"))
@@ -202,6 +209,14 @@ class TestConfigMigration:
# Verify the migrated value matches the expected one
new_value = configmigrate._get_json_nested_value(new_data, new_path)
if new_value != expected_value:
# Check if this mapping uses _KEEP_DEFAULT and the old value was None/missing
old_value = configmigrate._get_json_nested_value(old_data, old_path)
keep_default = (
isinstance(mapping, tuple)
and configmigrate._KEEP_DEFAULT in mapping
)
if keep_default and old_value is None:
continue # acceptable: old was None, new model keeps its default
mismatched_values.append(
f"{old_path}{new_path}: expected {expected_value!r}, got {new_value!r}"
)

File diff suppressed because it is too large Load Diff

1114
tests/test_dataabccompact.py Normal file

File diff suppressed because it is too large Load Diff

1148
tests/test_database.py Normal file

File diff suppressed because it is too large Load Diff

888
tests/test_databaseabc.py Normal file
View File

@@ -0,0 +1,888 @@
from typing import Any, Iterator, Literal, Optional, Type, cast
import pytest
from numpydantic import NDArray, Shape
from pydantic import BaseModel, Field
from akkudoktoreos.core.databaseabc import (
DATABASE_METADATA_KEY,
DatabaseRecordProtocolMixin,
DatabaseTimestamp,
_DatabaseTimestampUnbound,
)
from akkudoktoreos.utils.datetimeutil import (
DateTime,
Duration,
to_datetime,
to_duration,
)
# ---------------------------------------------------------------------------
# Test record
# ---------------------------------------------------------------------------
class SampleRecord(BaseModel):
date_time: Optional[DateTime] = Field(
default=None, json_schema_extra={"description": "DateTime"}
)
value: Optional[float] = None
def __getitem__(self, key: str) -> Any:
if key == "date_time":
return self.date_time
if key == "value":
return self.value
assert key is None
return None
# ---------------------------------------------------------------------------
# Fake database backend
# ---------------------------------------------------------------------------
class SampleDatabase:
def __init__(self):
self._data: dict[Optional[str], dict[bytes, bytes]] = {}
self._metadata: Optional[bytes] = None
self.is_open = True
self.compression = False
self.compression_level = 0
self.storage_path = "/fake"
# serialization (pass-through)
def serialize_data(self, data: bytes) -> bytes:
return data
def deserialize_data(self, data: bytes) -> bytes:
return data
# metadata
def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None:
self._metadata = metadata
def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]:
return self._metadata
# write
def save_records(
self, records: list[tuple[bytes, bytes]], namespace: Optional[str] = None
) -> int:
ns = self._data.setdefault(namespace, {})
saved = 0
for key, value in records:
ns[key] = value
saved += 1
return saved
def delete_records(
self, keys: Iterator[bytes], namespace: Optional[str] = None
) -> int:
ns_data = self._data.get(namespace, {})
deleted = 0
for key in keys:
if key in ns_data:
del ns_data[key]
deleted += 1
return deleted
# read
def iterate_records(
self,
start_key: Optional[bytes] = None,
end_key: Optional[bytes] = None,
namespace: Optional[str] = None,
reverse: bool = False,
) -> Iterator[tuple[bytes, bytes]]:
items = self._data.get(namespace, {})
keys = sorted(items, reverse=reverse)
for k in keys:
if k == DATABASE_METADATA_KEY:
continue
if start_key and k < start_key:
continue
if end_key and k >= end_key:
continue
yield k, items[k]
# stats
def count_records(
self,
start_key: Optional[bytes] = None,
end_key: Optional[bytes] = None,
*,
namespace: Optional[str] = None,
) -> int:
items = self._data.get(namespace, {})
count = 0
for k in items:
if k == DATABASE_METADATA_KEY:
continue
if start_key and k < start_key:
continue
if end_key and k >= end_key:
continue
count += 1
return count
def get_key_range(
self, namespace: Optional[str] = None
) -> tuple[Optional[bytes], Optional[bytes]]:
items = self._data.get(namespace, {})
keys = sorted(k for k in items if k != DATABASE_METADATA_KEY)
if not keys:
return None, None
return keys[0], keys[-1]
def get_backend_stats(self, namespace: Optional[str] = None) -> dict:
return {}
def flush(self, namespace: Optional[str] = None) -> None:
pass
# ---------------------------------------------------------------------------
# Concrete test sequence — minimal, no Pydantic / singleton overhead
# ---------------------------------------------------------------------------
class SampleSequence(DatabaseRecordProtocolMixin[SampleRecord]):
"""Minimal concrete implementation for unit-testing the mixin."""
def __init__(self):
self.records: list[SampleRecord] = []
self._db_record_index: dict[DatabaseTimestamp, SampleRecord] = {}
self._db_sorted_timestamps: list[DatabaseTimestamp] = []
self._db_dirty_timestamps: set[DatabaseTimestamp] = set()
self._db_new_timestamps: set[DatabaseTimestamp] = set()
self._db_deleted_timestamps: set[DatabaseTimestamp] = set()
self._db_initialized: bool = True
self._db_storage_initialized: bool = False
self._db_metadata: Optional[dict] = None
self._db_loaded_range = None
from akkudoktoreos.core.databaseabc import DatabaseRecordProtocolLoadPhase
self._db_load_phase = DatabaseRecordProtocolLoadPhase.NONE
self._db_version: int = 1
self.database = SampleDatabase()
self.config = type(
"Cfg",
(),
{
"database": type(
"DBCfg",
(),
{
"auto_save": False,
"compression_level": 0,
"autosave_interval_sec": 10,
"initial_load_window_h": None,
"keep_duration_h": None,
},
)()
},
)()
@classmethod
def record_class(cls) -> Type[SampleRecord]:
return SampleRecord
def db_namespace(self) -> str:
return "test"
@property
def record_keys_writable(self) -> list[str]:
"""Return writable field names of SampleRecord.
Required by _db_compact_tier which iterates record_keys_writable
to decide which fields to resample. Must match exactly what
key_to_array accepts — only 'value' here, not 'date_time'.
"""
return ["value"]
# Override key_to_array for the mixin tests — the full DataSequence
# implementation lives in dataabc.py; here we provide a minimal version
# that resamples the single `value` field to demonstrate compaction.
def key_to_array(
self,
key: str,
start_datetime: Optional[DateTime] = None,
end_datetime: Optional[DateTime] = None,
interval: Optional[Duration] = None,
fill_method: Optional[str] = None,
dropna: Optional[bool] = True,
boundary: Literal["strict", "context"] = "context",
align_to_interval: bool = False,
) -> NDArray[Shape["*"], Any]:
import numpy as np
import pandas as pd
if interval is None:
interval = to_duration("1 hour")
dates = []
values = []
for record in self.records:
if record.date_time is None:
continue
ts = DatabaseTimestamp.from_datetime(record.date_time)
if start_datetime and DatabaseTimestamp.from_datetime(start_datetime) > ts:
continue
if end_datetime and DatabaseTimestamp.from_datetime(end_datetime) <= ts:
continue
dates.append(record.date_time)
values.append(getattr(record, key, None))
if not dates:
return np.array([])
index = pd.to_datetime(dates, utc=True)
series = pd.Series(values, index=index, dtype=float)
freq = f"{int(interval.total_seconds())}s"
origin = start_datetime if start_datetime else "start_day"
resampled = series.resample(freq, origin=origin).mean().interpolate("time")
if start_datetime is not None:
resampled = resampled.truncate(before=start_datetime)
if end_datetime is not None:
resampled = resampled.truncate(after=end_datetime)
return resampled.values
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _insert_records_every_n_minutes(
seq: SampleSequence,
base: DateTime,
count: int,
interval_minutes: int,
value_fn=None,
) -> None:
"""Insert `count` records spaced `interval_minutes` apart starting at `base`."""
for i in range(count):
dt = base.add(minutes=i * interval_minutes)
value = value_fn(i) if value_fn else float(i)
seq.db_insert_record(SampleRecord(date_time=dt, value=value))
seq.db_save_records()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def seq():
return SampleSequence()
@pytest.fixture
def seq_with_15min_data():
"""Sequence with 15-min records spanning 4 weeks, so both tiers have data."""
s = SampleSequence()
now = to_datetime().in_timezone("UTC")
# 4 weeks × 7 days × 24 h × 4 records/h = 2688 records
base = now.subtract(weeks=4)
_insert_records_every_n_minutes(s, base, count=2688, interval_minutes=15)
return s, now
@pytest.fixture
def seq_sparse():
"""Sequence with only 3 records spread over 4 weeks — sparse, no compaction benefit."""
s = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(weeks=4)
for offset_days in [0, 14, 27]:
dt = base.add(days=offset_days)
s.db_insert_record(SampleRecord(date_time=dt, value=float(offset_days)))
s.db_save_records()
return s, now
# ---------------------------------------------------------------------------
# Existing tests (unchanged)
# ---------------------------------------------------------------------------
class TestDatabaseRecordProtocolMixin:
@pytest.mark.parametrize(
"start_str, value_count, interval_seconds",
[
("2024-11-10 00:00:00", 24, 3600),
("2024-08-10 00:00:00", 24, 3600),
("2024-03-31 00:00:00", 24, 3600),
("2024-10-27 00:00:00", 24, 3600),
],
)
def test_db_generate_timestamps_utc_spacing(
self, seq, start_str, value_count, interval_seconds
):
start_dt = to_datetime(start_str, in_timezone="Europe/Berlin")
assert start_dt.tz.name == "Europe/Berlin"
db_start = DatabaseTimestamp.from_datetime(start_dt)
generated = list(seq.db_generate_timestamps(db_start, value_count))
assert len(generated) == value_count
for db_dt in generated:
dt = DatabaseTimestamp.to_datetime(db_dt)
assert dt.tz.name == "UTC"
assert len(generated) == len(set(generated)), "Duplicate UTC datetimes found"
for i in range(1, len(generated)):
last_dt = DatabaseTimestamp.to_datetime(generated[i - 1])
current_dt = DatabaseTimestamp.to_datetime(generated[i])
delta = (current_dt - last_dt).total_seconds()
assert delta == interval_seconds, f"Spacing mismatch at index {i}: {delta}s"
def test_insert_and_memory_range(self, seq):
t0 = to_datetime()
t1 = t0.add(hours=1)
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
seq.db_insert_record(SampleRecord(date_time=t1, value=2))
assert seq.records[0].date_time == t0
assert seq.records[-1].date_time == t1
assert len(seq.records) == 2
def test_roundtrip_reload(self):
seq = SampleSequence()
t0 = to_datetime()
t1 = t0.add(hours=1)
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
seq.db_insert_record(SampleRecord(date_time=t1, value=2))
assert seq.db_save_records() == 2
db = seq.database
seq2 = SampleSequence()
seq2.database = db
loaded = seq2.db_load_records()
assert loaded == 2
assert len(seq2.records) == 2
def test_db_count_records(self, seq):
t0 = to_datetime()
seq.db_insert_record(SampleRecord(date_time=t0, value=1))
assert seq.db_count_records() == 1
seq.db_save_records()
assert seq.db_count_records() == 1
def test_delete_range(self, seq):
base = to_datetime()
for i in range(5):
seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i))
db_start = DatabaseTimestamp.from_datetime(base.add(minutes=1))
db_end = DatabaseTimestamp.from_datetime(base.add(minutes=4))
deleted = seq.db_delete_records(start_timestamp=db_start, end_timestamp=db_end)
assert deleted == 3
assert [r.value for r in seq.records] == [0, 4]
def test_db_count_records_memory_only_multiple(self):
seq = SampleSequence()
base = to_datetime()
for i in range(3):
seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i))
assert seq.db_count_records() == 3
def test_db_count_records_memory_newer_than_db(self):
seq = SampleSequence()
base = to_datetime()
seq.db_insert_record(SampleRecord(date_time=base, value=1))
seq.db_save_records()
seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2))
seq.db_insert_record(SampleRecord(date_time=base.add(hours=2), value=3))
assert seq.db_count_records() == 3
def test_db_count_records_memory_older_than_db(self):
seq = SampleSequence()
base = to_datetime()
seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2))
seq.db_save_records()
seq.db_insert_record(SampleRecord(date_time=base, value=1))
assert seq.db_count_records() == 2
def test_db_count_records_empty_everywhere(self):
seq = SampleSequence()
assert seq.db_count_records() == 0
def test_metadata_not_counted(self, seq):
seq.database._data.setdefault("test", {})[DATABASE_METADATA_KEY] = b"meta"
assert seq.db_count_records() == 0
def test_key_range_excludes_metadata(self, seq):
ns = seq.db_namespace()
seq.database._data.setdefault(ns, {})[DATABASE_METADATA_KEY] = b"meta"
assert seq.database.get_key_range(ns) == (None, None)
# ---------------------------------------------------------------------------
# Compaction tests
# ---------------------------------------------------------------------------
class TestCompactTiers:
"""Tests for db_compact_tiers() and the tier hook."""
def test_default_tiers_returns_two_entries(self, seq):
tiers = seq.db_compact_tiers()
assert len(tiers) == 2
def test_default_tiers_ordered_shortest_first(self, seq):
tiers = seq.db_compact_tiers()
ages = [t[0].total_seconds() for t in tiers]
assert ages == sorted(ages), "Tiers must be ordered shortest age first"
def test_default_tiers_first_is_2h_to_15min(self, seq):
tiers = seq.db_compact_tiers()
age_sec, interval_sec = (
tiers[0][0].total_seconds(),
tiers[0][1].total_seconds(),
)
assert age_sec == 2 * 3600
assert interval_sec == 15 * 60
def test_default_tiers_second_is_2weeks_to_1h(self, seq):
tiers = seq.db_compact_tiers()
age_sec, interval_sec = (
tiers[1][0].total_seconds(),
tiers[1][1].total_seconds(),
)
assert age_sec == 14 * 24 * 3600
assert interval_sec == 3600
def test_override_tiers(self):
class CustomSeq(SampleSequence):
def db_compact_tiers(self):
return [(to_duration("7 days"), to_duration("1 hour"))]
s = CustomSeq()
tiers = s.db_compact_tiers()
assert len(tiers) == 1
assert tiers[0][1].total_seconds() == 3600
def test_empty_tiers_disables_compaction(self):
class NoCompactSeq(SampleSequence):
def db_compact_tiers(self):
return []
s = NoCompactSeq()
now = to_datetime().in_timezone("UTC")
base = now.subtract(weeks=4)
_insert_records_every_n_minutes(s, base, count=100, interval_minutes=15)
deleted = s.db_compact()
assert deleted == 0
class TestCompactState:
"""Tests for _db_get_compact_state / _db_set_compact_state."""
def test_get_state_returns_none_when_no_metadata(self, seq):
interval = to_duration("1 hour")
assert seq._db_get_compact_state(interval) is None
def test_set_and_get_state_roundtrip(self, seq):
interval = to_duration("1 hour")
now = to_datetime().in_timezone("UTC")
ts = DatabaseTimestamp.from_datetime(now)
seq._db_set_compact_state(interval, ts)
retrieved = seq._db_get_compact_state(interval)
assert retrieved == ts
def test_state_is_per_tier(self, seq):
"""Different tier intervals must not overwrite each other."""
interval_15min = to_duration("15 minutes")
interval_1h = to_duration("1 hour")
now = to_datetime().in_timezone("UTC")
ts_15 = DatabaseTimestamp.from_datetime(now)
ts_1h = DatabaseTimestamp.from_datetime(now.subtract(days=1))
seq._db_set_compact_state(interval_15min, ts_15)
seq._db_set_compact_state(interval_1h, ts_1h)
assert seq._db_get_compact_state(interval_15min) == ts_15
assert seq._db_get_compact_state(interval_1h) == ts_1h
def test_state_persists_in_metadata(self, seq):
"""State must survive a metadata reload."""
interval = to_duration("1 hour")
now = to_datetime().in_timezone("UTC")
ts = DatabaseTimestamp.from_datetime(now)
seq._db_set_compact_state(interval, ts)
# Reload metadata from fake DB
seq2 = SampleSequence()
seq2.database = seq.database
seq2._db_metadata = seq2._db_load_metadata()
assert seq2._db_get_compact_state(interval) == ts
class TestCompactSparseGuard:
"""The inflation guard must skip compaction when records are already sparse."""
def test_sparse_data_aligns_but_does_not_reduce_cardinality(self, seq_sparse):
"""Sparse data must be aligned to the target interval for all records that were modified."""
seq, _ = seq_sparse
interval = to_duration("15 minutes")
interval_sec = int(interval.total_seconds())
# Snapshot original timestamps
before_epochs = {
int(r.date_time.timestamp())
for r in seq.records
}
seq._db_compact_tier(
to_duration("30 minutes"),
interval,
)
after_epochs = {
int(r.date_time.timestamp())
for r in seq.records
}
# Cardinality must not increase
assert len(after_epochs) <= len(before_epochs)
# Any timestamp that changed must now be aligned
changed_epochs = after_epochs - before_epochs
for epoch in changed_epochs:
assert epoch % interval_sec == 0
def test_sparse_guard_advances_cutoff(self, seq_sparse):
"""Even when skipped, the cutoff should be stored so next run skips the same window."""
seq, _ = seq_sparse
interval_1h = to_duration("1 hour")
interval_15min = to_duration("15 minutes")
seq.db_compact()
# Both tiers should have stored a cutoff even though nothing was deleted
assert seq._db_get_compact_state(interval_1h) is not None
assert seq._db_get_compact_state(interval_15min) is not None
def test_exactly_at_boundary_remains_stable(self, seq):
now = to_datetime().in_timezone("UTC")
interval = to_duration("1 hour")
raw_base = now.subtract(hours=5).set(minute=0, second=0, microsecond=0)
base = raw_base.subtract(seconds=int(raw_base.timestamp()) % 3600)
for i in range(4):
seq.db_insert_record(
SampleRecord(
date_time=base.add(hours=i),
value=float(i),
)
)
seq.db_insert_record(
SampleRecord(date_time=now.subtract(seconds=1), value=0.0)
)
seq.db_save_records()
before = [
(int(r.date_time.timestamp()), r.value)
for r in seq.records
]
seq._db_compact_tier(
to_duration("30 minutes"),
interval,
)
after = [
(int(r.date_time.timestamp()), r.value)
for r in seq.records
]
assert before == after
class TestCompactTierWorker:
"""Unit tests for _db_compact_tier directly."""
def test_empty_sequence_returns_zero(self, seq):
age = to_duration("2 hours")
interval = to_duration("15 minutes")
assert seq._db_compact_tier(age, interval) == 0
def test_all_records_too_recent_skipped(self):
"""Records within the age threshold must not be touched."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
# Insert 10 records from 30 minutes ago — all within 2h threshold
base = now.subtract(minutes=30)
_insert_records_every_n_minutes(seq, base, count=10, interval_minutes=1)
before = seq.db_count_records()
deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
assert deleted == 0
assert seq.db_count_records() == before
def test_compaction_reduces_record_count(self):
"""Dense 1-min records older than 2h should be downsampled to 15-min."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
# Insert 1-min records for 6 hours ending 3 hours ago
base = now.subtract(hours=9)
_insert_records_every_n_minutes(seq, base, count=6 * 60, interval_minutes=1)
before = seq.db_count_records()
deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
after = seq.db_count_records()
assert deleted > 0
assert after < before
def test_records_within_threshold_preserved(self):
"""Records newer than age_threshold must remain untouched after compaction."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
# Old dense records (will be compacted)
old_base = now.subtract(hours=6)
_insert_records_every_n_minutes(seq, old_base, count=4 * 60, interval_minutes=1)
# Recent records (must not be touched) — insert 5 records in the last hour
recent_base = now.subtract(minutes=50)
_insert_records_every_n_minutes(seq, recent_base, count=5, interval_minutes=10)
recent_before = [
r for r in seq.records
if r.date_time and r.date_time >= recent_base
]
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
recent_after = [
r for r in seq.records
if r.date_time and r.date_time >= recent_base
]
assert len(recent_after) == len(recent_before)
def test_incremental_cutoff_prevents_recompaction(self):
"""Running compaction twice must not re-compact already-compacted data."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(hours=8)
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
age = to_duration("2 hours")
interval = to_duration("15 minutes")
deleted_first = seq._db_compact_tier(age, interval)
count_after_first = seq.db_count_records()
deleted_second = seq._db_compact_tier(age, interval)
count_after_second = seq.db_count_records()
assert deleted_first > 0
assert deleted_second == 0, "Second run must be a no-op"
assert count_after_first == count_after_second
def test_cutoff_stored_after_compaction(self):
"""Cutoff timestamp must be persisted after a successful compaction run."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(hours=8)
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
interval = to_duration("15 minutes")
seq._db_compact_tier(to_duration("2 hours"), interval)
assert seq._db_get_compact_state(interval) is not None
class TestDbCompact:
"""Integration tests for the public db_compact() entry point."""
def test_compact_dense_data_both_tiers(self, seq_with_15min_data):
"""4 weeks of 15-min data should be reduced by both tiers."""
seq, _ = seq_with_15min_data
before = seq.db_count_records()
total_deleted = seq.db_compact()
after = seq.db_count_records()
assert total_deleted > 0
assert after < before
def test_compact_coarsest_tier_runs_first(self, seq_with_15min_data):
"""The 1-hour tier (coarsest) must run before the 15-min tier.
If coarsest ran last it would re-compact records the 15-min tier
had already downsampled — verified by checking that the 1-hour
cutoff is not later than the 15-min cutoff.
"""
seq, _ = seq_with_15min_data
seq.db_compact()
cutoff_1h = seq._db_get_compact_state(to_duration("1 hour"))
cutoff_15min = seq._db_get_compact_state(to_duration("15 minutes"))
assert cutoff_1h is not None
assert cutoff_15min is not None
# The 1h tier covers older data → its cutoff must be earlier than 15min tier
assert cutoff_1h <= cutoff_15min
def test_compact_idempotent(self, seq_with_15min_data):
"""Running db_compact twice must not change record count."""
seq, _ = seq_with_15min_data
seq.db_compact()
after_first = seq.db_count_records()
seq.db_compact()
after_second = seq.db_count_records()
assert after_first == after_second
def test_compact_empty_sequence_returns_zero(self, seq):
assert seq.db_compact() == 0
def test_compact_with_override_tiers(self):
"""Passing compact_tiers directly must override db_compact_tiers()."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(weeks=3)
_insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 4, interval_minutes=15)
before = seq.db_count_records()
deleted = seq.db_compact(
compact_tiers=[(to_duration("1 day"), to_duration("1 hour"))]
)
assert deleted > 0
assert seq.db_count_records() < before
def test_compact_only_processes_new_window_on_second_call(self):
"""Second call processes only the new window, not the full history."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(weeks=3)
# Dense 1-min data for 3 weeks
_insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 60, interval_minutes=1)
seq.db_compact()
count_after_first = seq.db_count_records()
# Add one more day of dense data in the past (simulate new old data arriving)
extra_base = now.subtract(weeks=3).subtract(days=1)
_insert_records_every_n_minutes(seq, extra_base, count=24 * 60, interval_minutes=1)
seq.db_compact()
count_after_second = seq.db_count_records()
# Second compact should have processed the newly added old data
# Record count may change but should not exceed first compacted count by much
assert count_after_second >= 0 # basic sanity
class TestCompactDataIntegrity:
"""Verify value integrity is preserved after compaction."""
def test_constant_value_preserved(self):
"""Constant value field must survive mean-resampling unchanged."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(hours=6)
# All values = 42.0
_insert_records_every_n_minutes(
seq, base, count=6 * 60, interval_minutes=1, value_fn=lambda _: 42.0
)
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
for record in seq.records:
if record.date_time and record.date_time < now.subtract(hours=2):
assert record.value == pytest.approx(42.0, abs=1e-6)
def test_recent_records_not_modified(self):
"""Records newer than the age threshold must have unchanged values."""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
old_base = now.subtract(hours=6)
_insert_records_every_n_minutes(seq, old_base, count=3 * 60, interval_minutes=1)
# Known recent values
recent_base = now.subtract(minutes=30)
expected = {i * 10: float(100 + i) for i in range(3)}
for offset, val in expected.items():
dt = recent_base.add(minutes=offset)
seq.db_insert_record(SampleRecord(date_time=dt, value=val))
seq.db_save_records()
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
for record in seq.records:
if record.date_time and record.date_time >= recent_base:
offset = int((record.date_time - recent_base).total_seconds() / 60)
if offset in expected:
assert record.value == pytest.approx(expected[offset], abs=1e-6)
def test_compacted_timestamps_spacing(self):
"""Resampled records must be fewer than original and span the compaction window.
Exact per-bucket spacing depends on the full DataSequence.key_to_array
implementation (pandas resampling). The stub key_to_array in SampleSequence
only guarantees a reduction in count — uniform spacing is verified in
test_dataabc_compact.py against the real implementation.
"""
seq = SampleSequence()
now = to_datetime().in_timezone("UTC")
base = now.subtract(hours=6)
_insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1)
before = seq.db_count_records()
seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes"))
cutoff = now.subtract(hours=2)
compacted = sorted(
[r for r in seq.records if r.date_time and r.date_time < cutoff],
key=lambda r: cast(DateTime, r.date_time),
)
# Must have produced fewer records than the original 1-min data
assert len(compacted) > 0, "Expected at least one compacted record"
assert len(compacted) < before, "Compaction must reduce record count"
# Window start is floored to interval boundary
interval_sec = 15 * 60
expected_window_start = DateTime.fromtimestamp(
(int(base.timestamp()) // interval_sec) * interval_sec,
tz="UTC",
)
assert compacted[0].date_time >= expected_window_start
# Last compacted record must be before the cutoff
assert compacted[-1].date_time < cutoff

View File

@@ -1460,6 +1460,17 @@ class TestTimeWindowSequence:
# - without local timezone as UTC
(
"TC014",
"UTC",
"2024-01-03",
None,
"UTC",
None,
False,
pendulum.datetime(2024, 1, 3, 0, 0, 0, tz="UTC"),
False,
),
(
"TC015",
"Atlantic/Canary",
"02/02/24",
None,
@@ -1470,7 +1481,7 @@ class TestTimeWindowSequence:
False,
),
(
"TC015",
"TC016",
"Atlantic/Canary",
"2024-03-03T10:20:30.000Z", # No dalight saving time at this date
None,
@@ -1484,7 +1495,7 @@ class TestTimeWindowSequence:
# from pendulum.datetime to pendulum.datetime object
# ---------------------------------------
(
"TC016",
"TC017",
"Atlantic/Canary",
pendulum.datetime(2024, 4, 4, 0, 0, 0),
None,
@@ -1495,7 +1506,7 @@ class TestTimeWindowSequence:
False,
),
(
"TC017",
"TC018",
"Atlantic/Canary",
pendulum.datetime(2024, 4, 4, 1, 0, 0),
None,
@@ -1506,7 +1517,7 @@ class TestTimeWindowSequence:
False,
),
(
"TC018",
"TC019",
"Atlantic/Canary",
pendulum.datetime(2024, 4, 4, 1, 0, 0, tz="Etc/UTC"),
None,
@@ -1517,7 +1528,7 @@ class TestTimeWindowSequence:
False,
),
(
"TC019",
"TC020",
"Atlantic/Canary",
pendulum.datetime(2024, 4, 4, 2, 0, 0, tz="Europe/Berlin"),
None,
@@ -1533,7 +1544,7 @@ class TestTimeWindowSequence:
# - no timezone
# local timezone UTC
(
"TC020",
"TC021",
"Etc/UTC",
"2023-11-06T00:00:00",
"UTC",
@@ -1545,7 +1556,7 @@ class TestTimeWindowSequence:
),
# local timezone "Europe/Berlin"
(
"TC021",
"TC022",
"Europe/Berlin",
"2023-11-06T00:00:00",
"UTC",
@@ -1557,7 +1568,7 @@ class TestTimeWindowSequence:
),
# - no microseconds
(
"TC022",
"TC023",
"Atlantic/Canary",
"2024-10-30T00:00:00+01:00",
"UTC",
@@ -1568,7 +1579,7 @@ class TestTimeWindowSequence:
False,
),
(
"TC023",
"TC024",
"Atlantic/Canary",
"2024-10-30T01:00:00+01:00",
"utc",
@@ -1580,7 +1591,7 @@ class TestTimeWindowSequence:
),
# - with microseconds
(
"TC024",
"TC025",
"Atlantic/Canary",
"2024-10-07T10:20:30.000+02:00",
"UTC",
@@ -1596,7 +1607,7 @@ class TestTimeWindowSequence:
# - no timezone
# local timezone
(
"TC025",
"TC026",
None,
None,
None,

View File

@@ -14,8 +14,10 @@ DIR_DOCS_GENERATED = DIR_PROJECT_ROOT / "docs" / "_generated"
DIR_TEST_GENERATED = DIR_TESTDATA / "docs" / "_generated"
def test_openapi_spec_current(config_eos):
def test_openapi_spec_current(config_eos, set_other_timezone):
"""Verify the openapi spec hasn´t changed."""
set_other_timezone("UTC") # CI runs on UTC
expected_spec_path = DIR_PROJECT_ROOT / "openapi.json"
new_spec_path = DIR_TESTDATA / "openapi-new.json"
@@ -23,7 +25,7 @@ def test_openapi_spec_current(config_eos):
expected_spec = json.load(f_expected)
# Patch get_config and import within guard to patch global variables within the eos module.
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
# Ensure the script works correctly as part of a package
root_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(root_dir))
@@ -39,7 +41,7 @@ def test_openapi_spec_current(config_eos):
expected_spec_str = json.dumps(expected_spec, indent=4, sort_keys=True)
try:
assert spec_str == expected_spec_str
assert json.loads(spec_str) == json.loads(expected_spec_str)
except AssertionError as e:
pytest.fail(
f"Expected {new_spec_path} to equal {expected_spec_path}.\n"
@@ -47,8 +49,10 @@ def test_openapi_spec_current(config_eos):
)
def test_openapi_md_current(config_eos):
def test_openapi_md_current(config_eos, set_other_timezone):
"""Verify the generated openapi markdown hasn´t changed."""
set_other_timezone("UTC") # CI runs on UTC
expected_spec_md_path = DIR_PROJECT_ROOT / "docs" / "_generated" / "openapi.md"
new_spec_md_path = DIR_TESTDATA / "openapi-new.md"
@@ -56,7 +60,7 @@ def test_openapi_md_current(config_eos):
expected_spec_md = f_expected.read()
# Patch get_config and import within guard to patch global variables within the eos module.
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
# Ensure the script works correctly as part of a package
root_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(root_dir))
@@ -76,8 +80,10 @@ def test_openapi_md_current(config_eos):
)
def test_config_md_current(config_eos):
def test_config_md_current(config_eos, set_other_timezone):
"""Verify the generated configuration markdown hasn´t changed."""
set_other_timezone("UTC") # CI runs on UTC
assert DIR_DOCS_GENERATED.exists()
# Remove any leftover files from last run
@@ -88,7 +94,7 @@ def test_config_md_current(config_eos):
DIR_TEST_GENERATED.mkdir(parents=True, exist_ok=True)
# Patch get_config and import within guard to patch global variables within the eos module.
with patch("akkudoktoreos.config.config.get_config", return_value=config_eos):
with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos):
# Ensure the script works correctly as part of a package
root_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(root_dir))
@@ -106,7 +112,11 @@ def test_config_md_current(config_eos):
tested.append(DIR_TEST_GENERATED / file_name)
# Create test files
config_md = generate_config_md.generate_config_md(tested[0], config_eos)
try:
config_eos._force_documentation_mode = True
config_md = generate_config_md.generate_config_md(tested[0], config_eos)
finally:
config_eos._force_documentation_mode = False
# Check test files are the same as the expected files
for i, expected_path in enumerate(expected):

View File

@@ -9,6 +9,8 @@ from typing import Optional
import pytest
from akkudoktoreos.core.coreabc import singletons_init
DIR_PROJECT_ROOT = Path(__file__).absolute().parent.parent
DIR_BUILD = DIR_PROJECT_ROOT / "build"
DIR_BUILD_DOCS = DIR_PROJECT_ROOT / "build" / "docs"
@@ -80,6 +82,7 @@ class TestSphinxDocumentation:
def test_sphinx_build(self, sphinx_changed: Optional[str], is_finalize: bool):
"""Build Sphinx documentation and ensure no major warnings appear in the build output."""
# Ensure docs folder exists
if not DIR_DOCS.exists():
pytest.skip(f"Skipping Sphinx build test - docs folder not present: {DIR_DOCS}")
@@ -88,7 +91,7 @@ class TestSphinxDocumentation:
pytest.skip(f"Skipping Sphinx build — no relevant file changes detected: {HASH_FILE}")
if not is_finalize:
pytest.skip("Skipping Sphinx test — not full run")
pytest.skip("Skipping Sphinx test — not finalize")
# Clean directories
self._cleanup_autosum_dirs()
@@ -123,7 +126,11 @@ class TestSphinxDocumentation:
# Remove temporary EOS_DIR
eos_tmp_dir.cleanup()
assert returncode == 0
if returncode != 0:
pytest.fail(
f"Sphinx build failed with exit code {returncode}.\n"
f"{output}\n"
)
# Possible markers: ERROR: WARNING: TRACEBACK:
major_markers = ("ERROR:", "TRACEBACK:")

View File

@@ -8,7 +8,7 @@ import requests
from loguru import logger
from akkudoktoreos.core.cache import CacheFileStore
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.elecpriceakkudoktor import (
AkkudoktorElecPrice,
AkkudoktorElecPriceValue,

View File

@@ -8,7 +8,7 @@ import requests
from loguru import logger
from akkudoktoreos.core.cache import CacheFileStore
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.elecpriceakkudoktor import (
AkkudoktorElecPrice,
AkkudoktorElecPriceValue,

View File

@@ -1,11 +1,12 @@
import json
from pathlib import Path
import numpy.testing as npt
import pytest
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
@@ -83,6 +84,7 @@ def test_invalid_provider(provider, config_eos):
)
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
"""Test fetching forecast from Import."""
key = "elecprice_marketprice_wh"
ems_eos = get_ems()
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
if from_file:
@@ -91,7 +93,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
else:
config_eos.elecprice.elecpriceimport.import_file_path = None
assert config_eos.elecprice.elecpriceimport.import_file_path is None
provider.clear()
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
# Call the method
provider.update_data()
@@ -100,16 +102,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
assert provider.ems_start_datetime is not None
assert provider.total_hours is not None
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
values = sample_import_1_json["elecprice_marketprice_wh"]
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
for i, mapping in enumerate(value_datetime_mapping):
assert i < len(provider.records)
expected_datetime, expected_value_index = mapping
expected_value = values[expected_value_index]
result_datetime = provider.records[i].date_time
result_value = provider.records[i]["elecprice_marketprice_wh"]
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
# print(f"{i}: Result: {result_datetime}:{result_value}")
assert compare_datetimes(result_datetime, expected_datetime).equal
assert result_value == expected_value
expected_values = sample_import_1_json[key]
result_values = provider.key_to_array(
key=key,
start_datetime=provider.ems_start_datetime,
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
interval=to_duration("1 hour"),
)
# Allow for some difference due to value calculation on DST change
npt.assert_allclose(result_values, expected_values, rtol=0.001)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import pytest
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.feedintarifffixed import FeedInTariffFixed
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime

View File

@@ -7,7 +7,7 @@ import pytest
from akkudoktoreos.config.config import ConfigEOS
from akkudoktoreos.core.cache import CacheEnergyManagementStore
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.optimization.genetic.genetic import GeneticOptimization
from akkudoktoreos.optimization.genetic.geneticparams import (
GeneticOptimizationParameters,
@@ -18,7 +18,7 @@ from akkudoktoreos.utils.visualize import (
prepare_visualize, # Import the new prepare_visualize
)
ems_eos = get_ems()
ems_eos = get_ems(init=True) # init once
DIR_TESTDATA = Path(__file__).parent / "testdata"

View File

@@ -4,8 +4,8 @@ import numpy as np
import pendulum
import pytest
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.measurement.measurement import MeasurementDataRecord, get_measurement
from akkudoktoreos.core.coreabc import get_ems, get_measurement
from akkudoktoreos.measurement.measurement import MeasurementDataRecord
from akkudoktoreos.prediction.loadakkudoktor import (
LoadAkkudoktor,
LoadAkkudoktorAdjusted,
@@ -63,7 +63,7 @@ def measurement_eos():
dt = to_datetime("2024-01-01T00:00:00")
interval = to_duration("1 hour")
for i in range(25):
measurement.records.append(
measurement.insert_by_datetime(
MeasurementDataRecord(
date_time=dt,
load0_mr=load0_mr,
@@ -138,7 +138,7 @@ def test_update_data(mock_load_data, loadakkudoktor):
ems_eos.set_start_datetime(pendulum.datetime(2024, 1, 1))
# Assure there are no prediction records
loadakkudoktor.clear()
loadakkudoktor.delete_by_datetime(start_datetime=None, end_datetime=None)
assert len(loadakkudoktor) == 0
# Execute the method
@@ -152,6 +152,24 @@ def test_calculate_adjustment(loadakkudoktoradjusted, measurement_eos):
"""Test `_calculate_adjustment` for various scenarios."""
data_year_energy = np.random.rand(365, 2, 24)
# Check the test setup
assert loadakkudoktoradjusted.measurement is measurement_eos
assert measurement_eos.min_datetime == to_datetime("2024-01-01T00:00:00")
assert measurement_eos.max_datetime == to_datetime("2024-01-02T00:00:00")
# Use same calculation as in _calculate_adjustment
compare_start = measurement_eos.max_datetime - to_duration("7 days")
if compare_datetimes(compare_start, measurement_eos.min_datetime).lt:
# Not enough measurements for 7 days - use what is available
compare_start = measurement_eos.min_datetime
compare_end = measurement_eos.max_datetime
compare_interval = to_duration("1 hour")
load_total_kwh_array = measurement_eos.load_total_kwh(
start_datetime=compare_start,
end_datetime=compare_end,
interval=compare_interval,
)
np.testing.assert_allclose(load_total_kwh_array, [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
# Call the method and validate results
weekday_adjust, weekend_adjust = loadakkudoktoradjusted._calculate_adjustment(data_year_energy)
assert weekday_adjust.shape == (24,)

View File

@@ -3,10 +3,17 @@ import pytest
from pendulum import datetime, duration
from akkudoktoreos.config.config import SettingsEOS
from akkudoktoreos.core.coreabc import get_measurement
from akkudoktoreos.measurement.measurement import (
MeasurementCommonSettings,
MeasurementDataRecord,
get_measurement,
)
from akkudoktoreos.utils.datetimeutil import (
DateTime,
Duration,
compare_datetimes,
to_datetime,
to_duration,
)
@@ -41,8 +48,9 @@ class TestMeasurementDataRecord:
def test_getitem_existing_field(self, record):
"""Test that __getitem__ returns correct value for existing native field."""
record.date_time = "2024-01-01T00:00:00+00:00"
assert record["date_time"] is not None
date_time = "2024-01-01T00:00:00+00:00"
record.date_time = date_time
assert compare_datetimes(record["date_time"], to_datetime(date_time)).equal
def test_getitem_existing_measurement(self, record):
"""Test that __getitem__ retrieves existing measurement values."""
@@ -220,6 +228,7 @@ class TestMeasurement:
# Load meter readings are in kWh
config_eos.measurement.load_emr_keys = ["load0_mr", "load1_mr", "load2_mr", "load3_mr"]
measurement = get_measurement()
measurement.delete_by_datetime(None, None)
record0 = MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=0),
load0_mr=100,
@@ -227,52 +236,54 @@ class TestMeasurement:
)
assert record0.load0_mr == 100
assert record0.load1_mr == 200
measurement.records = [
records = [
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=0),
date_time=to_datetime("2023-01-01T00:00:00"),
load0_mr=100,
load1_mr=200,
),
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=1),
date_time=to_datetime("2023-01-01T01:00:00"),
load0_mr=150,
load1_mr=250,
),
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=2),
date_time=to_datetime("2023-01-01T02:00:00"),
load0_mr=200,
load1_mr=300,
),
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=3),
date_time=to_datetime("2023-01-01T03:00:00"),
load0_mr=250,
load1_mr=350,
),
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=4),
date_time=to_datetime("2023-01-01T04:00:00"),
load0_mr=300,
load1_mr=400,
),
MeasurementDataRecord(
date_time=datetime(2023, 1, 1, hour=5),
date_time=to_datetime("2023-01-01T05:00:00"),
load0_mr=350,
load1_mr=450,
),
]
for record in records:
measurement.insert_by_datetime(record)
return measurement
def test_interval_count(self, measurement_eos):
"""Test interval count calculation."""
start = datetime(2023, 1, 1, 0)
end = datetime(2023, 1, 1, 3)
start = to_datetime("2023-01-01T00:00:00")
end = to_datetime("2023-01-01T03:00:00")
interval = duration(hours=1)
assert measurement_eos._interval_count(start, end, interval) == 3
def test_interval_count_invalid_end_before_start(self, measurement_eos):
"""Test interval count raises ValueError when end_datetime is before start_datetime."""
start = datetime(2023, 1, 1, 3)
end = datetime(2023, 1, 1, 0)
start = to_datetime("2023-01-01T03:00:00")
end = to_datetime("2023-01-01T00:00:00")
interval = duration(hours=1)
with pytest.raises(ValueError, match="end_datetime must be after start_datetime"):
@@ -280,8 +291,8 @@ class TestMeasurement:
def test_interval_count_invalid_non_positive_interval(self, measurement_eos):
"""Test interval count raises ValueError when interval is non-positive."""
start = datetime(2023, 1, 1, 0)
end = datetime(2023, 1, 1, 3)
start = to_datetime("2023-01-01T00:00:00")
end = to_datetime("2023-01-01T03:00:00")
with pytest.raises(ValueError, match="interval must be positive"):
measurement_eos._interval_count(start, end, duration(hours=0))
@@ -289,8 +300,8 @@ class TestMeasurement:
def test_energy_from_meter_readings_valid_input(self, measurement_eos):
"""Test _energy_from_meter_readings with valid inputs and proper alignment of load data."""
key = "load0_mr"
start_datetime = datetime(2023, 1, 1, 0)
end_datetime = datetime(2023, 1, 1, 5)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
interval = duration(hours=1)
load_array = measurement_eos._energy_from_meter_readings(
@@ -303,12 +314,12 @@ class TestMeasurement:
def test_energy_from_meter_readings_empty_array(self, measurement_eos):
"""Test _energy_from_meter_readings with no data (empty array)."""
key = "load0_mr"
start_datetime = datetime(2023, 1, 1, 0)
end_datetime = datetime(2023, 1, 1, 5)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
interval = duration(hours=1)
# Use empyt records array
measurement_eos.records = []
measurement_eos.delete_by_datetime(start_datetime, end_datetime)
load_array = measurement_eos._energy_from_meter_readings(
key, start_datetime, end_datetime, interval
@@ -324,25 +335,46 @@ class TestMeasurement:
def test_energy_from_meter_readings_misaligned_array(self, measurement_eos):
"""Test _energy_from_meter_readings with misaligned array size."""
key = "load1_mr"
start_datetime = measurement_eos.min_datetime
end_datetime = measurement_eos.max_datetime
interval = duration(hours=1)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
# Use misaligned array, latest interval set to 2 hours (instead of 1 hour)
measurement_eos.records[-1].date_time = datetime(2023, 1, 1, 6)
latest_record_datetime = to_datetime("2023-01-01T05:00:00")
new_record_datetime = to_datetime("2023-01-01T06:00:00")
record = measurement_eos.get_by_datetime(latest_record_datetime)
assert record is not None
measurement_eos.delete_by_datetime(start_datetime = latest_record_datetime,
end_datetime = new_record_datetime)
record.date_time = new_record_datetime
measurement_eos.insert_by_datetime(record)
# Check test setup
dates, values = measurement_eos.key_to_lists(key, start_datetime, None)
assert dates == [
to_datetime("2023-01-01T00:00:00"),
to_datetime("2023-01-01T01:00:00"),
to_datetime("2023-01-01T02:00:00"),
to_datetime("2023-01-01T03:00:00"),
to_datetime("2023-01-01T04:00:00"),
to_datetime("2023-01-01T06:00:00"),
]
assert values == [200, 250, 300, 350, 400, 450]
array = measurement_eos.key_to_array(key, start_datetime, end_datetime + interval, interval=interval)
np.testing.assert_array_equal(array, [200, 250, 300, 350, 400, 425])
load_array = measurement_eos._energy_from_meter_readings(
key, start_datetime, end_datetime, interval
)
expected_load_array = np.array([50, 50, 50, 50, 25]) # Differences between consecutive readings
expected_load_array = np.array([50., 50., 50., 50., 25.]) # Differences between consecutive readings
np.testing.assert_array_equal(load_array, expected_load_array)
def test_energy_from_meter_readings_partial_data(self, measurement_eos, caplog):
"""Test _energy_from_meter_readings with partial data (misaligned but empty array)."""
key = "load2_mr"
start_datetime = datetime(2023, 1, 1, 0)
end_datetime = datetime(2023, 1, 1, 5)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
interval = duration(hours=1)
with caplog.at_level("DEBUG"):
@@ -359,8 +391,8 @@ class TestMeasurement:
def test_energy_from_meter_readings_negative_interval(self, measurement_eos):
"""Test _energy_from_meter_readings with a negative interval."""
key = "load3_mr"
start_datetime = datetime(2023, 1, 1, 0)
end_datetime = datetime(2023, 1, 1, 5)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
interval = duration(hours=-1)
with pytest.raises(ValueError, match="interval must be positive"):
@@ -368,11 +400,11 @@ class TestMeasurement:
def test_load_total_kwh(self, measurement_eos):
"""Test total load calculation."""
start = datetime(2023, 1, 1, 0)
end = datetime(2023, 1, 1, 2)
start_datetime = to_datetime("2023-01-01T03:00:00")
end_datetime = to_datetime("2023-01-01T05:00:00")
interval = duration(hours=1)
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
# Expected total load per interval
expected = np.array([100, 100]) # Differences between consecutive meter readings
@@ -381,20 +413,20 @@ class TestMeasurement:
def test_load_total_kwh_no_data(self, measurement_eos):
"""Test total load calculation with no data."""
measurement_eos.records = []
start = datetime(2023, 1, 1, 0)
end = datetime(2023, 1, 1, 3)
start_datetime = to_datetime("2023-01-01T00:00:00")
end_datetime = to_datetime("2023-01-01T03:00:00")
interval = duration(hours=1)
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
expected = np.zeros(3) # No data, so all intervals are zero
np.testing.assert_array_equal(result, expected)
def test_load_total_kwh_partial_intervals(self, measurement_eos):
"""Test total load calculation with partial intervals."""
start = datetime(2023, 1, 1, 0, 30) # Start in the middle of an interval
end = datetime(2023, 1, 1, 1, 30) # End in the middle of another interval
start_datetime = to_datetime("2023-01-01T00:30:00") # Start in the middle of an interval
end_datetime = to_datetime("2023-01-01T01:30:00") # End in the middle of another interval
interval = duration(hours=1)
result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval)
result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval)
expected = np.array([100]) # Only one complete interval covered
np.testing.assert_array_equal(result, expected)

View File

@@ -1,6 +1,7 @@
import pytest
from pydantic import ValidationError
from akkudoktoreos.core.coreabc import get_prediction
from akkudoktoreos.prediction.elecpriceakkudoktor import ElecPriceAkkudoktor
from akkudoktoreos.prediction.elecpriceenergycharts import ElecPriceEnergyCharts
from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport
@@ -15,7 +16,6 @@ from akkudoktoreos.prediction.loadvrm import LoadVrm
from akkudoktoreos.prediction.prediction import (
Prediction,
PredictionCommonSettings,
get_prediction,
)
from akkudoktoreos.prediction.pvforecastakkudoktor import PVForecastAkkudoktor
from akkudoktoreos.prediction.pvforecastimport import PVForecastImport

View File

@@ -7,10 +7,10 @@ import pendulum
import pytest
from pydantic import Field
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.prediction import PredictionCommonSettings
from akkudoktoreos.prediction.predictionabc import (
PredictionBase,
PredictionABC,
PredictionContainer,
PredictionProvider,
PredictionRecord,
@@ -28,7 +28,7 @@ class DerivedConfig(PredictionCommonSettings):
class_constant: Optional[int] = Field(default=None, description="Test config by class constant")
class DerivedBase(PredictionBase):
class DerivedBase(PredictionABC):
instance_field: Optional[str] = Field(default=None, description="Field Value")
class_constant: ClassVar[int] = 30
@@ -84,7 +84,7 @@ class DerivedPredictionContainer(PredictionContainer):
# ----------
class TestPredictionBase:
class TestPredictionABC:
@pytest.fixture
def base(self, monkeypatch):
# Provide default values for configuration
@@ -216,17 +216,19 @@ class TestPredictionProvider:
def test_delete_by_datetime(self, provider, sample_start_datetime):
"""Test `delete_by_datetime` method for removing records by datetime range."""
# Add records to the provider for deletion testing
provider.records = [
records = [
self.create_test_record(sample_start_datetime - to_duration("3 hours"), 1),
self.create_test_record(sample_start_datetime - to_duration("1 hour"), 2),
self.create_test_record(sample_start_datetime + to_duration("1 hour"), 3),
]
for record in records:
provider.insert_by_datetime(record)
provider.delete_by_datetime(
start_datetime=sample_start_datetime - to_duration("2 hours"),
end_datetime=sample_start_datetime + to_duration("2 hours"),
)
assert len(provider.records) == 1, (
assert len(provider) == 1, (
"Only one record should remain after deletion by datetime."
)
assert provider.records[0].date_time == sample_start_datetime - to_duration("3 hours"), (
@@ -243,15 +245,17 @@ class TestPredictionContainer:
@pytest.fixture
def container_with_providers(self):
record1 = self.create_test_record(datetime(2023, 11, 5), 1)
record2 = self.create_test_record(datetime(2023, 11, 6), 2)
record3 = self.create_test_record(datetime(2023, 11, 7), 3)
records = [
# Test records - include 'prediction_value' key
self.create_test_record(datetime(2023, 11, 5), 1),
self.create_test_record(datetime(2023, 11, 6), 2),
self.create_test_record(datetime(2023, 11, 7), 3),
]
provider = DerivedPredictionProvider()
provider.clear()
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
assert len(provider) == 0
provider.append(record1)
provider.append(record2)
provider.append(record3)
for record in records:
provider.insert_by_datetime(record)
assert len(provider) == 3
container = DerivedPredictionContainer()
container.providers.clear()
@@ -378,7 +382,9 @@ class TestPredictionContainer:
assert len(container_with_providers.providers) == 1
# check all keys are available (don't care for position)
for key in ["prediction_value", "date_time"]:
assert key in list(container_with_providers.keys())
assert key in container_with_providers.record_keys
for key in ["prediction_value", "date_time"]:
assert key in container_with_providers.keys()
series = container_with_providers["prediction_value"]
assert isinstance(series, pd.Series)
assert series.name == "prediction_value"

View File

@@ -5,8 +5,7 @@ from unittest.mock import Mock, patch
import pytest
from loguru import logger
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.prediction.prediction import get_prediction
from akkudoktoreos.core.coreabc import get_ems, get_prediction
from akkudoktoreos.prediction.pvforecastakkudoktor import (
AkkudoktorForecastHorizon,
AkkudoktorForecastMeta,
@@ -137,7 +136,7 @@ def provider():
def provider_empty_instance():
"""Fixture that returns an empty instance of PVForecast."""
empty_instance = PVForecastAkkudoktor()
empty_instance.clear()
empty_instance.delete_by_datetime(start_datetime=None, end_datetime=None)
assert len(empty_instance) == 0
return empty_instance
@@ -277,7 +276,7 @@ def test_pvforecast_akkudoktor_update_with_sample_forecast(
ems_eos.set_start_datetime(sample_forecast_start)
provider.update_data(force_enable=True, force_update=True)
assert compare_datetimes(provider.ems_start_datetime, sample_forecast_start).equal
assert compare_datetimes(provider[0].date_time, to_datetime(sample_forecast_start)).equal
assert compare_datetimes(provider.records[0].date_time, to_datetime(sample_forecast_start)).equal
# Report Generation Test
@@ -290,7 +289,7 @@ def test_report_ac_power_and_measurement(provider, config_eos):
pvforecast_dc_power=450.0,
pvforecast_ac_power=400.0,
)
provider.append(record)
provider.insert_by_datetime(record)
report = provider.report_ac_power_and_measurement()
assert "DC: 450.0" in report
@@ -323,19 +322,19 @@ def test_timezone_behaviour(
expected_datetime = to_datetime("2024-10-06T00:00:00+0200", in_timezone=other_timezone)
assert compare_datetimes(other_start_datetime, expected_datetime).equal
provider.clear()
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
assert len(provider) == 0
ems_eos = get_ems()
ems_eos.set_start_datetime(other_start_datetime)
provider.update_data(force_update=True)
assert compare_datetimes(provider.ems_start_datetime, other_start_datetime).equal
# Check wether first record starts at requested sample start time
assert compare_datetimes(provider[0].date_time, sample_forecast_start).equal
assert compare_datetimes(provider.records[0].date_time, sample_forecast_start).equal
# Test updating AC power measurement for a specific date.
provider.update_value(sample_forecast_start, "pvforecastakkudoktor_ac_power_measured", 1000)
# Check wether first record was filled with ac power measurement
assert provider[0].pvforecastakkudoktor_ac_power_measured == 1000
assert provider.records[0].pvforecastakkudoktor_ac_power_measured == 1000
# Test fetching temperature forecast for a specific date.
other_end_datetime = other_start_datetime + to_duration("24 hours")

View File

@@ -1,11 +1,12 @@
import json
from pathlib import Path
import numpy.testing as npt
import pytest
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.pvforecastimport import PVForecastImport
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
@@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos):
)
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
"""Test fetching forecast from import."""
key = "pvforecast_ac_power"
ems_eos = get_ems()
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
if from_file:
@@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
else:
config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path = None
assert config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path is None
provider.clear()
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
# Call the method
provider.update_data()
@@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
assert provider.ems_start_datetime is not None
assert provider.total_hours is not None
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
values = sample_import_1_json["pvforecast_ac_power"]
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
for i, mapping in enumerate(value_datetime_mapping):
assert i < len(provider.records)
expected_datetime, expected_value_index = mapping
expected_value = values[expected_value_index]
result_datetime = provider.records[i].date_time
result_value = provider.records[i]["pvforecast_ac_power"]
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
# print(f"{i}: Result: {result_datetime}:{result_value}")
assert compare_datetimes(result_datetime, expected_datetime).equal
assert result_value == expected_value
expected_values = sample_import_1_json[key]
result_values = provider.key_to_array(
key=key,
start_datetime=provider.ems_start_datetime,
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
interval=to_duration("1 hour"),
)
# Allow for some difference due to value calculation on DST change
npt.assert_allclose(result_values, expected_values, rtol=0.001)

View File

@@ -0,0 +1,701 @@
"""Tests for RetentionManager and JobState."""
from __future__ import annotations
import asyncio
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from loguru import logger
import akkudoktoreos.server.retentionmanager
from akkudoktoreos.server.retentionmanager import JobState, RetentionManager
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
INTERVAL = 10.0
DUE_INTERVAL = 0.001 # non-zero so interval() does not fall back to fallback_interval
FALLBACK = 300.0
def make_config_getter(interval: float = INTERVAL) -> Any:
"""Return a simple config getter that always yields ``interval`` for any key."""
return lambda key: interval
def make_config_getter_none() -> Any:
"""Return a config getter that always yields ``None`` (job disabled)."""
return lambda key: None
def make_manager(interval: float = INTERVAL, shutdown_timeout: float = 5.0) -> RetentionManager:
"""Return a ``RetentionManager`` backed by a fixed-interval config getter."""
return RetentionManager(make_config_getter(interval), shutdown_timeout=shutdown_timeout)
def make_manager_none(shutdown_timeout: float = 5.0) -> RetentionManager:
"""Return a ``RetentionManager`` whose config getter always returns None (all jobs disabled)."""
return RetentionManager(make_config_getter_none(), shutdown_timeout=shutdown_timeout)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestRetentionManager:
"""Tests for :class:`RetentionManager` and :class:`JobState`."""
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
def test_init_stores_config_getter(self) -> None:
"""The config getter passed to __init__ is stored and forwarded to jobs."""
getter = make_config_getter()
manager = RetentionManager(getter)
assert manager._config_getter is getter
def test_init_empty_job_registry(self) -> None:
"""A newly created manager has no registered jobs."""
manager = make_manager()
assert manager._jobs == {}
# ------------------------------------------------------------------
# register / unregister
# ------------------------------------------------------------------
def test_register_adds_job(self) -> None:
"""Registering a function adds a JobState entry."""
manager = make_manager()
func = MagicMock()
manager.register("job1", func, interval_attr="some/key")
assert "job1" in manager._jobs
def test_register_job_state_fields(self) -> None:
"""Registered JobState carries the correct initial field values."""
manager = make_manager()
func = MagicMock()
manager.register("job1", func, interval_attr="some/key", fallback_interval=60.0)
job = manager._jobs["job1"]
assert job.name == "job1"
assert job.func is func
assert job.interval_attr == "some/key"
assert job.fallback_interval == 60.0
assert job.config_getter is manager._config_getter
assert job.on_exception is None
assert job.last_run_at == 0.0
assert job.run_count == 0
assert job.is_running is False
def test_register_stores_on_exception(self) -> None:
"""The on_exception callback is stored on the JobState."""
manager = make_manager()
handler = MagicMock()
manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler)
assert manager._jobs["job1"].on_exception is handler
def test_register_duplicate_raises(self) -> None:
"""Registering the same name twice raises ValueError."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
with pytest.raises(ValueError, match="job1"):
manager.register("job1", MagicMock(), interval_attr="k")
def test_unregister_removes_job(self) -> None:
"""Unregistering a job removes it from the registry."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
manager.unregister("job1")
assert "job1" not in manager._jobs
def test_unregister_missing_job_is_noop(self) -> None:
"""Unregistering a non-existent job does not raise."""
manager = make_manager()
manager.unregister("nonexistent") # must not raise
# ------------------------------------------------------------------
# JobState.interval()
# ------------------------------------------------------------------
def test_job_interval_from_config_getter(self) -> None:
"""JobState.interval() returns the value provided by config_getter."""
manager = make_manager(interval=42.0)
manager.register("job1", MagicMock(), interval_attr="k")
assert manager._jobs["job1"].interval() == 42.0
def test_job_interval_none_when_config_returns_none(self) -> None:
"""JobState.interval() returns None when config_getter returns None (job disabled)."""
manager = make_manager_none()
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=FALLBACK)
assert manager._jobs["job1"].interval() is None
def test_job_interval_none_does_not_fall_back(self) -> None:
"""A None config value must NOT fall back to fallback_interval -- None means disabled."""
manager = make_manager_none()
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0)
# If None incorrectly fell back, this would return 99.0 instead of None
assert manager._jobs["job1"].interval() is None
def test_job_interval_fallback_on_key_error(self) -> None:
"""JobState.interval() uses fallback_interval when config_getter raises KeyError."""
manager = RetentionManager(lambda key: (_ for _ in ()).throw(KeyError(key)))
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0)
assert manager._jobs["job1"].interval() == 99.0
def test_job_interval_fallback_on_index_error(self) -> None:
"""JobState.interval() uses fallback_interval when config_getter raises IndexError."""
manager = RetentionManager(lambda key: (_ for _ in ()).throw(IndexError()))
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=77.0)
assert manager._jobs["job1"].interval() == 77.0
def test_job_interval_fallback_on_zero_value(self) -> None:
"""JobState.interval() uses fallback_interval when config_getter returns zero."""
manager = RetentionManager(lambda key: 0)
manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=55.0)
assert manager._jobs["job1"].interval() == 55.0
# ------------------------------------------------------------------
# JobState.is_due()
# ------------------------------------------------------------------
def test_job_is_due_when_never_run(self) -> None:
"""A job is always due when it has never been run (last_run_at == 0.0)."""
manager = make_manager(interval=INTERVAL)
manager.register("job1", MagicMock(), interval_attr="k")
assert manager._jobs["job1"].is_due() is True
def test_job_is_not_due_immediately_after_run(self) -> None:
"""A job is not due immediately after last_run_at is set to now."""
manager = make_manager(interval=INTERVAL)
manager.register("job1", MagicMock(), interval_attr="k")
manager._jobs["job1"].last_run_at = time.monotonic()
assert manager._jobs["job1"].is_due() is False
def test_job_is_due_after_interval_elapsed(self) -> None:
"""A job becomes due once the interval has passed since last_run_at."""
manager = make_manager(interval=1.0)
manager.register("job1", MagicMock(), interval_attr="k")
manager._jobs["job1"].last_run_at = time.monotonic() - 2.0 # 2 s ago > 1 s interval
assert manager._jobs["job1"].is_due() is True
def test_job_is_never_due_when_interval_is_none(self) -> None:
"""is_due() returns False when interval() is None, even if last_run_at is 0."""
manager = make_manager_none()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
# last_run_at == 0.0 would make any enabled job due immediately
assert job.last_run_at == 0.0
assert job.is_due() is False
def test_job_is_never_due_when_disabled_regardless_of_last_run(self) -> None:
"""is_due() stays False for a disabled job even long after its last run."""
manager = make_manager_none()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
job.last_run_at = time.monotonic() - 365 * 24 * 3600 # "ran" a year ago
assert job.is_due() is False
# ------------------------------------------------------------------
# JobState.summary()
# ------------------------------------------------------------------
def test_summary_keys(self) -> None:
"""summary() returns all expected keys including interval_s."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
summary = manager._jobs["job1"].summary()
assert set(summary.keys()) == {
"name", "interval_attr", "interval_s", "last_run_at",
"last_duration_s", "last_error", "run_count", "is_running",
}
def test_summary_interval_s_reflects_config(self) -> None:
"""summary()['interval_s'] matches the value returned by interval()."""
manager = make_manager(interval=42.0)
manager.register("job1", MagicMock(), interval_attr="k")
assert manager._jobs["job1"].summary()["interval_s"] == 42.0
def test_summary_interval_s_is_none_when_disabled(self) -> None:
"""summary()['interval_s'] is None when the job is disabled via config."""
manager = make_manager_none()
manager.register("job1", MagicMock(), interval_attr="k")
assert manager._jobs["job1"].summary()["interval_s"] is None
def test_summary_values(self) -> None:
"""summary() reflects the current JobState values."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="my/key")
job = manager._jobs["job1"]
job.last_run_at = 1234.5
job.last_duration = 0.12345
job.last_error = "oops"
job.run_count = 3
job.is_running = True
s = job.summary()
assert s["name"] == "job1"
assert s["interval_attr"] == "my/key"
assert s["last_run_at"] == 1234.5
assert s["last_duration_s"] == 0.1235 # rounded to 4 dp
assert s["last_error"] == "oops"
assert s["run_count"] == 3
assert s["is_running"] is True
# ------------------------------------------------------------------
# status()
# ------------------------------------------------------------------
def test_status_empty(self) -> None:
"""status() returns an empty list when no jobs are registered."""
assert make_manager().status() == []
def test_status_contains_all_jobs(self) -> None:
"""status() returns one entry per registered job."""
manager = make_manager()
manager.register("a", MagicMock(), interval_attr="k1")
manager.register("b", MagicMock(), interval_attr="k2")
names = {s["name"] for s in manager.status()}
assert names == {"a", "b"}
def test_status_shows_disabled_job(self) -> None:
"""status() includes disabled jobs with interval_s == None."""
manager = make_manager_none()
manager.register("disabled", MagicMock(), interval_attr="k")
entries = manager.status()
assert len(entries) == 1
assert entries[0]["interval_s"] is None
# ------------------------------------------------------------------
# tick() -- job dispatch
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tick_runs_due_sync_job(self) -> None:
"""tick() executes a sync job that is due."""
manager = make_manager(interval=DUE_INTERVAL)
func = MagicMock()
manager.register("job1", func, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
await manager.shutdown()
func.assert_called_once()
@pytest.mark.asyncio
async def test_tick_runs_due_async_job(self) -> None:
"""tick() executes an async job that is due."""
manager = make_manager(interval=DUE_INTERVAL)
func = AsyncMock()
manager.register("job1", func, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
await manager.shutdown()
func.assert_called_once()
@pytest.mark.asyncio
async def test_tick_skips_not_due_job(self) -> None:
"""tick() does not execute a job whose interval has not yet elapsed."""
manager = make_manager(interval=9999.0)
func = MagicMock()
manager.register("job1", func, interval_attr="k")
manager._jobs["job1"].last_run_at = time.monotonic() # just ran
await manager.tick()
await asyncio.sleep(0)
await manager.shutdown()
func.assert_not_called()
@pytest.mark.asyncio
async def test_tick_skips_disabled_job(self) -> None:
"""tick() never executes a job whose interval is None, even if never run before."""
manager = make_manager_none()
func = MagicMock()
manager.register("disabled", func, interval_attr="k")
job = manager._jobs["disabled"]
# last_run_at == 0.0 would fire any enabled job immediately
assert job.last_run_at == 0.0
await manager.tick()
await asyncio.sleep(0)
await manager.shutdown()
func.assert_not_called()
@pytest.mark.asyncio
async def test_tick_skips_disabled_job_adds_no_task(self) -> None:
"""tick() adds no task to _running_tasks for a disabled job."""
manager = make_manager_none()
manager.register("disabled", AsyncMock(), interval_attr="k")
await manager.tick()
await asyncio.sleep(0)
assert len(manager._running_tasks) == 0
@pytest.mark.asyncio
async def test_tick_enabled_and_disabled_jobs_mixed(self) -> None:
"""tick() fires enabled jobs and silently skips disabled ones in the same manager."""
results: list[str] = []
async def enabled_job() -> None:
results.append("ran")
manager = RetentionManager(
lambda key: DUE_INTERVAL if key == "enabled/interval" else None,
shutdown_timeout=5.0,
)
manager.register("enabled", enabled_job, interval_attr="enabled/interval")
manager.register("disabled", AsyncMock(), interval_attr="disabled/interval")
await manager.tick()
await asyncio.sleep(0)
await asyncio.sleep(0)
await manager.shutdown()
assert results == ["ran"], "Only the enabled job must have run"
@pytest.mark.asyncio
async def test_tick_skips_already_running_job(self) -> None:
"""tick() does not start a job that is still marked as running."""
manager = make_manager(interval=DUE_INTERVAL)
func = MagicMock()
manager.register("job1", func, interval_attr="k")
manager._jobs["job1"].is_running = True
await manager.tick()
await asyncio.sleep(0)
await manager.shutdown()
func.assert_not_called()
@pytest.mark.asyncio
async def test_tick_runs_multiple_jobs_concurrently(self) -> None:
"""tick() fires all due jobs as independent tasks."""
manager = make_manager(interval=DUE_INTERVAL)
results: list[str] = []
async def job_a() -> None:
results.append("a")
async def job_b() -> None:
results.append("b")
manager.register("a", job_a, interval_attr="k")
manager.register("b", job_b, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
await manager.shutdown()
assert sorted(results) == ["a", "b"]
@pytest.mark.asyncio
async def test_tick_adds_tasks_to_running_set(self) -> None:
"""tick() adds a task to _running_tasks for each due job."""
barrier = asyncio.Event()
manager = make_manager(interval=DUE_INTERVAL)
async def blocking_job() -> None:
await barrier.wait()
manager.register("job1", blocking_job, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
# Task is still running (barrier not set), so it must be in the set.
assert len(manager._running_tasks) == 1
barrier.set()
await manager.shutdown()
@pytest.mark.asyncio
async def test_tick_removes_task_from_running_set_on_completion(self) -> None:
"""Completed tasks are removed from _running_tasks automatically."""
manager = make_manager(interval=DUE_INTERVAL)
manager.register("job1", AsyncMock(), interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await manager.shutdown()
assert len(manager._running_tasks) == 0
# ------------------------------------------------------------------
# shutdown()
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_shutdown_returns_immediately_when_no_tasks(self) -> None:
"""shutdown() completes without blocking when no tasks are running."""
manager = make_manager()
await manager.shutdown() # must return promptly without raising
@pytest.mark.asyncio
async def test_shutdown_waits_for_in_flight_task(self) -> None:
"""shutdown() blocks until a long-running job task finishes."""
barrier = asyncio.Event()
finished: list[bool] = []
manager = make_manager(interval=DUE_INTERVAL)
async def slow_job() -> None:
await barrier.wait()
finished.append(True)
manager.register("job1", slow_job, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
assert finished == [] # job still blocked
barrier.set()
await manager.shutdown()
assert finished == [True] # job completed before shutdown returned
@pytest.mark.asyncio
async def test_shutdown_waits_for_multiple_in_flight_tasks(self) -> None:
"""shutdown() waits for all concurrently running job tasks."""
barrier = asyncio.Event()
finished: list[str] = []
manager = make_manager(interval=DUE_INTERVAL)
async def slow_a() -> None:
await barrier.wait()
finished.append("a")
async def slow_b() -> None:
await barrier.wait()
finished.append("b")
manager.register("a", slow_a, interval_attr="k")
manager.register("b", slow_b, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await asyncio.sleep(0) # second yield ensures tasks have started
assert finished == []
barrier.set()
await manager.shutdown()
assert sorted(finished) == ["a", "b"]
@pytest.mark.asyncio
async def test_shutdown_does_not_raise_when_job_failed(self) -> None:
"""shutdown() completes without raising even if a job task raised an exception."""
manager = make_manager(interval=DUE_INTERVAL)
def failing_func() -> None:
raise RuntimeError("job error")
manager.register("job1", failing_func, interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await manager.shutdown() # must not raise
@pytest.mark.asyncio
async def test_shutdown_clears_running_tasks_set(self) -> None:
"""_running_tasks is empty after shutdown() completes."""
manager = make_manager(interval=DUE_INTERVAL)
manager.register("job1", AsyncMock(), interval_attr="k")
await manager.tick()
await asyncio.sleep(0) # yield so ensure_future tasks are scheduled
await manager.shutdown()
assert manager._running_tasks == set()
@pytest.mark.asyncio
async def test_shutdown_timeout_returns_without_blocking(self) -> None:
"""shutdown() returns once the timeout elapses even if a job is still running."""
stuck = asyncio.Event() # never set -- job blocks forever
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
async def forever() -> None:
await stuck.wait()
manager.register("stuck", forever, interval_attr="k")
await manager.tick()
await asyncio.sleep(0)
await asyncio.sleep(0)
# Must return within the timeout, not block forever.
await manager.shutdown()
@pytest.mark.asyncio
async def test_shutdown_timeout_logs_error_for_pending_jobs(self) -> None:
"""An error is logged listing jobs still running after the timeout."""
stuck = asyncio.Event()
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
async def forever() -> None:
await stuck.wait()
manager.register("stuck_job", forever, interval_attr="k")
await manager.tick()
await asyncio.sleep(0)
await asyncio.sleep(0)
with patch.object(logger, "error") as mock_error:
await manager.shutdown()
assert mock_error.called, "Expected logger.error to be called on timeout"
# All positional args joined: the stuck job name must appear.
logged = str(mock_error.call_args_list)
assert "stuck_job" in logged
@pytest.mark.asyncio
async def test_shutdown_timeout_clears_running_tasks_set(self) -> None:
"""_running_tasks is cleared even when the timeout elapses."""
stuck = asyncio.Event()
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05)
async def forever() -> None:
await stuck.wait()
manager.register("stuck", forever, interval_attr="k")
await manager.tick()
await asyncio.sleep(0)
await asyncio.sleep(0)
await manager.shutdown()
assert manager._running_tasks == set()
@pytest.mark.asyncio
async def test_shutdown_no_error_logged_when_all_finish_in_time(self) -> None:
"""No error is logged when all tasks complete within the timeout."""
manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=5.0)
manager.register("job1", AsyncMock(), interval_attr="k")
await manager.tick()
await asyncio.sleep(0)
with patch.object(logger, "error") as mock_error:
await manager.shutdown()
mock_error.assert_not_called()
def test_init_stores_shutdown_timeout(self) -> None:
"""The shutdown_timeout passed to __init__ is stored on the instance."""
manager = RetentionManager(make_config_getter(), shutdown_timeout=99.0)
assert manager._shutdown_timeout == 99.0
def test_init_default_shutdown_timeout(self) -> None:
"""The default shutdown_timeout is 30 seconds."""
manager = RetentionManager(make_config_getter())
assert manager._shutdown_timeout == 30.0
# ------------------------------------------------------------------
# _run_job() -- state updates
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_job_increments_run_count(self) -> None:
"""_run_job() increments run_count after each execution."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
await manager._run_job(job)
await manager._run_job(job)
assert job.run_count == 2
@pytest.mark.asyncio
async def test_run_job_updates_last_run_at(self) -> None:
"""_run_job() sets last_run_at to a recent monotonic timestamp."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
before = time.monotonic()
await manager._run_job(job)
assert job.last_run_at >= before
@pytest.mark.asyncio
async def test_run_job_updates_last_duration(self) -> None:
"""_run_job() records a non-negative last_duration."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
await manager._run_job(job)
assert job.last_duration >= 0.0
@pytest.mark.asyncio
async def test_run_job_clears_is_running_on_success(self) -> None:
"""is_running is False after a successful job execution."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
await manager._run_job(job)
assert job.is_running is False
@pytest.mark.asyncio
async def test_run_job_clears_last_error_on_success(self) -> None:
"""last_error is set to None after a successful execution."""
manager = make_manager()
manager.register("job1", MagicMock(), interval_attr="k")
job = manager._jobs["job1"]
job.last_error = "stale error"
await manager._run_job(job)
assert job.last_error is None
# ------------------------------------------------------------------
# _run_job() -- exception handling
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_job_stores_exception_message(self) -> None:
"""last_error is set to the exception message when the job raises."""
manager = make_manager()
def failing_func() -> None:
raise RuntimeError("boom")
manager.register("job1", failing_func, interval_attr="k")
job = manager._jobs["job1"]
await manager._run_job(job)
assert job.last_error == "boom"
@pytest.mark.asyncio
async def test_run_job_still_updates_state_after_exception(self) -> None:
"""run_count and last_run_at are updated even when the job raises."""
manager = make_manager()
def failing_func() -> None:
raise RuntimeError("boom")
manager.register("job1", failing_func, interval_attr="k")
job = manager._jobs["job1"]
before = time.monotonic()
await manager._run_job(job)
assert job.run_count == 1
assert job.last_run_at >= before
assert job.is_running is False
@pytest.mark.asyncio
async def test_run_job_calls_sync_on_exception_handler(self) -> None:
"""A sync on_exception handler is called with the raised exception."""
manager = make_manager()
handler = MagicMock()
exc = RuntimeError("oops")
def failing_func() -> None:
raise exc
manager.register("job1", failing_func, interval_attr="k", on_exception=handler)
await manager._run_job(manager._jobs["job1"])
handler.assert_called_once_with(exc)
@pytest.mark.asyncio
async def test_run_job_calls_async_on_exception_handler(self) -> None:
"""An async on_exception handler is awaited with the raised exception."""
manager = make_manager()
handler = AsyncMock()
exc = RuntimeError("oops")
def failing_func() -> None:
raise exc
manager.register("job1", failing_func, interval_attr="k", on_exception=handler)
await manager._run_job(manager._jobs["job1"])
handler.assert_called_once_with(exc)
@pytest.mark.asyncio
async def test_run_job_no_on_exception_handler_does_not_raise(self) -> None:
"""A failing job without on_exception does not propagate the exception."""
manager = make_manager()
def failing_func() -> None:
raise RuntimeError("silent failure")
manager.register("job1", failing_func, interval_attr="k")
await manager._run_job(manager._jobs["job1"]) # must not raise
@pytest.mark.asyncio
async def test_run_job_on_exception_not_called_on_success(self) -> None:
"""on_exception is not called when the job succeeds."""
manager = make_manager()
handler = MagicMock()
manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler)
await manager._run_job(manager._jobs["job1"])
handler.assert_not_called()

View File

@@ -2,11 +2,22 @@
import subprocess
import sys
from pathlib import Path
from typing import Optional, Union
import pytest
import yaml
from akkudoktoreos.core.version import _version_calculate, _version_hash
from akkudoktoreos.core.version import (
ALLOWED_SUFFIXES,
DIR_PACKAGE_ROOT,
EXCLUDED_DIR_PATTERNS,
EXCLUDED_FILES,
HashConfig,
_version_calculate,
_version_hash,
collect_files,
hash_files,
)
DIR_PROJECT_ROOT = Path(__file__).parent.parent
GET_VERSION_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "get_version.py"
@@ -14,11 +25,166 @@ BUMP_DEV_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "bump_dev_version.py"
UPDATE_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "update_version.py"
# --- Git helpers ---
def get_git_tracked_files(repo_path: Path) -> Optional[set[Path]]:
"""Get set of all files tracked by git in the repository.
Returns None if not a git repository or git command fails.
"""
try:
result = subprocess.run(
["git", "ls-files"],
cwd=repo_path,
capture_output=True,
text=True,
check=True
)
# Convert relative paths to absolute paths
tracked_files = {
(repo_path / line.strip()).resolve()
for line in result.stdout.splitlines()
if line.strip()
}
return tracked_files
except (subprocess.CalledProcessError, FileNotFoundError):
return None
def is_git_repository(path: Path) -> bool:
"""Check if path is inside a git repository."""
try:
subprocess.run(
["git", "rev-parse", "--git-dir"],
cwd=path,
capture_output=True,
check=True
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def get_git_root(path: Path) -> Optional[Path]:
"""Get the root directory of the git repository containing path."""
try:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
cwd=path,
capture_output=True,
text=True,
check=True
)
return Path(result.stdout.strip())
except (subprocess.CalledProcessError, FileNotFoundError):
return None
def check_files_in_git(
files: list[Path],
base_path: Optional[Path] = None
) -> tuple[list[Path], list[Path]]:
"""Check which files are tracked by git.
Args:
files: List of files to check
base_path: Base path to check for git repository (uses first file's parent if None)
Returns:
Tuple of (tracked_files, untracked_files)
Example:
>>> files = collect_files(config)
>>> tracked, untracked = check_files_in_git(files)
>>> if untracked:
... print(f"Warning: {len(untracked)} files not in git")
"""
if not files:
return [], []
check_path = base_path or files[0].parent
assert is_git_repository(check_path)
git_root = get_git_root(check_path)
if not git_root:
return [], files
git_tracked = get_git_tracked_files(git_root)
if git_tracked is None:
return [], files
tracked = [f for f in files if f in git_tracked]
untracked = [f for f in files if f not in git_tracked]
return tracked, untracked
# --- Helper to create test files ---
def write_file(path: Path, content: str):
path.write_text(content, encoding="utf-8")
return path
# -- Test version calculation ---
def test_version_hash() -> None:
"""Test which files are used for version hash calculation."""
watched_paths = [DIR_PACKAGE_ROOT]
# Collect files
config = HashConfig(
paths=watched_paths,
allowed_suffixes=ALLOWED_SUFFIXES,
excluded_dir_patterns=EXCLUDED_DIR_PATTERNS,
excluded_files=EXCLUDED_FILES
)
files = collect_files(config)
hash_digest = hash_files(files)
# Check git
tracked, untracked = check_files_in_git(files, DIR_PACKAGE_ROOT)
tracked_files: list[Path] = tracked
untracked_files: list[Path] = untracked
if untracked_files:
error_msg = f"\n{'='*60}"
error_msg += f"Version Hash Inspection"
error_msg += f"{'='*60}\n"
error_msg += f"Hash: {hash_digest}"
error_msg += f"Based on {len(files)} files:\n"
error_msg += f"OK: {len(tracked_files)} files tracked by git:\n"
for i, file_path in enumerate(files, 1):
try:
rel_path = file_path.relative_to(DIR_PACKAGE_ROOT)
status = ""
if file_path in untracked_files:
continue
elif file_path in tracked_files:
status = " [tracked]"
error_msg += f" {i:3d}. {rel_path}{status}\n"
except ValueError:
error_msg += f" {i:3d}. {file_path}\n"
error_msg += f"Warning: {len(untracked_files)} files not tracked by git:\n"
for i, file_path in enumerate(files, 1):
try:
rel_path = file_path.relative_to(DIR_PACKAGE_ROOT)
status = ""
if file_path in untracked_files:
status = " [NOT IN GIT]"
elif file_path in tracked_files:
continue
error_msg += f" {i:3d}. {rel_path}{status}\n"
except ValueError:
error_msg += f" {i:3d}. {file_path}\n"
error_msg += f"\n{'='*60}\n"
pytest.fail(error_msg)
# --- Test version helpers ---
def test_version_non_dev(monkeypatch):
@@ -38,7 +204,7 @@ def test_version_dev_precision_8(monkeypatch):
result = _version_calculate()
# compute expected suffix
# Compute expected suffix using the same logic as _version_calculate
hash_value = int(fake_hash, 16)
expected_digits = str(hash_value % (10 ** 8)).zfill(8)
@@ -60,12 +226,17 @@ def test_version_dev_precision_8_different_hash(monkeypatch):
result = _version_calculate()
# Compute expected suffix using the same logic as _version_calculate
hash_value = int(fake_hash, 16)
expected_digits = str(hash_value % (10 ** 8)).zfill(8)
expected = f"0.2.0.dev{expected_digits}"
assert result == expected
assert len(expected_digits) == 8
assert result.startswith("0.2.0.dev")
assert result == expected
# --- 1⃣ Test get_version.py ---

View File

@@ -6,7 +6,7 @@ import pandas as pd
import pytest
from akkudoktoreos.core.cache import CacheFileStore
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.weatherbrightsky import WeatherBrightSky
from akkudoktoreos.utils.datetimeutil import to_datetime

View File

@@ -10,7 +10,7 @@ import pytest
from bs4 import BeautifulSoup
from akkudoktoreos.core.cache import CacheFileStore
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.weatherclearoutside import WeatherClearOutside
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime

View File

@@ -1,11 +1,12 @@
import json
from pathlib import Path
import numpy.testing as npt
import pytest
from akkudoktoreos.core.ems import get_ems
from akkudoktoreos.core.coreabc import get_ems
from akkudoktoreos.prediction.weatherimport import WeatherImport
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata")
@@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos, monkeypatch):
)
def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos):
"""Test fetching forecast from Import."""
key = "weather_temp_air"
ems_eos = get_ems()
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
if from_file:
@@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
else:
config_eos.weather.provider_settings.WeatherImport.import_file_path = None
assert config_eos.weather.provider_settings.WeatherImport.import_file_path is None
provider.clear()
provider.delete_by_datetime(start_datetime=None, end_datetime=None)
# Call the method
provider.update_data()
@@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi
assert provider.ems_start_datetime is not None
assert provider.total_hours is not None
assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal
values = sample_import_1_json["weather_temp_air"]
value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values))
for i, mapping in enumerate(value_datetime_mapping):
assert i < len(provider.records)
expected_datetime, expected_value_index = mapping
expected_value = values[expected_value_index]
result_datetime = provider.records[i].date_time
result_value = provider.records[i]["weather_temp_air"]
# print(f"{i}: Expected: {expected_datetime}:{expected_value}")
# print(f"{i}: Result: {result_datetime}:{result_value}")
assert compare_datetimes(result_datetime, expected_datetime).equal
assert result_value == expected_value
expected_values = sample_import_1_json[key]
result_values = provider.key_to_array(
key=key,
start_datetime=provider.ems_start_datetime,
end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"),
interval=to_duration("1 hour"),
)
# Allow for some difference due to value calculation on DST change
npt.assert_allclose(result_values, expected_values, rtol=0.001)

View File

@@ -1,6 +1,6 @@
{
"general": {
"data_folder_path": null,
"data_folder_path": "__ANY__",
"data_output_subpath": "output",
"latitude": 52.5,
"longitude": 13.4