diff --git a/.env b/.env index 4f9e32a..99b3c44 100644 --- a/.env +++ b/.env @@ -11,7 +11,7 @@ DOCKER_COMPOSE_DATA_DIR=${HOME}/.local/share/net.akkudoktor.eos # ----------------------------------------------------------------------------- # Image / build # ----------------------------------------------------------------------------- -VERSION=0.2.0.dev84352035 +VERSION=0.2.0.dev58204789 PYTHON_VERSION=3.13.9 # ----------------------------------------------------------------------------- diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index b021526..a82b84a 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -60,6 +60,9 @@ jobs: - linux/arm64 exclude: ${{ fromJSON(needs.platform-excludes.outputs.excludes) }} steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Prepare run: | platform=${{ matrix.platform }} @@ -114,6 +117,7 @@ jobs: id: build uses: docker/build-push-action@v6 with: + context: . platforms: ${{ matrix.platform }} labels: ${{ steps.meta.outputs.labels }} annotations: ${{ steps.meta.outputs.annotations }} diff --git a/CHANGELOG.md b/CHANGELOG.md index e8da6ac..f91fb90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to the akkudoktoreos project will be documented in this file The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## 0.3.0 (2025-12-??) +## 0.3.0 (2026-02-??) Adapters for Home Assistant and NodeRed integration are added. These adapters provide a simplified interface to these HEMS besides the standard REST interface. @@ -13,94 +13,126 @@ Akkudoktor-EOS can now be run as Home Assistant add-on and standalone. As Home Assistant add-on EOS uses ingress to fully integrate the EOSdash dashboard in Home Assistant. +The prediction and measurement data can now be backed by a database. The database allows +to keep historic prediction data and measurement data for long time without keeping +it in memory. The database supports backend selection, compression, incremental data load, +automatic data saving to storage, automatic vaccum and compaction. Two database backends +are integrated and can be configured, LMDB and SQLight3. + In addition, bugs were fixed and new features were added. ### Feat +- add database support for measurements and historic prediction data. + The prediction and measurement data can now be backed by a database. The database allows + to keep historic prediction data and measurement data for long time without keeping + it in memory. Two database backends are integrated and can be configured, LMDB and SQLight3. - add adapters for integrations - Adapters for Home Assistant and NodeRED integration are added. Akkudoktor-EOS can now be run as Home Assistant add-on and standalone. - As Home Assistant add-on EOS uses ingress to fully integrate the EOSdash dashboard in Home Assistant. - +- add make repeated task function + make_repeated_task allows to wrap a function to be repeated cyclically. - allow eos to be started with root permissions and drop priviledges - Home assistant starts all add-ons with root permissions. Eos now drops root permissions if an applicable user is defined by paramter --run_as_user. The docker image defines the user eos to be used. - - make eos supervise and monitor EOSdash - Eos now not only starts EOSdash but also monitors EOSdash during runtime and restarts EOSdash on fault. EOSdash logging is captured by EOS and forwarded to the EOS log to provide better visibility. - - add duration to string conversion - Make to_duration to also return the duration as string on request. ### Fixed +- config eos test setup + Make the config_eos fixture generate a new instance of the config_eos singleton. + Use correct env names to setup data folder path. +- startup with no config + Make cache and measurements complain about missing data path configuration but + do not bail out. +- soc data preparation and usage for genetic optimization. + Search for soc measurments 48 hours around the optimization start time. + Only clamp soc to maximum in battery device simulation. +- dashboard bailout on zero value solution display + Do not use zero values to calculate the chart values adjustment for display. +- openapi generation script + Make the script also replace data_folder_path and data_output_path to hide + real (test) environment pathes. - development version scheme - The development versioning scheme is adaptet to fit to docker and home assistant expectations. The new scheme is x.y.z and x.y.z.dev. Hash is only digits as expected by home assistant. Development version is appended by .dev as expected by docker. - - use mean value in interval on resampling for array - When downsampling data use the mean value of all values within the new sampling interval. - - default battery ev soc and appliance wh - Make the genetic simulation return default values for the battery SoC, electric vehicle SoC and appliance load if these assets are not used. - - import json string - Strip outer quotes from JSON strings on import to be compliant to json.loads() expectation. - - default interval definition for import data - Default interval must be defined in lowercase human definition to be accepted by pendulum. - - clearoutside schema change ### Chore +- removed index based data sequence access + Index based data sequence access does not make sense as the sequence can be backed + by the database. The sequence is now purely time series data. +- refactor eos startup to avoid module import startup + Avoid module import initialisation expecially of the EOS configuration. + Config mutation, singleton initialization, logging setup, argparse parsing, + background task definitions depending on config and environment-dependent behavior + is now done at function startup. +- introduce retention manager + A single long-running background task that owns the scheduling of all periodic + server-maintenance jobs (cache cleanup, DB autosave, …) +- canonicalize timezone name for UTC + Timezone names that are semantically identical to UTC are canonicalized to UTC. +- extend config file migration for default value handling +- extend datetime util test cases +- make version test check for untracked files + Check for files that are not tracked by git. Version calculation will be + wrong if these files will not be commited. +- bump pandas to 3.0.0 + Pandas 3.0 now performs inference on the appropriate resolution (a.k.a. unit) + for the output dtype which may become datetime64[us] (before it was ns). Also + numeric dtype detection is now more strict which needs a different detection for + numerics. +- bump pydantic-settings to 2.12.0 + pydantic-settings 2.12.0 under pytest creates a different behaviour. The tests + were adapted and a workaround was introduced. Also ConfigEOS was adapted + to allow for fine grain initialization control to be able to switch + off certain settings such as file settings during test. +- remove sci learn kit from dependencies + The sci learn kit is not strictly necessary as long as we have scipy. +- add documentation mode guarding for sphinx autosummary + Sphinx autosummary excecutes functions. Prevent exceptions in case of pure doc + mode. +- adapt docker-build CI workflow to stricter GitHub handling - Use info logging to report missing optimization parameters - In parameter preparation for automatic optimization an error was logged for missing paramters. Log is now down using the info level. - - make EOSdash use the EOS data directory for file import/ export - EOSdash use the EOS data directory for file import/ export by default. This allows to use the configuration import/ export function also within docker images. - - improve EOSdash config tab display - Improve display of JSON code and add more forms for config value update. - - make docker image file system layout similar to home assistant - Only use /data directory for persistent data. This is handled as a docker volume. The /data volume is mapped to ~/.local/share/net.akkudoktor.eos if using docker compose. - - add home assistant add-on development environment - Add VSCode devcontainer and task definition for home assistant add-on development. - - improve documentation ## 0.2.0 (2025-11-09) diff --git a/config.yaml b/config.yaml index d747830..2a93cfd 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,7 @@ # the root directory (no add-on folder as usual). name: "Akkudoktor-EOS" -version: "0.2.0.dev84352035" +version: "0.2.0.dev58204789" slug: "eos" description: "Akkudoktor-EOS add-on" url: "https://github.com/Akkudoktor-EOS/EOS" diff --git a/docs/_generated/config.md b/docs/_generated/config.md index 42a8a61..ab268d9 100644 --- a/docs/_generated/config.md +++ b/docs/_generated/config.md @@ -4,6 +4,7 @@ ../_generated/configadapter.md ../_generated/configcache.md +../_generated/configdatabase.md ../_generated/configdevices.md ../_generated/configelecprice.md ../_generated/configems.md diff --git a/docs/_generated/configadapter.md b/docs/_generated/configadapter.md index 753ab40..f7d307b 100644 --- a/docs/_generated/configadapter.md +++ b/docs/_generated/configadapter.md @@ -10,7 +10,7 @@ | homeassistant | `EOS_ADAPTER__HOMEASSISTANT` | `HomeAssistantAdapterCommonSettings` | `rw` | `required` | Home Assistant adapter settings. | | nodered | `EOS_ADAPTER__NODERED` | `NodeREDAdapterCommonSettings` | `rw` | `required` | NodeRED adapter settings. | | provider | `EOS_ADAPTER__PROVIDER` | `Optional[list[str]]` | `rw` | `None` | List of adapter provider id(s) of provider(s) to be used. | -| providers | | `list[str]` | `ro` | `N/A` | Available electricity price provider ids. | +| providers | | `list[str]` | `ro` | `N/A` | Available adapter provider ids. | ::: @@ -33,10 +33,7 @@ "pv_production_emr_entity_ids": null, "device_measurement_entity_ids": null, "device_instruction_entity_ids": null, - "solution_entity_ids": null, - "homeassistant_entity_ids": [], - "eos_solution_entity_ids": [], - "eos_device_instruction_entity_ids": [] + "solution_entity_ids": null }, "nodered": { "host": "127.0.0.1", diff --git a/docs/_generated/configcache.md b/docs/_generated/configcache.md index 9a8ff8b..89129d4 100644 --- a/docs/_generated/configcache.md +++ b/docs/_generated/configcache.md @@ -7,7 +7,7 @@ | Name | Environment Variable | Type | Read-Only | Default | Description | | ---- | -------------------- | ---- | --------- | ------- | ----------- | -| cleanup_interval | `EOS_CACHE__CLEANUP_INTERVAL` | `float` | `rw` | `300` | Intervall in seconds for EOS file cache cleanup. | +| cleanup_interval | `EOS_CACHE__CLEANUP_INTERVAL` | `float` | `rw` | `300.0` | Intervall in seconds for EOS file cache cleanup. | | subpath | `EOS_CACHE__SUBPATH` | `Optional[pathlib.Path]` | `rw` | `cache` | Sub-path for the EOS cache data directory. | ::: diff --git a/docs/_generated/configdatabase.md b/docs/_generated/configdatabase.md new file mode 100644 index 0000000..01a7a70 --- /dev/null +++ b/docs/_generated/configdatabase.md @@ -0,0 +1,72 @@ +## Configuration model for database settings + +Attributes: + provider: Optional provider identifier (e.g. "LMDB"). + max_records_in_memory: Maximum records kept in memory before auto-save. + auto_save: Whether to auto-save when threshold exceeded. + batch_size: Batch size for batch operations. + + +:::{table} database +:widths: 10 20 10 5 5 30 +:align: left + +| Name | Environment Variable | Type | Read-Only | Default | Description | +| ---- | -------------------- | ---- | --------- | ------- | ----------- | +| autosave_interval_sec | `EOS_DATABASE__AUTOSAVE_INTERVAL_SEC` | `Optional[int]` | `rw` | `10` | Automatic saving interval [seconds]. +Set to None to disable automatic saving. | +| batch_size | `EOS_DATABASE__BATCH_SIZE` | `int` | `rw` | `100` | Number of records to process in batch operations. | +| compaction_interval_sec | `EOS_DATABASE__COMPACTION_INTERVAL_SEC` | `Optional[int]` | `rw` | `604800` | Interval in between automatic tiered compaction runs [seconds]. +Compaction downsamples old records to reduce storage while retaining coverage. Set to None to disable automatic compaction. | +| compression_level | `EOS_DATABASE__COMPRESSION_LEVEL` | `int` | `rw` | `9` | Compression level for database record data. | +| initial_load_window_h | `EOS_DATABASE__INITIAL_LOAD_WINDOW_H` | `Optional[int]` | `rw` | `None` | Specifies the default duration of the initial load window when loading records from the database, in hours. If set to None, the full available range is loaded. The window is centered around the current time by default, unless a different center time is specified. Different database namespaces may define their own default windows. | +| keep_duration_h | `EOS_DATABASE__KEEP_DURATION_H` | `Optional[int]` | `rw` | `None` | Default maximum duration records shall be kept in database [hours, none]. +None indicates forever. Database namespaces may have diverging definitions. | +| provider | `EOS_DATABASE__PROVIDER` | `Optional[str]` | `rw` | `None` | Database provider id of provider to be used. | +| providers | | `List[str]` | `ro` | `N/A` | Return available database provider ids. | +::: + + + +**Example Input** + + + +```json + { + "database": { + "provider": "LMDB", + "compression_level": 0, + "initial_load_window_h": 48, + "keep_duration_h": 48, + "autosave_interval_sec": 5, + "compaction_interval_sec": 604800, + "batch_size": 100 + } + } +``` + + + +**Example Output** + + + +```json + { + "database": { + "provider": "LMDB", + "compression_level": 0, + "initial_load_window_h": 48, + "keep_duration_h": 48, + "autosave_interval_sec": 5, + "compaction_interval_sec": 604800, + "batch_size": 100, + "providers": [ + "LMDB", + "SQLite" + ] + } + } +``` + diff --git a/docs/_generated/configdevices.md b/docs/_generated/configdevices.md index 2560feb..e372090 100644 --- a/docs/_generated/configdevices.md +++ b/docs/_generated/configdevices.md @@ -50,19 +50,7 @@ 1.0 ], "min_soc_percentage": 0, - "max_soc_percentage": 100, - "measurement_key_soc_factor": "battery1-soc-factor", - "measurement_key_power_l1_w": "battery1-power-l1-w", - "measurement_key_power_l2_w": "battery1-power-l2-w", - "measurement_key_power_l3_w": "battery1-power-l3-w", - "measurement_key_power_3_phase_sym_w": "battery1-power-3-phase-sym-w", - "measurement_keys": [ - "battery1-soc-factor", - "battery1-power-l1-w", - "battery1-power-l2-w", - "battery1-power-l3-w", - "battery1-power-3-phase-sym-w" - ] + "max_soc_percentage": 100 } ], "max_batteries": 1, @@ -89,19 +77,7 @@ 1.0 ], "min_soc_percentage": 0, - "max_soc_percentage": 100, - "measurement_key_soc_factor": "battery1-soc-factor", - "measurement_key_power_l1_w": "battery1-power-l1-w", - "measurement_key_power_l2_w": "battery1-power-l2-w", - "measurement_key_power_l3_w": "battery1-power-l3-w", - "measurement_key_power_3_phase_sym_w": "battery1-power-3-phase-sym-w", - "measurement_keys": [ - "battery1-soc-factor", - "battery1-power-l1-w", - "battery1-power-l2-w", - "battery1-power-l3-w", - "battery1-power-3-phase-sym-w" - ] + "max_soc_percentage": 100 } ], "max_electric_vehicles": 1, diff --git a/docs/_generated/configems.md b/docs/_generated/configems.md index fe336cd..3930576 100644 --- a/docs/_generated/configems.md +++ b/docs/_generated/configems.md @@ -7,7 +7,7 @@ | Name | Environment Variable | Type | Read-Only | Default | Description | | ---- | -------------------- | ---- | --------- | ------- | ----------- | -| interval | `EOS_EMS__INTERVAL` | `Optional[float]` | `rw` | `None` | Intervall in seconds between EOS energy management runs. | +| interval | `EOS_EMS__INTERVAL` | `float` | `rw` | `300.0` | Intervall between EOS energy management runs [seconds]. | | mode | `EOS_EMS__MODE` | `Optional[akkudoktoreos.core.emsettings.EnergyManagementMode]` | `rw` | `None` | Energy management mode [OPTIMIZATION | PREDICTION]. | | startup_delay | `EOS_EMS__STARTUP_DELAY` | `float` | `rw` | `5` | Startup delay in seconds for EOS energy management runs. | ::: diff --git a/docs/_generated/configexample.md b/docs/_generated/configexample.md index 1b1a09a..b3ea38c 100644 --- a/docs/_generated/configexample.md +++ b/docs/_generated/configexample.md @@ -15,10 +15,7 @@ "pv_production_emr_entity_ids": null, "device_measurement_entity_ids": null, "device_instruction_entity_ids": null, - "solution_entity_ids": null, - "homeassistant_entity_ids": [], - "eos_solution_entity_ids": [], - "eos_device_instruction_entity_ids": [] + "solution_entity_ids": null }, "nodered": { "host": "127.0.0.1", @@ -29,6 +26,15 @@ "subpath": "cache", "cleanup_interval": 300.0 }, + "database": { + "provider": "LMDB", + "compression_level": 0, + "initial_load_window_h": 48, + "keep_duration_h": 48, + "autosave_interval_sec": 5, + "compaction_interval_sec": 604800, + "batch_size": 100 + }, "devices": { "batteries": [ { @@ -53,19 +59,7 @@ 1.0 ], "min_soc_percentage": 0, - "max_soc_percentage": 100, - "measurement_key_soc_factor": "battery1-soc-factor", - "measurement_key_power_l1_w": "battery1-power-l1-w", - "measurement_key_power_l2_w": "battery1-power-l2-w", - "measurement_key_power_l3_w": "battery1-power-l3-w", - "measurement_key_power_3_phase_sym_w": "battery1-power-3-phase-sym-w", - "measurement_keys": [ - "battery1-soc-factor", - "battery1-power-l1-w", - "battery1-power-l2-w", - "battery1-power-l3-w", - "battery1-power-3-phase-sym-w" - ] + "max_soc_percentage": 100 } ], "max_batteries": 1, @@ -92,19 +86,7 @@ 1.0 ], "min_soc_percentage": 0, - "max_soc_percentage": 100, - "measurement_key_soc_factor": "battery1-soc-factor", - "measurement_key_power_l1_w": "battery1-power-l1-w", - "measurement_key_power_l2_w": "battery1-power-l2-w", - "measurement_key_power_l3_w": "battery1-power-l3-w", - "measurement_key_power_3_phase_sym_w": "battery1-power-3-phase-sym-w", - "measurement_keys": [ - "battery1-soc-factor", - "battery1-power-l1-w", - "battery1-power-l2-w", - "battery1-power-l3-w", - "battery1-power-3-phase-sym-w" - ] + "max_soc_percentage": 100 } ], "max_electric_vehicles": 1, @@ -138,8 +120,8 @@ } }, "general": { - "version": "0.2.0.dev84352035", - "data_folder_path": null, + "version": "0.2.0.dev58204789", + "data_folder_path": "/home/user/.local/share/net.akkudoktoreos.net", "data_output_subpath": "output", "latitude": 52.52, "longitude": 13.405 @@ -157,6 +139,7 @@ "file_level": "TRACE" }, "measurement": { + "historic_hours": 17520, "load_emr_keys": [ "load0_emr" ], diff --git a/docs/_generated/configgeneral.md b/docs/_generated/configgeneral.md index 8b42f74..1a102b4 100644 --- a/docs/_generated/configgeneral.md +++ b/docs/_generated/configgeneral.md @@ -9,14 +9,14 @@ | ---- | -------------------- | ---- | --------- | ------- | ----------- | | config_file_path | | `Optional[pathlib.Path]` | `ro` | `N/A` | Path to EOS configuration file. | | config_folder_path | | `Optional[pathlib.Path]` | `ro` | `N/A` | Path to EOS configuration directory. | -| data_folder_path | `EOS_GENERAL__DATA_FOLDER_PATH` | `Optional[pathlib.Path]` | `rw` | `None` | Path to EOS data directory. | +| data_folder_path | `EOS_GENERAL__DATA_FOLDER_PATH` | `Path` | `rw` | `required` | Path to EOS data folder. | | data_output_path | | `Optional[pathlib.Path]` | `ro` | `N/A` | Computed data_output_path based on data_folder_path. | -| data_output_subpath | `EOS_GENERAL__DATA_OUTPUT_SUBPATH` | `Optional[pathlib.Path]` | `rw` | `output` | Sub-path for the EOS output data directory. | -| home_assistant_addon | | `bool` | `ro` | `N/A` | EOS is running as home assistant add-on. | +| data_output_subpath | `EOS_GENERAL__DATA_OUTPUT_SUBPATH` | `Optional[pathlib.Path]` | `rw` | `output` | Sub-path for the EOS output data folder. | +| home_assistant_addon | `EOS_GENERAL__HOME_ASSISTANT_ADDON` | `bool` | `rw` | `required` | EOS is running as home assistant add-on. | | latitude | `EOS_GENERAL__LATITUDE` | `Optional[float]` | `rw` | `52.52` | Latitude in decimal degrees between -90 and 90. North is positive (ISO 19115) (°) | | longitude | `EOS_GENERAL__LONGITUDE` | `Optional[float]` | `rw` | `13.405` | Longitude in decimal degrees within -180 to 180 (°) | | timezone | | `Optional[str]` | `ro` | `N/A` | Computed timezone based on latitude and longitude. | -| version | `EOS_GENERAL__VERSION` | `str` | `rw` | `0.2.0.dev84352035` | Configuration file version. Used to check compatibility. | +| version | `EOS_GENERAL__VERSION` | `str` | `rw` | `0.2.0.dev58204789` | Configuration file version. Used to check compatibility. | ::: @@ -28,8 +28,8 @@ ```json { "general": { - "version": "0.2.0.dev84352035", - "data_folder_path": null, + "version": "0.2.0.dev58204789", + "data_folder_path": "/home/user/.local/share/net.akkudoktoreos.net", "data_output_subpath": "output", "latitude": 52.52, "longitude": 13.405 @@ -46,16 +46,15 @@ ```json { "general": { - "version": "0.2.0.dev84352035", - "data_folder_path": null, + "version": "0.2.0.dev58204789", + "data_folder_path": "/home/user/.local/share/net.akkudoktoreos.net", "data_output_subpath": "output", "latitude": 52.52, "longitude": 13.405, "timezone": "Europe/Berlin", - "data_output_path": null, + "data_output_path": "/home/user/.local/share/net.akkudoktoreos.net/output", "config_folder_path": "/home/user/.config/net.akkudoktoreos.net", - "config_file_path": "/home/user/.config/net.akkudoktoreos.net/EOS.config.json", - "home_assistant_addon": false + "config_file_path": "/home/user/.config/net.akkudoktoreos.net/EOS.config.json" } } ``` diff --git a/docs/_generated/configmeasurement.md b/docs/_generated/configmeasurement.md index 158df67..303e679 100644 --- a/docs/_generated/configmeasurement.md +++ b/docs/_generated/configmeasurement.md @@ -9,6 +9,7 @@ | ---- | -------------------- | ---- | --------- | ------- | ----------- | | grid_export_emr_keys | `EOS_MEASUREMENT__GRID_EXPORT_EMR_KEYS` | `Optional[list[str]]` | `rw` | `None` | The keys of the measurements that are energy meter readings of energy export to grid [kWh]. | | grid_import_emr_keys | `EOS_MEASUREMENT__GRID_IMPORT_EMR_KEYS` | `Optional[list[str]]` | `rw` | `None` | The keys of the measurements that are energy meter readings of energy import from grid [kWh]. | +| historic_hours | `EOS_MEASUREMENT__HISTORIC_HOURS` | `Optional[int]` | `rw` | `17520` | Number of hours into the past for measurement data | | keys | | `list[str]` | `ro` | `N/A` | The keys of the measurements that can be stored. | | load_emr_keys | `EOS_MEASUREMENT__LOAD_EMR_KEYS` | `Optional[list[str]]` | `rw` | `None` | The keys of the measurements that are energy meter readings of a load [kWh]. | | pv_production_emr_keys | `EOS_MEASUREMENT__PV_PRODUCTION_EMR_KEYS` | `Optional[list[str]]` | `rw` | `None` | The keys of the measurements that are PV production energy meter readings [kWh]. | @@ -23,6 +24,7 @@ ```json { "measurement": { + "historic_hours": 17520, "load_emr_keys": [ "load0_emr" ], @@ -48,6 +50,7 @@ ```json { "measurement": { + "historic_hours": 17520, "load_emr_keys": [ "load0_emr" ], diff --git a/docs/_generated/openapi.md b/docs/_generated/openapi.md index da82a6b..2e58177 100644 --- a/docs/_generated/openapi.md +++ b/docs/_generated/openapi.md @@ -1,6 +1,6 @@ # Akkudoktor-EOS -**Version**: `v0.2.0.dev84352035` +**Version**: `v0.2.0.dev58204789` **Description**: This project provides a comprehensive solution for simulating and optimizing an energy system based on renewable energy sources. With a focus on photovoltaic (PV) systems, battery storage (batteries), load management (consumer requirements), heat pumps, electric vehicles, and consideration of electricity price data, this system enables forecasting and optimization of energy flow and costs over a specified period. @@ -338,6 +338,56 @@ Returns: --- +## GET /v1/admin/database/stats + + +**Links**: [local](http://localhost:8503/docs#/default/fastapi_admin_database_stats_get_v1_admin_database_stats_get), [eos](https://petstore3.swagger.io/?url=https://raw.githubusercontent.com/Akkudoktor-EOS/EOS/refs/heads/main/openapi.json#/default/fastapi_admin_database_stats_get_v1_admin_database_stats_get) + + +Fastapi Admin Database Stats Get + + +```python +""" +Get statistics from database. + +Returns: + data (dict): The database statistics +""" +``` + + +**Responses**: + +- **200**: Successful Response + +--- + +## POST /v1/admin/database/vacuum + + +**Links**: [local](http://localhost:8503/docs#/default/fastapi_admin_database_vacuum_post_v1_admin_database_vacuum_post), [eos](https://petstore3.swagger.io/?url=https://raw.githubusercontent.com/Akkudoktor-EOS/EOS/refs/heads/main/openapi.json#/default/fastapi_admin_database_vacuum_post_v1_admin_database_vacuum_post) + + +Fastapi Admin Database Vacuum Post + + +```python +""" +Remove old records from database. + +Returns: + data (dict): The database stats after removal of old records. +""" +``` + + +**Responses**: + +- **200**: Successful Response + +--- + ## POST /v1/admin/server/restart diff --git a/docs/akkudoktoreos/database.md b/docs/akkudoktoreos/database.md new file mode 100644 index 0000000..e8b0d36 --- /dev/null +++ b/docs/akkudoktoreos/database.md @@ -0,0 +1,599 @@ +% SPDX-License-Identifier: Apache-2.0 +(database-page)= + +# Database + +## Overview + +The EOS database system provides a flexible, pluggable persistence layer for time-series data +records with automatic lazy loading, dirty tracking, and multi-backend support. The architecture +separates the abstract database interface from concrete storage implementations, allowing seamless +switching between LMDB and SQLite backends. + +## Architecture + +### Three-Layer Design + +**Abstract Interface Layer** (`DatabaseABC`) + +- Defines the contract for all database operations +- Provides compression/decompression utilities +- Backend-agnostic API + +**Backend Implementation Layer** (`DatabaseBackendABC`) + +- Concrete implementations: `LMDBDatabase`, `SQLiteDatabase` +- Singleton pattern ensures single instance per backend +- Thread-safe operations via internal locking + +**Record Protocol Layer** (`DatabaseRecordProtocolMixin`) + +- Manages in-memory record lifecycle +- Implements lazy loading strategies +- Handles dirty tracking and autosave + +## Configuration + +### Database Settings (`DatabaseCommonSettings`) + +```python +provider: Optional[str] = None # "LMDB" or "SQLite" +compression_level: int = 9 # 0-9, gzip compression +initial_load_window_h: Optional[int] = None # Hours, None = full load +keep_duration_h: Optional[int] = None # Retention period +autosave_interval_sec: Optional[int] = None # Auto-flush interval +compaction_interval_sec: Optional[int] = 604800 # Compaction interval +batch_size: int = 100 # Batch operation size +``` + +### User Configuration Guide + +This section explains what each setting does in practical terms and gives +concrete recommendations for common deployment scenarios. + +#### `provider` — choosing a backend + +Set `provider` to `"LMDB"` or `"SQLite"`. Leave it `None` only during +development or unit testing — with `None` set, nothing is persisted to disk and +all data is lost on restart. + +**Use LMDB** for a long-running home server that records data continuously. It +is significantly faster for high-frequency writes and range reads because it +uses memory-mapped files. The trade-off is that it pre-allocates a large file +on disk (default 10 GB) even when mostly empty. + +**Use SQLite** when disk space is constrained, for portable single-file +deployments, or when you want to inspect or manipulate the database with +standard SQL tools. SQLite is slightly slower for bulk writes but perfectly +adequate for home energy data volumes. + +**Do not** switch backends while data exists in the old backend — records are +not migrated automatically. If you need to switch, vacuum the old database +first, export your data, then reconfigure. + +#### `compression_level` — storage size vs. CPU + +Values range from `0` (no compression) to `9` (maximum compression). The default of `9` is +appropriate for most deployments: home energy time-series data compresses very well (often +60–80 % reduction) and the CPU overhead is negligible on modern hardware. + +**Set to `0`** only if you are running on very constrained hardware (e.g. a single-core ARM +board at full load) and storage space is not a concern. + +**Do not** change this setting after data has been written — the database stores each record +with the compression level active at write time and auto-detects the format on read, so mixed +levels are fine technically, but you will not reclaim space from already-written records until +they are rewritten by compaction. + +#### `initial_load_window_h` — startup memory usage + +Controls how much history is loaded into memory when the application first accesses a namespace. + +**Set a window** (e.g. `48`) on systems with limited RAM or large databases. Only the most +recent 48 hours are loaded immediately; older data is fetched on demand if a query reaches +outside that window. + +**Leave as `None`** (the default) on well-resourced systems or when you need guaranteed +access to all history from the first query. Full load is simpler and avoids the small latency +spike of incremental loads. + +**Do not** set this to a very small value (e.g. `1`) if your forecasting or reporting queries +routinely look back further — every out-of-window query triggers a database read, and many +small reads are slower than one full load. + +#### `keep_duration_h` — data retention + +Sets the age limit (in hours) for the vacuum operation. Records older than +`max_timestamp - keep_duration_h` are permanently deleted when vacuum runs. + +**Set this** to match your actual analysis needs. If your forecast models only look back 7 days, +keeping 14 days (`336`) gives a comfortable safety margin without accumulating indefinitely. + +**Leave as `None`** only if you have a strong archival requirement and understand that the +database will grow without bound. Even with compaction reducing resolution, old data is not +deleted unless vacuum runs with a retention limit. + +**Do not** set `keep_duration_h` shorter than the oldest data your forecast or reporting +queries ever request — vacuum is permanent and irreversible. + +#### `autosave_interval_sec` — write durability + +Controls how often dirty (modified) records are flushed to disk automatically, in seconds. + +**Set to a low value** (e.g. `10`–`30`) on a system that could lose power unexpectedly, +such as a Raspberry Pi without a UPS. A power cut between autosaves loses that window of data. + +**Set to a higher value** (e.g. `300`) on stable systems to reduce write amplification. Each +autosave is a full flush of all dirty records, so frequent saves on large dirty sets are +more expensive. + +**Leave as `None`** only if you call `db_save_records()` manually at appropriate points in +your application code. With `None`, data written since the last manual save is lost on crash. + +#### `compaction_interval_sec` — automatic tiered downsampling + +Controls how often the compaction maintenance job runs, in seconds. The default is +604 800 (one week). Set to `None` to disable automatic compaction entirely. + +Compaction applies a tiered downsampling policy to old records: + +- Records older than **2 hours** are downsampled to **15-minute** resolution +- Records older than **14 days** are downsampled to **1-hour** resolution + +This reduces storage and speeds up range queries on historical data while preserving full +resolution for recent data where it matters most. Each tier is processed incrementally — +only the window since the last compaction run is examined, so weekly runs are fast regardless +of total history length. + +**Leave at the default weekly interval** for most deployments. Compaction is idempotent and +cheap when run frequently on small new windows. + +**Set to a shorter interval** (e.g. `86400`, daily) if your device records at very high +frequency (sub-minute) and disk space is a concern. + +**Set to `None`** only if you have a custom retention policy and manage downsampling manually, +or if you store data that must not be averaged (e.g. raw event logs where mean resampling +would be meaningless). + +**Do not** set the interval shorter than `autosave_interval_sec` — compaction reads from the +backend and a record that has not been saved yet will not be visible to it. + +**Interaction with vacuum:** compaction and vacuum are complementary. Compaction reduces +resolution of old data; vacuum deletes it entirely past `keep_duration_h`. The recommended +pipeline is: compaction runs first (weekly), then vacuum runs immediately after. This means +vacuum always operates on already-downsampled data, which is faster and produces cleaner +storage boundaries. + +### Recommended Configurations by Scenario + +#### Home server, typical (Raspberry Pi 4, SSD) + +```python +provider = "LMDB" +compression_level = 9 +initial_load_window_h = 48 +keep_duration_h = 720 # 30 days +autosave_interval_sec = 30 +compaction_interval_sec = 604800 # weekly +``` + +#### Home server, low storage (Raspberry Pi Zero, SD card) + +```python +provider = "SQLite" +compression_level = 9 +initial_load_window_h = 24 +keep_duration_h = 168 # 7 days +autosave_interval_sec = 60 +compaction_interval_sec = 86400 # daily — reclaim space faster +``` + +#### Development / testing + +```python +provider = "SQLite" # or None for fully in-memory +compression_level = 0 # faster without compression overhead +initial_load_window_h = None # always load everything +keep_duration_h = None # never vacuum automatically +autosave_interval_sec = None # manual saves only +compaction_interval_sec = None # disable compaction +``` + +#### High-frequency recording (sub-minute intervals) + +```python +provider = "LMDB" +compression_level = 9 +initial_load_window_h = 24 +keep_duration_h = 336 # 14 days +autosave_interval_sec = 10 +compaction_interval_sec = 86400 # daily — essential at high frequency +``` + +## Storage Backends + +### LMDB Backend + +**Characteristics:** + +- Memory-mapped file database +- Native namespace support via DBIs (Database Instances) +- High-performance reads with MVCC +- Configurable map size (default: 10 GB) + +**Configuration:** + +```python +map_size: int = 10 * 1024 * 1024 * 1024 # 10 GB +writemap=True, map_async=True # Performance optimizations +max_dbs=128 # Maximum namespaces +``` + +**File Structure:** + +```text +data_folder_path/ +└── db/ + └── lmdbdatabase/ + ├── data.mdb + └── lock.mdb +``` + +### SQLite Backend + +**Characteristics:** + +- Single-file relational database +- Namespace emulation via `namespace` column +- ACID transactions with autocommit mode +- Cross-platform compatibility + +**Schema:** + +```sql +CREATE TABLE records ( + namespace TEXT NOT NULL DEFAULT '', + key BLOB NOT NULL, + value BLOB NOT NULL, + PRIMARY KEY (namespace, key) +); + +CREATE TABLE metadata ( + namespace TEXT PRIMARY KEY, + value BLOB +); +``` + +**File Structure:** + +```text +data_folder_path/ +└── db/ + └── sqlitedatabase/ + └── data.db +``` + +## Timestamp System + +### DatabaseTimestamp + +All records are indexed by UTC timestamps in sortable ISO 8601 format: + +```python +DatabaseTimestamp.from_datetime(dt: DateTime) -> "20241027T123456[Z]" +``` + +**Properties:** +- Always stored in UTC (timezone-aware required) +- Lexicographically sortable +- Bijective conversion to/from `pendulum.DateTime` +- Second-level precision + +### Unbounded Sentinels + +```python +UNBOUND_START # Smaller than any timestamp +UNBOUND_END # Greater than any timestamp +``` + +Used for open-ended range queries without special-casing `None`. + +## Lazy Loading Strategy + +### Three-Phase Loading + +The system uses a progressive loading model to minimize memory footprint: + +#### **Phase 0: NONE** + +- No records loaded +- First query triggers either: + - Initial window load (if `initial_load_window_h` configured) + - Full database load (if `initial_load_window_h = None`) + - Targeted range load (if explicit range requested) + +#### **Phase 1: INITIAL** + +- Partial time window loaded +- `_db_loaded_range` tracks coverage: `[start_timestamp, end_timestamp)` +- Out-of-window queries trigger incremental expansion: + - Left expansion: load records before current window + - Right expansion: load records after current window +- Unbounded queries escalate to FULL + +#### **Phase 2: FULL** + +- All database records in memory +- No further database access needed +- `_db_loaded_range` spans entire dataset + +### Boundary Extension + +When loading a range `[start, end)`, the system automatically extends boundaries to include: +- **First record before** `start` (for interpolation/context) +- **First record at or after** `end` (for closing boundary) + +This prevents additional database lookups during nearest-neighbor searches. + +## Namespace Support + +Namespaces provide logical isolation within a single database instance: + +```python +# LMDB: uses native DBIs +db.save_records(records, namespace="measurement") + +# SQLite: uses namespace column +SELECT * FROM records WHERE namespace='measurement' +``` + +**Default Namespace:** +- Can be set during `open(namespace="default")` +- Operations with `namespace=None` use the default +- Each record class typically defines its own namespace via `db_namespace()` + +## Record Lifecycle + +### Insertion + +```python +db_insert_record(record, mark_dirty=True) +``` + +1. Normalize `record.date_time` to UTC `DatabaseTimestamp` +2. Ensure timestamp range is loaded (lazy load if needed) +3. Check for duplicates (raises `ValueError`) +4. Insert into sorted position in memory +5. Update index: `_db_record_index[timestamp] = record` +6. Mark dirty if `mark_dirty=True` + +### Retrieval + +```python +db_get_record(target_timestamp, time_window=None) +``` + +**Search Strategies:** + +| `time_window` | Behavior | +|---|---| +| `None` | Exact match only | +| `UNBOUND_WINDOW` | Nearest record (unlimited search) | +| `Duration` | Nearest within symmetric window | + +**Memory-First:** Checks in-memory index before querying database. + +### Deletion + +```python +db_delete_records(start_timestamp, end_timestamp) +``` + +1. Ensure range is fully loaded +2. Remove from memory: `records`, `_db_sorted_timestamps`, `_db_record_index` +3. Add to `_db_deleted_timestamps` (tombstone) +4. Discard from dirty sets (cancel pending writes) +5. Physical deletion deferred until `db_save_records()` + +## Dirty Tracking + +The system maintains three dirty sets to optimize writes: + +```python +_db_dirty_timestamps: set[DatabaseTimestamp] # Modified records +_db_new_timestamps: set[DatabaseTimestamp] # Newly inserted +_db_deleted_timestamps: set[DatabaseTimestamp] # Pending deletes +``` + +**Write Strategy:** + +1. **Saves first:** Insert/update all dirty records +2. **Deletes last:** Remove tombstoned records +3. **Clear tracking sets:** Reset dirty state + +**Autosave:** Triggered periodically if `autosave_interval_sec` configured. + +## Compression + +Optional gzip compression reduces storage footprint: + +```python +# Serialize +data = pickle.dumps(record.model_dump()) +if compression_level > 0: + data = gzip.compress(data, compresslevel=compression_level) + +# Deserialize (auto-detect) +if data[:2] == b'\x1f\x8b': # gzip magic bytes + data = gzip.decompress(data) +record_data = pickle.loads(data) +``` + +**Compression is transparent:** Application code never handles compressed data directly. + +## Metadata + +Each namespace can store arbitrary metadata (version, creation time, provider): + +```python +_db_metadata = { + "version": 1, + "created": "2024-01-01T00:00:00Z", + "provider_id": "LMDB", + "compression": True, + "backend": "LMDBDatabase" +} +``` + +Stored separately from records using reserved key `__metadata__`. + +## Compaction + +Compaction reduces storage by downsampling old records to a lower time resolution. Unlike +vacuum — which deletes records outright — compaction preserves the full time span of the +data while replacing many fine-grained records with fewer coarse-grained averages. + +### Tiered Downsampling Policy + +The default policy has two tiers, applied coarsest-first: + +| Age threshold | Target resolution | Effect | +|---|---|---| +| Older than 14 days | 1 hour | 15-min records → 1 per hour (75 % reduction) | +| Older than 2 hours | 15 minutes | 1-min records → 1 per 15 min (93 % reduction) | + +Records within the most recent 2 hours are never touched. + +### How Compaction Works + +Each tier is processed incrementally using a stored cutoff timestamp per tier. On each run, +only the window `[last_cutoff, new_cutoff)` is examined — records already compacted in a +previous run are never re-processed. This makes weekly runs fast even on years of history. + +For each writable numeric field, records in the window are mean-resampled at the target +interval using time interpolation. The original records are deleted and the downsampled +records are written back. A **sparse-data guard** skips any window where the existing record +count is already at or below the resampled bucket count, preventing compaction from +accidentally *increasing* record count for data that is already coarse or irregular. + +### Customising the Policy per Namespace + +Individual data providers can override `db_compact_tiers()` to use a different policy: + +```python +class PriceDataProvider(DataProvider): + def db_compact_tiers(self): + # Price data is already at 15-min resolution from the source. + # Skip the first tier; only compact to hourly after 2 weeks. + return [(to_duration("14 days"), to_duration("1 hour"))] +``` + +Return an empty list to disable compaction for a specific namespace entirely: + +```python +class EventLogProvider(DataProvider): + def db_compact_tiers(self): + return [] # Raw events must not be averaged +``` + +### Manual Invocation + +```python +# Compact all providers in the container +data_container.db_compact() + +# Compact a single provider +provider.db_compact() + +# Use a one-off policy without changing the instance default +provider.db_compact(compact_tiers=[ + (to_duration("7 days"), to_duration("1 hour")) +]) +``` + +### Interaction with Vacuum + +Compaction and vacuum are complementary and should always run in this order: + +```text +compact → vacuum +``` + +Compact first so that vacuum operates on already-downsampled records. This produces cleaner +retention boundaries and ensures the vacuum cutoff falls on hour-aligned timestamps rather +than arbitrary sub-minute ones. Running them in reverse order (vacuum then compact) wastes +work: vacuum may delete records that compaction would have downsampled and kept. + +The `RetentionManager` registers both jobs and ensures compaction always runs before vacuum +within the same maintenance window. + +## Vacuum Operation + +Remove old records to reclaim space: + +```python +db_vacuum(keep_hours=48) # Keep last 48 hours +db_vacuum(keep_timestamp=cutoff) # Keep from cutoff onward +``` + +**Strategy:** +- Computes cutoff relative to `max_timestamp - keep_hours` +- Deletes all records before cutoff +- Immediately persists changes via `db_save_records()` + +## Thread Safety + +- **LMDB:** Internal lock protects write transactions; reads are lock-free via MVCC +- **SQLite:** Lock guards all operations (autocommit mode eliminates transaction deadlocks) +- **Record Protocol:** No internal locking (assumes single-threaded access per instance) + +## Performance Characteristics + +| Operation | LMDB | SQLite | +|---|---|---| +| Sequential read | Excellent (mmap) | Good (indexed) | +| Random read | Excellent (mmap) | Good (B-tree) | +| Bulk write | Excellent (single txn) | Good (batch insert) | +| Range query | Excellent (cursor) | Good (indexed scan) | +| Disk usage | Moderate (pre-allocated) | Compact (auto-grow) | +| Concurrency | High (MVCC readers) | Low (write serialization) | + +**Recommendation:** Use LMDB for high-frequency time-series workloads; +SQLite for portability and simpler deployment. + +## Example Usage + +```python +# Configuration +config.database.provider = "LMDB" +config.database.compression_level = 9 +config.database.initial_load_window_h = 24 # Load last 24h initially +config.database.keep_duration_h = 720 # Retain 30 days +config.database.compaction_interval_sec = 604800 # Compact weekly + +# Access (automatic singleton initialization) +class MeasurementData(DatabaseRecordProtocolMixin): + records: list[MeasurementRecord] = [] + + def db_namespace(self) -> str: + return "measurement" + +# Operations +measurement = MeasurementData() + +# Lazy load on first access +record = measurement.db_get_record( + DatabaseTimestamp.from_datetime(now), + time_window=Duration(hours=1) +) + +# Insert new record +measurement.db_insert_record(new_record) + +# Automatic save (if autosave configured) or manual +measurement.db_save_records() + +# Maintenance pipeline (normally handled by RetentionManager) +measurement.db_compact() # downsample old records first +measurement.db_vacuum(keep_hours=720) # then delete beyond retention +``` diff --git a/docs/conf.py b/docs/conf.py index f533f45..a1f5082 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,7 +18,7 @@ from akkudoktoreos.core.version import __version__ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "Akkudoktor EOS" -copyright = "2025, Andreas Schmitz" +copyright = "2025..2026, Andreas Schmitz" author = "Andreas Schmitz" release = __version__ diff --git a/docs/index.md b/docs/index.md index f4a8ffd..e9ff054 100644 --- a/docs/index.md +++ b/docs/index.md @@ -50,6 +50,7 @@ akkudoktoreos/prediction.md akkudoktoreos/measurement.md akkudoktoreos/integration.md akkudoktoreos/logging.md +akkudoktoreos/database.md akkudoktoreos/adapter.md akkudoktoreos/serverapi.md akkudoktoreos/api.rst diff --git a/openapi.json b/openapi.json index 3cfe8de..60fa9c4 100644 --- a/openapi.json +++ b/openapi.json @@ -2,8 +2,13 @@ "openapi": "3.1.0", "info": { "title": "Akkudoktor-EOS", + "summary": "Comprehensive solution for simulating and optimizing an energy system based on renewable energy sources", "description": "This project provides a comprehensive solution for simulating and optimizing an energy system based on renewable energy sources. With a focus on photovoltaic (PV) systems, battery storage (batteries), load management (consumer requirements), heat pumps, electric vehicles, and consideration of electricity price data, this system enables forecasting and optimization of energy flow and costs over a specified period.", - "version": "v0.2.0.dev84352035" + "license": { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "v0.2.0.dev58204789" }, "paths": { "/v1/admin/cache/clear": { @@ -126,6 +131,54 @@ } } }, + "/v1/admin/database/stats": { + "get": { + "tags": [ + "admin" + ], + "summary": "Fastapi Admin Database Stats Get", + "description": "Get statistics from database.\n\nReturns:\n data (dict): The database statistics", + "operationId": "fastapi_admin_database_stats_get_v1_admin_database_stats_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Fastapi Admin Database Stats Get V1 Admin Database Stats Get" + } + } + } + } + } + } + }, + "/v1/admin/database/vacuum": { + "post": { + "tags": [ + "admin" + ], + "summary": "Fastapi Admin Database Vacuum Post", + "description": "Remove old records from database.\n\nReturns:\n data (dict): The database stats after removal of old records.", + "operationId": "fastapi_admin_database_vacuum_post_v1_admin_database_vacuum_post", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Fastapi Admin Database Vacuum Post V1 Admin Database Vacuum Post" + } + } + } + } + } + } + }, "/v1/admin/server/restart": { "post": { "tags": [ @@ -2102,7 +2155,7 @@ }, "type": "array", "title": "Providers", - "description": "Available electricity price provider ids.", + "description": "Available adapter provider ids.", "readOnly": true } }, @@ -2493,9 +2546,10 @@ }, "cleanup_interval": { "type": "number", + "minimum": 5.0, "title": "Cleanup Interval", "description": "Intervall in seconds for EOS file cache cleanup.", - "default": 300 + "default": 300.0 } }, "type": "object", @@ -2523,170 +2577,61 @@ "ConfigEOS": { "properties": { "general": { - "$ref": "#/components/schemas/GeneralSettings-Output", - "default": { - "version": "0.2.0.dev84352035", - "data_output_subpath": "output", - "latitude": 52.52, - "longitude": 13.405, - "timezone": "Europe/Berlin", - "config_folder_path": "/home/user/.config/net.akkudoktoreos.net", - "config_file_path": "/home/user/.config/net.akkudoktoreos.net/EOS.config.json", - "home_assistant_addon": false - } + "$ref": "#/components/schemas/GeneralSettings-Output" }, "cache": { - "$ref": "#/components/schemas/CacheCommonSettings", - "default": { - "subpath": "cache", - "cleanup_interval": 300.0 - } + "$ref": "#/components/schemas/CacheCommonSettings" + }, + "database": { + "$ref": "#/components/schemas/DatabaseCommonSettings-Output" }, "ems": { - "$ref": "#/components/schemas/EnergyManagementCommonSettings", - "default": { - "startup_delay": 5.0 - } + "$ref": "#/components/schemas/EnergyManagementCommonSettings" }, "logging": { - "$ref": "#/components/schemas/LoggingCommonSettings-Output", - "default": { - "file_path": "/home/user/.local/share/net.akkudoktoreos.net/output/eos.log" - } + "$ref": "#/components/schemas/LoggingCommonSettings-Output" }, "devices": { - "$ref": "#/components/schemas/DevicesCommonSettings-Output", - "default": { - "measurement_keys": [] - } + "$ref": "#/components/schemas/DevicesCommonSettings-Output" }, "measurement": { - "$ref": "#/components/schemas/MeasurementCommonSettings-Output", - "default": { - "keys": [] - } + "$ref": "#/components/schemas/MeasurementCommonSettings-Output" }, "optimization": { - "$ref": "#/components/schemas/OptimizationCommonSettings-Output", - "default": { - "horizon_hours": 24, - "interval": 3600, - "algorithm": "GENETIC", - "genetic": { - "generations": 400, - "individuals": 300 - }, - "keys": [] - } + "$ref": "#/components/schemas/OptimizationCommonSettings-Output" }, "prediction": { - "$ref": "#/components/schemas/PredictionCommonSettings", - "default": { - "hours": 48, - "historic_hours": 48 - } + "$ref": "#/components/schemas/PredictionCommonSettings" }, "elecprice": { - "$ref": "#/components/schemas/ElecPriceCommonSettings-Output", - "default": { - "vat_rate": 1.19, - "elecpriceimport": {}, - "energycharts": { - "bidding_zone": "DE-LU" - }, - "providers": [ - "ElecPriceAkkudoktor", - "ElecPriceEnergyCharts", - "ElecPriceImport" - ] - } + "$ref": "#/components/schemas/ElecPriceCommonSettings-Output" }, "feedintariff": { - "$ref": "#/components/schemas/FeedInTariffCommonSettings-Output", - "default": { - "provider_settings": {}, - "providers": [ - "FeedInTariffFixed", - "FeedInTariffImport" - ] - } + "$ref": "#/components/schemas/FeedInTariffCommonSettings-Output" }, "load": { - "$ref": "#/components/schemas/LoadCommonSettings-Output", - "default": { - "provider_settings": {}, - "providers": [ - "LoadAkkudoktor", - "LoadAkkudoktorAdjusted", - "LoadVrm", - "LoadImport" - ] - } + "$ref": "#/components/schemas/LoadCommonSettings-Output" }, "pvforecast": { - "$ref": "#/components/schemas/PVForecastCommonSettings-Output", - "default": { - "provider_settings": {}, - "max_planes": 0, - "providers": [ - "PVForecastAkkudoktor", - "PVForecastVrm", - "PVForecastImport" - ], - "planes_peakpower": [], - "planes_azimuth": [], - "planes_tilt": [], - "planes_userhorizon": [], - "planes_inverter_paco": [] - } + "$ref": "#/components/schemas/PVForecastCommonSettings-Output" }, "weather": { - "$ref": "#/components/schemas/WeatherCommonSettings-Output", - "default": { - "provider_settings": {}, - "providers": [ - "BrightSky", - "ClearOutside", - "WeatherImport" - ] - } + "$ref": "#/components/schemas/WeatherCommonSettings-Output" }, "server": { - "$ref": "#/components/schemas/ServerCommonSettings", - "default": { - "host": "127.0.0.1", - "port": 8503, - "verbose": false, - "startup_eosdash": true - } + "$ref": "#/components/schemas/ServerCommonSettings" }, "utils": { - "$ref": "#/components/schemas/UtilsCommonSettings", - "default": {} + "$ref": "#/components/schemas/UtilsCommonSettings" }, "adapter": { - "$ref": "#/components/schemas/AdapterCommonSettings-Output", - "default": { - "homeassistant": { - "eos_device_instruction_entity_ids": [], - "eos_solution_entity_ids": [], - "homeassistant_entity_ids": [] - }, - "nodered": { - "host": "127.0.0.1", - "port": 1880 - }, - "providers": [ - "HomeAssistant", - "NodeRED" - ] - } + "$ref": "#/components/schemas/AdapterCommonSettings-Output" } }, "additionalProperties": false, "type": "object", "title": "ConfigEOS", - "description": "Singleton configuration handler for the EOS application.\n\nConfigEOS extends `SettingsEOS` with support for default configuration paths and automatic\ninitialization.\n\n`ConfigEOS` ensures that only one instance of the class is created throughout the application,\nallowing consistent access to EOS configuration settings. This singleton instance loads\nconfiguration data from a predefined set of directories or creates a default configuration if\nnone is found.\n\nInitialization Process:\n - Upon instantiation, the singleton instance attempts to load a configuration file in this order:\n 1. The directory specified by the `EOS_CONFIG_DIR` environment variable\n 2. The directory specified by the `EOS_DIR` environment variable.\n 3. A platform specific default directory for EOS.\n 4. The current working directory.\n - The first available configuration file found in these directories is loaded.\n - If no configuration file is found, a default configuration file is created in the platform\n specific default directory, and default settings are loaded into it.\n\nAttributes from the loaded configuration are accessible directly as instance attributes of\n`ConfigEOS`, providing a centralized, shared configuration object for EOS.\n\nSingleton Behavior:\n - This class uses the `SingletonMixin` to ensure that all requests for `ConfigEOS` return\n the same instance, which contains the most up-to-date configuration. Modifying the configuration\n in one part of the application reflects across all references to this class.\n\nAttributes:\n config_folder_path (Optional[Path]): Path to the configuration directory.\n config_file_path (Optional[Path]): Path to the configuration file.\n\nRaises:\n FileNotFoundError: If no configuration file is found, and creating a default configuration fails.\n\nExample:\n To initialize and access configuration attributes (only one instance is created):\n .. code-block:: python\n\n config_eos = ConfigEOS() # Always returns the same instance\n print(config_eos.prediction.hours) # Access a setting from the loaded configuration" + "description": "Singleton configuration handler for the EOS application.\n\nConfigEOS extends `SettingsEOS` with support for default configuration paths and automatic\ninitialization.\n\n`ConfigEOS` ensures that only one instance of the class is created throughout the application,\nallowing consistent access to EOS configuration settings. This singleton instance loads\nconfiguration data from a predefined set of directories or creates a default configuration if\nnone is found.\n\nInitialization Process:\n - Upon instantiation, the singleton instance attempts to load a configuration file in this order:\n 1. The directory specified by the `EOS_CONFIG_DIR` environment variable\n 2. The directory specified by the `EOS_DIR` environment variable.\n 3. A platform specific default directory for EOS.\n 4. The current working directory.\n - The first available configuration file found in these directories is loaded.\n - If no configuration file is found, a default configuration file is created in the platform\n specific default directory, and default settings are loaded into it.\n\nAttributes from the loaded configuration are accessible directly as instance attributes of\n`ConfigEOS`, providing a centralized, shared configuration object for EOS.\n\nSingleton Behavior:\n - This class uses the `SingletonMixin` to ensure that all requests for `ConfigEOS` return\n the same instance, which contains the most up-to-date configuration. Modifying the configuration\n in one part of the application reflects across all references to this class.\n\nRaises:\n FileNotFoundError: If no configuration file is found, and creating a default configuration fails.\n\nExample:\n To initialize and access configuration attributes (only one instance is created):\n .. code-block:: python\n\n config_eos = ConfigEOS() # Always returns the same instance\n print(config_eos.prediction.hours) # Access a setting from the loaded configuration" }, "DDBCActuatorStatus": { "properties": { @@ -2809,6 +2754,240 @@ "title": "DDBCInstruction", "description": "Instruction for Demand Driven Based Control (DDBC).\n\nContains information about when and how to activate a specific operation mode\nfor an actuator. Used to command resources to change their operation at a specified time." }, + "DatabaseCommonSettings-Input": { + "properties": { + "provider": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Provider", + "description": "Database provider id of provider to be used.", + "examples": [ + "LMDB" + ] + }, + "compression_level": { + "type": "integer", + "maximum": 9.0, + "minimum": 0.0, + "title": "Compression Level", + "description": "Compression level for database record data.", + "default": 9, + "examples": [ + 0, + 9 + ] + }, + "initial_load_window_h": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Initial Load Window H", + "description": "Specifies the default duration of the initial load window when loading records from the database, in hours. If set to None, the full available range is loaded. The window is centered around the current time by default, unless a different center time is specified. Different database namespaces may define their own default windows.", + "examples": [ + "48", + "None" + ] + }, + "keep_duration_h": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Keep Duration H", + "description": "Default maximum duration records shall be kept in database [hours, none].\nNone indicates forever. Database namespaces may have diverging definitions.", + "examples": [ + 48, + "none" + ] + }, + "autosave_interval_sec": { + "anyOf": [ + { + "type": "integer", + "minimum": 5.0 + }, + { + "type": "null" + } + ], + "title": "Autosave Interval Sec", + "description": "Automatic saving interval [seconds].\nSet to None to disable automatic saving.", + "default": 10, + "examples": [ + 5 + ] + }, + "compaction_interval_sec": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Compaction Interval Sec", + "description": "Interval in between automatic tiered compaction runs [seconds].\nCompaction downsamples old records to reduce storage while retaining coverage. Set to None to disable automatic compaction.", + "default": 604800, + "examples": [ + 604800 + ] + }, + "batch_size": { + "type": "integer", + "title": "Batch Size", + "description": "Number of records to process in batch operations.", + "default": 100, + "examples": [ + 100 + ] + } + }, + "type": "object", + "title": "DatabaseCommonSettings", + "description": "Configuration model for database settings.\n\nAttributes:\n provider: Optional provider identifier (e.g. \"LMDB\").\n max_records_in_memory: Maximum records kept in memory before auto-save.\n auto_save: Whether to auto-save when threshold exceeded.\n batch_size: Batch size for batch operations." + }, + "DatabaseCommonSettings-Output": { + "properties": { + "provider": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Provider", + "description": "Database provider id of provider to be used.", + "examples": [ + "LMDB" + ] + }, + "compression_level": { + "type": "integer", + "maximum": 9.0, + "minimum": 0.0, + "title": "Compression Level", + "description": "Compression level for database record data.", + "default": 9, + "examples": [ + 0, + 9 + ] + }, + "initial_load_window_h": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Initial Load Window H", + "description": "Specifies the default duration of the initial load window when loading records from the database, in hours. If set to None, the full available range is loaded. The window is centered around the current time by default, unless a different center time is specified. Different database namespaces may define their own default windows.", + "examples": [ + "48", + "None" + ] + }, + "keep_duration_h": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Keep Duration H", + "description": "Default maximum duration records shall be kept in database [hours, none].\nNone indicates forever. Database namespaces may have diverging definitions.", + "examples": [ + 48, + "none" + ] + }, + "autosave_interval_sec": { + "anyOf": [ + { + "type": "integer", + "minimum": 5.0 + }, + { + "type": "null" + } + ], + "title": "Autosave Interval Sec", + "description": "Automatic saving interval [seconds].\nSet to None to disable automatic saving.", + "default": 10, + "examples": [ + 5 + ] + }, + "compaction_interval_sec": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Compaction Interval Sec", + "description": "Interval in between automatic tiered compaction runs [seconds].\nCompaction downsamples old records to reduce storage while retaining coverage. Set to None to disable automatic compaction.", + "default": 604800, + "examples": [ + 604800 + ] + }, + "batch_size": { + "type": "integer", + "title": "Batch Size", + "description": "Number of records to process in batch operations.", + "default": 100, + "examples": [ + 100 + ] + }, + "providers": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Providers", + "description": "Return available database provider ids.", + "readOnly": true + } + }, + "type": "object", + "required": [ + "providers" + ], + "title": "DatabaseCommonSettings", + "description": "Configuration model for database settings.\n\nAttributes:\n provider: Optional provider identifier (e.g. \"LMDB\").\n max_records_in_memory: Maximum records kept in memory before auto-save.\n auto_save: Whether to auto-save when threshold exceeded.\n batch_size: Batch size for batch operations." + }, "DevicesCommonSettings-Input": { "properties": { "batteries": { @@ -3583,16 +3762,11 @@ "default": 5 }, "interval": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], + "type": "number", + "minimum": 60.0, "title": "Interval", - "description": "Intervall in seconds between EOS energy management runs.", + "description": "Intervall between EOS energy management runs [seconds].", + "default": 300.0, "examples": [ "300" ] @@ -4268,28 +4442,22 @@ }, "GeneralSettings-Input": { "properties": { + "home_assistant_addon": { + "type": "boolean", + "title": "Home Assistant Addon", + "description": "EOS is running as home assistant add-on." + }, "version": { "type": "string", "title": "Version", "description": "Configuration file version. Used to check compatibility.", - "default": "0.2.0.dev84352035" + "default": "0.2.0.dev58204789" }, "data_folder_path": { - "anyOf": [ - { - "type": "string", - "format": "path" - }, - { - "type": "null" - } - ], + "type": "string", + "format": "path", "title": "Data Folder Path", - "description": "Path to EOS data directory.", - "examples": [ - null, - "/home/eos/data" - ] + "description": "Path to EOS data folder." }, "data_output_subpath": { "anyOf": [ @@ -4302,7 +4470,7 @@ } ], "title": "Data Output Subpath", - "description": "Sub-path for the EOS output data directory.", + "description": "Sub-path for the EOS output data folder.", "default": "output" }, "latitude": { @@ -4346,24 +4514,13 @@ "type": "string", "title": "Version", "description": "Configuration file version. Used to check compatibility.", - "default": "0.2.0.dev84352035" + "default": "0.2.0.dev58204789" }, "data_folder_path": { - "anyOf": [ - { - "type": "string", - "format": "path" - }, - { - "type": "null" - } - ], + "type": "string", + "format": "path", "title": "Data Folder Path", - "description": "Path to EOS data directory.", - "examples": [ - null, - "/home/eos/data" - ] + "description": "Path to EOS data folder." }, "data_output_subpath": { "anyOf": [ @@ -4376,7 +4533,7 @@ } ], "title": "Data Output Subpath", - "description": "Sub-path for the EOS output data directory.", + "description": "Sub-path for the EOS output data folder.", "default": "output" }, "latitude": { @@ -4463,12 +4620,6 @@ "title": "Config File Path", "description": "Path to EOS configuration file.", "readOnly": true - }, - "home_assistant_addon": { - "type": "boolean", - "title": "Home Assistant Addon", - "description": "EOS is running as home assistant add-on.", - "readOnly": true } }, "type": "object", @@ -4476,8 +4627,7 @@ "timezone", "data_output_path", "config_folder_path", - "config_file_path", - "home_assistant_addon" + "config_file_path" ], "title": "GeneralSettings", "description": "General settings." @@ -6091,6 +6241,23 @@ }, "MeasurementCommonSettings-Input": { "properties": { + "historic_hours": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Historic Hours", + "description": "Number of hours into the past for measurement data", + "default": 17520, + "examples": [ + 17520 + ] + }, "load_emr_keys": { "anyOf": [ { @@ -6178,6 +6345,23 @@ }, "MeasurementCommonSettings-Output": { "properties": { + "historic_hours": { + "anyOf": [ + { + "type": "integer", + "minimum": 0.0 + }, + { + "type": "null" + } + ], + "title": "Historic Hours", + "description": "Number of hours into the past for measurement data", + "default": 17520, + "examples": [ + 17520 + ] + }, "load_emr_keys": { "anyOf": [ { @@ -8105,6 +8289,17 @@ ], "description": "Cache Settings" }, + "database": { + "anyOf": [ + { + "$ref": "#/components/schemas/DatabaseCommonSettings-Input" + }, + { + "type": "null" + } + ], + "description": "Database Settings" + }, "ems": { "anyOf": [ { diff --git a/requirements.txt b/requirements.txt index 5e910e1..089687f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,12 +14,11 @@ markdown-it-py==4.0.0 mdit-py-plugins==0.5.0 bokeh==3.8.2 uvicorn==0.40.0 -scikit-learn==1.8.0 scipy==1.17.0 tzfpy==1.1.1 deap==1.4.3 requests==2.32.5 -pandas==2.3.3 +pandas==3.0.0 pendulum==3.2.0 platformdirs==4.9.2 psutil==7.2.2 @@ -27,6 +26,7 @@ pvlib==0.15.0 pydantic==2.12.5 pydantic_extra_types==2.11.0 statsmodels==0.14.6 -pydantic-settings==2.11.0 +pydantic-settings==2.12.0 linkify-it-py==2.0.3 loguru==0.7.3 +lmdb==1.7.5 diff --git a/scripts/generate_config_md.py b/scripts/generate_config_md.py index b88a2c1..e902e46 100755 --- a/scripts/generate_config_md.py +++ b/scripts/generate_config_md.py @@ -14,7 +14,8 @@ from loguru import logger from pydantic.fields import ComputedFieldInfo, FieldInfo from pydantic_core import PydanticUndefined -from akkudoktoreos.config.config import ConfigEOS, GeneralSettings, get_config +from akkudoktoreos.config.config import ConfigEOS, default_data_folder_path +from akkudoktoreos.core.coreabc import get_config, singletons_init from akkudoktoreos.core.pydantic import PydanticBaseModel from akkudoktoreos.utils.datetimeutil import to_datetime @@ -361,12 +362,6 @@ def generate_config_md(file_path: Optional[Union[str, Path]], config_eos: Config Returns: str: The Markdown representation of the configuration spec. """ - # Fix file path for general settings to not show local/test file path - GeneralSettings._config_file_path = Path( - "/home/user/.config/net.akkudoktoreos.net/EOS.config.json" - ) - GeneralSettings._config_folder_path = config_eos.general.config_file_path.parent - markdown = "" if file_path: @@ -446,6 +441,19 @@ def write_to_file(file_path: Optional[Union[str, Path]], config_md: str): '/home/user/.local/share/net.akkudoktor.eos/output/eos.log', config_md ) + # Assure pathes are set to default for documentation + replacements = [ + ("data_folder_path", "/home/user/.local/share/net.akkudoktoreos.net"), + ("data_output_path", "/home/user/.local/share/net.akkudoktoreos.net/output"), + ("config_folder_path", "/home/user/.config/net.akkudoktoreos.net"), + ("config_file_path", "/home/user/.config/net.akkudoktoreos.net/EOS.config.json"), + ] + for key, value in replacements: + config_md = re.sub( + rf'("{key}":\s*)"[^"]*"', + rf'\1"{value}"', + config_md + ) # Assure timezone name does not leak to documentation tz_name = to_datetime().timezone_name @@ -477,16 +485,31 @@ def main(): ) args = parser.parse_args() - config_eos = get_config() + + # Ensure we are in documentation mode + ConfigEOS._force_documentation_mode = True + + # Make minimal config to make the generation reproducable + config_eos = get_config(init={ + "with_init_settings": True, + "with_env_settings": False, + "with_dotenv_settings": False, + "with_file_settings": False, + "with_file_secret_settings": False, + }) + + # Also init other singletons to get same list of e.g. providers + singletons_init() try: config_md = generate_config_md(args.output_file, config_eos) - except Exception as e: print(f"Error during Configuration Specification generation: {e}", file=sys.stderr) # keep throwing error to debug potential problems (e.g. invalid examples) raise e - + finally: + # Ensure we are out of documentation mode + ConfigEOS._force_documentation_mode = False if __name__ == "__main__": main() diff --git a/scripts/generate_openapi.py b/scripts/generate_openapi.py index 13ec283..04bb90d 100755 --- a/scripts/generate_openapi.py +++ b/scripts/generate_openapi.py @@ -19,32 +19,44 @@ import json import os import sys -from fastapi.openapi.utils import get_openapi - +from akkudoktoreos.core.coreabc import get_config from akkudoktoreos.server.eos import app def generate_openapi() -> dict: - """Generate the OpenAPI specification. + # Make minimal config to make the generation reproducable + config_eos = get_config(init={ + "with_init_settings": True, + "with_env_settings": False, + "with_dotenv_settings": False, + "with_file_settings": False, + "with_file_secret_settings": False, + }) - Returns: - openapi_spec (dict): OpenAPI specification. - """ - openapi_spec = get_openapi( - title=app.title, - version=app.version, - openapi_version=app.openapi_version, - description=app.description, - routes=app.routes, + openapi_spec = app.openapi() + + config_schema = ( + openapi_spec + .get("components", {}) + .get("schemas", {}) + .get("ConfigEOS", {}) + .get("properties", {}) ) - # Fix file path for general settings to not show local/test file path - general = openapi_spec["components"]["schemas"]["ConfigEOS"]["properties"]["general"]["default"] - general["config_file_path"] = "/home/user/.config/net.akkudoktoreos.net/EOS.config.json" - general["config_folder_path"] = "/home/user/.config/net.akkudoktoreos.net" - # Fix file path for logging settings to not show local/test file path - logging = openapi_spec["components"]["schemas"]["ConfigEOS"]["properties"]["logging"]["default"] - logging["file_path"] = "/home/user/.local/share/net.akkudoktoreos.net/output/eos.log" + # ---- General settings ---- + general = config_schema.get("general", {}).get("default") + if general: + general.update({ + "config_file_path": "/home/user/.config/net.akkudoktoreos.net/EOS.config.json", + "config_folder_path": "/home/user/.config/net.akkudoktoreos.net", + "data_folder_path": "/home/user/.local/share/net.akkudoktoreos.net", + "data_output_path": "/home/user/.local/share/net.akkudoktoreos.net/output", + }) + + # ---- Logging settings ---- + logging_cfg = config_schema.get("logging", {}).get("default") + if logging_cfg: + logging_cfg["file_path"] = "/home/user/.local/share/net.akkudoktoreos.net/output/eos.log" return openapi_spec diff --git a/src/akkudoktoreos/adapter/adapter.py b/src/akkudoktoreos/adapter/adapter.py index 3aa07bd..d19ced4 100644 --- a/src/akkudoktoreos/adapter/adapter.py +++ b/src/akkudoktoreos/adapter/adapter.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union from pydantic import Field, computed_field, field_validator @@ -10,9 +10,6 @@ from akkudoktoreos.adapter.homeassistant import ( from akkudoktoreos.adapter.nodered import NodeREDAdapter, NodeREDAdapterCommonSettings from akkudoktoreos.config.configabc import SettingsBaseModel -if TYPE_CHECKING: - adapter_providers: list[str] - class AdapterCommonSettings(SettingsBaseModel): """Adapter Configuration.""" @@ -38,8 +35,9 @@ class AdapterCommonSettings(SettingsBaseModel): @computed_field # type: ignore[prop-decorator] @property def providers(self) -> list[str]: - """Available electricity price provider ids.""" - return adapter_providers + """Available adapter provider ids.""" + adapter_provider_ids = [provider.provider_id() for provider in adapter_providers()] + return adapter_provider_ids # Validators @field_validator("provider", mode="after") @@ -47,48 +45,39 @@ class AdapterCommonSettings(SettingsBaseModel): def validate_provider(cls, value: Optional[list[str]]) -> Optional[list[str]]: if value is None: return value + adapter_provider_ids = [provider.provider_id() for provider in adapter_providers()] for provider_id in value: - if provider_id not in adapter_providers: + if provider_id not in adapter_provider_ids: raise ValueError( - f"Provider '{value}' is not a valid adapter provider: {adapter_providers}." + f"Provider '{value}' is not a valid adapter provider: {adapter_provider_ids}." ) return value -class Adapter(AdapterContainer): - """Adapter container to manage multiple adapter providers. - - Attributes: - providers (List[Union[PVForecastAkkudoktor, WeatherBrightSky, WeatherClearOutside]]): - List of forecast provider instances, in the order they should be updated. - Providers may depend on updates from others. - """ - - providers: list[ - Union[ - HomeAssistantAdapter, - NodeREDAdapter, - ] - ] = Field(default_factory=list, json_schema_extra={"description": "List of adapter providers"}) - - # Initialize adapter providers, all are singletons. homeassistant_adapter = HomeAssistantAdapter() nodered_adapter = NodeREDAdapter() -def get_adapter() -> Adapter: - """Gets the EOS adapter data.""" - # Initialize Adapter instance with providers in the required order - # Care for provider sequence as providers may rely on others to be updated before. - adapter = Adapter( - providers=[ - homeassistant_adapter, - nodered_adapter, +def adapter_providers() -> list[Union["HomeAssistantAdapter", "NodeREDAdapter"]]: + """Return list of adapter providers.""" + global homeassistant_adapter, nodered_adapter + + return [ + homeassistant_adapter, + nodered_adapter, + ] + + +class Adapter(AdapterContainer): + """Adapter container to manage multiple adapter providers.""" + + providers: list[ + Union[ + HomeAssistantAdapter, + NodeREDAdapter, ] + ] = Field( + default_factory=adapter_providers, + json_schema_extra={"description": "List of adapter providers"}, ) - return adapter - - -# Valid adapter providers -adapter_providers = [provider.provider_id() for provider in get_adapter().providers] diff --git a/src/akkudoktoreos/adapter/homeassistant.py b/src/akkudoktoreos/adapter/homeassistant.py index 4957e9b..f7b2fb2 100644 --- a/src/akkudoktoreos/adapter/homeassistant.py +++ b/src/akkudoktoreos/adapter/homeassistant.py @@ -10,12 +10,12 @@ from pydantic import Field, computed_field, field_validator from akkudoktoreos.adapter.adapterabc import AdapterProvider from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_adapter from akkudoktoreos.core.emplan import ( DDBCInstruction, FRBCInstruction, ) from akkudoktoreos.core.ems import EnergyManagementStage -from akkudoktoreos.devices.devices import get_resource_registry from akkudoktoreos.utils.datetimeutil import to_datetime # Supervisor API endpoint and token (injected automatically in add-on container) @@ -29,8 +29,6 @@ HEADERS = { HOMEASSISTANT_ENTITY_ID_PREFIX = "sensor.eos_" -resources_eos = get_resource_registry() - class HomeAssistantAdapterCommonSettings(SettingsBaseModel): """Common settings for the home assistant adapter.""" @@ -146,8 +144,6 @@ class HomeAssistantAdapterCommonSettings(SettingsBaseModel): def homeassistant_entity_ids(self) -> list[str]: """Entity IDs available at Home Assistant.""" try: - from akkudoktoreos.adapter.adapter import get_adapter - adapter_eos = get_adapter() result = adapter_eos.provider_by_id("HomeAssistant").get_homeassistant_entity_ids() except: @@ -159,8 +155,6 @@ class HomeAssistantAdapterCommonSettings(SettingsBaseModel): def eos_solution_entity_ids(self) -> list[str]: """Entity IDs for optimization solution available at EOS.""" try: - from akkudoktoreos.adapter.adapter import get_adapter - adapter_eos = get_adapter() result = adapter_eos.provider_by_id("HomeAssistant").get_eos_solution_entity_ids() except: @@ -172,8 +166,6 @@ class HomeAssistantAdapterCommonSettings(SettingsBaseModel): def eos_device_instruction_entity_ids(self) -> list[str]: """Entity IDs for energy management instructions available at EOS.""" try: - from akkudoktoreos.adapter.adapter import get_adapter - adapter_eos = get_adapter() result = adapter_eos.provider_by_id( "HomeAssistant" diff --git a/src/akkudoktoreos/config/config.py b/src/akkudoktoreos/config/config.py index a9c8a76..cd21e93 100644 --- a/src/akkudoktoreos/config/config.py +++ b/src/akkudoktoreos/config/config.py @@ -11,6 +11,7 @@ Key features: import json import os +import sys import tempfile from pathlib import Path from typing import Any, ClassVar, Optional, Type, Union @@ -26,6 +27,7 @@ from akkudoktoreos.config.configabc import SettingsBaseModel from akkudoktoreos.config.configmigrate import migrate_config_data, migrate_config_file from akkudoktoreos.core.cachesettings import CacheCommonSettings from akkudoktoreos.core.coreabc import SingletonMixin +from akkudoktoreos.core.database import DatabaseCommonSettings from akkudoktoreos.core.decorators import classproperty from akkudoktoreos.core.emsettings import ( EnergyManagementCommonSettings, @@ -65,16 +67,66 @@ def get_absolute_path( return None +def is_home_assistant_addon() -> bool: + """Detect Home Assistant add-on environment. + + Home Assistant sets this environment variable automatically. + """ + return "HASSIO_TOKEN" in os.environ or "SUPERVISOR_TOKEN" in os.environ + + +def default_data_folder_path() -> Path: + """Provide default data folder path. + + 1. From EOS_DATA_DIR env + 2. From EOS_DIR env + 3. From platform specific default path + 4. Current working directory + + Note: + When running as Home Assistant add-on the path is fixed to /data. + """ + if is_home_assistant_addon(): + return Path("/data") + + # 1. From EOS_DATA_DIR env + if env_dir := os.getenv(ConfigEOS.EOS_DATA_DIR): + try: + data_dir = Path(env_dir).resolve() + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + except Exception as e: + logger.warning(f"Could not setup data folder {data_dir}: {e}") + + # 2. From EOS_DIR env + if env_dir := os.getenv(ConfigEOS.EOS_DIR): + try: + data_dir = Path(env_dir).resolve() + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + except Exception as e: + logger.warning(f"Could not setup data folder {data_dir}: {e}") + + # 3. From platform specific default path + try: + data_dir = Path(user_data_dir(ConfigEOS.APP_NAME, ConfigEOS.APP_AUTHOR)) + if data_dir is not None: + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + except Exception as e: + logger.warning(f"Could not setup data folder {data_dir}: {e}") + + # 4. Current working directory + return Path.cwd() + + class GeneralSettings(SettingsBaseModel): """General settings.""" - _config_folder_path: ClassVar[Optional[Path]] = None - _config_file_path: ClassVar[Optional[Path]] = None - - # Detect Home Assistant add-on environment - # Home Assistant sets this environment variable automatically - _home_assistant_addon: ClassVar[bool] = ( - "HASSIO_TOKEN" in os.environ or "SUPERVISOR_TOKEN" in os.environ + home_assistant_addon: bool = Field( + default_factory=is_home_assistant_addon, + json_schema_extra={"description": "EOS is running as home assistant add-on."}, + exclude=True, ) version: str = Field( @@ -84,17 +136,16 @@ class GeneralSettings(SettingsBaseModel): }, ) - data_folder_path: Optional[Path] = Field( - default=None, + data_folder_path: Path = Field( + default_factory=default_data_folder_path, json_schema_extra={ - "description": "Path to EOS data directory.", - "examples": [None, "/home/eos/data"], + "description": "Path to EOS data folder.", }, ) data_output_subpath: Optional[Path] = Field( default="output", - json_schema_extra={"description": "Sub-path for the EOS output data directory."}, + json_schema_extra={"description": "Sub-path for the EOS output data folder."}, ) latitude: Optional[float] = Field( @@ -134,19 +185,13 @@ class GeneralSettings(SettingsBaseModel): @property def config_folder_path(self) -> Optional[Path]: """Path to EOS configuration directory.""" - return self._config_folder_path + return self.config._config_file_path.parent @computed_field # type: ignore[prop-decorator] @property def config_file_path(self) -> Optional[Path]: """Path to EOS configuration file.""" - return self._config_file_path - - @computed_field # type: ignore[prop-decorator] - @property - def home_assistant_addon(self) -> bool: - """EOS is running as home assistant add-on.""" - return self._home_assistant_addon + return self.config._config_file_path compatible_versions: ClassVar[list[str]] = [__version__] @@ -164,17 +209,19 @@ class GeneralSettings(SettingsBaseModel): @field_validator("data_folder_path", mode="after") @classmethod - def validate_data_folder_path(cls, value: Optional[Union[str, Path]]) -> Optional[Path]: + def validate_data_folder_path(cls, value: Optional[Union[str, Path]]) -> Path: """Ensure dir is available.""" - if cls._home_assistant_addon: + if is_home_assistant_addon(): # Force to home assistant add-on /data directory return Path("/data") if value is None: - return None + return default_data_folder_path() if isinstance(value, str): value = Path(value) - value.resolve() - if not value.is_dir(): + try: + value.resolve() + value.mkdir(parents=True, exist_ok=True) + except Exception: raise ValueError(f"Data folder path '{value}' is not a directory.") return value @@ -191,6 +238,9 @@ class SettingsEOS(pydantic_settings.BaseSettings, PydanticModelNestedValueMixin) cache: Optional[CacheCommonSettings] = Field( default=None, json_schema_extra={"description": "Cache Settings"} ) + database: Optional[DatabaseCommonSettings] = Field( + default=None, json_schema_extra={"description": "Database Settings"} + ) ems: Optional[EnergyManagementCommonSettings] = Field( default=None, json_schema_extra={"description": "Energy Management Settings"} ) @@ -248,22 +298,23 @@ class SettingsEOSDefaults(SettingsEOS): Used by ConfigEOS instance to make all fields available. """ - general: GeneralSettings = GeneralSettings() - cache: CacheCommonSettings = CacheCommonSettings() - ems: EnergyManagementCommonSettings = EnergyManagementCommonSettings() - logging: LoggingCommonSettings = LoggingCommonSettings() - devices: DevicesCommonSettings = DevicesCommonSettings() - measurement: MeasurementCommonSettings = MeasurementCommonSettings() - optimization: OptimizationCommonSettings = OptimizationCommonSettings() - prediction: PredictionCommonSettings = PredictionCommonSettings() - elecprice: ElecPriceCommonSettings = ElecPriceCommonSettings() - feedintariff: FeedInTariffCommonSettings = FeedInTariffCommonSettings() - load: LoadCommonSettings = LoadCommonSettings() - pvforecast: PVForecastCommonSettings = PVForecastCommonSettings() - weather: WeatherCommonSettings = WeatherCommonSettings() - server: ServerCommonSettings = ServerCommonSettings() - utils: UtilsCommonSettings = UtilsCommonSettings() - adapter: AdapterCommonSettings = AdapterCommonSettings() + general: GeneralSettings = Field(default_factory=GeneralSettings) + cache: CacheCommonSettings = Field(default_factory=CacheCommonSettings) + database: DatabaseCommonSettings = Field(default_factory=DatabaseCommonSettings) + ems: EnergyManagementCommonSettings = Field(default_factory=EnergyManagementCommonSettings) + logging: LoggingCommonSettings = Field(default_factory=LoggingCommonSettings) + devices: DevicesCommonSettings = Field(default_factory=DevicesCommonSettings) + measurement: MeasurementCommonSettings = Field(default_factory=MeasurementCommonSettings) + optimization: OptimizationCommonSettings = Field(default_factory=OptimizationCommonSettings) + prediction: PredictionCommonSettings = Field(default_factory=PredictionCommonSettings) + elecprice: ElecPriceCommonSettings = Field(default_factory=ElecPriceCommonSettings) + feedintariff: FeedInTariffCommonSettings = Field(default_factory=FeedInTariffCommonSettings) + load: LoadCommonSettings = Field(default_factory=LoadCommonSettings) + pvforecast: PVForecastCommonSettings = Field(default_factory=PVForecastCommonSettings) + weather: WeatherCommonSettings = Field(default_factory=WeatherCommonSettings) + server: ServerCommonSettings = Field(default_factory=ServerCommonSettings) + utils: UtilsCommonSettings = Field(default_factory=UtilsCommonSettings) + adapter: AdapterCommonSettings = Field(default_factory=AdapterCommonSettings) def __hash__(self) -> int: # Just for usage in configmigrate, finally overwritten when used by ConfigEOS. @@ -300,10 +351,6 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): the same instance, which contains the most up-to-date configuration. Modifying the configuration in one part of the application reflects across all references to this class. - Attributes: - config_folder_path (Optional[Path]): Path to the configuration directory. - config_file_path (Optional[Path]): Path to the configuration file. - Raises: FileNotFoundError: If no configuration file is found, and creating a default configuration fails. @@ -323,6 +370,15 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): EOS_CONFIG_DIR: ClassVar[str] = "EOS_CONFIG_DIR" ENCODING: ClassVar[str] = "UTF-8" CONFIG_FILE_NAME: ClassVar[str] = "EOS.config.json" + _init_config_eos: ClassVar[dict[str, bool]] = { + "with_init_settings": True, + "with_env_settings": True, + "with_dotenv_settings": True, + "with_file_settings": True, + "with_file_secret_settings": True, + } + _config_file_path: ClassVar[Optional[Path]] = None + _force_documentation_mode = False def __hash__(self) -> int: # ConfigEOS is a singleton @@ -377,31 +433,156 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): configuration directory cannot be created. - It ensures that a fallback to a default configuration file is always possible. """ - # Ensure we know and have the config folder path and the config file - config_file = cls._setup_config_file() + + def lazy_config_file_settings() -> dict: + """Config file settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + config_file_path, exists = cls._get_config_file_path() + if not exists: + # Create minimum config file + config_minimum_content = '{ "general": { "version": "' + __version__ + '" } }' + if config_file_path.is_relative_to(ConfigEOS.package_root_path): + # Never write into package directory + error_msg = ( + f"Could not create minimum config file. " + f"Config file path '{config_file_path}' is within package root " + f"'{ConfigEOS.package_root_path}'" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + try: + config_file_path.parent.mkdir(parents=True, exist_ok=True) + config_file_path.write_text(config_minimum_content, encoding="utf-8") + except Exception as exc: + # Create minimum config in temporary config directory as last resort + error_msg = ( + f"Could not create minimum config file in {config_file_path.parent}: {exc}" + ) + logger.error(error_msg) + temp_dir = Path(tempfile.mkdtemp()) + info_msg = f"Using temporary config directory {temp_dir}" + logger.info(info_msg) + config_file_path = temp_dir / config_file_path.name + config_file_path.write_text(config_minimum_content, encoding="utf-8") + + # Remember for other lazy settings and computed_field + cls._config_file_path = config_file_path + + return {} + + def lazy_data_folder_path_settings() -> dict: + """Data folder path settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + # Updates path to the data directory. + data_folder_settings = { + "general": { + "data_folder_path": default_data_folder_path(), + }, + } + + return data_folder_settings + + def lazy_init_settings() -> dict: + """Init settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + if not cls._init_config_eos.get("with_init_settings", True): + logger.debug("Config initialisation with init settings is disabled.") + return {} + + settings = init_settings() + + return settings + + def lazy_env_settings() -> dict: + """Env settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + if not cls._init_config_eos.get("with_env_settings", True): + logger.debug("Config initialisation with env settings is disabled.") + return {} + + return env_settings() + + def lazy_dotenv_settings() -> dict: + """Dotenv settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + if not cls._init_config_eos.get("with_dotenv_settings", True): + logger.debug("Config initialisation with dotenv settings is disabled.") + return {} + + return dotenv_settings() + + def lazy_file_settings() -> dict: + """File settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + + Ensures the config file exists and creates a backup if necessary. + """ + if not cls._init_config_eos.get("with_file_settings", True): + logger.debug("Config initialisation with file settings is disabled.") + return {} + + config_file = cls._config_file_path # provided by lazy_config_file_settings + if config_file is None: + # This should not happen + raise RuntimeError("Config file path not set.") + + try: + backup_file = config_file.with_suffix(f".{to_datetime(as_string='YYYYMMDDHHmmss')}") + if migrate_config_file(config_file, backup_file): + # If the config file does have the correct version add it as settings source + settings = pydantic_settings.JsonConfigSettingsSource( + settings_cls, json_file=config_file + )() + except Exception as ex: + logger.error( + f"Error reading config file '{config_file}' (falling back to default config): {ex}" + ) + settings = {} + + return settings + + def lazy_file_secret_settings() -> dict: + """File secret settings. + + This function runs at **instance creation**, not class definition. Ensures if ConfigEOS + is recreated this function is run. + """ + if not cls._init_config_eos.get("with_file_secret_settings", True): + logger.debug("Config initialisation with file secret settings is disabled.") + return {} + + return file_secret_settings() # All the settings sources in priority sequence + # The settings are all lazyly evaluated at instance creation time to allow for + # runtime configuration. setting_sources = [ - init_settings, - env_settings, - dotenv_settings, + lazy_config_file_settings, # Prio high + lazy_init_settings, + lazy_env_settings, + lazy_dotenv_settings, + lazy_file_settings, + lazy_data_folder_path_settings, + lazy_file_secret_settings, # Prio low ] - # Append file settings to sources - file_settings: Optional[pydantic_settings.JsonConfigSettingsSource] = None - try: - backup_file = config_file.with_suffix(f".{to_datetime(as_string='YYYYMMDDHHmmss')}") - if migrate_config_file(config_file, backup_file): - # If the config file does have the correct version add it as settings source - file_settings = pydantic_settings.JsonConfigSettingsSource( - settings_cls, json_file=config_file - ) - setting_sources.append(file_settings) - except Exception as ex: - logger.error( - f"Error reading config file '{config_file}' (falling back to default config): {ex}" - ) - return tuple(setting_sources) @classproperty @@ -409,30 +590,41 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): """Compute the package root path.""" return Path(__file__).parent.parent.resolve() + @classmethod + def documentation_mode(cls) -> bool: + """Are we running in documentation mode. + + Some checks may be relaxed to allow for proper documentation execution. + """ + # Detect if Sphinx is importing this module + is_sphinx = "sphinx" in sys.modules or getattr(sys, "_called_from_sphinx", False) + return cls._force_documentation_mode or is_sphinx + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initializes the singleton ConfigEOS instance. Configuration data is loaded from a configuration file or a default one is created if none exists. """ - logger.debug("Config init with parameters {} {}", args, kwargs) # Check for singleton guard if hasattr(self, "_initialized"): + logger.debug("Config init called again with parameters {} {}", args, kwargs) return + logger.debug("Config init with parameters {} {}", args, kwargs) self._setup(self, *args, **kwargs) def _setup(self, *args: Any, **kwargs: Any) -> None: """Re-initialize global settings.""" logger.debug("Config setup with parameters {} {}", args, kwargs) + # Assure settings base knows the singleton EOS configuration SettingsBaseModel.config = self + # (Re-)load settings - call base class init SettingsEOSDefaults.__init__(self, *args, **kwargs) - # Init config file and data folder pathes - self._setup_config_file() - self._update_data_folder_path() + self._initialized = True - logger.debug("Config setup:\n{}", self) + logger.debug(f"Config setup:\n{self}") def merge_settings(self, settings: SettingsEOS) -> None: """Merges the provided settings into the global settings for EOS, with optional overwrite. @@ -562,48 +754,6 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): return result - def _update_data_folder_path(self) -> None: - """Updates path to the data directory.""" - # From Settings - if data_dir := self.general.data_folder_path: - try: - data_dir.mkdir(parents=True, exist_ok=True) - self.general.data_folder_path = data_dir - return - except Exception as e: - logger.warning(f"Could not setup data dir {data_dir}: {e}") - # From EOS_DATA_DIR env - if env_dir := os.getenv(self.EOS_DATA_DIR): - try: - data_dir = Path(env_dir).resolve() - data_dir.mkdir(parents=True, exist_ok=True) - self.general.data_folder_path = data_dir - return - except Exception as e: - logger.warning(f"Could not setup data dir {data_dir}: {e}") - # From EOS_DIR env - if env_dir := os.getenv(self.EOS_DIR): - try: - data_dir = Path(env_dir).resolve() - data_dir.mkdir(parents=True, exist_ok=True) - self.general.data_folder_path = data_dir - return - except Exception as e: - logger.warning(f"Could not setup data dir {data_dir}: {e}") - # From platform specific default path - try: - data_dir = Path(user_data_dir(self.APP_NAME, self.APP_AUTHOR)) - if data_dir is not None: - data_dir.mkdir(parents=True, exist_ok=True) - self.general.data_folder_path = data_dir - return - except Exception as e: - logger.warning(f"Could not setup data dir {data_dir}: {e}") - # Current working directory - data_dir = Path.cwd() - logger.warning(f"Using data dir {data_dir}") - self.general.data_folder_path = data_dir - @classmethod def _get_config_file_path(cls) -> tuple[Path, bool]: """Find a valid configuration file or return the desired path for a new config file. @@ -618,32 +768,80 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): Returns: tuple[Path, bool]: The path to the configuration file and if there is already a config file there """ - if GeneralSettings._home_assistant_addon: + if is_home_assistant_addon(): # Only /data is persistent for home assistant add-on cfile = Path("/data/config") / cls.CONFIG_FILE_NAME logger.debug(f"Config file forced to: '{cfile}'") return cfile, cfile.exists() config_dirs = [] - env_eos_dir = os.getenv(cls.EOS_DIR) - logger.debug(f"Environment EOS_DIR: '{env_eos_dir}'") - env_eos_config_dir = os.getenv(cls.EOS_CONFIG_DIR) - logger.debug(f"Environment EOS_CONFIG_DIR: '{env_eos_config_dir}'") - env_config_dir = get_absolute_path(env_eos_dir, env_eos_config_dir) - logger.debug(f"Resulting environment config dir: '{env_config_dir}'") + # 1. Directory specified by EOS_CONFIG_DIR + config_dir: Optional[Union[Path, str]] = os.getenv(cls.EOS_CONFIG_DIR) + if config_dir: + logger.debug(f"Environment EOS_CONFIG_DIR: '{config_dir}'") + config_dir = Path(config_dir).resolve() + if config_dir.exists(): + config_dirs.append(config_dir) + else: + logger.info(f"Environment EOS_CONFIG_DIR: '{config_dir}' does not exist.") - if env_config_dir is not None: - config_dirs.append(env_config_dir.resolve()) - config_dirs.append(Path(user_config_dir(cls.APP_NAME, cls.APP_AUTHOR))) - config_dirs.append(Path.cwd()) + # 2. Directory specified by EOS_DIR / EOS_CONFIG_DIR + eos_dir = os.getenv(cls.EOS_DIR) + eos_config_dir = os.getenv(cls.EOS_CONFIG_DIR) + if eos_dir and eos_config_dir: + logger.debug(f"Environment EOS_DIR/EOS_CONFIG_DIR: '{eos_dir}/{eos_config_dir}'") + config_dir = get_absolute_path(eos_dir, eos_config_dir) + if config_dir: + config_dir = Path(config_dir).resolve() + if config_dir.exists(): + config_dirs.append(config_dir) + else: + logger.info( + f"Environment EOS_DIR/EOS_CONFIG_DIR: '{config_dir}' does not exist." + ) + else: + logger.debug( + f"Environment EOS_DIR/EOS_CONFIG_DIR: '{eos_dir}/{eos_config_dir}' not a valid path" + ) + + # 3. Directory specified by EOS_DIR + config_dir = os.getenv(cls.EOS_DIR) + if config_dir: + logger.debug(f"Environment EOS_DIR: '{config_dir}'") + config_dir = Path(config_dir).resolve() + if config_dir.exists(): + config_dirs.append(config_dir) + else: + logger.info(f"Environment EOS_DIR: '{config_dir}' does not exist.") + + # 4. User configuration directory + config_dir = Path(user_config_dir(cls.APP_NAME, cls.APP_AUTHOR)).resolve() + logger.debug(f"User config dir: '{config_dir}'") + if config_dir.exists(): + config_dirs.append(config_dir) + else: + logger.info(f"User config dir: '{config_dir}' does not exist.") + + # 5. Current working directory + config_dir = Path.cwd() + logger.debug(f"Current working dir: '{config_dir}'") + if config_dir.exists(): + config_dirs.append(config_dir) + else: + logger.info(f"Current working dir: '{config_dir}' does not exist.") + + # Search for file for cdir in config_dirs: cfile = cdir.joinpath(cls.CONFIG_FILE_NAME) if cfile.exists(): logger.debug(f"Found config file: '{cfile}'") return cfile, True - return config_dirs[0].joinpath(cls.CONFIG_FILE_NAME), False + # Return highest priority directory with standard file name appended + default_config_file = config_dirs[0].joinpath(cls.CONFIG_FILE_NAME) + logger.debug(f"No config file found. Defaulting to: '{default_config_file}'") + return default_config_file, False @classmethod def _setup_config_file(cls) -> Path: @@ -714,8 +912,3 @@ class ConfigEOS(SingletonMixin, SettingsEOSDefaults): The first non None value in priority order is taken. """ self._setup(**self.model_dump()) - - -def get_config() -> ConfigEOS: - """Gets the EOS configuration data.""" - return ConfigEOS() diff --git a/src/akkudoktoreos/config/configmigrate.py b/src/akkudoktoreos/config/configmigrate.py index 9ad075f..4bbdec5 100644 --- a/src/akkudoktoreos/config/configmigrate.py +++ b/src/akkudoktoreos/config/configmigrate.py @@ -3,7 +3,7 @@ import json import shutil from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Union, cast from loguru import logger @@ -13,19 +13,33 @@ if TYPE_CHECKING: # There are circular dependencies - only import here for type checking from akkudoktoreos.config.config import SettingsEOSDefaults + +_KEEP_DEFAULT = object() + # ----------------------------- # Global migration map constant # ----------------------------- # key: old JSON path, value: either # - str (new model path) # - tuple[str, Callable[[Any], Any]] (new path + transform) +# - _KEEP_DEFAULT (keep new default if old value is none or not given) # - None (drop) -MIGRATION_MAP: Dict[str, Union[str, Tuple[str, Callable[[Any], Any]], None]] = { +MIGRATION_MAP: Dict[ + str, + Union[ + str, # simple rename + Tuple[str, Callable[[Any], Any]], # rename + transform + Tuple[str, object], # rename + _KEEP_DEFAULT + Tuple[str, object, Callable[[Any], Any]], # rename + _KEEP_DEFAULT + transform + None, # drop + ], +] = { # 0.2.0.dev -> 0.2.0.dev "adapter/homeassistant/optimization_solution_entity_ids": ( "adapter/homeassistant/solution_entity_ids", lambda v: v if isinstance(v, list) else None, ), + "general/data_folder_path": ("general/data_folder_path", _KEEP_DEFAULT), # 0.2.0 -> 0.2.0+dev "elecprice/provider_settings/ElecPriceImport/import_file_path": "elecprice/elecpriceimport/import_file_path", "elecprice/provider_settings/ElecPriceImport/import_json": "elecprice/elecpriceimport/import_json", @@ -91,20 +105,32 @@ def migrate_config_data(config_data: Dict[str, Any]) -> "SettingsEOSDefaults": for old_path, mapping in MIGRATION_MAP.items(): new_path = None transform = None + keep_default = False + if mapping is None: migrated_source_paths.add(old_path.strip("/")) logger.debug(f"🗑️ Migration map: dropping '{old_path}'") continue if isinstance(mapping, tuple): - new_path, transform = mapping + new_path = mapping[0] + for m in mapping[1:]: + if m is _KEEP_DEFAULT: + keep_default = True + elif callable(m): + transform = cast(Callable[[Any], Any], m) else: new_path = mapping old_value = _get_json_nested_value(config_data, old_path) if old_value is None: - migrated_source_paths.add(old_path.strip("/")) - mapped_count += 1 - logger.debug(f"✅ Migrated mapped '{old_path}' → 'None'") + if keep_default: + migrated_source_paths.add(old_path.strip("/")) + mapped_count += 1 + logger.debug(f"✅ Migrated mapped '{old_path}' → keeping new default") + else: + migrated_source_paths.add(old_path.strip("/")) + mapped_count += 1 + logger.debug(f"✅ Migrated mapped '{old_path}' → 'None'") continue try: diff --git a/src/akkudoktoreos/core/cache.py b/src/akkudoktoreos/core/cache.py index ec92c64..0dc3078 100644 --- a/src/akkudoktoreos/core/cache.py +++ b/src/akkudoktoreos/core/cache.py @@ -13,6 +13,7 @@ import os import pickle import tempfile import threading +from pathlib import Path from typing import ( IO, Any, @@ -236,6 +237,24 @@ Param = ParamSpec("Param") RetType = TypeVar("RetType") +def cache_clear(clear_all: Optional[bool] = None) -> None: + """Cleanup expired cache files.""" + if clear_all: + CacheFileStore().clear(clear_all=True) + else: + CacheFileStore().clear(before_datetime=to_datetime()) + + +def cache_load() -> dict: + """Load cache from cachefilestore.json.""" + return CacheFileStore().load_store() + + +def cache_save() -> dict: + """Save cache to cachefilestore.json.""" + return CacheFileStore().save_store() + + class CacheFileRecord(PydanticBaseModel): cache_file: Any = Field( ..., json_schema_extra={"description": "File descriptor of the cache file."} @@ -284,9 +303,16 @@ class CacheFileStore(ConfigMixin, SingletonMixin): return self._store: Dict[str, CacheFileRecord] = {} self._store_lock = threading.RLock() - self._store_file = self.config.cache.path().joinpath("cachefilestore.json") super().__init__(*args, **kwargs) + def _store_file(self) -> Optional[Path]: + """Get file to store the cache.""" + try: + return self.config.cache.path().joinpath("cachefilestore.json") + except Exception: + logger.error("Path for cache files missing. Please configure!") + return None + def _until_datetime_by_options( self, until_date: Optional[Any] = None, @@ -496,10 +522,18 @@ class CacheFileStore(ConfigMixin, SingletonMixin): # File already available cache_file_obj = cache_item.cache_file else: - self.config.cache.path().mkdir(parents=True, exist_ok=True) - cache_file_obj = tempfile.NamedTemporaryFile( - mode=mode, delete=delete, suffix=suffix, dir=self.config.cache.path() - ) + # Create cache file + store_file = self._store_file() + if store_file: + store_file.parent.mkdir(parents=True, exist_ok=True) + cache_file_obj = tempfile.NamedTemporaryFile( + mode=mode, delete=delete, suffix=suffix, dir=store_file.parent + ) + else: + # Cache storage not configured, use temporary path + cache_file_obj = tempfile.NamedTemporaryFile( + mode=mode, delete=delete, suffix=suffix + ) self._store[cache_file_key] = CacheFileRecord( cache_file=cache_file_obj, until_datetime=until_datetime_dt, @@ -766,10 +800,14 @@ class CacheFileStore(ConfigMixin, SingletonMixin): Returns: data (dict): cache management data that was saved. """ + store_file = self._store_file() + if store_file is None: + return {} + with self._store_lock: - self._store_file.parent.mkdir(parents=True, exist_ok=True) + store_file.parent.mkdir(parents=True, exist_ok=True) store_to_save = self.current_store() - with self._store_file.open("w", encoding="utf-8", newline="\n") as f: + with store_file.open("w", encoding="utf-8", newline="\n") as f: try: json.dump(store_to_save, f, indent=4) except Exception as e: @@ -782,18 +820,22 @@ class CacheFileStore(ConfigMixin, SingletonMixin): Returns: data (dict): cache management data that was loaded. """ + store_file = self._store_file() + if store_file is None: + return {} + with self._store_lock: store_loaded = {} - if self._store_file.exists(): - with self._store_file.open("r", encoding="utf-8", newline=None) as f: + if store_file.exists(): + with store_file.open("r", encoding="utf-8", newline=None) as f: try: store_to_load = json.load(f) except Exception as e: logger.error( f"Error loading cache file store: {e}\n" - + f"Deleting the store file {self._store_file}." + + f"Deleting the store file {store_file}." ) - self._store_file.unlink() + store_file.unlink() return {} for key, record in store_to_load.items(): if record is None: diff --git a/src/akkudoktoreos/core/cachesettings.py b/src/akkudoktoreos/core/cachesettings.py index 56bc5c3..2ed55f4 100644 --- a/src/akkudoktoreos/core/cachesettings.py +++ b/src/akkudoktoreos/core/cachesettings.py @@ -20,7 +20,8 @@ class CacheCommonSettings(SettingsBaseModel): ) cleanup_interval: float = Field( - default=5 * 60, + default=5.0 * 60, + ge=5.0, json_schema_extra={"description": "Intervall in seconds for EOS file cache cleanup."}, ) diff --git a/src/akkudoktoreos/core/coreabc.py b/src/akkudoktoreos/core/coreabc.py index 31b4b45..514670b 100644 --- a/src/akkudoktoreos/core/coreabc.py +++ b/src/akkudoktoreos/core/coreabc.py @@ -1,28 +1,76 @@ """Abstract and base classes for EOS core. -This module provides foundational classes for handling configuration and prediction functionality -in EOS. It includes base classes that provide convenient access to global -configuration and prediction instances through properties. - -Classes: - - ConfigMixin: Mixin class for managing and accessing global configuration. - - PredictionMixin: Mixin class for managing and accessing global prediction data. - - SingletonMixin: Mixin class to create singletons. +This module provides foundational classes and functions to access global EOS resources. """ +from __future__ import ( + annotations, # use types lazy as strings, helps to prevent circular dependencies +) + import threading -from typing import Any, ClassVar, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Type, Union from loguru import logger from akkudoktoreos.core.decorators import classproperty from akkudoktoreos.utils.datetimeutil import DateTime -adapter_eos: Any = None -config_eos: Any = None -measurement_eos: Any = None -prediction_eos: Any = None -ems_eos: Any = None +if TYPE_CHECKING: + # Prevents circular dependies + from akkudoktoreos.adapter.adapter import Adapter + from akkudoktoreos.config.config import ConfigEOS + from akkudoktoreos.core.database import Database + from akkudoktoreos.core.ems import EnergyManagement + from akkudoktoreos.devices.devices import ResourceRegistry + from akkudoktoreos.measurement.measurement import Measurement + from akkudoktoreos.prediction.prediction import Prediction + + +# Module level singleton cache +_adapter_eos: Optional[Adapter] = None +_config_eos: Optional[ConfigEOS] = None +_ems_eos: Optional[EnergyManagement] = None +_database_eos: Optional[Database] = None +_measurement_eos: Optional[Measurement] = None +_prediction_eos: Optional[Prediction] = None +_resource_registry_eos: Optional[ResourceRegistry] = None + + +def get_adapter(init: bool = False) -> Adapter: + """Retrieve the singleton EOS Adapter instance. + + This function provides access to the global EOS Adapter instance. The Adapter + object is created on first access if `init` is True. If the instance is + accessed before initialization and `init` is False, a RuntimeError is raised. + + Args: + init (bool): If True, create the Adapter instance if it does not exist. + Default is False. + + Returns: + Adapter: The global EOS Adapter instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + adapter = get_adapter(init=True) # Initialize and retrieve + adapter.do_something() + """ + global _adapter_eos + if _adapter_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("Adapter access before init.") + + from akkudoktoreos.adapter.adapter import Adapter + + _adapter_eos = Adapter() + + return _adapter_eos class AdapterMixin: @@ -49,20 +97,84 @@ class AdapterMixin: """ @classproperty - def adapter(cls) -> Any: + def adapter(cls) -> Adapter: """Convenience class method/ attribute to retrieve the EOS adapters. Returns: Adapter: The adapters. """ - # avoid circular dependency at import time - global adapter_eos - if adapter_eos is None: - from akkudoktoreos.adapter.adapter import get_adapter + return get_adapter() - adapter_eos = get_adapter() - return adapter_eos +def get_config(init: Union[bool, dict[str, bool]] = False) -> ConfigEOS: + """Retrieve the singleton EOS configuration instance. + + This function provides controlled access to the global EOS configuration + singleton (`ConfigEOS`). The configuration is created lazily on first + access and can be initialized with a configurable set of settings sources. + + By default, accessing the configuration without prior initialization + raises a `RuntimeError`. Passing `init=True` or an initialization + configuration dictionary enables creation of the singleton. + + Args: + init (Union[bool, dict[str, bool]]): + Controls initialization of the configuration. + + - ``False`` (default): Do not initialize. Raises ``RuntimeError`` + if the configuration does not yet exist. + - ``True``: Initialize the configuration using default + initialization behavior (all settings sources enabled). + - ``dict[str, bool]``: Initialize the configuration with fine-grained + control over which settings sources are enabled. Missing keys + default to ``True``. + + Supported keys include: + - ``with_init_settings`` + - ``with_env_settings`` + - ``with_dotenv_settings`` + - ``with_file_settings`` + - ``with_file_secret_settings`` + + Returns: + ConfigEOS: The global EOS configuration singleton instance. + + Raises: + RuntimeError: + If the configuration has not been initialized and ``init`` is + ``False``. + + Usage: + .. code-block:: python + + # Initialize with default behavior (all sources enabled) + config = get_config(init=True) + + # Initialize with explicit source control + config = get_config(init={ + "with_init_settings": True, + "with_env_settings": True, + "with_dotenv_settings": True, + "with_file_settings": False, + "with_file_secret_settings": False, + }) + + # Access existing configuration + host = get_config().server.host + """ + global _config_eos + if _config_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("Config access before init.") + + if isinstance(init, dict): + ConfigEOS._init_config_eos = init + + _config_eos = ConfigEOS() + + return _config_eos class ConfigMixin: @@ -89,20 +201,51 @@ class ConfigMixin: """ @classproperty - def config(cls) -> Any: + def config(cls) -> ConfigEOS: """Convenience class method/ attribute to retrieve the EOS configuration data. Returns: ConfigEOS: The configuration. """ - # avoid circular dependency at import time - global config_eos - if config_eos is None: - from akkudoktoreos.config.config import get_config + return get_config() - config_eos = get_config() - return config_eos +def get_measurement(init: bool = False) -> Measurement: + """Retrieve the singleton EOS Measurement instance. + + This function provides access to the global EOS Measurement object. The + Measurement instance is created on first access if `init` is True. If the + instance is accessed before initialization and `init` is False, a RuntimeError + is raised. + + Args: + init (bool): If True, create the Measurement instance if it does not exist. + Default is False. + + Returns: + Measurement: The global EOS Measurement instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + measurement = get_measurement(init=True) # Initialize and retrieve + measurement.read_sensor_data() + """ + global _measurement_eos + if _measurement_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("Measurement access before init.") + + from akkudoktoreos.measurement.measurement import Measurement + + _measurement_eos = Measurement() + + return _measurement_eos class MeasurementMixin: @@ -130,20 +273,51 @@ class MeasurementMixin: """ @classproperty - def measurement(cls) -> Any: + def measurement(cls) -> Measurement: """Convenience class method/ attribute to retrieve the EOS measurement data. Returns: Measurement: The measurement. """ - # avoid circular dependency at import time - global measurement_eos - if measurement_eos is None: - from akkudoktoreos.measurement.measurement import get_measurement + return get_measurement() - measurement_eos = get_measurement() - return measurement_eos +def get_prediction(init: bool = False) -> Prediction: + """Retrieve the singleton EOS Prediction instance. + + This function provides access to the global EOS Prediction object. The + Prediction instance is created on first access if `init` is True. If the + instance is accessed before initialization and `init` is False, a RuntimeError + is raised. + + Args: + init (bool): If True, create the Prediction instance if it does not exist. + Default is False. + + Returns: + Prediction: The global EOS Prediction instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + prediction = get_prediction(init=True) # Initialize and retrieve + prediction.forecast_next_hour() + """ + global _prediction_eos + if _prediction_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("Prediction access before init.") + + from akkudoktoreos.prediction.prediction import Prediction + + _prediction_eos = Prediction() + + return _prediction_eos class PredictionMixin: @@ -171,20 +345,50 @@ class PredictionMixin: """ @classproperty - def prediction(cls) -> Any: + def prediction(cls) -> Prediction: """Convenience class method/ attribute to retrieve the EOS prediction data. Returns: Prediction: The prediction. """ - # avoid circular dependency at import time - global prediction_eos - if prediction_eos is None: - from akkudoktoreos.prediction.prediction import get_prediction + return get_prediction() - prediction_eos = get_prediction() - return prediction_eos +def get_ems(init: bool = False) -> EnergyManagement: + """Retrieve the singleton EOS Energy Management System (EMS) instance. + + This function provides access to the global EOS EMS instance. The instance + is created on first access if `init` is True. If the instance is accessed + before initialization and `init` is False, a RuntimeError is raised. + + Args: + init (bool): If True, create the EMS instance if it does not exist. + Default is False. + + Returns: + EnergyManagement: The global EOS EMS instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + ems = get_ems(init=True) # Initialize and retrieve + ems.start_energy_management_loop() + """ + global _ems_eos + if _ems_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("EMS access before init.") + + from akkudoktoreos.core.ems import EnergyManagement + + _ems_eos = EnergyManagement() + + return _ems_eos class EnergyManagementSystemMixin: @@ -200,7 +404,7 @@ class EnergyManagementSystemMixin: global EnergyManagementSystem instance lazily to avoid import-time circular dependencies. Attributes: - ems (EnergyManagementSystem): Property to access the global EOS energy management system. + ems (EnergyManagement): Property to access the global EOS energy management system. Example: .. code-block:: python @@ -213,20 +417,120 @@ class EnergyManagementSystemMixin: """ @classproperty - def ems(cls) -> Any: + def ems(cls) -> EnergyManagement: """Convenience class method/ attribute to retrieve the EOS energy management system. Returns: EnergyManagementSystem: The energy management system. """ - # avoid circular dependency at import time - global ems_eos - if ems_eos is None: - from akkudoktoreos.core.ems import get_ems + return get_ems() - ems_eos = get_ems() - return ems_eos +def get_database(init: bool = False) -> Database: + """Retrieve the singleton EOS database instance. + + This function provides access to the global EOS Database instance. The + instance is created on first access if `init` is True. If the instance is + accessed before initialization and `init` is False, a RuntimeError is raised. + + Args: + init (bool): If True, create the Database instance if it does not exist. + Default is False. + + Returns: + Database: The global EOS database instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + db = get_database(init=True) # Initialize and retrieve + db.insert_measurement(...) + """ + global _database_eos + if _database_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("Database access before init.") + + from akkudoktoreos.core.database import Database + + _database_eos = Database() + + return _database_eos + + +class DatabaseMixin: + """Mixin class for managing EOS database access. + + This class serves as a foundational component for EOS-related classes requiring access + to the EOS database. It provides a `database` property that dynamically retrieves + the database instance. + + Usage: + Subclass this base class to gain access to the `database` attribute, which retrieves the + global database instance lazily to avoid import-time circular dependencies. + + Attributes: + database (Database): Property to access the global EOS database. + + Example: + .. code-block:: python + + class MyOptimizationClass(PredictionMixin): + def store something(self): + db = self.database + + """ + + @classproperty + def database(cls) -> Database: + """Convenience class method/ attribute to retrieve the EOS database. + + Returns: + Database: The database. + """ + return get_database() + + +def get_resource_registry(init: bool = False) -> ResourceRegistry: + """Retrieve the singleton EOS Resource Registry instance. + + This function provides access to the global EOS ResourceRegistry instance. + The instance is created on first access if `init` is True. If the instance + is accessed before initialization and `init` is False, a RuntimeError is raised. + + Args: + init (bool): If True, create the ResourceRegistry instance if it does not exist. + Default is False. + + Returns: + ResourceRegistry: The global EOS Resource Registry instance. + + Raises: + RuntimeError: If accessed before initialization with `init=False`. + + Usage: + .. code-block:: python + + registry = get_resource_registry(init=True) # Initialize and retrieve + registry.register_device(my_device) + """ + global _resource_registry_eos + if _resource_registry_eos is None: + from akkudoktoreos.config.config import ConfigEOS + + if not init and not ConfigEOS.documentation_mode(): + raise RuntimeError("ResourceRegistry access before init.") + + from akkudoktoreos.devices.devices import ResourceRegistry + + _resource_registry_eos = ResourceRegistry() + + return _resource_registry_eos class StartMixin(EnergyManagementSystemMixin): @@ -243,14 +547,7 @@ class StartMixin(EnergyManagementSystemMixin): Returns: DateTime: The starting datetime of the current or latest energy management, or None. """ - # avoid circular dependency at import time - global ems_eos - if ems_eos is None: - from akkudoktoreos.core.ems import get_ems - - ems_eos = get_ems() - - return ems_eos.start_datetime + return get_ems().start_datetime class SingletonMixin: @@ -332,3 +629,43 @@ class SingletonMixin: if not hasattr(self, "_initialized"): super().__init__(*args, **kwargs) self._initialized = True + + +_singletons_init_running: bool = False + + +def singletons_init() -> None: + """Initialize the singletons for adapter, config, measurement, prediction, database, resource registry.""" + # Prevent recursive calling + global \ + _singletons_init_running, \ + _adapter_eos, \ + _config_eos, \ + _database_eos, \ + _measurement_eos, \ + _prediction_eos, \ + _ems_eos, \ + _resource_registry_eos + + if _singletons_init_running: + return + + _singletons_init_running = True + + try: + if _config_eos is None: + get_config(init=True) + if _adapter_eos is None: + get_adapter(init=True) + if _database_eos is None: + get_database(init=True) + if _ems_eos is None: + get_ems(init=True) + if _measurement_eos is None: + get_measurement(init=True) + if _prediction_eos is None: + get_prediction(init=True) + if _resource_registry_eos is None: + get_resource_registry(init=True) + finally: + _singletons_init_running = False diff --git a/src/akkudoktoreos/core/dataabc.py b/src/akkudoktoreos/core/dataabc.py index f4217b0..c035076 100644 --- a/src/akkudoktoreos/core/dataabc.py +++ b/src/akkudoktoreos/core/dataabc.py @@ -11,24 +11,24 @@ and manipulation of configuration and generic data in a clear, scalable, and str import difflib import json from abc import abstractmethod -from collections.abc import MutableMapping, MutableSequence +from collections.abc import KeysView, MutableMapping from itertools import chain from pathlib import Path from typing import ( Any, Dict, Iterator, - List, + Literal, Optional, Tuple, Type, Union, + get_args, overload, ) import numpy as np import pandas as pd -import pendulum from loguru import logger from numpydantic import NDArray, Shape from pydantic import ( @@ -41,7 +41,17 @@ from pydantic import ( model_validator, ) -from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin, StartMixin +from akkudoktoreos.core.coreabc import ( + ConfigMixin, + SingletonMixin, + StartMixin, +) +from akkudoktoreos.core.databaseabc import ( + UNBOUND_WINDOW, + DatabaseRecordProtocolMixin, + DatabaseTimestamp, + DatabaseTimeWindowType, +) from akkudoktoreos.core.pydantic import ( PydanticBaseModel, PydanticDateTimeData, @@ -56,7 +66,7 @@ from akkudoktoreos.utils.datetimeutil import ( ) -class DataBase(ConfigMixin, StartMixin, PydanticBaseModel): +class DataABC(ConfigMixin, StartMixin, PydanticBaseModel): """Base class for handling generic data. Enables access to EOS configuration data (attribute `config`). @@ -65,7 +75,10 @@ class DataBase(ConfigMixin, StartMixin, PydanticBaseModel): pass -class DataRecord(DataBase, MutableMapping): +# ==================== DataRecord ==================== + + +class DataRecord(DataABC, MutableMapping): """Base class for data records, enabling dynamic access to fields defined in derived classes. Fields can be accessed and mutated both using dictionary-style access (`record['field_name']`) @@ -77,7 +90,8 @@ class DataRecord(DataBase, MutableMapping): dictionary-style and attribute-style access. Attributes: - date_time (Optional[DateTime]): Aware datetime indicating when the data record applies. + date_time (DateTime): Aware datetime indicating when the data record applies. Defaults + to now. Configurations: - Allows mutation after creation. @@ -145,7 +159,7 @@ class DataRecord(DataBase, MutableMapping): return None @classmethod - def record_keys(cls) -> List[str]: + def record_keys(cls) -> list[str]: """Returns the keys of all fields in the data record.""" key_list = [] key_list.extend(list(cls.model_fields.keys())) @@ -158,7 +172,7 @@ class DataRecord(DataBase, MutableMapping): return key_list @classmethod - def record_keys_writable(cls) -> List[str]: + def record_keys_writable(cls) -> list[str]: """Returns the keys of all fields in the data record that are writable.""" keys_writable = [] keys_writable.extend(list(cls.model_fields.keys())) @@ -394,18 +408,18 @@ class DataRecord(DataBase, MutableMapping): @classmethod def keys_from_descriptions( - cls, descriptions: List[str], threshold: float = 0.8 - ) -> List[Optional[str]]: + cls, descriptions: list[str], threshold: float = 0.8 + ) -> list[Optional[str]]: """Returns a list of attribute keys that best matches the provided list of descriptions. Fuzzy matching is used. Args: - descriptions (List[str]): A list of description texts to search for. + descriptions (list[str]): A list of description texts to search for. threshold (float): The minimum ratio for a match (0-1). Default is 0.8. Returns: - List[Optional[str]]: A list of attribute keys matching the descriptions, with None for unmatched descriptions. + list[Optional[str]]: A list of attribute keys matching the descriptions, with None for unmatched descriptions. """ keys = [] for description in descriptions: @@ -414,20 +428,32 @@ class DataRecord(DataBase, MutableMapping): return keys -class DataSequence(DataBase, MutableSequence): - """A managed sequence of DataRecord instances with list-like behavior. +# ==================== DataSequence ==================== + + +class DataSequence(DataABC, DatabaseRecordProtocolMixin[DataRecord]): + """A managed sequence of DataRecord instances with ltime series behavior. The DataSequence class provides an ordered, mutable collection of DataRecord - instances, allowing list-style access for adding, deleting, and retrieving records. It also - supports advanced data operations such as JSON serialization, conversion to Pandas Series, - and sorting by timestamp. + instances. + + It also supports advanced data operations such as + + - JSON serialization, + - conversion to Pandas Series, + - sorting by timestamp, + - and data storage in a database. Attributes: - records (List[DataRecord]): A list of DataRecord instances representing + records (list[DataRecord]): A list of DataRecord instances representing individual generic data points. - record_keys (Optional[List[str]]): A list of field names (keys) expected in each + record_keys (Optional[list[str]]): A list of field names (keys) expected in each DataRecord. + Invariant: + ``self.records`` is always kept sorted in ascending ``date_time`` order + whenever it contains any records. + Note: Derived classes have to provide their own records field with correct record type set. @@ -436,7 +462,7 @@ class DataSequence(DataBase, MutableSequence): # Example of creating, adding, and using DataSequence class DerivedSequence(DataSquence): - records: List[DerivedDataRecord] = Field(default_factory=list, json_schema_extra={ "description": "List of data records" }) + records: list[DerivedDataRecord] = Field(default_factory=list, json_schema_extra={ "description": "List of data records" }) seq = DerivedSequence() seq.insert(DerivedDataRecord(date_time=datetime.now(), temperature=72)) @@ -452,86 +478,11 @@ class DataSequence(DataBase, MutableSequence): """ # To be overloaded by derived classes. - records: List[DataRecord] = Field( + records: list[DataRecord] = Field( default_factory=list, json_schema_extra={"description": "List of data records"} ) - # Derived fields (computed) - @computed_field # type: ignore[prop-decorator] - @property - def min_datetime(self) -> Optional[DateTime]: - """Minimum (earliest) datetime in the sorted sequence of data records. - - This property computes the earliest datetime from the sequence of data records. - If no records are present, it returns `None`. - - Returns: - Optional[DateTime]: The earliest datetime in the sequence, or `None` if no - data records exist. - """ - if len(self.records) == 0: - return None - return self.records[0].date_time - - @computed_field # type: ignore[prop-decorator] - @property - def max_datetime(self) -> DateTime: - """Maximum (latest) datetime in the sorted sequence of data records. - - This property computes the latest datetime from the sequence of data records. - If no records are present, it returns `None`. - - Returns: - Optional[DateTime]: The latest datetime in the sequence, or `None` if no - data records exist. - """ - if len(self.records) == 0: - return None - return self.records[-1].date_time - - @computed_field # type: ignore[prop-decorator] - @property - def record_keys(self) -> List[str]: - """Returns the keys of all fields in the data records.""" - return self.record_class().record_keys() - - @computed_field # type: ignore[prop-decorator] - @property - def record_keys_writable(self) -> List[str]: - """Get the keys of all writable fields in the data records. - - This property retrieves the keys of all fields in the data records that - can be written to. It uses the `record_class` to determine the model's - field structure. - - Returns: - List[str]: A list of field keys that are writable in the data records. - """ - return self.record_class().record_keys_writable() - - @classmethod - def record_class(cls) -> Type: - """Get the class of the data record handled by this data sequence. - - This method determines the class of the data record type associated with - the `records` field of the model. The field is expected to be a list, and - the element type of the list should be a subclass of `DataRecord`. - - Raises: - ValueError: If the record type is not a subclass of `DataRecord`. - - Returns: - Type: The class of the data record handled by the data sequence. - """ - # Access the model field metadata - field_info = cls.model_fields["records"] - # Get the list element type from the 'type_' attribute - list_element_type = field_info.annotation.__args__[0] - if not isinstance(list_element_type(), DataRecord): - raise ValueError( - f"Data record must be an instance of DataRecord: '{list_element_type}'." - ) - return list_element_type + # Sequence helpers def _validate_key(self, key: str) -> None: """Verify that a specified key exists in the current record keys. @@ -576,93 +527,124 @@ class DataSequence(DataBase, MutableSequence): # Assure datetime value can be converted to datetime object value.date_time = to_datetime(value.date_time) - @overload - def __getitem__(self, index: int) -> DataRecord: ... + # Sequence state - @overload - def __getitem__(self, index: slice) -> list[DataRecord]: ... + # Derived fields (computed) + @computed_field # type: ignore[prop-decorator] + @property + def min_datetime(self) -> Optional[DateTime]: + """Minimum (earliest) datetime in the time series sequence of data records. - def __getitem__(self, index: Union[int, slice]) -> Union[DataRecord, list[DataRecord]]: - """Retrieve a DataRecord or list of DataRecords by index or slice. - - Supports both single item and slice-based access to the sequence. - - Args: - index (int or slice): The index or slice to access. + This property computes the earliest datetime from the sequence of data records. + If no records are present, it returns `None`. Returns: - DataRecord or list[DataRecord]: A single DataRecord or a list of DataRecords. + Optional[DateTime]: The earliest datetime in the sequence, or `None` if no + data records exist. + """ + min_timestamp, _ = self.db_timestamp_range() + if min_timestamp is None: + return None + # Timestamps are in UTC - convert to timezone + utc_datetime = DatabaseTimestamp.to_datetime(min_timestamp) + return utc_datetime.in_timezone(self.config.general.timezone) + + @computed_field # type: ignore[prop-decorator] + @property + def max_datetime(self) -> Optional[DateTime]: + """Maximum (latest) datetime in the time series sequence of data records. + + This property computes the latest datetime from the sequence of data records. + If no records are present, it returns `None`. + + Returns: + Optional[DateTime]: The latest datetime in the sequence, or `None` if no + data records exist. + """ + _, max_timestamp = self.db_timestamp_range() + if max_timestamp is None: + return None + # Timestamps are in UTC - convert to timezone + utc_datetime = DatabaseTimestamp.to_datetime(max_timestamp) + return utc_datetime.in_timezone(self.config.general.timezone) + + @computed_field # type: ignore[prop-decorator] + @property + def record_keys(self) -> list[str]: + """Returns the keys of all fields in the data records.""" + return self.record_class().record_keys() + + @computed_field # type: ignore[prop-decorator] + @property + def record_keys_writable(self) -> list[str]: + """Get the keys of all writable fields in the data records. + + This property retrieves the keys of all fields in the data records that + can be written to. It uses the `record_class` to determine the model's + field structure. + + Returns: + list[str]: A list of field keys that are writable in the data records. + """ + return self.record_class().record_keys_writable() + + @classmethod + def record_class(cls) -> Type: + """Get the class of the data record handled by this data sequence. + + This method determines the class of the data record type associated with + the `records` field of the model. The field is expected to be a list, and + the element type of the list should be a subclass of `DataRecord`. Raises: - IndexError: If the index is invalid or out of range. + ValueError: If the record type is not a subclass of `DataRecord`. + + Returns: + Type: The class of the data record handled by the data sequence. """ - if isinstance(index, int): - # Single item access logic - return self.records[index] - elif isinstance(index, slice): - # Slice access logic - return self.records[index] - raise IndexError("Invalid index") + # Access the model field metadata + field_info = cls.model_fields["records"] + # Get the list element type from the 'type_' attribute + list_element_type = get_args(field_info.annotation)[0] + if not isinstance(list_element_type(), DataRecord): + raise ValueError( + f"Data record must be an instance of DataRecord: '{list_element_type}'." + ) + return list_element_type - def __setitem__(self, index: Any, value: Any) -> None: - """Replace a data record or slice of records with new value(s). + @classmethod + def from_dict(cls, data: dict) -> "DataSequence": + """Reconstruct a sequence from its serialized dictionary form. - Supports setting a single record at an integer index or - multiple records using a slice. - - Args: - index (int or slice): The index or slice to modify. - value (DataRecord or list[DataRecord]): - Single record or list of records to set. - - Raises: - ValueError: If the number of records does not match the slice length. - IndexError: If the index is out of range. + Fully subclass-safe and invariant-safe. """ - if isinstance(index, int): - if isinstance(value, list): - raise ValueError("Cannot assign list to single index") - self._validate_record(value) - self.records[index] = value - elif isinstance(index, slice): - if isinstance(value, DataRecord): - raise ValueError("Cannot assign single record to slice") - for record in value: - self._validate_record(record) - self.records[index] = value - else: - # Should never happen - raise TypeError("Invalid type for index") + if not isinstance(data, dict): + raise TypeError("from_dict() expects a dictionary") - def __delitem__(self, index: Any) -> None: - """Remove a single data record or a slice of records. + records_data = data.get("records", []) + if not isinstance(records_data, list): + raise ValueError("'records' must be a list") - Supports deleting a single record by integer index - or multiple records using a slice. + # Create empty instance of *actual class* + sequence = cls() - Args: - index (int or slice): The index or slice to delete. + # Rebuild records using the sequence's record model + record_model = sequence.record_class() - Raises: - IndexError: If the index is out of range. - """ - del self.records[index] + for record_dict in records_data: + if not isinstance(record_dict, dict): + raise ValueError("Each record must be a dictionary") + + record = record_model(**record_dict) + + # Important: use insert_by_datetime to rebuild invariants + sequence.insert_by_datetime(record) + + return sequence def __len__(self) -> int: - """Get the number of DataRecords in the sequence. - - Returns: - int: The count of records in the sequence. - """ - return len(self.records) - - def __iter__(self) -> Iterator[DataRecord]: - """Create an iterator for accessing DataRecords sequentially. - - Returns: - Iterator[DataRecord]: An iterator for the records. - """ - return iter(self.records) + """Get total number of DataRecords in sequence (DB + memory-only).""" + return self.db_count_records() def __repr__(self) -> str: """Provide a string representation of the DataSequence. @@ -672,52 +654,88 @@ class DataSequence(DataBase, MutableSequence): """ return f"{self.__class__.__name__}([{', '.join(repr(record) for record in self.records)}])" - def insert(self, index: int, value: DataRecord) -> None: - """Insert a DataRecord at a specified index in the sequence. + # Sequence methods - This method inserts a `DataRecord` at the specified index within the sequence of records, - shifting subsequent records to the right. If `index` is 0, the record is added at the beginning - of the sequence, and if `index` is equal to the length of the sequence, the record is appended - at the end. + def __iter__(self) -> Iterator[DataRecord]: + """Create an iterator for accessing DataRecords sequentially. + + Returns: + Iterator[DataRecord]: An iterator for the records. + """ + return iter(self.records) + + def get_by_datetime( + self, target_datetime: DateTime, *, time_window: Optional[Duration] = None + ) -> Optional[DataRecord]: + """Get the record at the specified datetime, with an optional fallback search window. Args: - index (int): The position before which to insert the new record. An index of 0 inserts - the record at the start, while an index equal to the length of the sequence - appends it to the end. - value (DataRecord): The `DataRecord` instance to insert into the sequence. + target_datetime: The datetime to search for. + time_window: Optional total width of the symmetric search window centered on + ``target_datetime``. If provided and no exact match exists, the nearest + record within this window is returned. + + Returns: + The matching DataRecord, the nearest DataRecord within the specified time window + if no exact match exists, or ``None`` if no suitable record is found. + """ + # Ensure datetime objects are normalized + db_target = DatabaseTimestamp.from_datetime(target_datetime) + + return self.db_get_record(db_target, time_window=time_window) + + def get_nearest_by_datetime( + self, target_datetime: DateTime, time_window: Optional[Duration] = None + ) -> Optional[DataRecord]: + """Get the record nearest to the specified datetime within an optional time window. + + Args: + target_datetime: The datetime to search near. + time_window: Total width of the symmetric search window centered on + ``target_datetime``. If ``None``, searches all records. + + Returns: + The nearest DataRecord within the specified time window, or ``None`` if no records + exist or no records fall within the window. Raises: - ValueError: If `value` is not an instance of `DataRecord`. + ValueError: If ``time_window`` is negative. """ - self.records.insert(index, value) + # Ensure datetime objects are normalized + db_target = DatabaseTimestamp.from_datetime(target_datetime) - def insert_by_datetime(self, value: DataRecord) -> None: + if time_window is None: + twin: DatabaseTimeWindowType = UNBOUND_WINDOW + else: + twin = time_window + return self.db_get_record(db_target, time_window=twin) + + def insert_by_datetime(self, record: DataRecord) -> None: """Insert or merge a DataRecord into the sequence based on its date. If a record with the same date exists, merges new data fields with the existing record. Otherwise, appends the record and maintains chronological order. Args: - value (DataRecord): The record to add or merge. + record (DataRecord): The record to add or merge. + + Note: + record.date_time shall be a DateTime or None """ - self._validate_record(value) - # Check if a record with the given date already exists - for record in self.records: - if not isinstance(record.date_time, DateTime): - raise ValueError( - f"Record date '{record.date_time}' is not a datetime, but a `{type(record.date_time).__name__}`." - ) - if compare_datetimes(record.date_time, value.date_time).equal: - # Merge values, only updating fields where data record has a non-None value - for field, val in value.model_dump(exclude_unset=True).items(): - if field in value.record_keys_writable(): - setattr(record, field, val) - break + self._validate_record(record) + + # Ensure datetime objects are normalized + record_date_time_timestamp = DatabaseTimestamp.from_datetime(record.date_time) + + avail_record = self.db_get_record(record_date_time_timestamp) + if avail_record: + # Merge values, only updating fields where data record has a non-None value + for field, val in record.model_dump(exclude_unset=True).items(): + if field in record.record_keys_writable(): + setattr(avail_record, field, val) + self.db_mark_dirty_record(record) else: - # Add data record if the date does not exist - self.records.append(value) - # Sort the list by datetime after adding/updating - self.sort_by_datetime() + self.db_insert_record(record) @overload def update_value(self, date: DateTime, key: str, value: Any) -> None: ... @@ -762,41 +780,19 @@ class DataSequence(DataBase, MutableSequence): self._validate_key_writable(key) # Ensure datetime objects are normalized - date = to_datetime(date, to_maxtime=False) + db_target = DatabaseTimestamp.from_datetime(date) # Check if a record with the given date already exists - for record in self.records: - if not isinstance(record.date_time, DateTime): - raise ValueError( - f"Record date '{record.date_time}' is not a datetime, but a `{type(record.date_time).__name__}`." - ) - if compare_datetimes(record.date_time, date).equal: - # Update the DataRecord with all new values - for key, value in values.items(): - setattr(record, key, value) - break - else: + record = self.db_get_record(db_target) + if record is None: # Create a new record and append to the list - record = self.record_class()(date_time=date, **values) - self.records.append(record) - # Sort the list by datetime after adding/updating - self.sort_by_datetime() - - def to_datetimeindex(self) -> pd.DatetimeIndex: - """Generate a Pandas DatetimeIndex from the date_time fields of all records in the sequence. - - Returns: - pd.DatetimeIndex: An index of datetime values corresponding to each record's date_time attribute. - - Raises: - ValueError: If any record does not have a valid date_time attribute. - """ - date_times = [record.date_time for record in self.records if record.date_time is not None] - - if not date_times: - raise ValueError("No valid date_time values found in the records.") - - return pd.DatetimeIndex(date_times) + new_record = self.record_class()(date_time=date, **values) + self.db_insert_record(new_record) + else: + # Update the DataRecord with all new values + for key, value in values.items(): + setattr(record, key, value) + self.db_mark_dirty_record(record) def key_to_dict( self, @@ -824,36 +820,45 @@ class DataSequence(DataBase, MutableSequence): KeyError: If the specified key is not found in any of the DataRecords. """ self._validate_key(key) + # Ensure datetime objects are normalized - start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None - end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + start_timestamp = ( + DatabaseTimestamp.from_datetime(start_datetime) if start_datetime else None + ) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) if end_datetime else None # Create a dictionary to hold date_time and corresponding values if dropna is None: dropna = True filtered_data = {} - for record in self.records: + for record in self.db_iterate_records(start_timestamp, end_timestamp): if ( record.date_time is None or (dropna and getattr(record, key, None) is None) or (dropna and getattr(record, key, None) == float("nan")) ): continue - if ( - start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge - ) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt): + record_date_time_timestamp = DatabaseTimestamp.from_datetime(record.date_time) + if (start_timestamp is None or record_date_time_timestamp >= start_timestamp) and ( + end_timestamp is None or record_date_time_timestamp < end_timestamp + ): filtered_data[to_datetime(record.date_time, as_string=True)] = getattr( record, key, None ) return filtered_data - def key_to_value(self, key: str, target_datetime: DateTime) -> Optional[float]: + def key_to_value( + self, key: str, target_datetime: DateTime, time_window: Optional[Duration] = None + ) -> Optional[float]: """Returns the value corresponding to the specified key that is nearest to the given datetime. Args: key (str): The key of the attribute in DataRecord to extract. - target_datetime (datetime): The datetime to search nearest to. + target_datetime (datetime): The datetime to search for. + time_window: Optional total width of the symmetric search window centered on + ``target_datetime``. If provided and no exact match exists, the nearest + record within this window is returned. Returns: Optional[float]: The value nearest to the given datetime, or None if no valid records are found. @@ -863,22 +868,12 @@ class DataSequence(DataBase, MutableSequence): """ self._validate_key(key) - # Filter out records with None or NaN values for the key - valid_records = [ - record - for record in self.records - if record.date_time is not None - and getattr(record, key, None) not in (None, float("nan")) - ] + # Ensure datetime objects are normalized + db_target = DatabaseTimestamp.from_datetime(to_datetime(target_datetime)) - if not valid_records: - return None + record = self.db_get_record(db_target, time_window=time_window) - # Find the record with datetime nearest to target_datetime - target = to_datetime(target_datetime) - nearest_record = min(valid_records, key=lambda r: abs(r.date_time - target)) - - return getattr(nearest_record, key, None) + return getattr(record, key, None) def key_to_lists( self, @@ -886,7 +881,7 @@ class DataSequence(DataBase, MutableSequence): start_datetime: Optional[DateTime] = None, end_datetime: Optional[DateTime] = None, dropna: Optional[bool] = None, - ) -> Tuple[List[DateTime], List[Optional[float]]]: + ) -> Tuple[list[DateTime], list[Optional[float]]]: """Extracts two lists from data records within an optional date range. The lists are: @@ -906,36 +901,42 @@ class DataSequence(DataBase, MutableSequence): KeyError: If the specified key is not found in any of the DataRecords. """ self._validate_key(key) + # Ensure datetime objects are normalized - start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None - end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + start_timestamp = ( + DatabaseTimestamp.from_datetime(start_datetime) if start_datetime else None + ) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) if end_datetime else None # Create two lists to hold date_time and corresponding values if dropna is None: dropna = True filtered_records = [] - for record in self.records: + for record in self.db_iterate_records(start_timestamp, end_timestamp): if ( record.date_time is None or (dropna and getattr(record, key, None) is None) or (dropna and getattr(record, key, None) == float("nan")) ): continue - if ( - start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge - ) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt): + record_date_time_timestamp = DatabaseTimestamp.from_datetime(record.date_time) + if (start_timestamp is None or record_date_time_timestamp >= start_timestamp) and ( + end_timestamp is None or record_date_time_timestamp < end_timestamp + ): filtered_records.append(record) dates = [record.date_time for record in filtered_records] values = [getattr(record, key, None) for record in filtered_records] return dates, values - def key_from_lists(self, key: str, dates: List[DateTime], values: List[float]) -> None: + def key_from_lists(self, key: str, dates: list[DateTime], values: list[float]) -> None: """Update the DataSequence from lists of datetime and value elements. The dates list should represent the date_time of each DataRecord, and the values list should represent the corresponding data values for the specified key. + The list must be ordered starting with the oldest date. + Args: key (str): The field name in the DataRecord that corresponds to the values in the Series. dates: List of datetime elements. @@ -945,17 +946,17 @@ class DataSequence(DataBase, MutableSequence): for i, date_time in enumerate(dates): # Ensure datetime objects are normalized - date_time = to_datetime(date_time, to_maxtime=False) if date_time else None + db_target = DatabaseTimestamp.from_datetime(date_time) # Check if there's an existing record for this date_time - existing_record = next((r for r in self.records if r.date_time == date_time), None) - if existing_record: - # Update existing record's specified key - setattr(existing_record, key, values[i]) - else: + avail_record = self.db_get_record(db_target) + if avail_record is None: # Create a new DataRecord if none exists new_record = self.record_class()(date_time=date_time, **{key: values[i]}) - self.records.append(new_record) - self.sort_by_datetime() + self.db_insert_record(new_record) + else: + # Update existing record's specified key + setattr(avail_record, key, values[i]) + self.db_mark_dirty_record(avail_record) def key_to_series( self, @@ -999,17 +1000,17 @@ class DataSequence(DataBase, MutableSequence): for date_time, value in series.items(): # Ensure datetime objects are normalized - date_time = to_datetime(date_time, to_maxtime=False) if date_time else None + db_target = DatabaseTimestamp.from_datetime(to_datetime(date_time)) # Check if there's an existing record for this date_time - existing_record = next((r for r in self.records if r.date_time == date_time), None) - if existing_record: - # Update existing record's specified key - setattr(existing_record, key, value) - else: + avail_record = self.db_get_record(db_target) + if avail_record is None: # Create a new DataRecord if none exists new_record = self.record_class()(date_time=date_time, **{key: value}) - self.records.append(new_record) - self.sort_by_datetime() + self.db_insert_record(new_record) + else: + # Update existing record's specified key + setattr(avail_record, key, value) + self.db_mark_dirty_record(avail_record) def key_to_array( self, @@ -1019,6 +1020,8 @@ class DataSequence(DataBase, MutableSequence): interval: Optional[Duration] = None, fill_method: Optional[str] = None, dropna: Optional[bool] = True, + boundary: Literal["strict", "context"] = "context", + align_to_interval: bool = False, ) -> NDArray[Shape["*"], Any]: """Extract an array indexed by fixed time intervals from data records within an optional date range. @@ -1029,11 +1032,31 @@ class DataSequence(DataBase, MutableSequence): interval (duration, optional): The fixed time interval. Defaults to 1 hour. fill_method (str): Method to handle missing values during resampling. - 'linear': Linearly interpolate missing values (for numeric data only). + - 'time': Interpolate missing values (for numeric data only). - 'ffill': Forward fill missing values. - 'bfill': Backward fill missing values. - 'none': Defaults to 'linear' for numeric values, otherwise 'ffill'. dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True. + boundary (Literal["strict", "context"]): + "strict" → only values inside [start, end) + "context" → include one value before and after for proper resampling + align_to_interval (bool): When True, snap the resample origin to the nearest + UTC epoch-aligned boundary of ``interval`` before resampling. This ensures + that bucket timestamps always fall on wall-clock-round times regardless of + when ``start_datetime`` falls: + + - 15-minute interval → buckets on :00, :15, :30, :45 + - 1-hour interval → buckets on the hour + + When False (default), the origin is ``query_start`` (or ``"start_day"`` when + no start is given), preserving the existing behaviour where buckets are + aligned to the query window rather than the clock. + + Set to True when storing compacted records back to the database so that the + resulting timestamps are predictable and human-readable. Leave False for + forecast or reporting queries where alignment to the exact query window is + more important than clock-round boundaries. Returns: np.ndarray: A NumPy Array of the values at the chosen frequency extracted from the @@ -1045,9 +1068,12 @@ class DataSequence(DataBase, MutableSequence): self._validate_key(key) # Validate fill method - if fill_method not in ("ffill", "bfill", "linear", "none", None): + if fill_method not in ("ffill", "bfill", "linear", "time", "none", None): raise ValueError(f"Unsupported fill method: {fill_method}") + if boundary not in ("strict", "context"): + raise ValueError(f"Unsupported boundary mode: {boundary}") + # Ensure datetime objects are normalized start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None @@ -1058,80 +1084,131 @@ class DataSequence(DataBase, MutableSequence): else: resample_freq = to_duration(interval, as_string="pandas") + # Extend window for context resampling + query_start = start_datetime + query_end = end_datetime + + if boundary == "context": + # include one timestamp before and after for proper resampling + if query_start is not None: + # We have a start datetime - look for previous entry + start_timestamp = DatabaseTimestamp.from_datetime(query_start) + query_start_timestamp = self.db_previous_timestamp(start_timestamp) + if query_start_timestamp: + query_start = DatabaseTimestamp.to_datetime(query_start_timestamp) + if end_datetime is not None: + # We have a end datetime - look for next entry + end_timestamp = DatabaseTimestamp.from_datetime(query_end) + query_end_timestamp = self.db_next_timestamp(end_timestamp) + if query_end_timestamp is None: + # Ensure at least end_datetime is included (excluded by definition) + query_end = end_datetime.add(seconds=1) + else: + query_end = DatabaseTimestamp.to_datetime(query_end_timestamp).add(seconds=1) + # Load raw lists (already sorted & filtered) - dates, values = self.key_to_lists(key=key, dropna=dropna) + dates, values = self.key_to_lists( + key=key, start_datetime=query_start, end_datetime=query_end, dropna=dropna + ) values_len = len(values) # Bring lists into shape if values_len < 1: # No values, assume at least one value set to None - if start_datetime is not None: - dates.append(start_datetime - interval) + if query_start is not None: + dates.append(query_start - interval) else: dates.append(to_datetime(to_maxtime=False)) values.append(None) - if start_datetime is not None: + if query_start is not None: start_index = 0 while start_index < values_len: - if compare_datetimes(dates[start_index], start_datetime).ge: + if compare_datetimes(dates[start_index], query_start).ge: break start_index += 1 if start_index == 0: # No value before start # Add dummy value - dates.insert(0, start_datetime - interval) + dates.insert(0, query_start - interval) values.insert(0, values[0]) elif start_index > 1: - # Truncate all values before latest value before start_datetime + # Truncate all values before latest value before query_start dates = dates[start_index - 1 :] values = values[start_index - 1 :] - # We have a start_datetime, align to start datetime - resample_origin = start_datetime + + # Determine resample origin + if align_to_interval: + # Snap to nearest UTC epoch-aligned floor of the interval so that bucket + # timestamps land on wall-clock-round boundaries (:00, :15, :30, :45 etc.) + # regardless of sub-second jitter in query_start. + interval_sec = int(interval.total_seconds()) + if interval_sec > 0: + start_epoch = int(query_start.timestamp()) + floored_epoch = (start_epoch // interval_sec) * interval_sec + resample_origin: Union[str, pd.Timestamp] = pd.Timestamp( + floored_epoch, unit="s", tz="UTC" + ) + else: + resample_origin = query_start + else: + # Original behaviour: align to the query window start. + resample_origin = query_start else: - # We do not have a start_datetime, align resample buckets to midnight of first day + # We do not have a query_start, align resample buckets to midnight of first day resample_origin = "start_day" - if end_datetime is not None: - if compare_datetimes(dates[-1], end_datetime).lt: - # Add dummy value at end_datetime - dates.append(end_datetime) + if query_end is not None: + if compare_datetimes(dates[-1], query_end).lt: + # Add dummy value at query_end + dates.append(query_end) values.append(values[-1]) # Construct series - series = pd.Series(values, index=pd.DatetimeIndex(dates), name=key) + index = pd.to_datetime(dates, utc=True) + series = pd.Series(values, index=index, name=key) if series.index.inferred_type != "datetime64": raise TypeError( f"Expected DatetimeIndex, but got {type(series.index)} " f"infered to {series.index.inferred_type}: {series}" ) + # Check for numeric values + numeric = pd.to_numeric(series.dropna(), errors="coerce") + is_numeric = numeric.notna().all() + # Determine default fill method depending on dtype if fill_method is None: - if pd.api.types.is_numeric_dtype(series): - fill_method = "linear" + if is_numeric: + fill_method = "time" else: fill_method = "ffill" # Perform the resampling - if pd.api.types.is_numeric_dtype(series): - # numeric → use mean - resampled = series.resample(interval, origin=resample_origin).mean() - else: - # non-numeric → fallback (first, last, mode, or ffill) - resampled = series.resample(interval, origin=resample_origin).first() + if is_numeric: + # Step 1: aggregate — collapses sub-interval data (e.g. 4x 15min → 1h mean). + # Produces NaN for buckets where no data existed at all. + resampled = pd.to_numeric( + series.resample(resample_freq, origin=resample_origin).mean(), + errors="coerce", # ← ensures float64, not object dtype + ) - # Handle missing values after resampling - if fill_method == "linear" and pd.api.types.is_numeric_dtype(series): - resampled = resampled.interpolate("linear") - elif fill_method == "ffill": - resampled = resampled.ffill() - elif fill_method == "bfill": - resampled = resampled.bfill() - elif fill_method == "none": - pass + # Step 2: fill gaps — interpolates or fills the NaN buckets from step 1. + if fill_method in ("linear", "time"): + # Both are equivalent post-resample (equally-spaced index), + # but 'time' is kept as the label for clarity. + resampled = resampled.interpolate("time") + elif fill_method == "ffill": + resampled = resampled.ffill() + elif fill_method == "bfill": + resampled = resampled.bfill() + # fill_method == "none": leave NaNs in place else: - raise ValueError(f"Unsupported fill method: {fill_method}") + resampled = series.resample(resample_freq, origin=resample_origin).first() + if fill_method == "ffill": + resampled = resampled.ffill() + elif fill_method == "bfill": + resampled = resampled.bfill() logger.debug( "Resampled for '{}' with length {}: {}...{}", @@ -1182,11 +1259,19 @@ class DataSequence(DataBase, MutableSequence): if not self.records: return pd.DataFrame() # Return empty DataFrame if no records exist - # Use filter_by_datetime to get filtered records - filtered_records = self.filter_by_datetime(start_datetime, end_datetime) + # Ensure datetime objects are normalized + start_timestamp = ( + DatabaseTimestamp.from_datetime(start_datetime) if start_datetime else None + ) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) if end_datetime else None # Convert filtered records to a dictionary list - data = [record.model_dump() for record in filtered_records] + data = [ + record.model_dump() + for record in self.db_iterate_records( + start_timestamp=start_timestamp, end_timestamp=end_timestamp + ) + ] # Convert to DataFrame df = pd.DataFrame(data) @@ -1201,74 +1286,30 @@ class DataSequence(DataBase, MutableSequence): df.index = pd.DatetimeIndex(df["date_time"]) return df - def sort_by_datetime(self, reverse: bool = False) -> None: - """Sort the DataRecords in the sequence by their date_time attribute. - - This method modifies the existing list of records in place, arranging them in order - based on the date_time attribute of each DataRecord. - - Args: - reverse (bool, optional): If True, sorts in descending order. - If False (default), sorts in ascending order. - - Raises: - TypeError: If any record's date_time attribute is None or not comparable. - """ - try: - # Use a default value (-inf or +inf) for None to make all records comparable - self.records.sort( - key=lambda record: record.date_time or pendulum.datetime(1, 1, 1, 0, 0, 0), - reverse=reverse, - ) - except TypeError as e: - # Provide a more informative error message - none_records = [i for i, record in enumerate(self.records) if record.date_time is None] - if none_records: - raise TypeError( - f"Cannot sort: {len(none_records)} record(s) have None date_time " - f"at indices {none_records}" - ) from e - raise - def delete_by_datetime( - self, start_datetime: Optional[DateTime] = None, end_datetime: Optional[DateTime] = None - ) -> None: - """Delete DataRecords from the sequence within a specified datetime range. + self, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + ) -> int: + """Delete records in the given datetime range. - Removes records with `date_time` attributes that fall between `start_datetime` (inclusive) - and `end_datetime` (exclusive). If only `start_datetime` is provided, records from that date - onward will be removed. If only `end_datetime` is provided, records up to that date will be - removed. If none is given, no record will be deleted. + Deletes records from memory and, if database storage is enabled, from the database. + Returns the maximum of in-memory and database deletions. Args: - start_datetime (datetime, optional): The start date to begin deleting records (inclusive). - end_datetime (datetime, optional): The end date to stop deleting records (exclusive). + start_datetime: Start datetime (inclusive) + end_datetime: End datetime (exclusive) - Raises: - ValueError: If both `start_datetime` and `end_datetime` are None. + Returns: + Number of records deleted (max of memory and database deletions) """ # Ensure datetime objects are normalized - start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None - end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + start_timestamp = ( + DatabaseTimestamp.from_datetime(start_datetime) if start_datetime else None + ) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) if end_datetime else None - # Retain records that are outside the specified range - retained_records = [] - for record in self.records: - if record.date_time is None: - continue - if ( - ( - start_datetime is not None - and compare_datetimes(record.date_time, start_datetime).lt - ) - or ( - end_datetime is not None - and compare_datetimes(record.date_time, end_datetime).ge - ) - or (start_datetime is None and end_datetime is None) - ): - retained_records.append(record) - self.records = retained_records + return self.db_delete_records(start_timestamp=start_timestamp, end_timestamp=end_timestamp) def key_delete_by_datetime( self, @@ -1294,39 +1335,49 @@ class DataSequence(DataBase, MutableSequence): KeyError: If `key` is not a valid attribute of the records. """ self._validate_key_writable(key) + # Ensure datetime objects are normalized - start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None - end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + start_timestamp = ( + DatabaseTimestamp.from_datetime(start_datetime) if start_datetime else None + ) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) if end_datetime else None - for record in self.records: - if ( - start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge - ) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt): - del record[key] + for record in self.db_iterate_records(start_timestamp, end_timestamp): + del record[key] + self.db_mark_dirty_record(record) - def filter_by_datetime( - self, start_datetime: Optional[DateTime] = None, end_datetime: Optional[DateTime] = None - ) -> "DataSequence": - """Returns a new DataSequence object containing only records within the specified datetime range. - - Args: - start_datetime (Optional[datetime]): The start of the datetime range (inclusive). If None, no lower limit. - end_datetime (Optional[datetime]): The end of the datetime range (exclusive). If None, no upper limit. + def save(self) -> bool: + """Save data records to persistent storage. Returns: - DataSequence: A new DataSequence object with filtered records. + True in case the data records were saved, False otherwise. """ - # Ensure datetime objects are normalized - start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None - end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + if not self.db_enabled: + return False - filtered_records = [ - record - for record in self.records - if (start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge) - and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt) - ] - return self.__class__(records=filtered_records) + saved = self.db_save_records() + return saved > 0 + + def load(self) -> bool: + """Load data records from from persistent storage. + + Returns: + True in case the data records were loaded, False otherwise. + """ + if not self.db_enabled: + return False + + loaded = self.db_load_records() + return loaded > 0 + + # ----------------------- DataSequence Database Protocol --------------------- + + # Required interface propagated to derived class. + # - db_keep_duration + # - db_namespace + + +# ==================== DataProvider ==================== class DataProvider(SingletonMixin, DataSequence): @@ -1374,6 +1425,10 @@ class DataProvider(SingletonMixin, DataSequence): return super().__init__(*args, **kwargs) + def db_namespace(self) -> str: + """Namespace of database.""" + return self.provider_id() + def update_data( self, force_enable: Optional[bool] = False, @@ -1392,11 +1447,11 @@ class DataProvider(SingletonMixin, DataSequence): # Call the custom update logic self._update_data(force_update=force_update) - # Assure records are sorted. - self.sort_by_datetime() + +# ==================== DataImportMixin ==================== -class DataImportMixin: +class DataImportMixin(StartMixin): """Mixin class for import of generic data. This class is designed to handle generic data provided in the form of a key-value dictionary. @@ -1417,89 +1472,7 @@ class DataImportMixin: # Attributes required but defined elsehere. # - start_datetime # - record_keys_writable - # - update_valu - - def import_datetimes( - self, start_datetime: DateTime, value_count: int, interval: Optional[Duration] = None - ) -> List[Tuple[DateTime, int]]: - """Generates a list of tuples containing timestamps and their corresponding value indices. - - The function accounts for daylight saving time (DST) transitions: - - During a spring forward transition (e.g., DST begins), skipped hours are omitted. - - During a fall back transition (e.g., DST ends), repeated hours are included, - but they share the same value index. - - Args: - start_datetime (DateTime): Start datetime of values - value_count (int): The number of timestamps to generate. - interval (duration, optional): The fixed time interval. Defaults to 1 hour. - - Returns: - List[Tuple[DateTime, int]]: - A list of tuples, where each tuple contains: - - A `DateTime` object representing an hourly step from `start_datetime`. - - An integer value index corresponding to the logical hour. - - Behavior: - - Skips invalid timestamps during DST spring forward transitions. - - Includes both instances of repeated timestamps during DST fall back transitions. - - Ensures the list contains exactly 'value_count' entries. - - Example: - .. code-block:: python - - start_datetime = pendulum.datetime(2024, 11, 3, 0, 0, tz="America/New_York") - import_datetimes(start_datetime, 5) - - [(DateTime(2024, 11, 3, 0, 0, tzinfo=Timezone('America/New_York')), 0), - (DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), - (DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), # Repeated hour - (DateTime(2024, 11, 3, 2, 0, tzinfo=Timezone('America/New_York')), 2), - (DateTime(2024, 11, 3, 3, 0, tzinfo=Timezone('America/New_York')), 3)] - - """ - timestamps_with_indices: List[Tuple[DateTime, int]] = [] - - if interval is None: - interval = to_duration("1 hour") - interval_steps_per_hour = int(3600 / interval.total_seconds()) - if interval.total_seconds() * interval_steps_per_hour != 3600: - error_msg = f"Interval {interval} does not fit into hour." - logger.error(error_msg) - raise NotImplementedError(error_msg) - - value_datetime = start_datetime - value_index = 0 - - while value_index < value_count: - i = len(timestamps_with_indices) - logger.debug(f"{i}: Insert at {value_datetime} with index {value_index}") - timestamps_with_indices.append((value_datetime, value_index)) - - next_time = value_datetime.add(seconds=interval.total_seconds()) - - # Check if there is a DST transition - if next_time.dst() != value_datetime.dst(): - if next_time.hour == value_datetime.hour: - # We jump back by 1 hour - # Repeat the value(s) (reuse value index) - for i in range(interval_steps_per_hour): - logger.debug(f"{i + 1}: Repeat at {next_time} with index {value_index}") - timestamps_with_indices.append((next_time, value_index)) - next_time = next_time.add(seconds=interval.total_seconds()) - else: - # We jump forward by 1 hour - # Drop the value(s) - logger.debug( - f"{i + 1}: Skip {interval_steps_per_hour} at {next_time} with index {value_index}" - ) - value_index += interval_steps_per_hour - - # Increment value index and value_datetime for new interval - value_index += 1 - value_datetime = next_time - - return timestamps_with_indices + # - update_value def import_from_dict( self, @@ -1533,7 +1506,7 @@ class DataImportMixin: raise ValueError(f"Invalid start_datetime in import data: {e}") if start_datetime is None: - start_datetime = self.ems_start_datetime # type: ignore + start_datetime = self.ems_start_datetime if "interval" in import_data: try: @@ -1541,6 +1514,14 @@ class DataImportMixin: except (ValueError, TypeError) as e: raise ValueError(f"Invalid interval in import data: {e}") + if interval is None: + interval = to_duration("1 hour") + interval_steps_per_hour = int(3600 / interval.total_seconds()) + if interval.total_seconds() * interval_steps_per_hour != 3600: + error_msg = f"Interval {interval} does not fit into hour." + logger.error(error_msg) + raise NotImplementedError(error_msg) + # Filter keys based on key_prefix and record_keys_writable valid_keys = [ key @@ -1567,20 +1548,20 @@ class DataImportMixin: f"{dict(zip(valid_keys, value_lengths))}" ) - # Generate datetime mapping once for the common length values_count = value_lengths[0] - value_datetime_mapping = self.import_datetimes( - start_datetime, values_count, interval=interval - ) # Process each valid key + start_timestamp = DatabaseTimestamp.from_datetime(start_datetime) for key in valid_keys: try: - value_list = import_data[key] + values = import_data[key] # Update values, skipping any None/NaN - for value_datetime, value_index in value_datetime_mapping: - value = value_list[value_index] + for value_index, value_db_datetime in enumerate( + self.db_generate_timestamps(start_timestamp, values_count, interval) # type: ignore[attr-defined] + ): + value = values[value_index] + value_datetime = DatabaseTimestamp.to_datetime(value_db_datetime) if value is not None and not pd.isna(value): self.update_value(value_datetime, key, value) # type: ignore @@ -1624,7 +1605,7 @@ class DataImportMixin: raise ValueError(f"Invalid datetime index in DataFrame: {e}") else: if start_datetime is None: - start_datetime = self.ems_start_datetime # type: ignore + start_datetime = self.ems_start_datetime has_datetime_index = False # Filter columns based on key_prefix and record_keys_writable @@ -1642,8 +1623,10 @@ class DataImportMixin: # Generate value_datetime_mapping once if not using datetime index if not has_datetime_index: - value_datetime_mapping = self.import_datetimes( - start_datetime, values_count, interval=interval + # Create values datetime list + start_timestamp = DatabaseTimestamp.from_datetime(start_datetime) + value_db_datetimes = list( + self.db_generate_timestamps(start_timestamp, values_count, interval) # type: ignore[attr-defined] ) # Process each valid column @@ -1657,9 +1640,12 @@ class DataImportMixin: if value is not None and not pd.isna(value): self.update_value(dt, column, value) # type: ignore else: - # Use the pre-generated datetime mapping - for value_datetime, value_index in value_datetime_mapping: + # Use the pre-generated datetime index + for value_index in range(values_count): value = values[value_index] + value_datetime = DatabaseTimestamp.to_datetime( + value_db_datetimes[value_index] + ) if value is not None and not pd.isna(value): self.update_value(value_datetime, column, value) # type: ignore @@ -1802,6 +1788,9 @@ class DataImportMixin: ) +# ==================== DataImportProvider ==================== + + class DataImportProvider(DataImportMixin, DataProvider): """Abstract base class for data providers that import generic data. @@ -1817,7 +1806,10 @@ class DataImportProvider(DataImportMixin, DataProvider): pass -class DataContainer(SingletonMixin, DataBase, MutableMapping): +# ==================== DataContainer ==================== + + +class DataContainer(SingletonMixin, DataABC, MutableMapping): """A container for managing multiple DataProvider instances. This class enables access to data from multiple data providers, supporting retrieval and @@ -1830,12 +1822,12 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): """ # To be overloaded by derived classes. - providers: List[DataProvider] = Field( + providers: list[DataProvider] = Field( default_factory=list, json_schema_extra={"description": "List of data providers"} ) @field_validator("providers", mode="after") - def check_providers(cls, value: List[DataProvider]) -> List[DataProvider]: + def check_providers(cls, value: list[DataProvider]) -> list[DataProvider]: # Check each item in the list for item in value: if not isinstance(item, DataProvider): @@ -1845,7 +1837,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): return value @property - def enabled_providers(self) -> List[Any]: + def enabled_providers(self) -> list[Any]: """List of providers that are currently enabled.""" enab = [] for provider in self.providers: @@ -1971,6 +1963,9 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): """ return f"{self.__class__.__name__}({self.providers})" + def keys(self) -> KeysView[str]: + return dict.fromkeys(self.record_keys).keys() + def update_data( self, force_enable: Optional[bool] = False, @@ -2039,6 +2034,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): end_datetime: Optional[DateTime] = None, interval: Optional[Duration] = None, fill_method: Optional[str] = None, + boundary: Optional[str] = "context", ) -> NDArray[Shape["*"], Any]: """Retrieve an array indexed by fixed time intervals for a specified key from the data in each DataProvider. @@ -2073,6 +2069,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): end_datetime=end_datetime, interval=interval, fill_method=fill_method, + boundary=boundary, ) break except KeyError: @@ -2197,3 +2194,61 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping): logger.error(error_msg) raise ValueError(error_msg) return providers[provider_id] + + # ----------------------- DataContainer Database Protocol --------------------- + + def save(self) -> None: + """Save data records to persistent storage.""" + for provider in self.providers: + try: + provider.save() + except Exception as ex: + error = f"Provider {provider.provider_id()} fails on save: {ex}" + logger.error(error) + raise RuntimeError(error) + + def load(self) -> None: + """Load data records from from persistent storage.""" + for provider in self.providers: + try: + provider.load() + except Exception as ex: + error = f"Provider {provider.provider_id()} fails on load: {ex}" + logger.error(error) + raise RuntimeError(error) + + def db_vacuum(self) -> None: + """Remove old records of all providers from database to free space.""" + for provider in self.providers: + try: + provider.db_vacuum() + except Exception as ex: + error = f"Provider {provider.provider_id()} fails on db vacuum: {ex}" + logger.error(error) + raise RuntimeError(error) + + def db_compact(self) -> None: + """Apply tiered compaction to all providers to reduce storage while retaining coverage.""" + for provider in self.providers: + try: + provider.db_compact() + except Exception as ex: + error = f"Provider {provider.provider_id()} fails on db_compact: {ex}" + logger.error(error) + raise RuntimeError(error) + + def db_get_stats(self) -> dict: + """Get comprehensive statistics about database storage for all providers. + + Returns: + Dictionary with statistics + """ + db_stats = {} + for provider in self.providers: + try: + db_stats[provider.db_namespace()] = provider.db_get_stats() + except Exception as ex: + error = f"Provider {provider.provider_id()} fails on db vacuum: {ex}" + logger.error(error) + raise RuntimeError(error) + return db_stats diff --git a/src/akkudoktoreos/core/database.py b/src/akkudoktoreos/core/database.py new file mode 100644 index 0000000..b806d33 --- /dev/null +++ b/src/akkudoktoreos/core/database.py @@ -0,0 +1,1178 @@ +"""Database persistence extension for data records with plugin architecture. + +Provides an abstract database interface and concrete implementations for various +backends. This version exposes first-class "namespace" support: the Database +abstract interface and concrete implementations accept an optional `namespace` +argument on methods. LMDB uses named DBIs for namespaces; SQLite emulates +namespaces with a `namespace` column. +""" + +from __future__ import annotations + +import shutil +import sqlite3 +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple + +import lmdb +from loguru import logger +from pydantic import Field, computed_field, field_validator + +from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import SingletonMixin +from akkudoktoreos.core.databaseabc import ( + DATABASE_METADATA_KEY, + DatabaseABC, + DatabaseBackendABC, +) + +# Valid database providers +database_providers: List[str] = ["LMDB", "SQLite"] + + +class DatabaseCommonSettings(SettingsBaseModel): + """Configuration model for database settings. + + Attributes: + provider: Optional provider identifier (e.g. "LMDB"). + max_records_in_memory: Maximum records kept in memory before auto-save. + auto_save: Whether to auto-save when threshold exceeded. + batch_size: Batch size for batch operations. + """ + + provider: Optional[str] = Field( + default=None, + json_schema_extra={ + "description": "Database provider id of provider to be used.", + "examples": ["LMDB"], + }, + ) + + compression_level: int = Field( + default=9, + ge=0, + le=9, + json_schema_extra={ + "description": "Compression level for database record data.", + "examples": [0, 9], + }, + ) + + initial_load_window_h: Optional[int] = Field( + default=None, + ge=0, + json_schema_extra={ + "description": ( + "Specifies the default duration of the initial load window when " + "loading records from the database, in hours. " + "If set to None, the full available range is loaded. " + "The window is centered around the current time by default, " + "unless a different center time is specified. " + "Different database namespaces may define their own default windows." + ), + "examples": ["48", "None"], + }, + ) + + keep_duration_h: Optional[int] = Field( + default=None, + ge=0, + json_schema_extra={ + "description": ( + "Default maximum duration records shall be kept in database [hours, none].\n" + "None indicates forever. Database namespaces may have diverging definitions." + ), + "examples": [48, "none"], + }, + ) + + autosave_interval_sec: Optional[int] = Field( + default=10, + ge=5, + json_schema_extra={ + "description": ( + "Automatic saving interval [seconds].\nSet to None to disable automatic saving." + ), + "examples": [5], + }, + ) + + compaction_interval_sec: Optional[int] = Field( + default=7 * 24 * 3600, # weekly + ge=0, + json_schema_extra={ + "description": ( + "Interval in between automatic tiered compaction runs [seconds].\n" + "Compaction downsamples old records to reduce storage while retaining " + "coverage. Set to None to disable automatic compaction." + ), + "examples": [604800], # 1 week + }, + ) + + batch_size: int = Field( + default=100, + json_schema_extra={ + "description": "Number of records to process in batch operations.", + "examples": [100], + }, + ) + + @computed_field # type: ignore[prop-decorator] + @property + def providers(self) -> List[str]: + """Return available database provider ids.""" + return database_providers + + @field_validator("provider", mode="after") + @classmethod + def validate_provider(cls, value: Optional[str]) -> Optional[str]: + """Validate provider is in allowed list. + + Args: + value: provider value to validate. + + Returns: + The validated provider or None. + + Raises: + ValueError: if provider is not in the allowed list. + """ + if value is None or value in database_providers: + return value + raise ValueError( + f"Provider '{value}' is not a valid database provider: {database_providers}." + ) + + +class LMDBDatabase(DatabaseBackendABC): + """LMDB implementation using named DBIs for namespaces.""" + + env: Optional[lmdb.Environment] + _dbis: Dict[Optional[str], Optional[Any]] + + def __init__( + self, + map_size: int = 10 * 1024 * 1024 * 1024, + **kwargs: Any, + ) -> None: + """Initialize LMDB backend. + + Args: + storage_path: directory to store LMDB files. + compression: whether to compress values. + compression_level: gzip compression level. + map_size: maximum LMDB map size. + """ + super().__init__() + self.map_size = map_size + self.env = None + self._dbis = {None: None} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def provider_id(self) -> str: + """Return the unique identifier for the database provider.""" + return "LMDB" + + def open(self, namespace: Optional[str] = None) -> None: + """Open LMDB environment and optionally ensure a namespace DBI. + + Args: + namespace: Optional default namespace to open (DBI created on demand). + """ + self.storage_path.mkdir(parents=True, exist_ok=True) + + self.env = lmdb.open( + str(self.storage_path), + map_size=self.map_size, + max_dbs=128, + writemap=True, + map_async=True, + metasync=False, + sync=False, + lock=True, + ) + + self.connection = self.env + self._is_open = True + self.default_namespace = namespace + + if namespace is not None: + self._ensure_dbi(namespace) + + def close(self) -> None: + """Close the LMDB environment and clear cached DBIs.""" + if self.env: + self.env.sync() + self.env.close() + self.env = None + self.connection = None + self._is_open = False + self._dbis.clear() + logger.debug("Closed LMDB at %s", self.storage_path) + + def flush(self, namespace: Optional[str] = None) -> None: + """Sync LMDB environment (writes to disk).""" + if not isinstance(self.env, lmdb.Environment): + raise ValueError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + with self.lock: + self.env.sync() + + # ------------------------------------------------------------------ + # Namespace helpers + # ------------------------------------------------------------------ + + def _normalize_namespace(self, namespace: Optional[str]) -> Optional[str]: + """Return explicit namespace or default if None.""" + return namespace if namespace is not None else self.default_namespace + + def _ensure_dbi(self, namespace: Optional[str]) -> Optional[Any]: + """Open and cache a DBI for the given namespace. + + Args: + namespace: Namespace name or None for the unnamed DB. + + Returns: + DBI handle (implementation specific) or None for unnamed DB. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + name = self._normalize_namespace(namespace) + + if name in self._dbis: + return self._dbis[name] + + if name is None: + dbi = None + else: + dbi = self.env.open_db(name.encode("utf-8"), create=True) + + self._dbis[name] = dbi + return dbi + + # ------------------------------------------------------------------ + # Metadata Operations + # ------------------------------------------------------------------ + + def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None: + """Save metadata for a given namespace. + + Metadata is treated separately from data records and stored as a single object. + + Args: + metadata (bytes): Arbitrary metadata to save or None to delete metadata. + namespace (Optional[str]): Optional namespace under which to store metadata. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + + with self.env.begin(write=True) as txn: + if metadata is None: + txn.delete(DATABASE_METADATA_KEY) + else: + txn.put(DATABASE_METADATA_KEY, metadata) + + def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]: + """Load metadata for a given namespace. + + Returns None if no metadata exists. + + Args: + namespace (Optional[str]): Optional namespace whose metadata to retrieve. + + Returns: + Optional[bytes]: The loaded metadata, or None if not found. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + + with self.env.begin(write=False) as txn: + return txn.get(DATABASE_METADATA_KEY) + + # ------------------------------------------------------------------ + # Bulk Write Operations + # ------------------------------------------------------------------ + + def save_records( + self, + records: Iterable[tuple[bytes, bytes]], + namespace: Optional[str] = None, + ) -> int: + """Save multiple records into the specified namespace (or default). + + Args: + records: Iterable providing key, value tuples ordered by key: + - key: Byte key (sortable) for the record. + - value: Serialized (and optionally compressed) bytes to store. + namespace: Optional namespace. + + Returns: + Number of records saved. + + Raises: + RuntimeError: If DB not open or write failed. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + + saved = 0 + with self.lock: + with self.env.begin(write=True) as txn: + for key, value in records: + if txn.put(key, value, db=dbi): + saved += 1 + + return saved + + def delete_records( + self, + keys: Iterable[bytes], + namespace: Optional[str] = None, + ) -> int: + """Delete multiple records by key from the specified namespace. + + Args: + keys: Iterable that provides the Byte keys to delete. + namespace: Optional namespace. + + Returns: + Number of records actually deleted. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError("Database not open") + + dbi = self._ensure_dbi(namespace) + + deleted = 0 + with self.lock: + with self.env.begin(write=True) as txn: + for key in keys: + if txn.delete(key, db=dbi): + deleted += 1 + + return deleted + + # ------------------------------------------------------------------ + # Read Operations + # ------------------------------------------------------------------ + + def iterate_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + namespace: Optional[str] = None, + reverse: bool = False, + ) -> Iterator[tuple[bytes, bytes]]: + """Iterate over records in a namespace with optional key bounds. + + The LMDB read transaction is fully closed before yielding any results, + preventing reader-slot leaks even if the caller aborts iteration early. + + Args: + start_key: Inclusive lower bound key, or None. + end_key: Exclusive upper bound key, or None. + namespace: Optional namespace to target. + reverse: If True, iterate in descending key order. + + Yields: + Tuples of (key, value). + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong type `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + META = DATABASE_METADATA_KEY + + results: list[tuple[bytes, bytes]] = [] + + txn = self.env.begin(write=False) + try: + cursor = txn.cursor(dbi) + + if reverse: + # --- Position cursor for reverse scan --- + + if end_key is not None: + # Jump to first key >= end_key, then step one back + if cursor.set_range(end_key): + if not cursor.prev(): + # No smaller key exists + return iter(()) + else: + if not cursor.last(): + return iter(()) + else: + if not cursor.last(): + return iter(()) + + while True: + key = cursor.key() + value = cursor.value() + + if key != META: + if start_key is None or key >= start_key: + results.append((key, value)) + else: + break + + if not cursor.prev(): + break + + else: + # --- Position cursor for forward scan --- + + if start_key is not None: + if not cursor.set_range(start_key): + return iter(()) + else: + if not cursor.first(): + return iter(()) + + while True: + key = cursor.key() + value = cursor.value() + + if end_key is not None and key >= end_key: + break + + if key != META: + results.append((key, value)) + + if not cursor.next(): + break + + finally: + # Ensure reader slot is always released + cursor.close() + txn.abort() + + # Transaction is closed here — safe to yield + return iter(results) + + # ------------------------------------------------------------------ + # Stats / Metadata + # ------------------------------------------------------------------ + + def count_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + *, + namespace: Optional[str] = None, + ) -> int: + """Count records in [start_key, end_key) excluding metadata in specified namespace. + + Excludes metadata records. + """ + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + META = DATABASE_METADATA_KEY + + count = 0 + + with self.env.begin(write=False) as txn: + cursor = txn.cursor(db=dbi) + + # Position cursor + if start_key: + if not cursor.set_range(start_key): + return 0 + else: + if not cursor.first(): + return 0 + + while True: + key = cursor.key() + + if end_key and key >= end_key: + break + + if key != META: + count += 1 + + if not cursor.next(): + break + + return count + + def get_key_range( + self, + namespace: Optional[str] = None, + ) -> tuple[Optional[bytes], Optional[bytes]]: + """Return (min_key, max_key) in the given namespace or (None, None) if empty.""" + if not isinstance(self.env, lmdb.Environment): + raise RuntimeError(f"LMDB Environment is of wrong tpe `{type(self.env)}`.") + + dbi = self._ensure_dbi(namespace) + + with self.env.begin(write=False) as txn: + cursor = txn.cursor(db=dbi) + + if not cursor.first(): + return None, None + + min_key = cursor.key() + if min_key == DATABASE_METADATA_KEY: + if not cursor.next(): + return None, None + min_key = cursor.key() + + if not cursor.last(): + return None, None + + max_key = cursor.key() + if max_key == DATABASE_METADATA_KEY: + if not cursor.prev(): + return None, None + max_key = cursor.key() + + return min_key, max_key + + def get_backend_stats(self, namespace: Optional[str] = None) -> dict[str, Any]: + """Get LMDB backend-specific statistics.""" + if not self.env: + return {} + + dbi = self._ensure_dbi(namespace) + + with self.env.begin(write=False) as txn: + stat = txn.stat(db=dbi) + info = self.env.info() + + return { + "backend": "lmdb", + "entries": int(stat.get("entries", 0)), + "page_size": stat.get("psize"), + "depth": stat.get("depth"), + "branch_pages": stat.get("branch_pages"), + "leaf_pages": stat.get("leaf_pages"), + "overflow_pages": stat.get("overflow_pages"), + "map_size": info.get("map_size"), + "last_pgno": info.get("last_pgno"), + "last_txnid": info.get("last_txnid"), + "namespace": namespace or self.default_namespace, + } + + def compact(self) -> None: + """Compact LMDB by copying a compact snapshot and atomically replacing files. + + Raises: + RuntimeError: If the environment is not open. + """ + if not self.env: + raise RuntimeError("Database not open") + + logger.info("Starting LMDB compaction...") + + orig_path = Path(self.storage_path) + backup_parent = orig_path.parent + backup_dir = backup_parent / f"{orig_path.name}_compact_tmp" + final_backup_dir = backup_parent / f"{orig_path.name}_compact" + + try: + if backup_dir.exists(): + shutil.rmtree(backup_dir) + if final_backup_dir.exists(): + shutil.rmtree(final_backup_dir) + except Exception: + logger.exception("Failed to remove existing backup dirs before compaction") + + try: + backup_dir.mkdir(parents=True, exist_ok=False) + with self.lock: + self.env.copy(str(backup_dir), compact=True) + try: + self.close() + except Exception: + logger.exception( + "Failed to close LMDB environment after copy; proceeding with replacement" + ) + + try: + if orig_path.exists(): + shutil.rmtree(orig_path) + shutil.move(str(backup_dir), str(final_backup_dir)) + shutil.move(str(final_backup_dir), str(orig_path)) + except Exception as exc: + logger.exception( + "Failed to replace original LMDB files with compacted copy: %s", exc + ) + try: + if final_backup_dir.exists() and not orig_path.exists(): + shutil.move(str(final_backup_dir), str(orig_path)) + except Exception: + logger.exception("Failed to restore original LMDB after failed replacement") + raise + + try: + self.open() + except Exception: + logger.exception("Failed to re-open LMDB after compaction; DB may be closed") + raise + + logger.info("LMDB compaction completed successfully: %s", str(self.storage_path)) + finally: + try: + if backup_dir.exists(): + shutil.rmtree(backup_dir) + if final_backup_dir.exists(): + shutil.rmtree(final_backup_dir) + except Exception: + logger.exception("Failed to clean up temporary backup directories after compaction") + + +# ==================== SQLite Implementation ==================== + + +class SQLiteDatabase(DatabaseBackendABC): + """SQLite implementation that stores a `namespace` column to emulate namespaces.""" + + db_file: Path + conn: Optional[Any] + + def __init__(self, **kwargs: Any) -> None: + """Initialize SQLite backend.""" + super().__init__() + self.db_file = self.storage_path / "data.db" + self.conn = None + + def _ns(self, namespace: Optional[str]) -> str: + """Normalize namespace for storage ('' for None).""" + return namespace if namespace is not None else (self.default_namespace or "") + + def provider_id(self) -> str: + """Return the unique identifier for the database provider.""" + return "SQLite" + + def open(self, namespace: Optional[str] = None) -> None: + """Open SQLite connection and optionally set default namespace. + + Args: + namespace: Optional default namespace to use when operations omit namespace. + """ + self.storage_path.mkdir(parents=True, exist_ok=True) + + self.conn = sqlite3.connect( + str(self.db_file), + isolation_level=None, # autocommit + check_same_thread=False, + ) + + # Create table with namespace column and composite primary key (namespace, key) + self.conn.execute( + """ + CREATE TABLE IF NOT EXISTS records ( + namespace TEXT NOT NULL DEFAULT '', + key BLOB NOT NULL, + value BLOB NOT NULL, + PRIMARY KEY (namespace, key) + ) + """ + ) + + # Index to accelerate range queries per namespace + self.conn.execute("CREATE INDEX IF NOT EXISTS idx_namespace_key ON records(namespace, key)") + + self.connection = self.conn + self._is_open = True + self.default_namespace = namespace + logger.debug("Opened SQLite at %s (default_namespace=%s)", self.db_file, namespace) + + def close(self) -> None: + """Close SQLite connection.""" + if self.conn: + self.conn.close() + self.conn = None + self.connection = None + self._is_open = False + logger.debug("Closed SQLite at %s", self.db_file) + + def flush(self, namespace: Optional[str] = None) -> None: + """Commit any pending transactions to disk (no-op if autocommit).""" + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError(f"SQLite connection is of wrong tpe `{type(self.conn)}`.") + + with self.lock: + self.conn.commit() + + def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None: + """Save metadata for a given namespace. + + Metadata is treated separately from data records and stored as a single object. + + Args: + metadata (bytes): Arbitrary metadata to save or None to delete metadata. + namespace (Optional[str]): Optional namespace under which to store metadata. + """ + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError("Database not open") + + ns = self._ns(namespace) + + with self.conn: + # Ensure metadata table exists + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS metadata ( + namespace TEXT PRIMARY KEY, + value BLOB + ) + """) + + if metadata is None: + # Delete metadata for the namespace + self.conn.execute("DELETE FROM metadata WHERE namespace=?", (ns,)) + else: + # Insert or update metadata + self.conn.execute( + """ + INSERT INTO metadata(namespace, value) + VALUES (?, ?) + ON CONFLICT(namespace) DO UPDATE SET value=excluded.value + """, + (ns, metadata), + ) + + def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]: + """Load metadata for a given namespace. + + Returns None if no metadata exists. + + Args: + namespace (Optional[str]): Optional namespace whose metadata to retrieve. + + Returns: + Optional[bytes]: The loaded metadata, or None if not found. + """ + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError("Database not open") + + ns = self._ns(namespace) + + # Ensure metadata table exists + with self.conn: + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS metadata ( + namespace TEXT PRIMARY KEY, + value BLOB + ) + """) + row = self.conn.execute( + "SELECT value FROM metadata WHERE namespace=?", (ns,) + ).fetchone() + return row[0] if row else None + + def save_records( + self, + records: Iterable[tuple[bytes, bytes]], + namespace: Optional[str] = None, + ) -> int: + """Bulk insert or replace records. + + Returns: + Number of records written. + """ + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError("Database not open") + + ns = self._ns(namespace) + + rows = [(ns, k, v) for k, v in records] + if not rows: + return 0 + + with self.lock: + self.conn.execute("BEGIN") + self.conn.executemany( + "INSERT OR REPLACE INTO records (namespace, key, value) VALUES (?, ?, ?)", + rows, + ) + self.conn.execute("COMMIT") + + return len(rows) + + def delete_records( + self, + keys: Iterable[bytes], + namespace: Optional[str] = None, + ) -> int: + """Delete multiple records by key. + + Returns True if at least one row was deleted. + """ + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError("Database not open") + + ns = self._ns(namespace) + + deleted: int = 0 + with self.lock: + for key in keys: + cursor = self.conn.execute( + "DELETE FROM records WHERE namespace = ? AND key = ?", + (ns, key), + ) + deleted += cursor.rowcount + + return deleted + + def iterate_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + namespace: Optional[str] = None, + reverse: bool = False, + ) -> Iterator[Tuple[bytes, bytes]]: + """Iterate records for a namespace within optional bounds. + + Snapshot-based iteration: + - Query results are materialized while holding the lock. + - Yields happen after releasing the lock. + - Metadata key is excluded. + - Range semantics: [start_key, end_key) + + Args: + start_key: Inclusive lower bound or None. + end_key: Exclusive upper bound or None. + namespace: Optional namespace. + reverse: If True iterate descending. + + Yields: + (key, value) tuples ordered by key. + """ + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError(f"SQLite connection is of wrong tpe `{type(self.conn)}`.") + + ns = self._ns(namespace) + order = "DESC" if reverse else "ASC" + + where_clauses = ["namespace = ?", "key != ?"] + params: List[Any] = [ns, DATABASE_METADATA_KEY] + + if start_key is not None: + where_clauses.append("key >= ?") + params.append(start_key) + + if end_key is not None: + where_clauses.append("key < ?") + params.append(end_key) + + where_sql = " AND ".join(where_clauses) + sql = f"SELECT key, value FROM records WHERE {where_sql} ORDER BY key {order}" # noqa: S608 + + # Snapshot rows while holding lock + with self.lock: + cursor = self.conn.execute(sql, tuple(params)) + rows = cursor.fetchall() + + # Yield after releasing lock + for k, v in rows: + yield k, v + + def count_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + *, + namespace: Optional[str] = None, + ) -> int: + """Count records in [start_key, end_key) excluding metadata.""" + if not isinstance(self.conn, sqlite3.Connection): + raise RuntimeError(f"SQLite connection is of wrong tpe `{type(self.conn)}`.") + + ns = self._ns(namespace) + + where_clauses = ["namespace = ?", "key != ?"] + params: List[Any] = [ns, DATABASE_METADATA_KEY] + + if start_key is not None: + where_clauses.append("key >= ?") + params.append(start_key) + + if end_key is not None: + where_clauses.append("key < ?") + params.append(end_key) + + where_sql = " AND ".join(where_clauses) + sql = f"SELECT COUNT(*) FROM records WHERE {where_sql}" # noqa: S608 + + with self.lock: + cursor = self.conn.execute(sql, tuple(params)) + return int(cursor.fetchone()[0]) + + def get_key_range( + self, namespace: Optional[str] = None + ) -> Tuple[Optional[bytes], Optional[bytes]]: + """Return (min_key, max_key) for the namespace or (None, None) if empty.""" + if not isinstance(self.conn, sqlite3.Connection): + raise ValueError(f"SQLite connection is of wrong tpe `{type(self.conn)}`.") + + ns = self._ns(namespace) + with self.lock: + cursor = self.conn.execute( + "SELECT MIN(key), MAX(key) FROM records WHERE namespace = ? and key != ?", + (ns, DATABASE_METADATA_KEY), + ) + result = cursor.fetchone() + return result[0], result[1] + + def get_backend_stats(self, namespace: Optional[str] = None) -> Dict[str, Any]: + """Return SQLite-specific stats and namespace metrics.""" + if not self.conn: + return {} + ns = self._ns(namespace) + with self.lock: + cursor = self.conn.execute( + "SELECT page_count, page_size FROM pragma_page_count(), pragma_page_size()" + ) + page_count, page_size = cursor.fetchone() + cursor = self.conn.execute("SELECT COUNT(*) FROM records WHERE namespace = ?", (ns,)) + namespace_count = int(cursor.fetchone()[0]) + return { + "backend": "sqlite", + "page_count": page_count, + "page_size": page_size, + "database_size": page_count * page_size, + "file_path": str(self.db_file), + "namespace": ns, + "namespace_count": namespace_count, + } + + def vacuum(self) -> None: + """Run SQLite VACUUM to reduce file size.""" + if not self.conn: + raise RuntimeError("Database not open") + with self.lock: + self.conn.execute("VACUUM") + logger.info("SQLite vacuum completed") + + +# ==================== Generic Database Implementation ==================== + + +class Database(DatabaseABC, SingletonMixin): + """Generic database. + + All operations accept an optional `namespace` argument. Implementations should + treat None as the default/root namespace. Concrete implementations can map + namespace -> native namespace (LMDB DBI) or emulate namespaces (SQLite uses + a namespace column). + """ + + _db: Optional[DatabaseBackendABC] = None + + @classmethod + def reset_instance(cls) -> None: + """Resets the singleton instance, forcing it to be recreated on next access.""" + with cls._lock: + # Close current database backend + if cls._db: + cls._db.close() + cls._db = None + # Remove current database instance + if cls in cls._instances: + del cls._instances[cls] + logger.debug(f"{cls.__name__} singleton instance has been reset.") + + def __init__(self) -> None: + """Initialize database.""" + super().__init__() + self._db = None + + def _setup_db(self) -> None: + """Setup database.""" + provider_id = self.config.database.provider + database: Optional[DatabaseBackendABC] = None + if provider_id is None: + database = None + elif provider_id == "LMDB": + database = LMDBDatabase() + elif provider_id == "SQLite": + database = SQLiteDatabase() + else: + raise RuntimeError("Invalid database provider '{provider_id}'") + if self._db is not None: + self._db.close() + self._db = database + + def _database(self) -> DatabaseBackendABC: + """Get database.""" + provider_id = self.config.database.provider + if provider_id is None: + raise RuntimeError("Database not configured") + + if self._db is None or self._db.provider_id() != provider_id: + # No database or configuration does not match + self._setup_db() + if self._db is None: + raise RuntimeError("Database not configured") + + if not self._db.is_open: + self._db.open() + + return self._db + + def provider_id(self) -> str: + """Return the unique identifier for the database provider.""" + try: + return self._database().provider_id() + except: + return "None" + + @property + def is_open(self) -> bool: + """Return whether the database connection is open.""" + try: + return self._database().is_open + except: + return False + + @property + def storage_path(self) -> Path: + """Storage path for the database.""" + return self._database().storage_path + + @property + def compression_level(self) -> int: + """Compression level for database record data.""" + return self._database().compression_level + + @property + def compression(self) -> bool: + """Whether to compress stored values.""" + return self._database().compression_level > 0 + + # Lifecycle + + def open(self, namespace: Optional[str] = None) -> None: + """Open database connection and optionally set default namespace. + + Args: + namespace: Optional default namespace to prepare. + + Raises: + RuntimeError: If the database cannot be opened. + """ + self._database().open(namespace) + + def close(self) -> None: + """Close the database connection and cleanup resources.""" + self._database().close() + + def flush(self, namespace: Optional[str] = None) -> None: + """Force synchronization of pending writes to storage (optional per-namespace).""" + return self._database().flush(namespace) + + # Metadata operations + + def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None: + """Save metadata for a given namespace. + + Metadata is treated separately from data records and stored as a single object. + + Args: + metadata (bytes): Arbitrary metadata to save or None to delete metadata. + namespace (Optional[str]): Optional namespace under which to store metadata. + """ + self._database().set_metadata(metadata, namespace=namespace) + + def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]: + """Load metadata for a given namespace. + + Returns None if no metadata exists. + + Args: + namespace (Optional[str]): Optional namespace whose metadata to retrieve. + + Returns: + Optional[bytes]: The loaded metadata, or None if not found. + """ + return self._database().get_metadata(namespace=namespace) + + # Basic record operations + + def save_records( + self, records: Iterable[tuple[bytes, bytes]], namespace: Optional[str] = None + ) -> int: + """Save multiple records into the specified namespace (or default). + + Args: + records: Iterable providing key, value tuples ordered by key: + - key: Byte key (sortable) for the record. + - value: Serialized (and optionally compressed) bytes to store. + namespace: Optional namespace. + + Returns: + Number of records saved. + + Raises: + RuntimeError: If DB not open or write failed. + """ + return self._database().save_records(records, namespace) + + def delete_records(self, keys: Iterable[bytes], namespace: Optional[str] = None) -> int: + """Delete multiple records by key from the specified namespace. + + Args: + keys: Iterable that provides the Byte keys to delete. + namespace: Optional namespace. + + Returns: + Number of records actually deleted. + """ + return self._database().delete_records(keys, namespace) + + def iterate_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + namespace: Optional[str] = None, + reverse: bool = False, + ) -> Iterator[tuple[bytes, bytes]]: + """Iterate over records for a namespace with optional bounds. + + Args: + start_key: Inclusive start key, or None. + end_key: Exclusive end key, or None. + namespace: Optional namespace to target. + reverse: If True iterate in descending key order. + + Yields: + Tuples of (key, record). + """ + return self._database().iterate_records(start_key, end_key, namespace, reverse) + + def count_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + *, + namespace: Optional[str] = None, + ) -> int: + """Count records in [start_key, end_key) excluding metadata in specified namespace. + + Excludes metadata records. + """ + return self._database().count_records(start_key, end_key, namespace=namespace) + + def get_key_range( + self, namespace: Optional[str] = None + ) -> Tuple[Optional[bytes], Optional[bytes]]: + """Return (min_key, max_key) in the given namespace or (None, None) if empty.""" + return self._database().get_key_range(namespace) + + def get_backend_stats(self, namespace: Optional[str] = None) -> Dict[str, Any]: + """Get backend-specific statistics; implementations may return namespace-specific data.""" + return self._database().get_backend_stats(namespace) diff --git a/src/akkudoktoreos/core/databaseabc.py b/src/akkudoktoreos/core/databaseabc.py new file mode 100644 index 0000000..b80aac7 --- /dev/null +++ b/src/akkudoktoreos/core/databaseabc.py @@ -0,0 +1,2194 @@ +"""Abstract database interface.""" + +from __future__ import annotations + +import bisect +import gzip +import pickle +from abc import ABC, abstractmethod +from enum import Enum, auto +from pathlib import Path +from threading import Lock +from typing import ( + TYPE_CHECKING, + Any, + Final, + Generic, + Iterable, + Iterator, + Literal, + Optional, + Protocol, + Self, + Type, + TypeVar, + Union, +) + +from loguru import logger +from numpydantic import NDArray, Shape + +from akkudoktoreos.core.coreabc import ( + ConfigMixin, + DatabaseMixin, + SingletonMixin, +) +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + Duration, + to_datetime, + to_duration, +) + +# Key used to store metadata +DATABASE_METADATA_KEY: bytes = b"__metadata__" + +# ==================== Abstract Database Interface ==================== + + +class DatabaseABC(ABC, ConfigMixin): + """Abstract base class for database. + + All operations accept an optional `namespace` argument. Implementations should + treat None as the default/root namespace. Concrete implementations can map + namespace -> native namespace (LMDB DBI) or emulate namespaces (SQLite uses + a namespace column). + """ + + @property + @abstractmethod + def is_open(self) -> bool: + """Return whether the database connection is open.""" + raise NotImplementedError + + @property + def storage_path(self) -> Path: + """Storage path for the database.""" + return self.config.general.data_folder_path / "db" / self.__class__.__name__.lower() + + @property + def compression_level(self) -> int: + """Compression level for database record data.""" + return self.config.database.compression_level + + @property + def compression(self) -> bool: + """Whether to compress stored values.""" + return self.config.database.compression_level > 0 + + # Lifecycle + + @abstractmethod + def provider_id(self) -> str: + """Return the unique identifier for the database provider. + + To be implemented by derived classes. + """ + raise NotImplementedError + + @abstractmethod + def open(self, namespace: Optional[str] = None) -> None: + """Open database connection and optionally set default namespace. + + Args: + namespace: Optional default namespace to prepare. + + Raises: + RuntimeError: If the database cannot be opened. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the database connection and cleanup resources.""" + raise NotImplementedError + + @abstractmethod + def flush(self, namespace: Optional[str] = None) -> None: + """Force synchronization of pending writes to storage (optional per-namespace).""" + raise NotImplementedError + + # Metadata operations + + @abstractmethod + def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None: + """Save metadata for a given namespace. + + Metadata is treated separately from data records and stored as a single object. + + Args: + metadata (bytes): Arbitrary metadata to save or None to delete metadata. + namespace (Optional[str]): Optional namespace under which to store metadata. + """ + raise NotImplementedError + + @abstractmethod + def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]: + """Load metadata for a given namespace. + + Returns None if no metadata exists. + + Args: + namespace (Optional[str]): Optional namespace whose metadata to retrieve. + + Returns: + Optional[bytes]: The loaded metadata, or None if not found. + """ + raise NotImplementedError + + # Basic record operations + + @abstractmethod + def save_records( + self, records: Iterable[tuple[bytes, bytes]], namespace: Optional[str] = None + ) -> int: + """Save multiple records into the specified namespace (or default). + + Args: + records: Iterable providing key, value tuples ordered by key: + - key: Byte key (sortable) for the record. + - value: Serialized (and optionally compressed) bytes to store. + namespace: Optional namespace. + + Returns: + Number of records saved. + + Raises: + RuntimeError: If DB not open or write failed. + """ + raise NotImplementedError + + @abstractmethod + def delete_records(self, keys: Iterable[bytes], namespace: Optional[str] = None) -> int: + """Delete multiple records by key from the specified namespace. + + Args: + keys: Iterable that provides the Byte keys to delete. + namespace: Optional namespace. + + Returns: + Number of records actually deleted. + """ + raise NotImplementedError + + @abstractmethod + def iterate_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + namespace: Optional[str] = None, + reverse: bool = False, + ) -> Iterator[tuple[bytes, bytes]]: + """Iterate over records for a namespace with optional bounds. + + Args: + start_key: Inclusive start key, or None. + end_key: Exclusive end key, or None. + namespace: Optional namespace to target. + reverse: If True iterate in descending key order. + + Yields: + Tuples of (key, record). + """ + raise NotImplementedError + + @abstractmethod + def count_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + *, + namespace: Optional[str] = None, + ) -> int: + """Count records in [start_key, end_key) excluding metadata in specified namespace. + + Excludes metadata records. + """ + raise NotImplementedError + + @abstractmethod + def get_key_range( + self, namespace: Optional[str] = None + ) -> tuple[Optional[bytes], Optional[bytes]]: + """Return (min_key, max_key) in the given namespace or (None, None) if empty.""" + raise NotImplementedError + + @abstractmethod + def get_backend_stats(self, namespace: Optional[str] = None) -> dict[str, Any]: + """Get backend-specific statistics; implementations may return namespace-specific data.""" + raise NotImplementedError + + # Compression helpers + + def serialize_data(self, data: bytes) -> bytes: + """Optionally compress raw pickled data before storage. + + Args: + data: Raw pickled bytes. + + Returns: + Possibly compressed bytes. + """ + if self.compression: + return gzip.compress(data, compresslevel=self.compression_level) + return data + + def deserialize_data(self, data: bytes) -> bytes: + """Optionally decompress stored data. + + Args: + data: Stored bytes. + + Returns: + Raw pickled bytes (decompressed if needed). + """ + if len(data) >= 2 and data[:2] == b"\x1f\x8b": + try: + return gzip.decompress(data) + except gzip.BadGzipFile: + pass + return data + + +class DatabaseBackendABC(DatabaseABC, SingletonMixin): + """Abstract base class for database backends. + + All operations accept an optional `namespace` argument. Implementations should + treat None as the default/root namespace. Concrete implementations can map + namespace -> native namespace (LMDB DBI) or emulate namespaces (SQLite uses + a namespace column). + """ + + connection: Any + lock: Lock + _is_open: bool + default_namespace: Optional[str] + + def __init__(self, **kwargs: Any) -> None: + """Initialize the DatabaseBackendABC base. + + Args: + **kwargs: Backend-specific options (ignored by base). + """ + self.connection = None + self.lock = Lock() + self._is_open = False + self.default_namespace = None + + @property + def is_open(self) -> bool: + """Return whether the database connection is open.""" + return self._is_open + + +# ==================== Database Record Protocol Mixin ==================== + + +class DataRecordProtocol(Protocol): + date_time: DateTime + + def __init__(self, date_time: Any) -> None: ... + + def __getitem__(self, key: str) -> Any: ... + + def model_dump(self) -> dict: ... + + +T_Record = TypeVar("T_Record", bound=DataRecordProtocol) + + +class DatabaseTimestamp(str): + """ISO8601 UTC datetime string used as database timestamp. + + Must always be in UTC and lexicographically sortable. + + Example: + "20241027T123456[Z]" # 2024-10-27 12:34:56 + """ + + __slots__ = () + + @classmethod + def from_datetime(cls, dt: DateTime) -> "DatabaseTimestamp": + if dt.tz is None: + raise ValueError("Timezone-aware datetime required") + + return cls(dt.in_timezone("UTC").format("YYYYMMDDTHHmmss[Z]")) + + def to_datetime(self) -> DateTime: + from pendulum import parse + + return parse(self) + + +class _DatabaseTimestampUnbound(str): + """Sentinel type representing an unbounded datetime value for database usage. + + Instances of this class are designed to be totally ordered relative to + ISO datetime strings: + + - UNBOUND_START is smaller than any other value. + - UNBOUND_END is greater than any other value. + + This makes the type safe for: + - sorted lists + - bisect operations + - dictionary keys + - range queries + + The type inherits from `str` to remain maximally efficient for hashing + and dictionary usage. + """ + + __slots__ = ("_is_start",) + + if TYPE_CHECKING: + _is_start: bool + + def __new__(cls, value: str, is_start: bool) -> "_DatabaseTimestampUnbound": + obj = super().__new__(cls, value) + obj._is_start = is_start + return obj + + def __lt__(self, other: object) -> bool: + if isinstance(other, _DatabaseTimestampUnbound): + return self._is_start and not other._is_start + return self._is_start + + def __le__(self, other: object) -> bool: + if isinstance(other, _DatabaseTimestampUnbound): + return self._is_start or self is other + return self._is_start + + def __gt__(self, other: object) -> bool: + if isinstance(other, _DatabaseTimestampUnbound): + return not self._is_start and other._is_start + return not self._is_start + + def __ge__(self, other: object) -> bool: + if isinstance(other, _DatabaseTimestampUnbound): + return not self._is_start or self is other + return not self._is_start + + def __repr__(self) -> str: + return "UNBOUND_START" if self._is_start else "UNBOUND_END" + + +DatabaseTimestampType = Union[DatabaseTimestamp, _DatabaseTimestampUnbound] + + +# Public sentinels +UNBOUND_START: Final[_DatabaseTimestampUnbound] = _DatabaseTimestampUnbound( + "UNBOUND_START", is_start=True +) +UNBOUND_END: Final[_DatabaseTimestampUnbound] = _DatabaseTimestampUnbound( + "UNBOUND_END", is_start=False +) + + +class _DatabaseTimeWindowUnbound: + """Sentinel representing an unbounded time window. + + This is distinct from `None`: + - None → parameter not provided + - UNBOUND_WINDOW → explicitly infinite duration + + Designed to: + - be identity-compared (is) + - be hashable + - be safe for dict usage + - avoid accidental equality with other values + """ + + __slots__ = () + + def __repr__(self) -> str: + return "UNBOUND_WINDOW" + + def __reduce__(self) -> str: + # Ensures singleton behavior during pickling + return "UNBOUND_WINDOW" + + +DatabaseTimeWindowType = Union[Duration, None, _DatabaseTimeWindowUnbound] + + +UNBOUND_WINDOW: Final[_DatabaseTimeWindowUnbound] = _DatabaseTimeWindowUnbound() + + +class DatabaseRecordProtocol(Protocol, Generic[T_Record]): + # ---- derived class required interface ---- + + records: list[T_Record] + + def model_post_init(self, __context: Any) -> None: ... + + def model_copy(self, *, deep: bool = False) -> Self: ... + + # record class introspection + @classmethod + def record_class(cls) -> Type[T_Record]: ... + + # Duration for which records shall be kept in database storage + def db_keep_duration(self) -> Optional[Duration]: ... + + # namespace + def db_namespace(self) -> str: ... + + # ---- public DB interface ---- + + def _db_reset_state(self) -> None: ... + + @property + def db_enabled(self) -> bool: ... + + def db_timestamp_range(self) -> tuple[DatabaseTimestampType, DatabaseTimestampType]: ... + + def db_generate_timestamps( + self, + start_timestamp: DatabaseTimestamp, + values_count: int, + interval: Optional[Duration] = None, + ) -> Iterator[DatabaseTimestamp]: ... + + def db_get_record(self, target_timestamp: DatabaseTimestamp) -> Optional[T_Record]: ... + + def db_insert_record( + self, + record: T_Record, + *, + mark_dirty: bool = True, + ) -> None: ... + + def db_iterate_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> Iterator[T_Record]: ... + + def db_load_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> int: ... + + def db_delete_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> int: ... + + # ---- dirty tracking ---- + def db_mark_dirty_record(self, record: T_Record) -> None: ... + + def db_save_records(self) -> int: ... + + # ---- autosave ---- + def db_autosave(self) -> int: ... + + # ---- Remove old records from database to free space ---- + def db_vacuum( + self, + keep_hours: Optional[int] = None, + keep_datetime: Optional[DatabaseTimestampType] = None, + ) -> int: ... + + # ---- statistics about database storage ---- + def db_count_records(self) -> int: ... + + def db_get_stats(self) -> dict: ... + + +T_DatabaseRecordProtocol = TypeVar("T_DatabaseRecordProtocol", bound="DatabaseRecordProtocol") + + +class DatabaseRecordProtocolLoadPhase(Enum): + """Database loading phases. + + NONE: + No records have been loaded from the database. + + INITIAL: + A limited initial time window has been loaded, typically centered + around a target datetime. + + FULL: + All records in the database have been loaded into memory. + + The phase controls whether further calls to ``db_ensure_loaded`` may + trigger additional database access. + """ + + NONE = auto() # nothing loaded + INITIAL = auto() # initial window loaded + FULL = auto() # fully expanded + + +class DatabaseRecordProtocolMixin( + ConfigMixin, + DatabaseMixin, + Generic[T_Record], # for typing only +): + """Database Record Protocol Mixin. + + Completely manages in memory records and database storage. + + Expects records with date_time (DatabaseTimestamp) property and the a record list + in self.records of the derived class. + + DatabaseRecordProtocolMixin expects the derived classes to be singletons. + """ + + # Tell mypy these attributes exist (will be provided by subclasses) + if TYPE_CHECKING: + records: list[T_Record] + + @classmethod + def record_class(cls) -> Type[T_Record]: ... + + @property + def record_keys_writable(self) -> list[str]: ... + + def key_to_array( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + interval: Optional[Duration] = None, + fill_method: Optional[str] = None, + dropna: Optional[bool] = True, + boundary: Literal["strict", "context"] = "context", + align_to_interval: bool = False, + ) -> NDArray[Shape["*"], Any]: ... + + # Database configuration + + def db_initial_time_window(self) -> Optional[Duration]: + """Return the initial time window used for database loading. + + This window defines the initial symmetric time span around a target datetime + that should be loaded from the database when no explicit search time window + is specified. It serves as a loading hint and may be expanded by the caller + if no records are found within the initial range. + + Subclasses may override this method to provide a domain-specific default. + + Returns: + The initial loading time window as a Duration, or ``None`` to indicate + that no initial window constraint should be applied. + """ + return None + + # ----------------------------------------------------- + # Initialization + # ----------------------------------------------------- + + def _db_ensure_initialized(self) -> None: + """Initialize DB runtime state. + + Idempotent — safe to call multiple times. + """ + if not getattr(self, "_db_initialized", None): + # record datetime to record mapping for fast lookup + self._db_record_index: dict[DatabaseTimestamp, T_Record] = {} + self._db_sorted_timestamps: list[DatabaseTimestamp] = [] + + # Loading phase tracking + self._db_load_phase: DatabaseRecordProtocolLoadPhase = ( + DatabaseRecordProtocolLoadPhase.NONE + ) + # Range of timestamps the was already queried from database storage during load + self._db_loaded_range: Optional[tuple[DatabaseTimestampType, DatabaseTimestampType]] = ( + None + ) + + # Dirty tracking + # - dirty records since last save + self._db_dirty_timestamps: set[DatabaseTimestamp] = set() + # - records added since last save + self._db_new_timestamps: set[DatabaseTimestamp] = set() + # - deleted records since last save + self._db_deleted_timestamps: set[DatabaseTimestamp] = set() + + self._db_version: int = 1 + + # Storage + self._db_metadata: Optional[dict] = None + self._db_storage_initialized: bool = False + + self._db_initialized: bool = True + + if not self._db_storage_initialized and self.db_enabled: + # Metadata + existing_metadata = self._db_load_metadata() + if existing_metadata: + self._db_metadata = existing_metadata + else: + self._db_metadata = { + "version": self._db_version, + "created": to_datetime(as_string=True), + "provider_id": getattr(self, "provider_id", lambda: "unknown")(), + "compression": self.database.compression, + "backend": self.database.__class__.__name__, + } + self._db_save_metadata(self._db_metadata) + + logger.info( + f"Initialized {self.database.__class__.__name__}:{self.db_namespace()} storage at " + f"{self.database.storage_path} " + f"autosave_interval_sec={self.config.database.autosave_interval_sec})" + ) + + self._db_storage_initialized = True + + def model_post_init(self, __context: Any) -> None: + """Initialize DB state attributes immediately after Pydantic construction.""" + # Always call super() first — other mixins may also define model_post_init + super().model_post_init(__context) # type: ignore[misc] + self._db_ensure_initialized() + + # ----------------------------------------------------- + # Helpers + # ----------------------------------------------------- + + def _db_key_from_timestamp(self, dt: DatabaseTimestamp) -> bytes: + """Convert database timestamp to a sortable database backend key.""" + return dt.encode("utf-8") + + def _db_key_to_timestamp(self, dbkey: bytes) -> DatabaseTimestamp: + """Convert database backend key back to database timestamp.""" + return DatabaseTimestamp(dbkey.decode("utf-8")) + + def _db_timestamp_after(self, timestamp: DatabaseTimestamp) -> DatabaseTimestamp: + """Get database timestamp after this timestamp. + + A minimal time span is added to the DatabaseTimestamp to get the first possible timestamp + after DatabaseTimestamp. + """ + target = DatabaseTimestamp.to_datetime(timestamp) + db_datetime_after = DatabaseTimestamp.from_datetime(target.add(seconds=1)) + return db_datetime_after + + def db_previous_timestamp( + self, + timestamp: DatabaseTimestamp, + ) -> Optional[DatabaseTimestamp]: + """Find the largest timestamp < given timestamp. + + Search memory-first, then fallback to database if necessary. + """ + self._db_ensure_initialized() + + # Step 1: Memory-first search + if self._db_sorted_timestamps: + idx = bisect.bisect_left(self._db_sorted_timestamps, timestamp) + if idx > 0: + return self._db_sorted_timestamps[idx - 1] + + # Step 2: Check if DB might contain older keys + if not self.db_enabled: + return None + + db_min_key, _ = self.database.get_key_range(self.db_namespace()) + if db_min_key is None: + return None + + db_min_ts = self._db_key_to_timestamp(db_min_key) + if timestamp <= db_min_ts: + return None + + # Step 3: Load left part of DB if not already in memory + # We want records < timestamp + start_key = None + end_key = self._db_key_from_timestamp(timestamp) + + # Only load if timestamp is out of currently loaded memory + if self._db_loaded_range: + loaded_start, _ = self._db_loaded_range + if isinstance(loaded_start, DatabaseTimestamp) and timestamp > loaded_start: + # Already partially loaded, restrict iterator to unloaded portion + start_key = self._db_key_from_timestamp(loaded_start) + + previous_ts: Optional[DatabaseTimestamp] = None + for key, _ in self.database.iterate_records( + start_key=start_key, + end_key=end_key, + namespace=self.db_namespace(), + ): + ts = self._db_key_to_timestamp(key) + if ts in self._db_deleted_timestamps: + continue + previous_ts = ts # last one before `timestamp` + + return previous_ts + + def db_next_timestamp( + self, + timestamp: DatabaseTimestamp, + ) -> Optional[DatabaseTimestamp]: + """Find the smallest timestamp > given timestamp. + + Search memory-first, then fallback to database if necessary. + """ + self._db_ensure_initialized() + + # Step 1: Memory-first search + if self._db_sorted_timestamps: + idx = bisect.bisect_right(self._db_sorted_timestamps, timestamp) + if idx < len(self._db_sorted_timestamps): + return self._db_sorted_timestamps[idx] + + # Step 2: Check if DB might contain newer keys + if not self.db_enabled: + return None + + _, db_max_key = self.database.get_key_range(self.db_namespace()) + if db_max_key is None: + return None + + db_max_ts = self._db_key_to_timestamp(db_max_key) + if timestamp >= db_max_ts: + return None + + # Step 3: Search right part of DB if not already in memory + timestamp_key = self._db_key_from_timestamp(timestamp) + start_key = timestamp_key + end_key = None + + # Restrict iterator to unloaded portion if partially loaded + if self._db_loaded_range: + _, loaded_end = self._db_loaded_range + # Assumes everything < loaded_end is fully represented in memory. + if isinstance(loaded_end, DatabaseTimestamp) and timestamp < loaded_end: + start_key = self._db_key_from_timestamp(max(timestamp, loaded_end)) + + for key, _ in self.database.iterate_records( + start_key=start_key, + end_key=end_key, + namespace=self.db_namespace(), + ): + if key == timestamp_key: + # skip + continue + + ts = self._db_key_to_timestamp(key) + + # Check for deleted (only necessary for database - memory already removed + if ts in self._db_deleted_timestamps: + continue + + return ts # first valid one + + return None + + def _db_serialize_record(self, record: T_Record) -> bytes: + """Serialize a DataRecord to bytes.""" + if self.database is None: + raise ValueError("Database not defined.") + data = pickle.dumps(record.model_dump(), protocol=pickle.HIGHEST_PROTOCOL) + return self.database.serialize_data(data) + + def _db_deserialize_record(self, data: bytes) -> T_Record: + """Deserialize bytes to a DataRecord.""" + if self.database is None: + raise ValueError("Database not defined.") + data = self.database.deserialize_data(data) + record_data = pickle.loads(data) # noqa: S301 + return self.record_class()(**record_data) + + def _db_save_metadata(self, metadata: dict) -> None: + """Save metadata to database.""" + if not self.db_enabled: + return + + key = DATABASE_METADATA_KEY + value = pickle.dumps(metadata) + self.database.set_metadata(value, namespace=self.db_namespace()) + + def _db_load_metadata(self) -> Optional[dict]: + """Load metadata from database.""" + if not self.db_enabled: + return None + + try: + value = self.database.get_metadata(namespace=self.db_namespace()) + return pickle.loads(value) # noqa: S301 + except Exception: + logger.debug("Can not load metadata.") + return None + + def _db_reset_state(self) -> None: + self.records = [] + self._db_loaded_range = None + self._db_load_phase = DatabaseRecordProtocolLoadPhase.NONE + try: + del self._db_initialized + except: + logger.debug("_db_reset_state called on uninitialized sequence") + + def _db_clone_empty(self: T_DatabaseRecordProtocol) -> T_DatabaseRecordProtocol: + """Create an empty internal clone for database operations. + + The clone shares configuration and database access implicitly via + ConfigMixin and DatabaseMixin, but contains no in-memory records + or loaded-range state. + + Internal helper for database workflows only. + """ + clone = self.model_copy(deep=True) + clone._db_reset_state() + + return clone + + def _search_window( + self, + center_timestamp: Optional[DatabaseTimestampType], + time_window: DatabaseTimeWindowType, + ) -> tuple[DatabaseTimestampType, DatabaseTimestampType]: + """Compute a symmetric search window around a center timestamp. + + This method always returns valid database boundary values. + + Args: + center_timestamp: Center of the window. Defaults to current UTC time + if None. Must not be an unbounded timestamp sentinel. + time_window: Total width of the search window. + Half is applied on each side of center_timestamp. + - None: interpreted as unbounded. + - UNBOUND_WINDOW: interpreted as unbounded. + - Duration: symmetric bounded interval. + + Returns: + A tuple (start, end) representing a half-open interval. + Always returns valid database timestamp boundaries: + either concrete timestamps or (UNBOUND_START, UNBOUND_END). + + Raises: + TypeError: If center_timestamp is an unbounded timestamp sentinel. + ValueError: If time_window is a negative Duration. + """ + # Unbounded cases → full DB range + if time_window is None or isinstance(time_window, _DatabaseTimeWindowUnbound): + return UNBOUND_START, UNBOUND_END + + if isinstance(center_timestamp, _DatabaseTimestampUnbound): + raise TypeError("center_timestamp cannot be of unbounded timestamp type.") + + # Resolve center + if center_timestamp is None: + center = to_datetime().in_timezone("UTC") + else: + center = DatabaseTimestamp.to_datetime(center_timestamp) + + duration = to_duration(time_window) + + if duration.total_seconds() < 0: + raise ValueError("time_window must be non-negative") + + # Use duration arithmetic to avoid float precision issues + half = duration / 2 + + start = center - half + end = center + half + + return ( + DatabaseTimestamp.from_datetime(start), + DatabaseTimestamp.from_datetime(end), + ) + + def _db_range_covered( + self, + start_timestamp: DatabaseTimestampType, + end_timestamp: DatabaseTimestampType, + ) -> bool: + """Return True if [start_timestamp, end_timestamp) is fully covered. + + Args: + start_timestamp: Inclusive lower boundary of the requested range. + end_timestamp: Exclusive upper boundary of the requested range. + + Returns: + True if the requested half-open interval is completely contained + within the loaded database range. + + Raises: + TypeError: If start_timestamp or end_timestamp is None. + """ + if start_timestamp is None or end_timestamp is None: + raise TypeError( + "start_timestamp and end_timestamp must not be None. " + "Use UNBOUND_START / UNBOUND_END instead." + ) + + if not isinstance(start_timestamp, (str, _DatabaseTimestampUnbound)): + raise TypeError( + f"Invalid start_timestamp type: {type(start_timestamp)}. " + "Must be DatabaseTimestamp or unbound sentinel." + ) + + if not isinstance(end_timestamp, (str, _DatabaseTimestampUnbound)): + raise TypeError( + f"Invalid end_timestamp type: {type(end_timestamp)}. " + "Must be DatabaseTimestamp or unbound sentinel." + ) + + if self._db_loaded_range is None: + return False + + loaded_start, loaded_end = self._db_loaded_range + + if loaded_start is None or loaded_end is None: + return False + + return loaded_start <= start_timestamp and end_timestamp <= loaded_end + + def _db_load_initial_window( + self, + center_timestamp: Optional[DatabaseTimestampType] = None, + ) -> None: + """Load an initial time window of records from the database. + + This method establishes the first lazy-loading window when the load phase + is ``NONE``. It queries the database for records within a symmetric time + interval around ``center_timestamp`` and transitions the load phase to + ``INITIAL``. + + The loaded interval is recorded in ``self._db_loaded_range`` and represents + **database coverage**, not memory continuity. That is: + + - All database records in the half-open interval + [start_timestamp, end_timestamp) have been queried. + - Records within that interval are either loaded into memory or + confirmed absent. + - The interval does not imply that memory contains continuous records. + + The loaded range is later expanded incrementally if additional + out-of-window ranges are requested. + + If ``center_timestamp`` is not provided, the current time is used. + + Args: + center_timestamp (DatabaseTimestampType): + The central reference time for the initial loading window. + If None, the current time is used. + + Side Effects: + + * Loads records from persistent storage into memory. + * Sets ``self._db_loaded_range`` by db_load_records(). + * Sets ``self._db_load_phase`` to ``INITIAL``. + + Notes: + * The loaded range uses half-open interval semantics: + [start_timestamp, end_timestamp). + * This method does not perform a full database load. + * Empty query results still establish coverage for the interval, + preventing redundant database queries. + """ + if not self.db_enabled: + return + + # Redundant guard - should only be called from load phase None + if self._db_load_phase is not DatabaseRecordProtocolLoadPhase.NONE: + raise RuntimeError( + "_db_load_initial_window() may only be called when load phase is NONE." + ) + + window_h = self.config.database.initial_load_window_h + if window_h is None: + start, end = self._search_window(center_timestamp, UNBOUND_WINDOW) + else: + window = to_duration(window_h * 3600) + start, end = self._search_window(center_timestamp, window) + + self.db_load_records(start, end) + + self._db_load_phase = DatabaseRecordProtocolLoadPhase.INITIAL + + def _db_load_full(self) -> int: + """Load all remaining records from the database into memory. + + This method performs a **full load** of the database, ensuring that all + records are present in memory. After this operation, the `_db_load_phase` + will be set to FULL, and `_db_loaded_range` will cover all known records. + + **State transitions:** + + * Allowed only from the INITIAL phase (partial window loaded) or NONE + (nothing loaded yet). + * If already FULL, the method is a no-op and returns 0. + + Returns: + int: Number of records loaded from the database during this operation. + + Raises: + RuntimeError: If called from an invalid load phase. + """ + if not self.db_enabled: + return 0 + + # Guard: must only run from NONE or INITIAL + if self._db_load_phase not in ( + DatabaseRecordProtocolLoadPhase.NONE, + DatabaseRecordProtocolLoadPhase.INITIAL, + ): + raise RuntimeError( + "_db_load_full() may only be called when load phase is NONE or INITIAL." + ) + + # Perform full database load (memory is authoritative; skips duplicates) + # This also sets _db_loaded_range + loaded_count = self.db_load_records() + + # Update state + self._db_load_phase = DatabaseRecordProtocolLoadPhase.FULL + + return loaded_count + + def _extend_boundaries( + self, + start_timestamp: DatabaseTimestampType, + end_timestamp: DatabaseTimestampType, + ) -> tuple[DatabaseTimestampType, DatabaseTimestampType]: + """Find nearest database records outside requested range. + + Returns: + (new_start, new_end) timestamps to fully cover requested range including neighbors. + """ + if start_timestamp is None or end_timestamp is None: + # Make mypy happy + raise RuntimeError(f"timestamps shall be non None: {start_timestamp}, {end_timestamp}") + + new_start, new_end = start_timestamp, end_timestamp + + # Extend start + if ( + not isinstance(start_timestamp, _DatabaseTimestampUnbound) + and self._db_sorted_timestamps + and start_timestamp < self._db_sorted_timestamps[0] + ): + # There may be earlier DB records + # Reverse iterate to get nearest smaller key + for key, _ in self.database.iterate_records( + start_key=UNBOUND_START, + end_key=self._db_key_from_timestamp(start_timestamp), + namespace=self.db_namespace(), + reverse=True, + ): + ts = self._db_key_to_timestamp(key) + + if ts in self._db_deleted_timestamps: + continue + + if ts < start_timestamp: + new_start = ts + break # first valid record is the nearest + + # Extend end + if ( + not isinstance(end_timestamp, _DatabaseTimestampUnbound) + and self._db_sorted_timestamps + and end_timestamp > self._db_sorted_timestamps[-1] + ): + # There may be later DB records + for key, _ in self.database.iterate_records( + start_key=self._db_key_from_timestamp(end_timestamp), + end_key=UNBOUND_END, + namespace=self.db_namespace(), + ): + ts = self._db_key_to_timestamp(key) + + if ts in self._db_deleted_timestamps: + continue + + if ts >= end_timestamp: + new_end = ts + break # first valid record is the nearest + + return new_start, new_end + + def _db_ensure_loaded( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + *, + center_timestamp: Optional[DatabaseTimestampType] = None, + ) -> None: + """Ensure database records for a given timestamp range are available in memory. + + Lazy loading is performed in phases: NONE -> INITIAL -> FULL + + 1. **NONE**: No records loaded yet. + + * If a range is provided, load exactly that range. + * If no range, load an initial window around `center_timestamp`. + + 2. **INITIAL**: A partial window is loaded. + + * If requested range extends beyond loaded window, expand left/right as needed. + * If no range requested, escalate to FULL. + + 3. **FULL**: All records already loaded. Nothing to do. + + Args: + start_timestamp (DatabaseTimestampType): Inclusive start of desired range. + end_timestamp (DatabaseTimestampType): Exclusive end of desired range. + center_timestamp (DatabaseTimestampType): Center for initial window if nothing loaded. + + Notes: + * Only used for preparing memory for subsequent queries; does not return records. + * `center_timestamp` is ignored once an initial window has been established. + """ + if not self.db_enabled: + return + + # Normalize boundaries immediately (strict DB layer rule) + if start_timestamp is None: + start_timestamp = UNBOUND_START + if end_timestamp is None: + end_timestamp = UNBOUND_END + + # Shortcut: memory already covers the extended range + if self._db_sorted_timestamps: + mem_start, mem_end = self._db_sorted_timestamps[0], self._db_sorted_timestamps[-1] + + # Case 1: bounded request + if ( + start_timestamp is not UNBOUND_START + and end_timestamp is not UNBOUND_END + and mem_start < start_timestamp + and mem_end >= end_timestamp + ): + return + + # Case 2: unbounded request only safe if FULL + if ( + self._db_load_phase is DatabaseRecordProtocolLoadPhase.FULL + and (start_timestamp is UNBOUND_START or mem_start < start_timestamp) + and (end_timestamp is UNBOUND_END or mem_end >= end_timestamp) + ): + return + + # Phase 0: NOTHING LOADED + if self._db_load_phase is DatabaseRecordProtocolLoadPhase.NONE: + if start_timestamp is UNBOUND_START and end_timestamp is UNBOUND_END: + self._db_load_initial_window(center_timestamp) + # _db_load_initial_window sets _db_loaded_range and _db_load_phase + else: + # Load the records + loaded = self.db_load_records(start_timestamp, end_timestamp) + self._db_load_phase = DatabaseRecordProtocolLoadPhase.INITIAL + return + + if center_timestamp is not None: + logger.debug( + f"Center timestamp parameter '{center_timestamp}' given outside of load phase NONE" + ) + + # Phase 1: INITIAL WINDOW (PARTIAL) + if self._db_load_phase is DatabaseRecordProtocolLoadPhase.INITIAL: + # Escalate to FULL if no range is specified + if self._db_loaded_range is None: + # Should never happen + raise RuntimeError("_db_loaded_range shall set when load phase is INITIAL") + + if self._db_range_covered(start_timestamp, end_timestamp): + return # already have it + + if start_timestamp == UNBOUND_START and end_timestamp == UNBOUND_END: + self._db_load_full() + return + + current_start, current_end = self._db_loaded_range + if current_start is None or current_end is None: + raise RuntimeError( + "_db_loaded_range shall not be set to (None, None) when load phase is INITIAL" + ) + + # Left expansion + if start_timestamp < current_start: + self.db_load_records(start_timestamp, current_start) + + # Right expansion + if end_timestamp > current_end: + self.db_load_records(current_end, end_timestamp) + + return + + # Phase 2: FULL + # Everything already loaded, nothing to do + return + + # ---- derived class required interface ---- + + def db_keep_duration(self) -> Optional[Duration]: + """Duration for which database records should be retained. + + Used when removing old records from database to free space. + + Defaults to general database configuration. + + May be provided by derived class. + + Returns: + Duration or None (forever). + """ + duration_h: Optional[Duration] = self.config.database.keep_duration_h + if duration_h is None: + return None + return to_duration(duration_h * 3600) + + def db_namespace(self) -> str: + """Namespace of database. + + To be implemented by derived class. + """ + raise NotImplementedError + + # ---- public DB interface ---- + + @property + def db_enabled(self) -> bool: + return self.database.is_open + + def db_timestamp_range( + self, + ) -> tuple[Optional[DatabaseTimestamp], Optional[DatabaseTimestamp]]: + """Get the timestamp range of records in database. + + Regards records in storage plus extra records in memory. + """ + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + if self._db_sorted_timestamps: + memory_min_timestamp: Optional[DatabaseTimestamp] = self._db_sorted_timestamps[0] + memory_max_timestamp: Optional[DatabaseTimestamp] = self._db_sorted_timestamps[-1] + else: + memory_min_timestamp = None + memory_max_timestamp = None + + if not self.db_enabled: + return memory_min_timestamp, memory_max_timestamp + + db_min_key, db_max_key = self.database.get_key_range(self.db_namespace()) + + if db_min_key is None or db_max_key is None: + return memory_min_timestamp, memory_max_timestamp + + storage_min_timestamp = self._db_key_to_timestamp(db_min_key) + storage_max_timestamp = self._db_key_to_timestamp(db_max_key) + + if memory_min_timestamp and memory_min_timestamp < storage_min_timestamp: + min_timestamp = memory_min_timestamp + else: + min_timestamp = storage_min_timestamp + if memory_max_timestamp and memory_max_timestamp > storage_max_timestamp: + max_timestamp = memory_max_timestamp + else: + max_timestamp = storage_max_timestamp + + return min_timestamp, max_timestamp + + def db_generate_timestamps( + self, + start_timestamp: DatabaseTimestamp, + values_count: int, + interval: Optional[Duration] = None, + ) -> Iterator[DatabaseTimestamp]: + """Generate database timestamps using fixed absolute time stepping. + + The iterator advances strictly in UTC, guaranteeing constant + spacing in seconds across daylight saving transitions. + + Returned database timestamps are in UTC. This avoids ambiguity during + fall-back transitions and prevents accidental overwriting when + inserting into UTC-normalized storage backends. + + Args: + start_timestamp (DatabaseTimestamp): Starting database timestamp. + values_count (int): Number of timestamps to generate. + interval (Optional[Duration]): Fixed duration between timestamps. + Defaults to 1 hour if not provided. + + Yields: + DatabaseTimestamp: UTC-based database timestamps. + + Raises: + ValueError: If values_count is negative. + """ + if values_count < 0: + raise ValueError("values_count must be non-negative") + + if interval is None: + interval = Duration(hours=1) + + step_seconds = int(interval.total_seconds()) + + current_utc = DatabaseTimestamp.to_datetime(start_timestamp) + + for _ in range(values_count): + yield DatabaseTimestamp.from_datetime(current_utc) + current_utc = current_utc.add(seconds=step_seconds) + + def db_get_record( + self, + target_timestamp: DatabaseTimestamp, + *, + time_window: DatabaseTimeWindowType = None, + ) -> Optional[T_Record]: + """Get the record at or nearest to the specified timestamp. + + The search strategies are: + + * None - exact match only. + * UNBOUND_WINDOW - nearest record across all stored records. + * Duration - nearest record within a symmetric window of this total width around + target_timestamp. + + Args: + target_timestamp: The timestamp to search for. + time_window: Controls the search strategy (None, UNBOUND_WINDOW, Duration). + + Returns: + Exact match, nearest record within the window, or None. + """ + self._db_ensure_initialized() + + if time_window is None: + # Exact match only — load the minimal range containing this point + self._db_ensure_loaded( + target_timestamp, + self._db_timestamp_after(target_timestamp), + center_timestamp=target_timestamp, + ) + return self._db_record_index.get(target_timestamp, None) + + # load the relevant range + # in case of unbounded escalates to FULL + search_start, search_end = self._search_window(target_timestamp, time_window) + self._db_ensure_loaded(search_start, search_end, center_timestamp=target_timestamp) + + # Exact match first (works for all three cases once loaded) + record = self._db_record_index.get(target_timestamp, None) + if record is not None: + return record + + # Nearest-neighbour search + idx = bisect.bisect_left(self._db_sorted_timestamps, target_timestamp) + candidates = [] + if idx < len(self._db_sorted_timestamps): + candidates.append(self.records[idx]) + if idx > 0: + candidates.append(self.records[idx - 1]) + if not candidates: + return None + + record = min( + candidates, + key=lambda r: abs( + (r.date_time - DatabaseTimestamp.to_datetime(target_timestamp)).total_seconds() + ), + ) + + # For bounded windows, enforce the distance constraint + if not isinstance(time_window, _DatabaseTimeWindowUnbound): + half_seconds = to_duration(time_window).total_seconds() / 2 + if ( + abs( + ( + record.date_time - DatabaseTimestamp.to_datetime(target_timestamp) + ).total_seconds() + ) + > half_seconds + ): + return None + + return record + + def db_insert_record( + self, + record: T_Record, + *, + mark_dirty: bool = True, + ) -> None: + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + # Ensure normalized to UTC + db_record_date_time = DatabaseTimestamp.from_datetime(record.date_time) + + self._db_ensure_loaded( + start_timestamp=db_record_date_time, + end_timestamp=db_record_date_time, + ) + + # Memory only + if db_record_date_time in self._db_record_index: + # No duplicates allowed + raise ValueError(f"Duplicate timestamp {record.date_time} -> {db_record_date_time}") + + if db_record_date_time in self._db_deleted_timestamps: + # Clear tombstone - if we are re-inserting + self._db_deleted_timestamps.discard(db_record_date_time) + + # insert + index = bisect.bisect_left(self._db_sorted_timestamps, db_record_date_time) + self._db_sorted_timestamps.insert(index, db_record_date_time) + self.records.insert(index, record) + self._db_record_index[db_record_date_time] = record + + if mark_dirty: + self._db_dirty_timestamps.add(db_record_date_time) + self._db_new_timestamps.add(db_record_date_time) + + # ----------------------------------------------------- + # Load (range) + # ----------------------------------------------------- + + def db_load_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> int: + """Load records from database into memory. + + Merges database records into in-memory records while preserving: + - Memory-only records + - Sorted order + - No duplicates (DB overwrites memory) + + This requested load range is extended to include the first record < start_timestamp + and the first record >= end_timestamp, so nearest-neighbor searches do not require + additional DB lookups. + + The `_db_loaded_range` is updated to reflect the total timestamp span + currently present in memory after this method completes. + + Args: + start_timestamp: Load records from this timestamp (inclusive) + end_timestamp: Load records until this timestamp (exclusive) + + Returns: + Number of records loaded from database + + Note: + record.date_time shall be DateTime or None + """ + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + if not self.db_enabled: + return 0 + + # Normalize boundaries immediately (strict DB layer rule) + if start_timestamp is None: + start_timestamp = UNBOUND_START + if end_timestamp is None: + end_timestamp = UNBOUND_END + + # Extend boundaries to include first record < start and first record >= end + query_start, query_end = self._extend_boundaries(start_timestamp, end_timestamp) + + if isinstance(query_start, _DatabaseTimestampUnbound): + start_key = None + else: + start_key = self._db_key_from_timestamp(query_start) + if isinstance(query_end, _DatabaseTimestampUnbound): + end_key = None + else: + end_key = self._db_key_from_timestamp(query_end) + + namespace = self.db_namespace() + + loaded_count = 0 + + # Iterate DB records (already sorted by key) + for db_key, value in self.database.iterate_records( + start_key=start_key, + end_key=end_key, + namespace=namespace, + ): + if db_key == DATABASE_METADATA_KEY: + continue + + record = self._db_deserialize_record(value) + db_record_date_time = DatabaseTimestamp.from_datetime(record.date_time) + + # Do not resurrect explicitly deleted records + if db_record_date_time in self._db_deleted_timestamps: + continue + + # ---- Memory is authoritative: skip if already present + if db_record_date_time in self._db_record_index: + continue + + # Insert sorted + # - do not call self.db_insert_record - may call db_load_records recursively + # - see self.db_insert_record(record, mark_dirty=False) + index = bisect.bisect_left(self._db_sorted_timestamps, db_record_date_time) + self._db_sorted_timestamps.insert(index, db_record_date_time) + self.records.insert(index, record) + self._db_record_index[db_record_date_time] = record + + loaded_count += 1 + + # Update range of timestamps the was already queried from database storage during load + if self._db_loaded_range is None: + # First load - initialize + self._db_loaded_range = query_start, query_end + else: + current_start, current_end = self._db_loaded_range + if query_start < current_start: + current_start = query_start + if query_end > current_end: + current_end = query_end + self._db_loaded_range = current_start, current_end + + return loaded_count + + # ----------------------------------------------------- + # Delete (range) + # ----------------------------------------------------- + + def db_delete_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> int: + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + # Deletion is global — ensure we see everything + self._db_ensure_loaded( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + ) + + to_delete: list[DatabaseTimestamp] = [] + + for dt in list(self._db_sorted_timestamps): + if start_timestamp and dt < start_timestamp: + continue + if end_timestamp and dt >= end_timestamp: + continue + to_delete.append(dt) + + for dt in to_delete: + record = self._db_record_index.pop(dt, None) + if record is not None: + idx = bisect.bisect_left(self._db_sorted_timestamps, dt) + if idx < len(self._db_sorted_timestamps) and self._db_sorted_timestamps[idx] == dt: + self._db_sorted_timestamps.pop(idx) + try: + self.records.remove(record) + except Exception as ex: + logger.debug(f"Failed to remove record: {ex}") + + # Mark for physical deletion + self._db_deleted_timestamps.add(dt) + + # If it was dirty (new record), cancel the insert instead + self._db_dirty_timestamps.discard(dt) + self._db_new_timestamps.discard(dt) + + return len(to_delete) + + # ----------------------------------------------------- + # Iteration from DB (no duplicates) + # ----------------------------------------------------- + + def db_iterate_records( + self, + start_timestamp: Optional[DatabaseTimestampType] = None, + end_timestamp: Optional[DatabaseTimestampType] = None, + ) -> Iterator[T_Record]: + """Iterate records in requested range. + + Ensures storage is loaded into memory first, + then iterates over in-memory records only. + """ + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + # Ensure memory contains required range + self._db_ensure_loaded( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + ) + + for record in self.records: + record_date_time_timestamp = DatabaseTimestamp.from_datetime(record.date_time) + + if start_timestamp and record_date_time_timestamp < start_timestamp: + continue + + if end_timestamp and record_date_time_timestamp >= end_timestamp: + break + + if record_date_time_timestamp in self._db_deleted_timestamps: + continue + + yield record + + # ----------------------------------------------------- + # Dirty tracking + # ----------------------------------------------------- + + def db_mark_dirty_record(self, record: T_Record) -> None: + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + record_date_time_timestamp = DatabaseTimestamp.from_datetime(record.date_time) + self._db_dirty_timestamps.add(record_date_time_timestamp) + + # ----------------------------------------------------- + # Bulk save (flush dirty only) + # ----------------------------------------------------- + + def db_save_records(self) -> int: + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + if not self.db_enabled: + return 0 + + if not self._db_dirty_timestamps and not self._db_deleted_timestamps: + return 0 + + namespace = self.db_namespace() + + # safer order: saves first, deletes last + + # --- handle inserts/updates --- + save_items = [] + for dt in self._db_dirty_timestamps: + record = self._db_record_index.get(dt) + if record: + key = self._db_key_from_timestamp(dt) + value = self._db_serialize_record(record) + save_items.append((key, value)) + saved_count = len(save_items) + if saved_count: + self.database.save_records(save_items, namespace=namespace) + self._db_dirty_timestamps.clear() + self._db_new_timestamps.clear() + + # --- handle deletions --- + if self._db_deleted_timestamps: + delete_keys = [self._db_key_from_timestamp(dt) for dt in self._db_deleted_timestamps] + self.database.delete_records(delete_keys, namespace=namespace) + deleted_count = len(self._db_deleted_timestamps) + self._db_deleted_timestamps.clear() + + return saved_count + deleted_count + + def db_autosave(self) -> int: + return self.db_save_records() + + def db_vacuum( + self, + keep_hours: Optional[int] = None, + keep_timestamp: Optional[DatabaseTimestampType] = None, + ) -> int: + """Remove old records from database to free space. + + Semantics: + + - keep_hours is relative to the DB's max timestamp: cutoff = db_max - keep_hours, and records + with timestamp < cutoff are deleted. + - keep_timestamp is an absolute cutoff; records with timestamp < cutoff are deleted (exclusive). + + Uses self.keep_duration() if both of keep_hours and keep_timestamp are None. + + Args: + keep_hours: Keep only records from the last N hours (relative to the data's max timestamp) + keep_timestamp: Keep only records from this timestamp on (absolute cutoff) + + Returns: + Number of records deleted + """ + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + if keep_hours is None and keep_timestamp is None: + keep_duration = self.db_keep_duration() + if keep_duration is None: + # No vacuum if all is None + logger.info( + f"Vacuum requested for database '{self.db_namespace()}' but keep limit is infinite." + ) + return 0 + keep_hours = keep_duration.hours + + if keep_hours is not None: + _, db_max = self.db_timestamp_range() + if db_max is None or isinstance(db_max, _DatabaseTimestampUnbound): + # No records + return 0 # nothing to delete + if keep_hours <= 0: + db_cutoff_timestamp: DatabaseTimestampType = UNBOUND_END + else: + # cutoff = first record we want to delete; everything before is removed + datetime_max: DateTime = DatabaseTimestamp.to_datetime(db_max) + db_cutoff_timestamp = DatabaseTimestamp.from_datetime( + datetime_max.subtract(hours=keep_hours - 1) + ) + elif keep_timestamp is not None: + db_cutoff_timestamp = keep_timestamp + else: + raise ValueError("Must specify either keep_hours or keep_timestamp") + + # Delete records + deleted_count = self.db_delete_records(end_timestamp=db_cutoff_timestamp) + + self.db_save_records() + + logger.info( + f"Vacuumed {deleted_count} old records from database '{self.db_namespace()}' " + f"(before {db_cutoff_timestamp})" + ) + return deleted_count + + def db_count_records(self) -> int: + """Return total logical number of records. + + Memory is authoritative. If DB is enabled but not fully loaded, + we conservatively include storage-only records. + """ + # Defensive call - model_post_init() may not have initialized metadata + self._db_ensure_initialized() + + if not self.db_enabled: + return len(self.records) + + # If fully loaded, memory is complete view + if self._db_load_phase is DatabaseRecordProtocolLoadPhase.FULL: + return len(self.records) + + storage_count = self.database.count_records(namespace=self.db_namespace()) + pending_deletes = len(self._db_deleted_timestamps) + new_count = len(self._db_new_timestamps) + + return storage_count + new_count - pending_deletes + + def db_get_stats(self) -> dict: + """Get comprehensive statistics about database storage. + + Returns: + Dictionary with statistics + """ + if not self.db_enabled: + return {"enabled": False} + + ns = self.db_namespace() + + stats = { + "enabled": True, + "backend": self.database.__class__.__name__, + "path": str(self.database.storage_path), + "memory_records": len(self.records), + "compression_enabled": self.database.compression, + "keep_duration_h": self.config.database.keep_duration_h, + "autosave_interval_sec": self.config.database.autosave_interval_sec, + "total_records": self.database.count_records(namespace=ns), + } + + # Add backend-specific stats + stats.update(self.database.get_backend_stats(namespace=ns)) + + min_timestamp, max_timestamp = self.db_timestamp_range() + stats["timestamp_range"] = { + "min": str(min_timestamp), + "max": str(max_timestamp), + } + + return stats + + # ==================== Tiered Compaction ==================== + + def db_compact_tiers(self) -> list[tuple[Duration, Duration]]: + """Compaction tiers as (age_threshold, target_interval) pairs. + + Records older than age_threshold are downsampled to target_interval. + Tiers must be ordered from shortest to longest age threshold. + + Default policy: + + - older than 2 hours → 15 min resolution + - older than 14 days → 1 hour resolution + + Return empty list to disable compaction entirely. + Override in derived classes for domain-specific behaviour. + + Example override to disable: + + .. code-block python + + def db_compact_tiers(self): + return [] + + Example override for price data (already at 15 min, skip first tier): + + .. code-block python + + def db_compact_tiers(self): + return [ + (to_duration("2 weeks"), to_duration("1 hour")), + ] + + .. comment + """ + return [ + (to_duration("2 hours"), to_duration("15 minutes")), + (to_duration("14 days"), to_duration("1 hour")), + ] + + # ------------------------------------------------------------------ + # Compaction state helpers (stored in namespace metadata) + # ------------------------------------------------------------------ + + def _db_get_compact_state( + self, + tier_interval: Duration, + ) -> Optional[DatabaseTimestamp]: + """Load the last compaction cutoff timestamp for a given tier interval. + + Args: + tier_interval: The target interval that identifies this tier. + + Returns: + The last cutoff DatabaseTimestamp, or None if never compacted. + """ + if self._db_metadata is None: + return None + key = f"last_compact_cutoff_{int(tier_interval.total_seconds())}" + cutoff_str = self._db_metadata.get(key) + return DatabaseTimestamp(cutoff_str) if cutoff_str else None + + def _db_set_compact_state( + self, + tier_interval: Duration, + cutoff_ts: DatabaseTimestamp, + ) -> None: + """Persist the last compaction cutoff timestamp for a given tier interval. + + Args: + tier_interval: The target interval that identifies this tier. + cutoff_ts: The cutoff timestamp to store. + """ + if self._db_metadata is None: + self._db_metadata = {} + key = f"last_compact_cutoff_{int(tier_interval.total_seconds())}" + self._db_metadata[key] = str(cutoff_ts) + self._db_save_metadata(self._db_metadata) + + # ------------------------------------------------------------------ + # Single-tier worker + # ------------------------------------------------------------------ + + def _db_compact_tier( + self, + age_threshold: Duration, + target_interval: Duration, + ) -> int: + """Downsample records older than age_threshold to target_interval resolution. + + Only processes the window [last_compact_cutoff, new_cutoff) so repeated + runs are cheap. + + The window boundaries are snapped to UTC epoch-aligned interval boundaries + before processing: + + - ``window_start`` is floored to the nearest interval boundary at or before + the raw start. This guarantees that the first resampled bucket always + sits on a clock-round timestamp (e.g. :00/:15/:30/:45 for 15 min) and + that consecutive runs produce gapless, non-overlapping coverage. + - ``window_end`` (the new cutoff stored in metadata) is also floored, so + the boundary stored in metadata is always interval-aligned. Records + between the floored cutoff and the raw cutoff (``newest - age_threshold``) + are left untouched and will be picked up on the next run once more data + arrives and the floored cutoff advances. + + Skips resampling entirely when the existing record count is already at or + below the number of buckets resampling would produce (sparse-data guard). + When data is sparse but timestamps are misaligned the guard is bypassed and + timestamps are snapped to interval boundaries without changing values. + + Args: + age_threshold: Records older than (newest - age_threshold) are compacted. + target_interval: Target resolution after compaction. + + Returns: + Number of original records deleted (before re-insertion of downsampled + records). Returns 0 if skipped. + """ + self._db_ensure_initialized() + + interval_sec = int(target_interval.total_seconds()) + if interval_sec <= 0: + return 0 + + # ---- Determine raw new cutoff ------------------------------------ + _, db_max = self.db_timestamp_range() + if db_max is None or isinstance(db_max, _DatabaseTimestampUnbound): + return 0 + + newest_dt = DatabaseTimestamp.to_datetime(db_max) + raw_cutoff_dt = newest_dt - age_threshold + + # Snap new_cutoff DOWN to the nearest interval boundary. + # Records in [floored_cutoff, raw_cutoff) are left alone until the next + # run — they are inside the age window but straddle an incomplete bucket. + raw_cutoff_epoch = int(raw_cutoff_dt.timestamp()) + floored_cutoff_epoch = (raw_cutoff_epoch // interval_sec) * interval_sec + new_cutoff_dt = DateTime.fromtimestamp(floored_cutoff_epoch, tz="UTC") + new_cutoff_ts = DatabaseTimestamp.from_datetime(new_cutoff_dt) + + # ---- Determine window start (incremental) ------------------------ + last_cutoff_ts = self._db_get_compact_state(target_interval) + + if last_cutoff_ts is not None and last_cutoff_ts >= new_cutoff_ts: + logger.debug( + f"Namespace '{self.db_namespace()}' tier {target_interval} already " + f"compacted up to {new_cutoff_ts}, skipping." + ) + return 0 + + db_min, _ = self.db_timestamp_range() + if db_min is None or isinstance(db_min, _DatabaseTimestampUnbound): + return 0 + + # Raw window start: last cutoff or absolute db minimum + raw_window_start_ts = last_cutoff_ts if last_cutoff_ts is not None else db_min + if raw_window_start_ts >= new_cutoff_ts: + return 0 + + raw_window_start_dt = DatabaseTimestamp.to_datetime(raw_window_start_ts) + + # Snap window_start DOWN to the nearest interval boundary so the first + # resampled bucket is clock-aligned. This may pull the window slightly + # earlier than the last stored cutoff, which is safe: key_to_array with + # boundary="strict" only reads the window we pass and the re-insert step + # is idempotent for already-compacted records (they will simply be + # overwritten with the same values). + raw_start_epoch = int(raw_window_start_dt.timestamp()) + floored_start_epoch = (raw_start_epoch // interval_sec) * interval_sec + window_start_dt = DateTime.fromtimestamp(floored_start_epoch, tz="UTC") + window_start_ts = DatabaseTimestamp.from_datetime(window_start_dt) + + window_end_dt = new_cutoff_dt # exclusive upper bound, already aligned + window_end_ts = new_cutoff_ts + + # ---- Sparse-data guard ------------------------------------------- + existing_count = self.database.count_records( + start_key=self._db_key_from_timestamp(window_start_ts), + end_key=self._db_key_from_timestamp(window_end_ts), + namespace=self.db_namespace(), + ) + + window_sec = int((window_end_dt - window_start_dt).total_seconds()) + # Maximum number of buckets resampling could produce (ceiling division) + resampled_count = (window_sec + interval_sec - 1) // interval_sec + + if existing_count == 0: + # Nothing in window — just advance the cutoff + self._db_set_compact_state(target_interval, new_cutoff_ts) + return 0 + + if existing_count <= resampled_count: + # Data is already sparse — check whether timestamps are aligned. + # If every record already sits on an interval boundary, nothing to do. + # If any are misaligned, snap them in place without resampling. + records_in_window = [ + r + for r in self.records + if r.date_time is not None and window_start_dt <= r.date_time < window_end_dt + ] + misaligned = [ + r for r in records_in_window if int(r.date_time.timestamp()) % interval_sec != 0 + ] + if not misaligned: + logger.debug( + f"Skipping tier {target_interval} compaction for " + f"namespace '{self.db_namespace()}': " + f"existing={existing_count} <= resampled={resampled_count} " + f"and all timestamps already aligned " + f"(window={window_start_dt}..{window_end_dt})" + ) + self._db_set_compact_state(target_interval, new_cutoff_ts) + return 0 + + # ---- Sparse but misaligned: full window rewrite ----------------- + # Delete the entire window and reinsert floor-snapped records. + # Deleting first guarantees no duplicate-timestamp ValueError on + # reinsert, even when an already-aligned record sits at the same + # epoch that a misaligned record floors to. + logger.debug( + f"Rewriting sparse window in namespace '{self.db_namespace()}' " + f"tier {target_interval} (existing={existing_count}, " + f"resampled={resampled_count})" + ) + + # Build snapped buckets from ALL records in window. + # Process chronologically so the earliest record's values win when + # multiple records floor to the same bucket. + snapped_bucket: dict[int, dict[str, Any]] = {} + for r in sorted(records_in_window, key=lambda x: x.date_time): + ts_epoch = int(r.date_time.timestamp()) + snapped_epoch = (ts_epoch // interval_sec) * interval_sec + bucket = snapped_bucket.setdefault(snapped_epoch, {}) + for key in self.record_keys_writable: + if key == "date_time": + continue + try: + val = r[key] + except KeyError: + continue + if val is not None and bucket.get(key) is None: + bucket[key] = val + + # Delete entire window (aligned + misaligned) + deleted = self.db_delete_records( + start_timestamp=window_start_ts, + end_timestamp=window_end_ts, + ) + + # Reinsert one record per bucket + for snapped_epoch, values in snapped_bucket.items(): + if not values: + continue + snapped_dt = DateTime.fromtimestamp(snapped_epoch, tz="UTC") + record = self.record_class()(date_time=snapped_dt, **values) + self.db_insert_record(record, mark_dirty=True) + + self.db_save_records() + self._db_set_compact_state(target_interval, new_cutoff_ts) + logger.info( + f"Rewrote sparse window in namespace '{self.db_namespace()}' " + f"tier {target_interval}: deleted={deleted}, " + f"reinserted={len(snapped_bucket)} buckets " + f"(window={window_start_dt}..{window_end_dt})" + ) + return deleted + + # ---- Full resampling path ---------------------------------------- + # boundary="context" is used here instead of "strict" so that key_to_array + # can include one record on each side of the window for proper interpolation + # at the edges. The truncation inside key_to_array then clips the result + # back to [window_start_dt, window_end_dt) so no out-of-window values are + # ever written back. align_to_interval=True ensures buckets land on + # clock-round timestamps regardless of window_start_dt precision. + compacted_data: dict[str, Any] = {} + compacted_timestamps: list[DateTime] = [] + + for key in self.record_keys_writable: + if key == "date_time": + continue + try: + array = self.key_to_array( + key, + start_datetime=window_start_dt, + end_datetime=window_end_dt, + interval=target_interval, + fill_method="time", + boundary="context", + align_to_interval=True, + ) + except (KeyError, TypeError, ValueError): + continue # non-numeric or missing key — skip silently + + if len(array) == 0: + continue + + # Build the shared timestamp spine once from the first successful key. + # The spine is derived from the actual resampled index, not from + # db_generate_timestamps, so it matches exactly what key_to_array + # produced (epoch-aligned, truncated to window). + if not compacted_timestamps: + raw_start_epoch_aligned = ( + int(window_start_dt.timestamp()) // interval_sec + ) * interval_sec + first_bucket_epoch = raw_start_epoch_aligned + # Advance to first bucket >= window_start_dt (truncation in key_to_array + # removes any bucket before window_start_dt) + while first_bucket_epoch < int(window_start_dt.timestamp()): + first_bucket_epoch += interval_sec + compacted_timestamps = [ + DateTime.fromtimestamp(first_bucket_epoch + i * interval_sec, tz="UTC") + for i in range(len(array)) + ] + + # Guard against length mismatch between keys + if len(array) == len(compacted_timestamps): + compacted_data[key] = array + + if not compacted_data or not compacted_timestamps: + # Nothing to write back — still advance cutoff + self._db_set_compact_state(target_interval, new_cutoff_ts) + return 0 + + # ---- Delete originals, re-insert downsampled records ------------- + deleted = self.db_delete_records( + start_timestamp=window_start_ts, + end_timestamp=window_end_ts, + ) + + for i, dt in enumerate(compacted_timestamps): + values = { + key: arr[i] + for key, arr in compacted_data.items() + if i < len(arr) and arr[i] is not None + } + if values: + record = self.record_class()(date_time=dt, **values) + self.db_insert_record(record, mark_dirty=True) + + self.db_save_records() + + # Persist the aligned new cutoff for this tier + self._db_set_compact_state(target_interval, new_cutoff_ts) + + logger.info( + f"Compacted tier {target_interval}: deleted {deleted} records in " + f"namespace '{self.db_namespace()}' " + f"(window={window_start_dt}..{window_end_dt}, " + f"reinserted={len(compacted_timestamps)})" + ) + return deleted + + # ------------------------------------------------------------------ + # Public entry point + # ------------------------------------------------------------------ + + def db_compact( + self, + compact_tiers: Optional[list[tuple[Duration, Duration]]] = None, + ) -> int: + """Apply tiered compaction policy to all records in this namespace. + + Tiers are processed coarsest-first (longest age threshold first) to + avoid compacting fine-grained data that an inner tier would immediately + re-compact anyway. + + Args: + compact_tiers: Override tiers for this call. If None, uses + db_compact_tiers(). Each entry is (age_threshold, target_interval), + ordered shortest to longest age threshold. + + Returns: + Total number of original records deleted across all tiers. + """ + if compact_tiers is None: + compact_tiers = self.db_compact_tiers() + + if not compact_tiers: + return 0 + + total_deleted = 0 + + # Coarsest tier first (reversed) to avoid redundant work + for age_threshold, target_interval in reversed(compact_tiers): + total_deleted += self._db_compact_tier(age_threshold, target_interval) + + return total_deleted diff --git a/src/akkudoktoreos/core/ems.py b/src/akkudoktoreos/core/ems.py index 82563b0..06412e9 100644 --- a/src/akkudoktoreos/core/ems.py +++ b/src/akkudoktoreos/core/ems.py @@ -24,7 +24,7 @@ from akkudoktoreos.optimization.genetic.geneticparams import ( ) from akkudoktoreos.optimization.genetic.geneticsolution import GeneticSolution from akkudoktoreos.optimization.optimization import OptimizationSolution -from akkudoktoreos.utils.datetimeutil import DateTime, compare_datetimes, to_datetime +from akkudoktoreos.utils.datetimeutil import DateTime, to_datetime # The executor to execute the CPU heavy energy management run executor = ThreadPoolExecutor(max_workers=1) @@ -44,6 +44,15 @@ class EnergyManagementStage(Enum): return self.value +async def ems_manage_energy() -> None: + """Repeating task for managing energy. + + This task should be executed by the server regularly + to ensure proper energy management. + """ + await EnergyManagement().run() + + class EnergyManagement( SingletonMixin, ConfigMixin, PredictionMixin, AdapterMixin, PydanticBaseModel ): @@ -286,6 +295,9 @@ class EnergyManagement( error_msg = f"Adapter update failed - phase {cls._stage}: {e}\n{trace}" logger.error(error_msg) + # Remember energy run datetime. + EnergyManagement._last_run_datetime = to_datetime() + # energy management run finished cls._stage = EnergyManagementStage.IDLE @@ -346,73 +358,3 @@ class EnergyManagement( ) # Run optimization in background thread to avoid blocking event loop await loop.run_in_executor(executor, func) - - async def manage_energy(self) -> None: - """Repeating task for managing energy. - - This task should be executed by the server regularly (e.g., every 10 seconds) - to ensure proper energy management. Configuration changes to the energy management interval - will only take effect if this task is executed. - - - Initializes and runs the energy management for the first time if it has never been run - before. - - If the energy management interval is not configured or invalid (NaN), the task will not - trigger any repeated energy management runs. - - Compares the current time with the last run time and runs the energy management if the - interval has elapsed. - - Logs any exceptions that occur during the initialization or execution of the energy - management. - - Note: The task maintains the interval even if some intervals are missed. - """ - current_datetime = to_datetime() - interval = self.config.ems.interval # interval maybe changed in between - - if EnergyManagement._last_run_datetime is None: - # Never run before - try: - # Remember energy run datetime. - EnergyManagement._last_run_datetime = current_datetime - # Try to run a first energy management. May fail due to config incomplete. - await self.run() - except Exception as e: - trace = "".join(traceback.TracebackException.from_exception(e).format()) - message = f"EOS init: {e}\n{trace}" - logger.error(message) - return - - if interval is None or interval == float("nan"): - # No Repetition - return - - if ( - compare_datetimes(current_datetime, EnergyManagement._last_run_datetime).time_diff - < interval - ): - # Wait for next run - return - - try: - await self.run() - except Exception as e: - trace = "".join(traceback.TracebackException.from_exception(e).format()) - message = f"EOS run: {e}\n{trace}" - logger.error(message) - - # Remember the energy management run - keep on interval even if we missed some intervals - while ( - compare_datetimes(current_datetime, EnergyManagement._last_run_datetime).time_diff - >= interval - ): - EnergyManagement._last_run_datetime = EnergyManagement._last_run_datetime.add( - seconds=interval - ) - - -# Initialize the Energy Management System, it is a singleton. -ems = EnergyManagement() - - -def get_ems() -> EnergyManagement: - """Gets the EOS Energy Management System.""" - return ems diff --git a/src/akkudoktoreos/core/emsettings.py b/src/akkudoktoreos/core/emsettings.py index b22e2b6..0ac4449 100644 --- a/src/akkudoktoreos/core/emsettings.py +++ b/src/akkudoktoreos/core/emsettings.py @@ -29,10 +29,11 @@ class EnergyManagementCommonSettings(SettingsBaseModel): }, ) - interval: Optional[float] = Field( - default=None, + interval: float = Field( + default=300.0, + ge=60.0, json_schema_extra={ - "description": "Intervall in seconds between EOS energy management runs.", + "description": "Intervall between EOS energy management runs [seconds].", "examples": ["300"], }, ) diff --git a/src/akkudoktoreos/core/pydantic.py b/src/akkudoktoreos/core/pydantic.py index 781b4b6..dcce1b9 100644 --- a/src/akkudoktoreos/core/pydantic.py +++ b/src/akkudoktoreos/core/pydantic.py @@ -47,7 +47,12 @@ from pydantic import ( ) from pydantic.fields import ComputedFieldInfo, FieldInfo -from akkudoktoreos.utils.datetimeutil import DateTime, to_datetime, to_duration +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + to_datetime, + to_duration, + to_timezone, +) # Global weakref dictionary to hold external state per model instance # Used as a workaround for PrivateAttr not working in e.g. Mixin Classes @@ -683,13 +688,8 @@ class PydanticBaseModel(PydanticModelNestedValueMixin, BaseModel): self, *args: Any, include_computed_fields: bool = True, **kwargs: Any ) -> dict[str, Any]: """Custom dump method to serialize computed fields by default.""" - result = super().model_dump(*args, **kwargs) - - if not include_computed_fields: - for computed_field_name in self.__class__.model_computed_fields: - result.pop(computed_field_name, None) - - return result + kwargs.setdefault("exclude_computed_fields", not include_computed_fields) + return super().model_dump(*args, **kwargs) def to_dict(self) -> dict: """Convert this PredictionRecord instance to a dictionary representation. @@ -1061,8 +1061,8 @@ class PydanticDateTimeDataFrame(PydanticBaseModel): valid_base_dtypes = {"int64", "float64", "bool", "object", "string"} def is_valid_dtype(dtype: str) -> bool: - # Allow timezone-aware or naive datetime64 - if dtype.startswith("datetime64[ns"): + # Allow timezone-aware or naive datetime64 - pandas 3.0 also has us + if dtype.startswith("datetime64[ns") or dtype.startswith("datetime64[us"): return True return dtype in valid_base_dtypes @@ -1102,7 +1102,7 @@ class PydanticDateTimeDataFrame(PydanticBaseModel): # Apply dtypes for col, dtype in self.dtypes.items(): - if dtype.startswith("datetime64[ns"): + if dtype.startswith("datetime64[ns") or dtype.startswith("datetime64[us"): df[col] = pd.to_datetime(df[col], utc=True) elif dtype in dtype_mapping.keys(): df[col] = df[col].astype(dtype_mapping[dtype]) @@ -1111,20 +1111,59 @@ class PydanticDateTimeDataFrame(PydanticBaseModel): return df + @classmethod + def _detect_data_tz(cls, df: pd.DataFrame) -> Optional[str]: + """Detect timezone of pandas data.""" + # Index first (strongest signal) + if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is not None: + return str(df.index.tz) + + # Then datetime columns + for col in df.columns: + if is_datetime64_any_dtype(df[col]): + tz = getattr(df[col].dt, "tz", None) + if tz is not None: + return str(tz) + + return None + @classmethod def from_dataframe( cls, df: pd.DataFrame, tz: Optional[str] = None ) -> "PydanticDateTimeDataFrame": """Create a PydanticDateTimeDataFrame instance from a pandas DataFrame.""" - index = pd.Index([to_datetime(dt, as_string=True, in_timezone=tz) for dt in df.index]) + # resolve timezone + data_tz = cls._detect_data_tz(df) + + if tz is not None: + if data_tz and data_tz != tz: + raise ValueError(f"Timezone mismatch: tz='{tz}' but data uses '{data_tz}'") + resolved_tz = tz + else: + if data_tz: + resolved_tz = data_tz + else: + # Use local timezone + resolved_tz = to_timezone(as_string=True) + + # normalize index + index = pd.Index( + [to_datetime(dt, as_string=True, in_timezone=resolved_tz) for dt in df.index] + ) df.index = index + # normalize datetime columns datetime_columns = [col for col in df.columns if is_datetime64_any_dtype(df[col])] + for col in datetime_columns: + if df[col].dt.tz is None: + df[col] = df[col].dt.tz_localize(resolved_tz) + else: + df[col] = df[col].dt.tz_convert(resolved_tz) return cls( data=df.to_dict(orient="index"), dtypes={col: str(dtype) for col, dtype in df.dtypes.items()}, - tz=tz, + tz=resolved_tz, datetime_columns=datetime_columns, ) diff --git a/src/akkudoktoreos/core/version.py b/src/akkudoktoreos/core/version.py index bd445ce..059d049 100644 --- a/src/akkudoktoreos/core/version.py +++ b/src/akkudoktoreos/core/version.py @@ -2,6 +2,7 @@ import hashlib import re +from dataclasses import dataclass from fnmatch import fnmatch from pathlib import Path from typing import Optional @@ -16,14 +17,117 @@ HASH_EOS = "" # Number of digits to append to .dev to identify a development version VERSION_DEV_PRECISION = 8 +# Hashing configuration +DIR_PACKAGE_ROOT = Path(__file__).resolve().parent.parent +ALLOWED_SUFFIXES: set[str] = {".py", ".md", ".json"} +EXCLUDED_DIR_PATTERNS: set[str] = {"*_autosum", "*__pycache__", "*_generated"} +EXCLUDED_FILES: set[Path] = set() + + # ------------------------------ # Helpers for version generation # ------------------------------ -def is_excluded_dir(path: Path, excluded_dir_patterns: set[str]) -> bool: - """Check whether a directory should be excluded based on name patterns.""" - return any(fnmatch(path.name, pattern) for pattern in excluded_dir_patterns) +@dataclass +class HashConfig: + """Configuration for file hashing.""" + + paths: list[Path] + allowed_suffixes: set[str] + excluded_dir_patterns: set[str] + excluded_files: set[Path] + + def __post_init__(self) -> None: + """Validate configuration.""" + for path in self.paths: + if not path.exists(): + raise ValueError(f"Path does not exist: {path}") + + +def is_excluded_dir(path: Path, patterns: set[str]) -> bool: + """Check if directory matches any exclusion pattern. + + Args: + path: Directory path to check + patterns: set of glob-like patterns (e.g., {``*__pycache__``, ``*_test``}) + + Returns: + True if directory should be excluded + """ + dir_name = path.name + return any(fnmatch(dir_name, pattern) for pattern in patterns) + + +def collect_files(config: HashConfig) -> list[Path]: + """Collect all files that should be included in the hash. + + This function only collects files - it doesn't hash them. + Makes it easy to inspect what will be hashed. + + Args: + config: Hash configuration + + Returns: + Sorted list of files to be hashed + + Example: + >>> config = HashConfig( + ... paths=[Path('src')], + ... allowed_suffixes={'.py'}, + ... excluded_dir_patterns={'*__pycache__'}, + ... excluded_files=set() + ... ) + >>> files = collect_files(config) + >>> print(f"Will hash {len(files)} files") + >>> for f in files[:5]: + ... print(f" {f}") + """ + collected_files: list[Path] = [] + + for root in config.paths: + for p in sorted(root.rglob("*")): + # Skip excluded directories + if p.is_dir() and is_excluded_dir(p, config.excluded_dir_patterns): + continue + + # Skip files inside excluded directories + if any(is_excluded_dir(parent, config.excluded_dir_patterns) for parent in p.parents): + continue + + # Skip excluded files + if p.resolve() in config.excluded_files: + continue + + # Collect only allowed file types + if p.is_file() and p.suffix.lower() in config.allowed_suffixes: + collected_files.append(p.resolve()) + + return sorted(collected_files) + + +def hash_files(files: list[Path]) -> str: + """Calculate SHA256 hash of file contents. + + Args: + files: list of files to hash (order matters!) + + Returns: + SHA256 hex digest + + Example: + >>> files = [Path('file1.py'), Path('file2.py')] + >>> hash_value = hash_files(files) + """ + h = hashlib.sha256() + + for file_path in files: + if not file_path.exists(): + continue + + h.update(file_path.read_bytes()) + + return h.hexdigest() def hash_tree( @@ -31,80 +135,93 @@ def hash_tree( allowed_suffixes: set[str], excluded_dir_patterns: set[str], excluded_files: Optional[set[Path]] = None, -) -> str: - """Return SHA256 hash for files under `paths`. +) -> tuple[str, list[Path]]: + """Return SHA256 hash for files under `paths` and the list of files hashed. - Restricted by suffix, excluding excluded directory patterns and excluded_files. + Args: + paths: list of root paths to hash + allowed_suffixes: set of file suffixes to include (e.g., {'.py', '.json'}) + excluded_dir_patterns: set of directory patterns to exclude + excluded_files: Optional set of specific files to exclude + + Returns: + tuple of (hash_digest, list_of_hashed_files) + + Example: + >>> hash_digest, files = hash_tree( + ... paths=[Path('src')], + ... allowed_suffixes={'.py'}, + ... excluded_dir_patterns={'*__pycache__'}, + ... ) + >>> print(f"Hash: {hash_digest}") + >>> print(f"Based on {len(files)} files") """ - h = hashlib.sha256() - excluded_files = excluded_files or set() + config = HashConfig( + paths=paths, + allowed_suffixes=allowed_suffixes, + excluded_dir_patterns=excluded_dir_patterns, + excluded_files=excluded_files or set(), + ) - for root in paths: - if not root.exists(): - raise ValueError(f"Root path does not exist: {root}") - for p in sorted(root.rglob("*")): - # Skip excluded directories - if p.is_dir() and is_excluded_dir(p, excluded_dir_patterns): - continue + files = collect_files(config) + digest = hash_files(files) - # Skip files inside excluded directories - if any(is_excluded_dir(parent, excluded_dir_patterns) for parent in p.parents): - continue + return digest, files - # Skip excluded files - if p.resolve() in excluded_files: - continue - # Hash only allowed file types - if p.is_file() and p.suffix.lower() in allowed_suffixes: - h.update(p.read_bytes()) - - digest = h.hexdigest() - - return digest +# --------------------- +# Version hash function +# --------------------- def _version_hash() -> str: """Calculate project hash. - Only package file ins src/akkudoktoreos can be hashed to make it work also for packages. + Only package files in src/akkudoktoreos can be hashed to make it work also for packages. + + Returns: + SHA256 hash of the project files """ - DIR_PACKAGE_ROOT = Path(__file__).resolve().parent.parent + if not str(DIR_PACKAGE_ROOT).endswith("src/akkudoktoreos"): + error_msg = f"DIR_PACKAGE_ROOT does not end with src/akkudoktoreos: {DIR_PACKAGE_ROOT}" + raise ValueError(error_msg) - # Allowed file suffixes to consider - ALLOWED_SUFFIXES: set[str] = {".py", ".md", ".json"} - - # Directory patterns to exclude (glob-like) - EXCLUDED_DIR_PATTERNS: set[str] = {"*_autosum", "*__pycache__", "*_generated"} - - # Files to exclude - EXCLUDED_FILES: set[Path] = set() - - # Directories whose changes shall be part of the project hash + # Configuration watched_paths = [DIR_PACKAGE_ROOT] - hash_current = hash_tree( - watched_paths, ALLOWED_SUFFIXES, EXCLUDED_DIR_PATTERNS, excluded_files=EXCLUDED_FILES + # Collect files and calculate hash + hash_digest, hashed_files = hash_tree( + watched_paths, + ALLOWED_SUFFIXES, + EXCLUDED_DIR_PATTERNS, + excluded_files=EXCLUDED_FILES, ) - return hash_current + + return hash_digest def _version_calculate() -> str: - """Compute version.""" - global HASH_EOS - HASH_EOS = _version_hash() - if VERSION_BASE.endswith("dev"): + """Calculate the full version string. + + For release versions: "x.y.z" + For dev versions: "x.y.z.dev" + + Returns: + Full version string + """ + if VERSION_BASE.endswith(".dev"): # After dev only digits are allowed - convert hexdigest to digits - hash_value = int(HASH_EOS, 16) + hash_value = int(_version_hash(), 16) hash_digits = str(hash_value % (10**VERSION_DEV_PRECISION)).zfill(VERSION_DEV_PRECISION) return f"{VERSION_BASE}{hash_digits}" else: + # Release version - use base as-is return VERSION_BASE # --------------------------- # Project version information -# ---------------------------- +# --------------------------- # The version __version__ = _version_calculate() @@ -114,16 +231,13 @@ __version__ = _version_calculate() # Version info access # ------------------- - # Regular expression to split the version string into pieces VERSION_RE = re.compile( r""" ^(?P\d+\.\d+\.\d+) # x.y.z - (?:[\.\+\-] # .dev starts here - (?: - (?Pdev) # literal 'dev' - (?:(?P[A-Za-z0-9]+))? # optional - ) + (?:\. # .dev starts here + (?Pdev) # literal 'dev' + (?P[a-f0-9]+)? # optional (hex digits) )? $ """, @@ -143,7 +257,7 @@ def version() -> dict[str, Optional[str]]: .. code-block:: python { - "version": "0.2.0+dev.a96a65", + "version": "0.2.0.dev.a96a65", "base": "x.y.z", "dev": "dev" or None, "hash": "" or None, @@ -153,7 +267,7 @@ def version() -> dict[str, Optional[str]]: match = VERSION_RE.match(__version__) if not match: - raise ValueError(f"Invalid version format: {version}") + raise ValueError(f"Invalid version format: {__version__}") # Fixed: was 'version' info = match.groupdict() info["version"] = __version__ diff --git a/src/akkudoktoreos/devices/devices.py b/src/akkudoktoreos/devices/devices.py index b7a27c4..c0c34dd 100644 --- a/src/akkudoktoreos/devices/devices.py +++ b/src/akkudoktoreos/devices/devices.py @@ -431,8 +431,3 @@ class ResourceRegistry(SingletonMixin, ConfigMixin, PydanticBaseModel): self.history = loaded.history except Exception as e: logger.error("Can not load resource registry: {}", e) - - -def get_resource_registry() -> ResourceRegistry: - """Gets the EOS resource registry.""" - return ResourceRegistry() diff --git a/src/akkudoktoreos/devices/genetic/battery.py b/src/akkudoktoreos/devices/genetic/battery.py index c4cd7fa..88123c0 100644 --- a/src/akkudoktoreos/devices/genetic/battery.py +++ b/src/akkudoktoreos/devices/genetic/battery.py @@ -87,7 +87,7 @@ class Battery: def reset(self) -> None: """Resets the battery state to its initial values.""" self.soc_wh = (self.initial_soc_percentage / 100) * self.capacity_wh - self.soc_wh = min(max(self.soc_wh, self.min_soc_wh), self.max_soc_wh) + self.soc_wh = min(self.soc_wh, self.max_soc_wh) # Only clamp to max self.discharge_array = np.full(self.prediction_hours, 0) self.charge_array = np.full(self.prediction_hours, 0) diff --git a/src/akkudoktoreos/measurement/measurement.py b/src/akkudoktoreos/measurement/measurement.py index e696e95..dbbdca4 100644 --- a/src/akkudoktoreos/measurement/measurement.py +++ b/src/akkudoktoreos/measurement/measurement.py @@ -6,6 +6,7 @@ data records for measurements. The measurements can be added programmatically or imported from a file or JSON string. """ +from pathlib import Path from typing import Any, Optional import numpy as np @@ -16,12 +17,26 @@ from pydantic import Field, computed_field from akkudoktoreos.config.configabc import SettingsBaseModel from akkudoktoreos.core.coreabc import SingletonMixin from akkudoktoreos.core.dataabc import DataImportMixin, DataRecord, DataSequence -from akkudoktoreos.utils.datetimeutil import DateTime, Duration, to_duration +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + Duration, + to_datetime, + to_duration, +) class MeasurementCommonSettings(SettingsBaseModel): """Measurement Configuration.""" + historic_hours: Optional[int] = Field( + default=2 * 365 * 24, + ge=0, + json_schema_extra={ + "description": "Number of hours into the past for measurement data", + "examples": [2 * 365 * 24], + }, + ) + load_emr_keys: Optional[list[str]] = Field( default=None, json_schema_extra={ @@ -94,6 +109,16 @@ class Measurement(SingletonMixin, DataImportMixin, DataSequence): return super().__init__(*args, **kwargs) + def _measurement_file_path(self) -> Optional[Path]: + """Path to measurements file (may be used optional to database).""" + try: + return self.config.general.data_folder_path / "measurement.json" + except Exception: + logger.error( + "Path for measurements is missing. Please configure data folder path or database!" + ) + return None + def _interval_count( self, start_datetime: DateTime, end_datetime: DateTime, interval: Duration ) -> int: @@ -143,30 +168,32 @@ class Measurement(SingletonMixin, DataImportMixin, DataSequence): np.ndarray: A NumPy Array of the energy [kWh] per interval values calculated from the meter readings. """ - # Add one interval to end_datetime to assure we have a energy value interval for all - # datetimes from start_datetime (inclusive) to end_datetime (exclusive) - end_datetime += interval size = self._interval_count(start_datetime, end_datetime, interval) energy_mr_array = self.key_to_array( - key=key, start_datetime=start_datetime, end_datetime=end_datetime, interval=interval + key=key, + start_datetime=start_datetime, + end_datetime=end_datetime + interval, + interval=interval, + fill_method="time", + boundary="context", ) - if energy_mr_array.size != size: + if energy_mr_array.size != size + 1: logging_msg = ( f"'{key}' meter reading array size: {energy_mr_array.size}" - f" does not fit to expected size: {size}, {energy_mr_array}" + f" does not fit to expected size: {size + 1}, {energy_mr_array}" ) if energy_mr_array.size != 0: logger.error(logging_msg) raise ValueError(logging_msg) logger.debug(logging_msg) - energy_array = np.zeros(size - 1) + energy_array = np.zeros(size) elif np.any(energy_mr_array == None): # 'key_to_array()' creates None values array if no data records are available. # Array contains None value -> ignore debug_msg = f"'{key}' meter reading None: {energy_mr_array}" logger.debug(debug_msg) - energy_array = np.zeros(size - 1) + energy_array = np.zeros(size) else: # Calculate load per interval debug_msg = f"'{key}' meter reading: {energy_mr_array}" @@ -193,6 +220,9 @@ class Measurement(SingletonMixin, DataImportMixin, DataSequence): np.ndarray: A NumPy Array of the total load energy [kWh] per interval values calculated from the load meter readings. """ + if interval is None: + interval = to_duration("1 hour") + if len(self) < 1: # No data available if start_datetime is None or end_datetime is None: @@ -200,14 +230,14 @@ class Measurement(SingletonMixin, DataImportMixin, DataSequence): else: size = self._interval_count(start_datetime, end_datetime, interval) return np.zeros(size) - if interval is None: - interval = to_duration("1 hour") + if start_datetime is None: - start_datetime = self[0].date_time + start_datetime = self.min_datetime if end_datetime is None: - end_datetime = self[-1].date_time + end_datetime = self.max_datetime.add(seconds=1) size = self._interval_count(start_datetime, end_datetime, interval) load_total_kwh_array = np.zeros(size) + # Loop through all loads if isinstance(self.config.measurement.load_emr_keys, list): for key in self.config.measurement.load_emr_keys: @@ -225,7 +255,66 @@ class Measurement(SingletonMixin, DataImportMixin, DataSequence): return load_total_kwh_array + # ----------------------- Measurement Database Protocol --------------------- -def get_measurement() -> Measurement: - """Gets the EOS measurement data.""" - return Measurement() + def db_namespace(self) -> str: + return "Measurement" + + def db_keep_datetime(self) -> Optional[DateTime]: + """Earliest datetime from which database records should be retained. + + Used when removing old records from database to free space. + + Returns: + Datetime or None. + """ + return to_datetime().subtract(hours=self.config.measurement.historic_hours) + + def save(self) -> bool: + """Save the measurements to persistent storage. + + Returns: + True in case the measurements were saved, False otherwise. + """ + # Use db storage if available + saved_to_db = DataSequence.save(self) + if not saved_to_db: + measurement_file_path = self._measurement_file_path() + if measurement_file_path is None: + return False + try: + measurement_file_path.write_text( + self.model_dump_json(indent=4), + encoding="utf-8", + newline="\n", + ) + except Exception as e: + logger.exception("Cannot save measurements") + return True + + def load(self) -> bool: + """Load measurements from persistent storage. + + Returns: + True in case the measurements were loaded, False otherwise. + """ + # Use db storage if available + loaded_from_db = DataSequence.load(self) + if not loaded_from_db: + measurement_file_path = self._measurement_file_path() + if measurement_file_path is None: + return False + if not measurement_file_path.exists(): + return False + try: + # Validate into a temporary instance + loaded = self.__class__.model_validate_json( + measurement_file_path.read_text(encoding="utf-8") + ) + + # Explicitly add data records to the existing singleton + for record in loaded.records: + self.insert_by_datetime(record) + except Exception as e: + logger.exception("Cannot load measurements") + return True diff --git a/src/akkudoktoreos/optimization/genetic/geneticparams.py b/src/akkudoktoreos/optimization/genetic/geneticparams.py index 5c7ff50..f073801 100644 --- a/src/akkudoktoreos/optimization/genetic/geneticparams.py +++ b/src/akkudoktoreos/optimization/genetic/geneticparams.py @@ -18,6 +18,7 @@ from akkudoktoreos.core.coreabc import ( ConfigMixin, MeasurementMixin, PredictionMixin, + get_ems, ) from akkudoktoreos.optimization.genetic.geneticabc import GeneticParametersBaseModel from akkudoktoreos.optimization.genetic.geneticdevices import ( @@ -161,9 +162,6 @@ class GeneticOptimizationParameters( Raises: ValueError: If required configuration values like start time are missing. """ - # Avoid circular dependency - from akkudoktoreos.core.ems import get_ems - ems = get_ems() # The optimization paramters @@ -439,6 +437,7 @@ class GeneticOptimizationParameters( initial_soc_factor = cls.measurement.key_to_value( key=battery_config.measurement_key_soc_factor, target_datetime=ems.start_datetime, + time_window=to_duration(to_duration("48 hours")), ) if initial_soc_factor > 1.0 or initial_soc_factor < 0.0: logger.error( @@ -510,6 +509,7 @@ class GeneticOptimizationParameters( initial_soc_factor = cls.measurement.key_to_value( key=electric_vehicle_config.measurement_key_soc_factor, target_datetime=ems.start_datetime, + time_window=to_duration(to_duration("48 hours")), ) if initial_soc_factor > 1.0 or initial_soc_factor < 0.0: logger.error( diff --git a/src/akkudoktoreos/optimization/genetic/geneticsolution.py b/src/akkudoktoreos/optimization/genetic/geneticsolution.py index c558719..210317c 100644 --- a/src/akkudoktoreos/optimization/genetic/geneticsolution.py +++ b/src/akkudoktoreos/optimization/genetic/geneticsolution.py @@ -8,6 +8,8 @@ from pydantic import Field, field_validator from akkudoktoreos.core.coreabc import ( ConfigMixin, + get_ems, + get_prediction, ) from akkudoktoreos.core.emplan import ( DDBCInstruction, @@ -22,7 +24,6 @@ from akkudoktoreos.devices.devicesabc import ( from akkudoktoreos.devices.genetic.battery import Battery from akkudoktoreos.optimization.genetic.geneticdevices import GeneticParametersBaseModel from akkudoktoreos.optimization.optimization import OptimizationSolution -from akkudoktoreos.prediction.prediction import get_prediction from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration from akkudoktoreos.utils.utils import NumpyEncoder @@ -272,8 +273,6 @@ class GeneticSolution(ConfigMixin, GeneticParametersBaseModel): - GRID_SUPPORT_EXPORT: ac_charge == 0 and discharge_allowed == 1 - GRID_SUPPORT_IMPORT: ac_charge > 0 and discharge_allowed == 0 or 1 """ - from akkudoktoreos.core.ems import get_ems - start_datetime = get_ems().start_datetime start_day_hour = start_datetime.in_timezone(self.config.general.timezone).hour interval_hours = 1 @@ -567,8 +566,6 @@ class GeneticSolution(ConfigMixin, GeneticParametersBaseModel): def energy_management_plan(self) -> EnergyManagementPlan: """Provide the genetic solution as an energy management plan.""" - from akkudoktoreos.core.ems import get_ems - start_datetime = get_ems().start_datetime start_day_hour = start_datetime.in_timezone(self.config.general.timezone).hour plan = EnergyManagementPlan( diff --git a/src/akkudoktoreos/optimization/optimization.py b/src/akkudoktoreos/optimization/optimization.py index 2540d00..c2e557a 100644 --- a/src/akkudoktoreos/optimization/optimization.py +++ b/src/akkudoktoreos/optimization/optimization.py @@ -3,6 +3,7 @@ from typing import Optional, Union from pydantic import Field, computed_field, model_validator from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.core.pydantic import ( PydanticBaseModel, PydanticDateTimeDataFrame, @@ -91,10 +92,14 @@ class OptimizationCommonSettings(SettingsBaseModel): @property def keys(self) -> list[str]: """The keys of the solution.""" - from akkudoktoreos.core.ems import get_ems + try: + ems_eos = get_ems() + except: + # ems might not be initialized + return [] key_list = [] - optimization_solution = get_ems().optimization_solution() + optimization_solution = ems_eos.optimization_solution() if optimization_solution: # Prepare mapping df = optimization_solution.solution.to_dataframe() diff --git a/src/akkudoktoreos/prediction/elecprice.py b/src/akkudoktoreos/prediction/elecprice.py index 457ab05..6d8576a 100644 --- a/src/akkudoktoreos/prediction/elecprice.py +++ b/src/akkudoktoreos/prediction/elecprice.py @@ -3,21 +3,28 @@ from typing import Optional from pydantic import Field, computed_field, field_validator from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.elecpriceabc import ElecPriceProvider from akkudoktoreos.prediction.elecpriceenergycharts import ( ElecPriceEnergyChartsCommonSettings, ) from akkudoktoreos.prediction.elecpriceimport import ElecPriceImportCommonSettings -from akkudoktoreos.prediction.prediction import get_prediction -prediction_eos = get_prediction() -# Valid elecprice providers -elecprice_providers = [ - provider.provider_id() - for provider in prediction_eos.providers - if isinstance(provider, ElecPriceProvider) -] +def elecprice_provider_ids() -> list[str]: + """Valid elecprice provider ids.""" + try: + prediction_eos = get_prediction() + except: + # Prediction may not be initialized + # Return at least provider used in example + return ["ElecPriceAkkudoktor"] + + return [ + provider.provider_id() + for provider in prediction_eos.providers + if isinstance(provider, ElecPriceProvider) + ] class ElecPriceCommonSettings(SettingsBaseModel): @@ -61,14 +68,14 @@ class ElecPriceCommonSettings(SettingsBaseModel): @property def providers(self) -> list[str]: """Available electricity price provider ids.""" - return elecprice_providers + return elecprice_provider_ids() # Validators @field_validator("provider", mode="after") @classmethod def validate_provider(cls, value: Optional[str]) -> Optional[str]: - if value is None or value in elecprice_providers: + if value is None or value in elecprice_provider_ids(): return value raise ValueError( - f"Provider '{value}' is not a valid electricity price provider: {elecprice_providers}." + f"Provider '{value}' is not a valid electricity price provider: {elecprice_provider_ids()}." ) diff --git a/src/akkudoktoreos/prediction/feedintariff.py b/src/akkudoktoreos/prediction/feedintariff.py index 9412b50..4c5ac8d 100644 --- a/src/akkudoktoreos/prediction/feedintariff.py +++ b/src/akkudoktoreos/prediction/feedintariff.py @@ -3,19 +3,26 @@ from typing import Optional from pydantic import Field, computed_field, field_validator from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.feedintariffabc import FeedInTariffProvider from akkudoktoreos.prediction.feedintarifffixed import FeedInTariffFixedCommonSettings from akkudoktoreos.prediction.feedintariffimport import FeedInTariffImportCommonSettings -from akkudoktoreos.prediction.prediction import get_prediction -prediction_eos = get_prediction() -# Valid feedintariff providers -feedintariff_providers = [ - provider.provider_id() - for provider in prediction_eos.providers - if isinstance(provider, FeedInTariffProvider) -] +def elecprice_provider_ids() -> list[str]: + """Valid feedintariff provider ids.""" + try: + prediction_eos = get_prediction() + except: + # Prediction may not be initialized + # Return at least provider used in example + return ["FeedInTariffFixed", "FeedInTarifImport"] + + return [ + provider.provider_id() + for provider in prediction_eos.providers + if isinstance(provider, FeedInTariffProvider) + ] class FeedInTariffCommonProviderSettings(SettingsBaseModel): @@ -60,14 +67,14 @@ class FeedInTariffCommonSettings(SettingsBaseModel): @property def providers(self) -> list[str]: """Available feed in tariff provider ids.""" - return feedintariff_providers + return elecprice_provider_ids() # Validators @field_validator("provider", mode="after") @classmethod def validate_provider(cls, value: Optional[str]) -> Optional[str]: - if value is None or value in feedintariff_providers: + if value is None or value in elecprice_provider_ids(): return value raise ValueError( - f"Provider '{value}' is not a valid feed in tariff provider: {feedintariff_providers}." + f"Provider '{value}' is not a valid feed in tariff provider: {elecprice_provider_ids()}." ) diff --git a/src/akkudoktoreos/prediction/load.py b/src/akkudoktoreos/prediction/load.py index 2422e6d..92def68 100644 --- a/src/akkudoktoreos/prediction/load.py +++ b/src/akkudoktoreos/prediction/load.py @@ -5,20 +5,27 @@ from typing import Optional from pydantic import Field, computed_field, field_validator from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.loadabc import LoadProvider from akkudoktoreos.prediction.loadakkudoktor import LoadAkkudoktorCommonSettings from akkudoktoreos.prediction.loadimport import LoadImportCommonSettings from akkudoktoreos.prediction.loadvrm import LoadVrmCommonSettings -from akkudoktoreos.prediction.prediction import get_prediction -prediction_eos = get_prediction() -# Valid load providers -load_providers = [ - provider.provider_id() - for provider in prediction_eos.providers - if isinstance(provider, LoadProvider) -] +def load_providers() -> list[str]: + """Valid load provider ids.""" + try: + prediction_eos = get_prediction() + except: + # Prediction may not be initialized + # Return at least provider used in example + return ["LoadAkkudoktor", "LoadVrm", "LoadImport"] + + return [ + provider.provider_id() + for provider in prediction_eos.providers + if isinstance(provider, LoadProvider) + ] class LoadCommonProviderSettings(SettingsBaseModel): @@ -66,12 +73,12 @@ class LoadCommonSettings(SettingsBaseModel): @property def providers(self) -> list[str]: """Available load provider ids.""" - return load_providers + return load_providers() # Validators @field_validator("provider", mode="after") @classmethod def validate_provider(cls, value: Optional[str]) -> Optional[str]: - if value is None or value in load_providers: + if value is None or value in load_providers(): return value - raise ValueError(f"Provider '{value}' is not a valid load provider: {load_providers}.") + raise ValueError(f"Provider '{value}' is not a valid load provider: {load_providers()}.") diff --git a/src/akkudoktoreos/prediction/loadakkudoktor.py b/src/akkudoktoreos/prediction/loadakkudoktor.py index 26ebbe9..c34f4e0 100644 --- a/src/akkudoktoreos/prediction/loadakkudoktor.py +++ b/src/akkudoktoreos/prediction/loadakkudoktor.py @@ -132,23 +132,32 @@ class LoadAkkudoktorAdjusted(LoadAkkudoktor): compare_dt = compare_start for i in range(len(load_total_kwh_array)): load_total_wh = load_total_kwh_array[i] * 1000 + hour = compare_dt.hour + + # Weight calculated by distance in days to the latest measurement + weight = 1 / ((compare_end - compare_dt).days + 1) + # Extract mean (index 0) and standard deviation (index 1) for the given day and hour # Day indexing starts at 0, -1 because of that - hourly_stats = data_year_energy[compare_dt.day_of_year - 1, :, compare_dt.hour] - weight = 1 / ((compare_end - compare_dt).days + 1) + day_idx = compare_dt.day_of_year - 1 + hourly_stats = data_year_energy[day_idx, :, hour] + + # Calculate adjustments (working days and weekend) if compare_dt.day_of_week < 5: - weekday_adjust[compare_dt.hour] += (load_total_wh - hourly_stats[0]) * weight - weekday_adjust_weight[compare_dt.hour] += weight + weekday_adjust[hour] += (load_total_wh - hourly_stats[0]) * weight + weekday_adjust_weight[hour] += weight else: - weekend_adjust[compare_dt.hour] += (load_total_wh - hourly_stats[0]) * weight - weekend_adjust_weight[compare_dt.hour] += weight + weekend_adjust[hour] += (load_total_wh - hourly_stats[0]) * weight + weekend_adjust_weight[hour] += weight + compare_dt += compare_interval + # Calculate mean - for i in range(24): - if weekday_adjust_weight[i] > 0: - weekday_adjust[i] = weekday_adjust[i] / weekday_adjust_weight[i] - if weekend_adjust_weight[i] > 0: - weekend_adjust[i] = weekend_adjust[i] / weekend_adjust_weight[i] + for hour in range(24): + if weekday_adjust_weight[hour] > 0: + weekday_adjust[hour] = weekday_adjust[hour] / weekday_adjust_weight[hour] + if weekend_adjust_weight[hour] > 0: + weekend_adjust[hour] = weekend_adjust[hour] / weekend_adjust_weight[hour] return (weekday_adjust, weekend_adjust) diff --git a/src/akkudoktoreos/prediction/prediction.py b/src/akkudoktoreos/prediction/prediction.py index c83031b..4aaab4c 100644 --- a/src/akkudoktoreos/prediction/prediction.py +++ b/src/akkudoktoreos/prediction/prediction.py @@ -26,7 +26,7 @@ Attributes: weather_clearoutside (WeatherClearOutside): Weather forecast provider using ClearOutside. """ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import Field @@ -69,38 +69,6 @@ class PredictionCommonSettings(SettingsBaseModel): ) -class Prediction(PredictionContainer): - """Prediction container to manage multiple prediction providers. - - Attributes: - providers (List[Union[PVForecastAkkudoktor, WeatherBrightSky, WeatherClearOutside]]): - List of forecast provider instances, in the order they should be updated. - Providers may depend on updates from others. - """ - - providers: List[ - Union[ - ElecPriceAkkudoktor, - ElecPriceEnergyCharts, - ElecPriceImport, - FeedInTariffFixed, - FeedInTariffImport, - LoadAkkudoktor, - LoadAkkudoktorAdjusted, - LoadVrm, - LoadImport, - PVForecastAkkudoktor, - PVForecastVrm, - PVForecastImport, - WeatherBrightSky, - WeatherClearOutside, - WeatherImport, - ] - ] = Field( - default_factory=list, json_schema_extra={"description": "List of prediction providers"} - ) - - # Initialize forecast providers, all are singletons. elecprice_akkudoktor = ElecPriceAkkudoktor() elecprice_energy_charts = ElecPriceEnergyCharts() @@ -119,42 +87,85 @@ weather_clearoutside = WeatherClearOutside() weather_import = WeatherImport() -def get_prediction() -> Prediction: - """Gets the EOS prediction data.""" - # Initialize Prediction instance with providers in the required order +def prediction_providers() -> list[ + Union[ + ElecPriceAkkudoktor, + ElecPriceEnergyCharts, + ElecPriceImport, + FeedInTariffFixed, + FeedInTariffImport, + LoadAkkudoktor, + LoadAkkudoktorAdjusted, + LoadVrm, + LoadImport, + PVForecastAkkudoktor, + PVForecastVrm, + PVForecastImport, + WeatherBrightSky, + WeatherClearOutside, + WeatherImport, + ] +]: + """Return list of prediction providers.""" + global \ + elecprice_akkudoktor, \ + elecprice_energy_charts, \ + elecprice_import, \ + feedintariff_fixed, \ + feedintariff_import, \ + loadforecast_akkudoktor, \ + loadforecast_akkudoktor_adjusted, \ + loadforecast_vrm, \ + loadforecast_import, \ + pvforecast_akkudoktor, \ + pvforecast_vrm, \ + pvforecast_import, \ + weather_brightsky, \ + weather_clearoutside, \ + weather_import + # Care for provider sequence as providers may rely on others to be updated before. - prediction = Prediction( - providers=[ - elecprice_akkudoktor, - elecprice_energy_charts, - elecprice_import, - feedintariff_fixed, - feedintariff_import, - loadforecast_akkudoktor, - loadforecast_akkudoktor_adjusted, - loadforecast_vrm, - loadforecast_import, - pvforecast_akkudoktor, - pvforecast_vrm, - pvforecast_import, - weather_brightsky, - weather_clearoutside, - weather_import, + return [ + elecprice_akkudoktor, + elecprice_energy_charts, + elecprice_import, + feedintariff_fixed, + feedintariff_import, + loadforecast_akkudoktor, + loadforecast_akkudoktor_adjusted, + loadforecast_vrm, + loadforecast_import, + pvforecast_akkudoktor, + pvforecast_vrm, + pvforecast_import, + weather_brightsky, + weather_clearoutside, + weather_import, + ] + + +class Prediction(PredictionContainer): + """Prediction container to manage multiple prediction providers.""" + + providers: list[ + Union[ + ElecPriceAkkudoktor, + ElecPriceEnergyCharts, + ElecPriceImport, + FeedInTariffFixed, + FeedInTariffImport, + LoadAkkudoktor, + LoadAkkudoktorAdjusted, + LoadVrm, + LoadImport, + PVForecastAkkudoktor, + PVForecastVrm, + PVForecastImport, + WeatherBrightSky, + WeatherClearOutside, + WeatherImport, ] + ] = Field( + default_factory=prediction_providers, + json_schema_extra={"description": "List of prediction providers"}, ) - return prediction - - -def main() -> None: - """Main function to update and display predictions. - - This function initializes and updates the forecast providers in sequence - according to the `Prediction` instance, then prints the updated prediction data. - """ - prediction = get_prediction() - prediction.update_data() - print(f"Prediction: {prediction}") - - -if __name__ == "__main__": - main() diff --git a/src/akkudoktoreos/prediction/predictionabc.py b/src/akkudoktoreos/prediction/predictionabc.py index dd6d2ae..e99aa51 100644 --- a/src/akkudoktoreos/prediction/predictionabc.py +++ b/src/akkudoktoreos/prediction/predictionabc.py @@ -15,17 +15,17 @@ from pydantic import Field, computed_field from akkudoktoreos.core.coreabc import MeasurementMixin from akkudoktoreos.core.dataabc import ( - DataBase, + DataABC, DataContainer, DataImportProvider, DataProvider, DataRecord, DataSequence, ) -from akkudoktoreos.utils.datetimeutil import DateTime, to_duration +from akkudoktoreos.utils.datetimeutil import DateTime, Duration, to_duration -class PredictionBase(DataBase, MeasurementMixin): +class PredictionABC(DataABC, MeasurementMixin): """Base class for handling prediction data. Enables access to EOS configuration data (attribute `config`) and EOS measurement data @@ -95,7 +95,7 @@ class PredictionSequence(DataSequence): ) -class PredictionStartEndKeepMixin(PredictionBase): +class PredictionStartEndKeepMixin(PredictionABC): """A mixin to manage start, end, and historical retention datetimes for prediction data. The starting datetime for prediction data generation is provided by the energy management @@ -196,6 +196,35 @@ class PredictionProvider(PredictionStartEndKeepMixin, DataProvider): Derived classes have to provide their own records field with correct record type set. """ + def db_keep_datetime(self) -> Optional[DateTime]: + """Earliest datetime from which database records should be retained. + + Used when removing old records from database to free space. + + Subclasses may override this method to provide a domain-specific default. + + Returns: + Datetime or None. + """ + return self.keep_datetime + + def db_initial_time_window(self) -> Optional[Duration]: + """Return the initial time window used for database loading. + + This window defines the initial symmetric time span around a target datetime + that should be loaded from the database when no explicit search time window + is specified. It serves as a loading hint and may be expanded by the caller + if no records are found within the initial range. + + Subclasses may override this method to provide a domain-specific default. + + Returns: + The initial loading time window as a Duration, or ``None`` to indicate + that no initial window constraint should be applied. + """ + hours = max(self.config.prediction.hours, self.config.prediction_historic_hours, 24) + return to_duration(hours * 3600) + def update_data( self, force_enable: Optional[bool] = False, @@ -219,9 +248,6 @@ class PredictionProvider(PredictionStartEndKeepMixin, DataProvider): # Call the custom update logic self._update_data(force_update=force_update) - # Assure records are sorted. - self.sort_by_datetime() - class PredictionImportProvider(PredictionProvider, DataImportProvider): """Abstract base class for prediction providers that import prediction data. diff --git a/src/akkudoktoreos/prediction/pvforecast.py b/src/akkudoktoreos/prediction/pvforecast.py index 91b8060..f0ddb42 100644 --- a/src/akkudoktoreos/prediction/pvforecast.py +++ b/src/akkudoktoreos/prediction/pvforecast.py @@ -5,19 +5,26 @@ from typing import Any, List, Optional, Self from pydantic import Field, computed_field, field_validator, model_validator from akkudoktoreos.config.configabc import SettingsBaseModel -from akkudoktoreos.prediction.prediction import get_prediction +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.pvforecastabc import PVForecastProvider from akkudoktoreos.prediction.pvforecastimport import PVForecastImportCommonSettings from akkudoktoreos.prediction.pvforecastvrm import PVForecastVrmCommonSettings -prediction_eos = get_prediction() -# Valid PV forecast providers -pvforecast_providers = [ - provider.provider_id() - for provider in prediction_eos.providers - if isinstance(provider, PVForecastProvider) -] +def pvforecast_provider_ids() -> list[str]: + """Valid PV forecast providers.""" + try: + prediction_eos = get_prediction() + except: + # Prediction may not be initialized + # Return at least provider used in example + return ["PVForecastAkkudoktor", "PVForecastImport", "PVForecastVrm"] + + return [ + provider.provider_id() + for provider in prediction_eos.providers + if isinstance(provider, PVForecastProvider) + ] class PVForecastPlaneSetting(SettingsBaseModel): @@ -264,16 +271,16 @@ class PVForecastCommonSettings(SettingsBaseModel): @property def providers(self) -> list[str]: """Available PVForecast provider ids.""" - return pvforecast_providers + return pvforecast_provider_ids() # Validators @field_validator("provider", mode="after") @classmethod def validate_provider(cls, value: Optional[str]) -> Optional[str]: - if value is None or value in pvforecast_providers: + if value is None or value in pvforecast_provider_ids(): return value raise ValueError( - f"Provider '{value}' is not a valid PV forecast provider: {pvforecast_providers}." + f"Provider '{value}' is not a valid PV forecast provider: {pvforecast_provider_ids()}." ) ## Computed fields diff --git a/src/akkudoktoreos/prediction/weather.py b/src/akkudoktoreos/prediction/weather.py index 5a280f7..90873b6 100644 --- a/src/akkudoktoreos/prediction/weather.py +++ b/src/akkudoktoreos/prediction/weather.py @@ -5,18 +5,25 @@ from typing import Optional from pydantic import Field, computed_field, field_validator from akkudoktoreos.config.configabc import SettingsBaseModel -from akkudoktoreos.prediction.prediction import get_prediction +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.weatherabc import WeatherProvider from akkudoktoreos.prediction.weatherimport import WeatherImportCommonSettings -prediction_eos = get_prediction() -# Valid weather providers -weather_providers = [ - provider.provider_id() - for provider in prediction_eos.providers - if isinstance(provider, WeatherProvider) -] +def weather_provider_ids() -> list[str]: + """Valid weather provider ids.""" + try: + prediction_eos = get_prediction() + except: + # Prediction may not be initialized + # Return at least provider used in example + return ["WeatherImport"] + + return [ + provider.provider_id() + for provider in prediction_eos.providers + if isinstance(provider, WeatherProvider) + ] class WeatherCommonProviderSettings(SettingsBaseModel): @@ -56,14 +63,14 @@ class WeatherCommonSettings(SettingsBaseModel): @property def providers(self) -> list[str]: """Available weather provider ids.""" - return weather_providers + return weather_provider_ids() # Validators @field_validator("provider", mode="after") @classmethod def validate_provider(cls, value: Optional[str]) -> Optional[str]: - if value is None or value in weather_providers: + if value is None or value in weather_provider_ids(): return value raise ValueError( - f"Provider '{value}' is not a valid weather provider: {weather_providers}." + f"Provider '{value}' is not a valid weather provider: {weather_provider_ids()}." ) diff --git a/src/akkudoktoreos/server/dash/admin.py b/src/akkudoktoreos/server/dash/admin.py index ce50ec0..69898b3 100644 --- a/src/akkudoktoreos/server/dash/admin.py +++ b/src/akkudoktoreos/server/dash/admin.py @@ -402,6 +402,75 @@ def AdminConfig( ) +def AdminDatabase( + eos_host: str, eos_port: Union[str, int], data: Optional[dict], config: Optional[dict[str, Any]] +) -> tuple[str, Union[Card, list[Card]]]: + """Creates a cache management card. + + Args: + eos_host (str): The hostname of the EOS server. + eos_port (Union[str, int]): The port of the EOS server. + data (Optional[dict]): Incoming data containing action and category for processing. + + Returns: + tuple[str, Union[Card, list[Card]]]: A tuple containing the cache category label and the `Card` UI component. + """ + server = f"http://{eos_host}:{eos_port}" + eos_hostname = "EOS server" + eosdash_hostname = "EOSdash server" + + category = "database" + + status_vacuum = (None,) + if data and data.get("category", None) == category: + # This data is for us + if data["action"] == "vacuum": + # Remove old records from database + try: + result = requests.post(f"{server}/v1/admin/database/vacuum", timeout=30) + result.raise_for_status() + status_vacuum = Success( + f"Removed old data records from database on '{eos_hostname}'" + ) + except requests.exceptions.HTTPError as e: + detail = result.json()["detail"] + status_vacuum = Error( + f"Can not remove old data records from database on '{eos_hostname}': {e}, {detail}" + ) + except Exception as e: + status_vacuum = Error( + f"Can not remove old data records from database on '{eos_hostname}': {e}" + ) + + return ( + category, + [ + Card( + Details( + Summary( + Grid( + DivHStacked( + UkIcon(icon="play"), + ConfigButton( + "Vacuum", + hx_post=request_url_for("/eosdash/admin"), + hx_target="#page-content", + hx_swap="innerHTML", + hx_vals='{"category": "database", "action": "vacuum"}', + ), + P(f"Remove old data records from database on '{eos_hostname}'"), + ), + status_vacuum, + ), + cls="list-none", + ), + P(f"Remove old data records from database on '{eos_hostname}'."), + ), + ), + ], + ) + + def Admin(eos_host: str, eos_port: Union[str, int], data: Optional[dict] = None) -> Div: """Generates the administrative dashboard layout. @@ -450,6 +519,7 @@ def Admin(eos_host: str, eos_port: Union[str, int], data: Optional[dict] = None) for category, admin in [ AdminCache(eos_host, eos_port, data, config), AdminConfig(eos_host, eos_port, data, config, config_backup), + AdminDatabase(eos_host, eos_port, data, config), ]: if category != last_category: rows.append(H3(category)) diff --git a/src/akkudoktoreos/server/dash/footer.py b/src/akkudoktoreos/server/dash/footer.py index 64defd7..9899829 100644 --- a/src/akkudoktoreos/server/dash/footer.py +++ b/src/akkudoktoreos/server/dash/footer.py @@ -7,7 +7,7 @@ from monsterui.franken import A, ButtonT, DivFullySpaced, P from requests.exceptions import RequestException import akkudoktoreos.server.dash.eosstatus as eosstatus -from akkudoktoreos.config.config import get_config +from akkudoktoreos.core.coreabc import get_config def get_alive(eos_host: str, eos_port: Union[str, int]) -> str: diff --git a/src/akkudoktoreos/server/dash/plan.py b/src/akkudoktoreos/server/dash/plan.py index a488d67..50efad5 100644 --- a/src/akkudoktoreos/server/dash/plan.py +++ b/src/akkudoktoreos/server/dash/plan.py @@ -206,13 +206,20 @@ def SolutionCard(solution: OptimizationSolution, config: SettingsEOS, data: Opti else: continue # Adjust to similar y-axis 0-point + values_min_max = [ + (energy_wh_min, energy_wh_max), + (amt_kwh_min, amt_kwh_max), + (amt_min, amt_max), + (soc_factor_min, soc_factor_max), + ] # First get the maximum factor for the min value related the maximum value - min_max_factor = max( - (energy_wh_min * -1.0) / energy_wh_max, - (amt_kwh_min * -1.0) / amt_kwh_max, - (amt_min * -1.0) / amt_max, - (soc_factor_min * -1.0) / soc_factor_max, - ) + min_max_factor = 0.0 + for value_min, value_max in values_min_max: + if value_max > 0: + value_factor = (value_min * -1.0) / value_max + if value_factor > min_max_factor: + min_max_factor = value_factor + # Adapt the min values to have the same relative min/max factor on all y-axis energy_wh_min = min_max_factor * energy_wh_max * -1.0 amt_kwh_min = min_max_factor * amt_kwh_max * -1.0 diff --git a/src/akkudoktoreos/server/eos.py b/src/akkudoktoreos/server/eos.py index 26dfe2d..8d09516 100755 --- a/src/akkudoktoreos/server/eos.py +++ b/src/akkudoktoreos/server/eos.py @@ -26,12 +26,19 @@ from fastapi.responses import ( ) from loguru import logger -from akkudoktoreos.config.config import ConfigEOS, SettingsEOS, get_config -from akkudoktoreos.core.cache import CacheFileStore +from akkudoktoreos.config.config import ConfigEOS, SettingsEOS +from akkudoktoreos.core.cache import CacheFileStore, cache_clear, cache_load, cache_save +from akkudoktoreos.core.coreabc import ( + get_config, + get_ems, + get_measurement, + get_prediction, + get_resource_registry, + singletons_init, +) from akkudoktoreos.core.emplan import EnergyManagementPlan, ResourceStatus -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.ems import ems_manage_energy from akkudoktoreos.core.emsettings import EnergyManagementMode -from akkudoktoreos.core.logabc import LOGGING_LEVELS from akkudoktoreos.core.logging import logging_track_config, read_file_log from akkudoktoreos.core.pydantic import ( PydanticBaseModel, @@ -40,8 +47,7 @@ from akkudoktoreos.core.pydantic import ( PydanticDateTimeSeries, ) from akkudoktoreos.core.version import __version__ -from akkudoktoreos.devices.devices import ResourceKey, get_resource_registry -from akkudoktoreos.measurement.measurement import get_measurement +from akkudoktoreos.devices.devices import ResourceKey from akkudoktoreos.optimization.genetic.geneticparams import ( GeneticOptimizationParameters, ) @@ -50,152 +56,19 @@ from akkudoktoreos.optimization.optimization import OptimizationSolution from akkudoktoreos.prediction.elecprice import ElecPriceCommonSettings from akkudoktoreos.prediction.load import LoadCommonProviderSettings, LoadCommonSettings from akkudoktoreos.prediction.loadakkudoktor import LoadAkkudoktorCommonSettings -from akkudoktoreos.prediction.prediction import get_prediction from akkudoktoreos.prediction.pvforecast import PVForecastCommonSettings +from akkudoktoreos.server.rest.cli import cli_apply_args_to_config, cli_parse_args from akkudoktoreos.server.rest.error import create_error_page from akkudoktoreos.server.rest.starteosdash import run_eosdash_supervisor -from akkudoktoreos.server.rest.tasks import repeat_every +from akkudoktoreos.server.rest.tasks import make_repeated_task +from akkudoktoreos.server.retentionmanager import RetentionManager from akkudoktoreos.server.server import ( drop_root_privileges, fix_data_directories_permissions, - get_default_host, get_host_ip, wait_for_port_free, ) from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration -from akkudoktoreos.utils.stringutil import str2bool - -config_eos = get_config() -measurement_eos = get_measurement() -prediction_eos = get_prediction() -ems_eos = get_ems() -resource_registry_eos = get_resource_registry() - -# ------------------------------------ -# Logging configuration at import time -# ------------------------------------ - -logger.remove() -logging_track_config(config_eos, "logging", None, None) - -# ----------------------------- -# Configuration change tracking -# ----------------------------- - -config_eos.track_nested_value("/logging", logging_track_config) - -# ---------------------------- -# Safe argparse at import time -# ---------------------------- - -parser = argparse.ArgumentParser(description="Start EOS server.") - -parser.add_argument( - "--host", - type=str, - help="Host for the EOS server (default: value from config)", -) -parser.add_argument( - "--port", - type=int, - help="Port for the EOS server (default: value from config)", -) -parser.add_argument( - "--log_level", - type=str, - default="none", - help='Log level for the server console. Options: "critical", "error", "warning", "info", "debug", "trace" (default: "none")', -) -parser.add_argument( - "--reload", - type=str2bool, - default=False, - help="Enable or disable auto-reload. Useful for development. Options: True or False (default: False)", -) -parser.add_argument( - "--startup_eosdash", - type=str2bool, - default=None, - help="Enable or disable automatic EOSdash startup. Options: True or False (default: value from config)", -) -parser.add_argument( - "--run_as_user", - type=str, - help="The unprivileged user account the EOS server shall switch to after performing root-level startup tasks.", -) - -# Command line arguments -args: argparse.Namespace -args_unknown: list[str] -args, args_unknown = parser.parse_known_args() - - -# ----------------------------- -# Prepare config at import time -# ----------------------------- - -# Set config to actual environment variable & config file content -config_eos.reset_settings() - -# Setup parameters from args, config_eos and default -# Remember parameters in config - -# Setup EOS logging level - first to have the other logging messages logged -if args and args.log_level is not None: - log_level = args.log_level.upper() - # Ensure log_level from command line is in config settings - if log_level in LOGGING_LEVELS: - # Setup console logging level using nested value - # - triggers logging configuration by logging_track_config - config_eos.set_nested_value("logging/console_level", log_level) - logger.debug(f"logging/console_level configuration set by argument to {log_level}") - -# Setup EOS server host -if args and args.host: - host = args.host - logger.debug(f"server/host configuration set by argument to {host}") -elif config_eos.server.host: - host = config_eos.server.host -else: - host = get_default_host() -# Ensure host from command line is in config settings -config_eos.set_nested_value("server/host", host) - -# Setup EOS server port -if args and args.port: - port = args.port - logger.debug(f"server/port configuration set by argument to {port}") -elif config_eos.server.port: - port = config_eos.server.port -else: - port = 8503 -# Ensure port from command line is in config settings -config_eos.set_nested_value("server/port", port) - -# Setup EOS reload for development -if args is None or args.reload is None: - reload = False -else: - logger.debug(f"reload set by argument to {args.reload}") - reload = args.reload - -# Setup EOSdash startup -if args and args.startup_eosdash is not None: - # Ensure startup_eosdash from command line is in config settings - config_eos.set_nested_value("server/startup_eosdash", args.startup_eosdash) - logger.debug(f"server/startup_eosdash configuration set by argument to {args.startup_eosdash}") - -if config_eos.server.startup_eosdash: - # Ensure EOSdash host and port config settings are at least set to default values - - # Setup EOS server host - if config_eos.server.eosdash_host is None: - config_eos.set_nested_value("server/eosdash_host", host) - - # Setup EOS server host - if config_eos.server.eosdash_port is None: - config_eos.set_nested_value("server/eosdash_port", port + 1) - # ---------------------- # EOS REST Server @@ -204,14 +77,18 @@ if config_eos.server.startup_eosdash: def save_eos_state() -> None: """Save EOS state.""" - resource_registry_eos.save() + get_resource_registry().save() + get_prediction().save() + get_measurement().save() cache_save() # keep last def load_eos_state() -> None: """Load EOS state.""" cache_load() # keep first - resource_registry_eos.load() + get_measurement().load() + get_prediction().load() + get_resource_registry().load() def terminate_eos() -> None: @@ -225,58 +102,25 @@ def terminate_eos() -> None: logger.info(f"🚀 EOS terminated, PID {pid}") -def cache_clear(clear_all: Optional[bool] = None) -> None: - """Cleanup expired cache files.""" - if clear_all: - CacheFileStore().clear(clear_all=True) - else: - CacheFileStore().clear(before_datetime=to_datetime()) +def save_eos_database() -> None: + """Save EOS database.""" + get_prediction().save() + get_measurement().save() -def cache_load() -> dict: - """Load cache from cachefilestore.json.""" - return CacheFileStore().load_store() - - -def cache_save() -> dict: - """Save cache to cachefilestore.json.""" - return CacheFileStore().save_store() - - -def cache_cleanup_on_exception(e: Exception) -> None: - logger.error("Cache cleanup task caught an exception: {}", e, exc_info=True) - - -@repeat_every( - seconds=float(config_eos.cache.cleanup_interval), - on_exception=cache_cleanup_on_exception, -) -def cache_cleanup_task() -> None: - """Repeating task to clear cache from expired cache files.""" - logger.debug("Clear cache") - cache_clear() - - -def energy_management_on_exception(e: Exception) -> None: - logger.error("Energy management task caught an exception: {}", e, exc_info=True) - - -@repeat_every( - seconds=10, - wait_first=config_eos.ems.startup_delay, - on_exception=energy_management_on_exception, -) -async def energy_management_task() -> None: - """Repeating task for energy management.""" - logger.debug("Check EMS run") - await ems_eos.manage_energy() +def compact_eos_database() -> None: + """Compact EOS database.""" + get_prediction().db_compact() + get_measurement().db_compact() + get_prediction().db_vacuum() + get_measurement().db_vacuum() async def server_shutdown_task() -> None: """One-shot task for shutting down the EOS server. This coroutine performs the following actions: - 1. Ensures the cache is saved by calling the cache_save function. + 1. Ensures the EOS state is saved by calling the save_eos_state function. 2. Waits for 5 seconds to allow the EOS server to complete any ongoing tasks. 3. Gracefully shuts down the current process by sending the appropriate signal. @@ -303,16 +147,29 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: load_eos_state() - # Start EOS tasks - if config_eos.cache.cleanup_interval is None: - logger.warning("Cache file cleanup disabled. Set cache.cleanup_interval.") - else: - await cache_cleanup_task() - await energy_management_task() + config_eos = get_config() + + # Prepare the Manager and all task that are handled by the manager + manager = RetentionManager(config_getter=get_config().get_nested_value, shutdown_timeout=10) + manager.register("cache_clear", cache_clear, interval_attr="cache/cleanup_interval") + manager.register( + "save_eos_database", save_eos_database, interval_attr="database/autosave_interval_sec" + ) + manager.register( + "compact_eos_database", save_eos_database, interval_attr="database/compact_interval_sec" + ) + manager.register("manage_energy", ems_manage_energy, interval_attr="ems/interval") + + # Start EOS repeated tasks + tick_task = make_repeated_task(manager.tick, seconds=5, wait_first=2) + await tick_task() # Handover to application yield + # waits for any in-flight job to finish cleanly + await manager.shutdown() + # On shutdown save_eos_state() @@ -410,6 +267,46 @@ def fastapi_admin_cache_get() -> dict: return data +@app.get("/v1/admin/database/stats", tags=["admin"]) +def fastapi_admin_database_stats_get() -> dict: + """Get statistics from database. + + Returns: + data (dict): The database statistics + """ + data = {} + try: + # Get the stats + data[get_measurement().db_namespace()] = get_measurement().db_get_stats() + data[get_prediction().__class__.__name__] = get_prediction().db_get_stats() + except Exception as e: + trace = "".join(traceback.TracebackException.from_exception(e).format()) + raise HTTPException( + status_code=400, detail=f"Error on database statistic retrieval: {e}\n{trace}" + ) + return data + + +@app.post("/v1/admin/database/vacuum", tags=["admin"]) +def fastapi_admin_database_vacuum_post() -> dict: + """Remove old records from database. + + Returns: + data (dict): The database stats after removal of old records. + """ + data = {} + try: + get_measurement().db_vacuum() + get_prediction().db_vacuum() + # Get the stats + data[get_measurement().db_namespace()] = get_measurement().db_get_stats() + data[get_prediction().__class__.__name__] = get_prediction().db_get_stats() + except Exception as e: + trace = "".join(traceback.TracebackException.from_exception(e).format()) + raise HTTPException(status_code=400, detail=f"Error on database vacuum: {e}\n{trace}") + return data + + @app.post("/v1/admin/server/restart", tags=["admin"]) async def fastapi_admin_server_restart_post() -> dict: """Restart the server. @@ -424,8 +321,8 @@ async def fastapi_admin_server_restart_post() -> dict: # Force a new process group to make the new process easily distinguishable from the current one # Set environment before any subprocess run, to keep custom config dir env = os.environ.copy() - env["EOS_DIR"] = str(config_eos.general.data_folder_path) - env["EOS_CONFIG_DIR"] = str(config_eos.general.config_folder_path) + env["EOS_DIR"] = str(get_config().general.data_folder_path) + env["EOS_CONFIG_DIR"] = str(get_config().general.config_folder_path) if os.name == "nt": # Windows @@ -491,8 +388,8 @@ def fastapi_health_get(): # type: ignore "pid": psutil.Process().pid, "version": __version__, "energy-management": { - "start_datetime": to_datetime(ems_eos.start_datetime, as_string=True), - "last_run_datetime": to_datetime(ems_eos.last_run_datetime, as_string=True), + "start_datetime": to_datetime(get_ems().start_datetime, as_string=True), + "last_run_datetime": to_datetime(get_ems().last_run_datetime, as_string=True), }, } ) @@ -506,13 +403,13 @@ def fastapi_config_reset_post() -> ConfigEOS: configuration (ConfigEOS): The current configuration after update. """ try: - config_eos.reset_settings() + get_config().reset_settings() except Exception as e: raise HTTPException( status_code=404, detail=f"Cannot reset configuration: {e}", ) - return config_eos + return get_config() @app.get("/v1/config/backup", tags=["config"]) @@ -523,7 +420,7 @@ def fastapi_config_backup_get() -> dict[str, dict[str, Any]]: dict[str, dict[str, Any]]: Mapping of backup identifiers to metadata. """ try: - result = config_eos.list_backups() + result = get_config().list_backups() except Exception as e: raise HTTPException( status_code=404, @@ -542,13 +439,13 @@ def fastapi_config_revert_put( configuration (ConfigEOS): The current configuration after revert. """ try: - config_eos.revert_settings(backup_id) + get_config().revert_settings(backup_id) except Exception as e: raise HTTPException( status_code=400, detail=f"Error on reverting of configuration: {e}", ) - return config_eos + return get_config() @app.put("/v1/config/file", tags=["config"]) @@ -559,13 +456,13 @@ def fastapi_config_file_put() -> ConfigEOS: configuration (ConfigEOS): The current configuration that was saved. """ try: - config_eos.to_config_file() + get_config().to_config_file() except: raise HTTPException( status_code=404, - detail=f"Cannot save configuration to file '{config_eos.config_file_path}'.", + detail=f"Cannot save configuration to file '{get_config().config_file_path}'.", ) - return config_eos + return get_config() @app.get("/v1/config", tags=["config"]) @@ -575,7 +472,7 @@ def fastapi_config_get() -> ConfigEOS: Returns: configuration (ConfigEOS): The current configuration. """ - return config_eos + return get_config() @app.put("/v1/config", tags=["config"]) @@ -593,10 +490,10 @@ def fastapi_config_put(settings: SettingsEOS) -> ConfigEOS: configuration (ConfigEOS): The current configuration after the write. """ try: - config_eos.merge_settings(settings) + get_config().merge_settings(settings) except Exception as e: raise HTTPException(status_code=400, detail=f"Error on update of configuration: {e}") - return config_eos + return get_config() @app.put("/v1/config/{path:path}", tags=["config"]) @@ -618,7 +515,7 @@ def fastapi_config_put_key( configuration (ConfigEOS): The current configuration after the update. """ try: - config_eos.set_nested_value(path, value) + get_config().set_nested_value(path, value) except Exception as e: trace = "".join(traceback.TracebackException.from_exception(e).format()) raise HTTPException( @@ -626,7 +523,7 @@ def fastapi_config_put_key( detail=f"Error on update of configuration '{path}','{value}': {e}\n{trace}", ) - return config_eos + return get_config() @app.get("/v1/config/{path:path}", tags=["config"]) @@ -644,7 +541,7 @@ def fastapi_config_get_key( value (Any): The value of the selected nested key. """ try: - return config_eos.get_nested_value(path) + return get_config().get_nested_value(path) except IndexError as e: raise HTTPException(status_code=400, detail=str(e)) except KeyError as e: @@ -682,7 +579,7 @@ async def fastapi_logging_get_log( Returns: JSONResponse: A JSON list of log entries. """ - log_path = config_eos.logging.file_path + log_path = get_config().logging.file_path try: logs = read_file_log( log_path=log_path, @@ -710,9 +607,9 @@ def fastapi_devices_status_get( latest_status: The latest status of a resource/ device. """ key = ResourceKey(resource_id=resource_id, actuator_id=actuator_id) - if not resource_registry_eos.status_exists(key): + if not get_resource_registry().status_exists(key): raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") - status_latest = resource_registry_eos.status_latest(key) + status_latest = get_resource_registry().status_latest(key) if status_latest is None: raise HTTPException(status_code=404, detail=f"Key '{key}' does not have a status.") return status_latest @@ -731,13 +628,13 @@ def fastapi_devices_status_put( """ key = ResourceKey(resource_id=resource_id, actuator_id=actuator_id) try: - resource_registry_eos.update_status(key, status) + get_resource_registry().update_status(key, status) except Exception as e: raise HTTPException( status_code=400, detail=f"Error on resource status update key='{key}', status='{status}': {e}", ) - status_latest = resource_registry_eos.status_latest(key) + status_latest = get_resource_registry().status_latest(key) if status_latest is None: raise HTTPException(status_code=404, detail=f"Key '{key}' does not have a status.") return status_latest @@ -746,7 +643,7 @@ def fastapi_devices_status_put( @app.get("/v1/measurement/keys", tags=["measurement"]) def fastapi_measurement_keys_get() -> list[str]: """Get a list of available measurement keys.""" - return sorted(measurement_eos.record_keys) + return sorted(get_measurement().record_keys) @app.get("/v1/measurement/series", tags=["measurement"]) @@ -754,9 +651,9 @@ def fastapi_measurement_series_get( key: Annotated[str, Query(description="Measurement key.")], ) -> PydanticDateTimeSeries: """Get the measurements of given key as series.""" - if key not in measurement_eos.record_keys: + if key not in get_measurement().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") - pdseries = measurement_eos.key_to_series(key=key) + pdseries = get_measurement().key_to_series(key=key) return PydanticDateTimeSeries.from_series(pdseries) @@ -767,7 +664,7 @@ def fastapi_measurement_value_put( value: Union[float | str], ) -> PydanticDateTimeSeries: """Merge the measurement of given key and value into EOS measurements at given datetime.""" - if key not in measurement_eos.record_keys: + if key not in get_measurement().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") if isinstance(value, str): # Try to convert to float @@ -777,8 +674,8 @@ def fastapi_measurement_value_put( logger.debug( f'/v1/measurement/value key: {key} value: "{value}" - string value not convertable to float' ) - measurement_eos.update_value(datetime, key, value) - pdseries = measurement_eos.key_to_series(key=key) + get_measurement().update_value(datetime, key, value) + pdseries = get_measurement().key_to_series(key=key) return PydanticDateTimeSeries.from_series(pdseries) @@ -787,11 +684,11 @@ def fastapi_measurement_series_put( key: Annotated[str, Query(description="Measurement key.")], series: PydanticDateTimeSeries ) -> PydanticDateTimeSeries: """Merge measurement given as series into given key.""" - if key not in measurement_eos.record_keys: + if key not in get_measurement().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") pdseries = series.to_series() # make pandas series from PydanticDateTimeSeries - measurement_eos.key_from_series(key=key, series=pdseries) - pdseries = measurement_eos.key_to_series(key=key) + get_measurement().key_from_series(key=key, series=pdseries) + pdseries = get_measurement().key_to_series(key=key) return PydanticDateTimeSeries.from_series(pdseries) @@ -799,14 +696,14 @@ def fastapi_measurement_series_put( def fastapi_measurement_dataframe_put(data: PydanticDateTimeDataFrame) -> None: """Merge the measurement data given as dataframe into EOS measurements.""" dataframe = data.to_dataframe() - measurement_eos.import_from_dataframe(dataframe) + get_measurement().import_from_dataframe(dataframe) @app.put("/v1/measurement/data", tags=["measurement"]) def fastapi_measurement_data_put(data: PydanticDateTimeData) -> None: """Merge the measurement data given as datetime data into EOS measurements.""" datetimedata = data.to_dict() - measurement_eos.import_from_dict(datetimedata) + get_measurement().import_from_dict(datetimedata) @app.get("/v1/prediction/providers", tags=["prediction"]) @@ -823,7 +720,7 @@ def fastapi_prediction_providers_get(enabled: Optional[bool] = None) -> list[str return sorted( [ provider.provider_id() - for provider in prediction_eos.providers + for provider in get_prediction().providers if provider.enabled() in enabled_status ] ) @@ -832,7 +729,7 @@ def fastapi_prediction_providers_get(enabled: Optional[bool] = None) -> list[str @app.get("/v1/prediction/keys", tags=["prediction"]) def fastapi_prediction_keys_get() -> list[str]: """Get a list of available prediction keys.""" - return sorted(prediction_eos.record_keys) + return sorted(get_prediction().record_keys) @app.get("/v1/prediction/series", tags=["prediction"]) @@ -856,17 +753,17 @@ def fastapi_prediction_series_get( end_datetime (Optional[str]: Ending datetime (exclusive). Defaults to end datetime of latest prediction. """ - if key not in prediction_eos.record_keys: + if key not in get_prediction().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") if start_datetime is None: - start_datetime = prediction_eos.ems_start_datetime + start_datetime = get_prediction().ems_start_datetime else: start_datetime = to_datetime(start_datetime) if end_datetime is None: - end_datetime = prediction_eos.end_datetime + end_datetime = get_prediction().end_datetime else: end_datetime = to_datetime(end_datetime) - pdseries = prediction_eos.key_to_series( + pdseries = get_prediction().key_to_series( key=key, start_datetime=start_datetime, end_datetime=end_datetime ) return PydanticDateTimeSeries.from_series(pdseries) @@ -899,20 +796,20 @@ def fastapi_prediction_dataframe_get( Defaults to end datetime of latest prediction. """ for key in keys: - if key not in prediction_eos.record_keys: + if key not in get_prediction().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") if start_datetime is None: - start_datetime = prediction_eos.ems_start_datetime + start_datetime = get_prediction().ems_start_datetime else: start_datetime = to_datetime(start_datetime) if end_datetime is None: - end_datetime = prediction_eos.end_datetime + end_datetime = get_prediction().end_datetime else: end_datetime = to_datetime(end_datetime) - df = prediction_eos.keys_to_dataframe( + df = get_prediction().keys_to_dataframe( keys=keys, start_datetime=start_datetime, end_datetime=end_datetime, interval=interval ) - return PydanticDateTimeDataFrame.from_dataframe(df, tz=config_eos.general.timezone) + return PydanticDateTimeDataFrame.from_dataframe(df, tz=get_config().general.timezone) @app.get("/v1/prediction/list", tags=["prediction"]) @@ -942,26 +839,30 @@ def fastapi_prediction_list_get( interval (Optional[str]): Time duration for each interval. Defaults to 1 hour. """ - if key not in prediction_eos.record_keys: + if key not in get_prediction().record_keys: raise HTTPException(status_code=404, detail=f"Key '{key}' is not available.") if start_datetime is None: - start_datetime = prediction_eos.ems_start_datetime + start_datetime = get_prediction().ems_start_datetime else: start_datetime = to_datetime(start_datetime) if end_datetime is None: - end_datetime = prediction_eos.end_datetime + end_datetime = get_prediction().end_datetime else: end_datetime = to_datetime(end_datetime) if interval is None: interval = to_duration("1 hour") else: interval = to_duration(interval) - prediction_list = prediction_eos.key_to_array( - key=key, - start_datetime=start_datetime, - end_datetime=end_datetime, - interval=interval, - ).tolist() + prediction_list = ( + get_prediction() + .key_to_array( + key=key, + start_datetime=start_datetime, + end_datetime=end_datetime, + interval=interval, + ) + .tolist() + ) return prediction_list @@ -980,14 +881,14 @@ def fastapi_prediction_import_provider( Defaults to False. """ try: - provider = prediction_eos.provider_by_id(provider_id) + provider = get_prediction().provider_by_id(provider_id) except ValueError: raise HTTPException(status_code=404, detail=f"Provider '{provider_id}' not found.") if not provider.enabled() and not force_enable: raise HTTPException(status_code=404, detail=f"Provider '{provider_id}' not enabled.") try: provider.import_from_json(json_str=json.dumps(data)) - provider.update_datetime = to_datetime(in_timezone=config_eos.general.timezone) + provider.update_datetime = to_datetime(in_timezone=get_config().general.timezone) except Exception as e: raise HTTPException( status_code=400, detail=f"Error on import for provider '{provider_id}': {e}" @@ -1009,7 +910,7 @@ async def fastapi_prediction_update( """ # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=force_update, force_enable=force_enable, @@ -1038,13 +939,13 @@ async def fastapi_prediction_update_provider( Defaults to False. """ try: - provider = prediction_eos.provider_by_id(provider_id) + provider = get_prediction().provider_by_id(provider_id) except ValueError: raise HTTPException(status_code=404, detail=f"Provider '{provider_id}' not found.") # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=force_update, force_enable=force_enable, @@ -1062,7 +963,7 @@ async def fastapi_prediction_update_provider( @app.get("/v1/energy-management/optimization/solution", tags=["energy-management"]) def fastapi_energy_management_optimization_solution_get() -> OptimizationSolution: """Get the latest solution of the optimization.""" - solution = ems_eos.optimization_solution() + solution = get_ems().optimization_solution() if solution is None: raise HTTPException( status_code=404, @@ -1074,7 +975,7 @@ def fastapi_energy_management_optimization_solution_get() -> OptimizationSolutio @app.get("/v1/energy-management/plan", tags=["energy-management"]) def fastapi_energy_management_plan_get() -> EnergyManagementPlan: """Get the latest energy management plan.""" - plan = ems_eos.plan() + plan = get_ems().plan() if plan is None: raise HTTPException( status_code=404, @@ -1106,11 +1007,11 @@ async def fastapi_strompreis() -> list[float]: provider="ElecPriceAkkudoktor", ) ) - config_eos.merge_settings(settings=settings) + get_config().merge_settings(settings=settings) # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=True, ) @@ -1124,11 +1025,15 @@ async def fastapi_strompreis() -> list[float]: start_datetime = to_datetime().start_of("day") end_datetime = start_datetime.add(days=2) try: - elecprice = prediction_eos.key_to_array( - key="elecprice_marketprice_wh", - start_datetime=start_datetime, - end_datetime=end_datetime, - ).tolist() + elecprice = ( + get_prediction() + .key_to_array( + key="elecprice_marketprice_wh", + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + .tolist() + ) except Exception as e: raise HTTPException( status_code=404, @@ -1178,12 +1083,12 @@ async def fastapi_gesamtlast(request: GesamtlastRequest) -> list[float]: "load_emr_keys": ["gesamtlast_emr"], }, } - config_eos.merge_settings_from_dict(settings) + get_config().merge_settings_from_dict(settings) # Insert measured data into EOS measurement # Convert from energy per interval to dummy energy meter readings measurement_key = "gesamtlast_emr" - measurement_eos.key_delete_by_datetime( + get_measurement().key_delete_by_datetime( key=measurement_key ) # delete all gesamtlast_emr measurements energy = {} @@ -1210,11 +1115,11 @@ async def fastapi_gesamtlast(request: GesamtlastRequest) -> list[float]: energy_mr_values.append(0.0) energy_mr_dates.append(dt) energy_mr_values.append(energy_mr) - measurement_eos.key_from_lists(measurement_key, energy_mr_dates, energy_mr_values) + get_measurement().key_from_lists(measurement_key, energy_mr_dates, energy_mr_values) # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=True, ) @@ -1228,11 +1133,15 @@ async def fastapi_gesamtlast(request: GesamtlastRequest) -> list[float]: start_datetime = to_datetime().start_of("day") end_datetime = start_datetime.add(days=2) try: - prediction_list = prediction_eos.key_to_array( - key="loadforecast_power_w", - start_datetime=start_datetime, - end_datetime=end_datetime, - ).tolist() + prediction_list = ( + get_prediction() + .key_to_array( + key="loadforecast_power_w", + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + .tolist() + ) except Exception as e: raise HTTPException( status_code=404, @@ -1271,11 +1180,11 @@ async def fastapi_gesamtlast_simple(year_energy: float) -> list[float]: ), ) ) - config_eos.merge_settings(settings=settings) + get_config().merge_settings(settings=settings) # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=True, ) @@ -1289,11 +1198,15 @@ async def fastapi_gesamtlast_simple(year_energy: float) -> list[float]: start_datetime = to_datetime().start_of("day") end_datetime = start_datetime.add(days=2) try: - prediction_list = prediction_eos.key_to_array( - key="loadforecast_power_w", - start_datetime=start_datetime, - end_datetime=end_datetime, - ).tolist() + prediction_list = ( + get_prediction() + .key_to_array( + key="loadforecast_power_w", + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + .tolist() + ) except Exception as e: raise HTTPException( status_code=404, @@ -1326,11 +1239,11 @@ async def fastapi_pvforecast() -> ForecastResponse: '/v1/prediction/list?key=pvforecastakkudoktor_temp_air' instead. """ settings = SettingsEOS(pvforecast=PVForecastCommonSettings(provider="PVForecastAkkudoktor")) - config_eos.merge_settings(settings=settings) + get_config().merge_settings(settings=settings) # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( mode=EnergyManagementMode.PREDICTION, force_update=True, ) @@ -1344,16 +1257,24 @@ async def fastapi_pvforecast() -> ForecastResponse: start_datetime = to_datetime().start_of("day") end_datetime = start_datetime.add(days=2) try: - ac_power = prediction_eos.key_to_array( - key="pvforecast_ac_power", - start_datetime=start_datetime, - end_datetime=end_datetime, - ).tolist() - temp_air = prediction_eos.key_to_array( - key="pvforecastakkudoktor_temp_air", - start_datetime=start_datetime, - end_datetime=end_datetime, - ).tolist() + ac_power = ( + get_prediction() + .key_to_array( + key="pvforecast_ac_power", + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + .tolist() + ) + temp_air = ( + get_prediction() + .key_to_array( + key="pvforecastakkudoktor_temp_air", + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + .tolist() + ) except Exception as e: raise HTTPException( status_code=404, @@ -1388,7 +1309,7 @@ async def fastapi_optimize( # Ensure there is only one optimization/ energy management run at a time try: - await ems_eos.run( + await get_ems().run( start_datetime=start_datetime, mode=EnergyManagementMode.OPTIMIZATION, genetic_parameters=parameters, @@ -1397,7 +1318,7 @@ async def fastapi_optimize( except Exception as e: raise HTTPException(status_code=400, detail=f"Optimize error: {e}.") - solution = ems_eos.genetic_solution() + solution = get_ems().genetic_solution() if solution is None: raise HTTPException(status_code=400, detail="Optimize error: no solution stored by run.") @@ -1407,7 +1328,7 @@ async def fastapi_optimize( @app.get("/visualization_results.pdf", response_class=PdfResponse, tags=["optimize"]) def get_pdf() -> PdfResponse: # Endpoint to serve the generated PDF with visualization results - output_path = config_eos.general.data_output_path + output_path = get_config().general.data_output_path if output_path is None or not output_path.is_dir(): raise HTTPException(status_code=404, detail=f"Output path does not exist: {output_path}.") file_path = output_path / "visualization_results.pdf" @@ -1447,11 +1368,11 @@ async def redirect_put(request: Request, path: str) -> Response: def redirect(request: Request, path: str) -> Union[HTMLResponse, RedirectResponse]: # Path is not for EOSdash if not (path.startswith("eosdash") or path == ""): - host = config_eos.server.eosdash_host + host = get_config().server.eosdash_host if host is None: - host = config_eos.server.host + host = get_config().server.host host = str(host) - port = config_eos.server.eosdash_port + port = get_config().server.eosdash_port if port is None: port = 8504 if host == "0.0.0.0": # noqa: S104 @@ -1470,13 +1391,13 @@ Did you want to connect to EOSdash? ) return HTMLResponse(content=error_page, status_code=404) - host = str(config_eos.server.eosdash_host) + host = str(get_config().server.eosdash_host) if host == "0.0.0.0": # noqa: S104 # Use IP of EOS host host = get_host_ip() - if host and config_eos.server.eosdash_port: + if host and get_config().server.eosdash_port: # Redirect to EOSdash server - url = f"http://{host}:{config_eos.server.eosdash_port}/{path}" + url = f"http://{host}:{get_config().server.eosdash_port}/{path}" return RedirectResponse(url=url, status_code=303) # Redirect the root URL to the site map @@ -1492,8 +1413,35 @@ def run_eos() -> None: Returns: None """ + # get_config(init=True) creates the configuration + # this should not be done before nor later + config_eos = get_config(init=True) + + # set logging to what is in config + logger.remove() + logging_track_config(config_eos, "logging", None, None) + + # make logger track logging changes in config + config_eos.track_nested_value("/logging", logging_track_config) + + # Set config to actual environment variable & config file content + config_eos.reset_settings() + + # add arguments to config + args: argparse.Namespace + args_unknown: list[str] + args, args_unknown = cli_parse_args() + cli_apply_args_to_config(args) + + # prepare runtime arguments if args: run_as_user = args.run_as_user + # Setup EOS reload for development + if args.reload is None: + reload = False + else: + logger.debug(f"reload set by argument to {args.reload}") + reload = args.reload else: run_as_user = None @@ -1503,7 +1451,13 @@ def run_eos() -> None: # Switch privileges to run_as_user drop_root_privileges(run_as_user=run_as_user) + # Init the other singletons (besides config_eos) + singletons_init() + # Wait for EOS port to be free - e.g. in case of restart + port = config_eos.server.port + if port is None: + port = 8503 wait_for_port_free(port, timeout=120, waiting_app_name="EOS") try: @@ -1511,7 +1465,7 @@ def run_eos() -> None: uvicorn.run( "akkudoktoreos.server.eos:app", host=str(config_eos.server.host), - port=config_eos.server.port, + port=port, log_level="info", # Fix log level for uvicorn to info access_log=True, # Fix server access logging to True reload=reload, diff --git a/src/akkudoktoreos/server/eosdash.py b/src/akkudoktoreos/server/eosdash.py index 3d99d5f..e2f1b72 100644 --- a/src/akkudoktoreos/server/eosdash.py +++ b/src/akkudoktoreos/server/eosdash.py @@ -12,7 +12,7 @@ from monsterui.core import FastHTML, Theme from starlette.middleware import Middleware from starlette.requests import Request -from akkudoktoreos.config.config import get_config +from akkudoktoreos.core.coreabc import get_config from akkudoktoreos.core.logabc import LOGGING_LEVELS from akkudoktoreos.core.logging import logging_track_config from akkudoktoreos.core.version import __version__ @@ -39,7 +39,7 @@ from akkudoktoreos.server.server import ( ) from akkudoktoreos.utils.stringutil import str2bool -config_eos = get_config() +config_eos = get_config(init=True) # ------------------------------------ diff --git a/src/akkudoktoreos/server/rest/cli.py b/src/akkudoktoreos/server/rest/cli.py new file mode 100644 index 0000000..93d0f03 --- /dev/null +++ b/src/akkudoktoreos/server/rest/cli.py @@ -0,0 +1,149 @@ +import argparse + +from loguru import logger + +from akkudoktoreos.core.coreabc import get_config +from akkudoktoreos.core.logabc import LOGGING_LEVELS +from akkudoktoreos.server.server import get_default_host +from akkudoktoreos.utils.stringutil import str2bool + + +def cli_argument_parser() -> argparse.ArgumentParser: + """Build argument parser for EOS cli.""" + parser = argparse.ArgumentParser(description="Start EOS server.") + + parser.add_argument( + "--host", + type=str, + help="Host for the EOS server (default: value from config)", + ) + parser.add_argument( + "--port", + type=int, + help="Port for the EOS server (default: value from config)", + ) + parser.add_argument( + "--log_level", + type=str, + default="none", + help='Log level for the server console. Options: "critical", "error", "warning", "info", "debug", "trace" (default: "none")', + ) + parser.add_argument( + "--reload", + type=str2bool, + default=False, + help="Enable or disable auto-reload. Useful for development. Options: True or False (default: False)", + ) + parser.add_argument( + "--startup_eosdash", + type=str2bool, + default=None, + help="Enable or disable automatic EOSdash startup. Options: True or False (default: value from config)", + ) + parser.add_argument( + "--run_as_user", + type=str, + help="The unprivileged user account the EOS server shall switch to after performing root-level startup tasks.", + ) + return parser + + +def cli_parse_args( + argv: list[str] | None = None, +) -> tuple[argparse.Namespace, list[str]]: + """Parse command-line arguments for the EOS CLI. + + This function parses known EOS-specific command-line arguments and + returns any remaining unknown arguments unmodified. Unknown arguments + can be forwarded to other subsystems (e.g. Uvicorn). + + If ``argv`` is ``None``, arguments are read from ``sys.argv[1:]``. + If ``argv`` is provided, it is used instead. + + Args: + argv: Optional list of command-line arguments to parse. If omitted, + the arguments are taken from ``sys.argv[1:]``. + + Returns: + A tuple containing: + - A namespace with parsed EOS CLI arguments. + - A list of unparsed (unknown) command-line arguments. + """ + args, args_unknown = cli_argument_parser().parse_known_args(argv) + return args, args_unknown + + +def cli_apply_args_to_config(args: argparse.Namespace) -> None: + """Apply parsed CLI arguments to the EOS configuration. + + This function updates the EOS configuration with values provided via + the command line. For each parameter, the precedence is: + + CLI argument > existing config value > default value + + Currently handled arguments: + + - log_level: Updates "logging/console_level" in config. + - host: Updates "server/host" in config. + - port: Updates "server/port" in config. + - startup_eosdash: Updates "server/startup_eosdash" in config. + - eosdash_host/port: Initialized if EOSdash is enabled and not already set. + + Args: + args: Parsed command-line arguments from argparse. + """ + config_eos = get_config() + + # Setup parameters from args, config_eos and default + # Remember parameters in config + + # Setup EOS logging level - first to have the other logging messages logged + if args.log_level is not None: + log_level = args.log_level.upper() + # Ensure log_level from command line is in config settings + if log_level in LOGGING_LEVELS: + # Setup console logging level using nested value + # - triggers logging configuration by logging_track_config + config_eos.set_nested_value("logging/console_level", log_level) + logger.debug(f"logging/console_level configuration set by argument to {log_level}") + + # Setup EOS server host + if args.host: + host = args.host + logger.debug(f"server/host configuration set by argument to {host}") + elif config_eos.server.host: + host = config_eos.server.host + else: + host = get_default_host() + # Ensure host from command line is in config settings + config_eos.set_nested_value("server/host", host) + + # Setup EOS server port + if args.port: + port = args.port + logger.debug(f"server/port configuration set by argument to {port}") + elif config_eos.server.port: + port = config_eos.server.port + else: + port = 8503 + # Ensure port from command line is in config settings + config_eos.set_nested_value("server/port", port) + + # Setup EOSdash startup + if args.startup_eosdash is not None: + # Ensure startup_eosdash from command line is in config settings + config_eos.set_nested_value("server/startup_eosdash", args.startup_eosdash) + logger.debug( + f"server/startup_eosdash configuration set by argument to {args.startup_eosdash}" + ) + + if config_eos.server.startup_eosdash: + # Ensure EOSdash host and port config settings are at least set to default values + + # Setup EOS server host + if config_eos.server.eosdash_host is None: + config_eos.set_nested_value("server/eosdash_host", host) + + # Setup EOS server host + if config_eos.server.eosdash_port is None: + config_eos.set_nested_value("server/eosdash_port", port + 1) diff --git a/src/akkudoktoreos/server/rest/starteosdash.py b/src/akkudoktoreos/server/rest/starteosdash.py index fbb1fe7..f5d9d88 100644 --- a/src/akkudoktoreos/server/rest/starteosdash.py +++ b/src/akkudoktoreos/server/rest/starteosdash.py @@ -8,14 +8,12 @@ from typing import Any, MutableMapping from loguru import logger -from akkudoktoreos.config.config import get_config +from akkudoktoreos.core.coreabc import get_config from akkudoktoreos.server.server import ( validate_ip_or_hostname, wait_for_port_free, ) -config_eos = get_config() - # Loguru to HA stdout logger.add(sys.stdout, format="{time} | {level} | {message}", enqueue=True) @@ -277,14 +275,18 @@ async def forward_stream(stream: asyncio.StreamReader, prefix: str = "") -> None _emit_drop_warning() +# Path to eosdash +eosdash_path = Path(__file__).parent.resolve().joinpath("eosdash.py") + + async def run_eosdash_supervisor() -> None: """Starts EOSdash, pipes its logs, restarts it if it crashes. Runs forever. """ - global eosdash_log_queue + global eosdash_log_queue, eosdash_path - eosdash_path = Path(__file__).parent.resolve().joinpath("eosdash.py") + config_eos = get_config() while True: await asyncio.sleep(5) diff --git a/src/akkudoktoreos/server/rest/tasks.py b/src/akkudoktoreos/server/rest/tasks.py index 1cd4cb1..1f03594 100644 --- a/src/akkudoktoreos/server/rest/tasks.py +++ b/src/akkudoktoreos/server/rest/tasks.py @@ -90,3 +90,73 @@ def repeat_every( return wrapped return decorator + + +def make_repeated_task( + func: NoArgsNoReturnAnyFuncT, + *, + seconds: float, + wait_first: float | None = None, + max_repetitions: int | None = None, + on_complete: NoArgsNoReturnAnyFuncT | None = None, + on_exception: ExcArgNoReturnAnyFuncT | None = None, +) -> NoArgsNoReturnAsyncFuncT: + """Create a version of the given function that runs periodically. + + This function wraps `func` with the `repeat_every` decorator at runtime, + allowing decorator parameters to be determined dynamically rather than at import time. + + Args: + func (Callable[[], None] | Callable[[], Coroutine[Any, Any, None]]): + The function to execute periodically. Must accept no arguments. + seconds (float): + Interval in seconds between repeated calls. + wait_first (float | None, optional): + If provided, the function will wait this many seconds before the first call. + max_repetitions (int | None, optional): + Maximum number of times to repeat the function. If None, repeats indefinitely. + on_complete (Callable[[], None] | Callable[[], Coroutine[Any, Any, None]] | None, optional): + Function to call once the repetitions are complete. + on_exception (Callable[[Exception], None] | Callable[[Exception], Coroutine[Any, Any, None]] | None, optional): + Function to call if an exception is raised by `func`. + + Returns: + Callable[[], Coroutine[Any, Any, None]]: + An async function that starts the periodic execution when called. + + Usage: + .. code-block:: python + + from my_task import my_task + + from akkudoktoreos.core.coreabc import get_config + from akkudoktoreos.server.rest.tasks import make_repeated_task + + config = get_config() + + # Create a periodic task using configuration-dependent interval + repeated_task = make_repeated_task( + my_task, + seconds=config.server.poll_interval, + wait_first=5, + max_repetitions=None + ) + + # Run the task in the event loop + import asyncio + asyncio.run(repeated_task()) + + + Notes: + - This pattern avoids starting the loop at import time. + - Arguments such as `seconds` can be read from runtime sources (config, CLI args, environment variables). + - The returned function must be awaited to start the periodic loop. + """ + # Return decorated function + return repeat_every( + seconds=seconds, + wait_first=wait_first, + max_repetitions=max_repetitions, + on_complete=on_complete, + on_exception=on_exception, + )(func) diff --git a/src/akkudoktoreos/server/retentionmanager.py b/src/akkudoktoreos/server/retentionmanager.py new file mode 100644 index 0000000..4cf647d --- /dev/null +++ b/src/akkudoktoreos/server/retentionmanager.py @@ -0,0 +1,390 @@ +"""Retention Manager for Akkudoktor-EOS server. + +This module provides a single long-running background task that owns the scheduling of all periodic +server-maintenance jobs (cache cleanup, DB autosave, config reload, …). + +Responsibilities: + - Run a fast "heartbeat" loop (default 5 s) — the *compaction tick*. + - Maintain a registry of ``ManagedJob`` entries, each with its own interval. + - Re-read the live configuration on every tick so interval changes take effect + immediately without a server restart. + - Track per-job state: last run time, last duration, last error, run count. + - Expose that state for health-check / metrics endpoints. + +Example: + Typical usage inside your FastAPI lifespan:: + + from akkudoktoreos.core.coreabc import get_config + from akkudoktoreos.server.rest.retention_manager import RetentionManager + from akkudoktoreos.server.rest.tasks import make_repeated_task + + manager = RetentionManager(get_config().get_nested_value) + manager.register("cache_cleanup", cache_cleanup_fn, interval_attr="server/cache_cleanup_interval") + manager.register("db_autosave", db_autosave_fn, interval_attr="server/db_autosave_interval") + + @asynccontextmanager + async def lifespan(app: FastAPI): + tick_task = make_repeated_task(manager.tick, seconds=5, wait_first=2) + await tick_task() + yield +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Callable, Coroutine, Optional, Union + +from loguru import logger +from starlette.concurrency import run_in_threadpool + +NoArgsNoReturnAnyFuncT = Union[Callable[[], None], Callable[[], Coroutine[Any, Any, None]]] +ExcArgNoReturnAnyFuncT = Union[ + Callable[[Exception], None], Callable[[Exception], Coroutine[Any, Any, None]] +] +ConfigGetterFuncT = Callable[[str], Any] + + +# --------------------------------------------------------------------------- +# Job state — one per registered maintenance task +# --------------------------------------------------------------------------- + + +@dataclass +class JobState: + """Runtime state tracked for a single managed job. + + Attributes: + name: Unique human-readable job name used in logs and metrics. + func: The maintenance callable. Must accept no arguments. + interval_attr: Key passed to ``config_getter`` to retrieve the interval in seconds + for this job. + fallback_interval: Interval in seconds used when the key is not found or returns zero. + config_getter: Callable that accepts a string key and returns the corresponding + configuration value. Invoked with ``interval_attr`` to obtain the interval + in seconds. + on_exception: Optional callable invoked with the raised exception whenever + ``func`` fails. May be sync or async. + last_run_at: Monotonic timestamp of the last completed run; ``0.0`` means never run. + last_duration: How long the last run took, in seconds. + last_error: String representation of the last exception, or ``None`` if the last run succeeded. + run_count: Total number of completed runs (successful or not). + is_running: ``True`` while the job coroutine is currently executing. + """ + + name: str + func: NoArgsNoReturnAnyFuncT + interval_attr: str # key passed to config_getter to obtain the interval in seconds + fallback_interval: float # used when the key is not found or returns zero + config_getter: ConfigGetterFuncT # callable(key: str) -> Any; returns interval in seconds + on_exception: Optional[ExcArgNoReturnAnyFuncT] = None # optional cleanup/alerting hook + + # mutable state + last_run_at: float = 0.0 # monotonic timestamp; 0.0 means "never run" + last_duration: float = 0.0 # seconds the job took + last_error: Optional[str] = None + run_count: int = 0 + is_running: bool = False + + def interval(self) -> Optional[float]: + """Retrieve the current interval by calling ``config_getter`` with ``interval_attr``. + + Returns ``None`` when the config value is ``None``, which signals that the + job is disabled and must never fire. Falls back to ``fallback_interval`` + when the key is not found. + + Returns: + The interval in seconds, or ``None`` if the job is disabled. + """ + try: + value = self.config_getter(self.interval_attr) + if value is None: + return None + return float(value) if value else self.fallback_interval + except (KeyError, IndexError): + logger.warning( + "RetentionManager: config key '{}' not found, using fallback {}s", + self.interval_attr, + self.fallback_interval, + ) + return self.fallback_interval + + def is_due(self) -> bool: + """Check whether enough time has elapsed since the last run to execute this job again. + + Returns ``False`` immediately when `interval` returns ``None`` + (job is disabled), so a disabled job never fires regardless of when it + last ran. + + Returns: + ``True`` if the job should be executed on this tick, ``False`` otherwise. + """ + interval = self.interval() + if interval is None: + return False + return (time.monotonic() - self.last_run_at) >= interval + + def summary(self) -> dict: + """Build a serialisable snapshot of the job's current state. + + Returns: + A dictionary suitable for JSON serialisation, containing the job name, + interval key, last run timestamp, last duration, last error, + run count, and whether the job is currently running. + """ + return { + "name": self.name, + "interval_attr": self.interval_attr, + "interval_s": self.interval(), + "last_run_at": self.last_run_at, + "last_duration_s": round(self.last_duration, 4), + "last_error": self.last_error, + "run_count": self.run_count, + "is_running": self.is_running, + } + + +# --------------------------------------------------------------------------- +# Retention Manager +# --------------------------------------------------------------------------- + + +class RetentionManager: + """Orchestrates all periodic server-maintenance jobs. + + The manager itself is driven by an external ``make_repeated_task`` heartbeat + (the *compaction tick*). A ``config_getter`` callable — accepting a string key + and returning the corresponding value — is supplied at initialisation and + stored on every registered job, keeping the manager decoupled from any + specific config implementation. + + Jobs are launched as independent ``asyncio.Task`` objects so they run + concurrently without blocking the tick. Call `shutdown` during + application teardown to wait for any in-flight tasks to complete before + the event loop closes. A configurable shutdown_timeout prevents the + wait from blocking indefinitely; jobs still running after the timeout are + reported by name but not cancelled. + """ + + def __init__( + self, + config_getter: ConfigGetterFuncT, + *, + shutdown_timeout: float = 30.0, + ) -> None: + """Initialise the manager with a configuration accessor. + + Args: + config_getter: Callable that accepts a string key and returns the + corresponding configuration value. Used by each registered job + to look up its interval in seconds. + shutdown_timeout: Maximum number of seconds to wait for in-flight + jobs to finish during `shutdown`. If the timeout elapses + before all tasks complete, an error is logged and the names of + the still-running jobs are reported. The tasks are not cancelled + so they may continue running until the event loop closes. + Defaults to 30.0. + + Example:: + + manager = RetentionManager(get_config().get_nested_value, shutdown_timeout=60.0) + """ + self._config_getter = config_getter + self._shutdown_timeout = shutdown_timeout + self._jobs: dict[str, JobState] = {} + self._running_tasks: set[asyncio.Task] = set() + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register( + self, + name: str, + func: NoArgsNoReturnAnyFuncT, + *, + interval_attr: str, + fallback_interval: float = 300.0, + on_exception: Optional[ExcArgNoReturnAnyFuncT] = None, + ) -> None: + """Register a maintenance function with the manager. + + Args: + name: Unique human-readable job name used in logs and metrics. + func: The maintenance callable. Must accept no arguments. + interval_attr: Key passed to ``config_getter`` to retrieve the interval + in seconds for this job. When the config value is ``None`` the job + is treated as disabled and will never fire. + fallback_interval: Seconds to use when the config attribute is missing or zero. + Defaults to ``300.0``. + on_exception: Optional callable invoked with the raised exception whenever + ``func`` fails. Useful for cleanup or alerting. May be sync or async. + + Raises: + ValueError: If a job with the given ``name`` is already registered. + """ + if name in self._jobs: + raise ValueError(f"RetentionManager: job '{name}' is already registered") + + self._jobs[name] = JobState( + name=name, + func=func, + interval_attr=interval_attr, + fallback_interval=fallback_interval, + config_getter=self._config_getter, + on_exception=on_exception, + ) + logger.info("RetentionManager: registered job '{}' (config: {})", name, interval_attr) + + def unregister(self, name: str) -> None: + """Remove a previously registered job from the manager. + + If no job with the given name exists, this is a no-op. + + Args: + name: The name of the job to remove. + """ + self._jobs.pop(name, None) + + # ------------------------------------------------------------------ + # Tick — called by the external heartbeat loop + # ------------------------------------------------------------------ + + async def tick(self) -> None: + """Single compaction tick: check every job and fire those that are due. + + Each job resolves its own interval via the ``config_getter`` captured at + registration time. Jobs whose interval is ``None`` are silently skipped + (disabled). Due jobs are launched as independent ``asyncio.Task`` objects + so they run concurrently without blocking the tick. Each task is tracked + in ``_running_tasks`` and removed automatically on completion, allowing + `shutdown` to await all of them gracefully. + + Jobs that are still running from a previous tick are skipped to prevent + overlapping executions. + + Note: + This is the function you pass to ``make_repeated_task``. + """ + due = [job for job in self._jobs.values() if not job.is_running and job.is_due()] + + if not due: + return + + logger.debug("RetentionManager: {} job(s) due this tick", len(due)) + for job in due: + task = asyncio.ensure_future(self._run_job(job)) + task.set_name(job.name) # used by shutdown() to report timed-out jobs by name + self._running_tasks.add(task) + task.add_done_callback(self._running_tasks.discard) + + async def shutdown(self) -> None: + """Wait for all currently running job tasks to complete. + + Waits up to shutdown_timeout seconds (configured at initialisation) + for in-flight tasks to finish. If the timeout elapses before all tasks + complete, an error is logged listing the names of the jobs that are still + running. Those tasks are **not** cancelled — they continue until the event + loop closes — but `shutdown` returns so that application teardown + is not blocked indefinitely. + + Returns immediately if no tasks are running. + + Example:: + + @asynccontextmanager + async def lifespan(app: FastAPI): + tick_task = make_repeated_task(manager.tick, seconds=5, wait_first=2) + await tick_task() + + Yield: + await manager.shutdown() + """ + if not self._running_tasks: + return + + logger.info( + "RetentionManager: shutdown — waiting up to {}s for {} task(s) to finish", + self._shutdown_timeout, + len(self._running_tasks), + ) + + done, pending = await asyncio.wait(self._running_tasks, timeout=self._shutdown_timeout) + + if pending: + # Task names were set to the job name when the task was created in tick(). + pending_names = [t.get_name() for t in pending] + logger.error( + "RetentionManager: shutdown timed out after {}s — {} job(s) still running: {}", + self._shutdown_timeout, + len(pending), + pending_names, + ) + else: + logger.info("RetentionManager: all tasks finished, shutdown complete") + + self._running_tasks.clear() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _run_job(self, job: JobState) -> None: + """Execute a single job and update its state regardless of outcome. + + Handles both async and sync callables for both the main function and the + optional ``on_exception`` hook. Exceptions from ``func`` are caught, logged, + stored on the job, and forwarded to ``on_exception`` if provided, so a + failing job never disrupts other concurrent jobs or future ticks. + + Args: + job: The `JobState` instance to execute. + """ + job.is_running = True + start = time.monotonic() + logger.debug("RetentionManager: starting job '{}'", job.name) + try: + if asyncio.iscoroutinefunction(job.func): + await job.func() + else: + await run_in_threadpool(job.func) + + job.last_error = None + logger.debug( + "RetentionManager: job '{}' completed in {:.3f}s", + job.name, + time.monotonic() - start, + ) + + except Exception as exc: # noqa: BLE001 + job.last_error = str(exc) + logger.exception("RetentionManager: job '{}' raised an exception: {}", job.name, exc) + + if job.on_exception is not None: + if asyncio.iscoroutinefunction(job.on_exception): + await job.on_exception(exc) + else: + await run_in_threadpool(job.on_exception, exc) + + finally: + job.last_duration = time.monotonic() - start + job.last_run_at = time.monotonic() + job.run_count += 1 + job.is_running = False + + # ------------------------------------------------------------------ + # Observability + # ------------------------------------------------------------------ + + def status(self) -> list[dict]: + """Return a snapshot of every job's state for health or metrics endpoints. + + Returns: + A list of dictionaries, one per registered job, each produced by + `JobState.summary`. + """ + return [job.summary() for job in self._jobs.values()] + + def __repr__(self) -> str: # pragma: no cover + return f"" diff --git a/src/akkudoktoreos/server/server.py b/src/akkudoktoreos/server/server.py index bd304e2..ad6d835 100644 --- a/src/akkudoktoreos/server/server.py +++ b/src/akkudoktoreos/server/server.py @@ -14,6 +14,7 @@ from loguru import logger from pydantic import Field, field_validator from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_config def get_default_host() -> str: @@ -258,8 +259,6 @@ def fix_data_directories_permissions(run_as_user: Optional[str] = None) -> None: run_as_user (Optional[str]): The user who should own the data directories and files. Defaults to current one. """ - from akkudoktoreos.config.config import get_config - config_eos = get_config() base_dirs = [ diff --git a/src/akkudoktoreos/utils/datetimeutil.py b/src/akkudoktoreos/utils/datetimeutil.py index d442b96..ab0cfc5 100644 --- a/src/akkudoktoreos/utils/datetimeutil.py +++ b/src/akkudoktoreos/utils/datetimeutil.py @@ -1868,6 +1868,28 @@ def to_duration( raise ValueError(error_msg) +# Timezone names that are semantically identical to UTC and should be +# canonicalized. Keys are lower-cased for case-insensitive matching. +_UTC_ALIASES: dict[str, str] = { + "utc": "UTC", + "gmt": "UTC", + "z": "UTC", + "etc/utc": "UTC", + "etc/gmt": "UTC", + "etc/gmt+0": "UTC", + "etc/gmt-0": "UTC", + "etc/gmt0": "UTC", + "etc/greenwich": "UTC", + "etc/universal": "UTC", + "etc/zulu": "UTC", +} + + +def _canonicalize_tz_name(name: str) -> str: + """Return 'UTC' when *name* is a known UTC alias, otherwise return unchanged.""" + return _UTC_ALIASES.get(name.lower(), name) + + @overload def to_timezone( utc_offset: Optional[float] = None, @@ -1891,6 +1913,9 @@ def to_timezone( ) -> Union[Timezone, str]: """Determines the timezone either by UTC offset, geographic location, or local system timezone. + Timezone names that are semantically equivalent to UTC (e.g. ``GMT``, ``Z``, + ``Etc/GMT``) are canonicalized to ``"UTC"`` before returning. + By default, it returns a `Timezone` object representing the timezone. If `as_string` is set to `True`, the function returns the timezone name as a string instead. @@ -1925,7 +1950,15 @@ def to_timezone( if not -24 <= utc_offset <= 24: raise ValueError("UTC offset must be within the range -24 to +24 hours.") - # Convert UTC offset to an Etc/GMT-compatible format + # Offset of exactly 0 is plain UTC – no need for Etc/GMT+0 etc. + if utc_offset == 0: + if as_string: + return "UTC" + return pendulum.timezone("UTC") + + # Convert UTC offset to an Etc/GMT-compatible format. + # NOTE: Etc/GMT sign convention is *inverted* relative to the common + # expectation: Etc/GMT+5 means UTC-5. We therefore flip the sign. hours = int(utc_offset) minutes = int((abs(utc_offset) - abs(hours)) * 60) sign = "-" if utc_offset >= 0 else "+" @@ -1951,6 +1984,8 @@ def to_timezone( except Exception as e: raise ValueError(f"Error determining timezone for location {location}: {e}") from e + tz_name = _canonicalize_tz_name(tz_name) + if as_string: return tz_name return pendulum.timezone(tz_name) @@ -1958,7 +1993,9 @@ def to_timezone( # Fallback to local timezone local_tz = pendulum.local_timezone() if isinstance(local_tz, str): - local_tz = pendulum.timezone(local_tz) + local_tz = pendulum.timezone(_canonicalize_tz_name(local_tz)) + else: + local_tz = pendulum.timezone(_canonicalize_tz_name(local_tz.name)) if as_string: return local_tz.name return local_tz diff --git a/src/akkudoktoreos/utils/visualize.py b/src/akkudoktoreos/utils/visualize.py index 69929ab..b2b481f 100644 --- a/src/akkudoktoreos/utils/visualize.py +++ b/src/akkudoktoreos/utils/visualize.py @@ -11,8 +11,7 @@ import numpy as np import pendulum from matplotlib.backends.backend_pdf import PdfPages -from akkudoktoreos.core.coreabc import ConfigMixin -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import ConfigMixin, get_ems from akkudoktoreos.optimization.genetic.genetic import GeneticOptimizationParameters from akkudoktoreos.utils.datetimeutil import DateTime, to_datetime diff --git a/tests/conftest.py b/tests/conftest.py index 972e967..c4730e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,8 @@ from _pytest.logging import LogCaptureFixture from loguru import logger from xprocess import ProcessStarter, XProcess -from akkudoktoreos.config.config import ConfigEOS, get_config +from akkudoktoreos.config.config import ConfigEOS +from akkudoktoreos.core.coreabc import get_config, get_prediction, singletons_init from akkudoktoreos.core.version import _version_hash, version from akkudoktoreos.server.server import get_default_host @@ -134,8 +135,6 @@ def is_ci() -> bool: @pytest.fixture def prediction_eos(): - from akkudoktoreos.prediction.prediction import get_prediction - return get_prediction() @@ -172,6 +171,37 @@ def cfg_non_existent(request): ) +# ------------------------------------ +# Provide pytest EOS config management +# ------------------------------------ + + +@pytest.fixture +def config_default_dirs(tmpdir): + """Fixture that provides a list of directories to be used as config dir.""" + tmp_user_home_dir = Path(tmpdir) + + # Default config directory from platform user config directory + config_default_dir_user = tmp_user_home_dir / "config" + + # Default config directory from current working directory + config_default_dir_cwd = tmp_user_home_dir / "cwd" + config_default_dir_cwd.mkdir() + + # Default config directory from default config file + config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data") + + # Default data directory from platform user data directory + data_default_dir_user = tmp_user_home_dir + + return ( + config_default_dir_user, + config_default_dir_cwd, + config_default_dir_default, + data_default_dir_user, + ) + + @pytest.fixture(autouse=True) def user_cwd(config_default_dirs): """Patch cwd provided by module pathlib.Path.cwd.""" @@ -203,64 +233,102 @@ def user_data_dir(config_default_dirs): @pytest.fixture -def config_eos( +def config_eos_factory( disable_debug_logging, user_config_dir, user_data_dir, user_cwd, config_default_dirs, monkeypatch, -) -> ConfigEOS: - """Fixture to reset EOS config to default values.""" - monkeypatch.setenv( - "EOS_CONFIG__DATA_CACHE_SUBPATH", str(config_default_dirs[-1] / "data/cache") - ) - monkeypatch.setenv( - "EOS_CONFIG__DATA_OUTPUT_SUBPATH", str(config_default_dirs[-1] / "data/output") - ) - config_file = config_default_dirs[0] / ConfigEOS.CONFIG_FILE_NAME - config_file_cwd = config_default_dirs[1] / ConfigEOS.CONFIG_FILE_NAME - assert not config_file.exists() - assert not config_file_cwd.exists() +): + """Factory fixture for creating a fully initialized ``ConfigEOS`` instance. - config_eos = get_config() - config_eos.reset_settings() - assert config_file == config_eos.general.config_file_path - assert config_file.exists() - assert not config_file_cwd.exists() + Returns a callable that creates a ``ConfigEOS`` singleton with a controlled + filesystem layout and environment variables. Allows tests to customize which + pydantic-settings sources are enabled (init, env, dotenv, file, secrets). - # Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user) - assert config_default_dirs[-1] / "data" == config_eos.general.data_folder_path - assert config_default_dirs[-1] / "data/cache" == config_eos.cache.path() - assert config_default_dirs[-1] / "data/output" == config_eos.general.data_output_path - assert config_default_dirs[-1] / "data/output/eos.log" == config_eos.logging.file_path - return config_eos + The factory ensures: + - Required directories exist + - No pre-existing config files are present + - Settings are reloaded to respect test-specific configuration + - Dependent singletons are initialized + + The singleton instance is reset during fixture teardown. + """ + def _create(init: dict[str, bool] | None = None) -> ConfigEOS: + init = init or { + "with_init_settings": True, + "with_env_settings": True, + "with_dotenv_settings": False, + "with_file_settings": False, + "with_file_secret_settings": False, + } + + # reset singleton before touching env or config + ConfigEOS.reset_instance() + ConfigEOS._init_config_eos = { + "with_init_settings": True, + "with_env_settings": True, + "with_dotenv_settings": True, + "with_file_settings": True, + "with_file_secret_settings": True, + } + ConfigEOS._config_file_path = None + ConfigEOS._force_documentation_mode = False + + data_folder_path = config_default_dirs[-1] / "data" + data_folder_path.mkdir(exist_ok=True) + + config_dir = config_default_dirs[0] + config_dir.mkdir(exist_ok=True) + + cwd = config_default_dirs[1] + cwd.mkdir(exist_ok=True) + + monkeypatch.setenv("EOS_CONFIG_DIR", str(config_dir)) + monkeypatch.setenv("EOS_GENERAL__DATA_FOLDER_PATH", str(data_folder_path)) + monkeypatch.setenv("EOS_GENERAL__DATA_CACHE_SUBPATH", "cache") + monkeypatch.setenv("EOS_GENERAL__DATA_OUTPUT_SUBPATH", "output") + + # Ensure no config files exist + config_file = config_dir / ConfigEOS.CONFIG_FILE_NAME + config_file_cwd = cwd / ConfigEOS.CONFIG_FILE_NAME + assert not config_file.exists() + assert not config_file_cwd.exists() + + config_eos = get_config(init=init) + # Ensure newly created configurations are respected + # Note: Workaround for pydantic_settings and pytest + config_eos.reset_settings() + + # Check user data directory pathes (config_default_dirs[-1] == data_default_dir_user) + assert config_eos.general.data_folder_path == data_folder_path + assert config_eos.general.data_output_subpath == Path("output") + assert config_eos.cache.subpath == "cache" + assert config_eos.cache.path() == config_default_dirs[-1] / "data/cache" + assert config_eos.logging.file_path == config_default_dirs[-1] / "data/output/eos.log" + + # Check config file path + assert str(config_eos.general.config_file_path) == str(config_file) + assert config_file.exists() + assert not config_file_cwd.exists() + + # Initialize all other singletons (if not already initialized) + singletons_init() + + return config_eos + + yield _create + + # teardown - final safety net + ConfigEOS.reset_instance() @pytest.fixture -def config_default_dirs(tmpdir): - """Fixture that provides a list of directories to be used as config dir.""" - tmp_user_home_dir = Path(tmpdir) - - # Default config directory from platform user config directory - config_default_dir_user = tmp_user_home_dir / "config" - - # Default config directory from current working directory - config_default_dir_cwd = tmp_user_home_dir / "cwd" - config_default_dir_cwd.mkdir() - - # Default config directory from default config file - config_default_dir_default = Path(__file__).parent.parent.joinpath("src/akkudoktoreos/data") - - # Default data directory from platform user data directory - data_default_dir_user = tmp_user_home_dir - - return ( - config_default_dir_user, - config_default_dir_cwd, - config_default_dir_default, - data_default_dir_user, - ) +def config_eos(config_eos_factory) -> ConfigEOS: + """Fixture to reset EOS config to default values.""" + config_eos = config_eos_factory() + return config_eos # ------------------------------------ @@ -405,7 +473,11 @@ def server_base( Yields: dict[str, str]: A dictionary containing: - "server" (str): URL of the server. + - "port": port + - "eosdash_server": eosdash_server + - "eosdash_port": eosdash_port - "eos_dir" (str): Path to the temporary EOS_DIR. + - "timeout": server_timeout """ host = get_default_host() port = 8503 @@ -427,12 +499,14 @@ def server_base( eos_tmp_dir = tempfile.TemporaryDirectory() eos_dir = str(eos_tmp_dir.name) + eos_general_data_folder_path = str(Path(eos_dir) / "data") class Starter(ProcessStarter): # Set environment for server run env = os.environ.copy() env["EOS_DIR"] = eos_dir env["EOS_CONFIG_DIR"] = eos_dir + env["EOS_GENERAL__DATA_FOLDER_PATH"] = eos_general_data_folder_path if extra_env: env.update(extra_env) diff --git a/tests/single_test_optimization.py b/tests/single_test_optimization.py index dc4b493..76a22d9 100755 --- a/tests/single_test_optimization.py +++ b/tests/single_test_optimization.py @@ -12,13 +12,11 @@ from typing import Any import numpy as np from loguru import logger -from akkudoktoreos.config.config import get_config -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_config, get_ems, get_prediction from akkudoktoreos.core.emsettings import EnergyManagementMode from akkudoktoreos.optimization.genetic.geneticparams import ( GeneticOptimizationParameters, ) -from akkudoktoreos.prediction.prediction import get_prediction from akkudoktoreos.utils.datetimeutil import to_datetime config_eos = get_config() diff --git a/tests/single_test_prediction.py b/tests/single_test_prediction.py index 41398c1..c08fae4 100644 --- a/tests/single_test_prediction.py +++ b/tests/single_test_prediction.py @@ -6,8 +6,7 @@ import pstats import sys import time -from akkudoktoreos.config.config import get_config -from akkudoktoreos.prediction.prediction import get_prediction +from akkudoktoreos.core.coreabc import get_config, get_prediction config_eos = get_config() prediction_eos = get_prediction() diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 9d002f1..181ab80 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -12,11 +12,11 @@ import pytest from akkudoktoreos.adapter.adapter import ( Adapter, AdapterCommonSettings, - get_adapter, ) from akkudoktoreos.adapter.adapterabc import AdapterContainer from akkudoktoreos.adapter.homeassistant import HomeAssistantAdapter from akkudoktoreos.adapter.nodered import NodeREDAdapter +from akkudoktoreos.core.coreabc import get_adapter # ---------- Typed aliases for fixtures ---------- AdapterFixture: TypeAlias = Adapter diff --git a/tests/test_cache.py b/tests/test_cache.py index 3d6b96a..7683586 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -167,12 +167,21 @@ def temp_store_file(): @pytest.fixture -def cache_file_store(temp_store_file): +def cache_file_store(temp_store_file, monkeypatch): """A pytest fixture that creates a new CacheFileStore instance for testing.""" + cache = CacheFileStore() - cache._store_file = temp_store_file + + # Patch the _cache_file method to return the temp file + monkeypatch.setattr( + cache, + "_store_file", + lambda: temp_store_file, + ) + cache.clear(clear_all=True) assert len(cache._store) == 0 + return cache @@ -481,7 +490,7 @@ class TestCacheFileStore: cache_file_store.save_store() # Verify the file content - with cache_file_store._store_file.open("r", encoding="utf-8", newline=None) as f: + with cache_file_store._store_file().open("r", encoding="utf-8", newline=None) as f: store_loaded = json.load(f) assert "test_key" in store_loaded assert store_loaded["test_key"]["cache_file"] == "cache_file_path" @@ -501,7 +510,7 @@ class TestCacheFileStore: "ttl_duration": None, } } - with cache_file_store._store_file.open("w", encoding="utf-8", newline="\n") as f: + with cache_file_store._store_file().open("w", encoding="utf-8", newline="\n") as f: json.dump(cache_record, f, indent=4) # Mock the open function to return a MagicMock for the cache file diff --git a/tests/test_config.py b/tests/test_config.py index e591ab5..4f485b8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -109,17 +109,18 @@ def test_config_ipaddress(monkeypatch, config_eos): assert config_eos.server.host == "localhost" -def test_singleton_behavior(config_eos, config_default_dirs): +def test_singleton_behavior(config_eos, config_default_dirs, monkeypatch): """Test that ConfigEOS behaves as a singleton.""" - initial_cfg_file = config_eos.general.config_file_path - with patch( - "akkudoktoreos.config.config.user_config_dir", return_value=str(config_default_dirs[0]) - ): - instance1 = ConfigEOS() - instance2 = ConfigEOS() - assert instance1 is config_eos + config_eos.reset_instance() + + monkeypatch.setenv("EOS_CONFIG_DIR", str(config_default_dirs[0])) + + instance1 = ConfigEOS() + instance2 = ConfigEOS() + + assert instance1 is not config_eos assert instance1 is instance2 - assert instance1.general.config_file_path == initial_cfg_file + assert instance1._config_file_path == instance2._config_file_path def test_config_file_priority(config_default_dirs): @@ -169,17 +170,22 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_ with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) monkeypatch.setenv("EOS_DIR", str(temp_dir_path)) + monkeypatch.delenv("EOS_CONFIG_DIR", raising=False) assert config_eos._get_config_file_path() == (cfg_file(temp_dir_path), False) monkeypatch.setenv("EOS_CONFIG_DIR", "config") + config_dir = temp_dir_path / "config" + config_dir.mkdir(exist_ok=True) assert config_eos._get_config_file_path() == ( - cfg_file(temp_dir_path / "config"), + cfg_file(config_dir), False, ) monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config2")) + config_dir = temp_dir_path / "config2" + config_dir.mkdir(exist_ok=True) assert config_eos._get_config_file_path() == ( - cfg_file(temp_dir_path / "config2"), + cfg_file(config_dir), False, ) @@ -188,8 +194,10 @@ def test_get_config_file_path(user_config_dir_patch, config_eos, config_default_ assert config_eos._get_config_file_path() == (cfg_file(config_default_dir_user), False) monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_dir_path / "config3")) + config_dir = temp_dir_path / "config3" + config_dir.mkdir(exist_ok=True) assert config_eos._get_config_file_path() == ( - cfg_file(temp_dir_path / "config3"), + cfg_file(config_dir), False, ) @@ -199,7 +207,7 @@ def test_config_copy(config_eos, monkeypatch): with tempfile.TemporaryDirectory() as temp_dir: temp_folder_path = Path(temp_dir) temp_config_file_path = temp_folder_path.joinpath(config_eos.CONFIG_FILE_NAME).resolve() - monkeypatch.setenv(config_eos.EOS_DIR, str(temp_folder_path)) + monkeypatch.setenv("EOS_CONFIG_DIR", str(temp_folder_path)) assert not temp_config_file_path.exists() with patch("akkudoktoreos.config.config.user_config_dir", return_value=temp_dir): assert config_eos._get_config_file_path() == (temp_config_file_path, False) diff --git a/tests/test_configmigrate.py b/tests/test_configmigrate.py index cb04603..825f30c 100644 --- a/tests/test_configmigrate.py +++ b/tests/test_configmigrate.py @@ -26,6 +26,9 @@ MIGRATION_PAIRS = [ # (DIR_TESTDATA / "old_config_X.json", DIR_TESTDATA / "expected_config_X.json"), ] +# Any sentinel in expected data +_ANY_SENTINEL = "__ANY__" + def _dict_contains(superset: Any, subset: Any, path="") -> list[str]: """Recursively verify that all key-value pairs from a subset dictionary or list exist in a superset. @@ -60,6 +63,9 @@ def _dict_contains(superset: Any, subset: Any, path="") -> list[str]: errors.extend(_dict_contains(superset[i], elem, f"{path}[{i}]" if path else f"[{i}]")) else: + # "__ANY__" in expected means "accept whatever value the migration produces" + if subset == _ANY_SENTINEL: + return errors # Compare values (with numeric tolerance) if isinstance(subset, (int, float)) and isinstance(superset, (int, float)): if abs(float(subset) - float(superset)) > 1e-6: @@ -162,6 +168,7 @@ class TestConfigMigration: assert backup_file.exists(), f"Backup file not created for {old_file.name}" # --- Compare migrated result with expected output --- + old_data = json.loads(old_file.read_text(encoding="utf-8")) new_data = json.loads(working_file.read_text(encoding="utf-8")) expected_data = json.loads(expected_file.read_text(encoding="utf-8")) @@ -202,6 +209,14 @@ class TestConfigMigration: # Verify the migrated value matches the expected one new_value = configmigrate._get_json_nested_value(new_data, new_path) if new_value != expected_value: + # Check if this mapping uses _KEEP_DEFAULT and the old value was None/missing + old_value = configmigrate._get_json_nested_value(old_data, old_path) + keep_default = ( + isinstance(mapping, tuple) + and configmigrate._KEEP_DEFAULT in mapping + ) + if keep_default and old_value is None: + continue # acceptable: old was None, new model keeps its default mismatched_values.append( f"{old_path} → {new_path}: expected {expected_value!r}, got {new_value!r}" ) diff --git a/tests/test_dataabc.py b/tests/test_dataabc.py index 96c9ec7..0f31dc1 100644 --- a/tests/test_dataabc.py +++ b/tests/test_dataabc.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timezone from typing import Any, ClassVar, List, Optional, Union @@ -8,15 +9,16 @@ import pytest from pydantic import Field, ValidationError from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.core.dataabc import ( - DataBase, + DataABC, DataContainer, DataImportProvider, DataProvider, DataRecord, DataSequence, ) -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.databaseabc import DatabaseTimestamp from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration # Derived classes for testing @@ -28,7 +30,7 @@ class DerivedConfig(SettingsBaseModel): class_constant: Optional[int] = Field(default=None, description="Test config by class constant") -class DerivedBase(DataBase): +class DerivedBase(DataABC): instance_field: Optional[str] = Field(default=None, description="Field Value") class_constant: ClassVar[int] = 30 @@ -58,6 +60,15 @@ class DerivedSequence(DataSequence): def record_class(cls) -> Any: return DerivedRecord +class DerivedSequence2(DataSequence): + # overload + records: List[DerivedRecord] = Field( + default_factory=list, description="List of DerivedRecord records" + ) + + @classmethod + def record_class(cls) -> Any: + return DerivedRecord class DerivedDataProvider(DataProvider): """A concrete subclass of DataProvider for testing purposes.""" @@ -121,7 +132,7 @@ class DerivedDataContainer(DataContainer): # ---------- -class TestDataBase: +class TestDataABC: @pytest.fixture def base(self): # Provide default values for configuration @@ -141,7 +152,7 @@ class TestDataRecord: @pytest.fixture def record(self): """Fixture to create a sample DerivedDataRecord with some data set.""" - rec = DerivedRecord(date_time=None, data_value=10.0) + rec = DerivedRecord(date_time=to_datetime("1967-01-11"), data_value=10.0) rec.configured_data = {"dish_washer_emr": 123.0, "solar_power": 456.0} return rec @@ -393,8 +404,8 @@ class TestDataSequence: sequence = DerivedSequence() record1 = self.create_test_record(datetime(1970, 1, 1), 1970) record2 = self.create_test_record(datetime(1971, 1, 1), 1971) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) assert len(sequence) == 2 return sequence @@ -403,141 +414,118 @@ class TestDataSequence: return DerivedRecord(date_time=date, data_value=value) # Test cases + @pytest.mark.parametrize("tz_name", ["UTC", "Europe/Berlin", "Atlantic/Canary"]) + def test_min_max_datetime_timezone_and_order(self, sequence, tz_name, monkeypatch, config_eos): + # Monkeypatch the read-only timezone property + monkeypatch.setattr(config_eos.general.__class__, "timezone", property(lambda self: tz_name)) + + # Create timezone-aware datetimes using the patched config + dt_early = to_datetime("2024-01-01T00:00:00", in_timezone=config_eos.general.timezone) + dt_late = to_datetime("2024-01-02T00:00:00", in_timezone=config_eos.general.timezone) + + # Insert in reverse order to verify sorting + record1 = self.create_test_record(dt_late, 1) + record2 = self.create_test_record(dt_early, 2) + + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + + min_dt = sequence.min_datetime + max_dt = sequence.max_datetime + + # --- Basic correctness --- + assert min_dt == dt_early + assert max_dt == dt_late + + # --- Must be timezone aware --- + assert min_dt.tzinfo is not None + assert max_dt.tzinfo is not None + + # --- Must preserve timezone --- + assert min_dt.tzinfo.name == tz_name + assert max_dt.tzinfo.name == tz_name + + def test_getitem(self, sequence): assert len(sequence) == 0 - record = self.create_test_record("2024-01-01 00:00:00", 0) + dt = to_datetime("2024-01-01 00:00:00") + record = self.create_test_record(dt, 0) sequence.insert_by_datetime(record) - assert isinstance(sequence[0], DerivedRecord) + assert isinstance(sequence.get_by_datetime(dt), DerivedRecord) def test_setitem(self, sequence2): - new_record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 1) - sequence2[0] = new_record - assert sequence2[0].date_time == datetime(2024, 1, 3, tzinfo=timezone.utc) + dt = to_datetime("2024-01-03", in_timezone="UTC") + record = self.create_test_record(dt, 1) + sequence2.insert_by_datetime(record) + assert sequence2.records[2].date_time == dt - def test_set_record_at_index(self, sequence2): - record1 = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 1) - record2 = self.create_test_record(datetime(2023, 11, 5), 0.8) - sequence2[1] = record1 - assert sequence2[1].date_time == datetime(2024, 1, 3, tzinfo=timezone.utc) - sequence2[0] = record2 - assert len(sequence2) == 2 - assert sequence2[0] == record2 + def test_insert_reversed_date_record(self, sequence2): + dt1 = to_datetime("2023-11-05", in_timezone="UTC") + dt2 = to_datetime("2024-01-03", in_timezone="UTC") + record1 = self.create_test_record(dt2, 0.8) + record2 = self.create_test_record(dt1, 0.9) # reversed date + sequence2.insert_by_datetime(record1) + assert sequence2.records[2].date_time == dt2 + sequence2.insert_by_datetime(record2) + assert len(sequence2) == 4 + assert sequence2.records[2] == record2 def test_insert_duplicate_date_record(self, sequence): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 5), 0.9) # Duplicate date + dt1 = to_datetime("2023-11-05") + record1 = self.create_test_record(dt1, 0.8) + record2 = self.create_test_record(dt1, 0.9) # Duplicate date sequence.insert_by_datetime(record1) sequence.insert_by_datetime(record2) assert len(sequence) == 1 - assert sequence[0].data_value == 0.9 # Record should have merged with new value - - def test_sort_by_datetime_ascending(self, sequence): - """Test sorting records in ascending order by date_time.""" - records = [ - self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), - self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), - self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), - ] - for i, record in enumerate(records): - sequence.insert(i, record) - sequence.sort_by_datetime() - sorted_dates = [record.date_time for record in sequence.records] - for i, expected_date in enumerate( - [ - pendulum.datetime(2024, 10, 1), - pendulum.datetime(2024, 11, 1), - pendulum.datetime(2024, 12, 1), - ] - ): - assert compare_datetimes(sorted_dates[i], expected_date).equal - - def test_sort_by_datetime_descending(self, sequence): - """Test sorting records in descending order by date_time.""" - records = [ - self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), - self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), - self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), - ] - for i, record in enumerate(records): - sequence.insert(i, record) - sequence.sort_by_datetime(reverse=True) - sorted_dates = [record.date_time for record in sequence.records] - for i, expected_date in enumerate( - [ - pendulum.datetime(2024, 12, 1), - pendulum.datetime(2024, 11, 1), - pendulum.datetime(2024, 10, 1), - ] - ): - assert compare_datetimes(sorted_dates[i], expected_date).equal - - def test_sort_by_datetime_with_none(self, sequence): - """Test sorting records when some date_time values are None.""" - records = [ - self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), - self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), - self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), - ] - for i, record in enumerate(records): - sequence.insert(i, record) - sequence.records[2].date_time = None - assert sequence.records[2].date_time is None - sequence.sort_by_datetime() - sorted_dates = [record.date_time for record in sequence.records] - for i, expected_date in enumerate( - [ - None, # None values should come first - pendulum.datetime(2024, 10, 1), - pendulum.datetime(2024, 11, 1), - ] - ): - if expected_date is None: - assert sorted_dates[i] is None - else: - assert compare_datetimes(sorted_dates[i], expected_date).equal - - def test_sort_by_datetime_error_on_uncomparable(self, sequence): - """Test error is raised when date_time contains uncomparable values.""" - records = [ - self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), - self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), - self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), - ] - for i, record in enumerate(records): - sequence.insert(i, record) - with pytest.raises( - ValidationError, match="Date string not_a_datetime does not match any known formats." - ): - sequence.records[2].date_time = "not_a_datetime" # Invalid date_time - sequence.sort_by_datetime() + assert sequence.get_by_datetime(dt1).data_value == 0.9 # Record should have merged with new value def test_key_to_series(self, sequence): - record = self.create_test_record(datetime(2023, 11, 6), 0.8) - sequence.append(record) + dt = to_datetime(datetime(2023, 11, 6)) + record = self.create_test_record(dt, 0.8) + sequence.insert_by_datetime(record) series = sequence.key_to_series("data_value") assert isinstance(series, pd.Series) - assert series[to_datetime(datetime(2023, 11, 6))] == 0.8 + + retrieved_record = sequence.get_by_datetime(dt) + assert retrieved_record is not None + assert retrieved_record.data_value == 0.8 def test_key_from_series(self, sequence): + dt1 = to_datetime(datetime(2023, 11, 5)) + dt2 = to_datetime(datetime(2023, 11, 6)) + series = pd.Series( - data=[0.8, 0.9], index=pd.to_datetime([datetime(2023, 11, 5), datetime(2023, 11, 6)]) + data=[0.8, 0.9], index=pd.to_datetime([dt1, dt2]) ) sequence.key_from_series("data_value", series) assert len(sequence) == 2 - assert sequence[0].data_value == 0.8 - assert sequence[1].data_value == 0.9 + + record1 = sequence.get_by_datetime(dt1) + assert record1 is not None + assert record1.data_value == 0.8 + + record2 = sequence.get_by_datetime(dt2) + assert record2 is not None + assert record2.data_value == 0.9 def test_key_to_array(self, sequence): interval = to_duration("1 day") start_datetime = to_datetime("2023-11-6") last_datetime = to_datetime("2023-11-8") end_datetime = to_datetime("2023-11-9") - record = self.create_test_record(start_datetime, float(start_datetime.day)) - sequence.insert_by_datetime(record) - record = self.create_test_record(last_datetime, float(last_datetime.day)) - sequence.insert_by_datetime(record) - assert sequence[0].data_value == 6.0 - assert sequence[1].data_value == 8.0 + + record1 = self.create_test_record(start_datetime, float(start_datetime.day)) + sequence.insert_by_datetime(record1) + record2 = self.create_test_record(last_datetime, float(last_datetime.day)) + sequence.insert_by_datetime(record2) + + retrieved_record1 = sequence.get_by_datetime(start_datetime) + assert retrieved_record1 is not None + assert retrieved_record1.data_value == 6.0 + + retrieved_record2 = sequence.get_by_datetime(last_datetime) + assert retrieved_record2 is not None + assert retrieved_record2.data_value == 8.0 series = sequence.key_to_series( key="data_value", start_datetime=start_datetime, end_datetime=end_datetime @@ -553,10 +541,7 @@ class TestDataSequence: interval=interval, ) assert isinstance(array, np.ndarray) - assert len(array) == 3 - assert array[0] == start_datetime.day - assert array[1] == 7 - assert array[2] == last_datetime.day + np.testing.assert_equal(array, [6.0, 7.0, 8.0]) def test_key_to_array_linear_interpolation(self, sequence): """Test key_to_array with linear interpolation for numeric data.""" @@ -578,6 +563,44 @@ class TestDataSequence: assert array[1] == 0.9 # Interpolated value assert array[2] == 1.0 + + def test_key_to_array_linear_interpolation_out_of_grid(self, sequence): + """Test key_to_array with linear interpolation out of grid.""" + interval = to_duration("1 hour") + start_datetime= to_datetime("2023-11-06T00:30:00") # out of grid + end_datetime=to_datetime("2023-11-06T01:30:00") # out of grid + + record1_datetime = to_datetime("2023-11-06T00:00:00") + record1 = self.create_test_record(record1_datetime, 1.0) + + record2_datetime = to_datetime("2023-11-06T02:00:00") + record2 = self.create_test_record(record2_datetime, 2.0) # Gap of 2 hours + + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + + # Check test setup + record1_timestamp = DatabaseTimestamp.from_datetime(record1_datetime) + record2_timestamp = DatabaseTimestamp.from_datetime(record2_datetime) + start_timestamp = DatabaseTimestamp.from_datetime(start_datetime) + end_timestamp = DatabaseTimestamp.from_datetime(end_datetime) + + start_previous_timestamp = sequence.db_previous_timestamp(start_timestamp) + assert start_previous_timestamp == record1_timestamp + end_next_timestamp = sequence.db_next_timestamp(end_timestamp) + assert end_next_timestamp == record2_timestamp + + # Test + array = sequence.key_to_array( + key="data_value", + start_datetime=start_datetime, + end_datetime=end_datetime, + interval=interval, + fill_method="linear", + boundary="context", + ) + np.testing.assert_equal(array, [1.5]) + def test_key_to_array_ffill(self, sequence): """Test key_to_array with forward filling for missing values.""" interval = to_duration("1 hour") @@ -645,15 +668,19 @@ class TestDataSequence: sequence.insert_by_datetime(record1) sequence.insert_by_datetime(record2) + #assert sequence is None + array = sequence.key_to_array( key="data_value", - start_datetime=pendulum.datetime(2023, 11, 6), + start_datetime=pendulum.datetime(2023, 11, 5, 23), end_datetime=pendulum.datetime(2023, 11, 6, 2), interval=interval, ) - assert len(array) == 2 - assert array[0] == 0.9 # Interpolated from previous day - assert array[1] == 1.0 + + assert len(array) == 3 + assert array[0] == 0.8 + assert array[1] == 0.9 # Interpolated from previous day + assert array[2] == 1.0 def test_key_to_array_with_none(self, sequence): """Test handling of empty series in key_to_array.""" @@ -675,13 +702,14 @@ class TestDataSequence: array = sequence.key_to_array( key="data_value", - start_datetime=pendulum.datetime(2023, 11, 6), + start_datetime=pendulum.datetime(2023, 11, 5, 23), end_datetime=pendulum.datetime(2023, 11, 6, 2), interval=interval, ) - assert len(array) == 2 - assert array[0] == 0.8 # Interpolated from previous day - assert array[1] == 0.8 + assert len(array) == 3 + assert array[0] == 0.8 + assert array[1] == 0.8 # Interpolated from previous day + assert array[2] == 0.8 # Interpolated from previous day def test_key_to_array_invalid_fill_method(self, sequence): """Test invalid fill_method raises an error.""" @@ -725,71 +753,354 @@ class TestDataSequence: # The first interval mean = (1+2+3+4)/4 = 2.5 assert array[0] == pytest.approx(2.5) - def test_to_datetimeindex(self, sequence2): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence2.insert(0, record1) - sequence2.insert(1, record2) - dt_index = sequence2.to_datetimeindex() - assert isinstance(dt_index, pd.DatetimeIndex) - assert dt_index[0] == to_datetime(datetime(2023, 11, 5)) - assert dt_index[1] == to_datetime(datetime(2023, 11, 6)) + # ------------------------------------------------------------------ + # key_to_array — align_to_interval parameter + # ------------------------------------------------------------------ + # + # The existing tests above use start_datetime values that already sit on + # clean hour/day boundaries, so the default alignment (origin=query_start) + # and clock alignment (origin=epoch-floor) produce identical results. + # The tests below specifically use off-boundary start times to expose + # the difference and verify the new parameter. + + def test_key_to_array_align_false_origin_is_query_start(self, sequence): + """Without align_to_interval the first bucket sits at query_start, not a clock boundary. + + With start_datetime at 10:07:00 and 15-min interval the first resampled + bucket must be at 10:07:00 (origin = query_start), NOT at 10:00:00 or 10:15:00. + """ + # Off-boundary start: 10:07 + start_dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 12, 7, tz="UTC") + + # Records every 15 min so the resampled mean equals the input values + for m in range(0, 120, 15): + dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=False, + ) + + assert len(array) > 0 + # Reconstruct the pandas index that key_to_array used: origin=start_dt + idx = pd.date_range(start=start_dt, periods=len(array), freq="900s") + # First bucket must be exactly at start_dt (10:07) + assert idx[0].minute == 7 + assert idx[0].second == 0 + + def test_key_to_array_align_true_15min_buckets_on_quarter_hours(self, sequence): + """align_to_interval=True produces timestamps on :00/:15/:30/:45 boundaries.""" + # Off-boundary start: 10:07 + start_dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 12, 7, tz="UTC") + + # 1-min records across the window so resampling has data to work with + for m in range(0, 121): + dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert len(array) > 0 + # Reconstruct the epoch-aligned index that key_to_array must have used + import math + epoch = int(start_dt.timestamp()) + floored_epoch = (epoch // 900) * 900 # floor to nearest 15-min boundary + idx = pd.date_range( + start=pd.Timestamp(floored_epoch, unit="s", tz="UTC"), + periods=len(array), + freq="900s", + ) + # Every bucket must land on a :00/:15/:30/:45 minute mark with zero seconds + for ts in idx: + assert ts.minute % 15 == 0, ( + f"Bucket at {ts} is not on a 15-min boundary (minute={ts.minute})" + ) + assert ts.second == 0, ( + f"Bucket at {ts} has non-zero seconds ({ts.second})" + ) + + def test_key_to_array_align_true_1hour_buckets_on_the_hour(self, sequence): + """align_to_interval=True with 1-hour interval produces on-the-hour timestamps.""" + # Off-boundary start: 10:23 + start_dt = pendulum.datetime(2024, 6, 1, 10, 23, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 15, 23, tz="UTC") + + for m in range(0, 301, 15): + dt = pendulum.datetime(2024, 6, 1, 10, 23, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("1 hour"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert len(array) > 0 + epoch = int(start_dt.timestamp()) + floored_epoch = (epoch // 3600) * 3600 # floor to nearest hour + idx = pd.date_range( + start=pd.Timestamp(floored_epoch, unit="s", tz="UTC"), + periods=len(array), + freq="1h", + ) + for ts in idx: + assert ts.minute == 0, ( + f"Bucket at {ts} should be on the hour (minute={ts.minute})" + ) + assert ts.second == 0, ( + f"Bucket at {ts} has non-zero seconds ({ts.second})" + ) + + def test_key_to_array_align_true_when_start_already_on_boundary(self, sequence): + """align_to_interval=True is a no-op when start_datetime is exactly on a boundary. + + With start at a clean 15-min mark both modes must produce identical arrays. + """ + # Exactly on boundary: 10:00:00 + start_dt = pendulum.datetime(2024, 6, 1, 10, 0, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 12, 0, tz="UTC") + + for m in range(0, 121, 15): + dt = pendulum.datetime(2024, 6, 1, 10, 0, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + arr_aligned = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + arr_default = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=False, + ) + + assert len(arr_aligned) == len(arr_default) + np.testing.assert_array_almost_equal(arr_aligned, arr_default, decimal=6) + + def test_key_to_array_align_true_without_start_datetime(self, sequence): + """align_to_interval=True with no start_datetime must not raise. + + Without a query_start there is no origin to snap; behaviour falls back + to 'start_day' (same as default). No exception is expected. + """ + for m in range(0, 121, 15): + dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=None, + end_datetime=pendulum.datetime(2024, 6, 1, 12, 7, tz="UTC"), + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert isinstance(array, np.ndarray) + assert len(array) > 0 + + def test_key_to_array_align_true_output_within_requested_window(self, sequence): + """align_to_interval=True truncates output to [start_datetime, end_datetime). + + The epoch-floor origin may generate a bucket before start_datetime (e.g. 10:00 + when start is 10:07), but key_to_array must truncate it away. The surviving + buckets are verified directly by reconstructing the index from the first + surviving timestamp (the first epoch-aligned bucket >= start_datetime). + + Also checks that all surviving buckets are on 15-min clock boundaries. + """ + start_dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 13, 7, tz="UTC") + + for m in range(0, 181): + dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert len(array) > 0 + + # The first surviving bucket is the first epoch-aligned timestamp >= start_dt. + # Compute it the same way key_to_array does: floor then step forward if needed. + epoch = int(start_dt.timestamp()) + floored_epoch = (epoch // 900) * 900 + first_bucket = pd.Timestamp(floored_epoch, unit="s", tz="UTC") + if first_bucket < pd.Timestamp(start_dt): + first_bucket += pd.Timedelta(seconds=900) + + idx = pd.date_range(start=first_bucket, periods=len(array), freq="900s") + + start_pd = pd.Timestamp(start_dt) + end_pd = pd.Timestamp(end_dt) + for ts in idx: + assert ts >= start_pd, f"Bucket {ts} is before start_datetime {start_pd}" + assert ts < end_pd, f"Bucket {ts} is at or after end_datetime {end_pd}" + assert ts.minute % 15 == 0, f"Bucket {ts} is not on a 15-min boundary" + assert ts.second == 0, f"Bucket {ts} has non-zero seconds" + + def test_key_to_array_align_true_preserves_mean_values(self, sequence): + """align_to_interval=True does not corrupt resampled values. + + A constant-valued series must resample to the same constant regardless + of bucket alignment. + """ + # 1-min records with constant value 42.0, starting off-boundary + start_dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC") + end_dt = pendulum.datetime(2024, 6, 1, 12, 7, tz="UTC") + + for m in range(0, 121): + dt = pendulum.datetime(2024, 6, 1, 10, 7, tz="UTC").add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, 42.0)) + + array = sequence.key_to_array( + key="data_value", + start_datetime=start_dt, + end_datetime=end_dt, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert len(array) > 0 + for v in array: + if v is not None: + assert abs(v - 42.0) < 1e-6, f"Expected 42.0, got {v}" + + def test_key_to_array_align_true_compaction_call_pattern(self, sequence): + """Verify the call pattern used by _db_compact_tier produces clock-aligned timestamps. + + _db_compact_tier calls key_to_array with boundary='strict', fill_method='time', + align_to_interval=True on a window whose start has arbitrary sub-second precision. + All output buckets must land on 15-min boundaries so that compacted records are + stored at predictable, human-readable timestamps. + """ + # Non-round base time: 08:43 — chosen to expose any origin-alignment bug + base_dt = pendulum.datetime(2024, 6, 1, 8, 43, tz="UTC") + window_end = pendulum.datetime(2024, 6, 1, 11, 43, tz="UTC") + + for m in range(0, 181): + dt = base_dt.add(minutes=m) + sequence.insert_by_datetime(self.create_test_record(dt, float(m))) + + array = sequence.key_to_array( + key="data_value", + start_datetime=base_dt, + end_datetime=window_end, + interval=to_duration("15 minutes"), + fill_method="time", + boundary="strict", + align_to_interval=True, + ) + + assert len(array) > 0 + epoch = int(base_dt.timestamp()) + floored_epoch = (epoch // 900) * 900 + idx = pd.date_range( + start=pd.Timestamp(floored_epoch, unit="s", tz="UTC"), + periods=len(array), + freq="900s", + ) + for ts in idx: + assert ts.minute % 15 == 0, ( + f"Compacted record at {ts} is not on a 15-min boundary (minute={ts.minute})" + ) + assert ts.second == 0, ( + f"Compacted record at {ts} has non-zero seconds ({ts.second})" + ) def test_delete_by_datetime_range(self, sequence): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - record3 = self.create_test_record(datetime(2023, 11, 7), 1.0) - sequence.append(record1) - sequence.append(record2) - sequence.append(record3) + dt1 = to_datetime("2023-11-05") + dt2 = to_datetime("2023-11-06") + dt3 = to_datetime("2023-11-07") + record1 = self.create_test_record(dt1, 0.8) + record2 = self.create_test_record(dt2, 0.9) + record3 = self.create_test_record(dt3, 1.0) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + sequence.insert_by_datetime(record3) assert len(sequence) == 3 - sequence.delete_by_datetime( - start_datetime=datetime(2023, 11, 6), end_datetime=datetime(2023, 11, 7) - ) + sequence.delete_by_datetime(start_datetime=dt2, end_datetime=dt3) assert len(sequence) == 2 - assert sequence[0].date_time == to_datetime(datetime(2023, 11, 5)) - assert sequence[1].date_time == to_datetime(datetime(2023, 11, 7)) + assert sequence.records[0].date_time == dt1 + assert sequence.records[1].date_time == dt3 def test_delete_by_datetime_start(self, sequence): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence.append(record1) - sequence.append(record2) + dt1 = to_datetime("2023-11-05") + dt2 = to_datetime("2023-11-06") + record1 = self.create_test_record(dt1, 0.8) + record2 = self.create_test_record(dt2, 0.9) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) assert len(sequence) == 2 - sequence.delete_by_datetime(start_datetime=datetime(2023, 11, 6)) + sequence.delete_by_datetime(start_datetime=dt2) assert len(sequence) == 1 - assert sequence[0].date_time == to_datetime(datetime(2023, 11, 5)) + assert sequence.records[0].date_time == dt1 def test_delete_by_datetime_end(self, sequence): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence.append(record1) - sequence.append(record2) + dt1 = to_datetime("2023-11-05") + dt2 = to_datetime("2023-11-06") + record1 = self.create_test_record(dt1, 0.8) + record2 = self.create_test_record(dt2, 0.9) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) assert len(sequence) == 2 - sequence.delete_by_datetime(end_datetime=datetime(2023, 11, 6)) + sequence.delete_by_datetime(end_datetime=dt2) assert len(sequence) == 1 - assert sequence[0].date_time == to_datetime(datetime(2023, 11, 6)) - - def test_filter_by_datetime(self, sequence): - record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) - record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence.append(record1) - sequence.append(record2) - filtered_sequence = sequence.filter_by_datetime(start_datetime=datetime(2023, 11, 6)) - assert len(filtered_sequence) == 1 - assert filtered_sequence[0].date_time == to_datetime(datetime(2023, 11, 6)) + assert sequence.records[0].date_time == dt2 def test_to_dict(self, sequence): - record = self.create_test_record(datetime(2023, 11, 6), 0.8) - sequence.append(record) + dt = to_datetime("2023-11-06") + record = self.create_test_record(dt, 0.8) + sequence.insert_by_datetime(record) data_dict = sequence.to_dict() assert isinstance(data_dict, dict) - sequence_other = sequence.from_dict(data_dict) - assert sequence_other.model_dump() == sequence.model_dump() + # We need a new class - Sequences are singletons + sequence2 = DerivedSequence2.from_dict(data_dict) + assert sequence2.model_dump() == sequence.model_dump() def test_to_json(self, sequence): - record = self.create_test_record(datetime(2023, 11, 6), 0.8) - sequence.append(record) + dt = to_datetime("2023-11-06") + record = self.create_test_record(dt, 0.8) + sequence.insert_by_datetime(record) json_str = sequence.to_json() assert isinstance(json_str, str) assert "2023-11-06" in json_str @@ -799,14 +1110,14 @@ class TestDataSequence: json_str = sequence2.to_json() sequence = sequence.from_json(json_str) assert len(sequence) == len(sequence2) - assert sequence[0].date_time == sequence2[0].date_time - assert sequence[0].data_value == sequence2[0].data_value + assert sequence.records[0].date_time == sequence2.records[0].date_time + assert sequence.records[0].data_value == sequence2.records[0].data_value def test_key_to_value_exact_match(self, sequence): """Test key_to_value returns exact match when datetime matches a record.""" - dt = datetime(2023, 11, 5) + dt = to_datetime("2023-11-05") record = self.create_test_record(dt, 0.75) - sequence.append(record) + sequence.insert_by_datetime(record) result = sequence.key_to_value("data_value", dt) assert result == 0.75 @@ -814,20 +1125,20 @@ class TestDataSequence: """Test key_to_value returns value closest in time to the given datetime.""" record1 = self.create_test_record(datetime(2023, 11, 5, 12), 0.6) record2 = self.create_test_record(datetime(2023, 11, 6, 12), 0.9) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) dt = datetime(2023, 11, 6, 10) # closer to record2 - result = sequence.key_to_value("data_value", dt) + result = sequence.key_to_value("data_value", dt, time_window=to_duration("48 hours")) assert result == 0.9 def test_key_to_value_nearest_after(self, sequence): """Test key_to_value returns value nearest after the given datetime.""" record1 = self.create_test_record(datetime(2023, 11, 5, 10), 0.7) record2 = self.create_test_record(datetime(2023, 11, 5, 15), 0.8) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) dt = datetime(2023, 11, 5, 14) # closer to record2 - result = sequence.key_to_value("data_value", dt) + result = sequence.key_to_value("data_value", dt, time_window=to_duration("48 hours")) assert result == 0.8 def test_key_to_value_empty_sequence(self, sequence): @@ -838,7 +1149,7 @@ class TestDataSequence: def test_key_to_value_missing_key(self, sequence): """Test key_to_value returns None when key is missing in records.""" record = self.create_test_record(datetime(2023, 11, 5), None) - sequence.append(record) + sequence.insert_by_datetime(record) result = sequence.key_to_value("data_value", datetime(2023, 11, 5)) assert result is None @@ -846,16 +1157,16 @@ class TestDataSequence: """Test key_to_value skips records with None values.""" r1 = self.create_test_record(datetime(2023, 11, 5), None) r2 = self.create_test_record(datetime(2023, 11, 6), 1.0) - sequence.append(r1) - sequence.append(r2) - result = sequence.key_to_value("data_value", datetime(2023, 11, 5, 12)) + sequence.insert_by_datetime(r1) + sequence.insert_by_datetime(r2) + result = sequence.key_to_value("data_value", datetime(2023, 11, 5, 12), time_window=to_duration("48 hours")) assert result == 1.0 def test_key_to_dict(self, sequence): record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) data_dict = sequence.key_to_dict("data_value") assert isinstance(data_dict, dict) assert data_dict[to_datetime(datetime(2023, 11, 5), as_string=True)] == 0.8 @@ -864,8 +1175,8 @@ class TestDataSequence: def test_key_to_lists(self, sequence): record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) dates, values = sequence.key_to_lists("data_value") assert dates == [to_datetime(datetime(2023, 11, 5)), to_datetime(datetime(2023, 11, 6))] assert values == [0.8, 0.9] @@ -875,9 +1186,9 @@ class TestDataSequence: record1 = self.create_test_record("2024-01-01T12:00:00Z", 10) record2 = self.create_test_record("2024-01-01T13:00:00Z", 20) record3 = self.create_test_record("2024-01-01T14:00:00Z", 30) - sequence.append(record1) - sequence.append(record2) - sequence.append(record3) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + sequence.insert_by_datetime(record3) df = sequence.to_dataframe() @@ -892,9 +1203,9 @@ class TestDataSequence: record1 = self.create_test_record("2024-01-01T12:00:00Z", 10) record2 = self.create_test_record("2024-01-01T13:00:00Z", 20) record3 = self.create_test_record("2024-01-01T14:00:00Z", 30) - sequence.append(record1) - sequence.append(record2) - sequence.append(record3) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + sequence.insert_by_datetime(record3) start = to_datetime("2024-01-01T12:30:00Z") end = to_datetime("2024-01-01T14:00:00Z") @@ -910,8 +1221,8 @@ class TestDataSequence: """Test when no records match the given datetime filter.""" record1 = self.create_test_record("2024-01-01T12:00:00Z", 10) record2 = self.create_test_record("2024-01-01T13:00:00Z", 20) - sequence.append(record1) - sequence.append(record2) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) start = to_datetime("2024-01-01T14:00:00Z") # Start time after all records end = to_datetime("2024-01-01T15:00:00Z") @@ -935,9 +1246,9 @@ class TestDataSequence: record1 = self.create_test_record("2024-01-01T12:00:00Z", 10) record2 = self.create_test_record("2024-01-01T13:00:00Z", 20) record3 = self.create_test_record("2024-01-01T14:00:00Z", 30) - sequence.append(record1) - sequence.append(record2) - sequence.append(record3) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + sequence.insert_by_datetime(record3) end = to_datetime("2024-01-01T13:00:00Z") # Include only first record @@ -953,9 +1264,9 @@ class TestDataSequence: record1 = self.create_test_record("2024-01-01T12:00:00Z", 10) record2 = self.create_test_record("2024-01-01T13:00:00Z", 20) record3 = self.create_test_record("2024-01-01T14:00:00Z", 30) - sequence.append(record1) - sequence.append(record2) - sequence.append(record3) + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + sequence.insert_by_datetime(record3) start = to_datetime("2024-01-01T13:00:00Z") # Include last two records @@ -1018,11 +1329,13 @@ class TestDataProvider: def test_delete_by_datetime(self, provider, sample_start_datetime): """Test `delete_by_datetime` method for removing records by datetime range.""" # Add records to the provider for deletion testing - provider.records = [ + records = [ self.create_test_record(sample_start_datetime - to_duration("3 hours"), 1), self.create_test_record(sample_start_datetime - to_duration("1 hour"), 2), self.create_test_record(sample_start_datetime + to_duration("1 hour"), 3), ] + for record in records: + provider.insert_by_datetime(record) provider.delete_by_datetime( start_datetime=sample_start_datetime - to_duration("2 hours"), @@ -1036,50 +1349,175 @@ class TestDataProvider: ) -class TestDataImportProvider: +class NewTestDataImportProvider: + # Fixtures and helper functions @pytest.fixture def provider(self): """Fixture to provide an instance of DerivedDataImportProvider for testing.""" DerivedDataImportProvider.provider_enabled = True - DerivedDataImportProvider.provider_updated = False + DerivedDataImportProvider.provider_updated = True return DerivedDataImportProvider() - @pytest.mark.parametrize( - "start_datetime, value_count, expected_mapping_count", - [ - ("2024-11-10 00:00:00", 24, 24), # No DST in Germany - ("2024-08-10 00:00:00", 24, 24), # DST in Germany - ("2024-03-31 00:00:00", 24, 23), # DST change in Germany (23 hours/ day) - ("2024-10-27 00:00:00", 24, 25), # DST change in Germany (25 hours/ day) - ], - ) - def test_import_datetimes(self, provider, start_datetime, value_count, expected_mapping_count): - start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin") +# --------------------------------------------------------------------------- +# import_from_dict +# --------------------------------------------------------------------------- - value_datetime_mapping = provider.import_datetimes(start_datetime, value_count) + def test_import_from_dict_basic(self, provider): + data = { + "start_datetime": "2024-01-01 00:00:00", + "interval": "1 hour", + "power": [1, 2, 3], + } - assert len(value_datetime_mapping) == expected_mapping_count + provider.import_from_dict(data) - @pytest.mark.parametrize( - "start_datetime, value_count, expected_mapping_count", - [ - ("2024-11-10 00:00:00", 24, 24), # No DST in Germany - ("2024-08-10 00:00:00", 24, 24), # DST in Germany - ("2024-03-31 00:00:00", 24, 23), # DST change in Germany (23 hours/ day) - ("2024-10-27 00:00:00", 24, 25), # DST change in Germany (25 hours/ day) - ], - ) - def test_import_datetimes_utc( - self, set_other_timezone, provider, start_datetime, value_count, expected_mapping_count - ): - original_tz = set_other_timezone("Etc/UTC") - start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin") - assert start_datetime.timezone.name == "Europe/Berlin" + assert provider.records is not None + assert provider.records[0]["power"] == 1 + assert provider.records[1]["power"] == 2 - value_datetime_mapping = provider.import_datetimes(start_datetime, value_count) - assert len(value_datetime_mapping) == expected_mapping_count + def test_import_from_dict_default_start_and_interval(self, provider): + data = { + "power": [10, 20], + } + + provider.import_from_dict(data) + + assert len(provider._updates) == 2 + + + def test_import_from_dict_with_prefix(self, provider): + data = { + "load_power": [1, 2], + "other": [5, 6], + } + + provider.import_from_dict(data, key_prefix="load") + + assert len(provider._updates) == 2 + assert all(update[1] == "load_power" for update in provider._updates) + + + def test_import_from_dict_mismatching_lengths(self, provider): + data = { + "power": [1, 2], + "voltage": [1], + } + + with pytest.raises(ValueError): + provider.import_from_dict(data) + + + def test_import_from_dict_invalid_interval(self, provider): + data = { + "interval": "17 minutes", # does not divide hour + "power": [1, 2, 3], + } + + with pytest.raises(NotImplementedError): + provider.import_from_dict(data) + + + def test_import_from_dict_skips_none_and_nan(self, provider): + data = { + "power": [1, None, np.nan, 4], + } + + provider.import_from_dict(data) + + # only 1 and 4 should be written + assert len(provider._updates) == 2 + assert provider._updates[0][2] == 1 + assert provider._updates[1][2] == 4 + + + def test_import_from_dict_invalid_value_type(self, provider): + data = { + "power": "not a list" + } + + with pytest.raises(ValueError): + provider.import_from_dict(data) + + +# --------------------------------------------------------------------------- +# import_from_dataframe +# --------------------------------------------------------------------------- + + def test_import_from_dataframe_with_datetime_index(self, provider): + index = pd.date_range("2024-01-01", periods=3, freq="H") + df = pd.DataFrame({"power": [1, 2, 3]}, index=index) + + provider.import_from_dataframe(df) + + assert len(provider._updates) == 3 + assert provider._updates[0][2] == 1 + + + def test_import_from_dataframe_without_datetime_index(self, provider): + df = pd.DataFrame({"power": [5, 6, 7]}) + + provider.import_from_dataframe( + df, + start_datetime=datetime(2024, 1, 1), + interval=to_duration("1 hour"), + ) + + assert len(provider._updates) == 3 + + + def test_import_from_dataframe_prefix_filter(self, provider): + df = pd.DataFrame({ + "load_power": [1, 2], + "other": [3, 4], + }) + + provider.import_from_dataframe(df, key_prefix="load") + + assert len(provider._updates) == 2 + assert all(update[1] == "load_power" for update in provider._updates) + + + def test_import_from_dataframe_invalid_input(self, provider): + with pytest.raises(ValueError): + provider.import_from_dataframe("not a dataframe") + + +# --------------------------------------------------------------------------- +# import_from_json +# --------------------------------------------------------------------------- + + def test_import_from_json_simple_dict(self, provider): + json_str = json.dumps({ + "power": [1, 2, 3] + }) + + provider.import_from_json(json_str) + + assert len(provider._updates) == 3 + + + def test_import_from_json_invalid(self, provider): + with pytest.raises(ValueError): + provider.import_from_json("this is not json") + + +# --------------------------------------------------------------------------- +# import_from_file +# --------------------------------------------------------------------------- + + def test_import_from_file(self, provider, tmp_path): + file_path = tmp_path / "data.json" + + file_path.write_text(json.dumps({ + "power": [1, 2] + })) + + provider.import_from_file(file_path) + + assert len(provider._updates) == 2 + class TestDataContainer: @@ -1095,11 +1533,11 @@ class TestDataContainer: record2 = self.create_test_record(datetime(2023, 11, 6), 2) record3 = self.create_test_record(datetime(2023, 11, 7), 3) provider = DerivedDataProvider() - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) assert len(provider) == 0 - provider.append(record1) - provider.append(record2) - provider.append(record3) + provider.insert_by_datetime(record1) + provider.insert_by_datetime(record2) + provider.insert_by_datetime(record3) assert len(provider) == 3 container = DerivedDataContainer() container.providers.clear() diff --git a/tests/test_dataabccompact.py b/tests/test_dataabccompact.py new file mode 100644 index 0000000..1582387 --- /dev/null +++ b/tests/test_dataabccompact.py @@ -0,0 +1,1114 @@ +"""Compaction tests for DataSequence and DataContainer. + +These tests sit on top of the full DataSequence / DataProvider / DataContainer +stack (dataabc.py) and exercise compaction end-to-end, including the +DataContainer delegation path. + +A temporary SQLite database is configured for the entire test session via the +`configure_database` autouse fixture so that DataSequence instances — which +use the real Database singleton via DatabaseMixin — have a working backend. +""" + +from typing import List, Optional, Type + +import numpy as np +import pytest +from pydantic import Field + +from akkudoktoreos.core.dataabc import ( + DataContainer, + DataProvider, + DataRecord, + DataSequence, +) +from akkudoktoreos.core.database import Database +from akkudoktoreos.core.databaseabc import DatabaseTimestamp +from akkudoktoreos.utils.datetimeutil import DateTime, to_datetime, to_duration + +# --------------------------------------------------------------------------- +# Minimal concrete record / sequence / provider +# --------------------------------------------------------------------------- + + +class EnergyRecord(DataRecord): + """Simple numeric record for compaction testing.""" + + power_w: Optional[float] = Field( + default=None, json_schema_extra={"description": "Power in Watts"} + ) + price_eur: Optional[float] = Field( + default=None, json_schema_extra={"description": "Price in EUR/kWh"} + ) + + +class EnergySequence(DataSequence): + records: List[EnergyRecord] = Field( + default_factory=list, + json_schema_extra={"description": "List of energy records"}, + ) + + @classmethod + def record_class(cls) -> Type[EnergyRecord]: + return EnergyRecord + + def db_namespace(self) -> str: + return "energy_test" + + +class PriceSequence(DataSequence): + """Price data — overrides tiers to keep 15-min resolution for 2 weeks.""" + + records: List[EnergyRecord] = Field( + default_factory=list, + json_schema_extra={"description": "List of price records"}, + ) + + @classmethod + def record_class(cls) -> Type[EnergyRecord]: + return EnergyRecord + + def db_namespace(self) -> str: + return "price_test" + + def db_compact_tiers(self): + # Price data: skip first tier (already at target resolution for 2 weeks) + return [(to_duration("14 days"), to_duration("1 hour"))] + + +class EnergyProvider(DataProvider): + records: List[EnergyRecord] = Field( + default_factory=list, + json_schema_extra={"description": "List of energy records"}, + ) + + @classmethod + def record_class(cls) -> Type[EnergyRecord]: + return EnergyRecord + + def provider_id(self) -> str: + return "EnergyProvider" + + def enabled(self) -> bool: + return True + + def _update_data(self, force_update=False) -> None: + pass + + def db_namespace(self) -> str: + return self.provider_id() + + +class PriceProvider(DataProvider): + records: List[EnergyRecord] = Field( + default_factory=list, + json_schema_extra={"description": "List of price records"}, + ) + + @classmethod + def record_class(cls) -> Type[EnergyRecord]: + return EnergyRecord + + def provider_id(self) -> str: + return "PriceProvider" + + def enabled(self) -> bool: + return True + + def _update_data(self, force_update=False) -> None: + pass + + def db_namespace(self) -> str: + return self.provider_id() + + def db_compact_tiers(self): + return [(to_duration("14 days"), to_duration("1 hour"))] + + +class EnergyContainer(DataContainer): + providers: List[DataProvider] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _aligned_base(now: DateTime, interval_minutes: int = 15) -> DateTime: + """Floor ``now`` to the nearest ``interval_minutes`` boundary. + + All fixtures that feed _fill_sequence use this so that compacted timestamps + are predictably on clock-round boundaries and tests are deterministic. + """ + interval_sec = interval_minutes * 60 + epoch = int(now.timestamp()) + return now.subtract(seconds=epoch % interval_sec).set(microsecond=0) + + +def _fill_sequence( + seq: DataSequence, + base: DateTime, + count: int, + interval_minutes: int, + power_w: float = 1000.0, + price_eur: float = 0.25, +) -> None: + """Insert ``count`` EnergyRecords spaced ``interval_minutes`` apart. + + ``base`` should be interval-aligned (use ``_aligned_base``) so that + compacted bucket timestamps are deterministic across all tests. + """ + for i in range(count): + dt = base.add(minutes=i * interval_minutes) + rec = EnergyRecord(date_time=dt, power_w=power_w + i, price_eur=price_eur) + seq.db_insert_record(rec) + seq.db_save_records() + + +def _reset_singletons() -> None: + """Reset all singleton classes used in these tests. + + DataProvider and DataSequence inherit SingletonMixin, meaning each subclass + only ever has one instance. Without resetting between tests, state from one + test (records, compaction metadata, monkey-patches) leaks into the next. + """ + for cls in (EnergySequence, PriceSequence, EnergyProvider, PriceProvider, EnergyContainer): + try: + cls.reset_instance() + except Exception: + pass + + +@pytest.fixture(autouse=True) +def configure_database(tmp_path): + """Configure a fresh temporary SQLite database for every test. + + DataSequence uses the real Database singleton via DatabaseMixin. + Without an open database backend, count_records() and all other DB + operations raise RuntimeError('Database not configured'). + + This fixture: + 1. Resets the Database singleton so the previous test's state is gone. + 2. Points the database config at a fresh per-test tmp_path directory. + 3. Opens a SQLite backend. + 4. Resets all sequence/provider/container singletons before and after. + 5. Tears everything down cleanly after each test. + """ + _reset_singletons() + + # Reset the Database singleton itself + Database.reset_instance() + + # Patch config to use SQLite in tmp_path + db = Database() + db.config.database.provider = "SQLite" + db.config.general.data_folder_path = tmp_path + db.open() + + yield + + # Teardown + _reset_singletons() + try: + Database.reset_instance() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def energy_seq(): + """Fresh EnergySequence with no data.""" + return EnergySequence() + + +@pytest.fixture +def dense_energy_seq(): + """EnergySequence with 4 weeks of 15-min records (~2688 records). + + The base timestamp is floored to a 15-min boundary so compacted bucket + timestamps are deterministic and on clock-round marks. + """ + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(seq, base, count=4 * 7 * 24 * 4, interval_minutes=15) + return seq, now + + +@pytest.fixture +def dense_price_seq(): + """PriceSequence with 4 weeks of 15-min records. + + The base timestamp is floored to a 15-min boundary so compacted bucket + timestamps are deterministic and on clock-round marks. + """ + seq = PriceSequence() + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(seq, base, count=4 * 7 * 24 * 4, interval_minutes=15) + return seq, now + + +@pytest.fixture +def energy_container(energy_seq): + """DataContainer with one EnergyProvider and one PriceProvider.""" + ep = EnergyProvider() + pp = PriceProvider() + container = EnergyContainer(providers=[ep, pp]) + return container, ep, pp + + +# --------------------------------------------------------------------------- +# DataSequence — tier configuration +# --------------------------------------------------------------------------- + + +class TestDataSequenceCompactTiers: + + def test_default_tiers_two_entries(self, energy_seq): + tiers = energy_seq.db_compact_tiers() + assert len(tiers) == 2 + + def test_default_first_tier_2h_15min(self, energy_seq): + tiers = energy_seq.db_compact_tiers() + age_sec = tiers[0][0].total_seconds() + interval_sec = tiers[0][1].total_seconds() + assert age_sec == 2 * 3600 + assert interval_sec == 15 * 60 + + def test_default_second_tier_2weeks_1h(self, energy_seq): + tiers = energy_seq.db_compact_tiers() + age_sec = tiers[1][0].total_seconds() + interval_sec = tiers[1][1].total_seconds() + assert age_sec == 14 * 24 * 3600 + assert interval_sec == 3600 + + def test_price_sequence_overrides_to_single_tier(self): + seq = PriceSequence() + tiers = seq.db_compact_tiers() + assert len(tiers) == 1 + assert tiers[0][0].total_seconds() == 14 * 24 * 3600 + assert tiers[0][1].total_seconds() == 3600 + + def test_empty_tiers_disables_compaction(self): + class NoCompact(EnergySequence): + def db_compact_tiers(self): + return [] + + seq = NoCompact() + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(seq, base, count=500, interval_minutes=15) + assert seq.db_compact() == 0 + + +# --------------------------------------------------------------------------- +# DataSequence — compaction behaviour +# --------------------------------------------------------------------------- + + +class TestDataSequenceCompact: + + def test_empty_sequence_returns_zero(self, energy_seq): + assert energy_seq.db_compact() == 0 + + def test_dense_data_reduces_count(self, dense_energy_seq): + seq, _ = dense_energy_seq + before = seq.db_count_records() + deleted = seq.db_compact() + assert deleted > 0 + assert seq.db_count_records() < before + + def test_all_fields_compacted(self, dense_energy_seq): + """Both power_w and price_eur should be present on compacted records.""" + seq, now = dense_energy_seq + seq.db_compact() + + cutoff = now.subtract(weeks=2) + old_records = [r for r in seq.records if r.date_time and r.date_time < cutoff] + + assert len(old_records) > 0 + for rec in old_records: + assert rec.power_w is not None, "power_w must survive compaction" + assert rec.price_eur is not None, "price_eur must survive compaction" + + def test_recent_records_untouched(self, dense_energy_seq): + """Records within 2 hours of now must not be compacted.""" + seq, now = dense_energy_seq + cutoff = now.subtract(hours=2) + + # Snapshot recent values + recent_before = { + DatabaseTimestamp.from_datetime(r.date_time): r.power_w + for r in seq.records + if r.date_time and r.date_time >= cutoff + } + + seq.db_compact() + + recent_after = { + DatabaseTimestamp.from_datetime(r.date_time): r.power_w + for r in seq.records + if r.date_time and r.date_time >= cutoff + } + + assert recent_before == recent_after + + def test_idempotent(self, dense_energy_seq): + seq, _ = dense_energy_seq + seq.db_compact() + after_first = seq.db_count_records() + + seq.db_compact() + after_second = seq.db_count_records() + + assert after_first == after_second + + def test_price_sequence_preserves_15min_in_recent_2weeks(self, dense_price_seq): + """PriceSequence keeps 15-min resolution for data younger than 2 weeks.""" + seq, now = dense_price_seq + seq.db_compact() + + two_weeks_ago = now.subtract(weeks=2) + recent_records = [ + r for r in seq.records + if r.date_time and r.date_time >= two_weeks_ago + ] + # Should still have ~4 records per hour = 15-min resolution + if len(recent_records) > 1: + diffs = [] + sorted_recs = sorted(recent_records, key=lambda r: r.date_time) + for i in range(1, min(len(sorted_recs), 10)): + diff = (sorted_recs[i].date_time - sorted_recs[i - 1].date_time).total_seconds() + diffs.append(diff) + # Average spacing should be ~15 min, not 60 min + avg_spacing = sum(diffs) / len(diffs) + assert avg_spacing <= 20 * 60, ( + f"Expected ~15min spacing in recent 2 weeks, got {avg_spacing/60:.1f} min" + ) + + def test_price_sequence_compacts_older_than_2weeks_to_1h(self, dense_price_seq): + """PriceSequence compacts data older than 2 weeks to 1-hour resolution.""" + seq, now = dense_price_seq + seq.db_compact() + + two_weeks_ago = now.subtract(weeks=2) + old_records = sorted( + [r for r in seq.records if r.date_time and r.date_time < two_weeks_ago], + key=lambda r: r.date_time, + ) + + if len(old_records) > 1: + diffs = [] + for i in range(1, min(len(old_records), 10)): + diff = (old_records[i].date_time - old_records[i - 1].date_time).total_seconds() + diffs.append(diff) + avg_spacing = sum(diffs) / len(diffs) + assert avg_spacing >= 50 * 60, ( + f"Expected ~1h spacing for old price data, got {avg_spacing/60:.1f} min" + ) + + def test_compact_with_custom_tiers_argument(self, dense_energy_seq): + """db_compact(compact_tiers=...) overrides the instance's tiers.""" + seq, _ = dense_energy_seq + before = seq.db_count_records() + + deleted = seq.db_compact( + compact_tiers=[(to_duration("1 day"), to_duration("1 hour"))] + ) + + assert deleted > 0 + assert seq.db_count_records() < before + + def test_compacted_timestamps_are_clock_aligned(self, dense_energy_seq): + """All timestamps produced by compaction must sit on UTC clock boundaries. + + _db_compact_tier floors its cutoff timestamps to interval boundaries, so + the boundary between tiers is not exactly ``now - age`` but the floored + version of it. We compute the same floored cutoffs here. + + - Records older than floored 2-week cutoff → multiple of 3600 s + - Records in floored 2h..2week band → multiple of 900 s + - Records younger than floored 2h cutoff → unchanged + """ + seq, now = dense_energy_seq + seq.db_compact() + + # _db_compact_tier floors new_cutoff from db_max, not from wall-clock now. + # Compute the same floored cutoffs that the implementation used. + _, db_max_ts = seq.db_timestamp_range() + # DatabaseTimestamp already imported at top of file + db_max_epoch = int(DatabaseTimestamp.to_datetime(db_max_ts).timestamp()) + two_weeks_cutoff_epoch = ((db_max_epoch - 14*24*3600) // 3600) * 3600 + two_hours_cutoff_epoch = ((db_max_epoch - 2*3600) // 900) * 900 + + for rec in seq.records: + if rec.date_time is None: + continue + epoch = int(rec.date_time.timestamp()) + if epoch < two_weeks_cutoff_epoch: + assert epoch % 3600 == 0, ( + f"Old record {rec.date_time} not on hour boundary" + ) + elif epoch < two_hours_cutoff_epoch: + assert epoch % 900 == 0, ( + f"Mid record {rec.date_time} not on 15-min boundary" + ) + + +# --------------------------------------------------------------------------- +# DataSequence — data integrity after compaction +# --------------------------------------------------------------------------- + + +class TestDataSequenceCompactIntegrity: + + @staticmethod + def _tier_cutoff(now, age_seconds: int, interval_seconds: int): + """Compute the floored compaction cutoff the same way _db_compact_tier does. + + _db_compact_tier floors new_cutoff_dt to the interval boundary, so + ``newest - age_threshold`` rounded down. Tests must use the same value + to correctly classify which tier a record falls into. + """ + import math + raw_epoch = int(now.subtract(seconds=age_seconds).timestamp()) + floored_epoch = (raw_epoch // interval_seconds) * interval_seconds + return now.__class__.fromtimestamp(floored_epoch, tz=now.tzinfo) + + def test_constant_power_preserved(self): + """Mean resampling of a constant must equal the constant.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + # Use aligned base so bucket boundaries are deterministic + base = _aligned_base(now.subtract(hours=6), interval_minutes=15) + + for i in range(6 * 60): # 1-min records for 6 hours + dt = base.add(minutes=i) + seq.db_insert_record(EnergyRecord(date_time=dt, power_w=500.0, price_eur=0.30)) + seq.db_save_records() + + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + cutoff = now.subtract(hours=2) + for rec in seq.records: + if rec.date_time and rec.date_time < cutoff: + assert rec.power_w == pytest.approx(500.0, abs=1e-3) + assert rec.price_eur == pytest.approx(0.30, abs=1e-6) + + def test_record_count_monotonically_decreases(self): + """Each successive tier run should never increase record count.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(seq, base, count=4 * 7 * 24 * 4, interval_minutes=15) + + counts = [seq.db_count_records()] + for age, interval in reversed(seq.db_compact_tiers()): + seq._db_compact_tier(age, interval) + counts.append(seq.db_count_records()) + + for i in range(1, len(counts)): + assert counts[i] <= counts[i - 1], ( + f"Record count increased from {counts[i-1]} to {counts[i]} at tier {i}" + ) + + def test_no_duplicate_timestamps_after_compaction(self, dense_energy_seq): + """Compaction must not create duplicate timestamps.""" + seq, _ = dense_energy_seq + seq.db_compact() + + timestamps = [ + DatabaseTimestamp.from_datetime(r.date_time) + for r in seq.records + if r.date_time is not None + ] + assert len(timestamps) == len(set(timestamps)), "Duplicate timestamps after compaction" + + def test_timestamps_remain_sorted(self, dense_energy_seq): + """Records must remain in ascending order after compaction.""" + seq, _ = dense_energy_seq + seq.db_compact() + + dts = [r.date_time for r in seq.records if r.date_time is not None] + assert dts == sorted(dts) + + def test_compacted_old_timestamps_on_1h_boundaries(self, dense_energy_seq): + """Records older than the floored 2-week cutoff must be on whole-hour UTC boundaries. + + _db_compact_tier floors new_cutoff to the interval boundary, so we must + use the same floored cutoff to decide which records were compacted by the + 1-hour tier. Records between the floored and raw cutoff may still be at + 15-min resolution from the previous tier. + """ + seq, now = dense_energy_seq + seq.db_compact() + + # _db_compact_tier floors new_cutoff from db_max (the newest record), + # not from wall-clock now. Derive the same floored cutoff here. + _, db_max_ts = seq.db_timestamp_range() + # DatabaseTimestamp already imported at top of file + db_max_epoch = int(DatabaseTimestamp.to_datetime(db_max_ts).timestamp()) + two_weeks_cutoff_epoch = ((db_max_epoch - 14*24*3600) // 3600) * 3600 + two_weeks_cutoff_dt = DateTime.fromtimestamp(two_weeks_cutoff_epoch, tz="UTC") + + old_records = [r for r in seq.records if r.date_time and r.date_time < two_weeks_cutoff_dt] + + assert len(old_records) > 0, "Expected compacted records older than 2-week floored cutoff" + for rec in old_records: + epoch = int(rec.date_time.timestamp()) + assert epoch % 3600 == 0, ( + f"Old record at {rec.date_time} is not on an hour boundary" + ) + + def test_compacted_mid_timestamps_on_15min_boundaries(self): + """Records compacted by the 15-min tier must land on 15-min UTC boundaries. + + We run _db_compact_tier directly with the 2h/15min tier on a sequence + of 1-min records spanning 6 hours, then verify every compacted record + sits on a :00/:15/:30/:45 UTC mark. + + The implementation computes new_cutoff as floor(newest - age, 900). + We replicate that exact calculation to identify which records were in + the compaction window. + """ + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(hours=6), interval_minutes=15) + + # 1-min records for 6 hours; newest record is at base + 359 min + for i in range(6 * 60): + dt = base.add(minutes=i) + seq.db_insert_record(EnergyRecord(date_time=dt, power_w=500.0, price_eur=0.30)) + seq.db_save_records() + + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + # Replicate the implementation's floored cutoff exactly: + # newest_dt = last inserted record = base + 359min + # new_cutoff = floor(newest_dt - 2h, 900) + newest_dt = base.add(minutes=6 * 60 - 1) + raw_cutoff_epoch = int(newest_dt.subtract(hours=2).timestamp()) + window_end_epoch = (raw_cutoff_epoch // 900) * 900 + + # Records before window_end_epoch must all be on 15-min boundaries + compacted = [ + r for r in seq.records + if r.date_time is not None + and int(r.date_time.timestamp()) < window_end_epoch + ] + + assert len(compacted) > 0, ( + f"Expected compacted records before window_end={window_end_epoch}; " + f"got records at {[int(r.date_time.timestamp()) for r in seq.records if r.date_time]}" + ) + for rec in compacted: + assert rec.date_time is not None + epoch = int(rec.date_time.timestamp()) + assert epoch % 900 == 0, ( + f"15-min-tier record at {rec.date_time} (epoch={epoch}) " + f"is not on a 15-min boundary (epoch % 900 = {epoch % 900})" + ) + + def test_no_compacted_timestamps_between_boundaries(self, dense_energy_seq): + """After compaction no record timestamp must fall between expected bucket boundaries. + + Records older than the floored 2-week cutoff (processed by the 1h tier) + must be on hour marks. Records in the 15-min band must be on 15-min marks. + """ + seq, now = dense_energy_seq + seq.db_compact() + + # Derive floored cutoffs from db_max — same reference as the implementation. + _, db_max_ts = seq.db_timestamp_range() + # DatabaseTimestamp already imported at top of file + db_max_epoch = int(DatabaseTimestamp.to_datetime(db_max_ts).timestamp()) + two_weeks_cutoff_epoch = ((db_max_epoch - 14*24*3600) // 3600) * 3600 + two_hours_cutoff_epoch = ((db_max_epoch - 2*3600) // 900) * 900 + + for rec in seq.records: + if rec.date_time is None: + continue + epoch = int(rec.date_time.timestamp()) + if epoch < two_weeks_cutoff_epoch: + assert epoch % 3600 == 0, ( + f"Record at {rec.date_time} is not hour-aligned in 1h-tier region" + ) + elif epoch < two_hours_cutoff_epoch: + assert epoch % (15 * 60) == 0, ( + f"Record at {rec.date_time} is not 15min-aligned in 15min-tier region" + ) + + +# --------------------------------------------------------------------------- +# DataContainer — delegation +# --------------------------------------------------------------------------- + + +class TestDataContainerCompact: + + def test_compact_delegates_to_all_providers(self, energy_container): + container, ep, pp = energy_container + now = to_datetime().in_timezone("UTC") + + # Fill both providers with 4 weeks of 15-min data + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(ep, base, count=4 * 7 * 24 * 4, interval_minutes=15) + _fill_sequence(pp, base, count=4 * 7 * 24 * 4, interval_minutes=15) + + ep_before = ep.db_count_records() + pp_before = pp.db_count_records() + + container.db_compact() + + assert ep.db_count_records() < ep_before, "EnergyProvider records should be compacted" + assert pp.db_count_records() < pp_before, "PriceProvider records should be compacted" + + def test_compact_empty_container_no_error(self): + container = EnergyContainer(providers=[]) + container.db_compact() # must not raise + + def test_compact_provider_tiers_respected(self, energy_container): + """PriceProvider with single 2-week tier must not compact recent 15-min data.""" + container, ep, pp = energy_container + now = to_datetime().in_timezone("UTC") + + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(pp, base, count=4 * 7 * 24 * 4, interval_minutes=15) + + container.db_compact() + + # Price data in last 2 weeks should still be at 15-min resolution + two_weeks_ago = now.subtract(weeks=2) + recent = sorted( + [r for r in pp.records if r.date_time and r.date_time >= two_weeks_ago], + key=lambda r: r.date_time, + ) + if len(recent) > 1: + diff = (recent[1].date_time - recent[0].date_time).total_seconds() + assert diff <= 20 * 60, ( + f"PriceProvider recent data should be ~15min, got {diff/60:.1f} min" + ) + + def test_compact_raises_on_provider_failure(self): + """A provider that raises during compaction must bubble up as RuntimeError. + + Monkey-patching is blocked by Pydantic v2's __setattr__ validation, so + we use a subclass that overrides db_compact instead. + """ + class BrokenProvider(EnergyProvider): + def db_compact(self, *args, **kwargs): + raise ValueError("simulated failure") + + def provider_id(self) -> str: + # Distinct id so it doesn't collide with EnergyProvider singleton + return "BrokenProvider" + + def db_namespace(self) -> str: + return self.provider_id() + + bp = BrokenProvider() + container = EnergyContainer(providers=[bp]) + + with pytest.raises(RuntimeError, match="fails on db_compact"): + container.db_compact() + + def test_compact_idempotent_on_container(self, energy_container): + container, ep, pp = energy_container + now = to_datetime().in_timezone("UTC") + base = _aligned_base(now.subtract(weeks=4), interval_minutes=15) + _fill_sequence(ep, base, count=4 * 7 * 24 * 4, interval_minutes=15) + _fill_sequence(pp, base, count=4 * 7 * 24 * 4, interval_minutes=15) + + container.db_compact() + ep_after_first = ep.db_count_records() + pp_after_first = pp.db_count_records() + + container.db_compact() + assert ep.db_count_records() == ep_after_first + assert pp.db_count_records() == pp_after_first + + +# --------------------------------------------------------------------------- +# Sparse guard — DataSequence level +# --------------------------------------------------------------------------- +# +# The sparse guard distinguishes three cases: +# +# 1. Sparse + already aligned → skip entirely (deleted=0, count unchanged) +# 2. Sparse + misaligned → snap timestamps in place (deleted>0, but +# count stays the same or decreases if two +# records collide on the same bucket) +# 3. Sparse collision → two records snap to the same bucket; values +# are merged key-by-key; count decreases by 1 +# --------------------------------------------------------------------------- + + +class TestDataSequenceSparseGuard: + + # ------------------------------------------------------------------ + # Case 1: sparse + already aligned → pure skip + # ------------------------------------------------------------------ + + def test_sparse_aligned_data_not_modified(self): + """Sparse records that already sit on interval boundaries must not be touched. + + deleted must be 0 and record count must be unchanged. + """ + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=4) + + # Insert exactly 3 records, each snapped to a whole hour (aligned) + for offset_days in [0, 14, 27]: + raw = base.add(days=offset_days) + # Floor to nearest hour boundary so timestamp is already aligned + aligned = raw.set(minute=0, second=0, microsecond=0) + seq.db_insert_record(EnergyRecord(date_time=aligned, power_w=100.0)) + seq.db_save_records() + + before = seq.db_count_records() + deleted = seq.db_compact() + + assert deleted == 0, "Aligned sparse records must not be deleted" + assert seq.db_count_records() == before, "Record count must not change" + + def test_sparse_aligned_data_values_untouched(self): + """Values of aligned sparse records must be preserved exactly.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=4).set(minute=0, second=0, microsecond=0) + + seq.db_insert_record(EnergyRecord(date_time=base, power_w=42.0, price_eur=0.99)) + seq.db_save_records() + + seq.db_compact() + + remaining = [r for r in seq.records if r.date_time == base] + assert len(remaining) == 1 + assert remaining[0].power_w == pytest.approx(42.0) + assert remaining[0].price_eur == pytest.approx(0.99) + + # ------------------------------------------------------------------ + # Case 2: sparse + misaligned → timestamp snapping + # ------------------------------------------------------------------ + + @staticmethod + def _make_snapping_seq(now, offsets_minutes, interval_minutes=10, age_minutes=30): + """Build a sequence guaranteed to enter the sparse-snapping path. + + Key insight: _db_compact_tier measures age_threshold from db_max (the + newest record in the database), not from wall-clock now. We therefore + insert a "newest anchor" record 1 second before now so that + db_max ≈ now, making cutoff = db_max - age_threshold ≈ now - age_minutes. + + The test records are placed at now - (age_minutes + margin) + offset, + which puts them clearly before the cutoff and inside the compaction window. + + resampled_count = age_minutes / interval_minutes (the window width in + buckets). We require len(offsets_minutes) > resampled_count so the + snapping path is entered rather than the pure-skip path. + + Returns (seq, age_threshold, target_interval, record_datetimes). + """ + age_td = to_duration(f"{age_minutes} minutes") + interval_td = to_duration(f"{interval_minutes} minutes") + interval_sec = interval_minutes * 60 + + # Margin must be larger than the maximum offset so that ALL test records + # land before window_end = floor(now - age_minutes, interval_sec). + # We need: base + max(offsets) < now - age_minutes + # => now - (age_minutes + margin) + max(offsets) < now - age_minutes + # => max(offsets) < margin + # Use margin = max(offsets_minutes) + 2*interval_minutes + 1 (generous). + max_offset = max(offsets_minutes) if offsets_minutes else 0 + margin = max_offset + 2 * interval_minutes + 1 + + # Floor base to interval boundary so snapping arithmetic is exact + raw_base = now.subtract(minutes=age_minutes + margin).set(second=0, microsecond=0) + base_epoch = int(raw_base.timestamp()) + base = raw_base.subtract(seconds=base_epoch % interval_sec) + + seq = EnergySequence() + dts = [] + for off in offsets_minutes: + dt = base.add(minutes=off) + seq.db_insert_record(EnergyRecord(date_time=dt, power_w=float(off * 10))) + dts.append(dt) + + # Newest anchor: makes db_max ≈ now so cutoff = now - age_threshold + anchor = now.subtract(seconds=1) + seq.db_insert_record(EnergyRecord(date_time=anchor, power_w=0.0)) + seq.db_save_records() + return seq, age_td, interval_td, dts + + def test_sparse_misaligned_records_are_snapped(self): + """Sparse misaligned records must be moved to the nearest boundary. + + Uses a tight window (30 min age, 10 min interval → 3 resampled buckets) + with 4 misaligned records so existing_count(4) > resampled_count(3) and + the snapping path is entered deterministically. + """ + now = to_datetime().in_timezone("UTC") + # 4 records at :03, :08, :13, :18 — all misaligned for a 10-min interval + seq, age_td, interval_td, _ = self._make_snapping_seq( + now, offsets_minutes=[3, 8, 13, 18] + ) + # before includes the anchor record which is NOT in the compaction window + # and therefore NOT deleted. Only the 4 test records are in-window. + n_test_records = len([3, 8, 13, 18]) # offsets_minutes + deleted = seq._db_compact_tier(age_td, interval_td) + after = seq.db_count_records() + + assert deleted == n_test_records, ( + f"All {n_test_records} in-window records must be deleted (whole-window delete); " + f"got deleted={deleted}" + ) + # Net count after: anchor(1) + snapped buckets re-inserted. + # Implementation uses FLOOR division: (epoch // interval_sec) * interval_sec + # offsets [3,8,13,18] with interval=10min map to buckets: + # 3 // 10 = 0 → :00 + # 8 // 10 = 0 → :00 (collision with :03) + # 13 // 10 = 1 → :10 + # 18 // 10 = 1 → :10 (collision with :13) + # → 2 unique buckets + interval_minutes = 10 + n_snapped = len({(off // interval_minutes) * interval_minutes for off in [3, 8, 13, 18]}) + assert after == 1 + n_snapped, ( + f"Expected 1 anchor + {n_snapped} snapped buckets = {1 + n_snapped} records; " + f"got {after}" + ) + + def test_sparse_misaligned_timestamps_become_aligned(self): + """After snapping, in-window timestamps must be on the target interval boundary. + + The anchor record lives outside the compaction window (it is younger than + age_threshold) and is intentionally misaligned — it must NOT be checked. + """ + now = to_datetime().in_timezone("UTC") + interval_minutes = 10 + age_minutes = 30 + seq, age_td, interval_td, dts = self._make_snapping_seq( + now, offsets_minutes=[3, 8, 13, 18], interval_minutes=interval_minutes, + age_minutes=age_minutes, + ) + seq._db_compact_tier(age_td, interval_td) + + # Compute window_end the same way _db_compact_tier does + # (anchor is db_max; raw_cutoff = anchor - age_threshold ≈ now - 30min) + anchor_epoch = int(now.subtract(seconds=1).timestamp()) + raw_cutoff_epoch = anchor_epoch - age_minutes * 60 + window_end_epoch = (raw_cutoff_epoch // (interval_minutes * 60)) * (interval_minutes * 60) + + interval_sec = interval_minutes * 60 + for rec in seq.records: + if rec.date_time is None: + continue + epoch = int(rec.date_time.timestamp()) + if epoch >= window_end_epoch: + continue # anchor or other post-cutoff record — not compacted + assert epoch % interval_sec == 0, ( + f"Snapped timestamp {rec.date_time} (epoch={epoch}) is not on a " + f"{interval_minutes}-min boundary (epoch % {interval_sec} = {epoch % interval_sec})" + ) + + def test_sparse_misaligned_values_preserved_after_snap(self): + """Snapping must not alter the field values of sparse records.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + # Single misaligned record, old enough for both tiers + dt = now.subtract(weeks=4).set(minute=7, second=0, microsecond=0) + seq.db_insert_record(EnergyRecord(date_time=dt, power_w=777.0, price_eur=0.55)) + seq.db_save_records() + + seq.db_compact() + + # Exactly one record must remain and its values must be unchanged + assert len(seq.records) == 1 + assert seq.records[0].power_w == pytest.approx(777.0) + assert seq.records[0].price_eur == pytest.approx(0.55) + + # ------------------------------------------------------------------ + # Case 3: two sparse records collide on the same snapped bucket + # ------------------------------------------------------------------ + + def test_sparse_collision_merges_records(self): + """Two sparse records that snap to the same bucket must be merged. + + Records at :03 and :04 both round to :00 with a 10-min interval. + With 4 test records and resampled_count=3, the snapping path is entered. + A newest-anchor record at now-1s pushes db_max ≈ now so the compaction + cutoff lands at now-30min, which is after all test records. + """ + now = to_datetime().in_timezone("UTC") + age_td = to_duration("30 minutes") + interval_td = to_duration("10 minutes") + interval_sec = 600 + # Place test records 41+ min ago so they are before cutoff = now - 30min + # base must be far enough back that all records (+17min max) land before + # window_end = floor(now - 30min, 600). Use now - 52min. + raw_base = now.subtract(minutes=52).set(second=0, microsecond=0) + base = raw_base.subtract(seconds=int(raw_base.timestamp()) % interval_sec) + + seq = EnergySequence() + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=3), + power_w=100.0, price_eur=None)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=4), + power_w=None, price_eur=0.25)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=13), power_w=10.0)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=17), power_w=20.0)) + # Anchor: makes db_max ≈ now → cutoff = now - 30min (after all test records) + seq.db_insert_record(EnergyRecord(date_time=now.subtract(seconds=1), power_w=0.0)) + seq.db_save_records() + + # existing_count in window = 4, resampled_count = 3 → snapping path + seq._db_compact_tier(age_td, interval_td) + + snapped_epoch = int(base.timestamp()) + snapped = [ + r for r in seq.records + if r.date_time is not None and int(r.date_time.timestamp()) == snapped_epoch + ] + assert len(snapped) == 1, "The :03 and :04 records must merge into one :00 bucket" + assert snapped[0].power_w == pytest.approx(100.0), "power_w from :03 must survive" + assert snapped[0].price_eur == pytest.approx(0.25), "price_eur from :04 must survive" + + def test_sparse_collision_keeps_first_value_for_shared_key(self): + """When two sparse records floor to the same bucket, the earlier value wins. + + Two records at :03 (power_w=111) and :04 (power_w=222) both floor to :00 + with a 10-min interval (floor division: 3//10=0, 4//10=0). + existing_count(2) <= resampled_count for the ~22-min window, so the sparse + snapping path is taken rather than full resampling. The merged record at + :00 must carry power_w=111 because the chronologically earlier record wins. + """ + now = to_datetime().in_timezone("UTC") + interval_sec = 600 + # Place both records 52 min ago so they are before window_end ≈ now - 30min. + # Only 2 test records → existing_count(2) <= resampled_count → sparse path. + raw_base = now.subtract(minutes=52).set(second=0, microsecond=0) + base = raw_base.subtract(seconds=int(raw_base.timestamp()) % interval_sec) + seq = EnergySequence() + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=3), power_w=111.0)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=4), power_w=222.0)) + # Anchor at now-1s: makes db_max ≈ now so cutoff = now - 30min + seq.db_insert_record(EnergyRecord(date_time=now.subtract(seconds=1), power_w=0.0)) + seq.db_save_records() + seq._db_compact_tier(to_duration("30 minutes"), to_duration("10 minutes")) + snapped_epoch = int(base.timestamp()) + snapped = [ + r for r in seq.records + if r.date_time is not None and int(r.date_time.timestamp()) == snapped_epoch + ] + assert len(snapped) == 1, ":03 and :04 must floor-snap into one :00 record" + assert snapped[0].power_w == pytest.approx(111.0), "Earlier record's value must win" + + def test_sparse_collision_with_existing_aligned_record(self): + """A misaligned record that snaps onto an already-aligned record must merge + into it without raising ValueError. The aligned record's existing values win. + + :00 (aligned, power_w=50, price_eur=None) and :03 (misaligned, + power_w=None, price_eur=0.30) both map to :00. Result: power_w=50 + (aligned wins) and price_eur=0.30 (filled from :03). + """ + now = to_datetime().in_timezone("UTC") + interval_sec = 600 + # base must be far enough back that all records (+17min max) land before + # window_end = floor(now - 30min, 600). Use now - 52min. + raw_base = now.subtract(minutes=52).set(second=0, microsecond=0) + base = raw_base.subtract(seconds=int(raw_base.timestamp()) % interval_sec) + + seq = EnergySequence() + seq.db_insert_record(EnergyRecord(date_time=base, + power_w=50.0, price_eur=None)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=3), + power_w=None, price_eur=0.30)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=13), power_w=10.0)) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=17), power_w=20.0)) + # Anchor: db_max ≈ now → cutoff = now - 30min, after all test records + seq.db_insert_record(EnergyRecord(date_time=now.subtract(seconds=1), power_w=0.0)) + seq.db_save_records() + + # Must not raise ValueError + seq._db_compact_tier(to_duration("30 minutes"), to_duration("10 minutes")) + + snapped_epoch = int(base.timestamp()) + snapped = [ + r for r in seq.records + if r.date_time is not None and int(r.date_time.timestamp()) == snapped_epoch + ] + assert len(snapped) == 1, ":00 and :03 must merge into one :00 record" + rec = snapped[0] + assert rec.power_w == pytest.approx(50.0), "Aligned record's power_w must win" + assert rec.price_eur == pytest.approx(0.30), ":03 record's price_eur must fill in" + assert rec.date_time is not None + assert int(rec.date_time.timestamp()) % interval_sec == 0 + + def test_sparse_no_duplicate_timestamps_after_collision(self): + """After collision merging, no duplicate timestamps must remain. + + Three records at :02, :03, :04 all round to :00 with a 10-min interval. + Together with a record at :13 this gives existing_count(4) > + resampled_count(3) so the snapping path is entered. + """ + now = to_datetime().in_timezone("UTC") + interval_sec = 600 + # base must be far enough back that all records (+17min max) land before + # window_end = floor(now - 30min, 600). Use now - 52min. + raw_base = now.subtract(minutes=52).set(second=0, microsecond=0) + base = raw_base.subtract(seconds=int(raw_base.timestamp()) % interval_sec) + + seq = EnergySequence() + for offset_min in [2, 3, 4]: # all snap to :00 + seq.db_insert_record(EnergyRecord( + date_time=base.add(minutes=offset_min), power_w=float(offset_min) + )) + seq.db_insert_record(EnergyRecord(date_time=base.add(minutes=13), power_w=10.0)) + # Anchor: db_max ≈ now → cutoff = now - 30min, after all test records + seq.db_insert_record(EnergyRecord(date_time=now.subtract(seconds=1), power_w=0.0)) + seq.db_save_records() + + seq._db_compact_tier(to_duration("30 minutes"), to_duration("10 minutes")) + + timestamps = [ + int(r.date_time.timestamp()) + for r in seq.records + if r.date_time is not None + ] + assert len(timestamps) == len(set(timestamps)), "Duplicate timestamps after collision merge" + + # ------------------------------------------------------------------ + # Existing tier-skip tests (unchanged semantics) + # ------------------------------------------------------------------ + + def test_hourly_data_skips_1h_tier(self): + """Data already at 1-hour resolution and aligned must not be re-compacted.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + # Use an hour-aligned base so records are on clean boundaries + base = now.subtract(weeks=3).set(minute=0, second=0, microsecond=0) + + _fill_sequence(seq, base, count=3 * 7 * 24, interval_minutes=60) + + before = seq.db_count_records() + deleted = seq._db_compact_tier(to_duration("14 days"), to_duration("1 hour")) + + assert deleted == 0 + assert seq.db_count_records() == before + + def test_15min_data_younger_than_2weeks_skips_1h_tier(self): + """15-min data between 2h and 2weeks old must NOT be compacted by the 1h tier.""" + seq = EnergySequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=1).set(minute=0, second=0, microsecond=0) + _fill_sequence(seq, base, count=7 * 24 * 4, interval_minutes=15) + + before = seq.db_count_records() + deleted = seq._db_compact_tier(to_duration("14 days"), to_duration("1 hour")) + + assert deleted == 0 + assert seq.db_count_records() == before diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..1e555cf --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,1148 @@ +"""Pytest tests for database persistence module. + +Tests the abstract Database interface and concrete implementations (LMDB, SQLite). +Also tests the database integration with DataSequence/DataProvider classes via +DatabaseRecordProtocolMixin. + +Design constraints honoured by these tests: +- DatabaseRecordProtocolMixin subclasses are singletons; tests reset state via + _db_reset_state() helpers rather than re-instantiating. +- db_save_records() has no clear_memory or start/end parameters; memory management + is separate from persistence. +- db_delete_records() has no clear_memory parameter. +- _db_ensure_loaded() is private; public callers use db_iterate_records() or + db_load_records() which trigger loading internally. +- db_count_records() correctly combines storage_count + new_count - pending_deletes. +- db_vacuum() end_timestamp is already exclusive; no +1ms offset applied. +- db_save_records() returns saved_count + deleted_count. +""" + +import pickle +import shutil +import tempfile +import time +from pathlib import Path +from typing import Iterator, List, Optional, Type + +import pytest +from pydantic import Field + +from akkudoktoreos.core.coreabc import get_database +from akkudoktoreos.core.dataabc import ( + DataProvider, + DataRecord, + DataSequence, +) +from akkudoktoreos.core.database import ( + Database, + LMDBDatabase, + SQLiteDatabase, +) +from akkudoktoreos.core.databaseabc import ( + DatabaseRecordProtocolLoadPhase, + DatabaseTimestamp, +) +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + Duration, + to_datetime, + to_duration, +) + +# ==================== Helpers ==================== + +def _clear_sequence_state(sequence) -> None: + """Clear runtime DB state without re-instantiating the singleton. + + Does _NOT_ initialize the DB state. + """ + sequence.db_delete_records() + try: + sequence._db_metadata = None + sequence.database().set_metadata(None, namespace=sequence.db_namespace()) + except Exception: + # Database may not be available, just skip + pass + try: + del sequence._db_initialized + except Exception: + # May not be set + pass + +def _reset_sequence_state(sequence) -> None: + """Reset runtime DB state without re-instantiating the singleton.""" + try: + sequence.records = [] + del sequence._db_initialized + except Exception: + # May not be set + pass + sequence._db_ensure_initialized() + + +# ==================== Test Fixtures ==================== + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test databases.""" + temp_path = Path(tempfile.mkdtemp()) + yield temp_path + shutil.rmtree(temp_path, ignore_errors=True) + + +@pytest.fixture(params=["LMDB", "SQLite"]) +def database_provider(request) -> str: + """Parametrize all database backend tests.""" + return request.param + + +@pytest.fixture +def database_instance(config_eos, database_provider: str) -> Iterator[Database]: + """Open a database instance for testing and close it afterwards. + + Note: Database is a singleton — we configure and use it, then restore + the provider to None so subsequent tests start clean. + """ + config_eos.database.compression_level = 6 + config_eos.database.provider = database_provider + db = get_database() + + assert db.is_open is True + assert db.provider_id() == database_provider + + yield db + + # Teardown: close and reset provider so next fixture gets a fresh state + db.close() + config_eos.database.provider = None + + +# ==================== Test Data Models ==================== + +class SampleDataRecord(DataRecord): + """Minimal DataRecord for testing.""" + temperature: float = Field(default=0.0) + humidity: float = Field(default=0.0) + pressure: float = Field(default=0.0) + + +class SampleDataSequence(DataSequence): + """DataSequence subclass with database support.""" + records: List[SampleDataRecord] = Field(default_factory=list) + + @classmethod + def record_class(cls) -> Type[SampleDataRecord]: + return SampleDataRecord + + def db_namespace(self) -> str: + return "SampleDataSequence" + + +class SampleDataProvider(DataProvider): + """DataProvider subclass with database support.""" + records: List[SampleDataRecord] = Field(default_factory=list) + + @classmethod + def record_class(cls) -> Type[SampleDataRecord]: + return SampleDataRecord + + def provider_id(self) -> str: + return "SampleDataProvider" + + def enabled(self) -> bool: + return True + + def _update_data(self, force_update: Optional[bool] = False) -> None: + pass + + def db_namespace(self) -> str: + return "SampleDataProvider" + + +# ==================== Database Backend Tests ==================== + +class TestDatabase: + """Tests for the raw Database interface (both backends).""" + + def test_database_creation(self, config_eos, database_provider): + config_eos.database.compression_level = 6 + config_eos.database.provider = database_provider + db = get_database() + + assert db.is_open is True + assert db.compression is True + assert db.compression_level == 6 + # storage_path uses the concrete backend class name + assert db.storage_path == ( + config_eos.general.data_folder_path / "db" / db._db.__class__.__name__.lower() + ) + + def test_database_open_close(self, database_instance): + assert database_instance.is_open is True + assert database_instance._db.connection is not None + + database_instance.close() + assert database_instance._db.is_open is False + + def test_save_and_load_single_record(self, database_instance): + key = b"2024-01-01T00:00:00+00:00" + value = b"test_data_12345" + + database_instance.save_records([(key, value)]) + records = list(database_instance.iterate_records(key, key + b"\xff")) + + assert len(records) == 1 + assert records[0] == (key, value) + + def test_save_multiple_records(self, database_instance): + records = [ + (b"2024-01-01T00:00:00+00:00", b"data1"), + (b"2024-01-02T00:00:00+00:00", b"data2"), + (b"2024-01-03T00:00:00+00:00", b"data3"), + ] + saved = database_instance.save_records(records) + assert saved == len(records) + + loaded = list(database_instance.iterate_records()) + assert len(loaded) == len(records) + for expected, actual in zip(records, loaded): + assert expected == actual + + def test_load_records_with_range(self, database_instance): + records = [ + (b"2024-01-01T00:00:00+00:00", b"data1"), + (b"2024-01-02T00:00:00+00:00", b"data2"), + (b"2024-01-03T00:00:00+00:00", b"data3"), + (b"2024-01-04T00:00:00+00:00", b"data4"), + (b"2024-01-05T00:00:00+00:00", b"data5"), + ] + database_instance.save_records(records) + + # Range is half-open: [2024-01-02, 2024-01-04) + start_key = b"2024-01-02T00:00:00+00:00" + end_key = b"2024-01-04T00:00:00+00:00" + loaded = list(database_instance.iterate_records(start_key, end_key)) + + assert len(loaded) == 2 + assert loaded[0][0] == b"2024-01-02T00:00:00+00:00" + assert loaded[1][0] == b"2024-01-03T00:00:00+00:00" + + def test_delete_record(self, database_instance): + key = b"2024-01-01T00:00:00+00:00" + database_instance.save_records([(key, b"test_data")]) + assert database_instance.count_records() == 1 + + deleted = database_instance.delete_records([key]) + assert deleted == 1 + assert database_instance.count_records() == 0 + + # Deleting a non-existent key returns 0 + deleted = database_instance.delete_records([key]) + assert deleted == 0 + + def test_count_records(self, database_instance): + assert database_instance.count_records() == 0 + + for i in range(10): + key = f"2024-01-{i + 1:02d}T00:00:00+00:00".encode() + database_instance.save_records([(key, b"data")]) + + assert database_instance.count_records() == 10 + + def test_get_key_range_empty(self, database_instance): + min_key, max_key = database_instance.get_key_range() + assert min_key is None + assert max_key is None + + def test_get_key_range_with_records(self, database_instance): + keys = [ + b"2024-01-01T00:00:00+00:00", + b"2024-01-05T00:00:00+00:00", + b"2024-01-03T00:00:00+00:00", + ] + for key in keys: + database_instance.save_records([(key, b"data")]) + + min_key, max_key = database_instance.get_key_range() + assert min_key == b"2024-01-01T00:00:00+00:00" + assert max_key == b"2024-01-05T00:00:00+00:00" + + def test_iterate_records_forward(self, database_instance): + keys = [ + b"2024-01-01T00:00:00+00:00", + b"2024-01-02T00:00:00+00:00", + b"2024-01-03T00:00:00+00:00", + ] + for key in keys: + database_instance.save_records([(key, b"data")]) + + result_keys = [k for k, _ in database_instance.iterate_records()] + assert result_keys == keys + + def test_iterate_records_reverse(self, database_instance): + keys = [ + b"2024-01-01T00:00:00+00:00", + b"2024-01-02T00:00:00+00:00", + b"2024-01-03T00:00:00+00:00", + ] + for key in keys: + database_instance.save_records([(key, b"data")]) + + result_keys = [k for k, _ in database_instance.iterate_records(reverse=True)] + assert result_keys == list(reversed(keys)) + + def test_compression_reduces_size(self, config_eos, database_provider): + large_data = b"A" * 10_000 + + config_eos.database.provider = database_provider + config_eos.database.compression_level = 9 + compressed = get_database().serialize_data(large_data) + assert get_database().deserialize_data(compressed) == large_data + + config_eos.database.compression_level = 0 + uncompressed = get_database().serialize_data(large_data) + assert get_database().deserialize_data(uncompressed) == large_data + + assert len(compressed) < len(uncompressed) + + def test_flush(self, database_instance): + key = b"2024-01-01T00:00:00+00:00" + database_instance.save_records([(key, b"test_data")]) + database_instance.flush() + + loaded = list(database_instance.iterate_records()) + assert len(loaded) == 1 + assert loaded[0] == (key, b"test_data") + + def test_backend_stats(self, database_instance): + stats = database_instance.get_backend_stats() + assert isinstance(stats, dict) + assert "backend" in stats + + for i in range(10): + key = f"2024-01-{i + 1:02d}T00:00:00+00:00".encode() + database_instance.save_records([(key, b"data" * 100)]) + + stats = database_instance.get_backend_stats() + assert stats is not None + + def test_metadata_excluded_from_count(self, database_instance): + """Metadata record stored under DATABASE_METADATA_KEY must not appear in count.""" + # Save a normal record + database_instance.save_records([(b"2024-01-01T00:00:00+00:00", b"data")]) + count = database_instance.count_records() + assert count == 1 # metadata excluded by backend implementation + + +# ==================== DatabaseRecordProtocolMixin Tests ==================== + +class TestDataSequenceDatabaseProtocol: + """Tests for DatabaseRecordProtocolMixin via SampleDataSequence.""" + + def test_db_enabled_when_db_open(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + assert sequence.db_enabled is True + + def test_db_disabled_when_db_closed(self, config_eos): + config_eos.database.provider = None + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + assert sequence.db_enabled is False + + def test_insert_and_save_records(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0 + i) + ) + + # All 10 are dirty/new, none persisted yet + assert len(sequence.records) == 10 + assert len(sequence._db_new_timestamps) == 10 + + saved = sequence.db_save_records() + assert saved == 10 # 10 inserts + 0 deletes + assert len(sequence._db_dirty_timestamps) == 0 + assert len(sequence._db_new_timestamps) == 0 + + def test_save_returns_insert_plus_delete_count(self, database_instance): + """db_save_records() return value = saved_inserts + deleted_count.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(5): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + # Persist the 5 records + sequence.db_save_records() + + # Delete 2 of them + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=4)) + deleted = sequence.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + # Insert 3 new ones + for i in range(10, 13): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + + result = sequence.db_save_records() + # 3 inserts + 2 deletes = 5 + assert result == 5 + + def test_load_records_from_db(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0 + i) + ) + sequence.db_save_records() + + # Clear memory, then reload from DB + _reset_sequence_state(sequence) + loaded = sequence.db_load_records() + + assert loaded == 10 + assert len(sequence.records) == 10 + for i, record in enumerate(sequence.records): + assert record.temperature == 20.0 + i + + def test_load_records_with_range(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0 + i) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # Load [hours=3, hours=7) → 4 records (3, 4, 5, 6) + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=3)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=7)) + loaded = sequence.db_load_records(start_timestamp=db_start, end_timestamp=db_end) + assert loaded == 4 + assert sequence.records[0].temperature == 23.0 + assert sequence.records[-1].temperature == 26.0 + + def test_iterate_records_triggers_lazy_load(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0 + i) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # db_iterate_records calls _db_ensure_loaded internally + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=5)) + records = list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + assert len(records) == 3 + assert all(base_time.add(hours=2) <= r.date_time < base_time.add(hours=5) for r in records) + + def test_delete_records(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(6): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0) + ) + sequence.db_save_records() + + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=5)) + deleted = sequence.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + assert deleted == 3 + + # Persist the deletions + sequence.db_save_records() + + _reset_sequence_state(sequence) + sequence.db_load_records() + assert len(sequence.records) == 3 + + def test_delete_tombstone_prevents_resurrection(self, database_instance): + """Deleted records must not re-appear when db_load_records is called.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(3): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + + # Delete middle record + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=1)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + deleted = sequence.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + assert deleted == 1 + + # Do NOT persist yet — tombstone lives only in memory + # Loading should not resurrect the tombstoned record + loaded = sequence.db_load_records() + assert all(r.date_time != base_time.add(hours=1) for r in sequence.records) + + def test_insert_after_delete_clears_tombstone(self, database_instance): + """Re-inserting a deleted datetime must clear its tombstone.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + dt = base_time.add(hours=5) + + sequence.db_insert_record(SampleDataRecord(date_time=dt, temperature=10.0)) + sequence.db_save_records() + + db_start = DatabaseTimestamp.from_datetime(dt) + db_end = sequence._db_timestamp_after(db_start) + deleted = sequence.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + assert deleted == 1 + + sequence.db_save_records() + + # Re-insert the same datetime + sequence.db_insert_record(SampleDataRecord(date_time=dt, temperature=99.0)) + assert dt not in sequence._db_deleted_timestamps + sequence.db_save_records() + + _reset_sequence_state(sequence) + sequence.db_load_records() + assert any(r.date_time == dt and r.temperature == 99.0 for r in sequence.records) + + def test_db_count_records_memory_only(self): + """When db is disabled, count reflects memory only.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + + # Without a live DB, db_enabled is False + if sequence.db_enabled: + pytest.skip("DB is open; this test requires it to be closed") + + base_time = to_datetime("2024-01-01T00:00:00Z") + for i in range(5): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)), + mark_dirty=False, + ) + assert sequence.db_count_records() == 5 + + def test_db_count_records_combined(self, database_instance): + """db_count_records = storage + new_unpersisted - pending_deletes.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + # Persist 10 records + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + + # Add 3 new unpersisted records + for i in range(10, 13): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + + # Delete 2 persisted records (not yet saved) + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=0)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + deleted = sequence.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + assert deleted == 2 + + # storage=10, new=3, pending_deletes=2 → expected=11 + assert sequence.db_count_records() == 11 + + def test_db_timestamp_range_empty(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + min_dt, max_dt = sequence.db_timestamp_range() + assert min_dt is None + assert max_dt is None + + def test_db_timestamp_range_with_records(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for hours in [0, 5, 10]: + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=hours), temperature=20.0) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + min_dt, max_dt = sequence.db_timestamp_range() + assert min_dt == DatabaseTimestamp.from_datetime(base_time) + assert max_dt == DatabaseTimestamp.from_datetime(base_time.add(hours=10)) + + def test_db_mark_dirty_triggers_save(self, database_instance): + """Marking a record dirty causes it to be re-saved.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + record = SampleDataRecord(date_time=base_time, temperature=20.0) + sequence.db_insert_record(record) + sequence.db_save_records() + + # Mutate and mark dirty + record.temperature = 99.0 + sequence.db_mark_dirty_record(record) + sequence.db_save_records() + + # Reload and verify update was persisted + _reset_sequence_state(sequence) + sequence.db_load_records() + assert sequence.records[0].temperature == 99.0 + + def test_db_vacuum_keep_hours(self, database_instance): + """db_vacuum(keep_hours=N) retains only the last N hours of records.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + # 240 hourly records = 10 days + for i in range(240): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=20.0) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + keep_hours = 5 * 24 # keep last 5 days + deleted = sequence.db_vacuum(keep_hours=keep_hours) + + assert deleted == 240 - keep_hours + assert sequence.db_count_records() == keep_hours + + def test_db_vacuum_keep_timestamp(self, database_instance): + """db_vacuum(keep_timestamp=T) deletes everything before T (exclusive).""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # Keep from hours=5 onward — delete [0, 5), i.e. 5 records + cutoff = base_time.add(hours=5) + db_cutoff = DatabaseTimestamp.from_datetime(cutoff) + deleted = sequence.db_vacuum(keep_timestamp=db_cutoff) + + assert deleted == 5 + assert sequence.db_count_records() == 5 + + # Verify the boundary record (hours=5) was NOT deleted + _reset_sequence_state(sequence) + sequence.db_load_records() + assert any(r.date_time == cutoff for r in sequence.records) + + def test_db_vacuum_no_argument(self, database_instance, config_eos): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + record = SampleDataRecord(date_time=base_time, temperature=20.0) + sequence.db_insert_record(record) + sequence.db_save_records() + + config_eos.database.keep_duration_h = None + assert sequence.db_vacuum() == 0 + + config_eos.database.keep_duration_h = 0 + assert sequence.db_vacuum() == 1 + + def test_db_vacuum_keep_hours_zero_deletes_all(self, database_instance): + """keep_hours=0 should delete all records.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(5): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + deleted = sequence.db_vacuum(keep_hours=0) + assert deleted == 5 + assert sequence.db_count_records() == 0 + + def test_db_get_stats(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + stats = sequence.db_get_stats() + + assert stats["enabled"] is True + assert "backend" in stats + assert "path" in stats + assert "memory_records" in stats + assert "total_records" in stats + assert "compression_enabled" in stats + assert "timestamp_range" in stats + assert stats["timestamp_range"]["min"] == "None" + assert stats["timestamp_range"]["max"] == "None" + + def test_db_get_stats_disabled(self, config_eos): + config_eos.database.provider = None + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + stats = sequence.db_get_stats() + assert stats == {"enabled": False} + + def test_lazy_load_phase_none_to_initial(self, database_instance): + """Phase transitions from NONE to INITIAL when a range is loaded via ensure_loaded.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + assert sequence._db_load_phase is DatabaseRecordProtocolLoadPhase.NONE + + base_time = to_datetime("2024-01-01T00:00:00Z") + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # Use db_iterate_records — it calls _db_ensure_loaded which owns phase transitions + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=3)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=7)) + list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + + assert sequence._db_load_phase is DatabaseRecordProtocolLoadPhase.INITIAL + + def test_lazy_load_phase_initial_to_full(self, database_instance): + """Phase transitions from INITIAL to FULL when iterate is called without range.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # Load partial range → INITIAL + # Use db_iterate_records — it calls _db_ensure_loaded which owns phase transitions + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=3)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=7)) + list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + assert sequence._db_load_phase is DatabaseRecordProtocolLoadPhase.INITIAL + + # Iterate without range → escalates to FULL + list(sequence.db_iterate_records()) + assert sequence._db_load_phase is DatabaseRecordProtocolLoadPhase.FULL + + def test_range_covered_skips_redundant_load(self, database_instance): + """_db_range_covered prevents a second DB query for the same range.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(10): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=2)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=8)) + list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + + # Loaded range is now set + assert sequence._db_loaded_range is not None + assert sequence._db_range_covered(db_start, db_end) is True + + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=0)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=20)) + assert sequence._db_range_covered(db_start, db_end) is False + + def test_loaded_range_not_clobbered_by_expansion(self, database_instance): + """Expanding left or right must not narrow the tracked loaded range.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(24): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + # Initial window: hours 8–16 + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=8)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=16)) + list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + + assert sequence._db_loaded_range is not None + initial_start, initial_end = sequence._db_loaded_range + assert initial_start is not None + assert initial_end is not None + + # Expand left: load hours 4–8 + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=4)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=16)) + list(sequence.db_iterate_records(start_timestamp=db_start, end_timestamp=db_end)) + + assert sequence._db_loaded_range is not None + expanded_start, expanded_end = sequence._db_loaded_range + assert expanded_start is not None + assert expanded_end is not None + + # Left boundary must have moved left; right must not have shrunk + assert expanded_start <= initial_start + assert expanded_end >= initial_end + + def test_duplicate_insert_raises(self, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + dt = to_datetime("2024-01-01T00:00:00Z") + + sequence.db_insert_record(SampleDataRecord(date_time=dt, temperature=1.0)) + with pytest.raises(ValueError, match="Duplicate timestamp"): + sequence.db_insert_record(SampleDataRecord(date_time=dt, temperature=2.0)) + + def test_autosave_delegates_to_save_records(self, database_instance): + """db_autosave() is equivalent to db_save_records().""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + for i in range(3): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + + saved = sequence.db_autosave() + assert saved == 3 + assert len(sequence._db_dirty_timestamps) == 0 + + def test_metadata_round_trip(self, database_instance): + """Metadata can be saved and loaded back correctly.""" + sequence = SampleDataSequence() + + _clear_sequence_state(sequence) + assert sequence._db_metadata is None + + _reset_sequence_state(sequence) + assert sequence._db_metadata is not None + created = sequence._db_metadata["created"] + assert sequence._db_metadata["version"] == 1 + + _reset_sequence_state(sequence) + assert sequence._db_metadata is not None + assert sequence._db_metadata["created"] == created + assert sequence._db_metadata["version"] == 1 + + def test_initial_load_window_respected(self, database_instance): + """db_initial_time_window limits the initial load from DB.""" + + class WindowedSequence(SampleDataSequence): + def db_namespace(self) -> str: + return "WindowedSequence" + + def db_initial_time_window(self) -> Optional[Duration]: + return to_duration("2 hours") + + sequence = WindowedSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T12:00:00Z") + + # Store 24 hourly records centred on base_time + for i in range(24): + sequence.db_insert_record( + SampleDataRecord( + date_time=base_time.subtract(hours=12).add(hours=i), + temperature=float(i), + ) + ) + sequence.db_save_records() + + _reset_sequence_state(sequence) + + # Trigger initial window load centred on base_time + sequence.config.database.initial_load_window_h = 2 + db_center = DatabaseTimestamp.from_datetime(base_time) + sequence._db_load_initial_window(center_timestamp=db_center) + + # Only records within ±2h of base_time should be in memory + assert len(sequence.records) <= 5 # at most 4h window = 4–5 records + assert sequence._db_load_phase is DatabaseRecordProtocolLoadPhase.INITIAL + + +# ==================== Backend-Specific Tests ==================== + +class TestLMDBDatabase: + """LMDB-specific tests.""" + + def test_lmdb_compact(self, config_eos): + config_eos.database.compression_level = 0 + config_eos.database.provider = "LMDB" + db = get_database() + assert db.is_open + + for i in range(1000): + key = f"2024-01-01T{i:06d}+00:00".encode() + db.save_records([(key, b"X" * 1000)]) + + for i in range(500): + key = f"2024-01-01T{i:06d}+00:00".encode() + db.delete_records([key]) + + lmdb = db._database() + assert isinstance(lmdb, LMDBDatabase) + lmdb.compact() + assert db.count_records() == 500 + db.close() + + def test_lmdb_namespace_isolation(self, config_eos): + """Records in different namespaces must not interfere.""" + config_eos.database.provider = "LMDB" + db = get_database() + assert db.is_open + + key = b"2024-01-01T00:00:00+00:00" + db.save_records([(key, b"ns_a_data")], namespace="ns_a") + db.save_records([(key, b"ns_b_data")], namespace="ns_b") + + ns_a = list(db.iterate_records(namespace="ns_a")) + ns_b = list(db.iterate_records(namespace="ns_b")) + + assert ns_a[0][1] == b"ns_a_data" + assert ns_b[0][1] == b"ns_b_data" + db.close() + + +class TestSQLiteDatabase: + """SQLite-specific tests.""" + + def test_sqlite_vacuum(self, config_eos): + config_eos.database.compression_level = 0 + config_eos.database.provider = "SQLite" + db = get_database() + assert db.is_open + + records = [ + (f"2024-01-{i + 1:02d}T00:00:00+00:00".encode(), b"data" * 100) + for i in range(100) + ] + db.save_records(records) + + keys_to_delete = [f"2024-01-{i + 1:02d}T00:00:00+00:00".encode() for i in range(50)] + db.delete_records(keys_to_delete) + + sqlitedb = db._database() + assert isinstance(sqlitedb, SQLiteDatabase) + sqlitedb.vacuum() + + assert db.count_records() == 50 + db.close() + + def test_sqlite_namespace_isolation(self, config_eos): + """Records in different namespaces must not interfere.""" + config_eos.database.provider = "SQLite" + db = get_database() + assert db.is_open + + key = b"2024-01-01T00:00:00+00:00" + db.save_records([(key, b"ns_a_data")], namespace="ns_a") + db.save_records([(key, b"ns_b_data")], namespace="ns_b") + + ns_a = list(db.iterate_records(namespace="ns_a")) + ns_b = list(db.iterate_records(namespace="ns_b")) + + assert ns_a[0][1] == b"ns_a_data" + assert ns_b[0][1] == b"ns_b_data" + db.close() + + +# ==================== Integration Tests ==================== + +class TestIntegration: + """Full end-to-end workflow tests.""" + + def test_full_workflow(self, config_eos, database_instance): + """Save → partial load → update → vacuum → verify.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + + # Step 1: Insert 100 records and persist + for i in range(100): + sequence.db_insert_record( + SampleDataRecord( + date_time=base_time.add(hours=i), + temperature=20.0 + i * 0.1, + humidity=60.0, + ) + ) + sequence.db_save_records() + + storage_count = sequence.database.count_records(namespace="SampleDataSequence") + assert storage_count == 100 + assert sequence.db_count_records() == 100 + + # Step 2: Clear memory and load a specific range + _reset_sequence_state(sequence) + db_start = DatabaseTimestamp.from_datetime(base_time.add(hours=20)) + db_end = DatabaseTimestamp.from_datetime(base_time.add(hours=40)) + loaded = sequence.db_load_records(db_start, db_end) + assert loaded == 20 + assert len(sequence.records) == 20 + + # Step 3: Update records in memory and persist + for record in sequence.records: + record.humidity = 75.0 + sequence.db_mark_dirty_record(record) + sequence.db_save_records() + + # Step 4: Reload the range and verify updates + _reset_sequence_state(sequence) + sequence.db_load_records(db_start, db_end) + assert all(r.humidity == 75.0 for r in sequence.records) + + # Step 5: Vacuum — keep from hours=75 onward (delete first 75) + db_cutoff = DatabaseTimestamp.from_datetime(base_time.add(hours=75)) + deleted = sequence.db_vacuum(keep_timestamp=db_cutoff) + assert deleted == 75 + assert sequence.db_count_records() == 25 + + # Step 6: Stats reflect vacuum result + _reset_sequence_state(sequence) + stats = sequence.db_get_stats() + assert stats["total_records"] == 25 + + def test_error_handling_db_disabled(self, config_eos): + """Operations on a disabled DB raise clearly.""" + config_eos.database.provider = None + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + + assert sequence.db_enabled is False + + # Save is a no-op and returns 0 when disabled — no RuntimeError + # (mixin returns 0 early when not enabled) + result = sequence.db_save_records() + assert result == 0 + + def test_persistence_across_resets(self, database_instance): + """Data written in one memory session is available after reset.""" + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-06-01T00:00:00Z") + + for i in range(20): + sequence.db_insert_record( + SampleDataRecord(date_time=base_time.add(hours=i), temperature=float(i)) + ) + sequence.db_save_records() + + # Simulate a restart: reset memory state + _reset_sequence_state(sequence) + assert len(sequence.records) == 0 + + loaded = sequence.db_load_records() + assert loaded == 20 + assert sequence.records[0].temperature == 0.0 + assert sequence.records[-1].temperature == 19.0 + + +# ==================== Performance Tests ==================== + +class TestPerformance: + """Throughput benchmarks — not correctness tests.""" + + def test_insert_throughput(self, config_eos, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + n = 10_000 + + start = time.perf_counter() + for i in range(n): + sequence.db_insert_record( + SampleDataRecord( + date_time=base_time.add(minutes=i), + temperature=20.0 + (i % 100) * 0.1, + ) + ) + insert_duration = time.perf_counter() - start + print(f"\nInserted {n} records in {insert_duration:.2f}s " + f"({n / insert_duration:.0f} rec/s)") + + assert len(sequence.records) == n + + def test_save_throughput(self, config_eos, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + n = 10_000 + + for i in range(n): + sequence.db_insert_record( + SampleDataRecord( + date_time=base_time.add(minutes=i), + temperature=20.0 + (i % 100) * 0.1, + ) + ) + + start = time.perf_counter() + saved = sequence.db_save_records() + save_duration = time.perf_counter() - start + + assert saved == n + print(f"\nSaved {n} records in {save_duration:.2f}s " + f"({n / save_duration:.0f} rec/s)") + + def test_load_throughput(self, config_eos, database_instance): + sequence = SampleDataSequence() + _reset_sequence_state(sequence) + base_time = to_datetime("2024-01-01T00:00:00Z") + n = 10_000 + + for i in range(n): + sequence.db_insert_record( + SampleDataRecord( + date_time=base_time.add(minutes=i), + temperature=20.0 + (i % 100) * 0.1, + ) + ) + sequence.db_save_records() + _reset_sequence_state(sequence) + + start = time.perf_counter() + loaded = sequence.db_load_records() + load_duration = time.perf_counter() - start + + assert loaded == n + print(f"\nLoaded {n} records in {load_duration:.2f}s " + f"({n / load_duration:.0f} rec/s)") diff --git a/tests/test_databaseabc.py b/tests/test_databaseabc.py new file mode 100644 index 0000000..9dca203 --- /dev/null +++ b/tests/test_databaseabc.py @@ -0,0 +1,888 @@ +from typing import Any, Iterator, Literal, Optional, Type, cast + +import pytest +from numpydantic import NDArray, Shape +from pydantic import BaseModel, Field + +from akkudoktoreos.core.databaseabc import ( + DATABASE_METADATA_KEY, + DatabaseRecordProtocolMixin, + DatabaseTimestamp, + _DatabaseTimestampUnbound, +) +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + Duration, + to_datetime, + to_duration, +) + +# --------------------------------------------------------------------------- +# Test record +# --------------------------------------------------------------------------- + + +class SampleRecord(BaseModel): + date_time: Optional[DateTime] = Field( + default=None, json_schema_extra={"description": "DateTime"} + ) + value: Optional[float] = None + + def __getitem__(self, key: str) -> Any: + if key == "date_time": + return self.date_time + if key == "value": + return self.value + assert key is None + return None + + +# --------------------------------------------------------------------------- +# Fake database backend +# --------------------------------------------------------------------------- + + +class SampleDatabase: + def __init__(self): + self._data: dict[Optional[str], dict[bytes, bytes]] = {} + self._metadata: Optional[bytes] = None + self.is_open = True + self.compression = False + self.compression_level = 0 + self.storage_path = "/fake" + + # serialization (pass-through) + + def serialize_data(self, data: bytes) -> bytes: + return data + + def deserialize_data(self, data: bytes) -> bytes: + return data + + # metadata + + def set_metadata(self, metadata: Optional[bytes], *, namespace: Optional[str] = None) -> None: + self._metadata = metadata + + def get_metadata(self, namespace: Optional[str] = None) -> Optional[bytes]: + return self._metadata + + # write + + def save_records( + self, records: list[tuple[bytes, bytes]], namespace: Optional[str] = None + ) -> int: + ns = self._data.setdefault(namespace, {}) + saved = 0 + for key, value in records: + ns[key] = value + saved += 1 + return saved + + def delete_records( + self, keys: Iterator[bytes], namespace: Optional[str] = None + ) -> int: + ns_data = self._data.get(namespace, {}) + deleted = 0 + for key in keys: + if key in ns_data: + del ns_data[key] + deleted += 1 + return deleted + + # read + + def iterate_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + namespace: Optional[str] = None, + reverse: bool = False, + ) -> Iterator[tuple[bytes, bytes]]: + items = self._data.get(namespace, {}) + keys = sorted(items, reverse=reverse) + for k in keys: + if k == DATABASE_METADATA_KEY: + continue + if start_key and k < start_key: + continue + if end_key and k >= end_key: + continue + yield k, items[k] + + # stats + + def count_records( + self, + start_key: Optional[bytes] = None, + end_key: Optional[bytes] = None, + *, + namespace: Optional[str] = None, + ) -> int: + items = self._data.get(namespace, {}) + count = 0 + for k in items: + if k == DATABASE_METADATA_KEY: + continue + if start_key and k < start_key: + continue + if end_key and k >= end_key: + continue + count += 1 + return count + + def get_key_range( + self, namespace: Optional[str] = None + ) -> tuple[Optional[bytes], Optional[bytes]]: + items = self._data.get(namespace, {}) + keys = sorted(k for k in items if k != DATABASE_METADATA_KEY) + if not keys: + return None, None + return keys[0], keys[-1] + + def get_backend_stats(self, namespace: Optional[str] = None) -> dict: + return {} + + def flush(self, namespace: Optional[str] = None) -> None: + pass + + +# --------------------------------------------------------------------------- +# Concrete test sequence — minimal, no Pydantic / singleton overhead +# --------------------------------------------------------------------------- + + +class SampleSequence(DatabaseRecordProtocolMixin[SampleRecord]): + """Minimal concrete implementation for unit-testing the mixin.""" + + def __init__(self): + self.records: list[SampleRecord] = [] + self._db_record_index: dict[DatabaseTimestamp, SampleRecord] = {} + self._db_sorted_timestamps: list[DatabaseTimestamp] = [] + self._db_dirty_timestamps: set[DatabaseTimestamp] = set() + self._db_new_timestamps: set[DatabaseTimestamp] = set() + self._db_deleted_timestamps: set[DatabaseTimestamp] = set() + self._db_initialized: bool = True + self._db_storage_initialized: bool = False + self._db_metadata: Optional[dict] = None + self._db_loaded_range = None + from akkudoktoreos.core.databaseabc import DatabaseRecordProtocolLoadPhase + self._db_load_phase = DatabaseRecordProtocolLoadPhase.NONE + self._db_version: int = 1 + + self.database = SampleDatabase() + self.config = type( + "Cfg", + (), + { + "database": type( + "DBCfg", + (), + { + "auto_save": False, + "compression_level": 0, + "autosave_interval_sec": 10, + "initial_load_window_h": None, + "keep_duration_h": None, + }, + )() + }, + )() + + @classmethod + def record_class(cls) -> Type[SampleRecord]: + return SampleRecord + + def db_namespace(self) -> str: + return "test" + + @property + def record_keys_writable(self) -> list[str]: + """Return writable field names of SampleRecord. + + Required by _db_compact_tier which iterates record_keys_writable + to decide which fields to resample. Must match exactly what + key_to_array accepts — only 'value' here, not 'date_time'. + """ + return ["value"] + + # Override key_to_array for the mixin tests — the full DataSequence + # implementation lives in dataabc.py; here we provide a minimal version + # that resamples the single `value` field to demonstrate compaction. + def key_to_array( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + interval: Optional[Duration] = None, + fill_method: Optional[str] = None, + dropna: Optional[bool] = True, + boundary: Literal["strict", "context"] = "context", + align_to_interval: bool = False, + ) -> NDArray[Shape["*"], Any]: + import numpy as np + import pandas as pd + + if interval is None: + interval = to_duration("1 hour") + + dates = [] + values = [] + for record in self.records: + if record.date_time is None: + continue + ts = DatabaseTimestamp.from_datetime(record.date_time) + if start_datetime and DatabaseTimestamp.from_datetime(start_datetime) > ts: + continue + if end_datetime and DatabaseTimestamp.from_datetime(end_datetime) <= ts: + continue + dates.append(record.date_time) + values.append(getattr(record, key, None)) + + if not dates: + return np.array([]) + + index = pd.to_datetime(dates, utc=True) + series = pd.Series(values, index=index, dtype=float) + freq = f"{int(interval.total_seconds())}s" + origin = start_datetime if start_datetime else "start_day" + resampled = series.resample(freq, origin=origin).mean().interpolate("time") + + if start_datetime is not None: + resampled = resampled.truncate(before=start_datetime) + if end_datetime is not None: + resampled = resampled.truncate(after=end_datetime) + + return resampled.values + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _insert_records_every_n_minutes( + seq: SampleSequence, + base: DateTime, + count: int, + interval_minutes: int, + value_fn=None, +) -> None: + """Insert `count` records spaced `interval_minutes` apart starting at `base`.""" + for i in range(count): + dt = base.add(minutes=i * interval_minutes) + value = value_fn(i) if value_fn else float(i) + seq.db_insert_record(SampleRecord(date_time=dt, value=value)) + seq.db_save_records() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def seq(): + return SampleSequence() + + +@pytest.fixture +def seq_with_15min_data(): + """Sequence with 15-min records spanning 4 weeks, so both tiers have data.""" + s = SampleSequence() + now = to_datetime().in_timezone("UTC") + # 4 weeks × 7 days × 24 h × 4 records/h = 2688 records + base = now.subtract(weeks=4) + _insert_records_every_n_minutes(s, base, count=2688, interval_minutes=15) + return s, now + + +@pytest.fixture +def seq_sparse(): + """Sequence with only 3 records spread over 4 weeks — sparse, no compaction benefit.""" + s = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=4) + for offset_days in [0, 14, 27]: + dt = base.add(days=offset_days) + s.db_insert_record(SampleRecord(date_time=dt, value=float(offset_days))) + s.db_save_records() + return s, now + + +# --------------------------------------------------------------------------- +# Existing tests (unchanged) +# --------------------------------------------------------------------------- + + +class TestDatabaseRecordProtocolMixin: + + @pytest.mark.parametrize( + "start_str, value_count, interval_seconds", + [ + ("2024-11-10 00:00:00", 24, 3600), + ("2024-08-10 00:00:00", 24, 3600), + ("2024-03-31 00:00:00", 24, 3600), + ("2024-10-27 00:00:00", 24, 3600), + ], + ) + def test_db_generate_timestamps_utc_spacing( + self, seq, start_str, value_count, interval_seconds + ): + start_dt = to_datetime(start_str, in_timezone="Europe/Berlin") + assert start_dt.tz.name == "Europe/Berlin" + + db_start = DatabaseTimestamp.from_datetime(start_dt) + generated = list(seq.db_generate_timestamps(db_start, value_count)) + + assert len(generated) == value_count + + for db_dt in generated: + dt = DatabaseTimestamp.to_datetime(db_dt) + assert dt.tz.name == "UTC" + + assert len(generated) == len(set(generated)), "Duplicate UTC datetimes found" + + for i in range(1, len(generated)): + last_dt = DatabaseTimestamp.to_datetime(generated[i - 1]) + current_dt = DatabaseTimestamp.to_datetime(generated[i]) + delta = (current_dt - last_dt).total_seconds() + assert delta == interval_seconds, f"Spacing mismatch at index {i}: {delta}s" + + def test_insert_and_memory_range(self, seq): + t0 = to_datetime() + t1 = t0.add(hours=1) + + seq.db_insert_record(SampleRecord(date_time=t0, value=1)) + seq.db_insert_record(SampleRecord(date_time=t1, value=2)) + + assert seq.records[0].date_time == t0 + assert seq.records[-1].date_time == t1 + assert len(seq.records) == 2 + + def test_roundtrip_reload(self): + seq = SampleSequence() + t0 = to_datetime() + t1 = t0.add(hours=1) + + seq.db_insert_record(SampleRecord(date_time=t0, value=1)) + seq.db_insert_record(SampleRecord(date_time=t1, value=2)) + assert seq.db_save_records() == 2 + + db = seq.database + seq2 = SampleSequence() + seq2.database = db + loaded = seq2.db_load_records() + + assert loaded == 2 + assert len(seq2.records) == 2 + + def test_db_count_records(self, seq): + t0 = to_datetime() + seq.db_insert_record(SampleRecord(date_time=t0, value=1)) + assert seq.db_count_records() == 1 + seq.db_save_records() + assert seq.db_count_records() == 1 + + def test_delete_range(self, seq): + base = to_datetime() + for i in range(5): + seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i)) + + db_start = DatabaseTimestamp.from_datetime(base.add(minutes=1)) + db_end = DatabaseTimestamp.from_datetime(base.add(minutes=4)) + deleted = seq.db_delete_records(start_timestamp=db_start, end_timestamp=db_end) + + assert deleted == 3 + assert [r.value for r in seq.records] == [0, 4] + + def test_db_count_records_memory_only_multiple(self): + seq = SampleSequence() + base = to_datetime() + for i in range(3): + seq.db_insert_record(SampleRecord(date_time=base.add(minutes=i), value=i)) + assert seq.db_count_records() == 3 + + def test_db_count_records_memory_newer_than_db(self): + seq = SampleSequence() + base = to_datetime() + seq.db_insert_record(SampleRecord(date_time=base, value=1)) + seq.db_save_records() + seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2)) + seq.db_insert_record(SampleRecord(date_time=base.add(hours=2), value=3)) + assert seq.db_count_records() == 3 + + def test_db_count_records_memory_older_than_db(self): + seq = SampleSequence() + base = to_datetime() + seq.db_insert_record(SampleRecord(date_time=base.add(hours=1), value=2)) + seq.db_save_records() + seq.db_insert_record(SampleRecord(date_time=base, value=1)) + assert seq.db_count_records() == 2 + + def test_db_count_records_empty_everywhere(self): + seq = SampleSequence() + assert seq.db_count_records() == 0 + + def test_metadata_not_counted(self, seq): + seq.database._data.setdefault("test", {})[DATABASE_METADATA_KEY] = b"meta" + assert seq.db_count_records() == 0 + + def test_key_range_excludes_metadata(self, seq): + ns = seq.db_namespace() + seq.database._data.setdefault(ns, {})[DATABASE_METADATA_KEY] = b"meta" + assert seq.database.get_key_range(ns) == (None, None) + + +# --------------------------------------------------------------------------- +# Compaction tests +# --------------------------------------------------------------------------- + + +class TestCompactTiers: + """Tests for db_compact_tiers() and the tier hook.""" + + def test_default_tiers_returns_two_entries(self, seq): + tiers = seq.db_compact_tiers() + assert len(tiers) == 2 + + def test_default_tiers_ordered_shortest_first(self, seq): + tiers = seq.db_compact_tiers() + ages = [t[0].total_seconds() for t in tiers] + assert ages == sorted(ages), "Tiers must be ordered shortest age first" + + def test_default_tiers_first_is_2h_to_15min(self, seq): + tiers = seq.db_compact_tiers() + age_sec, interval_sec = ( + tiers[0][0].total_seconds(), + tiers[0][1].total_seconds(), + ) + assert age_sec == 2 * 3600 + assert interval_sec == 15 * 60 + + def test_default_tiers_second_is_2weeks_to_1h(self, seq): + tiers = seq.db_compact_tiers() + age_sec, interval_sec = ( + tiers[1][0].total_seconds(), + tiers[1][1].total_seconds(), + ) + assert age_sec == 14 * 24 * 3600 + assert interval_sec == 3600 + + def test_override_tiers(self): + class CustomSeq(SampleSequence): + def db_compact_tiers(self): + return [(to_duration("7 days"), to_duration("1 hour"))] + + s = CustomSeq() + tiers = s.db_compact_tiers() + assert len(tiers) == 1 + assert tiers[0][1].total_seconds() == 3600 + + def test_empty_tiers_disables_compaction(self): + class NoCompactSeq(SampleSequence): + def db_compact_tiers(self): + return [] + + s = NoCompactSeq() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=4) + _insert_records_every_n_minutes(s, base, count=100, interval_minutes=15) + + deleted = s.db_compact() + assert deleted == 0 + + +class TestCompactState: + """Tests for _db_get_compact_state / _db_set_compact_state.""" + + def test_get_state_returns_none_when_no_metadata(self, seq): + interval = to_duration("1 hour") + assert seq._db_get_compact_state(interval) is None + + def test_set_and_get_state_roundtrip(self, seq): + interval = to_duration("1 hour") + now = to_datetime().in_timezone("UTC") + ts = DatabaseTimestamp.from_datetime(now) + + seq._db_set_compact_state(interval, ts) + retrieved = seq._db_get_compact_state(interval) + + assert retrieved == ts + + def test_state_is_per_tier(self, seq): + """Different tier intervals must not overwrite each other.""" + interval_15min = to_duration("15 minutes") + interval_1h = to_duration("1 hour") + + now = to_datetime().in_timezone("UTC") + ts_15 = DatabaseTimestamp.from_datetime(now) + ts_1h = DatabaseTimestamp.from_datetime(now.subtract(days=1)) + + seq._db_set_compact_state(interval_15min, ts_15) + seq._db_set_compact_state(interval_1h, ts_1h) + + assert seq._db_get_compact_state(interval_15min) == ts_15 + assert seq._db_get_compact_state(interval_1h) == ts_1h + + def test_state_persists_in_metadata(self, seq): + """State must survive a metadata reload.""" + interval = to_duration("1 hour") + now = to_datetime().in_timezone("UTC") + ts = DatabaseTimestamp.from_datetime(now) + + seq._db_set_compact_state(interval, ts) + + # Reload metadata from fake DB + seq2 = SampleSequence() + seq2.database = seq.database + seq2._db_metadata = seq2._db_load_metadata() + + assert seq2._db_get_compact_state(interval) == ts + + +class TestCompactSparseGuard: + """The inflation guard must skip compaction when records are already sparse.""" + + def test_sparse_data_aligns_but_does_not_reduce_cardinality(self, seq_sparse): + """Sparse data must be aligned to the target interval for all records that were modified.""" + seq, _ = seq_sparse + + interval = to_duration("15 minutes") + interval_sec = int(interval.total_seconds()) + + # Snapshot original timestamps + before_epochs = { + int(r.date_time.timestamp()) + for r in seq.records + } + + seq._db_compact_tier( + to_duration("30 minutes"), + interval, + ) + + after_epochs = { + int(r.date_time.timestamp()) + for r in seq.records + } + + # Cardinality must not increase + assert len(after_epochs) <= len(before_epochs) + + # Any timestamp that changed must now be aligned + changed_epochs = after_epochs - before_epochs + + for epoch in changed_epochs: + assert epoch % interval_sec == 0 + + def test_sparse_guard_advances_cutoff(self, seq_sparse): + """Even when skipped, the cutoff should be stored so next run skips the same window.""" + seq, _ = seq_sparse + interval_1h = to_duration("1 hour") + interval_15min = to_duration("15 minutes") + + seq.db_compact() + + # Both tiers should have stored a cutoff even though nothing was deleted + assert seq._db_get_compact_state(interval_1h) is not None + assert seq._db_get_compact_state(interval_15min) is not None + + def test_exactly_at_boundary_remains_stable(self, seq): + now = to_datetime().in_timezone("UTC") + interval = to_duration("1 hour") + + raw_base = now.subtract(hours=5).set(minute=0, second=0, microsecond=0) + base = raw_base.subtract(seconds=int(raw_base.timestamp()) % 3600) + + for i in range(4): + seq.db_insert_record( + SampleRecord( + date_time=base.add(hours=i), + value=float(i), + ) + ) + + seq.db_insert_record( + SampleRecord(date_time=now.subtract(seconds=1), value=0.0) + ) + seq.db_save_records() + + before = [ + (int(r.date_time.timestamp()), r.value) + for r in seq.records + ] + + seq._db_compact_tier( + to_duration("30 minutes"), + interval, + ) + + after = [ + (int(r.date_time.timestamp()), r.value) + for r in seq.records + ] + + assert before == after + + +class TestCompactTierWorker: + """Unit tests for _db_compact_tier directly.""" + + def test_empty_sequence_returns_zero(self, seq): + age = to_duration("2 hours") + interval = to_duration("15 minutes") + assert seq._db_compact_tier(age, interval) == 0 + + def test_all_records_too_recent_skipped(self): + """Records within the age threshold must not be touched.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + # Insert 10 records from 30 minutes ago — all within 2h threshold + base = now.subtract(minutes=30) + _insert_records_every_n_minutes(seq, base, count=10, interval_minutes=1) + + before = seq.db_count_records() + deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + assert deleted == 0 + assert seq.db_count_records() == before + + def test_compaction_reduces_record_count(self): + """Dense 1-min records older than 2h should be downsampled to 15-min.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + # Insert 1-min records for 6 hours ending 3 hours ago + base = now.subtract(hours=9) + _insert_records_every_n_minutes(seq, base, count=6 * 60, interval_minutes=1) + + before = seq.db_count_records() + deleted = seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + after = seq.db_count_records() + assert deleted > 0 + assert after < before + + def test_records_within_threshold_preserved(self): + """Records newer than age_threshold must remain untouched after compaction.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + + # Old dense records (will be compacted) + old_base = now.subtract(hours=6) + _insert_records_every_n_minutes(seq, old_base, count=4 * 60, interval_minutes=1) + + # Recent records (must not be touched) — insert 5 records in the last hour + recent_base = now.subtract(minutes=50) + _insert_records_every_n_minutes(seq, recent_base, count=5, interval_minutes=10) + + recent_before = [ + r for r in seq.records + if r.date_time and r.date_time >= recent_base + ] + + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + recent_after = [ + r for r in seq.records + if r.date_time and r.date_time >= recent_base + ] + assert len(recent_after) == len(recent_before) + + def test_incremental_cutoff_prevents_recompaction(self): + """Running compaction twice must not re-compact already-compacted data.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(hours=8) + _insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1) + + age = to_duration("2 hours") + interval = to_duration("15 minutes") + + deleted_first = seq._db_compact_tier(age, interval) + count_after_first = seq.db_count_records() + + deleted_second = seq._db_compact_tier(age, interval) + count_after_second = seq.db_count_records() + + assert deleted_first > 0 + assert deleted_second == 0, "Second run must be a no-op" + assert count_after_first == count_after_second + + def test_cutoff_stored_after_compaction(self): + """Cutoff timestamp must be persisted after a successful compaction run.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(hours=8) + _insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1) + + interval = to_duration("15 minutes") + seq._db_compact_tier(to_duration("2 hours"), interval) + + assert seq._db_get_compact_state(interval) is not None + + +class TestDbCompact: + """Integration tests for the public db_compact() entry point.""" + + def test_compact_dense_data_both_tiers(self, seq_with_15min_data): + """4 weeks of 15-min data should be reduced by both tiers.""" + seq, _ = seq_with_15min_data + before = seq.db_count_records() + + total_deleted = seq.db_compact() + + after = seq.db_count_records() + assert total_deleted > 0 + assert after < before + + def test_compact_coarsest_tier_runs_first(self, seq_with_15min_data): + """The 1-hour tier (coarsest) must run before the 15-min tier. + + If coarsest ran last it would re-compact records the 15-min tier + had already downsampled — verified by checking that the 1-hour + cutoff is not later than the 15-min cutoff. + """ + seq, _ = seq_with_15min_data + seq.db_compact() + + cutoff_1h = seq._db_get_compact_state(to_duration("1 hour")) + cutoff_15min = seq._db_get_compact_state(to_duration("15 minutes")) + + assert cutoff_1h is not None + assert cutoff_15min is not None + # The 1h tier covers older data → its cutoff must be earlier than 15min tier + assert cutoff_1h <= cutoff_15min + + def test_compact_idempotent(self, seq_with_15min_data): + """Running db_compact twice must not change record count.""" + seq, _ = seq_with_15min_data + seq.db_compact() + after_first = seq.db_count_records() + + seq.db_compact() + after_second = seq.db_count_records() + + assert after_first == after_second + + def test_compact_empty_sequence_returns_zero(self, seq): + assert seq.db_compact() == 0 + + def test_compact_with_override_tiers(self): + """Passing compact_tiers directly must override db_compact_tiers().""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=3) + _insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 4, interval_minutes=15) + + before = seq.db_count_records() + deleted = seq.db_compact( + compact_tiers=[(to_duration("1 day"), to_duration("1 hour"))] + ) + + assert deleted > 0 + assert seq.db_count_records() < before + + def test_compact_only_processes_new_window_on_second_call(self): + """Second call processes only the new window, not the full history.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(weeks=3) + # Dense 1-min data for 3 weeks + _insert_records_every_n_minutes(seq, base, count=3 * 7 * 24 * 60, interval_minutes=1) + + seq.db_compact() + count_after_first = seq.db_count_records() + + # Add one more day of dense data in the past (simulate new old data arriving) + extra_base = now.subtract(weeks=3).subtract(days=1) + _insert_records_every_n_minutes(seq, extra_base, count=24 * 60, interval_minutes=1) + + seq.db_compact() + count_after_second = seq.db_count_records() + + # Second compact should have processed the newly added old data + # Record count may change but should not exceed first compacted count by much + assert count_after_second >= 0 # basic sanity + + +class TestCompactDataIntegrity: + """Verify value integrity is preserved after compaction.""" + + def test_constant_value_preserved(self): + """Constant value field must survive mean-resampling unchanged.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(hours=6) + + # All values = 42.0 + _insert_records_every_n_minutes( + seq, base, count=6 * 60, interval_minutes=1, value_fn=lambda _: 42.0 + ) + + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + for record in seq.records: + if record.date_time and record.date_time < now.subtract(hours=2): + assert record.value == pytest.approx(42.0, abs=1e-6) + + def test_recent_records_not_modified(self): + """Records newer than the age threshold must have unchanged values.""" + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + + old_base = now.subtract(hours=6) + _insert_records_every_n_minutes(seq, old_base, count=3 * 60, interval_minutes=1) + + # Known recent values + recent_base = now.subtract(minutes=30) + expected = {i * 10: float(100 + i) for i in range(3)} + for offset, val in expected.items(): + dt = recent_base.add(minutes=offset) + seq.db_insert_record(SampleRecord(date_time=dt, value=val)) + seq.db_save_records() + + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + for record in seq.records: + if record.date_time and record.date_time >= recent_base: + offset = int((record.date_time - recent_base).total_seconds() / 60) + if offset in expected: + assert record.value == pytest.approx(expected[offset], abs=1e-6) + + def test_compacted_timestamps_spacing(self): + """Resampled records must be fewer than original and span the compaction window. + + Exact per-bucket spacing depends on the full DataSequence.key_to_array + implementation (pandas resampling). The stub key_to_array in SampleSequence + only guarantees a reduction in count — uniform spacing is verified in + test_dataabc_compact.py against the real implementation. + """ + seq = SampleSequence() + now = to_datetime().in_timezone("UTC") + base = now.subtract(hours=6) + _insert_records_every_n_minutes(seq, base, count=5 * 60, interval_minutes=1) + + before = seq.db_count_records() + seq._db_compact_tier(to_duration("2 hours"), to_duration("15 minutes")) + + cutoff = now.subtract(hours=2) + compacted = sorted( + [r for r in seq.records if r.date_time and r.date_time < cutoff], + key=lambda r: cast(DateTime, r.date_time), + ) + + # Must have produced fewer records than the original 1-min data + assert len(compacted) > 0, "Expected at least one compacted record" + assert len(compacted) < before, "Compaction must reduce record count" + + # Window start is floored to interval boundary + interval_sec = 15 * 60 + expected_window_start = DateTime.fromtimestamp( + (int(base.timestamp()) // interval_sec) * interval_sec, + tz="UTC", + ) + assert compacted[0].date_time >= expected_window_start + + # Last compacted record must be before the cutoff + assert compacted[-1].date_time < cutoff diff --git a/tests/test_datetimeutil.py b/tests/test_datetimeutil.py index bbc4757..09c14c1 100644 --- a/tests/test_datetimeutil.py +++ b/tests/test_datetimeutil.py @@ -1460,6 +1460,17 @@ class TestTimeWindowSequence: # - without local timezone as UTC ( "TC014", + "UTC", + "2024-01-03", + None, + "UTC", + None, + False, + pendulum.datetime(2024, 1, 3, 0, 0, 0, tz="UTC"), + False, + ), + ( + "TC015", "Atlantic/Canary", "02/02/24", None, @@ -1470,7 +1481,7 @@ class TestTimeWindowSequence: False, ), ( - "TC015", + "TC016", "Atlantic/Canary", "2024-03-03T10:20:30.000Z", # No dalight saving time at this date None, @@ -1484,7 +1495,7 @@ class TestTimeWindowSequence: # from pendulum.datetime to pendulum.datetime object # --------------------------------------- ( - "TC016", + "TC017", "Atlantic/Canary", pendulum.datetime(2024, 4, 4, 0, 0, 0), None, @@ -1495,7 +1506,7 @@ class TestTimeWindowSequence: False, ), ( - "TC017", + "TC018", "Atlantic/Canary", pendulum.datetime(2024, 4, 4, 1, 0, 0), None, @@ -1506,7 +1517,7 @@ class TestTimeWindowSequence: False, ), ( - "TC018", + "TC019", "Atlantic/Canary", pendulum.datetime(2024, 4, 4, 1, 0, 0, tz="Etc/UTC"), None, @@ -1517,7 +1528,7 @@ class TestTimeWindowSequence: False, ), ( - "TC019", + "TC020", "Atlantic/Canary", pendulum.datetime(2024, 4, 4, 2, 0, 0, tz="Europe/Berlin"), None, @@ -1533,7 +1544,7 @@ class TestTimeWindowSequence: # - no timezone # local timezone UTC ( - "TC020", + "TC021", "Etc/UTC", "2023-11-06T00:00:00", "UTC", @@ -1545,7 +1556,7 @@ class TestTimeWindowSequence: ), # local timezone "Europe/Berlin" ( - "TC021", + "TC022", "Europe/Berlin", "2023-11-06T00:00:00", "UTC", @@ -1557,7 +1568,7 @@ class TestTimeWindowSequence: ), # - no microseconds ( - "TC022", + "TC023", "Atlantic/Canary", "2024-10-30T00:00:00+01:00", "UTC", @@ -1568,7 +1579,7 @@ class TestTimeWindowSequence: False, ), ( - "TC023", + "TC024", "Atlantic/Canary", "2024-10-30T01:00:00+01:00", "utc", @@ -1580,7 +1591,7 @@ class TestTimeWindowSequence: ), # - with microseconds ( - "TC024", + "TC025", "Atlantic/Canary", "2024-10-07T10:20:30.000+02:00", "UTC", @@ -1596,7 +1607,7 @@ class TestTimeWindowSequence: # - no timezone # local timezone ( - "TC025", + "TC026", None, None, None, diff --git a/tests/test_doc.py b/tests/test_doc.py index 5c8ba47..7bee639 100644 --- a/tests/test_doc.py +++ b/tests/test_doc.py @@ -14,8 +14,10 @@ DIR_DOCS_GENERATED = DIR_PROJECT_ROOT / "docs" / "_generated" DIR_TEST_GENERATED = DIR_TESTDATA / "docs" / "_generated" -def test_openapi_spec_current(config_eos): +def test_openapi_spec_current(config_eos, set_other_timezone): """Verify the openapi spec hasn´t changed.""" + set_other_timezone("UTC") # CI runs on UTC + expected_spec_path = DIR_PROJECT_ROOT / "openapi.json" new_spec_path = DIR_TESTDATA / "openapi-new.json" @@ -23,7 +25,7 @@ def test_openapi_spec_current(config_eos): expected_spec = json.load(f_expected) # Patch get_config and import within guard to patch global variables within the eos module. - with patch("akkudoktoreos.config.config.get_config", return_value=config_eos): + with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos): # Ensure the script works correctly as part of a package root_dir = Path(__file__).resolve().parent.parent sys.path.insert(0, str(root_dir)) @@ -39,7 +41,7 @@ def test_openapi_spec_current(config_eos): expected_spec_str = json.dumps(expected_spec, indent=4, sort_keys=True) try: - assert spec_str == expected_spec_str + assert json.loads(spec_str) == json.loads(expected_spec_str) except AssertionError as e: pytest.fail( f"Expected {new_spec_path} to equal {expected_spec_path}.\n" @@ -47,8 +49,10 @@ def test_openapi_spec_current(config_eos): ) -def test_openapi_md_current(config_eos): +def test_openapi_md_current(config_eos, set_other_timezone): """Verify the generated openapi markdown hasn´t changed.""" + set_other_timezone("UTC") # CI runs on UTC + expected_spec_md_path = DIR_PROJECT_ROOT / "docs" / "_generated" / "openapi.md" new_spec_md_path = DIR_TESTDATA / "openapi-new.md" @@ -56,7 +60,7 @@ def test_openapi_md_current(config_eos): expected_spec_md = f_expected.read() # Patch get_config and import within guard to patch global variables within the eos module. - with patch("akkudoktoreos.config.config.get_config", return_value=config_eos): + with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos): # Ensure the script works correctly as part of a package root_dir = Path(__file__).resolve().parent.parent sys.path.insert(0, str(root_dir)) @@ -76,8 +80,10 @@ def test_openapi_md_current(config_eos): ) -def test_config_md_current(config_eos): +def test_config_md_current(config_eos, set_other_timezone): """Verify the generated configuration markdown hasn´t changed.""" + set_other_timezone("UTC") # CI runs on UTC + assert DIR_DOCS_GENERATED.exists() # Remove any leftover files from last run @@ -88,7 +94,7 @@ def test_config_md_current(config_eos): DIR_TEST_GENERATED.mkdir(parents=True, exist_ok=True) # Patch get_config and import within guard to patch global variables within the eos module. - with patch("akkudoktoreos.config.config.get_config", return_value=config_eos): + with patch("akkudoktoreos.core.coreabc.get_config", return_value=config_eos): # Ensure the script works correctly as part of a package root_dir = Path(__file__).resolve().parent.parent sys.path.insert(0, str(root_dir)) @@ -106,7 +112,11 @@ def test_config_md_current(config_eos): tested.append(DIR_TEST_GENERATED / file_name) # Create test files - config_md = generate_config_md.generate_config_md(tested[0], config_eos) + try: + config_eos._force_documentation_mode = True + config_md = generate_config_md.generate_config_md(tested[0], config_eos) + finally: + config_eos._force_documentation_mode = False # Check test files are the same as the expected files for i, expected_path in enumerate(expected): diff --git a/tests/test_docsphinx.py b/tests/test_docsphinx.py index f8364b0..0fe6735 100644 --- a/tests/test_docsphinx.py +++ b/tests/test_docsphinx.py @@ -9,6 +9,8 @@ from typing import Optional import pytest +from akkudoktoreos.core.coreabc import singletons_init + DIR_PROJECT_ROOT = Path(__file__).absolute().parent.parent DIR_BUILD = DIR_PROJECT_ROOT / "build" DIR_BUILD_DOCS = DIR_PROJECT_ROOT / "build" / "docs" @@ -80,6 +82,7 @@ class TestSphinxDocumentation: def test_sphinx_build(self, sphinx_changed: Optional[str], is_finalize: bool): """Build Sphinx documentation and ensure no major warnings appear in the build output.""" + # Ensure docs folder exists if not DIR_DOCS.exists(): pytest.skip(f"Skipping Sphinx build test - docs folder not present: {DIR_DOCS}") @@ -88,7 +91,7 @@ class TestSphinxDocumentation: pytest.skip(f"Skipping Sphinx build — no relevant file changes detected: {HASH_FILE}") if not is_finalize: - pytest.skip("Skipping Sphinx test — not full run") + pytest.skip("Skipping Sphinx test — not finalize") # Clean directories self._cleanup_autosum_dirs() @@ -123,7 +126,11 @@ class TestSphinxDocumentation: # Remove temporary EOS_DIR eos_tmp_dir.cleanup() - assert returncode == 0 + if returncode != 0: + pytest.fail( + f"Sphinx build failed with exit code {returncode}.\n" + f"{output}\n" + ) # Possible markers: ERROR: WARNING: TRACEBACK: major_markers = ("ERROR:", "TRACEBACK:") diff --git a/tests/test_elecpriceakkudoktor.py b/tests/test_elecpriceakkudoktor.py index 477c383..732e48b 100644 --- a/tests/test_elecpriceakkudoktor.py +++ b/tests/test_elecpriceakkudoktor.py @@ -8,7 +8,7 @@ import requests from loguru import logger from akkudoktoreos.core.cache import CacheFileStore -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.elecpriceakkudoktor import ( AkkudoktorElecPrice, AkkudoktorElecPriceValue, diff --git a/tests/test_elecpriceenergycharts.py b/tests/test_elecpriceenergycharts.py index 62ad56a..cd212d4 100644 --- a/tests/test_elecpriceenergycharts.py +++ b/tests/test_elecpriceenergycharts.py @@ -8,7 +8,7 @@ import requests from loguru import logger from akkudoktoreos.core.cache import CacheFileStore -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.elecpriceakkudoktor import ( AkkudoktorElecPrice, AkkudoktorElecPriceValue, diff --git a/tests/test_elecpriceimport.py b/tests/test_elecpriceimport.py index 09f20e1..5ade60d 100644 --- a/tests/test_elecpriceimport.py +++ b/tests/test_elecpriceimport.py @@ -1,11 +1,12 @@ import json from pathlib import Path +import numpy.testing as npt import pytest -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport -from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime +from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata") @@ -83,6 +84,7 @@ def test_invalid_provider(provider, config_eos): ) def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos): """Test fetching forecast from Import.""" + key = "elecprice_marketprice_wh" ems_eos = get_ems() ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin")) if from_file: @@ -91,7 +93,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi else: config_eos.elecprice.elecpriceimport.import_file_path = None assert config_eos.elecprice.elecpriceimport.import_file_path is None - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) # Call the method provider.update_data() @@ -100,16 +102,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi assert provider.ems_start_datetime is not None assert provider.total_hours is not None assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal - values = sample_import_1_json["elecprice_marketprice_wh"] - value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values)) - for i, mapping in enumerate(value_datetime_mapping): - assert i < len(provider.records) - expected_datetime, expected_value_index = mapping - expected_value = values[expected_value_index] - result_datetime = provider.records[i].date_time - result_value = provider.records[i]["elecprice_marketprice_wh"] - # print(f"{i}: Expected: {expected_datetime}:{expected_value}") - # print(f"{i}: Result: {result_datetime}:{result_value}") - assert compare_datetimes(result_datetime, expected_datetime).equal - assert result_value == expected_value + expected_values = sample_import_1_json[key] + result_values = provider.key_to_array( + key=key, + start_datetime=provider.ems_start_datetime, + end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"), + interval=to_duration("1 hour"), + ) + # Allow for some difference due to value calculation on DST change + npt.assert_allclose(result_values, expected_values, rtol=0.001) diff --git a/tests/test_feedintarifffixed.py b/tests/test_feedintarifffixed.py index a8e24e8..ca7b6e4 100644 --- a/tests/test_feedintarifffixed.py +++ b/tests/test_feedintarifffixed.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.feedintarifffixed import FeedInTariffFixed from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime diff --git a/tests/test_geneticoptimize.py b/tests/test_geneticoptimize.py index add9994..fc5110d 100644 --- a/tests/test_geneticoptimize.py +++ b/tests/test_geneticoptimize.py @@ -7,7 +7,7 @@ import pytest from akkudoktoreos.config.config import ConfigEOS from akkudoktoreos.core.cache import CacheEnergyManagementStore -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.optimization.genetic.genetic import GeneticOptimization from akkudoktoreos.optimization.genetic.geneticparams import ( GeneticOptimizationParameters, @@ -18,7 +18,7 @@ from akkudoktoreos.utils.visualize import ( prepare_visualize, # Import the new prepare_visualize ) -ems_eos = get_ems() +ems_eos = get_ems(init=True) # init once DIR_TESTDATA = Path(__file__).parent / "testdata" diff --git a/tests/test_loadakkudoktor.py b/tests/test_loadakkudoktor.py index eb42831..2b646f9 100644 --- a/tests/test_loadakkudoktor.py +++ b/tests/test_loadakkudoktor.py @@ -4,8 +4,8 @@ import numpy as np import pendulum import pytest -from akkudoktoreos.core.ems import get_ems -from akkudoktoreos.measurement.measurement import MeasurementDataRecord, get_measurement +from akkudoktoreos.core.coreabc import get_ems, get_measurement +from akkudoktoreos.measurement.measurement import MeasurementDataRecord from akkudoktoreos.prediction.loadakkudoktor import ( LoadAkkudoktor, LoadAkkudoktorAdjusted, @@ -63,7 +63,7 @@ def measurement_eos(): dt = to_datetime("2024-01-01T00:00:00") interval = to_duration("1 hour") for i in range(25): - measurement.records.append( + measurement.insert_by_datetime( MeasurementDataRecord( date_time=dt, load0_mr=load0_mr, @@ -138,7 +138,7 @@ def test_update_data(mock_load_data, loadakkudoktor): ems_eos.set_start_datetime(pendulum.datetime(2024, 1, 1)) # Assure there are no prediction records - loadakkudoktor.clear() + loadakkudoktor.delete_by_datetime(start_datetime=None, end_datetime=None) assert len(loadakkudoktor) == 0 # Execute the method @@ -152,6 +152,24 @@ def test_calculate_adjustment(loadakkudoktoradjusted, measurement_eos): """Test `_calculate_adjustment` for various scenarios.""" data_year_energy = np.random.rand(365, 2, 24) + # Check the test setup + assert loadakkudoktoradjusted.measurement is measurement_eos + assert measurement_eos.min_datetime == to_datetime("2024-01-01T00:00:00") + assert measurement_eos.max_datetime == to_datetime("2024-01-02T00:00:00") + # Use same calculation as in _calculate_adjustment + compare_start = measurement_eos.max_datetime - to_duration("7 days") + if compare_datetimes(compare_start, measurement_eos.min_datetime).lt: + # Not enough measurements for 7 days - use what is available + compare_start = measurement_eos.min_datetime + compare_end = measurement_eos.max_datetime + compare_interval = to_duration("1 hour") + load_total_kwh_array = measurement_eos.load_total_kwh( + start_datetime=compare_start, + end_datetime=compare_end, + interval=compare_interval, + ) + np.testing.assert_allclose(load_total_kwh_array, [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + # Call the method and validate results weekday_adjust, weekend_adjust = loadakkudoktoradjusted._calculate_adjustment(data_year_energy) assert weekday_adjust.shape == (24,) diff --git a/tests/test_measurement.py b/tests/test_measurement.py index efba602..4131f3b 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -3,10 +3,17 @@ import pytest from pendulum import datetime, duration from akkudoktoreos.config.config import SettingsEOS +from akkudoktoreos.core.coreabc import get_measurement from akkudoktoreos.measurement.measurement import ( MeasurementCommonSettings, MeasurementDataRecord, - get_measurement, +) +from akkudoktoreos.utils.datetimeutil import ( + DateTime, + Duration, + compare_datetimes, + to_datetime, + to_duration, ) @@ -41,8 +48,9 @@ class TestMeasurementDataRecord: def test_getitem_existing_field(self, record): """Test that __getitem__ returns correct value for existing native field.""" - record.date_time = "2024-01-01T00:00:00+00:00" - assert record["date_time"] is not None + date_time = "2024-01-01T00:00:00+00:00" + record.date_time = date_time + assert compare_datetimes(record["date_time"], to_datetime(date_time)).equal def test_getitem_existing_measurement(self, record): """Test that __getitem__ retrieves existing measurement values.""" @@ -220,6 +228,7 @@ class TestMeasurement: # Load meter readings are in kWh config_eos.measurement.load_emr_keys = ["load0_mr", "load1_mr", "load2_mr", "load3_mr"] measurement = get_measurement() + measurement.delete_by_datetime(None, None) record0 = MeasurementDataRecord( date_time=datetime(2023, 1, 1, hour=0), load0_mr=100, @@ -227,52 +236,54 @@ class TestMeasurement: ) assert record0.load0_mr == 100 assert record0.load1_mr == 200 - measurement.records = [ + records = [ MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=0), + date_time=to_datetime("2023-01-01T00:00:00"), load0_mr=100, load1_mr=200, ), MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=1), + date_time=to_datetime("2023-01-01T01:00:00"), load0_mr=150, load1_mr=250, ), MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=2), + date_time=to_datetime("2023-01-01T02:00:00"), load0_mr=200, load1_mr=300, ), MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=3), + date_time=to_datetime("2023-01-01T03:00:00"), load0_mr=250, load1_mr=350, ), MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=4), + date_time=to_datetime("2023-01-01T04:00:00"), load0_mr=300, load1_mr=400, ), MeasurementDataRecord( - date_time=datetime(2023, 1, 1, hour=5), + date_time=to_datetime("2023-01-01T05:00:00"), load0_mr=350, load1_mr=450, ), ] + for record in records: + measurement.insert_by_datetime(record) return measurement def test_interval_count(self, measurement_eos): """Test interval count calculation.""" - start = datetime(2023, 1, 1, 0) - end = datetime(2023, 1, 1, 3) + start = to_datetime("2023-01-01T00:00:00") + end = to_datetime("2023-01-01T03:00:00") interval = duration(hours=1) assert measurement_eos._interval_count(start, end, interval) == 3 def test_interval_count_invalid_end_before_start(self, measurement_eos): """Test interval count raises ValueError when end_datetime is before start_datetime.""" - start = datetime(2023, 1, 1, 3) - end = datetime(2023, 1, 1, 0) + start = to_datetime("2023-01-01T03:00:00") + end = to_datetime("2023-01-01T00:00:00") interval = duration(hours=1) with pytest.raises(ValueError, match="end_datetime must be after start_datetime"): @@ -280,8 +291,8 @@ class TestMeasurement: def test_interval_count_invalid_non_positive_interval(self, measurement_eos): """Test interval count raises ValueError when interval is non-positive.""" - start = datetime(2023, 1, 1, 0) - end = datetime(2023, 1, 1, 3) + start = to_datetime("2023-01-01T00:00:00") + end = to_datetime("2023-01-01T03:00:00") with pytest.raises(ValueError, match="interval must be positive"): measurement_eos._interval_count(start, end, duration(hours=0)) @@ -289,8 +300,8 @@ class TestMeasurement: def test_energy_from_meter_readings_valid_input(self, measurement_eos): """Test _energy_from_meter_readings with valid inputs and proper alignment of load data.""" key = "load0_mr" - start_datetime = datetime(2023, 1, 1, 0) - end_datetime = datetime(2023, 1, 1, 5) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") interval = duration(hours=1) load_array = measurement_eos._energy_from_meter_readings( @@ -303,12 +314,12 @@ class TestMeasurement: def test_energy_from_meter_readings_empty_array(self, measurement_eos): """Test _energy_from_meter_readings with no data (empty array).""" key = "load0_mr" - start_datetime = datetime(2023, 1, 1, 0) - end_datetime = datetime(2023, 1, 1, 5) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") interval = duration(hours=1) # Use empyt records array - measurement_eos.records = [] + measurement_eos.delete_by_datetime(start_datetime, end_datetime) load_array = measurement_eos._energy_from_meter_readings( key, start_datetime, end_datetime, interval @@ -324,25 +335,46 @@ class TestMeasurement: def test_energy_from_meter_readings_misaligned_array(self, measurement_eos): """Test _energy_from_meter_readings with misaligned array size.""" key = "load1_mr" - start_datetime = measurement_eos.min_datetime - end_datetime = measurement_eos.max_datetime interval = duration(hours=1) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") # Use misaligned array, latest interval set to 2 hours (instead of 1 hour) - measurement_eos.records[-1].date_time = datetime(2023, 1, 1, 6) + latest_record_datetime = to_datetime("2023-01-01T05:00:00") + new_record_datetime = to_datetime("2023-01-01T06:00:00") + record = measurement_eos.get_by_datetime(latest_record_datetime) + assert record is not None + measurement_eos.delete_by_datetime(start_datetime = latest_record_datetime, + end_datetime = new_record_datetime) + record.date_time = new_record_datetime + measurement_eos.insert_by_datetime(record) + + # Check test setup + dates, values = measurement_eos.key_to_lists(key, start_datetime, None) + assert dates == [ + to_datetime("2023-01-01T00:00:00"), + to_datetime("2023-01-01T01:00:00"), + to_datetime("2023-01-01T02:00:00"), + to_datetime("2023-01-01T03:00:00"), + to_datetime("2023-01-01T04:00:00"), + to_datetime("2023-01-01T06:00:00"), + ] + assert values == [200, 250, 300, 350, 400, 450] + array = measurement_eos.key_to_array(key, start_datetime, end_datetime + interval, interval=interval) + np.testing.assert_array_equal(array, [200, 250, 300, 350, 400, 425]) load_array = measurement_eos._energy_from_meter_readings( key, start_datetime, end_datetime, interval ) - expected_load_array = np.array([50, 50, 50, 50, 25]) # Differences between consecutive readings + expected_load_array = np.array([50., 50., 50., 50., 25.]) # Differences between consecutive readings np.testing.assert_array_equal(load_array, expected_load_array) def test_energy_from_meter_readings_partial_data(self, measurement_eos, caplog): """Test _energy_from_meter_readings with partial data (misaligned but empty array).""" key = "load2_mr" - start_datetime = datetime(2023, 1, 1, 0) - end_datetime = datetime(2023, 1, 1, 5) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") interval = duration(hours=1) with caplog.at_level("DEBUG"): @@ -359,8 +391,8 @@ class TestMeasurement: def test_energy_from_meter_readings_negative_interval(self, measurement_eos): """Test _energy_from_meter_readings with a negative interval.""" key = "load3_mr" - start_datetime = datetime(2023, 1, 1, 0) - end_datetime = datetime(2023, 1, 1, 5) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") interval = duration(hours=-1) with pytest.raises(ValueError, match="interval must be positive"): @@ -368,11 +400,11 @@ class TestMeasurement: def test_load_total_kwh(self, measurement_eos): """Test total load calculation.""" - start = datetime(2023, 1, 1, 0) - end = datetime(2023, 1, 1, 2) + start_datetime = to_datetime("2023-01-01T03:00:00") + end_datetime = to_datetime("2023-01-01T05:00:00") interval = duration(hours=1) - result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval) + result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval) # Expected total load per interval expected = np.array([100, 100]) # Differences between consecutive meter readings @@ -381,20 +413,20 @@ class TestMeasurement: def test_load_total_kwh_no_data(self, measurement_eos): """Test total load calculation with no data.""" measurement_eos.records = [] - start = datetime(2023, 1, 1, 0) - end = datetime(2023, 1, 1, 3) + start_datetime = to_datetime("2023-01-01T00:00:00") + end_datetime = to_datetime("2023-01-01T03:00:00") interval = duration(hours=1) - result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval) + result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval) expected = np.zeros(3) # No data, so all intervals are zero np.testing.assert_array_equal(result, expected) def test_load_total_kwh_partial_intervals(self, measurement_eos): """Test total load calculation with partial intervals.""" - start = datetime(2023, 1, 1, 0, 30) # Start in the middle of an interval - end = datetime(2023, 1, 1, 1, 30) # End in the middle of another interval + start_datetime = to_datetime("2023-01-01T00:30:00") # Start in the middle of an interval + end_datetime = to_datetime("2023-01-01T01:30:00") # End in the middle of another interval interval = duration(hours=1) - result = measurement_eos.load_total_kwh(start_datetime=start, end_datetime=end, interval=interval) + result = measurement_eos.load_total_kwh(start_datetime=start_datetime, end_datetime=end_datetime, interval=interval) expected = np.array([100]) # Only one complete interval covered np.testing.assert_array_equal(result, expected) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 18af477..4b241e0 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,6 +1,7 @@ import pytest from pydantic import ValidationError +from akkudoktoreos.core.coreabc import get_prediction from akkudoktoreos.prediction.elecpriceakkudoktor import ElecPriceAkkudoktor from akkudoktoreos.prediction.elecpriceenergycharts import ElecPriceEnergyCharts from akkudoktoreos.prediction.elecpriceimport import ElecPriceImport @@ -15,7 +16,6 @@ from akkudoktoreos.prediction.loadvrm import LoadVrm from akkudoktoreos.prediction.prediction import ( Prediction, PredictionCommonSettings, - get_prediction, ) from akkudoktoreos.prediction.pvforecastakkudoktor import PVForecastAkkudoktor from akkudoktoreos.prediction.pvforecastimport import PVForecastImport diff --git a/tests/test_predictionabc.py b/tests/test_predictionabc.py index 4bb84f1..1e9954a 100644 --- a/tests/test_predictionabc.py +++ b/tests/test_predictionabc.py @@ -7,10 +7,10 @@ import pendulum import pytest from pydantic import Field -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.prediction import PredictionCommonSettings from akkudoktoreos.prediction.predictionabc import ( - PredictionBase, + PredictionABC, PredictionContainer, PredictionProvider, PredictionRecord, @@ -28,7 +28,7 @@ class DerivedConfig(PredictionCommonSettings): class_constant: Optional[int] = Field(default=None, description="Test config by class constant") -class DerivedBase(PredictionBase): +class DerivedBase(PredictionABC): instance_field: Optional[str] = Field(default=None, description="Field Value") class_constant: ClassVar[int] = 30 @@ -84,7 +84,7 @@ class DerivedPredictionContainer(PredictionContainer): # ---------- -class TestPredictionBase: +class TestPredictionABC: @pytest.fixture def base(self, monkeypatch): # Provide default values for configuration @@ -216,17 +216,19 @@ class TestPredictionProvider: def test_delete_by_datetime(self, provider, sample_start_datetime): """Test `delete_by_datetime` method for removing records by datetime range.""" # Add records to the provider for deletion testing - provider.records = [ + records = [ self.create_test_record(sample_start_datetime - to_duration("3 hours"), 1), self.create_test_record(sample_start_datetime - to_duration("1 hour"), 2), self.create_test_record(sample_start_datetime + to_duration("1 hour"), 3), ] + for record in records: + provider.insert_by_datetime(record) provider.delete_by_datetime( start_datetime=sample_start_datetime - to_duration("2 hours"), end_datetime=sample_start_datetime + to_duration("2 hours"), ) - assert len(provider.records) == 1, ( + assert len(provider) == 1, ( "Only one record should remain after deletion by datetime." ) assert provider.records[0].date_time == sample_start_datetime - to_duration("3 hours"), ( @@ -243,15 +245,17 @@ class TestPredictionContainer: @pytest.fixture def container_with_providers(self): - record1 = self.create_test_record(datetime(2023, 11, 5), 1) - record2 = self.create_test_record(datetime(2023, 11, 6), 2) - record3 = self.create_test_record(datetime(2023, 11, 7), 3) + records = [ + # Test records - include 'prediction_value' key + self.create_test_record(datetime(2023, 11, 5), 1), + self.create_test_record(datetime(2023, 11, 6), 2), + self.create_test_record(datetime(2023, 11, 7), 3), + ] provider = DerivedPredictionProvider() - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) assert len(provider) == 0 - provider.append(record1) - provider.append(record2) - provider.append(record3) + for record in records: + provider.insert_by_datetime(record) assert len(provider) == 3 container = DerivedPredictionContainer() container.providers.clear() @@ -378,7 +382,9 @@ class TestPredictionContainer: assert len(container_with_providers.providers) == 1 # check all keys are available (don't care for position) for key in ["prediction_value", "date_time"]: - assert key in list(container_with_providers.keys()) + assert key in container_with_providers.record_keys + for key in ["prediction_value", "date_time"]: + assert key in container_with_providers.keys() series = container_with_providers["prediction_value"] assert isinstance(series, pd.Series) assert series.name == "prediction_value" diff --git a/tests/test_pvforecastakkudoktor.py b/tests/test_pvforecastakkudoktor.py index f117ca8..0960d8a 100644 --- a/tests/test_pvforecastakkudoktor.py +++ b/tests/test_pvforecastakkudoktor.py @@ -5,8 +5,7 @@ from unittest.mock import Mock, patch import pytest from loguru import logger -from akkudoktoreos.core.ems import get_ems -from akkudoktoreos.prediction.prediction import get_prediction +from akkudoktoreos.core.coreabc import get_ems, get_prediction from akkudoktoreos.prediction.pvforecastakkudoktor import ( AkkudoktorForecastHorizon, AkkudoktorForecastMeta, @@ -137,7 +136,7 @@ def provider(): def provider_empty_instance(): """Fixture that returns an empty instance of PVForecast.""" empty_instance = PVForecastAkkudoktor() - empty_instance.clear() + empty_instance.delete_by_datetime(start_datetime=None, end_datetime=None) assert len(empty_instance) == 0 return empty_instance @@ -277,7 +276,7 @@ def test_pvforecast_akkudoktor_update_with_sample_forecast( ems_eos.set_start_datetime(sample_forecast_start) provider.update_data(force_enable=True, force_update=True) assert compare_datetimes(provider.ems_start_datetime, sample_forecast_start).equal - assert compare_datetimes(provider[0].date_time, to_datetime(sample_forecast_start)).equal + assert compare_datetimes(provider.records[0].date_time, to_datetime(sample_forecast_start)).equal # Report Generation Test @@ -290,7 +289,7 @@ def test_report_ac_power_and_measurement(provider, config_eos): pvforecast_dc_power=450.0, pvforecast_ac_power=400.0, ) - provider.append(record) + provider.insert_by_datetime(record) report = provider.report_ac_power_and_measurement() assert "DC: 450.0" in report @@ -323,19 +322,19 @@ def test_timezone_behaviour( expected_datetime = to_datetime("2024-10-06T00:00:00+0200", in_timezone=other_timezone) assert compare_datetimes(other_start_datetime, expected_datetime).equal - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) assert len(provider) == 0 ems_eos = get_ems() ems_eos.set_start_datetime(other_start_datetime) provider.update_data(force_update=True) assert compare_datetimes(provider.ems_start_datetime, other_start_datetime).equal # Check wether first record starts at requested sample start time - assert compare_datetimes(provider[0].date_time, sample_forecast_start).equal + assert compare_datetimes(provider.records[0].date_time, sample_forecast_start).equal # Test updating AC power measurement for a specific date. provider.update_value(sample_forecast_start, "pvforecastakkudoktor_ac_power_measured", 1000) # Check wether first record was filled with ac power measurement - assert provider[0].pvforecastakkudoktor_ac_power_measured == 1000 + assert provider.records[0].pvforecastakkudoktor_ac_power_measured == 1000 # Test fetching temperature forecast for a specific date. other_end_datetime = other_start_datetime + to_duration("24 hours") diff --git a/tests/test_pvforecastimport.py b/tests/test_pvforecastimport.py index 882e901..ba554b8 100644 --- a/tests/test_pvforecastimport.py +++ b/tests/test_pvforecastimport.py @@ -1,11 +1,12 @@ import json from pathlib import Path +import numpy.testing as npt import pytest -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.pvforecastimport import PVForecastImport -from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime +from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata") @@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos): ) def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos): """Test fetching forecast from import.""" + key = "pvforecast_ac_power" ems_eos = get_ems() ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin")) if from_file: @@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi else: config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path = None assert config_eos.pvforecast.provider_settings.PVForecastImport.import_file_path is None - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) # Call the method provider.update_data() @@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi assert provider.ems_start_datetime is not None assert provider.total_hours is not None assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal - values = sample_import_1_json["pvforecast_ac_power"] - value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values)) - for i, mapping in enumerate(value_datetime_mapping): - assert i < len(provider.records) - expected_datetime, expected_value_index = mapping - expected_value = values[expected_value_index] - result_datetime = provider.records[i].date_time - result_value = provider.records[i]["pvforecast_ac_power"] - # print(f"{i}: Expected: {expected_datetime}:{expected_value}") - # print(f"{i}: Result: {result_datetime}:{result_value}") - assert compare_datetimes(result_datetime, expected_datetime).equal - assert result_value == expected_value + expected_values = sample_import_1_json[key] + result_values = provider.key_to_array( + key=key, + start_datetime=provider.ems_start_datetime, + end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"), + interval=to_duration("1 hour"), + ) + # Allow for some difference due to value calculation on DST change + npt.assert_allclose(result_values, expected_values, rtol=0.001) diff --git a/tests/test_retentionmanager.py b/tests/test_retentionmanager.py new file mode 100644 index 0000000..8489642 --- /dev/null +++ b/tests/test_retentionmanager.py @@ -0,0 +1,701 @@ +"""Tests for RetentionManager and JobState.""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest +from loguru import logger + +import akkudoktoreos.server.retentionmanager +from akkudoktoreos.server.retentionmanager import JobState, RetentionManager + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +INTERVAL = 10.0 +DUE_INTERVAL = 0.001 # non-zero so interval() does not fall back to fallback_interval +FALLBACK = 300.0 + + +def make_config_getter(interval: float = INTERVAL) -> Any: + """Return a simple config getter that always yields ``interval`` for any key.""" + return lambda key: interval + + +def make_config_getter_none() -> Any: + """Return a config getter that always yields ``None`` (job disabled).""" + return lambda key: None + + +def make_manager(interval: float = INTERVAL, shutdown_timeout: float = 5.0) -> RetentionManager: + """Return a ``RetentionManager`` backed by a fixed-interval config getter.""" + return RetentionManager(make_config_getter(interval), shutdown_timeout=shutdown_timeout) + + +def make_manager_none(shutdown_timeout: float = 5.0) -> RetentionManager: + """Return a ``RetentionManager`` whose config getter always returns None (all jobs disabled).""" + return RetentionManager(make_config_getter_none(), shutdown_timeout=shutdown_timeout) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRetentionManager: + """Tests for :class:`RetentionManager` and :class:`JobState`.""" + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def test_init_stores_config_getter(self) -> None: + """The config getter passed to __init__ is stored and forwarded to jobs.""" + getter = make_config_getter() + manager = RetentionManager(getter) + assert manager._config_getter is getter + + def test_init_empty_job_registry(self) -> None: + """A newly created manager has no registered jobs.""" + manager = make_manager() + assert manager._jobs == {} + + # ------------------------------------------------------------------ + # register / unregister + # ------------------------------------------------------------------ + + def test_register_adds_job(self) -> None: + """Registering a function adds a JobState entry.""" + manager = make_manager() + func = MagicMock() + manager.register("job1", func, interval_attr="some/key") + assert "job1" in manager._jobs + + def test_register_job_state_fields(self) -> None: + """Registered JobState carries the correct initial field values.""" + manager = make_manager() + func = MagicMock() + manager.register("job1", func, interval_attr="some/key", fallback_interval=60.0) + job = manager._jobs["job1"] + assert job.name == "job1" + assert job.func is func + assert job.interval_attr == "some/key" + assert job.fallback_interval == 60.0 + assert job.config_getter is manager._config_getter + assert job.on_exception is None + assert job.last_run_at == 0.0 + assert job.run_count == 0 + assert job.is_running is False + + def test_register_stores_on_exception(self) -> None: + """The on_exception callback is stored on the JobState.""" + manager = make_manager() + handler = MagicMock() + manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler) + assert manager._jobs["job1"].on_exception is handler + + def test_register_duplicate_raises(self) -> None: + """Registering the same name twice raises ValueError.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + with pytest.raises(ValueError, match="job1"): + manager.register("job1", MagicMock(), interval_attr="k") + + def test_unregister_removes_job(self) -> None: + """Unregistering a job removes it from the registry.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + manager.unregister("job1") + assert "job1" not in manager._jobs + + def test_unregister_missing_job_is_noop(self) -> None: + """Unregistering a non-existent job does not raise.""" + manager = make_manager() + manager.unregister("nonexistent") # must not raise + + # ------------------------------------------------------------------ + # JobState.interval() + # ------------------------------------------------------------------ + + def test_job_interval_from_config_getter(self) -> None: + """JobState.interval() returns the value provided by config_getter.""" + manager = make_manager(interval=42.0) + manager.register("job1", MagicMock(), interval_attr="k") + assert manager._jobs["job1"].interval() == 42.0 + + def test_job_interval_none_when_config_returns_none(self) -> None: + """JobState.interval() returns None when config_getter returns None (job disabled).""" + manager = make_manager_none() + manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=FALLBACK) + assert manager._jobs["job1"].interval() is None + + def test_job_interval_none_does_not_fall_back(self) -> None: + """A None config value must NOT fall back to fallback_interval -- None means disabled.""" + manager = make_manager_none() + manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0) + # If None incorrectly fell back, this would return 99.0 instead of None + assert manager._jobs["job1"].interval() is None + + def test_job_interval_fallback_on_key_error(self) -> None: + """JobState.interval() uses fallback_interval when config_getter raises KeyError.""" + manager = RetentionManager(lambda key: (_ for _ in ()).throw(KeyError(key))) + manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=99.0) + assert manager._jobs["job1"].interval() == 99.0 + + def test_job_interval_fallback_on_index_error(self) -> None: + """JobState.interval() uses fallback_interval when config_getter raises IndexError.""" + manager = RetentionManager(lambda key: (_ for _ in ()).throw(IndexError())) + manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=77.0) + assert manager._jobs["job1"].interval() == 77.0 + + def test_job_interval_fallback_on_zero_value(self) -> None: + """JobState.interval() uses fallback_interval when config_getter returns zero.""" + manager = RetentionManager(lambda key: 0) + manager.register("job1", MagicMock(), interval_attr="k", fallback_interval=55.0) + assert manager._jobs["job1"].interval() == 55.0 + + # ------------------------------------------------------------------ + # JobState.is_due() + # ------------------------------------------------------------------ + + def test_job_is_due_when_never_run(self) -> None: + """A job is always due when it has never been run (last_run_at == 0.0).""" + manager = make_manager(interval=INTERVAL) + manager.register("job1", MagicMock(), interval_attr="k") + assert manager._jobs["job1"].is_due() is True + + def test_job_is_not_due_immediately_after_run(self) -> None: + """A job is not due immediately after last_run_at is set to now.""" + manager = make_manager(interval=INTERVAL) + manager.register("job1", MagicMock(), interval_attr="k") + manager._jobs["job1"].last_run_at = time.monotonic() + assert manager._jobs["job1"].is_due() is False + + def test_job_is_due_after_interval_elapsed(self) -> None: + """A job becomes due once the interval has passed since last_run_at.""" + manager = make_manager(interval=1.0) + manager.register("job1", MagicMock(), interval_attr="k") + manager._jobs["job1"].last_run_at = time.monotonic() - 2.0 # 2 s ago > 1 s interval + assert manager._jobs["job1"].is_due() is True + + def test_job_is_never_due_when_interval_is_none(self) -> None: + """is_due() returns False when interval() is None, even if last_run_at is 0.""" + manager = make_manager_none() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + # last_run_at == 0.0 would make any enabled job due immediately + assert job.last_run_at == 0.0 + assert job.is_due() is False + + def test_job_is_never_due_when_disabled_regardless_of_last_run(self) -> None: + """is_due() stays False for a disabled job even long after its last run.""" + manager = make_manager_none() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + job.last_run_at = time.monotonic() - 365 * 24 * 3600 # "ran" a year ago + assert job.is_due() is False + + # ------------------------------------------------------------------ + # JobState.summary() + # ------------------------------------------------------------------ + + def test_summary_keys(self) -> None: + """summary() returns all expected keys including interval_s.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + summary = manager._jobs["job1"].summary() + assert set(summary.keys()) == { + "name", "interval_attr", "interval_s", "last_run_at", + "last_duration_s", "last_error", "run_count", "is_running", + } + + def test_summary_interval_s_reflects_config(self) -> None: + """summary()['interval_s'] matches the value returned by interval().""" + manager = make_manager(interval=42.0) + manager.register("job1", MagicMock(), interval_attr="k") + assert manager._jobs["job1"].summary()["interval_s"] == 42.0 + + def test_summary_interval_s_is_none_when_disabled(self) -> None: + """summary()['interval_s'] is None when the job is disabled via config.""" + manager = make_manager_none() + manager.register("job1", MagicMock(), interval_attr="k") + assert manager._jobs["job1"].summary()["interval_s"] is None + + def test_summary_values(self) -> None: + """summary() reflects the current JobState values.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="my/key") + job = manager._jobs["job1"] + job.last_run_at = 1234.5 + job.last_duration = 0.12345 + job.last_error = "oops" + job.run_count = 3 + job.is_running = True + s = job.summary() + assert s["name"] == "job1" + assert s["interval_attr"] == "my/key" + assert s["last_run_at"] == 1234.5 + assert s["last_duration_s"] == 0.1235 # rounded to 4 dp + assert s["last_error"] == "oops" + assert s["run_count"] == 3 + assert s["is_running"] is True + + # ------------------------------------------------------------------ + # status() + # ------------------------------------------------------------------ + + def test_status_empty(self) -> None: + """status() returns an empty list when no jobs are registered.""" + assert make_manager().status() == [] + + def test_status_contains_all_jobs(self) -> None: + """status() returns one entry per registered job.""" + manager = make_manager() + manager.register("a", MagicMock(), interval_attr="k1") + manager.register("b", MagicMock(), interval_attr="k2") + names = {s["name"] for s in manager.status()} + assert names == {"a", "b"} + + def test_status_shows_disabled_job(self) -> None: + """status() includes disabled jobs with interval_s == None.""" + manager = make_manager_none() + manager.register("disabled", MagicMock(), interval_attr="k") + entries = manager.status() + assert len(entries) == 1 + assert entries[0]["interval_s"] is None + + # ------------------------------------------------------------------ + # tick() -- job dispatch + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_tick_runs_due_sync_job(self) -> None: + """tick() executes a sync job that is due.""" + manager = make_manager(interval=DUE_INTERVAL) + func = MagicMock() + manager.register("job1", func, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + await manager.shutdown() + func.assert_called_once() + + @pytest.mark.asyncio + async def test_tick_runs_due_async_job(self) -> None: + """tick() executes an async job that is due.""" + manager = make_manager(interval=DUE_INTERVAL) + func = AsyncMock() + manager.register("job1", func, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + await manager.shutdown() + func.assert_called_once() + + @pytest.mark.asyncio + async def test_tick_skips_not_due_job(self) -> None: + """tick() does not execute a job whose interval has not yet elapsed.""" + manager = make_manager(interval=9999.0) + func = MagicMock() + manager.register("job1", func, interval_attr="k") + manager._jobs["job1"].last_run_at = time.monotonic() # just ran + await manager.tick() + await asyncio.sleep(0) + await manager.shutdown() + func.assert_not_called() + + @pytest.mark.asyncio + async def test_tick_skips_disabled_job(self) -> None: + """tick() never executes a job whose interval is None, even if never run before.""" + manager = make_manager_none() + func = MagicMock() + manager.register("disabled", func, interval_attr="k") + job = manager._jobs["disabled"] + # last_run_at == 0.0 would fire any enabled job immediately + assert job.last_run_at == 0.0 + await manager.tick() + await asyncio.sleep(0) + await manager.shutdown() + func.assert_not_called() + + @pytest.mark.asyncio + async def test_tick_skips_disabled_job_adds_no_task(self) -> None: + """tick() adds no task to _running_tasks for a disabled job.""" + manager = make_manager_none() + manager.register("disabled", AsyncMock(), interval_attr="k") + await manager.tick() + await asyncio.sleep(0) + assert len(manager._running_tasks) == 0 + + @pytest.mark.asyncio + async def test_tick_enabled_and_disabled_jobs_mixed(self) -> None: + """tick() fires enabled jobs and silently skips disabled ones in the same manager.""" + results: list[str] = [] + + async def enabled_job() -> None: + results.append("ran") + + manager = RetentionManager( + lambda key: DUE_INTERVAL if key == "enabled/interval" else None, + shutdown_timeout=5.0, + ) + manager.register("enabled", enabled_job, interval_attr="enabled/interval") + manager.register("disabled", AsyncMock(), interval_attr="disabled/interval") + + await manager.tick() + await asyncio.sleep(0) + await asyncio.sleep(0) + await manager.shutdown() + + assert results == ["ran"], "Only the enabled job must have run" + + @pytest.mark.asyncio + async def test_tick_skips_already_running_job(self) -> None: + """tick() does not start a job that is still marked as running.""" + manager = make_manager(interval=DUE_INTERVAL) + func = MagicMock() + manager.register("job1", func, interval_attr="k") + manager._jobs["job1"].is_running = True + await manager.tick() + await asyncio.sleep(0) + await manager.shutdown() + func.assert_not_called() + + @pytest.mark.asyncio + async def test_tick_runs_multiple_jobs_concurrently(self) -> None: + """tick() fires all due jobs as independent tasks.""" + manager = make_manager(interval=DUE_INTERVAL) + results: list[str] = [] + + async def job_a() -> None: + results.append("a") + + async def job_b() -> None: + results.append("b") + + manager.register("a", job_a, interval_attr="k") + manager.register("b", job_b, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + await manager.shutdown() + assert sorted(results) == ["a", "b"] + + @pytest.mark.asyncio + async def test_tick_adds_tasks_to_running_set(self) -> None: + """tick() adds a task to _running_tasks for each due job.""" + barrier = asyncio.Event() + manager = make_manager(interval=DUE_INTERVAL) + + async def blocking_job() -> None: + await barrier.wait() + + manager.register("job1", blocking_job, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + # Task is still running (barrier not set), so it must be in the set. + assert len(manager._running_tasks) == 1 + barrier.set() + await manager.shutdown() + + @pytest.mark.asyncio + async def test_tick_removes_task_from_running_set_on_completion(self) -> None: + """Completed tasks are removed from _running_tasks automatically.""" + manager = make_manager(interval=DUE_INTERVAL) + manager.register("job1", AsyncMock(), interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await manager.shutdown() + assert len(manager._running_tasks) == 0 + + # ------------------------------------------------------------------ + # shutdown() + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_shutdown_returns_immediately_when_no_tasks(self) -> None: + """shutdown() completes without blocking when no tasks are running.""" + manager = make_manager() + await manager.shutdown() # must return promptly without raising + + @pytest.mark.asyncio + async def test_shutdown_waits_for_in_flight_task(self) -> None: + """shutdown() blocks until a long-running job task finishes.""" + barrier = asyncio.Event() + finished: list[bool] = [] + manager = make_manager(interval=DUE_INTERVAL) + + async def slow_job() -> None: + await barrier.wait() + finished.append(True) + + manager.register("job1", slow_job, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + assert finished == [] # job still blocked + barrier.set() + await manager.shutdown() + assert finished == [True] # job completed before shutdown returned + + @pytest.mark.asyncio + async def test_shutdown_waits_for_multiple_in_flight_tasks(self) -> None: + """shutdown() waits for all concurrently running job tasks.""" + barrier = asyncio.Event() + finished: list[str] = [] + manager = make_manager(interval=DUE_INTERVAL) + + async def slow_a() -> None: + await barrier.wait() + finished.append("a") + + async def slow_b() -> None: + await barrier.wait() + finished.append("b") + + manager.register("a", slow_a, interval_attr="k") + manager.register("b", slow_b, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await asyncio.sleep(0) # second yield ensures tasks have started + assert finished == [] + barrier.set() + await manager.shutdown() + assert sorted(finished) == ["a", "b"] + + @pytest.mark.asyncio + async def test_shutdown_does_not_raise_when_job_failed(self) -> None: + """shutdown() completes without raising even if a job task raised an exception.""" + manager = make_manager(interval=DUE_INTERVAL) + + def failing_func() -> None: + raise RuntimeError("job error") + + manager.register("job1", failing_func, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await manager.shutdown() # must not raise + + @pytest.mark.asyncio + async def test_shutdown_clears_running_tasks_set(self) -> None: + """_running_tasks is empty after shutdown() completes.""" + manager = make_manager(interval=DUE_INTERVAL) + manager.register("job1", AsyncMock(), interval_attr="k") + await manager.tick() + await asyncio.sleep(0) # yield so ensure_future tasks are scheduled + await manager.shutdown() + assert manager._running_tasks == set() + + @pytest.mark.asyncio + async def test_shutdown_timeout_returns_without_blocking(self) -> None: + """shutdown() returns once the timeout elapses even if a job is still running.""" + stuck = asyncio.Event() # never set -- job blocks forever + manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05) + + async def forever() -> None: + await stuck.wait() + + manager.register("stuck", forever, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) + await asyncio.sleep(0) + # Must return within the timeout, not block forever. + await manager.shutdown() + + @pytest.mark.asyncio + async def test_shutdown_timeout_logs_error_for_pending_jobs(self) -> None: + """An error is logged listing jobs still running after the timeout.""" + stuck = asyncio.Event() + manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05) + + async def forever() -> None: + await stuck.wait() + + manager.register("stuck_job", forever, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) + await asyncio.sleep(0) + + with patch.object(logger, "error") as mock_error: + await manager.shutdown() + assert mock_error.called, "Expected logger.error to be called on timeout" + # All positional args joined: the stuck job name must appear. + logged = str(mock_error.call_args_list) + assert "stuck_job" in logged + + @pytest.mark.asyncio + async def test_shutdown_timeout_clears_running_tasks_set(self) -> None: + """_running_tasks is cleared even when the timeout elapses.""" + stuck = asyncio.Event() + manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=0.05) + + async def forever() -> None: + await stuck.wait() + + manager.register("stuck", forever, interval_attr="k") + await manager.tick() + await asyncio.sleep(0) + await asyncio.sleep(0) + await manager.shutdown() + assert manager._running_tasks == set() + + @pytest.mark.asyncio + async def test_shutdown_no_error_logged_when_all_finish_in_time(self) -> None: + """No error is logged when all tasks complete within the timeout.""" + manager = RetentionManager(make_config_getter(DUE_INTERVAL), shutdown_timeout=5.0) + manager.register("job1", AsyncMock(), interval_attr="k") + await manager.tick() + await asyncio.sleep(0) + + with patch.object(logger, "error") as mock_error: + await manager.shutdown() + mock_error.assert_not_called() + + def test_init_stores_shutdown_timeout(self) -> None: + """The shutdown_timeout passed to __init__ is stored on the instance.""" + manager = RetentionManager(make_config_getter(), shutdown_timeout=99.0) + assert manager._shutdown_timeout == 99.0 + + def test_init_default_shutdown_timeout(self) -> None: + """The default shutdown_timeout is 30 seconds.""" + manager = RetentionManager(make_config_getter()) + assert manager._shutdown_timeout == 30.0 + + # ------------------------------------------------------------------ + # _run_job() -- state updates + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_run_job_increments_run_count(self) -> None: + """_run_job() increments run_count after each execution.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + await manager._run_job(job) + await manager._run_job(job) + assert job.run_count == 2 + + @pytest.mark.asyncio + async def test_run_job_updates_last_run_at(self) -> None: + """_run_job() sets last_run_at to a recent monotonic timestamp.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + before = time.monotonic() + await manager._run_job(job) + assert job.last_run_at >= before + + @pytest.mark.asyncio + async def test_run_job_updates_last_duration(self) -> None: + """_run_job() records a non-negative last_duration.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + await manager._run_job(job) + assert job.last_duration >= 0.0 + + @pytest.mark.asyncio + async def test_run_job_clears_is_running_on_success(self) -> None: + """is_running is False after a successful job execution.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + await manager._run_job(job) + assert job.is_running is False + + @pytest.mark.asyncio + async def test_run_job_clears_last_error_on_success(self) -> None: + """last_error is set to None after a successful execution.""" + manager = make_manager() + manager.register("job1", MagicMock(), interval_attr="k") + job = manager._jobs["job1"] + job.last_error = "stale error" + await manager._run_job(job) + assert job.last_error is None + + # ------------------------------------------------------------------ + # _run_job() -- exception handling + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_run_job_stores_exception_message(self) -> None: + """last_error is set to the exception message when the job raises.""" + manager = make_manager() + + def failing_func() -> None: + raise RuntimeError("boom") + + manager.register("job1", failing_func, interval_attr="k") + job = manager._jobs["job1"] + await manager._run_job(job) + assert job.last_error == "boom" + + @pytest.mark.asyncio + async def test_run_job_still_updates_state_after_exception(self) -> None: + """run_count and last_run_at are updated even when the job raises.""" + manager = make_manager() + + def failing_func() -> None: + raise RuntimeError("boom") + + manager.register("job1", failing_func, interval_attr="k") + job = manager._jobs["job1"] + before = time.monotonic() + await manager._run_job(job) + assert job.run_count == 1 + assert job.last_run_at >= before + assert job.is_running is False + + @pytest.mark.asyncio + async def test_run_job_calls_sync_on_exception_handler(self) -> None: + """A sync on_exception handler is called with the raised exception.""" + manager = make_manager() + handler = MagicMock() + exc = RuntimeError("oops") + + def failing_func() -> None: + raise exc + + manager.register("job1", failing_func, interval_attr="k", on_exception=handler) + await manager._run_job(manager._jobs["job1"]) + handler.assert_called_once_with(exc) + + @pytest.mark.asyncio + async def test_run_job_calls_async_on_exception_handler(self) -> None: + """An async on_exception handler is awaited with the raised exception.""" + manager = make_manager() + handler = AsyncMock() + exc = RuntimeError("oops") + + def failing_func() -> None: + raise exc + + manager.register("job1", failing_func, interval_attr="k", on_exception=handler) + await manager._run_job(manager._jobs["job1"]) + handler.assert_called_once_with(exc) + + @pytest.mark.asyncio + async def test_run_job_no_on_exception_handler_does_not_raise(self) -> None: + """A failing job without on_exception does not propagate the exception.""" + manager = make_manager() + + def failing_func() -> None: + raise RuntimeError("silent failure") + + manager.register("job1", failing_func, interval_attr="k") + await manager._run_job(manager._jobs["job1"]) # must not raise + + @pytest.mark.asyncio + async def test_run_job_on_exception_not_called_on_success(self) -> None: + """on_exception is not called when the job succeeds.""" + manager = make_manager() + handler = MagicMock() + manager.register("job1", MagicMock(), interval_attr="k", on_exception=handler) + await manager._run_job(manager._jobs["job1"]) + handler.assert_not_called() diff --git a/tests/test_version.py b/tests/test_version.py index e4da105..c175235 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -2,11 +2,22 @@ import subprocess import sys from pathlib import Path +from typing import Optional, Union import pytest import yaml -from akkudoktoreos.core.version import _version_calculate, _version_hash +from akkudoktoreos.core.version import ( + ALLOWED_SUFFIXES, + DIR_PACKAGE_ROOT, + EXCLUDED_DIR_PATTERNS, + EXCLUDED_FILES, + HashConfig, + _version_calculate, + _version_hash, + collect_files, + hash_files, +) DIR_PROJECT_ROOT = Path(__file__).parent.parent GET_VERSION_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "get_version.py" @@ -14,11 +25,166 @@ BUMP_DEV_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "bump_dev_version.py" UPDATE_SCRIPT = DIR_PROJECT_ROOT / "scripts" / "update_version.py" +# --- Git helpers --- + +def get_git_tracked_files(repo_path: Path) -> Optional[set[Path]]: + """Get set of all files tracked by git in the repository. + + Returns None if not a git repository or git command fails. + """ + try: + result = subprocess.run( + ["git", "ls-files"], + cwd=repo_path, + capture_output=True, + text=True, + check=True + ) + # Convert relative paths to absolute paths + tracked_files = { + (repo_path / line.strip()).resolve() + for line in result.stdout.splitlines() + if line.strip() + } + return tracked_files + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +def is_git_repository(path: Path) -> bool: + """Check if path is inside a git repository.""" + try: + subprocess.run( + ["git", "rev-parse", "--git-dir"], + cwd=path, + capture_output=True, + check=True + ) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +def get_git_root(path: Path) -> Optional[Path]: + """Get the root directory of the git repository containing path.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + cwd=path, + capture_output=True, + text=True, + check=True + ) + return Path(result.stdout.strip()) + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +def check_files_in_git( + files: list[Path], + base_path: Optional[Path] = None +) -> tuple[list[Path], list[Path]]: + """Check which files are tracked by git. + + Args: + files: List of files to check + base_path: Base path to check for git repository (uses first file's parent if None) + + Returns: + Tuple of (tracked_files, untracked_files) + + Example: + >>> files = collect_files(config) + >>> tracked, untracked = check_files_in_git(files) + >>> if untracked: + ... print(f"Warning: {len(untracked)} files not in git") + """ + if not files: + return [], [] + + check_path = base_path or files[0].parent + + assert is_git_repository(check_path) + + git_root = get_git_root(check_path) + if not git_root: + return [], files + + git_tracked = get_git_tracked_files(git_root) + if git_tracked is None: + return [], files + + tracked = [f for f in files if f in git_tracked] + untracked = [f for f in files if f not in git_tracked] + + return tracked, untracked + + # --- Helper to create test files --- def write_file(path: Path, content: str): path.write_text(content, encoding="utf-8") return path +# -- Test version calculation --- + +def test_version_hash() -> None: + """Test which files are used for version hash calculation.""" + + watched_paths = [DIR_PACKAGE_ROOT] + + # Collect files + config = HashConfig( + paths=watched_paths, + allowed_suffixes=ALLOWED_SUFFIXES, + excluded_dir_patterns=EXCLUDED_DIR_PATTERNS, + excluded_files=EXCLUDED_FILES + ) + + files = collect_files(config) + hash_digest = hash_files(files) + + # Check git + tracked, untracked = check_files_in_git(files, DIR_PACKAGE_ROOT) + tracked_files: list[Path] = tracked + untracked_files: list[Path] = untracked + + if untracked_files: + error_msg = f"\n{'='*60}" + error_msg += f"Version Hash Inspection" + error_msg += f"{'='*60}\n" + error_msg += f"Hash: {hash_digest}" + error_msg += f"Based on {len(files)} files:\n" + + error_msg += f"OK: {len(tracked_files)} files tracked by git:\n" + for i, file_path in enumerate(files, 1): + try: + rel_path = file_path.relative_to(DIR_PACKAGE_ROOT) + status = "" + if file_path in untracked_files: + continue + elif file_path in tracked_files: + status = " [tracked]" + error_msg += f" {i:3d}. {rel_path}{status}\n" + except ValueError: + error_msg += f" {i:3d}. {file_path}\n" + + error_msg += f"Warning: {len(untracked_files)} files not tracked by git:\n" + for i, file_path in enumerate(files, 1): + try: + rel_path = file_path.relative_to(DIR_PACKAGE_ROOT) + status = "" + if file_path in untracked_files: + status = " [NOT IN GIT]" + elif file_path in tracked_files: + continue + error_msg += f" {i:3d}. {rel_path}{status}\n" + except ValueError: + error_msg += f" {i:3d}. {file_path}\n" + + error_msg += f"\n{'='*60}\n" + + pytest.fail(error_msg) + # --- Test version helpers --- def test_version_non_dev(monkeypatch): @@ -38,7 +204,7 @@ def test_version_dev_precision_8(monkeypatch): result = _version_calculate() - # compute expected suffix + # Compute expected suffix using the same logic as _version_calculate hash_value = int(fake_hash, 16) expected_digits = str(hash_value % (10 ** 8)).zfill(8) @@ -60,12 +226,17 @@ def test_version_dev_precision_8_different_hash(monkeypatch): result = _version_calculate() + # Compute expected suffix using the same logic as _version_calculate hash_value = int(fake_hash, 16) expected_digits = str(hash_value % (10 ** 8)).zfill(8) + expected = f"0.2.0.dev{expected_digits}" assert result == expected assert len(expected_digits) == 8 + assert result.startswith("0.2.0.dev") + assert result == expected + # --- 1️⃣ Test get_version.py --- diff --git a/tests/test_weatherbrightsky.py b/tests/test_weatherbrightsky.py index d0b2bf0..7094a52 100644 --- a/tests/test_weatherbrightsky.py +++ b/tests/test_weatherbrightsky.py @@ -6,7 +6,7 @@ import pandas as pd import pytest from akkudoktoreos.core.cache import CacheFileStore -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.weatherbrightsky import WeatherBrightSky from akkudoktoreos.utils.datetimeutil import to_datetime diff --git a/tests/test_weatherclearoutside.py b/tests/test_weatherclearoutside.py index 7f1292c..8282b90 100644 --- a/tests/test_weatherclearoutside.py +++ b/tests/test_weatherclearoutside.py @@ -10,7 +10,7 @@ import pytest from bs4 import BeautifulSoup from akkudoktoreos.core.cache import CacheFileStore -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.weatherclearoutside import WeatherClearOutside from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime diff --git a/tests/test_weatherimport.py b/tests/test_weatherimport.py index 6125cd1..19ccf2a 100644 --- a/tests/test_weatherimport.py +++ b/tests/test_weatherimport.py @@ -1,11 +1,12 @@ import json from pathlib import Path +import numpy.testing as npt import pytest -from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.core.coreabc import get_ems from akkudoktoreos.prediction.weatherimport import WeatherImport -from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime +from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration DIR_TESTDATA = Path(__file__).absolute().parent.joinpath("testdata") @@ -87,6 +88,7 @@ def test_invalid_provider(provider, config_eos, monkeypatch): ) def test_import(provider, sample_import_1_json, start_datetime, from_file, config_eos): """Test fetching forecast from Import.""" + key = "weather_temp_air" ems_eos = get_ems() ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin")) if from_file: @@ -95,7 +97,7 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi else: config_eos.weather.provider_settings.WeatherImport.import_file_path = None assert config_eos.weather.provider_settings.WeatherImport.import_file_path is None - provider.clear() + provider.delete_by_datetime(start_datetime=None, end_datetime=None) # Call the method provider.update_data() @@ -104,16 +106,13 @@ def test_import(provider, sample_import_1_json, start_datetime, from_file, confi assert provider.ems_start_datetime is not None assert provider.total_hours is not None assert compare_datetimes(provider.ems_start_datetime, ems_eos.start_datetime).equal - values = sample_import_1_json["weather_temp_air"] - value_datetime_mapping = provider.import_datetimes(ems_eos.start_datetime, len(values)) - for i, mapping in enumerate(value_datetime_mapping): - assert i < len(provider.records) - expected_datetime, expected_value_index = mapping - expected_value = values[expected_value_index] - result_datetime = provider.records[i].date_time - result_value = provider.records[i]["weather_temp_air"] - # print(f"{i}: Expected: {expected_datetime}:{expected_value}") - # print(f"{i}: Result: {result_datetime}:{result_value}") - assert compare_datetimes(result_datetime, expected_datetime).equal - assert result_value == expected_value + expected_values = sample_import_1_json[key] + result_values = provider.key_to_array( + key=key, + start_datetime=provider.ems_start_datetime, + end_datetime=provider.ems_start_datetime + to_duration(f"{len(expected_values)} hours"), + interval=to_duration("1 hour"), + ) + # Allow for some difference due to value calculation on DST change + npt.assert_allclose(result_values, expected_values, rtol=0.001) diff --git a/tests/testdata/eos_config_andreas_now.json b/tests/testdata/eos_config_andreas_now.json index e7847a5..3391bec 100644 --- a/tests/testdata/eos_config_andreas_now.json +++ b/tests/testdata/eos_config_andreas_now.json @@ -1,6 +1,6 @@ { "general": { - "data_folder_path": null, + "data_folder_path": "__ANY__", "data_output_subpath": "output", "latitude": 52.5, "longitude": 13.4