mirror of
https://github.com/Akkudoktor-EOS/EOS.git
synced 2025-03-18 18:44:04 +00:00
Fix2 config and predictions revamp. (#281)
measurement: - Add new measurement class to hold real world measurements. - Handles load meter readings, grid import and export meter readings. - Aggregates load meter readings aka. measurements to total load. - Can import measurements from files, pandas datetime series, pandas datetime dataframes, simple daetime arrays and programmatically. - Maybe expanded to other measurement values. - Should be used for load prediction adaptions by real world measurements. core/coreabc: - Add mixin class to access measurements core/pydantic: - Add pydantic models for pandas datetime series and dataframes. - Add pydantic models for simple datetime array core/dataabc: - Provide DataImport mixin class for generic import handling. Imports from JSON string and files. Imports from pandas datetime dataframes and simple datetime arrays. Signature of import method changed to allow import datetimes to be given programmatically and by data content. - Use pydantic models for datetime series, dataframes, arrays - Validate generic imports by pydantic models - Provide new attributes min_datetime and max_datetime for DataSequence. - Add parameter dropna to drop NAN/ None values when creating lists, pandas series or numpy array from DataSequence. config/config: - Add common settings for the measurement module. predictions/elecpriceakkudoktor: - Use mean values of last 7 days to fill prediction values not provided by akkudoktor.net (only provides 24 values). prediction/loadabc: - Extend the generic prediction keys by 'load_total_adjusted' for load predictions that adjust the predicted total load by measured load values. prediction/loadakkudoktor: - Extend the Akkudoktor load prediction by load adjustment using measured load values. prediction/load_aggregator: - Module removed. Load aggregation is now handled by the measurement module. prediction/load_corrector: - Module removed. Load correction (aka. adjustment of load prediction by measured load energy) is handled by the LoadAkkudoktor prediction and the generic 'load_mean_adjusted' prediction key. prediction/load_forecast: - Module removed. Functionality now completely handled by the LoadAkkudoktor prediction. utils/cacheutil: - Use pydantic. - Fix potential bug in ttl (time to live) duration handling. utils/datetimeutil: - Added missing handling of pendulum.DateTime and pendulum.Duration instances as input. Handled before as datetime.datetime and datetime.timedelta. utils/visualize: - Move main to generate_example_report() for better testing support. server/server: - Added new configuration option server_fastapi_startup_server_fasthtml to make startup of FastHTML server by FastAPI server conditional. server/fastapi_server: - Add APIs for measurements - Improve APIs to provide or take pandas datetime series and datetime dataframes controlled by pydantic model. - Improve APIs to provide or take simple datetime data arrays controlled by pydantic model. - Move fastAPI server API to v1 for new APIs. - Update pre v1 endpoints to use new prediction and measurement capabilities. - Only start FastHTML server if 'server_fastapi_startup_server_fasthtml' config option is set. tests: - Adapt import tests to changed import method signature - Adapt server test to use the v1 API - Extend the dataabc test to test for array generation from data with several data interval scenarios. - Extend the datetimeutil test to also test for correct handling of to_datetime() providing now(). - Adapt LoadAkkudoktor test for new adjustment calculation. - Adapt visualization test to use example report function instead of visualize.py run as process. - Removed test_load_aggregator. Functionality is now tested in test_measurement. - Added tests for measurement module docs: - Remove sphinxcontrib-openapi as it prevents build of documentation. "site-packages/sphinxcontrib/openapi/openapi31.py", line 305, in _get_type_from_schema for t in schema["anyOf"]: KeyError: 'anyOf'" Signed-off-by: Bobby Noelte <b0661n0e17e@gmail.com>
This commit is contained in:
parent
2a8e11d7dc
commit
830af85fca
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,7 @@ extensions = [
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.openapi",
|
||||
# "sphinxcontrib.openapi", buggy
|
||||
"myst_parser",
|
||||
]
|
||||
|
||||
@ -118,7 +118,7 @@ autodoc_default_options = {
|
||||
autosummary_generate = True
|
||||
|
||||
# -- Options for openapi -----------------------------------------------------
|
||||
openapi_default_renderer = "httpdomain:old"
|
||||
# openapi_default_renderer = "httpdomain:old" buggy
|
||||
|
||||
# -- Options for napoleon -------------------------------------------------
|
||||
napoleon_google_docstring = True
|
||||
|
@ -124,10 +124,10 @@ def run_prediction(provider_id: str, verbose: bool = False) -> str:
|
||||
|
||||
def main():
|
||||
"""Main function to run the optimization script with optional profiling."""
|
||||
parser = argparse.ArgumentParser(description="Run Energy Optimization Simulation")
|
||||
parser = argparse.ArgumentParser(description="Run Prediction")
|
||||
parser.add_argument("--profile", action="store_true", help="Enable performance profiling")
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Enable verbose output during optimization"
|
||||
"--verbose", action="store_true", help="Enable verbose output during prediction"
|
||||
)
|
||||
parser.add_argument("--provider-id", type=str, default=0, help="Provider ID of prediction")
|
||||
|
||||
|
@ -12,7 +12,7 @@ Key features:
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar, List, Optional
|
||||
|
||||
import platformdirs
|
||||
from pydantic import Field, ValidationError, computed_field
|
||||
@ -21,6 +21,7 @@ from pydantic import Field, ValidationError, computed_field
|
||||
from akkudoktoreos.config.configabc import SettingsBaseModel
|
||||
from akkudoktoreos.core.coreabc import SingletonMixin
|
||||
from akkudoktoreos.devices.devices import DevicesCommonSettings
|
||||
from akkudoktoreos.measurement.measurement import MeasurementCommonSettings
|
||||
from akkudoktoreos.optimization.optimization import OptimizationCommonSettings
|
||||
from akkudoktoreos.prediction.elecprice import ElecPriceCommonSettings
|
||||
from akkudoktoreos.prediction.elecpriceimport import ElecPriceImportCommonSettings
|
||||
@ -80,6 +81,7 @@ class ConfigCommonSettings(SettingsBaseModel):
|
||||
class SettingsEOS(
|
||||
ConfigCommonSettings,
|
||||
DevicesCommonSettings,
|
||||
MeasurementCommonSettings,
|
||||
OptimizationCommonSettings,
|
||||
PredictionCommonSettings,
|
||||
ElecPriceCommonSettings,
|
||||
@ -169,6 +171,16 @@ class ConfigEOS(SingletonMixin, SettingsEOS):
|
||||
"""Compute the default config file path."""
|
||||
return Path(__file__).parent.parent.joinpath("data/default.config.json")
|
||||
|
||||
# Computed fields
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def config_keys(self) -> List[str]:
|
||||
"""Returns the keys of all fields in the configuration."""
|
||||
key_list = []
|
||||
key_list.extend(list(self.model_fields.keys()))
|
||||
key_list.extend(list(self.__pydantic_decorators__.computed_fields.keys()))
|
||||
return key_list
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the singleton ConfigEOS instance.
|
||||
|
||||
@ -228,9 +240,9 @@ class ConfigEOS(SingletonMixin, SettingsEOS):
|
||||
def merge_settings_from_dict(self, data: dict) -> None:
|
||||
"""Merges the provided dictionary data into the current instance.
|
||||
|
||||
Creates a new settings instance with all optional fields reset to None,
|
||||
then applies the dictionary data through validation, and finally merges
|
||||
the validated settings into the current instance.
|
||||
Creates a new settings instance, then applies the dictionary data through validation,
|
||||
and finally merges the validated settings into the current instance. None values
|
||||
are not merged.
|
||||
|
||||
Args:
|
||||
data (dict): Dictionary containing field values to merge into the
|
||||
@ -245,7 +257,7 @@ class ConfigEOS(SingletonMixin, SettingsEOS):
|
||||
>>> config.merge_settings_from_dict(new_data)
|
||||
"""
|
||||
# Create new settings instance with reset optional fields and merged data
|
||||
settings = SettingsEOS.from_dict_with_reset(data)
|
||||
settings = SettingsEOS.from_dict(data)
|
||||
self.merge_settings(settings)
|
||||
|
||||
def reset_settings(self) -> None:
|
||||
@ -377,7 +389,7 @@ class ConfigEOS(SingletonMixin, SettingsEOS):
|
||||
"""
|
||||
if not self.config_file_path:
|
||||
raise ValueError("Configuration file path unknown.")
|
||||
with self.config_file_path.open("r", encoding=self.ENCODING) as f_out:
|
||||
with self.config_file_path.open("w", encoding=self.ENCODING) as f_out:
|
||||
try:
|
||||
json_str = super().to_json()
|
||||
# Write to file
|
||||
|
@ -10,15 +10,4 @@ class SettingsBaseModel(PydanticBaseModel):
|
||||
Settings property names shall be disjunctive to all existing settings' property names.
|
||||
"""
|
||||
|
||||
def reset_to_defaults(self) -> None:
|
||||
"""Resets the fields to their default values."""
|
||||
for field_name, field_info in self.model_fields.items():
|
||||
if field_info.default_factory is not None: # Handle fields with default_factory
|
||||
default_value = field_info.default_factory()
|
||||
else:
|
||||
default_value = field_info.default
|
||||
try:
|
||||
setattr(self, field_name, default_value)
|
||||
except (AttributeError, TypeError):
|
||||
# Skip fields that are read-only or dynamically computed
|
||||
pass
|
||||
pass
|
||||
|
@ -21,6 +21,7 @@ from akkudoktoreos.utils.logutil import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
config_eos: Any = None
|
||||
measurement_eos: Any = None
|
||||
prediction_eos: Any = None
|
||||
devices_eos: Any = None
|
||||
ems_eos: Any = None
|
||||
@ -50,7 +51,7 @@ class ConfigMixin:
|
||||
|
||||
@property
|
||||
def config(self) -> Any:
|
||||
"""Convenience method/ attribute to retrieve the EOS onfiguration data.
|
||||
"""Convenience method/ attribute to retrieve the EOS configuration data.
|
||||
|
||||
Returns:
|
||||
ConfigEOS: The configuration.
|
||||
@ -65,6 +66,46 @@ class ConfigMixin:
|
||||
return config_eos
|
||||
|
||||
|
||||
class MeasurementMixin:
|
||||
"""Mixin class for managing EOS measurement data.
|
||||
|
||||
This class serves as a foundational component for EOS-related classes requiring access
|
||||
to global measurement data. It provides a `measurement` property that dynamically retrieves
|
||||
the measurement instance, ensuring up-to-date access to measurement results.
|
||||
|
||||
Usage:
|
||||
Subclass this base class to gain access to the `measurement` attribute, which retrieves the
|
||||
global measurement instance lazily to avoid import-time circular dependencies.
|
||||
|
||||
Attributes:
|
||||
measurement (Measurement): Property to access the global EOS measurement data.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyOptimizationClass(MeasurementMixin):
|
||||
def analyze_mymeasurement(self):
|
||||
measurement_data = self.measurement.mymeasurement
|
||||
# Perform analysis
|
||||
```
|
||||
"""
|
||||
|
||||
@property
|
||||
def measurement(self) -> Any:
|
||||
"""Convenience method/ attribute to retrieve the EOS measurement data.
|
||||
|
||||
Returns:
|
||||
Measurement: The measurement.
|
||||
"""
|
||||
# avoid circular dependency at import time
|
||||
global measurement_eos
|
||||
if measurement_eos is None:
|
||||
from akkudoktoreos.measurement.measurement import get_measurement
|
||||
|
||||
measurement_eos = get_measurement()
|
||||
|
||||
return measurement_eos
|
||||
|
||||
|
||||
class PredictionMixin:
|
||||
"""Mixin class for managing EOS prediction data.
|
||||
|
||||
|
@ -21,10 +21,21 @@ import pandas as pd
|
||||
import pendulum
|
||||
from numpydantic import NDArray, Shape
|
||||
from pendulum import DateTime, Duration
|
||||
from pydantic import AwareDatetime, ConfigDict, Field, computed_field, field_validator
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationError,
|
||||
computed_field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin, StartMixin
|
||||
from akkudoktoreos.core.pydantic import PydanticBaseModel
|
||||
from akkudoktoreos.core.pydantic import (
|
||||
PydanticBaseModel,
|
||||
PydanticDateTimeData,
|
||||
PydanticDateTimeDataFrame,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
@ -47,22 +58,22 @@ class DataRecord(DataBase, MutableMapping):
|
||||
and attribute-style access (`record.field_name`).
|
||||
|
||||
Attributes:
|
||||
date_time (Optional[AwareDatetime]): Aware datetime indicating when the data record applies.
|
||||
date_time (Optional[DateTime]): Aware datetime indicating when the data record applies.
|
||||
|
||||
Configurations:
|
||||
- Allows mutation after creation.
|
||||
- Supports non-standard data types like `datetime`.
|
||||
"""
|
||||
|
||||
date_time: Optional[AwareDatetime] = Field(default=None, description="DateTime")
|
||||
date_time: Optional[DateTime] = Field(default=None, description="DateTime")
|
||||
|
||||
# Pydantic v2 model configuration
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
@field_validator("date_time", mode="before")
|
||||
@classmethod
|
||||
def transform_to_datetime(cls, value: Any) -> DateTime:
|
||||
"""Converts various datetime formats into AwareDatetime."""
|
||||
def transform_to_datetime(cls, value: Any) -> Optional[DateTime]:
|
||||
"""Converts various datetime formats into DateTime."""
|
||||
if value is None:
|
||||
# Allow to set to default.
|
||||
return None
|
||||
@ -307,6 +318,38 @@ class DataSequence(DataBase, MutableSequence):
|
||||
records: List[DataRecord] = Field(default_factory=list, description="List of data records")
|
||||
|
||||
# Derived fields (computed)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def min_datetime(self) -> Optional[DateTime]:
|
||||
"""Minimum (earliest) datetime in the sorted sequence of data records.
|
||||
|
||||
This property computes the earliest datetime from the sequence of data records.
|
||||
If no records are present, it returns `None`.
|
||||
|
||||
Returns:
|
||||
Optional[DateTime]: The earliest datetime in the sequence, or `None` if no
|
||||
data records exist.
|
||||
"""
|
||||
if len(self.records) == 0:
|
||||
return None
|
||||
return self.records[0].date_time
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def max_datetime(self) -> DateTime:
|
||||
"""Maximum (latest) datetime in the sorted sequence of data records.
|
||||
|
||||
This property computes the latest datetime from the sequence of data records.
|
||||
If no records are present, it returns `None`.
|
||||
|
||||
Returns:
|
||||
Optional[DateTime]: The latest datetime in the sequence, or `None` if no
|
||||
data records exist.
|
||||
"""
|
||||
if len(self.records) == 0:
|
||||
return None
|
||||
return self.records[-1].date_time
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def record_keys(self) -> List[str]:
|
||||
@ -319,12 +362,31 @@ class DataSequence(DataBase, MutableSequence):
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def record_keys_writable(self) -> List[str]:
|
||||
"""Returns the keys of all fields in the data records that are writable."""
|
||||
"""Get the keys of all writable fields in the data records.
|
||||
|
||||
This property retrieves the keys of all fields in the data records that
|
||||
can be written to. It uses the `record_class` to determine the model's
|
||||
field structure.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of field keys that are writable in the data records.
|
||||
"""
|
||||
return list(self.record_class().model_fields.keys())
|
||||
|
||||
@classmethod
|
||||
def record_class(cls) -> Type:
|
||||
"""Returns the class of the data record this data sequence handles."""
|
||||
"""Get the class of the data record handled by this data sequence.
|
||||
|
||||
This method determines the class of the data record type associated with
|
||||
the `records` field of the model. The field is expected to be a list, and
|
||||
the element type of the list should be a subclass of `DataRecord`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the record type is not a subclass of `DataRecord`.
|
||||
|
||||
Returns:
|
||||
Type: The class of the data record handled by the data sequence.
|
||||
"""
|
||||
# Access the model field metadata
|
||||
field_info = cls.model_fields["records"]
|
||||
# Get the list element type from the 'type_' attribute
|
||||
@ -573,6 +635,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> Dict[DateTime, Any]:
|
||||
"""Extract a dictionary indexed by the date_time field of the DataRecords.
|
||||
|
||||
@ -583,6 +646,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date to filter records (inclusive).
|
||||
end_datetime (datetime, optional): The end date to filter records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[datetime, Any]: A dictionary with the date_time of each record as the key
|
||||
@ -597,12 +661,22 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
# Create a dictionary to hold date_time and corresponding values
|
||||
filtered_data = {
|
||||
to_datetime(record.date_time, as_string=True): getattr(record, key, None)
|
||||
for record in self.records
|
||||
if (start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge)
|
||||
and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt)
|
||||
}
|
||||
if dropna is None:
|
||||
dropna = True
|
||||
filtered_data = {}
|
||||
for record in self.records:
|
||||
if (
|
||||
record.date_time is None
|
||||
or (dropna and getattr(record, key, None) is None)
|
||||
or (dropna and getattr(record, key, None) == float("nan"))
|
||||
):
|
||||
continue
|
||||
if (
|
||||
start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge
|
||||
) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt):
|
||||
filtered_data[to_datetime(record.date_time, as_string=True)] = getattr(
|
||||
record, key, None
|
||||
)
|
||||
|
||||
return filtered_data
|
||||
|
||||
@ -611,6 +685,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> Tuple[List[DateTime], List[Optional[float]]]:
|
||||
"""Extracts two lists from data records within an optional date range.
|
||||
|
||||
@ -622,6 +697,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The key of the attribute in DataRecord to extract.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing a list of datetime values and a list of extracted values.
|
||||
@ -635,9 +711,15 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
# Create two lists to hold date_time and corresponding values
|
||||
if dropna is None:
|
||||
dropna = True
|
||||
filtered_records = []
|
||||
for record in self.records:
|
||||
if record.date_time is None:
|
||||
if (
|
||||
record.date_time is None
|
||||
or (dropna and getattr(record, key, None) is None)
|
||||
or (dropna and getattr(record, key, None) == float("nan"))
|
||||
):
|
||||
continue
|
||||
if (
|
||||
start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge
|
||||
@ -653,6 +735,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> pd.Series:
|
||||
"""Extract a series indexed by the date_time field from data records within an optional date range.
|
||||
|
||||
@ -660,6 +743,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
pd.Series: A Pandas Series with the index as the date_time of each record
|
||||
@ -668,7 +752,9 @@ class DataSequence(DataBase, MutableSequence):
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
dates, values = self.key_to_lists(key, start_datetime, end_datetime)
|
||||
dates, values = self.key_to_lists(
|
||||
key=key, start_datetime=start_datetime, end_datetime=end_datetime, dropna=dropna
|
||||
)
|
||||
return pd.Series(data=values, index=pd.DatetimeIndex(dates), name=key)
|
||||
|
||||
def key_from_series(self, key: str, series: pd.Series) -> None:
|
||||
@ -704,6 +790,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
fill_method: Optional[str] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> NDArray[Shape["*"], Any]:
|
||||
"""Extract an array indexed by fixed time intervals from data records within an optional date range.
|
||||
|
||||
@ -717,6 +804,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
- 'ffill': Forward fill missing values.
|
||||
- 'bfill': Backward fill missing values.
|
||||
- 'none': Defaults to 'linear' for numeric values, otherwise 'ffill'.
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A NumPy Array of the values extracted from the specified key.
|
||||
@ -724,10 +812,54 @@ class DataSequence(DataBase, MutableSequence):
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
self._validate_key(key)
|
||||
# Ensure datetime objects are normalized
|
||||
start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
resampled = None
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
series = self.key_to_series(key)
|
||||
|
||||
dates, values = self.key_to_lists(key=key, dropna=dropna)
|
||||
values_len = len(values)
|
||||
|
||||
if values_len < 1:
|
||||
# No values, assume at at least one value set to None
|
||||
if start_datetime is not None:
|
||||
dates.append(start_datetime - interval)
|
||||
else:
|
||||
dates.append(to_datetime(to_maxtime=False))
|
||||
values.append(None)
|
||||
|
||||
if start_datetime is not None:
|
||||
start_index = 0
|
||||
while start_index < values_len:
|
||||
if compare_datetimes(dates[start_index], start_datetime).ge:
|
||||
break
|
||||
start_index += 1
|
||||
if start_index == 0:
|
||||
# No value before start
|
||||
# Add dummy value
|
||||
dates.insert(0, dates[0] - interval)
|
||||
values.insert(0, values[0])
|
||||
elif start_index > 1:
|
||||
# Truncate all values before latest value before start_datetime
|
||||
dates = dates[start_index - 1 :]
|
||||
values = values[start_index - 1 :]
|
||||
|
||||
if end_datetime is not None:
|
||||
if compare_datetimes(dates[-1], end_datetime).lt:
|
||||
# Add dummy value at end_datetime
|
||||
dates.append(end_datetime)
|
||||
values.append(values[-1])
|
||||
|
||||
series = pd.Series(data=values, index=pd.DatetimeIndex(dates), name=key)
|
||||
if not series.index.inferred_type == "datetime64":
|
||||
raise TypeError(
|
||||
f"Expected DatetimeIndex, but got {type(series.index)} "
|
||||
f"infered to {series.index.inferred_type}: {series}"
|
||||
)
|
||||
|
||||
# Handle missing values
|
||||
if series.dtype in [np.float64, np.float32, np.int64, np.int32]:
|
||||
@ -735,7 +867,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
if fill_method is None:
|
||||
fill_method = "linear"
|
||||
# Resample the series to the specified interval
|
||||
resampled = series.resample(interval).mean()
|
||||
resampled = series.resample(interval, origin="start").first()
|
||||
if fill_method == "linear":
|
||||
resampled = resampled.interpolate(method="linear")
|
||||
elif fill_method == "ffill":
|
||||
@ -749,7 +881,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
if fill_method is None:
|
||||
fill_method = "ffill"
|
||||
# Resample the series to the specified interval
|
||||
resampled = series.resample(interval).first()
|
||||
resampled = series.resample(interval, origin="start").first()
|
||||
if fill_method == "ffill":
|
||||
resampled = resampled.ffill()
|
||||
elif fill_method == "bfill":
|
||||
@ -955,18 +1087,29 @@ class DataProvider(SingletonMixin, DataSequence):
|
||||
self.sort_by_datetime()
|
||||
|
||||
|
||||
class DataImportProvider(DataProvider):
|
||||
"""Abstract base class for data providers that import generic data.
|
||||
class DataImportMixin:
|
||||
"""Mixin class for import of generic data.
|
||||
|
||||
This class is designed to handle generic data provided in the form of a key-value dictionary.
|
||||
- **Keys**: Represent identifiers from the record keys of a specific data.
|
||||
- **Values**: Are lists of data values starting at a specified `start_datetime`, where
|
||||
each value corresponds to a subsequent time interval (e.g., hourly).
|
||||
|
||||
Subclasses must implement the logic for managing generic data based on the imported records.
|
||||
Two special keys are handled. `start_datetime` may be used to defined the starting datetime of
|
||||
the values. `ìnterval` may be used to define the fixed time interval between two values.
|
||||
|
||||
On import `self.update_value(datetime, key, value)` is called which has to be provided.
|
||||
Also `self.start_datetime` may be necessary as a default in case `start_datetime`is not given.
|
||||
"""
|
||||
|
||||
def import_datetimes(self, value_count: int) -> List[Tuple[DateTime, int]]:
|
||||
# Attributes required but defined elsehere.
|
||||
# - start_datetime
|
||||
# - record_keys_writable
|
||||
# - update_valu
|
||||
|
||||
def import_datetimes(
|
||||
self, start_datetime: DateTime, value_count: int, interval: Optional[Duration] = None
|
||||
) -> List[Tuple[DateTime, int]]:
|
||||
"""Generates a list of tuples containing timestamps and their corresponding value indices.
|
||||
|
||||
The function accounts for daylight saving time (DST) transitions:
|
||||
@ -975,7 +1118,9 @@ class DataImportProvider(DataProvider):
|
||||
but they share the same value index.
|
||||
|
||||
Args:
|
||||
start_datetime (DateTime): Start datetime of values
|
||||
value_count (int): The number of timestamps to generate.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Returns:
|
||||
List[Tuple[DateTime, int]]:
|
||||
@ -990,7 +1135,7 @@ class DataImportProvider(DataProvider):
|
||||
|
||||
Example:
|
||||
>>> start_datetime = pendulum.datetime(2024, 11, 3, 0, 0, tz="America/New_York")
|
||||
>>> import_datetimes(5)
|
||||
>>> import_datetimes(start_datetime, 5)
|
||||
[(DateTime(2024, 11, 3, 0, 0, tzinfo=Timezone('America/New_York')), 0),
|
||||
(DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1),
|
||||
(DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), # Repeated hour
|
||||
@ -998,7 +1143,16 @@ class DataImportProvider(DataProvider):
|
||||
(DateTime(2024, 11, 3, 3, 0, tzinfo=Timezone('America/New_York')), 3)]
|
||||
"""
|
||||
timestamps_with_indices: List[Tuple[DateTime, int]] = []
|
||||
value_datetime = self.start_datetime
|
||||
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
interval_steps_per_hour = int(3600 / interval.total_seconds())
|
||||
if interval.total_seconds() * interval_steps_per_hour != 3600:
|
||||
error_msg = f"Interval {interval} does not fit into hour."
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
value_datetime = start_datetime
|
||||
value_index = 0
|
||||
|
||||
while value_index < value_count:
|
||||
@ -1006,37 +1160,219 @@ class DataImportProvider(DataProvider):
|
||||
logger.debug(f"{i}: Insert at {value_datetime} with index {value_index}")
|
||||
timestamps_with_indices.append((value_datetime, value_index))
|
||||
|
||||
# Check if there is a DST transition
|
||||
next_time = value_datetime.add(hours=1)
|
||||
if next_time <= value_datetime:
|
||||
# Check if there is a DST transition (i.e., ambiguous time during fall back)
|
||||
# Repeat the hour value (reuse value index)
|
||||
value_datetime = next_time
|
||||
logger.debug(f"{i+1}: Repeat at {value_datetime} with index {value_index}")
|
||||
timestamps_with_indices.append((value_datetime, value_index))
|
||||
elif next_time.hour != value_datetime.hour + 1 and value_datetime.hour != 23:
|
||||
# Skip the hour value (spring forward in value index)
|
||||
value_index += 1
|
||||
logger.debug(f"{i+1}: Skip at {next_time} with index {value_index}")
|
||||
next_time = value_datetime.add(seconds=interval.total_seconds())
|
||||
|
||||
# Increment value index and value_datetime for new hour
|
||||
# Check if there is a DST transition
|
||||
if next_time.dst() != value_datetime.dst():
|
||||
if next_time.hour == value_datetime.hour:
|
||||
# We jump back by 1 hour
|
||||
# Repeat the value(s) (reuse value index)
|
||||
for i in range(interval_steps_per_hour):
|
||||
logger.debug(f"{i+1}: Repeat at {next_time} with index {value_index}")
|
||||
timestamps_with_indices.append((next_time, value_index))
|
||||
next_time = next_time.add(seconds=interval.total_seconds())
|
||||
else:
|
||||
# We jump forward by 1 hour
|
||||
# Drop the value(s)
|
||||
logger.debug(
|
||||
f"{i+1}: Skip {interval_steps_per_hour} at {next_time} with index {value_index}"
|
||||
)
|
||||
value_index += interval_steps_per_hour
|
||||
|
||||
# Increment value index and value_datetime for new interval
|
||||
value_index += 1
|
||||
value_datetime = value_datetime.add(hours=1)
|
||||
value_datetime = next_time
|
||||
|
||||
return timestamps_with_indices
|
||||
|
||||
def import_from_json(self, json_str: str, key_prefix: str = "") -> None:
|
||||
def import_from_dict(
|
||||
self,
|
||||
import_data: dict,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a dictionary.
|
||||
|
||||
This method reads generic data from a dictionary, matches keys based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values sequentially.
|
||||
All value lists must have the same length.
|
||||
|
||||
Args:
|
||||
import_data (dict): Dictionary containing the generic data with optional
|
||||
'start_datetime' and 'interval' keys.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values if not in dict.
|
||||
interval (Duration, optional): The fixed time interval if not in dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If value lists have different lengths or if datetime conversion fails.
|
||||
"""
|
||||
# Handle datetime and interval from dict or parameters
|
||||
if "start_datetime" in import_data:
|
||||
try:
|
||||
start_datetime = to_datetime(import_data["start_datetime"])
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid start_datetime in import data: {e}")
|
||||
|
||||
if start_datetime is None:
|
||||
start_datetime = self.start_datetime # type: ignore
|
||||
|
||||
if "interval" in import_data:
|
||||
try:
|
||||
interval = to_duration(import_data["interval"])
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid interval in import data: {e}")
|
||||
|
||||
# Filter keys based on key_prefix and record_keys_writable
|
||||
valid_keys = [
|
||||
key
|
||||
for key in import_data.keys()
|
||||
if key.startswith(key_prefix)
|
||||
and key in self.record_keys_writable # type: ignore
|
||||
and key not in ("start_datetime", "interval")
|
||||
]
|
||||
|
||||
if not valid_keys:
|
||||
return
|
||||
|
||||
# Validate all value lists have the same length
|
||||
value_lengths = []
|
||||
for key in valid_keys:
|
||||
value_list = import_data[key]
|
||||
if not isinstance(value_list, (list, tuple, np.ndarray)):
|
||||
raise ValueError(f"Value for key '{key}' must be a list, tuple, or array")
|
||||
value_lengths.append(len(value_list))
|
||||
|
||||
if len(set(value_lengths)) > 1:
|
||||
raise ValueError(
|
||||
f"All value lists must have the same length. Found lengths: "
|
||||
f"{dict(zip(valid_keys, value_lengths))}"
|
||||
)
|
||||
|
||||
# Generate datetime mapping once for the common length
|
||||
values_count = value_lengths[0]
|
||||
value_datetime_mapping = self.import_datetimes(
|
||||
start_datetime, values_count, interval=interval
|
||||
)
|
||||
|
||||
# Process each valid key
|
||||
for key in valid_keys:
|
||||
try:
|
||||
value_list = import_data[key]
|
||||
|
||||
# Update values, skipping any None/NaN
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
value = value_list[value_index]
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(value_datetime, key, value) # type: ignore
|
||||
|
||||
except (IndexError, TypeError) as e:
|
||||
raise ValueError(f"Error processing values for key '{key}': {e}")
|
||||
|
||||
def import_from_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a pandas DataFrame.
|
||||
|
||||
This method reads generic data from a DataFrame, matches columns based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values using
|
||||
the DataFrame's index as timestamps.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame containing the generic data with datetime index
|
||||
or sequential values.
|
||||
key_prefix (str, optional): A prefix to filter relevant columns from the DataFrame.
|
||||
Only columns starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime if DataFrame doesn't have datetime index.
|
||||
interval (Duration, optional): The fixed time interval if DataFrame doesn't have datetime index.
|
||||
|
||||
Raises:
|
||||
ValueError: If DataFrame structure is invalid or datetime conversion fails.
|
||||
"""
|
||||
# Validate DataFrame
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError("Input must be a pandas DataFrame")
|
||||
|
||||
# Handle datetime index
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
try:
|
||||
index_datetimes = [to_datetime(dt) for dt in df.index]
|
||||
has_datetime_index = True
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime index in DataFrame: {e}")
|
||||
else:
|
||||
if start_datetime is None:
|
||||
start_datetime = self.start_datetime # type: ignore
|
||||
has_datetime_index = False
|
||||
|
||||
# Filter columns based on key_prefix and record_keys_writable
|
||||
valid_columns = [
|
||||
col
|
||||
for col in df.columns
|
||||
if col.startswith(key_prefix) and col in self.record_keys_writable # type: ignore
|
||||
]
|
||||
|
||||
if not valid_columns:
|
||||
return
|
||||
|
||||
# For DataFrame, length validation is implicit since all columns have same length
|
||||
values_count = len(df)
|
||||
|
||||
# Generate value_datetime_mapping once if not using datetime index
|
||||
if not has_datetime_index:
|
||||
value_datetime_mapping = self.import_datetimes(
|
||||
start_datetime, values_count, interval=interval
|
||||
)
|
||||
|
||||
# Process each valid column
|
||||
for column in valid_columns:
|
||||
try:
|
||||
values = df[column].tolist()
|
||||
|
||||
if has_datetime_index:
|
||||
# Use the DataFrame's datetime index
|
||||
for dt, value in zip(index_datetimes, values):
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(dt, column, value) # type: ignore
|
||||
else:
|
||||
# Use the pre-generated datetime mapping
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
value = values[value_index]
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(value_datetime, column, value) # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error processing column '{column}': {e}")
|
||||
|
||||
def import_from_json(
|
||||
self,
|
||||
json_str: str,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a JSON string.
|
||||
|
||||
This method reads generic data from a JSON string, matches keys based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values sequentially,
|
||||
starting from the `start_datetime`. Each data value is associated with an hourly
|
||||
interval.
|
||||
starting from the `start_datetime`.
|
||||
|
||||
If start_datetime and or interval is given in the JSON dict it will be used. Otherwise
|
||||
the given parameters are used. If None is given start_datetime defaults to
|
||||
'self.start_datetime' and interval defaults to 1 hour.
|
||||
|
||||
Args:
|
||||
json_str (str): The JSON string containing the generic data.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Raises:
|
||||
JSONDecodeError: If the file content is not valid JSON.
|
||||
@ -1045,22 +1381,56 @@ class DataImportProvider(DataProvider):
|
||||
Given a JSON string with the following content:
|
||||
```json
|
||||
{
|
||||
"load0_mean": [20.5, 21.0, 22.1],
|
||||
"load1_mean": [50, 55, 60]
|
||||
"start_datetime": "2024-11-10 00:00:00"
|
||||
"interval": "30 minutes"
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"other_xyz: [10.5, 11.0, 12.1],
|
||||
}
|
||||
```
|
||||
and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though
|
||||
and `key_prefix = "load"`, only the "load_mean" key will be processed even though
|
||||
both keys are in the record.
|
||||
"""
|
||||
import_data = json.loads(json_str)
|
||||
for key in self.record_keys_writable:
|
||||
if key.startswith(key_prefix) and key in import_data:
|
||||
value_list = import_data[key]
|
||||
value_datetime_mapping = self.import_datetimes(len(value_list))
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
self.update_value(value_datetime, key, value_list[value_index])
|
||||
# Try pandas dataframe with orient="split"
|
||||
try:
|
||||
import_data = PydanticDateTimeDataFrame.model_validate_json(json_str)
|
||||
self.import_from_dataframe(import_data.to_dataframe())
|
||||
return
|
||||
except ValidationError as e:
|
||||
error_msg = ""
|
||||
for error in e.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
message = error["msg"]
|
||||
error_type = error["type"]
|
||||
error_msg += f"Field: {field}\nError: {message}\nType: {error_type}\n"
|
||||
logger.debug(f"PydanticDateTimeDataFrame import: {error_msg}")
|
||||
|
||||
def import_from_file(self, import_file_path: Path, key_prefix: str = "") -> None:
|
||||
# Try dictionary with special keys start_datetime and intervall
|
||||
try:
|
||||
import_data = PydanticDateTimeData.model_validate_json(json_str)
|
||||
self.import_from_dict(import_data.to_dict())
|
||||
return
|
||||
except ValidationError as e:
|
||||
error_msg = ""
|
||||
for error in e.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
message = error["msg"]
|
||||
error_type = error["type"]
|
||||
error_msg += f"Field: {field}\nError: {message}\nType: {error_type}\n"
|
||||
logger.debug(f"PydanticDateTimeData import: {error_msg}")
|
||||
|
||||
# Use simple dict format
|
||||
import_data = json.loads(json_str)
|
||||
self.import_from_dict(
|
||||
import_data, key_prefix=key_prefix, start_datetime=start_datetime, interval=interval
|
||||
)
|
||||
|
||||
def import_from_file(
|
||||
self,
|
||||
import_file_path: Path,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a file.
|
||||
|
||||
This method reads generic data from a JSON file, matches keys based on the
|
||||
@ -1068,10 +1438,16 @@ class DataImportProvider(DataProvider):
|
||||
starting from the `start_datetime`. Each data value is associated with an hourly
|
||||
interval.
|
||||
|
||||
If start_datetime and or interval is given in the JSON dict it will be used. Otherwise
|
||||
the given parameters are used. If None is given start_datetime defaults to
|
||||
'self.start_datetime' and interval defaults to 1 hour.
|
||||
|
||||
Args:
|
||||
import_file_path (Path): The path to the JSON file containing the generic data.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified file does not exist.
|
||||
@ -1081,16 +1457,32 @@ class DataImportProvider(DataProvider):
|
||||
Given a JSON file with the following content:
|
||||
```json
|
||||
{
|
||||
"load0_mean": [20.5, 21.0, 22.1],
|
||||
"load1_mean": [50, 55, 60]
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"other_xyz: [10.5, 11.0, 12.1],
|
||||
}
|
||||
```
|
||||
and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though
|
||||
and `key_prefix = "load"`, only the "load_mean" key will be processed even though
|
||||
both keys are in the record.
|
||||
"""
|
||||
with import_file_path.open("r") as import_file:
|
||||
import_str = import_file.read()
|
||||
self.import_from_json(import_str, key_prefix)
|
||||
self.import_from_json(
|
||||
import_str, key_prefix=key_prefix, start_datetime=start_datetime, interval=interval
|
||||
)
|
||||
|
||||
|
||||
class DataImportProvider(DataImportMixin, DataProvider):
|
||||
"""Abstract base class for data providers that import generic data.
|
||||
|
||||
This class is designed to handle generic data provided in the form of a key-value dictionary.
|
||||
- **Keys**: Represent identifiers from the record keys of a specific data.
|
||||
- **Values**: Are lists of data values starting at a specified `start_datetime`, where
|
||||
each value corresponds to a subsequent time interval (e.g., hourly).
|
||||
|
||||
Subclasses must implement the logic for managing generic data based on the imported records.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
@ -1129,6 +1521,24 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
enab.append(provider)
|
||||
return enab
|
||||
|
||||
@property
|
||||
def record_keys(self) -> list[str]:
|
||||
"""Returns the keys of all fields in the data records of all enabled providers."""
|
||||
key_set = set(
|
||||
chain.from_iterable(provider.record_keys for provider in self.enabled_providers)
|
||||
)
|
||||
return list(key_set)
|
||||
|
||||
@property
|
||||
def record_keys_writable(self) -> list[str]:
|
||||
"""Returns the keys of all fields in the data records that are writable of all enabled providers."""
|
||||
key_set = set(
|
||||
chain.from_iterable(
|
||||
provider.record_keys_writable for provider in self.enabled_providers
|
||||
)
|
||||
)
|
||||
return list(key_set)
|
||||
|
||||
def __getitem__(self, key: str) -> pd.Series:
|
||||
"""Retrieve a Pandas Series for a specified key from the data in each DataProvider.
|
||||
|
||||
@ -1206,9 +1616,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
Returns:
|
||||
Iterator[str]: An iterator over the unique keys from all providers.
|
||||
"""
|
||||
return iter(
|
||||
set(chain.from_iterable(provider.record_keys for provider in self.enabled_providers))
|
||||
)
|
||||
return iter(self.record_keys)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of keys in the container.
|
||||
@ -1216,9 +1624,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
Returns:
|
||||
int: The total number of keys in this container.
|
||||
"""
|
||||
return len(
|
||||
list(chain.from_iterable(provider.record_keys for provider in self.enabled_providers))
|
||||
)
|
||||
return len(self.record_keys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Provide a string representation of the DataContainer instance.
|
||||
@ -1242,6 +1648,48 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
for provider in self.enabled_providers:
|
||||
provider.update_data(force_enable=force_enable, force_update=force_update)
|
||||
|
||||
def key_to_series(
|
||||
self,
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> pd.Series:
|
||||
"""Extract a series indexed by the date_time field from data records within an optional date range.
|
||||
|
||||
Iterates through providers to find and return the first available series for the specified key.
|
||||
|
||||
Args:
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
pd.Series: A Pandas Series with the index as the date_time of each record
|
||||
and the values extracted from the specified key.
|
||||
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
series = None
|
||||
for provider in self.enabled_providers:
|
||||
try:
|
||||
series = provider.key_to_series(
|
||||
key,
|
||||
start_datetime=start_datetime,
|
||||
end_datetime=end_datetime,
|
||||
dropna=dropna,
|
||||
)
|
||||
break
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if series is None:
|
||||
raise KeyError(f"No data found for key '{key}'.")
|
||||
|
||||
return series
|
||||
|
||||
def key_to_array(
|
||||
self,
|
||||
key: str,
|
||||
|
@ -1,68 +1,43 @@
|
||||
"""Module for managing and serializing Pydantic-based models with custom support.
|
||||
|
||||
This module introduces the `PydanticBaseModel` class, which extends Pydantic’s `BaseModel` to facilitate
|
||||
custom serialization and deserialization for `pendulum.DateTime` objects. The main features include
|
||||
automatic handling of `pendulum.DateTime` fields, custom serialization to ISO 8601 format, and utility
|
||||
methods for converting model instances to and from dictionary and JSON formats.
|
||||
This module provides classes that extend Pydantic’s functionality to include robust handling
|
||||
of `pendulum.DateTime` fields, offering seamless serialization and deserialization into ISO 8601 format.
|
||||
These enhancements facilitate the use of Pydantic models in applications requiring timezone-aware
|
||||
datetime fields and consistent data serialization.
|
||||
|
||||
Key Classes:
|
||||
- PendulumDateTime: A custom type adapter that provides serialization and deserialization
|
||||
functionality for `pendulum.DateTime` objects, converting them to ISO 8601 strings and back.
|
||||
- PydanticBaseModel: A base model class for handling prediction records or configuration data
|
||||
with automatic Pendulum DateTime handling and additional methods for JSON and dictionary
|
||||
conversion.
|
||||
|
||||
Classes:
|
||||
PendulumDateTime(TypeAdapter[pendulum.DateTime]): Type adapter for `pendulum.DateTime` fields
|
||||
with ISO 8601 serialization. Includes:
|
||||
- serialize: Converts `pendulum.DateTime` instances to ISO 8601 string.
|
||||
- deserialize: Converts ISO 8601 strings to `pendulum.DateTime` instances.
|
||||
- is_iso8601: Validates if a string matches the ISO 8601 date format.
|
||||
|
||||
PydanticBaseModel(BaseModel): Extends `pydantic.BaseModel` to handle `pendulum.DateTime` fields
|
||||
and adds convenience methods for dictionary and JSON serialization. Key methods:
|
||||
- model_dump: Dumps the model, converting `pendulum.DateTime` fields to ISO 8601.
|
||||
- model_construct: Constructs a model instance with automatic deserialization of
|
||||
`pendulum.DateTime` fields from ISO 8601.
|
||||
- to_dict: Serializes the model instance to a dictionary.
|
||||
- from_dict: Constructs a model instance from a dictionary.
|
||||
- to_json: Converts the model instance to a JSON string.
|
||||
- from_json: Creates a model instance from a JSON string.
|
||||
|
||||
Usage Example:
|
||||
# Define custom settings in a model using PydanticBaseModel
|
||||
class PredictionCommonSettings(PydanticBaseModel):
|
||||
prediction_start: pendulum.DateTime = Field(...)
|
||||
|
||||
# Serialize a model instance to a dictionary or JSON
|
||||
config = PredictionCommonSettings(prediction_start=pendulum.now())
|
||||
config_dict = config.to_dict()
|
||||
config_json = config.to_json()
|
||||
|
||||
# Deserialize from dictionary or JSON
|
||||
new_config = PredictionCommonSettings.from_dict(config_dict)
|
||||
restored_config = PredictionCommonSettings.from_json(config_json)
|
||||
|
||||
Dependencies:
|
||||
- `pendulum`: Required for handling timezone-aware datetime fields.
|
||||
- `pydantic`: Required for model and validation functionality.
|
||||
|
||||
Notes:
|
||||
- This module enables custom handling of Pendulum DateTime fields within Pydantic models,
|
||||
which is particularly useful for applications requiring consistent ISO 8601 datetime formatting
|
||||
and robust timezone-aware datetime support.
|
||||
Key Features:
|
||||
- Custom type adapter for `pendulum.DateTime` fields with automatic serialization to ISO 8601 strings.
|
||||
- Utility methods for converting models to and from dictionaries and JSON strings.
|
||||
- Validation tools for maintaining data consistency, including specialized support for
|
||||
pandas DataFrames and Series with datetime indexes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas as pd
|
||||
import pendulum
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from pandas.api.types import is_datetime64_any_dtype
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
|
||||
|
||||
|
||||
# Custom type adapter for Pendulum DateTime fields
|
||||
class PendulumDateTime(TypeAdapter[pendulum.DateTime]):
|
||||
class PydanticTypeAdapterDateTime(TypeAdapter[pendulum.DateTime]):
|
||||
"""Custom type adapter for Pendulum DateTime fields."""
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value: Any) -> str:
|
||||
"""Convert pendulum.DateTime to ISO 8601 string."""
|
||||
@ -105,41 +80,69 @@ class PydanticBaseModel(BaseModel):
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def validate_and_convert_pendulum(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
"""Validator to convert fields of type `pendulum.DateTime`.
|
||||
|
||||
Converts fields to proper `pendulum.DateTime` objects, ensuring correct input types.
|
||||
|
||||
This method is invoked for every field before the field value is set. If the field's type
|
||||
is `pendulum.DateTime`, it tries to convert string or timestamp values to `pendulum.DateTime`
|
||||
objects. If the value cannot be converted, a validation error is raised.
|
||||
|
||||
Args:
|
||||
value: The value to be assigned to the field.
|
||||
info: Validation information for the field.
|
||||
|
||||
Returns:
|
||||
The converted value, if successful.
|
||||
|
||||
Raises:
|
||||
ValidationError: If the value cannot be converted to `pendulum.DateTime`.
|
||||
"""
|
||||
# Get the field name and expected type
|
||||
field_name = info.field_name
|
||||
expected_type = cls.model_fields[field_name].annotation
|
||||
|
||||
# Convert
|
||||
if expected_type is pendulum.DateTime or expected_type is AwareDatetime:
|
||||
try:
|
||||
value = to_datetime(value)
|
||||
except:
|
||||
pass
|
||||
return value
|
||||
|
||||
# Override Pydantic’s serialization for all DateTime fields
|
||||
def model_dump(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Custom dump method to handle serialization for DateTime fields."""
|
||||
result = super().model_dump(*args, **kwargs)
|
||||
for key, value in result.items():
|
||||
if isinstance(value, pendulum.DateTime):
|
||||
result[key] = PendulumDateTime.serialize(value)
|
||||
result[key] = PydanticTypeAdapterDateTime.serialize(value)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, data: dict) -> "PydanticBaseModel":
|
||||
def model_construct(
|
||||
cls, _fields_set: set[str] | None = None, **values: Any
|
||||
) -> "PydanticBaseModel":
|
||||
"""Custom constructor to handle deserialization for DateTime fields."""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and PendulumDateTime.is_iso8601(value):
|
||||
data[key] = PendulumDateTime.deserialize(value)
|
||||
return super().model_construct(data)
|
||||
for key, value in values.items():
|
||||
if isinstance(value, str) and PydanticTypeAdapterDateTime.is_iso8601(value):
|
||||
values[key] = PydanticTypeAdapterDateTime.deserialize(value)
|
||||
return super().model_construct(_fields_set, **values)
|
||||
|
||||
def reset_optional(self) -> "PydanticBaseModel":
|
||||
"""Resets all optional fields in the model to None.
|
||||
|
||||
Iterates through all model fields and sets any optional (non-required)
|
||||
fields to None. The modification is done in-place on the current instance.
|
||||
|
||||
Returns:
|
||||
PydanticBaseModel: The current instance with all optional fields
|
||||
reset to None.
|
||||
|
||||
Example:
|
||||
>>> settings = PydanticBaseModel(name="test", optional_field="value")
|
||||
>>> settings.reset_optional()
|
||||
>>> assert settings.optional_field is None
|
||||
"""
|
||||
for field_name, field in self.model_fields.items():
|
||||
if field.is_required is False: # Check if field is optional
|
||||
setattr(self, field_name, None)
|
||||
def reset_to_defaults(self) -> "PydanticBaseModel":
|
||||
"""Resets the fields to their default values."""
|
||||
for field_name, field_info in self.model_fields.items():
|
||||
if field_info.default_factory is not None: # Handle fields with default_factory
|
||||
default_value = field_info.default_factory()
|
||||
else:
|
||||
default_value = field_info.default
|
||||
try:
|
||||
setattr(self, field_name, default_value)
|
||||
except (AttributeError, TypeError, ValidationError):
|
||||
# Skip fields that are read-only or dynamically computed or can not be set to default
|
||||
pass
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@ -167,40 +170,6 @@ class PydanticBaseModel(BaseModel):
|
||||
"""
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict_with_reset(cls, data: dict | None = None) -> "PydanticBaseModel":
|
||||
"""Creates a new instance with reset optional fields, then updates from dict.
|
||||
|
||||
First creates an instance with default values, resets all optional fields
|
||||
to None, then updates the instance with the provided dictionary data if any.
|
||||
|
||||
Args:
|
||||
data (dict | None): Dictionary containing field values to initialize
|
||||
the instance with. Defaults to None.
|
||||
|
||||
Returns:
|
||||
PydanticBaseModel: A new instance with all optional fields initially
|
||||
reset to None and then updated with provided data.
|
||||
|
||||
Example:
|
||||
>>> data = {"name": "test", "optional_field": "value"}
|
||||
>>> settings = PydanticBaseModel.from_dict_with_reset(data)
|
||||
>>> # All non-specified optional fields will be None
|
||||
"""
|
||||
# Create instance with model defaults
|
||||
instance = cls()
|
||||
|
||||
# Reset all optional fields to None
|
||||
instance.reset_optional()
|
||||
|
||||
# Update with provided data if any
|
||||
if data:
|
||||
# Use model_validate to ensure proper type conversion and validation
|
||||
updated_instance = instance.model_validate({**instance.model_dump(), **data})
|
||||
return updated_instance
|
||||
|
||||
return instance
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert the PydanticBaseModel instance to a JSON string.
|
||||
|
||||
@ -224,3 +193,287 @@ class PydanticBaseModel(BaseModel):
|
||||
"""
|
||||
data = json.loads(json_str)
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
class PydanticDateTimeData(RootModel):
|
||||
"""Pydantic model for time series data with consistent value lengths.
|
||||
|
||||
This model validates a dictionary where:
|
||||
- Keys are strings representing data series names
|
||||
- Values are lists of numeric or string values
|
||||
- Special keys 'start_datetime' and 'interval' can contain string values
|
||||
for time series indexing
|
||||
- All value lists must have the same length
|
||||
|
||||
Example:
|
||||
{
|
||||
"start_datetime": "2024-01-01 00:00:00", # optional
|
||||
"interval": "1 Hour", # optional
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"load_min": [18.5, 19.0, 20.1]
|
||||
}
|
||||
"""
|
||||
|
||||
root: Dict[str, Union[str, List[Union[float, int, str, None]]]]
|
||||
|
||||
@field_validator("root", mode="after")
|
||||
@classmethod
|
||||
def validate_root(
|
||||
cls, value: Dict[str, Union[str, List[Union[float, int, str, None]]]]
|
||||
) -> Dict[str, Union[str, List[Union[float, int, str, None]]]]:
|
||||
# Validate that all keys are strings
|
||||
if not all(isinstance(k, str) for k in value.keys()):
|
||||
raise ValueError("All keys in the dictionary must be strings.")
|
||||
|
||||
# Validate that no lists contain only None values
|
||||
for v in value.values():
|
||||
if isinstance(v, list) and all(item is None for item in v):
|
||||
raise ValueError("Lists cannot contain only None values.")
|
||||
|
||||
# Validate that all lists have consistent lengths (if they are lists)
|
||||
list_lengths = [len(v) for v in value.values() if isinstance(v, list)]
|
||||
if len(set(list_lengths)) > 1:
|
||||
raise ValueError("All lists in the dictionary must have the same length.")
|
||||
|
||||
# Validate special keys
|
||||
if "start_datetime" in value.keys():
|
||||
value["start_datetime"] = to_datetime(value["start_datetime"])
|
||||
if "interval" in value.keys():
|
||||
value["interval"] = to_duration(value["interval"])
|
||||
|
||||
return value
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, List[Union[float, int, str, None]]]]:
|
||||
"""Convert the model to a plain dictionary.
|
||||
|
||||
Returns:
|
||||
Dict containing the validated data.
|
||||
"""
|
||||
return self.root
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PydanticDateTimeData":
|
||||
"""Create a PydanticDateTimeData instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: Input dictionary
|
||||
|
||||
Returns:
|
||||
PydanticDateTimeData instance
|
||||
"""
|
||||
return cls(root=data)
|
||||
|
||||
|
||||
class PydanticDateTimeDataFrame(PydanticBaseModel):
|
||||
"""Pydantic model for validating pandas DataFrame data with datetime index."""
|
||||
|
||||
data: Dict[str, Dict[str, Any]]
|
||||
dtypes: Dict[str, str] = Field(default_factory=dict)
|
||||
tz: Optional[str] = Field(default=None, description="Timezone for datetime values")
|
||||
datetime_columns: list[str] = Field(
|
||||
default_factory=lambda: ["date_time"], description="Columns to be treated as datetime"
|
||||
)
|
||||
|
||||
@field_validator("tz")
|
||||
def validate_timezone(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that the timezone is valid."""
|
||||
if v is not None:
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid timezone: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, v: Dict[str, Any], info: ValidationInfo) -> Dict[str, Any]:
|
||||
if not v:
|
||||
return v
|
||||
|
||||
# Validate consistent columns
|
||||
columns = set(next(iter(v.values())).keys())
|
||||
if not all(set(row.keys()) == columns for row in v.values()):
|
||||
raise ValueError("All rows must have the same columns")
|
||||
|
||||
# Convert index datetime strings
|
||||
try:
|
||||
d = {
|
||||
to_datetime(dt, as_string=True, in_timezone=info.data.get("tz")): value
|
||||
for dt, value in v.items()
|
||||
}
|
||||
v = d
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime string in index: {e}")
|
||||
|
||||
# Convert datetime columns
|
||||
datetime_cols = info.data.get("datetime_columns", [])
|
||||
try:
|
||||
for dt_str, value in v.items():
|
||||
for column_name, column_value in value.items():
|
||||
if column_name in datetime_cols and column_value is not None:
|
||||
v[dt_str][column_name] = to_datetime(
|
||||
column_value, in_timezone=info.data.get("tz")
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime value in column: {e}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("dtypes")
|
||||
@classmethod
|
||||
def validate_dtypes(cls, v: Dict[str, str], info: ValidationInfo) -> Dict[str, str]:
|
||||
if not v:
|
||||
return v
|
||||
|
||||
valid_dtypes = {"int64", "float64", "bool", "datetime64[ns]", "object", "string"}
|
||||
invalid_dtypes = set(v.values()) - valid_dtypes
|
||||
if invalid_dtypes:
|
||||
raise ValueError(f"Unsupported dtypes: {invalid_dtypes}")
|
||||
|
||||
data = info.data.get("data", {})
|
||||
if data:
|
||||
columns = set(next(iter(data.values())).keys())
|
||||
if not all(col in columns for col in v.keys()):
|
||||
raise ValueError("dtype columns must exist in data columns")
|
||||
return v
|
||||
|
||||
def to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert the validated model data to a pandas DataFrame."""
|
||||
df = pd.DataFrame.from_dict(self.data, orient="index")
|
||||
|
||||
# Convert index to datetime
|
||||
index = pd.Index([to_datetime(dt, in_timezone=self.tz) for dt in df.index])
|
||||
df.index = index
|
||||
|
||||
dtype_mapping = {
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bool": bool,
|
||||
}
|
||||
|
||||
# Apply dtypes
|
||||
for col, dtype in self.dtypes.items():
|
||||
if dtype == "datetime64[ns]":
|
||||
df[col] = pd.to_datetime(to_datetime(df[col], in_timezone=self.tz))
|
||||
elif dtype in dtype_mapping.keys():
|
||||
df[col] = df[col].astype(dtype_mapping[dtype])
|
||||
else:
|
||||
pass
|
||||
|
||||
return df
|
||||
|
||||
@classmethod
|
||||
def from_dataframe(
|
||||
cls, df: pd.DataFrame, tz: Optional[str] = None
|
||||
) -> "PydanticDateTimeDataFrame":
|
||||
"""Create a PydanticDateTimeDataFrame instance from a pandas DataFrame."""
|
||||
index = pd.Index([to_datetime(dt, as_string=True, in_timezone=tz) for dt in df.index])
|
||||
df.index = index
|
||||
|
||||
datetime_columns = [col for col in df.columns if is_datetime64_any_dtype(df[col])]
|
||||
|
||||
return cls(
|
||||
data=df.to_dict(orient="index"),
|
||||
dtypes={col: str(dtype) for col, dtype in df.dtypes.items()},
|
||||
tz=tz,
|
||||
datetime_columns=datetime_columns,
|
||||
)
|
||||
|
||||
|
||||
class PydanticDateTimeSeries(PydanticBaseModel):
|
||||
"""Pydantic model for validating pandas Series with datetime index in JSON format.
|
||||
|
||||
This model handles Series data serialized with orient='index', where the keys are
|
||||
datetime strings and values are the series values. Provides validation and
|
||||
conversion between JSON and pandas Series with datetime index.
|
||||
|
||||
Attributes:
|
||||
data (Dict[str, Any]): Dictionary mapping datetime strings to values.
|
||||
dtype (str): The data type of the series values.
|
||||
tz (str | None): Timezone name if the datetime index is timezone-aware.
|
||||
"""
|
||||
|
||||
data: Dict[str, Any]
|
||||
dtype: str = Field(default="float64")
|
||||
tz: Optional[str] = Field(default=None)
|
||||
|
||||
@field_validator("data", mode="after")
|
||||
@classmethod
|
||||
def validate_datetime_index(cls, v: Dict[str, Any], info: ValidationInfo) -> Dict[str, Any]:
|
||||
"""Validate that all keys can be parsed as datetime strings.
|
||||
|
||||
Args:
|
||||
v: Dictionary with datetime string keys and series values.
|
||||
|
||||
Returns:
|
||||
The validated data dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any key cannot be parsed as a datetime.
|
||||
"""
|
||||
tz = info.data.get("tz")
|
||||
if tz is not None:
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except KeyError:
|
||||
tz = None
|
||||
try:
|
||||
# Attempt to parse each key as datetime
|
||||
d = dict()
|
||||
for dt_str, value in v.items():
|
||||
d[to_datetime(dt_str, as_string=True, in_timezone=tz)] = value
|
||||
return d
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime string in index: {e}")
|
||||
|
||||
@field_validator("tz")
|
||||
def validate_timezone(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that the timezone is valid."""
|
||||
if v is not None:
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid timezone: {v}")
|
||||
return v
|
||||
|
||||
def to_series(self) -> pd.Series:
|
||||
"""Convert the validated model data to a pandas Series.
|
||||
|
||||
Returns:
|
||||
A pandas Series with datetime index constructed from the model data.
|
||||
"""
|
||||
index = [to_datetime(dt, in_timezone=self.tz) for dt in list(self.data.keys())]
|
||||
|
||||
series = pd.Series(data=list(self.data.values()), index=index, dtype=self.dtype)
|
||||
return series
|
||||
|
||||
@classmethod
|
||||
def from_series(cls, series: pd.Series, tz: Optional[str] = None) -> "PydanticDateTimeSeries":
|
||||
"""Create a PydanticDateTimeSeries instance from a pandas Series.
|
||||
|
||||
Args:
|
||||
series: The pandas Series with datetime index to convert.
|
||||
|
||||
Returns:
|
||||
A new instance containing the Series data.
|
||||
|
||||
Raises:
|
||||
ValueError: If series index is not datetime type.
|
||||
|
||||
Example:
|
||||
>>> dates = pd.date_range('2024-01-01', periods=3)
|
||||
>>> s = pd.Series([1.1, 2.2, 3.3], index=dates)
|
||||
>>> model = PydanticDateTimeSeries.from_series(s)
|
||||
"""
|
||||
index = pd.Index([to_datetime(dt, as_string=True, in_timezone=tz) for dt in series.index])
|
||||
series.index = index
|
||||
|
||||
if len(index) > 0:
|
||||
tz = to_datetime(series.index[0]).timezone.name
|
||||
|
||||
return cls(
|
||||
data=series.to_dict(),
|
||||
dtype=str(series.dtype),
|
||||
tz=tz,
|
||||
)
|
||||
|
0
src/akkudoktoreos/measurement/__init__.py
Normal file
0
src/akkudoktoreos/measurement/__init__.py
Normal file
263
src/akkudoktoreos/measurement/measurement.py
Normal file
263
src/akkudoktoreos/measurement/measurement.py
Normal file
@ -0,0 +1,263 @@
|
||||
"""Measurement module to provide and store measurements.
|
||||
|
||||
This module provides a `Measurement` class to manage and update a sequence of
|
||||
data records for measurements.
|
||||
|
||||
The measurements can be added programmatically or imported from a file or JSON string.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
from pendulum import DateTime, Duration
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from akkudoktoreos.config.configabc import SettingsBaseModel
|
||||
from akkudoktoreos.core.coreabc import SingletonMixin
|
||||
from akkudoktoreos.core.dataabc import DataImportMixin, DataRecord, DataSequence
|
||||
from akkudoktoreos.utils.datetimeutil import to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MeasurementCommonSettings(SettingsBaseModel):
|
||||
measurement_load0_name: Optional[str] = Field(
|
||||
default=None, description="Name of the load0 source (e.g. 'Household', 'Heat Pump')"
|
||||
)
|
||||
measurement_load1_name: Optional[str] = Field(
|
||||
default=None, description="Name of the load1 source (e.g. 'Household', 'Heat Pump')"
|
||||
)
|
||||
measurement_load2_name: Optional[str] = Field(
|
||||
default=None, description="Name of the load2 source (e.g. 'Household', 'Heat Pump')"
|
||||
)
|
||||
measurement_load3_name: Optional[str] = Field(
|
||||
default=None, description="Name of the load3 source (e.g. 'Household', 'Heat Pump')"
|
||||
)
|
||||
measurement_load4_name: Optional[str] = Field(
|
||||
default=None, description="Name of the load4 source (e.g. 'Household', 'Heat Pump')"
|
||||
)
|
||||
|
||||
|
||||
class MeasurementDataRecord(DataRecord):
|
||||
"""Represents a measurement data record containing various measurements at a specific datetime.
|
||||
|
||||
Attributes:
|
||||
date_time (Optional[DateTime]): The datetime of the record.
|
||||
"""
|
||||
|
||||
# Single loads, to be aggregated to total load
|
||||
measurement_load0_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Load0 meter reading [kWh]"
|
||||
)
|
||||
measurement_load1_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Load1 meter reading [kWh]"
|
||||
)
|
||||
measurement_load2_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Load2 meter reading [kWh]"
|
||||
)
|
||||
measurement_load3_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Load3 meter reading [kWh]"
|
||||
)
|
||||
measurement_load4_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Load4 meter reading [kWh]"
|
||||
)
|
||||
|
||||
measurement_max_loads: ClassVar[int] = 5 # Maximum number of loads that can be set
|
||||
|
||||
measurement_grid_export_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Export to grid meter reading [kWh]"
|
||||
)
|
||||
|
||||
measurement_grid_import_mr: Optional[float] = Field(
|
||||
default=None, ge=0, description="Import from grid meter reading [kWh]"
|
||||
)
|
||||
|
||||
# Computed fields
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def measurement_loads(self) -> List[str]:
|
||||
"""Compute a list of active loads."""
|
||||
active_loads = []
|
||||
|
||||
# Loop through measurement_loadx
|
||||
for i in range(self.measurement_max_loads):
|
||||
load_attr = f"measurement_load{i}_mr"
|
||||
|
||||
# Check if either attribute is set and add to active loads
|
||||
if getattr(self, load_attr, None):
|
||||
active_loads.append(load_attr)
|
||||
|
||||
return active_loads
|
||||
|
||||
|
||||
class Measurement(SingletonMixin, DataImportMixin, DataSequence):
|
||||
"""Singleton class that holds measurement data records.
|
||||
|
||||
Measurements can be provided programmatically or read from JSON string or file.
|
||||
"""
|
||||
|
||||
records: List[MeasurementDataRecord] = Field(
|
||||
default_factory=list, description="List of measurement data records"
|
||||
)
|
||||
|
||||
topics: ClassVar[List[str]] = [
|
||||
"measurement_load",
|
||||
]
|
||||
|
||||
def _interval_count(
|
||||
self, start_datetime: DateTime, end_datetime: DateTime, interval: Duration
|
||||
) -> int:
|
||||
"""Calculate number of intervals between two datetimes.
|
||||
|
||||
Args:
|
||||
start_datetime: Starting datetime
|
||||
end_datetime: Ending datetime
|
||||
interval: Time duration for each interval
|
||||
|
||||
Returns:
|
||||
Number of intervals as integer
|
||||
|
||||
Raises:
|
||||
ValueError: If end_datetime is before start_datetime
|
||||
ValueError: If interval is zero or negative
|
||||
"""
|
||||
if end_datetime < start_datetime:
|
||||
raise ValueError("end_datetime must be after start_datetime")
|
||||
|
||||
if interval.total_seconds() <= 0:
|
||||
raise ValueError("interval must be positive")
|
||||
|
||||
# Calculate difference in seconds
|
||||
diff_seconds = end_datetime.diff(start_datetime).total_seconds()
|
||||
interval_seconds = interval.total_seconds()
|
||||
|
||||
# Return ceiling of division to include partial intervals
|
||||
return int(np.ceil(diff_seconds / interval_seconds))
|
||||
|
||||
def name_to_key(self, name: str, topic: str) -> Optional[str]:
|
||||
"""Provides measurement key for given name and topic."""
|
||||
topic = topic.lower()
|
||||
|
||||
if topic not in self.topics:
|
||||
return None
|
||||
|
||||
topic_keys = [key for key in self.config.config_keys if key.startswith(topic)]
|
||||
key = None
|
||||
if topic == "measurement_load":
|
||||
for config_key in topic_keys:
|
||||
if config_key.endswith("_name") and getattr(self.config, config_key) == name:
|
||||
key = topic + config_key[len(topic) : len(topic) + 1] + "_mr"
|
||||
break
|
||||
|
||||
if key is not None and key not in self.record_keys:
|
||||
# Should never happen
|
||||
error_msg = f"Key '{key}' not available."
|
||||
logger.error(error_msg)
|
||||
raise KeyError(error_msg)
|
||||
|
||||
return key
|
||||
|
||||
def _energy_from_meter_readings(
|
||||
self,
|
||||
key: str,
|
||||
start_datetime: DateTime,
|
||||
end_datetime: DateTime,
|
||||
interval: Duration,
|
||||
) -> NDArray[Shape["*"], Any]:
|
||||
"""Calculate an energy values array indexed by fixed time intervals from energy metering data within an optional date range.
|
||||
|
||||
Args:
|
||||
key: Key for energy meter readings.
|
||||
start_datetime (datetime): The start date for filtering the energy data (inclusive).
|
||||
end_datetime (datetime): The end date for filtering the energy data (exclusive).
|
||||
interval (duration): The fixed time interval.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A NumPy Array of the energy [kWh] per interval values calculated from
|
||||
the meter readings.
|
||||
"""
|
||||
# Add one interval to end_datetime to assure we have a energy value interval for all
|
||||
# datetimes from start_datetime (inclusive) to end_datetime (exclusive)
|
||||
end_datetime += interval
|
||||
size = self._interval_count(start_datetime, end_datetime, interval)
|
||||
|
||||
energy_mr_array = self.key_to_array(
|
||||
key=key, start_datetime=start_datetime, end_datetime=end_datetime, interval=interval
|
||||
)
|
||||
if energy_mr_array.size != size:
|
||||
logging_msg = (
|
||||
f"'{key}' meter reading array size: {energy_mr_array.size}"
|
||||
f" does not fit to expected size: {size}, {energy_mr_array}"
|
||||
)
|
||||
if energy_mr_array.size != 0:
|
||||
logger.error(logging_msg)
|
||||
raise ValueError(logging_msg)
|
||||
logger.debug(logging_msg)
|
||||
energy_array = np.zeros(size - 1)
|
||||
elif np.any(energy_mr_array == None):
|
||||
# 'key_to_array()' creates None values array if no data records are available.
|
||||
# Array contains None value -> ignore
|
||||
debug_msg = f"'{key}' meter reading None: {energy_mr_array}"
|
||||
logger.debug(debug_msg)
|
||||
energy_array = np.zeros(size - 1)
|
||||
else:
|
||||
# Calculate load per interval
|
||||
debug_msg = f"'{key}' meter reading: {energy_mr_array}"
|
||||
logger.debug(debug_msg)
|
||||
energy_array = np.diff(energy_mr_array)
|
||||
debug_msg = f"'{key}' energy calculation: {energy_array}"
|
||||
logger.debug(debug_msg)
|
||||
return energy_array
|
||||
|
||||
def load_total(
|
||||
self,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> NDArray[Shape["*"], Any]:
|
||||
"""Calculate a total load energy values array indexed by fixed time intervals from load metering data within an optional date range.
|
||||
|
||||
Args:
|
||||
start_datetime (datetime, optional): The start date for filtering the load data (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the load data (exclusive).
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A NumPy Array of the total load energy [kWh] per interval values calculated from
|
||||
the load meter readings.
|
||||
"""
|
||||
if len(self) < 1:
|
||||
# No data available
|
||||
if start_datetime is None or end_datetime is None:
|
||||
size = 0
|
||||
else:
|
||||
size = self._interval_count(start_datetime, end_datetime, interval)
|
||||
return np.zeros(size)
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
if start_datetime is None:
|
||||
start_datetime = self[0].date_time
|
||||
if end_datetime is None:
|
||||
end_datetime = self[-1].date_time
|
||||
size = self._interval_count(start_datetime, end_datetime, interval)
|
||||
load_total_array = np.zeros(size)
|
||||
# Loop through measurement_load<x>_mr
|
||||
for i in range(self.record_class().measurement_max_loads):
|
||||
key = f"measurement_load{i}_mr"
|
||||
# Calculate load per interval
|
||||
load_array = self._energy_from_meter_readings(
|
||||
key=key, start_datetime=start_datetime, end_datetime=end_datetime, interval=interval
|
||||
)
|
||||
# Add calculated load to total load
|
||||
load_total_array += load_array
|
||||
debug_msg = f"Total load '{key}' calculation: {load_total_array}"
|
||||
logger.debug(debug_msg)
|
||||
|
||||
return load_total_array
|
||||
|
||||
|
||||
def get_measurement() -> Measurement:
|
||||
"""Gets the EOS measurement data."""
|
||||
return Measurement()
|
@ -8,13 +8,15 @@ format, enabling consistent access to forecasted and historical electricity pric
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from pydantic import ValidationError
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import Field, ValidationError
|
||||
|
||||
from akkudoktoreos.core.pydantic import PydanticBaseModel
|
||||
from akkudoktoreos.prediction.elecpriceabc import ElecPriceDataRecord, ElecPriceProvider
|
||||
from akkudoktoreos.utils.cacheutil import cache_in_file
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
from akkudoktoreos.utils.cacheutil import CacheFileStore, cache_in_file
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -63,6 +65,20 @@ class ElecPriceAkkudoktor(ElecPriceProvider):
|
||||
_update_data(): Processes and updates forecast data from Akkudoktor in ElecPriceDataRecord format.
|
||||
"""
|
||||
|
||||
elecprice_8days: NDArray[Shape["24, 8"], float] = Field(
|
||||
default=np.full((24, 8), np.nan),
|
||||
description="Hourly electricity prices for the last 7 days and today (€/KWh). "
|
||||
"A NumPy array of 24 elements, each representing the hourly prices "
|
||||
"of the last 7 days (index 0..6, Monday..Sunday) and today (index 7).",
|
||||
)
|
||||
elecprice_8days_weights_day_of_week: NDArray[Shape["7, 8"], float] = Field(
|
||||
default=np.full((7, 8), np.nan),
|
||||
description="Daily electricity price weights for the last 7 days and today. "
|
||||
"A NumPy array of 7 elements (Monday..Sunday), each representing "
|
||||
"the daily price weights of the last 7 days (index 0..6, Monday..Sunday) "
|
||||
"and today (index 7).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def provider_id(cls) -> str:
|
||||
"""Return the unique identifier for the Akkudoktor provider."""
|
||||
@ -84,6 +100,50 @@ class ElecPriceAkkudoktor(ElecPriceProvider):
|
||||
raise ValueError(error_msg)
|
||||
return akkudoktor_data
|
||||
|
||||
def _calculate_weighted_mean(self, day_of_week: int, hour: int) -> float:
|
||||
"""Calculate the weighted mean price for given day_of_week and hour.
|
||||
|
||||
Args:
|
||||
day_of_week (int). The day of week to calculate the mean for (0=Monday..6).
|
||||
hour (int): The hour week to calculate the mean for (0..23).
|
||||
|
||||
Returns:
|
||||
price_weihgted_mead (float): Weighted mean price for given day_of:week and hour.
|
||||
"""
|
||||
if np.isnan(self.elecprice_8days_weights_day_of_week[0][0]):
|
||||
# Weights not initialized - do now
|
||||
|
||||
# Priority of day: 1=most .. 7=least
|
||||
priority_of_day = np.array(
|
||||
# Available Prediction days /
|
||||
# M,Tu,We,Th,Fr,Sa,Su,Today/ Forecast day_of_week
|
||||
[
|
||||
[1, 2, 3, 4, 5, 6, 7, 1], # Monday
|
||||
[3, 1, 2, 4, 5, 6, 7, 1], # Tuesday
|
||||
[4, 2, 1, 3, 5, 6, 7, 1], # Wednesday
|
||||
[5, 4, 2, 1, 3, 6, 7, 1], # Thursday
|
||||
[5, 4, 3, 2, 1, 6, 7, 1], # Friday
|
||||
[7, 6, 5, 4, 2, 1, 3, 1], # Saturday
|
||||
[7, 6, 5, 4, 3, 2, 1, 1], # Sunday
|
||||
]
|
||||
)
|
||||
# Take priorities above to decrease relevance in 2s exponential
|
||||
self.elecprice_8days_weights_day_of_week = 2 / (2**priority_of_day)
|
||||
|
||||
# Compute the weighted mean for day_of_week and hour
|
||||
prices_of_hour = self.elecprice_8days[hour]
|
||||
if np.isnan(prices_of_hour).all():
|
||||
# No prediction prices available for this hour - use mean value of all prices
|
||||
price_weighted_mean = np.nanmean(self.elecprice_marketprice_8day)
|
||||
else:
|
||||
weights = self.elecprice_8days_weights_day_of_week[day_of_week]
|
||||
prices_of_hour_masked: NDArray[Shape["24"]] = np.ma.MaskedArray(
|
||||
prices_of_hour, mask=np.isnan(prices_of_hour)
|
||||
)
|
||||
price_weighted_mean = np.ma.average(prices_of_hour_masked, weights=weights)
|
||||
|
||||
return float(price_weighted_mean)
|
||||
|
||||
@cache_in_file(with_ttl="1 hour")
|
||||
def _request_forecast(self) -> AkkudoktorElecPrice:
|
||||
"""Fetch electricity price forecast data from Akkudoktor API.
|
||||
@ -98,13 +158,13 @@ class ElecPriceAkkudoktor(ElecPriceProvider):
|
||||
ValueError: If the API response does not include expected `electricity price` data.
|
||||
"""
|
||||
source = "https://api.akkudoktor.net"
|
||||
date = to_datetime(self.start_datetime, as_string="Y-M-D")
|
||||
# Try to take data from 7 days back for prediction - usually only some hours back are available
|
||||
date = to_datetime(self.start_datetime - to_duration("7 days"), as_string="Y-M-D")
|
||||
last_date = to_datetime(self.end_datetime, as_string="Y-M-D")
|
||||
response = requests.get(
|
||||
f"{source}/prices?date={date}&last_date={last_date}&tz={self.config.timezone}"
|
||||
)
|
||||
url = f"{source}/prices?date={date}&last_date={last_date}&tz={self.config.timezone}"
|
||||
response = requests.get(url)
|
||||
logger.debug(f"Response from {url}: {response}")
|
||||
response.raise_for_status() # Raise an error for bad responses
|
||||
logger.debug(f"Response from {source}: {response}")
|
||||
akkudoktor_data = self._validate_data(response.content)
|
||||
# We are working on fresh data (no cache), report update time
|
||||
self.update_datetime = to_datetime(in_timezone=self.config.timezone)
|
||||
@ -131,38 +191,66 @@ class ElecPriceAkkudoktor(ElecPriceProvider):
|
||||
f"but only {values_len} data sets are given in forecast data."
|
||||
)
|
||||
|
||||
previous_price = akkudoktor_data.values[0].marketpriceEurocentPerKWh
|
||||
# Get cached 8day values
|
||||
elecprice_cache_file = CacheFileStore().get(key="ElecPriceAkkudoktor8dayCache")
|
||||
if elecprice_cache_file is None:
|
||||
# Cache does not exist - create it
|
||||
elecprice_cache_file = CacheFileStore().create(
|
||||
key="ElecPriceAkkudoktor8dayCache",
|
||||
until_datetime=to_datetime("infinity"),
|
||||
suffix=".npy",
|
||||
)
|
||||
np.save(elecprice_cache_file, self.elecprice_8days)
|
||||
elecprice_cache_file.seek(0)
|
||||
self.elecprice_8days = np.load(elecprice_cache_file)
|
||||
|
||||
for i in range(values_len):
|
||||
original_datetime = akkudoktor_data.values[i].start
|
||||
dt = to_datetime(original_datetime, in_timezone=self.config.timezone)
|
||||
akkudoktor_value = akkudoktor_data.values[i]
|
||||
|
||||
if compare_datetimes(dt, self.start_datetime).le:
|
||||
if compare_datetimes(dt, self.start_datetime).lt:
|
||||
# forecast data is too old
|
||||
previous_price = akkudoktor_data.values[i].marketpriceEurocentPerKWh
|
||||
self.elecprice_8days[dt.hour, dt.day_of_week] = (
|
||||
akkudoktor_value.marketpriceEurocentPerKWh
|
||||
)
|
||||
continue
|
||||
self.elecprice_8days[dt.hour, 7] = akkudoktor_value.marketpriceEurocentPerKWh
|
||||
|
||||
record = ElecPriceDataRecord(
|
||||
date_time=dt,
|
||||
elecprice_marketprice=akkudoktor_data.values[i].marketpriceEurocentPerKWh,
|
||||
elecprice_marketprice=akkudoktor_value.marketpriceEurocentPerKWh,
|
||||
)
|
||||
self.append(record)
|
||||
|
||||
# Update 8day cache
|
||||
elecprice_cache_file.seek(0)
|
||||
np.save(elecprice_cache_file, self.elecprice_8days)
|
||||
|
||||
# Check for new/ valid forecast data
|
||||
if len(self) == 0:
|
||||
# Got no valid forecast data
|
||||
return
|
||||
|
||||
# Assure price starts at start_time
|
||||
if compare_datetimes(self[0].date_time, self.start_datetime).gt:
|
||||
while compare_datetimes(self[0].date_time, self.start_datetime).gt:
|
||||
# Repeat the mean on the 8 day array to cover the missing hours
|
||||
dt = self[0].date_time.subtract(hours=1) # type: ignore
|
||||
value = self._calculate_weighted_mean(dt.day_of_week, dt.hour)
|
||||
|
||||
record = ElecPriceDataRecord(
|
||||
date_time=self.start_datetime,
|
||||
elecprice_marketprice=previous_price,
|
||||
date_time=dt,
|
||||
elecprice_marketprice=value,
|
||||
)
|
||||
self.insert(0, record)
|
||||
# Assure price ends at end_time
|
||||
if compare_datetimes(self[-1].date_time, self.end_datetime).lt:
|
||||
while compare_datetimes(self[-1].date_time, self.end_datetime).lt:
|
||||
# Repeat the mean on the 8 day array to cover the missing hours
|
||||
dt = self[-1].date_time.add(hours=1) # type: ignore
|
||||
value = self._calculate_weighted_mean(dt.day_of_week, dt.hour)
|
||||
|
||||
record = ElecPriceDataRecord(
|
||||
date_time=self.end_datetime,
|
||||
elecprice_marketprice=self[-1].elecprice_marketprice,
|
||||
date_time=dt,
|
||||
elecprice_marketprice=value,
|
||||
)
|
||||
self.append(record)
|
||||
# If some of the hourly values are missing, they will be interpolated when using
|
||||
# `key_to_array`.
|
||||
|
@ -1,37 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class LoadAggregator:
|
||||
def __init__(self, prediction_hours: int = 24) -> None:
|
||||
"""Initializes the LoadAggregator object with the number of prediction hours.
|
||||
|
||||
:param prediction_hours: Number of hours to predict (default: 24)
|
||||
"""
|
||||
self.loads: defaultdict[str, list[float]] = defaultdict(
|
||||
list
|
||||
) # Dictionary to hold load arrays for different sources
|
||||
self.prediction_hours: int = prediction_hours
|
||||
|
||||
def add_load(self, name: str, last_array: Sequence[float]) -> None:
|
||||
"""Adds a load array for a specific source. Accepts a Sequence of floats.
|
||||
|
||||
:param name: Name of the load source (e.g., "Household", "Heat Pump").
|
||||
:param last_array: Sequence of loads, where each entry corresponds to an hour.
|
||||
:raises ValueError: If the length of last_array doesn't match the prediction hours.
|
||||
"""
|
||||
# Check length of the array without converting
|
||||
if len(last_array) != self.prediction_hours:
|
||||
raise ValueError(f"Total load inconsistent lengths in arrays: {name} {len(last_array)}")
|
||||
self.loads[name] = list(last_array)
|
||||
|
||||
def calculate_total_load(self) -> list[float]:
|
||||
"""Calculates the total load for each hour by summing up the loads from all sources.
|
||||
|
||||
:return: A list representing the total load for each hour.
|
||||
Returns an empty list if no loads have been added.
|
||||
"""
|
||||
# Optimize the summation using a single loop with zip
|
||||
total_load = [sum(hourly_loads) for hourly_loads in zip(*self.loads.values())]
|
||||
|
||||
return total_load
|
@ -1,202 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
from akkudoktoreos.prediction.load_forecast import LoadForecast
|
||||
|
||||
|
||||
class LoadPredictionAdjuster:
|
||||
def __init__(
|
||||
self, measured_data: pd.DataFrame, predicted_data: pd.DataFrame, load_forecast: LoadForecast
|
||||
):
|
||||
self.measured_data = measured_data
|
||||
self.predicted_data = predicted_data
|
||||
self.load_forecast = load_forecast
|
||||
self.merged_data = self._merge_data()
|
||||
|
||||
def _remove_outliers(self, data: pd.DataFrame, threshold: int = 2) -> pd.DataFrame:
|
||||
# Calculate the Z-Score of the 'Last' data
|
||||
data["Z-Score"] = np.abs((data["Last"] - data["Last"].mean()) / data["Last"].std())
|
||||
# Filter the data based on the threshold
|
||||
filtered_data = data[data["Z-Score"] < threshold]
|
||||
return filtered_data.drop(columns=["Z-Score"])
|
||||
|
||||
def _merge_data(self) -> pd.DataFrame:
|
||||
# Convert the time column in both DataFrames to datetime
|
||||
self.predicted_data["time"] = pd.to_datetime(self.predicted_data["time"])
|
||||
self.measured_data["time"] = pd.to_datetime(self.measured_data["time"])
|
||||
|
||||
# Ensure both time columns have the same timezone
|
||||
if self.measured_data["time"].dt.tz is None:
|
||||
self.measured_data["time"] = self.measured_data["time"].dt.tz_localize("UTC")
|
||||
|
||||
self.predicted_data["time"] = (
|
||||
self.predicted_data["time"].dt.tz_localize("UTC").dt.tz_convert("Europe/Berlin")
|
||||
)
|
||||
self.measured_data["time"] = self.measured_data["time"].dt.tz_convert("Europe/Berlin")
|
||||
|
||||
# Optionally: Remove timezone information if only working locally
|
||||
self.predicted_data["time"] = self.predicted_data["time"].dt.tz_localize(None)
|
||||
self.measured_data["time"] = self.measured_data["time"].dt.tz_localize(None)
|
||||
|
||||
# Now you can perform the merge
|
||||
merged_data = pd.merge(self.measured_data, self.predicted_data, on="time", how="inner")
|
||||
print(merged_data)
|
||||
merged_data["Hour"] = merged_data["time"].dt.hour
|
||||
merged_data["DayOfWeek"] = merged_data["time"].dt.dayofweek
|
||||
return merged_data
|
||||
|
||||
def calculate_weighted_mean(
|
||||
self, train_period_weeks: int = 9, test_period_weeks: int = 1
|
||||
) -> None:
|
||||
self.merged_data = self._remove_outliers(self.merged_data)
|
||||
train_end_date = self.merged_data["time"].max() - pd.Timedelta(weeks=test_period_weeks)
|
||||
train_start_date = train_end_date - pd.Timedelta(weeks=train_period_weeks)
|
||||
|
||||
test_start_date = train_end_date + pd.Timedelta(hours=1)
|
||||
test_end_date = (
|
||||
test_start_date + pd.Timedelta(weeks=test_period_weeks) - pd.Timedelta(hours=1)
|
||||
)
|
||||
|
||||
self.train_data = self.merged_data[
|
||||
(self.merged_data["time"] >= train_start_date)
|
||||
& (self.merged_data["time"] <= train_end_date)
|
||||
]
|
||||
|
||||
self.test_data = self.merged_data[
|
||||
(self.merged_data["time"] >= test_start_date)
|
||||
& (self.merged_data["time"] <= test_end_date)
|
||||
]
|
||||
|
||||
self.train_data["Difference"] = self.train_data["Last"] - self.train_data["Last Pred"]
|
||||
|
||||
weekdays_train_data = self.train_data[self.train_data["DayOfWeek"] < 5]
|
||||
weekends_train_data = self.train_data[self.train_data["DayOfWeek"] >= 5]
|
||||
|
||||
self.weekday_diff = (
|
||||
weekdays_train_data.groupby("Hour").apply(self._weighted_mean_diff).dropna()
|
||||
)
|
||||
self.weekend_diff = (
|
||||
weekends_train_data.groupby("Hour").apply(self._weighted_mean_diff).dropna()
|
||||
)
|
||||
|
||||
def _weighted_mean_diff(self, data: pd.DataFrame) -> float:
|
||||
train_end_date = self.train_data["time"].max()
|
||||
weights = 1 / (train_end_date - data["time"]).dt.days.replace(0, np.nan)
|
||||
weighted_mean = (data["Difference"] * weights).sum() / weights.sum()
|
||||
return weighted_mean
|
||||
|
||||
def adjust_predictions(self) -> None:
|
||||
self.train_data["Adjusted Pred"] = self.train_data.apply(self._adjust_row, axis=1)
|
||||
self.test_data["Adjusted Pred"] = self.test_data.apply(self._adjust_row, axis=1)
|
||||
|
||||
def _adjust_row(self, row: pd.Series) -> pd.Series:
|
||||
if row["DayOfWeek"] < 5:
|
||||
return row["Last Pred"] + self.weekday_diff.get(row["Hour"], 0)
|
||||
else:
|
||||
return row["Last Pred"] + self.weekend_diff.get(row["Hour"], 0)
|
||||
|
||||
def plot_results(self) -> None:
|
||||
self._plot_data(self.train_data, "Training")
|
||||
self._plot_data(self.test_data, "Testing")
|
||||
|
||||
def _plot_data(self, data: pd.DataFrame, data_type: str) -> None:
|
||||
plt.figure(figsize=(14, 7))
|
||||
plt.plot(data["time"], data["Last"], label=f"Actual Last - {data_type}", color="blue")
|
||||
plt.plot(
|
||||
data["time"],
|
||||
data["Last Pred"],
|
||||
label=f"Predicted Last - {data_type}",
|
||||
color="red",
|
||||
linestyle="--",
|
||||
)
|
||||
plt.plot(
|
||||
data["time"],
|
||||
data["Adjusted Pred"],
|
||||
label=f"Adjusted Predicted Last - {data_type}",
|
||||
color="green",
|
||||
linestyle=":",
|
||||
)
|
||||
plt.xlabel("Time")
|
||||
plt.ylabel("Load")
|
||||
plt.title(f"Actual vs Predicted vs Adjusted Predicted Load ({data_type} Data)")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
def evaluate_model(self) -> None:
|
||||
mse = mean_squared_error(self.test_data["Last"], self.test_data["Adjusted Pred"])
|
||||
r2 = r2_score(self.test_data["Last"], self.test_data["Adjusted Pred"])
|
||||
print(f"Mean Squared Error: {mse}")
|
||||
print(f"R-squared: {r2}")
|
||||
|
||||
def predict_next_hours(self, hours_ahead: int) -> pd.DataFrame:
|
||||
last_date = self.merged_data["time"].max()
|
||||
future_dates = [last_date + pd.Timedelta(hours=i) for i in range(1, hours_ahead + 1)]
|
||||
future_df = pd.DataFrame({"time": future_dates})
|
||||
future_df["Hour"] = future_df["time"].dt.hour
|
||||
future_df["DayOfWeek"] = future_df["time"].dt.dayofweek
|
||||
future_df["Last Pred"] = future_df["time"].apply(self._forecast_next_hours)
|
||||
future_df["Adjusted Pred"] = future_df.apply(self._adjust_row, axis=1)
|
||||
return future_df
|
||||
|
||||
def _forecast_next_hours(self, timestamp: datetime) -> float:
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
hour = timestamp.hour
|
||||
daily_forecast = self.load_forecast.get_daily_stats(date_str)
|
||||
return daily_forecast[0][hour] if hour < len(daily_forecast[0]) else np.nan
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# estimator = LastEstimator()
|
||||
# start_date = "2024-06-01"
|
||||
# end_date = "2024-08-01"
|
||||
# last_df = estimator.get_last(start_date, end_date)
|
||||
|
||||
# selected_columns = last_df[['timestamp', 'Last']]
|
||||
# selected_columns['time'] = pd.to_datetime(selected_columns['timestamp']).dt.floor('H')
|
||||
# selected_columns['Last'] = pd.to_numeric(selected_columns['Last'], errors='coerce')
|
||||
|
||||
# # Drop rows with NaN values
|
||||
# cleaned_data = selected_columns.dropna()
|
||||
|
||||
# print(cleaned_data)
|
||||
# # Create an instance of LoadForecast
|
||||
# lf = LoadForecast(filepath=r'.\load_profiles.npz', year_energy=6000*1000)
|
||||
|
||||
# # Initialize an empty DataFrame to hold the forecast data
|
||||
# forecast_list = []
|
||||
|
||||
# # Loop through each day in the date range
|
||||
# for single_date in pd.date_range(cleaned_data['time'].min().date(), cleaned_data['time'].max().date()):
|
||||
# date_str = single_date.strftime('%Y-%m-%d')
|
||||
# daily_forecast = lf.get_daily_stats(date_str)
|
||||
# mean_values = daily_forecast[0] # Extract the mean values
|
||||
# hours = [single_date + pd.Timedelta(hours=i) for i in range(24)]
|
||||
# daily_forecast_df = pd.DataFrame({'time': hours, 'Last Pred': mean_values})
|
||||
# forecast_list.append(daily_forecast_df)
|
||||
|
||||
# # Concatenate all daily forecasts into a single DataFrame
|
||||
# forecast_df = pd.concat(forecast_list, ignore_index=True)
|
||||
|
||||
# # Create an instance of the LoadPredictionAdjuster class
|
||||
# adjuster = LoadPredictionAdjuster(cleaned_data, forecast_df, lf)
|
||||
|
||||
# # Calculate the weighted mean differences
|
||||
# adjuster.calculate_weighted_mean()
|
||||
|
||||
# # Adjust the predictions
|
||||
# adjuster.adjust_predictions()
|
||||
|
||||
# # Plot the results
|
||||
# adjuster.plot_results()
|
||||
|
||||
# # Evaluate the model
|
||||
# adjuster.evaluate_model()
|
||||
|
||||
# # Predict the next x hours
|
||||
# future_predictions = adjuster.predict_next_hours(48)
|
||||
# print(future_predictions)
|
@ -1,99 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Load the .npz file when the application starts
|
||||
|
||||
|
||||
class LoadForecast:
|
||||
def __init__(self, filepath: str | Path, year_energy: float):
|
||||
self.filepath = filepath
|
||||
self.year_energy = year_energy
|
||||
self.load_data()
|
||||
|
||||
def get_daily_stats(self, date_str: str) -> np.ndarray:
|
||||
"""Returns the 24-hour profile with mean and standard deviation for a given date.
|
||||
|
||||
:param date_str: Date as a string in the format "YYYY-MM-DD"
|
||||
:return: An array with shape (2, 24), contains means and standard deviations
|
||||
"""
|
||||
# Convert the date string into a datetime object
|
||||
date = self._convert_to_datetime(date_str)
|
||||
|
||||
# Calculate the day of the year (1 to 365)
|
||||
day_of_year = date.timetuple().tm_yday
|
||||
|
||||
# Extract the 24-hour profile for the given date
|
||||
daily_stats = self.data_year_energy[day_of_year - 1] # -1 because indexing starts at 0
|
||||
return daily_stats
|
||||
|
||||
def get_hourly_stats(self, date_str: str, hour: int) -> np.ndarray:
|
||||
"""Returns the mean and standard deviation for a specific hour of a given date.
|
||||
|
||||
:param date_str: Date as a string in the format "YYYY-MM-DD"
|
||||
:param hour: Specific hour (0 to 23)
|
||||
:return: An array with shape (2,), contains mean and standard deviation for the specified hour
|
||||
"""
|
||||
# Convert the date string into a datetime object
|
||||
date = self._convert_to_datetime(date_str)
|
||||
|
||||
# Calculate the day of the year (1 to 365)
|
||||
day_of_year = date.timetuple().tm_yday
|
||||
|
||||
# Extract mean and standard deviation for the given hour
|
||||
hourly_stats = self.data_year_energy[day_of_year - 1, :, hour] # Access the specific hour
|
||||
|
||||
return hourly_stats
|
||||
|
||||
def get_stats_for_date_range(self, start_date_str: str, end_date_str: str) -> np.ndarray:
|
||||
"""Returns the means and standard deviations for a date range.
|
||||
|
||||
:param start_date_str: Start date as a string in the format "YYYY-MM-DD"
|
||||
:param end_date_str: End date as a string in the format "YYYY-MM-DD"
|
||||
:return: An array with aggregated data for the date range
|
||||
"""
|
||||
start_date = self._convert_to_datetime(start_date_str)
|
||||
end_date = self._convert_to_datetime(end_date_str)
|
||||
|
||||
start_day_of_year = start_date.timetuple().tm_yday
|
||||
end_day_of_year = end_date.timetuple().tm_yday
|
||||
|
||||
# Note that in leap years, the day of the year may need adjustment
|
||||
stats_for_range = self.data_year_energy[
|
||||
start_day_of_year:end_day_of_year
|
||||
] # -1 because indexing starts at 0
|
||||
stats_for_range = stats_for_range.swapaxes(1, 0)
|
||||
|
||||
stats_for_range = stats_for_range.reshape(stats_for_range.shape[0], -1)
|
||||
return stats_for_range
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""Loads data from the specified file."""
|
||||
try:
|
||||
data = np.load(self.filepath)
|
||||
self.data = np.array(list(zip(data["yearly_profiles"], data["yearly_profiles_std"])))
|
||||
self.data_year_energy = self.data * self.year_energy
|
||||
# pprint(self.data_year_energy)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {self.filepath} not found.")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while loading data: {e}")
|
||||
|
||||
def get_price_data(self) -> None:
|
||||
"""Returns price data (currently not implemented)."""
|
||||
raise NotImplementedError
|
||||
# return self.price_data
|
||||
|
||||
def _convert_to_datetime(self, date_str: str) -> datetime:
|
||||
"""Converts a date string to a datetime object."""
|
||||
return datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
|
||||
# Example usage of the class
|
||||
if __name__ == "__main__":
|
||||
filepath = r"..\data\load_profiles.npz" # Adjust the path to the .npz file
|
||||
lf = LoadForecast(filepath=filepath, year_energy=2000)
|
||||
specific_date_prices = lf.get_daily_stats("2024-02-16") # Adjust date as needed
|
||||
specific_hour_stats = lf.get_hourly_stats("2024-02-16", 12) # Adjust date and hour as needed
|
||||
print(specific_hour_stats)
|
@ -18,8 +18,14 @@ logger = get_logger(__name__)
|
||||
class LoadDataRecord(PredictionRecord):
|
||||
"""Represents a load data record containing various load attributes at a specific datetime."""
|
||||
|
||||
load_mean: Optional[float] = Field(default=None, description="Load mean value (W)")
|
||||
load_std: Optional[float] = Field(default=None, description="Load standard deviation (W)")
|
||||
load_mean: Optional[float] = Field(default=None, description="Predicted load mean value (W)")
|
||||
load_std: Optional[float] = Field(
|
||||
default=None, description="Predicted load standard deviation (W)"
|
||||
)
|
||||
|
||||
load_mean_adjusted: Optional[float] = Field(
|
||||
default=None, description="Predicted load mean value adjusted by load measurement (W)"
|
||||
)
|
||||
|
||||
|
||||
class LoadProvider(PredictionProvider):
|
||||
|
@ -8,7 +8,7 @@ from pydantic import Field
|
||||
|
||||
from akkudoktoreos.config.configabc import SettingsBaseModel
|
||||
from akkudoktoreos.prediction.loadabc import LoadProvider
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -30,6 +30,58 @@ class LoadAkkudoktor(LoadProvider):
|
||||
"""Return the unique identifier for the LoadAkkudoktor provider."""
|
||||
return "LoadAkkudoktor"
|
||||
|
||||
def _calculate_adjustment(self, data_year_energy: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Calculate weekday and week end adjustment from total load measurement data.
|
||||
|
||||
Returns:
|
||||
weekday_adjust (np.ndarray): hourly adjustment for Monday to Friday.
|
||||
weekend_adjust (np.ndarray): hourly adjustment for Saturday and Sunday.
|
||||
"""
|
||||
weekday_adjust = np.zeros(24)
|
||||
weekday_adjust_weight = np.zeros(24)
|
||||
weekend_adjust = np.zeros(24)
|
||||
weekend_adjust_weight = np.zeros(24)
|
||||
|
||||
if self.measurement.max_datetime is None:
|
||||
# No measurements - return 0 adjustment
|
||||
return (weekday_adjust, weekday_adjust)
|
||||
|
||||
# compare predictions with real measurement - try to use last 7 days
|
||||
compare_start = self.measurement.max_datetime - to_duration("7 days")
|
||||
if compare_datetimes(compare_start, self.measurement.min_datetime).lt:
|
||||
# Not enough measurements for 7 days - use what is available
|
||||
compare_start = self.measurement.min_datetime
|
||||
compare_end = self.measurement.max_datetime
|
||||
compare_interval = to_duration("1 hour")
|
||||
|
||||
load_total_array = self.measurement.load_total(
|
||||
start_datetime=compare_start,
|
||||
end_datetime=compare_end,
|
||||
interval=compare_interval,
|
||||
)
|
||||
compare_dt = compare_start
|
||||
for i in range(len(load_total_array)):
|
||||
load_total = load_total_array[i]
|
||||
# Extract mean (index 0) and standard deviation (index 1) for the given day and hour
|
||||
# Day indexing starts at 0, -1 because of that
|
||||
hourly_stats = data_year_energy[compare_dt.day_of_year - 1, :, compare_dt.hour]
|
||||
weight = 1 / ((compare_end - compare_dt).days + 1)
|
||||
if compare_dt.day_of_week < 5:
|
||||
weekday_adjust[compare_dt.hour] += (load_total - hourly_stats[0]) * weight
|
||||
weekday_adjust_weight[compare_dt.hour] += weight
|
||||
else:
|
||||
weekend_adjust[compare_dt.hour] += (load_total - hourly_stats[0]) * weight
|
||||
weekend_adjust_weight[compare_dt.hour] += weight
|
||||
compare_dt += compare_interval
|
||||
# Calculate mean
|
||||
for i in range(24):
|
||||
if weekday_adjust_weight[i] > 0:
|
||||
weekday_adjust[i] = weekday_adjust[i] / weekday_adjust_weight[i]
|
||||
if weekend_adjust_weight[i] > 0:
|
||||
weekend_adjust[i] = weekend_adjust[i] / weekend_adjust_weight[i]
|
||||
|
||||
return (weekday_adjust, weekend_adjust)
|
||||
|
||||
def load_data(self) -> np.ndarray:
|
||||
"""Loads data from the Akkudoktor load file."""
|
||||
load_file = Path(__file__).parent.parent.joinpath("data/load_profiles.npz")
|
||||
@ -54,13 +106,24 @@ class LoadAkkudoktor(LoadProvider):
|
||||
def _update_data(self, force_update: Optional[bool] = False) -> None:
|
||||
"""Adds the load means and standard deviations."""
|
||||
data_year_energy = self.load_data()
|
||||
weekday_adjust, weekend_adjust = self._calculate_adjustment(data_year_energy)
|
||||
date = self.start_datetime
|
||||
for i in range(self.config.prediction_hours):
|
||||
# Extract mean and standard deviation for the given day and hour
|
||||
# Extract mean (index 0) and standard deviation (index 1) for the given day and hour
|
||||
# Day indexing starts at 0, -1 because of that
|
||||
hourly_stats = data_year_energy[date.day_of_year - 1, :, date.hour]
|
||||
self.update_value(date, "load_mean", hourly_stats[0])
|
||||
self.update_value(date, "load_std", hourly_stats[1])
|
||||
if date.day_of_week < 5:
|
||||
# Monday to Friday (0..4)
|
||||
self.update_value(
|
||||
date, "load_mean_adjusted", hourly_stats[0] + weekday_adjust[date.hour]
|
||||
)
|
||||
else:
|
||||
# Saturday, Sunday (5, 6)
|
||||
self.update_value(
|
||||
date, "load_mean_adjusted", hourly_stats[0] + weekend_adjust[date.hour]
|
||||
)
|
||||
date += to_duration("1 hour")
|
||||
# We are working on fresh data (no cache), report update time
|
||||
self.update_datetime = to_datetime(in_timezone=self.config.timezone)
|
||||
|
@ -13,6 +13,7 @@ from typing import List, Optional
|
||||
from pendulum import DateTime
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from akkudoktoreos.core.coreabc import MeasurementMixin
|
||||
from akkudoktoreos.core.dataabc import (
|
||||
DataBase,
|
||||
DataContainer,
|
||||
@ -27,10 +28,11 @@ from akkudoktoreos.utils.logutil import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PredictionBase(DataBase):
|
||||
class PredictionBase(DataBase, MeasurementMixin):
|
||||
"""Base class for handling prediction data.
|
||||
|
||||
Enables access to EOS configuration data (attribute `config`).
|
||||
Enables access to EOS configuration data (attribute `config`) and EOS measurement data
|
||||
(attribute `measurement`).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
@ -70,7 +70,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast0_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast0_modules_per_string: Optional[str] = Field(
|
||||
pvforecast0_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast0_strings_per_inverter: Optional[str] = Field(
|
||||
@ -124,7 +124,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast1_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast1_modules_per_string: Optional[str] = Field(
|
||||
pvforecast1_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast1_strings_per_inverter: Optional[str] = Field(
|
||||
@ -178,7 +178,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast2_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast2_modules_per_string: Optional[str] = Field(
|
||||
pvforecast2_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast2_strings_per_inverter: Optional[str] = Field(
|
||||
@ -232,7 +232,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast3_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast3_modules_per_string: Optional[str] = Field(
|
||||
pvforecast3_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast3_strings_per_inverter: Optional[str] = Field(
|
||||
@ -286,7 +286,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast4_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast4_modules_per_string: Optional[str] = Field(
|
||||
pvforecast4_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast4_strings_per_inverter: Optional[str] = Field(
|
||||
@ -340,7 +340,7 @@ class PVForecastCommonSettings(SettingsBaseModel):
|
||||
pvforecast5_inverter_paco: Optional[int] = Field(
|
||||
default=None, description="AC power rating of the inverter. [W]"
|
||||
)
|
||||
pvforecast5_modules_per_string: Optional[str] = Field(
|
||||
pvforecast5_modules_per_string: Optional[int] = Field(
|
||||
default=None, description="Number of the PV modules of the strings of this plane."
|
||||
)
|
||||
pvforecast5_strings_per_inverter: Optional[str] = Field(
|
||||
|
@ -7,38 +7,55 @@ from pathlib import Path
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Query, Request
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import FileResponse, RedirectResponse, Response
|
||||
from pendulum import DateTime
|
||||
|
||||
from akkudoktoreos.config.config import ConfigEOS, SettingsEOS, get_config
|
||||
from akkudoktoreos.core.pydantic import PydanticBaseModel
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.core.pydantic import (
|
||||
PydanticBaseModel,
|
||||
PydanticDateTimeData,
|
||||
PydanticDateTimeDataFrame,
|
||||
PydanticDateTimeSeries,
|
||||
)
|
||||
from akkudoktoreos.measurement.measurement import get_measurement
|
||||
from akkudoktoreos.optimization.genetic import (
|
||||
OptimizationParameters,
|
||||
OptimizeResponse,
|
||||
optimization_problem,
|
||||
)
|
||||
|
||||
# Still to be adapted
|
||||
from akkudoktoreos.prediction.load_aggregator import LoadAggregator
|
||||
from akkudoktoreos.prediction.load_corrector import LoadPredictionAdjuster
|
||||
from akkudoktoreos.prediction.load_forecast import LoadForecast
|
||||
from akkudoktoreos.prediction.prediction import get_prediction
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_eos = get_config()
|
||||
measurement_eos = get_measurement()
|
||||
prediction_eos = get_prediction()
|
||||
ems_eos = get_ems()
|
||||
|
||||
|
||||
def start_fasthtml_server() -> subprocess.Popen:
|
||||
"""Start the fasthtml server as a subprocess."""
|
||||
server_process = subprocess.Popen(
|
||||
[sys.executable, str(server_dir.joinpath("fasthtml_server.py"))],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
return server_process
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Lifespan manager for the app."""
|
||||
# On startup
|
||||
if config_eos.server_fasthtml_host and config_eos.server_fasthtml_port:
|
||||
if (
|
||||
config_eos.server_fastapi_startup_server_fasthtml
|
||||
and config_eos.server_fasthtml_host
|
||||
and config_eos.server_fasthtml_port
|
||||
):
|
||||
try:
|
||||
fasthtml_process = start_fasthtml_server()
|
||||
except Exception as e:
|
||||
@ -72,41 +89,238 @@ class PdfResponse(FileResponse):
|
||||
media_type = "application/pdf"
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
@app.get("/v1/config")
|
||||
def fastapi_config_get() -> ConfigEOS:
|
||||
"""Get the current configuration."""
|
||||
return config_eos
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def fastapi_config_put(settings: SettingsEOS) -> ConfigEOS:
|
||||
"""Merge settings into current configuration."""
|
||||
@app.put("/v1/config")
|
||||
def fastapi_config_put(
|
||||
settings: SettingsEOS,
|
||||
save: Optional[bool] = None,
|
||||
) -> ConfigEOS:
|
||||
"""Merge settings into current configuration.
|
||||
|
||||
Args:
|
||||
settings (SettingsEOS): The settings to merge into the current configuration.
|
||||
save (Optional[bool]): Save the resulting configuration to the configuration file.
|
||||
Defaults to False.
|
||||
"""
|
||||
config_eos.merge_settings(settings)
|
||||
if save:
|
||||
try:
|
||||
config_eos.to_config_file()
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Cannot save configuration to file '{config_eos.config_file_path}'.",
|
||||
)
|
||||
return config_eos
|
||||
|
||||
|
||||
@app.get("/prediction/keys")
|
||||
def fastapi_prediction_keys() -> list[str]:
|
||||
@app.get("/v1/measurement/keys")
|
||||
def fastapi_measurement_keys_get() -> list[str]:
|
||||
"""Get a list of available measurement keys."""
|
||||
return sorted(measurement_eos.record_keys)
|
||||
|
||||
|
||||
@app.get("/v1/measurement/load-mr/series/by-name")
|
||||
def fastapi_measurement_load_mr_series_by_name_get(name: str) -> PydanticDateTimeSeries:
|
||||
"""Get the meter reading of given load name as series."""
|
||||
key = measurement_eos.name_to_key(name=name, topic="measurement_load")
|
||||
if key is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Measurement load with name '{name}' not available."
|
||||
)
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/load-mr/value/by-name")
|
||||
def fastapi_measurement_load_mr_value_by_name_put(
|
||||
datetime: Any, name: str, value: Union[float | str]
|
||||
) -> PydanticDateTimeSeries:
|
||||
"""Merge the meter reading of given load name and value into EOS measurements at given datetime."""
|
||||
key = measurement_eos.name_to_key(name=name, topic="measurement_load")
|
||||
if key is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Measurement load with name '{name}' not available."
|
||||
)
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
measurement_eos.update_value(datetime, key, value)
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/load-mr/series/by-name")
|
||||
def fastapi_measurement_load_mr_series_by_name_put(
|
||||
name: str, series: PydanticDateTimeSeries
|
||||
) -> PydanticDateTimeSeries:
|
||||
"""Merge the meter readings series of given load name into EOS measurements at given datetime."""
|
||||
key = measurement_eos.name_to_key(name=name, topic="measurement_load")
|
||||
if key is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Measurement load with name '{name}' not available."
|
||||
)
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
pdseries = series.to_series() # make pandas series from PydanticDateTimeSeries
|
||||
measurement_eos.key_from_series(key=key, series=pdseries)
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.get("/v1/measurement/series")
|
||||
def fastapi_measurement_series_get(key: str) -> PydanticDateTimeSeries:
|
||||
"""Get the measurements of given key as series."""
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/value")
|
||||
def fastapi_measurement_value_put(
|
||||
datetime: Any, key: str, value: Union[float | str]
|
||||
) -> PydanticDateTimeSeries:
|
||||
"""Merge the measurement of given key and value into EOS measurements at given datetime."""
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
measurement_eos.update_value(datetime, key, value)
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/series")
|
||||
def fastapi_measurement_series_put(
|
||||
key: str, series: PydanticDateTimeSeries
|
||||
) -> PydanticDateTimeSeries:
|
||||
"""Merge measurement given as series into given key."""
|
||||
if key not in measurement_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
pdseries = series.to_series() # make pandas series from PydanticDateTimeSeries
|
||||
measurement_eos.key_from_series(key=key, series=pdseries)
|
||||
pdseries = measurement_eos.key_to_series(key=key)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/dataframe")
|
||||
def fastapi_measurement_dataframe_put(data: PydanticDateTimeDataFrame) -> None:
|
||||
"""Merge the measurement data given as dataframe into EOS measurements."""
|
||||
dataframe = data.to_dataframe()
|
||||
measurement_eos.import_from_dataframe(dataframe)
|
||||
|
||||
|
||||
@app.put("/v1/measurement/data")
|
||||
def fastapi_measurement_data_put(data: PydanticDateTimeData) -> None:
|
||||
"""Merge the measurement data given as datetime data into EOS measurements."""
|
||||
datetimedata = data.to_dict()
|
||||
measurement_eos.import_from_dict(datetimedata)
|
||||
|
||||
|
||||
@app.get("/v1/prediction/keys")
|
||||
def fastapi_prediction_keys_get() -> list[str]:
|
||||
"""Get a list of available prediction keys."""
|
||||
return sorted(list(prediction_eos.keys()))
|
||||
return sorted(prediction_eos.record_keys)
|
||||
|
||||
|
||||
@app.get("/prediction")
|
||||
def fastapi_prediction(key: str) -> list[Union[float | str]]:
|
||||
"""Get the current configuration."""
|
||||
values = prediction_eos[key].to_list()
|
||||
return values
|
||||
@app.get("/v1/prediction/series")
|
||||
def fastapi_prediction_series_get(
|
||||
key: str,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
) -> PydanticDateTimeSeries:
|
||||
"""Get prediction for given key within given date range as series.
|
||||
|
||||
Args:
|
||||
start_datetime: Starting datetime (inclusive).
|
||||
Defaults to start datetime of latest prediction.
|
||||
end_datetime: Ending datetime (exclusive).
|
||||
Defaults to end datetime of latest prediction.
|
||||
"""
|
||||
if key not in prediction_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
if start_datetime is None:
|
||||
start_datetime = prediction_eos.start_datetime
|
||||
else:
|
||||
start_datetime = to_datetime(start_datetime)
|
||||
if end_datetime is None:
|
||||
end_datetime = prediction_eos.end_datetime
|
||||
else:
|
||||
end_datetime = to_datetime(end_datetime)
|
||||
pdseries = prediction_eos.key_to_series(
|
||||
key=key, start_datetime=start_datetime, end_datetime=end_datetime
|
||||
)
|
||||
return PydanticDateTimeSeries.from_series(pdseries)
|
||||
|
||||
|
||||
@app.get("/v1/prediction/list")
|
||||
def fastapi_prediction_list_get(
|
||||
key: str,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
interval: Optional[str] = None,
|
||||
) -> List[Any]:
|
||||
"""Get prediction for given key within given date range as value list.
|
||||
|
||||
Args:
|
||||
start_datetime: Starting datetime (inclusive).
|
||||
Defaults to start datetime of latest prediction.
|
||||
end_datetime: Ending datetime (exclusive).
|
||||
Defaults to end datetime of latest prediction.
|
||||
interval: Time duration for each interval
|
||||
Defaults to 1 hour.
|
||||
"""
|
||||
if key not in prediction_eos.record_keys:
|
||||
raise HTTPException(status_code=404, detail=f"Key '{key}' not available.")
|
||||
if start_datetime is None:
|
||||
start_datetime = prediction_eos.start_datetime
|
||||
else:
|
||||
start_datetime = to_datetime(start_datetime)
|
||||
if end_datetime is None:
|
||||
end_datetime = prediction_eos.end_datetime
|
||||
else:
|
||||
end_datetime = to_datetime(end_datetime)
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
else:
|
||||
interval = to_duration(interval)
|
||||
prediction_list = prediction_eos.key_to_array(
|
||||
key=key,
|
||||
start_datetime=start_datetime,
|
||||
end_datetime=end_datetime,
|
||||
interval=interval,
|
||||
).tolist()
|
||||
return prediction_list
|
||||
|
||||
|
||||
@app.get("/strompreis")
|
||||
def fastapi_strompreis() -> list[float]:
|
||||
"""Deprecated: Electricity Market Price Prediction.
|
||||
|
||||
Note:
|
||||
Use '/v1/prediction/list?key=elecprice_marketprice' instead.
|
||||
"""
|
||||
settings = SettingsEOS(
|
||||
elecprice_provider="ElecPriceAkkudoktor",
|
||||
)
|
||||
config_eos.merge_settings(settings=settings)
|
||||
ems_eos.set_start_datetime() # Set energy management start datetime to current hour.
|
||||
|
||||
# Create electricity price forecast
|
||||
prediction_eos.update_data(force_update=True)
|
||||
|
||||
# Get the current date and the end date based on prediction hours
|
||||
marketprice_series = prediction_eos["elecprice_marketprice"]
|
||||
# Fetch prices for the specified date range
|
||||
specific_date_prices = marketprice_series.loc[
|
||||
prediction_eos.start_datetime : prediction_eos.end_datetime
|
||||
]
|
||||
return specific_date_prices.tolist()
|
||||
return prediction_eos.key_to_array(
|
||||
key="elecprice_marketprice",
|
||||
start_datetime=prediction_eos.start_datetime,
|
||||
end_datetime=prediction_eos.end_datetime,
|
||||
).tolist()
|
||||
|
||||
|
||||
class GesamtlastRequest(PydanticBaseModel):
|
||||
@ -117,83 +331,79 @@ class GesamtlastRequest(PydanticBaseModel):
|
||||
|
||||
@app.post("/gesamtlast")
|
||||
def fastapi_gesamtlast(request: GesamtlastRequest) -> list[float]:
|
||||
"""Endpoint to handle total load calculation based on the latest measured data."""
|
||||
# Request-Daten extrahieren
|
||||
year_energy = request.year_energy
|
||||
measured_data = request.measured_data
|
||||
hours = request.hours
|
||||
"""Deprecated: Total Load Prediction with adjustment.
|
||||
|
||||
# Ab hier bleibt der Code unverändert ...
|
||||
measured_data_df = pd.DataFrame(measured_data)
|
||||
measured_data_df["time"] = pd.to_datetime(measured_data_df["time"])
|
||||
Endpoint to handle total load prediction adjusted by latest measured data.
|
||||
|
||||
# Zeitzonenmanagement
|
||||
if measured_data_df["time"].dt.tz is None:
|
||||
measured_data_df["time"] = measured_data_df["time"].dt.tz_localize("Europe/Berlin")
|
||||
else:
|
||||
measured_data_df["time"] = measured_data_df["time"].dt.tz_convert("Europe/Berlin")
|
||||
|
||||
# Zeitzone entfernen
|
||||
measured_data_df["time"] = measured_data_df["time"].dt.tz_localize(None)
|
||||
|
||||
# Forecast erstellen
|
||||
lf = LoadForecast(
|
||||
filepath=server_dir / ".." / "data" / "load_profiles.npz", year_energy=year_energy
|
||||
Note:
|
||||
Use '/v1/prediction/list?key=load_mean_adjusted' instead.
|
||||
Load energy meter readings to be added to EOS measurement by:
|
||||
'/v1/measurement/load-mr/value/by-name' or
|
||||
'/v1/measurement/value'
|
||||
"""
|
||||
settings = SettingsEOS(
|
||||
prediction_hours=request.hours,
|
||||
load_provider="LoadAkkudoktor",
|
||||
loadakkudoktor_year_energy=request.year_energy,
|
||||
)
|
||||
forecast_list = []
|
||||
config_eos.merge_settings(settings=settings)
|
||||
ems_eos.set_start_datetime() # Set energy management start datetime to current hour.
|
||||
|
||||
for single_date in pd.date_range(
|
||||
measured_data_df["time"].min().date(), measured_data_df["time"].max().date()
|
||||
):
|
||||
date_str = single_date.strftime("%Y-%m-%d")
|
||||
daily_forecast = lf.get_daily_stats(date_str)
|
||||
mean_values = daily_forecast[0]
|
||||
fc_hours = [single_date + pd.Timedelta(hours=i) for i in range(24)]
|
||||
daily_forecast_df = pd.DataFrame({"time": fc_hours, "Last Pred": mean_values})
|
||||
forecast_list.append(daily_forecast_df)
|
||||
# Insert measured data into EOS measurement
|
||||
# Convert from energy per interval to dummy energy meter readings
|
||||
measurement_key = "measurement_load0_mr"
|
||||
measurement_eos.key_delete_by_datetime(key=measurement_key) # delete all load0_mr measurements
|
||||
energy = {}
|
||||
for data_dict in request.measured_data:
|
||||
for date_time, value in data_dict.items():
|
||||
dt_str = to_datetime(date_time, as_string=True)
|
||||
energy[dt_str] = value
|
||||
energy_mr = 0
|
||||
for i, key in enumerate(sorted(energy)):
|
||||
energy_mr += energy[key]
|
||||
dt = to_datetime(key)
|
||||
if i == 0:
|
||||
# first element, add start value before
|
||||
dt_before = dt - to_duration("1 hour")
|
||||
measurement_eos.update_value(date=dt_before, key=measurement_key, value=0.0)
|
||||
measurement_eos.update_value(date=dt, key=measurement_key, value=energy_mr)
|
||||
|
||||
predicted_data = pd.concat(forecast_list, ignore_index=True)
|
||||
# Create load forecast
|
||||
prediction_eos.update_data(force_update=True)
|
||||
|
||||
adjuster = LoadPredictionAdjuster(measured_data_df, predicted_data, lf)
|
||||
adjuster.calculate_weighted_mean()
|
||||
adjuster.adjust_predictions()
|
||||
future_predictions = adjuster.predict_next_hours(hours)
|
||||
|
||||
leistung_haushalt = future_predictions["Adjusted Pred"].to_numpy()
|
||||
gesamtlast = LoadAggregator(prediction_hours=hours)
|
||||
gesamtlast.add_load(
|
||||
"Haushalt",
|
||||
tuple(leistung_haushalt),
|
||||
)
|
||||
|
||||
return gesamtlast.calculate_total_load()
|
||||
prediction_list = prediction_eos.key_to_array(
|
||||
key="load_mean_adjusted",
|
||||
start_datetime=prediction_eos.start_datetime,
|
||||
end_datetime=prediction_eos.end_datetime,
|
||||
).tolist()
|
||||
return prediction_list
|
||||
|
||||
|
||||
@app.get("/gesamtlast_simple")
|
||||
def fastapi_gesamtlast_simple(year_energy: float) -> list[float]:
|
||||
###############
|
||||
# Load Forecast
|
||||
###############
|
||||
lf = LoadForecast(
|
||||
filepath=server_dir / ".." / "data" / "load_profiles.npz", year_energy=year_energy
|
||||
) # Instantiate LoadForecast with specified parameters
|
||||
leistung_haushalt = lf.get_stats_for_date_range(
|
||||
prediction_eos.start_datetime, prediction_eos.end_datetime
|
||||
)[0] # Get expected household load for the date range
|
||||
"""Deprecated: Total Load Prediction.
|
||||
|
||||
prediction_hours = config_eos.prediction_hours if config_eos.prediction_hours else 48
|
||||
gesamtlast = LoadAggregator(prediction_hours=prediction_hours) # Create Gesamtlast instance
|
||||
gesamtlast.add_load(
|
||||
"Haushalt", tuple(leistung_haushalt)
|
||||
) # Add household to total load calculation
|
||||
Endpoint to handle total load prediction.
|
||||
|
||||
# ###############
|
||||
# # WP (Heat Pump)
|
||||
# ##############
|
||||
# leistung_wp = wp.simulate_24h(temperature_forecast) # Simulate heat pump load for 24 hours
|
||||
# gesamtlast.hinzufuegen("Heatpump", leistung_wp) # Add heat pump load to total load calculation
|
||||
Note:
|
||||
Use '/v1/prediction/list?key=load_mean' instead.
|
||||
"""
|
||||
settings = SettingsEOS(
|
||||
load_provider="LoadAkkudoktor",
|
||||
loadakkudoktor_year_energy=year_energy,
|
||||
)
|
||||
config_eos.merge_settings(settings=settings)
|
||||
ems_eos.set_start_datetime() # Set energy management start datetime to current hour.
|
||||
|
||||
return gesamtlast.calculate_total_load()
|
||||
# Create load forecast
|
||||
prediction_eos.update_data(force_update=True)
|
||||
|
||||
prediction_list = prediction_eos.key_to_array(
|
||||
key="load_mean",
|
||||
start_datetime=prediction_eos.start_datetime,
|
||||
end_datetime=prediction_eos.end_datetime,
|
||||
).tolist()
|
||||
return prediction_list
|
||||
|
||||
|
||||
class ForecastResponse(PydanticBaseModel):
|
||||
@ -231,7 +441,7 @@ def fastapi_optimize(
|
||||
] = None,
|
||||
) -> OptimizeResponse:
|
||||
if start_hour is None:
|
||||
start_hour = DateTime.now().hour
|
||||
start_hour = to_datetime().hour
|
||||
|
||||
# TODO: Remove when config and prediction update is done by EMS.
|
||||
config_eos.update()
|
||||
@ -313,16 +523,6 @@ async def proxy(request: Request, path: str) -> Union[Response | RedirectRespons
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def start_fasthtml_server() -> subprocess.Popen:
|
||||
"""Start the fasthtml server as a subprocess."""
|
||||
server_process = subprocess.Popen(
|
||||
[sys.executable, str(server_dir.joinpath("fasthtml_server.py"))],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
return server_process
|
||||
|
||||
|
||||
def start_fastapi_server() -> None:
|
||||
"""Start FastAPI server."""
|
||||
try:
|
||||
|
@ -23,6 +23,9 @@ class ServerCommonSettings(SettingsBaseModel):
|
||||
server_fastapi_port: Optional[int] = Field(
|
||||
default=8503, description="FastAPI server IP port number."
|
||||
)
|
||||
server_fastapi_startup_server_fasthtml: Optional[bool] = Field(
|
||||
default=True, description="FastAPI server to startup application FastHTML server."
|
||||
)
|
||||
server_fasthtml_host: Optional[IPvAnyAddress] = Field(
|
||||
default="0.0.0.0", description="FastHTML server IP address."
|
||||
)
|
||||
|
@ -31,21 +31,23 @@ import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import (
|
||||
IO,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
|
||||
from pendulum import DateTime, Duration
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -56,6 +58,21 @@ Param = ParamSpec("Param")
|
||||
RetType = TypeVar("RetType")
|
||||
|
||||
|
||||
class CacheFileRecord(BaseModel):
|
||||
# Enable custom serialization globally in config
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
use_enum_values=True,
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
cache_file: Any = Field(..., description="File descriptor of the cache file.")
|
||||
until_datetime: DateTime = Field(..., description="Datetime until the cache file is valid.")
|
||||
ttl_duration: Optional[Duration] = Field(
|
||||
default=None, description="Duration the cache file is valid."
|
||||
)
|
||||
|
||||
|
||||
class CacheFileStoreMeta(type, Generic[T]):
|
||||
"""A thread-safe implementation of CacheFileStore."""
|
||||
|
||||
@ -102,12 +119,36 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
This constructor sets up an empty key-value store (a dictionary) where each key
|
||||
corresponds to a cache file that is associated with a given key and an optional date.
|
||||
"""
|
||||
self._store: dict[str, tuple[IO[bytes], datetime]] = {}
|
||||
self._store: Dict[str, CacheFileRecord] = {}
|
||||
self._store_lock = threading.Lock()
|
||||
|
||||
def _until_datetime_by_options(
|
||||
self,
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Optional[Any] = None,
|
||||
) -> tuple[DateTime, Optional[Duration]]:
|
||||
"""Get until_datetime and ttl_duration from the given options."""
|
||||
ttl_duration = None
|
||||
if until_datetime:
|
||||
until_datetime = to_datetime(until_datetime)
|
||||
elif with_ttl:
|
||||
ttl_duration = to_duration(with_ttl)
|
||||
until_datetime = to_datetime() + ttl_duration
|
||||
elif until_date:
|
||||
until_datetime = to_datetime(until_date).end_of("day")
|
||||
else:
|
||||
# end of today
|
||||
until_datetime = to_datetime().end_of("day")
|
||||
return (until_datetime, ttl_duration)
|
||||
|
||||
def _generate_cache_file_key(
|
||||
self, key: str, until_datetime: Union[datetime, None]
|
||||
) -> tuple[str, datetime]:
|
||||
self,
|
||||
key: str,
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Optional[Any] = None,
|
||||
) -> tuple[str, DateTime, Optional[Duration]]:
|
||||
"""Generates a unique cache file key based on the key and date.
|
||||
|
||||
The cache file key is a combination of the input key and the date (if provided),
|
||||
@ -115,7 +156,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
|
||||
Args:
|
||||
key (str): The key that identifies the cache file.
|
||||
until_datetime (Optional[Any]): The datetime
|
||||
until_datetime (Optional[DateTime]): The datetime
|
||||
until the cache file is valid. The default is the current date at maximum time
|
||||
(23:59:59).
|
||||
|
||||
@ -123,12 +164,18 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
A tuple of:
|
||||
str: A hashed string that serves as the unique identifier for the cache file.
|
||||
datetime: The datetime until the the cache file is valid.
|
||||
Optional[ttl_duration]: Duration for ttl control.
|
||||
"""
|
||||
if until_datetime is None:
|
||||
until_datetime = datetime.combine(date.today(), time.max)
|
||||
key_datetime = to_datetime(until_datetime, as_string="UTC")
|
||||
until_datetime_dt, ttl_duration = self._until_datetime_by_options(
|
||||
until_date, until_datetime, with_ttl
|
||||
)
|
||||
if ttl_duration:
|
||||
# We need a special key for with_ttl, only encoding the with_ttl
|
||||
key_datetime = ttl_duration.in_words()
|
||||
else:
|
||||
key_datetime = to_datetime(until_datetime_dt, as_string="UTC")
|
||||
cache_key = hashlib.sha256(f"{key}{key_datetime}".encode("utf-8")).hexdigest()
|
||||
return (f"{cache_key}", until_datetime)
|
||||
return (f"{cache_key}", until_datetime_dt, ttl_duration)
|
||||
|
||||
def _get_file_path(self, file_obj: IO[bytes]) -> Optional[str]:
|
||||
"""Retrieve the file path from a file-like object.
|
||||
@ -147,37 +194,17 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
file_path = file_obj.name # Get the file path from the cache file object
|
||||
return file_path
|
||||
|
||||
def _until_datetime_by_options(
|
||||
self,
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Union[timedelta, str, int, float, None] = None,
|
||||
) -> datetime:
|
||||
"""Get until_datetime from the given options."""
|
||||
if until_datetime:
|
||||
until_datetime = to_datetime(until_datetime)
|
||||
elif with_ttl:
|
||||
with_ttl = to_duration(with_ttl)
|
||||
until_datetime = to_datetime(datetime.now() + with_ttl)
|
||||
elif until_date:
|
||||
until_datetime = to_datetime(to_datetime(until_date).date())
|
||||
else:
|
||||
# end of today
|
||||
until_datetime = to_datetime(datetime.combine(date.today(), time.max))
|
||||
return until_datetime
|
||||
|
||||
def _is_valid_cache_item(
|
||||
self,
|
||||
cache_item: tuple[IO[bytes], datetime],
|
||||
until_datetime: Optional[datetime] = None,
|
||||
at_datetime: Optional[datetime] = None,
|
||||
before_datetime: Optional[datetime] = None,
|
||||
cache_item: CacheFileRecord,
|
||||
until_datetime: Optional[DateTime] = None,
|
||||
at_datetime: Optional[DateTime] = None,
|
||||
before_datetime: Optional[DateTime] = None,
|
||||
) -> bool:
|
||||
cache_file_datetime = cache_item[1] # Extract the datetime associated with the cache item
|
||||
if (
|
||||
(until_datetime and until_datetime == cache_file_datetime)
|
||||
or (at_datetime and at_datetime <= cache_file_datetime)
|
||||
or (before_datetime and cache_file_datetime < before_datetime)
|
||||
(until_datetime and until_datetime == cache_item.until_datetime)
|
||||
or (at_datetime and at_datetime <= cache_item.until_datetime)
|
||||
or (before_datetime and cache_item.until_datetime < before_datetime)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@ -188,7 +215,8 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
until_datetime: Optional[Any] = None,
|
||||
at_datetime: Optional[Any] = None,
|
||||
before_datetime: Optional[Any] = None,
|
||||
) -> Optional[tuple[str, IO[bytes], datetime]]:
|
||||
ttl_duration: Optional[Any] = None,
|
||||
) -> tuple[str, Optional[CacheFileRecord]]:
|
||||
"""Searches for a cached item that matches the key and falls within the datetime range.
|
||||
|
||||
This method looks for a cache item with a key that matches the given `key`, and whose associated
|
||||
@ -203,48 +231,62 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
before_datetime (Optional[Any]): The datetime to compare the cache item's datetime to be before.
|
||||
|
||||
Returns:
|
||||
Optional[tuple]: Returns the cache_file_key, chache_file, cache_file_datetime if found,
|
||||
otherwise returns `None`.
|
||||
tuple[str, Optional[CacheFileRecord]]: Returns the cache_file_key, cache file record if found, otherwise returns `None`.
|
||||
"""
|
||||
# Convert input to datetime if they are not None
|
||||
until_datetime_dt: Optional[datetime] = None
|
||||
if until_datetime is not None:
|
||||
until_datetime_dt = to_datetime(until_datetime)
|
||||
at_datetime_dt: Optional[datetime] = None
|
||||
if at_datetime is not None:
|
||||
at_datetime_dt = to_datetime(at_datetime)
|
||||
before_datetime_dt: Optional[datetime] = None
|
||||
if before_datetime is not None:
|
||||
before_datetime_dt = to_datetime(before_datetime)
|
||||
if ttl_duration is not None:
|
||||
# TTL duration - use current datetime
|
||||
if until_datetime or at_datetime or before_datetime:
|
||||
raise NotImplementedError(
|
||||
f"Search with ttl_duration and datetime filter until:{until_datetime}, at:{at_datetime}, before:{before_datetime} is not implemented"
|
||||
)
|
||||
at_datetime = to_datetime()
|
||||
else:
|
||||
if until_datetime is not None:
|
||||
until_datetime = to_datetime(until_datetime)
|
||||
if at_datetime is not None:
|
||||
at_datetime = to_datetime(at_datetime)
|
||||
if before_datetime is not None:
|
||||
before_datetime = to_datetime(before_datetime)
|
||||
if until_datetime is None and at_datetime is None and before_datetime is None:
|
||||
at_datetime = to_datetime().end_of("day")
|
||||
|
||||
for cache_file_key, cache_item in self._store.items():
|
||||
# Check if the cache file datetime matches the given criteria
|
||||
if self._is_valid_cache_item(
|
||||
cache_item,
|
||||
until_datetime=until_datetime_dt,
|
||||
at_datetime=at_datetime_dt,
|
||||
before_datetime=before_datetime_dt,
|
||||
until_datetime=until_datetime,
|
||||
at_datetime=at_datetime,
|
||||
before_datetime=before_datetime,
|
||||
):
|
||||
# This cache file is within the given datetime range
|
||||
# Extract the datetime associated with the cache item
|
||||
cache_file_datetime = cache_item[1]
|
||||
|
||||
# Generate a cache file key based on the given key and the cache file datetime
|
||||
generated_key, _until_dt = self._generate_cache_file_key(key, cache_file_datetime)
|
||||
if cache_item.ttl_duration:
|
||||
generated_key, _until_dt, _ttl_duration = self._generate_cache_file_key(
|
||||
key, with_ttl=cache_item.ttl_duration
|
||||
)
|
||||
else:
|
||||
generated_key, _until_dt, _ttl_duration = self._generate_cache_file_key(
|
||||
key, until_datetime=cache_item.until_datetime
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Search: ttl:{ttl_duration}, until:{until_datetime}, at:{at_datetime}, before:{before_datetime} -> hit: {generated_key == cache_file_key}, item: {cache_item.cache_file.seek(0), cache_item.cache_file.read()}"
|
||||
)
|
||||
|
||||
if generated_key == cache_file_key:
|
||||
# The key matches, return the key and the cache item
|
||||
return (cache_file_key, cache_item[0], cache_file_datetime)
|
||||
# The key matches, return the cache item
|
||||
return (cache_file_key, cache_item)
|
||||
|
||||
# Return None if no matching cache item is found
|
||||
return None
|
||||
return ("<not found>", None)
|
||||
|
||||
def create(
|
||||
self,
|
||||
key: str,
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Union[timedelta, str, int, float, None] = None,
|
||||
with_ttl: Optional[Any] = None,
|
||||
mode: str = "wb+",
|
||||
delete: bool = False,
|
||||
suffix: Optional[str] = None,
|
||||
@ -261,8 +303,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
until_datetime (Optional[Any]): The datetime
|
||||
until the cache file is valid. Time of day is set to maximum time (23:59:59) if not
|
||||
provided.
|
||||
with_ttl (Union[timedelta, str, int, float, None], optional): The time to live that
|
||||
the cache file is valid. Time starts now.
|
||||
with_ttl (Optional[Any]): The time to live that the cache file is valid. Time starts now.
|
||||
mode (str, optional): The mode in which the tempfile is opened
|
||||
(e.g., 'w+', 'r+', 'wb+'). Defaults to 'wb+'.
|
||||
delete (bool, optional): Whether to delete the file after it is closed.
|
||||
@ -279,20 +320,22 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
>>> cache_file.seek(0)
|
||||
>>> print(cache_file.read()) # Output: 'Some cached data'
|
||||
"""
|
||||
until_datetime_dt = self._until_datetime_by_options(
|
||||
until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
|
||||
cache_file_key, until_datetime_dt, ttl_duration = self._generate_cache_file_key(
|
||||
key, until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
|
||||
)
|
||||
|
||||
cache_file_key, _ = self._generate_cache_file_key(key, until_datetime_dt)
|
||||
with self._store_lock: # Synchronize access to _store
|
||||
if (cache_file_item := self._store.get(cache_file_key)) is not None:
|
||||
if (cache_item := self._store.get(cache_file_key)) is not None:
|
||||
# File already available
|
||||
cache_file_obj = cache_file_item[0]
|
||||
cache_file_obj = cache_item.cache_file
|
||||
else:
|
||||
cache_file_obj = tempfile.NamedTemporaryFile(
|
||||
mode=mode, delete=delete, suffix=suffix
|
||||
)
|
||||
self._store[cache_file_key] = (cache_file_obj, until_datetime_dt)
|
||||
self._store[cache_file_key] = CacheFileRecord(
|
||||
cache_file=cache_file_obj,
|
||||
until_datetime=until_datetime_dt,
|
||||
ttl_duration=ttl_duration,
|
||||
)
|
||||
cache_file_obj.seek(0)
|
||||
return cache_file_obj
|
||||
|
||||
@ -302,7 +345,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
file_obj: IO[bytes],
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Union[timedelta, str, int, float, None] = None,
|
||||
with_ttl: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Stores a file-like object in the cache under the specified key and date.
|
||||
|
||||
@ -317,8 +360,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
until_datetime (Optional[Any]): The datetime
|
||||
until the cache file is valid. Time of day is set to maximum time (23:59:59) if not
|
||||
provided.
|
||||
with_ttl (Union[timedelta, str, int, float, None], optional): The time to live that
|
||||
the cache file is valid. Time starts now.
|
||||
with_ttl (Optional[Any]): The time to live that the cache file is valid. Time starts now.
|
||||
|
||||
Raises:
|
||||
ValueError: If the key is already in store.
|
||||
@ -326,16 +368,26 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
Example:
|
||||
>>> cache_store.set('example_file', io.BytesIO(b'Some binary data'))
|
||||
"""
|
||||
until_datetime_dt = self._until_datetime_by_options(
|
||||
until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
|
||||
cache_file_key, until_datetime_dt, ttl_duration = self._generate_cache_file_key(
|
||||
key, until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
|
||||
)
|
||||
|
||||
cache_file_key, until_date = self._generate_cache_file_key(key, until_datetime_dt)
|
||||
with self._store_lock: # Synchronize access to _store
|
||||
if cache_file_key in self._store:
|
||||
raise ValueError(f"Key already in store: `{key}`.")
|
||||
if ttl_duration:
|
||||
# Special with_ttl case
|
||||
if compare_datetimes(
|
||||
self._store[cache_file_key].until_datetime, to_datetime()
|
||||
).lt:
|
||||
# File is outdated - replace by new file
|
||||
self.delete(key=cache_file_key)
|
||||
else:
|
||||
raise ValueError(f"Key already in store: `{key}`.")
|
||||
else:
|
||||
raise ValueError(f"Key already in store: `{key}`.")
|
||||
|
||||
self._store[cache_file_key] = (file_obj, until_date)
|
||||
self._store[cache_file_key] = CacheFileRecord(
|
||||
cache_file=file_obj, until_datetime=until_datetime_dt, ttl_duration=ttl_duration
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@ -344,6 +396,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
until_datetime: Optional[Any] = None,
|
||||
at_datetime: Optional[Any] = None,
|
||||
before_datetime: Optional[Any] = None,
|
||||
ttl_duration: Optional[Any] = None,
|
||||
) -> Optional[IO[bytes]]:
|
||||
"""Retrieves the cache file associated with the given key and validity datetime.
|
||||
|
||||
@ -362,6 +415,8 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
provided. Defaults to the current datetime if None is provided.
|
||||
before_datetime (Optional[Any]): The datetime
|
||||
to compare the cache files datetime to be before.
|
||||
ttl_duration (Optional[Any]): The time to live to compare the cache files time to live
|
||||
to be equal.
|
||||
|
||||
Returns:
|
||||
file_obj: The file-like cache object, or None if no file is found.
|
||||
@ -373,21 +428,20 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
>>> print(cache_file.read()) # Output: Cached data (if exists)
|
||||
"""
|
||||
if until_datetime or until_date:
|
||||
until_datetime = self._until_datetime_by_options(
|
||||
until_datetime, _ttl_duration = self._until_datetime_by_options(
|
||||
until_datetime=until_datetime, until_date=until_date
|
||||
)
|
||||
elif at_datetime:
|
||||
at_datetime = to_datetime(at_datetime)
|
||||
elif before_datetime:
|
||||
before_datetime = to_datetime(before_datetime)
|
||||
else:
|
||||
at_datetime = to_datetime(datetime.now())
|
||||
|
||||
with self._store_lock: # Synchronize access to _store
|
||||
search_item = self._search(key, until_datetime, at_datetime, before_datetime)
|
||||
_cache_file_key, search_item = self._search(
|
||||
key,
|
||||
until_datetime=until_datetime,
|
||||
at_datetime=at_datetime,
|
||||
before_datetime=before_datetime,
|
||||
ttl_duration=ttl_duration,
|
||||
)
|
||||
if search_item is None:
|
||||
return None
|
||||
return search_item[1]
|
||||
return search_item.cache_file
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@ -418,17 +472,15 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
elif before_datetime:
|
||||
before_datetime = to_datetime(before_datetime)
|
||||
else:
|
||||
today = datetime.now().date() # Get today's date
|
||||
tomorrow = today + timedelta(days=1) # Add one day to get tomorrow's date
|
||||
before_datetime = to_datetime(datetime.combine(tomorrow, time.min))
|
||||
# Make before_datetime tommorow at start of day
|
||||
before_datetime = to_datetime().add(days=1).start_of("day")
|
||||
|
||||
with self._store_lock: # Synchronize access to _store
|
||||
search_item = self._search(key, until_datetime, None, before_datetime)
|
||||
cache_file_key, search_item = self._search(
|
||||
key, until_datetime=until_datetime, before_datetime=before_datetime
|
||||
)
|
||||
if search_item:
|
||||
cache_file_key = search_item[0]
|
||||
cache_file = search_item[1]
|
||||
cache_file_datetime = search_item[2]
|
||||
file_path = self._get_file_path(cache_file)
|
||||
file_path = self._get_file_path(search_item.cache_file)
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"The cache file with key '{cache_file_key}' is an in memory "
|
||||
@ -436,9 +488,10 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
)
|
||||
self._store.pop(cache_file_key)
|
||||
return
|
||||
file_path = cache_file.name # Get the file path from the cache file object
|
||||
# Get the file path from the cache file object
|
||||
file_path = search_item.cache_file.name
|
||||
del self._store[cache_file_key]
|
||||
if os.path.exists(file_path):
|
||||
if file_path and os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.debug(f"Deleted cache file: {file_path}")
|
||||
@ -462,30 +515,31 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
|
||||
OSError: If there's an error during file deletion.
|
||||
"""
|
||||
delete_keys = [] # List of keys to delete, prevent deleting when traversing the store
|
||||
clear_timestamp = None
|
||||
|
||||
# Some weired logic to prevent calling to_datetime on clear_all.
|
||||
# Clear_all may be set on __del__. At this time some info for to_datetime will
|
||||
# not be available anymore.
|
||||
if not clear_all:
|
||||
if before_datetime is None:
|
||||
before_datetime = to_datetime().start_of("day")
|
||||
else:
|
||||
before_datetime = to_datetime(before_datetime)
|
||||
|
||||
with self._store_lock: # Synchronize access to _store
|
||||
for cache_file_key, cache_item in self._store.items():
|
||||
cache_file = cache_item[0]
|
||||
|
||||
# Some weired logic to prevent calling to_datetime on clear_all.
|
||||
# Clear_all may be set on __del__. At this time some info for to_datetime will
|
||||
# not be available anymore.
|
||||
clear_file = clear_all
|
||||
if not clear_all:
|
||||
if clear_timestamp is None:
|
||||
before_datetime = to_datetime(before_datetime, to_maxtime=False)
|
||||
# Convert the threshold date to a timestamp (seconds since epoch)
|
||||
clear_timestamp = to_datetime(before_datetime).timestamp()
|
||||
cache_file_timestamp = to_datetime(cache_item[1]).timestamp()
|
||||
if cache_file_timestamp < clear_timestamp:
|
||||
clear_file = True
|
||||
if clear_all:
|
||||
clear_file = True
|
||||
else:
|
||||
clear_file = compare_datetimes(cache_item.until_datetime, before_datetime).lt
|
||||
|
||||
if clear_file:
|
||||
# We have to clear this cache file
|
||||
delete_keys.append(cache_file_key)
|
||||
|
||||
file_path = self._get_file_path(cache_file)
|
||||
file_path = self._get_file_path(cache_item.cache_file)
|
||||
|
||||
if file_path is None:
|
||||
# In memory file like object
|
||||
@ -516,7 +570,7 @@ def cache_in_file(
|
||||
force_update: Optional[bool] = None,
|
||||
until_date: Optional[Any] = None,
|
||||
until_datetime: Optional[Any] = None,
|
||||
with_ttl: Union[timedelta, str, int, float, None] = None,
|
||||
with_ttl: Optional[Any] = None,
|
||||
mode: Literal["w", "w+", "wb", "wb+", "r", "r+", "rb", "rb+"] = "wb+",
|
||||
delete: bool = False,
|
||||
suffix: Optional[str] = None,
|
||||
@ -620,7 +674,7 @@ def cache_in_file(
|
||||
elif param == "with_ttl":
|
||||
until_datetime = None
|
||||
until_date = None
|
||||
with_ttl = kwargs[param] # type: ignore[assignment]
|
||||
with_ttl = kwargs[param]
|
||||
elif param == "until_date":
|
||||
until_datetime = None
|
||||
until_date = kwargs[param]
|
||||
@ -642,7 +696,9 @@ def cache_in_file(
|
||||
|
||||
result: Optional[RetType | bytes] = None
|
||||
# Get cache file that is currently valid
|
||||
cache_file = CacheFileStore().get(key)
|
||||
cache_file = CacheFileStore().get(
|
||||
key, until_date=until_date, until_datetime=until_datetime, ttl_duration=with_ttl
|
||||
)
|
||||
if not force_update and cache_file is not None:
|
||||
# cache file is available
|
||||
try:
|
||||
|
@ -19,7 +19,7 @@ Example usage:
|
||||
>>> to_duration("2 days 5 hours")
|
||||
|
||||
# Timezone detection
|
||||
>>> to_timezone(location={40.7128, -74.0060})
|
||||
>>> to_timezone(location=(40.7128, -74.0060))
|
||||
"""
|
||||
|
||||
import re
|
||||
@ -27,7 +27,7 @@ from datetime import date, datetime, timedelta
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union, overload
|
||||
|
||||
import pendulum
|
||||
from pendulum import DateTime
|
||||
from pendulum import Date, DateTime, Duration
|
||||
from pendulum.tz.timezone import Timezone
|
||||
from timezonefinder import TimezoneFinder
|
||||
|
||||
@ -71,6 +71,7 @@ def to_datetime(
|
||||
date_input (Optional[Any]): The date input to convert. Supported types include:
|
||||
- `str`: A date string in various formats (e.g., "2024-10-13", "13 Oct 2024").
|
||||
- `pendulum.DateTime`: A Pendulum DateTime object.
|
||||
- `pendulum.Date`: A Pendulum Date object, which will be converted to a datetime at the start or end of the day.
|
||||
- `datetime.datetime`: A standard Python datetime object.
|
||||
- `datetime.date`: A date object, which will be converted to a datetime at the start or end of the day.
|
||||
- `int` or `float`: A Unix timestamp, interpreted as seconds since the epoch (UTC).
|
||||
@ -123,6 +124,14 @@ def to_datetime(
|
||||
|
||||
if isinstance(date_input, DateTime):
|
||||
dt = date_input
|
||||
elif isinstance(date_input, Date):
|
||||
dt = pendulum.datetime(
|
||||
year=date_input.year, month=date_input.month, day=date_input.day, tz=in_timezone
|
||||
)
|
||||
if to_maxtime:
|
||||
dt = dt.end_of("day")
|
||||
else:
|
||||
dt = dt.start_of("day")
|
||||
elif isinstance(date_input, str):
|
||||
# Convert to timezone aware datetime
|
||||
dt = None
|
||||
@ -161,14 +170,22 @@ def to_datetime(
|
||||
except pendulum.parsing.exceptions.ParserError as e:
|
||||
logger.debug(f"Date string {date_input} does not match any Pendulum formats: {e}")
|
||||
dt = None
|
||||
if dt is None:
|
||||
# Some special values
|
||||
if date_input.lower() == "infinity":
|
||||
# Subtract one year from max as max datetime will create an overflow error in certain context.
|
||||
dt = DateTime.max.subtract(years=1)
|
||||
if dt is None:
|
||||
try:
|
||||
timestamp = float(date_input)
|
||||
dt = pendulum.from_timestamp(timestamp, tz="UTC")
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.debug(f"Date string {date_input} does not match timestamp format: {e}")
|
||||
dt = None
|
||||
if dt is None:
|
||||
raise ValueError(f"Date string {date_input} does not match any known formats.")
|
||||
elif date_input is None:
|
||||
dt = (
|
||||
pendulum.today(tz=in_timezone).end_of("day")
|
||||
if to_maxtime
|
||||
else pendulum.today(tz=in_timezone).start_of("day")
|
||||
)
|
||||
dt = pendulum.now(tz=in_timezone)
|
||||
elif isinstance(date_input, datetime):
|
||||
dt = pendulum.instance(date_input)
|
||||
elif isinstance(date_input, date):
|
||||
@ -206,19 +223,19 @@ def to_datetime(
|
||||
|
||||
|
||||
def to_duration(
|
||||
input_value: Union[timedelta, str, int, float, Tuple[int, int, int, int], List[int]],
|
||||
) -> timedelta:
|
||||
"""Converts various input types into a timedelta object using pendulum.
|
||||
input_value: Union[Duration, timedelta, str, int, float, Tuple[int, int, int, int], List[int]],
|
||||
) -> Duration:
|
||||
"""Converts various input types into a Duration object using pendulum.
|
||||
|
||||
Args:
|
||||
input_value (Union[timedelta, str, int, float, tuple, list]): Input to be converted
|
||||
input_value (Union[Duration, timedelta, str, int, float, tuple, list]): Input to be converted
|
||||
into a timedelta:
|
||||
- str: A duration string like "2 days", "5 hours", "30 minutes", or a combination.
|
||||
- int/float: Number representing seconds.
|
||||
- tuple/list: A tuple or list in the format (days, hours, minutes, seconds).
|
||||
|
||||
Returns:
|
||||
timedelta: A timedelta object corresponding to the input value.
|
||||
duration: A Duration object corresponding to the input value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input format is not supported.
|
||||
@ -233,18 +250,21 @@ def to_duration(
|
||||
>>> to_duration((1, 2, 30, 15))
|
||||
timedelta(days=1, seconds=90315)
|
||||
"""
|
||||
if isinstance(input_value, timedelta):
|
||||
if isinstance(input_value, Duration):
|
||||
return input_value
|
||||
|
||||
if isinstance(input_value, timedelta):
|
||||
return pendulum.duration(seconds=input_value.total_seconds())
|
||||
|
||||
if isinstance(input_value, (int, float)):
|
||||
# Handle integers or floats as seconds
|
||||
return timedelta(seconds=input_value)
|
||||
return pendulum.duration(seconds=input_value)
|
||||
|
||||
elif isinstance(input_value, (tuple, list)):
|
||||
# Handle tuple or list: (days, hours, minutes, seconds)
|
||||
if len(input_value) == 4:
|
||||
days, hours, minutes, seconds = input_value
|
||||
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
return pendulum.duration(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
else:
|
||||
error_msg = f"Expected a tuple or list of length 4, got {len(input_value)}"
|
||||
logger.error(error_msg)
|
||||
@ -340,7 +360,7 @@ def to_timezone(
|
||||
>>> to_timezone(utc_offset=5.5, as_string=True)
|
||||
'UTC+05:30'
|
||||
|
||||
>>> to_timezone(location={40.7128, -74.0060})
|
||||
>>> to_timezone(location=(40.7128, -74.0060))
|
||||
<Timezone [America/New_York]>
|
||||
|
||||
>>> to_timezone()
|
||||
|
@ -427,9 +427,9 @@ def prepare_visualize(
|
||||
report.generate_pdf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
report = VisualizationReport("example_report.pdf")
|
||||
def generate_example_report(filename: str = "example_report.pdf") -> None:
|
||||
"""Generate example visualization report."""
|
||||
report = VisualizationReport(filename)
|
||||
x_hours = 0 # Define x-axis start values (e.g., hours)
|
||||
|
||||
# Group 1: Adding charts to be displayed on the same page
|
||||
@ -502,3 +502,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Generate the PDF report
|
||||
report.generate_pdf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_example_report()
|
||||
|
@ -2,13 +2,13 @@
|
||||
|
||||
import io
|
||||
import pickle
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.utils.cacheutil import CacheFileStore, cache_in_file
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
|
||||
# -----------------------------
|
||||
# CacheFileStore
|
||||
@ -18,24 +18,63 @@ from akkudoktoreos.utils.datetimeutil import to_datetime
|
||||
@pytest.fixture
|
||||
def cache_store():
|
||||
"""A pytest fixture that creates a new CacheFileStore instance for testing."""
|
||||
return CacheFileStore()
|
||||
cache = CacheFileStore()
|
||||
cache.clear(clear_all=True)
|
||||
assert len(cache._store) == 0
|
||||
return cache
|
||||
|
||||
|
||||
def test_generate_cache_file_key(cache_store):
|
||||
"""Test cache file key generation based on URL and date."""
|
||||
key = "http://example.com"
|
||||
until_dt = to_datetime("2024-10-01").date()
|
||||
cache_file_key, cache_file_until_dt = cache_store._generate_cache_file_key(key, until_dt)
|
||||
|
||||
# Provide until date - assure until_dt is used.
|
||||
until_dt = to_datetime("2024-10-01")
|
||||
cache_file_key, cache_file_until_dt, ttl_duration = cache_store._generate_cache_file_key(
|
||||
key=key, until_datetime=until_dt
|
||||
)
|
||||
assert cache_file_key is not None
|
||||
assert cache_file_until_dt == until_dt
|
||||
assert compare_datetimes(cache_file_until_dt, until_dt).equal
|
||||
|
||||
# Provide until date again - assure same key is generated.
|
||||
cache_file_key1, cache_file_until_dt1, ttl_duration1 = cache_store._generate_cache_file_key(
|
||||
key=key, until_datetime=until_dt
|
||||
)
|
||||
assert cache_file_key1 == cache_file_key
|
||||
assert compare_datetimes(cache_file_until_dt1, until_dt).equal
|
||||
|
||||
# Provide no until date - assure today EOD is used.
|
||||
until_dt = datetime.combine(date.today(), time.max)
|
||||
cache_file_key, cache_file_until_dt = cache_store._generate_cache_file_key(key, None)
|
||||
assert cache_file_until_dt == until_dt
|
||||
cache_file_key1, cache_file_until_dt1 = cache_store._generate_cache_file_key(key, until_dt)
|
||||
assert cache_file_key == cache_file_key1
|
||||
assert cache_file_until_dt == until_dt
|
||||
no_until_dt = to_datetime().end_of("day")
|
||||
cache_file_key, cache_file_until_dt, ttl_duration = cache_store._generate_cache_file_key(key)
|
||||
assert cache_file_key is not None
|
||||
assert compare_datetimes(cache_file_until_dt, no_until_dt).equal
|
||||
|
||||
# Provide with_ttl - assure until_dt is used.
|
||||
until_dt = to_datetime().add(hours=1)
|
||||
cache_file_key, cache_file_until_dt, ttl_duration = cache_store._generate_cache_file_key(
|
||||
key, with_ttl="1 hour"
|
||||
)
|
||||
assert cache_file_key is not None
|
||||
assert compare_datetimes(cache_file_until_dt, until_dt).approximately_equal
|
||||
assert ttl_duration == to_duration("1 hour")
|
||||
|
||||
# Provide with_ttl again - assure same key is generated.
|
||||
until_dt = to_datetime().add(hours=1)
|
||||
cache_file_key1, cache_file_until_dt1, ttl_duration1 = cache_store._generate_cache_file_key(
|
||||
key=key, with_ttl="1 hour"
|
||||
)
|
||||
assert cache_file_key1 == cache_file_key
|
||||
assert compare_datetimes(cache_file_until_dt1, until_dt).approximately_equal
|
||||
assert ttl_duration1 == to_duration("1 hour")
|
||||
|
||||
# Provide different with_ttl - assure different key is generated.
|
||||
until_dt = to_datetime().add(hours=1, minutes=1)
|
||||
cache_file_key2, cache_file_until_dt2, ttl_duration2 = cache_store._generate_cache_file_key(
|
||||
key=key, with_ttl="1 hour 1 minute"
|
||||
)
|
||||
assert cache_file_key2 != cache_file_key
|
||||
assert compare_datetimes(cache_file_until_dt2, until_dt).approximately_equal
|
||||
assert ttl_duration2 == to_duration("1 hour 1 minute")
|
||||
|
||||
|
||||
def test_get_file_path(cache_store):
|
||||
@ -46,6 +85,77 @@ def test_get_file_path(cache_store):
|
||||
assert file_path is not None
|
||||
|
||||
|
||||
def test_until_datetime_by_options(cache_store):
|
||||
"""Test until datetime calculation based on options."""
|
||||
now = to_datetime()
|
||||
|
||||
# Test with until_datetime
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(until_datetime=now)
|
||||
assert result == now
|
||||
assert ttl_duration is None
|
||||
|
||||
# -- From now on we expect a until_datetime in one hour
|
||||
ttl_duration_expected = to_duration("1 hour")
|
||||
|
||||
# Test with with_ttl as timedelta
|
||||
until_datetime_expected = to_datetime().add(hours=1)
|
||||
ttl = timedelta(hours=1)
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(with_ttl=ttl)
|
||||
assert compare_datetimes(result, until_datetime_expected).approximately_equal
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# Test with with_ttl as int (seconds)
|
||||
until_datetime_expected = to_datetime().add(hours=1)
|
||||
ttl_seconds = 3600
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(with_ttl=ttl_seconds)
|
||||
assert compare_datetimes(result, until_datetime_expected).approximately_equal
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# Test with with_ttl as string ("1 hour")
|
||||
until_datetime_expected = to_datetime().add(hours=1)
|
||||
ttl_string = "1 hour"
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(with_ttl=ttl_string)
|
||||
assert compare_datetimes(result, until_datetime_expected).approximately_equal
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# -- From now on we expect a until_datetime today at end of day
|
||||
until_datetime_expected = to_datetime().end_of("day")
|
||||
ttl_duration_expected = None
|
||||
|
||||
# Test default case (end of today)
|
||||
result, ttl_duration = cache_store._until_datetime_by_options()
|
||||
assert compare_datetimes(result, until_datetime_expected).equal
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# -- From now on we expect a until_datetime in one day at end of day
|
||||
until_datetime_expected = to_datetime().add(days=1).end_of("day")
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# Test with until_date as date
|
||||
until_date = date.today() + timedelta(days=1)
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(until_date=until_date)
|
||||
assert compare_datetimes(result, until_datetime_expected).equal
|
||||
assert ttl_duration == ttl_duration_expected
|
||||
|
||||
# -- Test with multiple options (until_datetime takes precedence)
|
||||
specific_datetime = to_datetime().add(days=2)
|
||||
result, ttl_duration = cache_store._until_datetime_by_options(
|
||||
until_date=to_datetime().add(days=1).date(),
|
||||
until_datetime=specific_datetime,
|
||||
with_ttl=ttl,
|
||||
)
|
||||
assert compare_datetimes(result, specific_datetime).equal
|
||||
assert ttl_duration is None
|
||||
|
||||
# Test with invalid inputs
|
||||
with pytest.raises(ValueError):
|
||||
cache_store._until_datetime_by_options(until_date="invalid-date")
|
||||
with pytest.raises(ValueError):
|
||||
cache_store._until_datetime_by_options(with_ttl="invalid-ttl")
|
||||
with pytest.raises(ValueError):
|
||||
cache_store._until_datetime_by_options(until_datetime="invalid-datetime")
|
||||
|
||||
|
||||
def test_create_cache_file(cache_store):
|
||||
"""Test the creation of a cache file and ensure it is stored correctly."""
|
||||
# Create a cache file for today's date
|
||||
@ -145,7 +255,7 @@ def test_clear_cache_files_by_date(cache_store):
|
||||
assert cache_store.get("file2") is cache_file2
|
||||
|
||||
# Clear cache files that are older than today
|
||||
cache_store.clear(before_datetime=datetime.combine(date.today(), time.min))
|
||||
cache_store.clear(before_datetime=to_datetime().start_of("day"))
|
||||
|
||||
# Ensure the files are in the store
|
||||
assert cache_store.get("file1") is cache_file1
|
||||
@ -228,7 +338,7 @@ def test_cache_in_file_decorator_caches_function_result(cache_store):
|
||||
|
||||
# Check if the result was written to the cache file
|
||||
key = next(iter(cache_store._store))
|
||||
cache_file = cache_store._store[key][0]
|
||||
cache_file = cache_store._store[key].cache_file
|
||||
assert cache_file is not None
|
||||
|
||||
# Assert correct content was written to the file
|
||||
@ -248,12 +358,12 @@ def test_cache_in_file_decorator_uses_cache(cache_store):
|
||||
return "New result"
|
||||
|
||||
# Call the decorated function (should store result in cache)
|
||||
result = my_function(until_date=datetime.now() + timedelta(days=1))
|
||||
result = my_function(until_date=to_datetime().add(days=1))
|
||||
assert result == "New result"
|
||||
|
||||
# Assert result was written to cache file
|
||||
key = next(iter(cache_store._store))
|
||||
cache_file = cache_store._store[key][0]
|
||||
cache_file = cache_store._store[key].cache_file
|
||||
assert cache_file is not None
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == result
|
||||
@ -264,7 +374,7 @@ def test_cache_in_file_decorator_uses_cache(cache_store):
|
||||
cache_file.write(result2)
|
||||
|
||||
# Call the decorated function again (should get result from cache)
|
||||
result = my_function(until_date=datetime.now() + timedelta(days=1))
|
||||
result = my_function(until_date=to_datetime().add(days=1))
|
||||
assert result == result2
|
||||
|
||||
|
||||
@ -279,7 +389,7 @@ def test_cache_in_file_decorator_forces_update_data(cache_store):
|
||||
def my_function(until_date=None):
|
||||
return "New result"
|
||||
|
||||
until_date = datetime.now() + timedelta(days=1)
|
||||
until_date = to_datetime().add(days=1).date()
|
||||
|
||||
# Call the decorated function (should store result in cache)
|
||||
result1 = "New result"
|
||||
@ -288,7 +398,7 @@ def test_cache_in_file_decorator_forces_update_data(cache_store):
|
||||
|
||||
# Assert result was written to cache file
|
||||
key = next(iter(cache_store._store))
|
||||
cache_file = cache_store._store[key][0]
|
||||
cache_file = cache_store._store[key].cache_file
|
||||
assert cache_file is not None
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == result
|
||||
@ -297,6 +407,8 @@ def test_cache_in_file_decorator_forces_update_data(cache_store):
|
||||
result2 = "Cached result"
|
||||
cache_file.seek(0)
|
||||
cache_file.write(result2)
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == result2
|
||||
|
||||
# Call the decorated function again with force update (should get result from function)
|
||||
result = my_function(until_date=until_date, force_update=True) # type: ignore[call-arg]
|
||||
@ -309,9 +421,6 @@ def test_cache_in_file_decorator_forces_update_data(cache_store):
|
||||
|
||||
def test_cache_in_file_handles_ttl(cache_store):
|
||||
"""Test that the cache_infile decorator handles the with_ttl parameter."""
|
||||
# Clear store to assure it is empty
|
||||
cache_store.clear(clear_all=True)
|
||||
assert len(cache_store._store) == 0
|
||||
|
||||
# Define a simple function to decorate
|
||||
@cache_in_file(mode="w+")
|
||||
@ -319,26 +428,37 @@ def test_cache_in_file_handles_ttl(cache_store):
|
||||
return "New result"
|
||||
|
||||
# Call the decorated function
|
||||
result = my_function(with_ttl="1 second") # type: ignore[call-arg]
|
||||
result1 = my_function(with_ttl="1 second") # type: ignore[call-arg]
|
||||
assert result1 == "New result"
|
||||
assert len(cache_store._store) == 1
|
||||
key = list(cache_store._store.keys())[0]
|
||||
|
||||
# Overwrite cache file
|
||||
# Assert result was written to cache file
|
||||
key = next(iter(cache_store._store))
|
||||
cache_file = cache_store._store[key][0]
|
||||
cache_file = cache_store._store[key].cache_file
|
||||
assert cache_file is not None
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
cache_file.write("Modified result")
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == "Modified result"
|
||||
assert cache_file.read() == result1
|
||||
|
||||
# Modify cache file
|
||||
result2 = "Cached result"
|
||||
cache_file.seek(0)
|
||||
cache_file.write(result2)
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == result2
|
||||
|
||||
# Call the decorated function again
|
||||
result = my_function(with_ttl="1 second") # type: ignore[call-arg]
|
||||
assert result == "Modified result"
|
||||
cache_file.seek(0) # Move to the start of the file
|
||||
assert cache_file.read() == result2
|
||||
assert result == result2
|
||||
|
||||
# Wait one second to let the cache time out
|
||||
sleep(1)
|
||||
sleep(2)
|
||||
|
||||
# Call again - cache should be timed out
|
||||
result = my_function(with_ttl="1 second") # type: ignore[call-arg]
|
||||
assert result == "New result"
|
||||
assert result == result1
|
||||
|
||||
|
||||
def test_cache_in_file_handles_bytes_return(cache_store):
|
||||
@ -357,7 +477,7 @@ def test_cache_in_file_handles_bytes_return(cache_store):
|
||||
|
||||
# Check if the binary data was written to the cache file
|
||||
key = next(iter(cache_store._store))
|
||||
cache_file = cache_store._store[key][0]
|
||||
cache_file = cache_store._store[key].cache_file
|
||||
assert len(cache_store._store) == 1
|
||||
assert cache_file is not None
|
||||
cache_file.seek(0)
|
||||
@ -367,5 +487,5 @@ def test_cache_in_file_handles_bytes_return(cache_store):
|
||||
# Access cache
|
||||
result = my_function(until_date=datetime.now() + timedelta(days=1))
|
||||
assert len(cache_store._store) == 1
|
||||
assert cache_store._store[key][0] is not None
|
||||
assert cache_store._store[key].cache_file is not None
|
||||
assert result1 == result
|
||||
|
@ -346,6 +346,127 @@ class TestDataSequence:
|
||||
assert array[1] == 7
|
||||
assert array[2] == last_datetime.day
|
||||
|
||||
def test_key_to_array_linear_interpolation(self, sequence):
|
||||
"""Test key_to_array with linear interpolation for numeric data."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 6, 0), 0.8)
|
||||
record2 = self.create_test_record(pendulum.datetime(2023, 11, 6, 2), 1.0) # Gap of 2 hours
|
||||
sequence.insert_by_datetime(record1)
|
||||
sequence.insert_by_datetime(record2)
|
||||
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 3),
|
||||
interval=interval,
|
||||
fill_method="linear",
|
||||
)
|
||||
assert len(array) == 3
|
||||
assert array[0] == 0.8
|
||||
assert array[1] == 0.9 # Interpolated value
|
||||
assert array[2] == 1.0
|
||||
|
||||
def test_key_to_array_ffill(self, sequence):
|
||||
"""Test key_to_array with forward filling for missing values."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 6, 0), 0.8)
|
||||
record2 = self.create_test_record(pendulum.datetime(2023, 11, 6, 2), 1.0)
|
||||
sequence.insert_by_datetime(record1)
|
||||
sequence.insert_by_datetime(record2)
|
||||
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 3),
|
||||
interval=interval,
|
||||
fill_method="ffill",
|
||||
)
|
||||
assert len(array) == 3
|
||||
assert array[0] == 0.8
|
||||
assert array[1] == 0.8 # Forward-filled value
|
||||
assert array[2] == 1.0
|
||||
|
||||
def test_key_to_array_bfill(self, sequence):
|
||||
"""Test key_to_array with backward filling for missing values."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 6, 0), 0.8)
|
||||
record2 = self.create_test_record(pendulum.datetime(2023, 11, 6, 2), 1.0)
|
||||
sequence.insert_by_datetime(record1)
|
||||
sequence.insert_by_datetime(record2)
|
||||
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 3),
|
||||
interval=interval,
|
||||
fill_method="bfill",
|
||||
)
|
||||
assert len(array) == 3
|
||||
assert array[0] == 0.8
|
||||
assert array[1] == 1.0 # Backward-filled value
|
||||
assert array[2] == 1.0
|
||||
|
||||
def test_key_to_array_with_truncation(self, sequence):
|
||||
"""Test truncation behavior in key_to_array."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 5, 23), 0.8)
|
||||
record2 = self.create_test_record(pendulum.datetime(2023, 11, 6, 1), 1.0)
|
||||
sequence.insert_by_datetime(record1)
|
||||
sequence.insert_by_datetime(record2)
|
||||
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 2),
|
||||
interval=interval,
|
||||
)
|
||||
assert len(array) == 2
|
||||
assert array[0] == 0.9 # Interpolated from previous day
|
||||
assert array[1] == 1.0
|
||||
|
||||
def test_key_to_array_with_none(self, sequence):
|
||||
"""Test handling of empty series in key_to_array."""
|
||||
interval = to_duration("1 hour")
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 3),
|
||||
interval=interval,
|
||||
)
|
||||
assert isinstance(array, np.ndarray)
|
||||
assert np.all(array == None)
|
||||
|
||||
def test_key_to_array_with_one(self, sequence):
|
||||
"""Test handling of one element series in key_to_array."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 5, 23), 0.8)
|
||||
sequence.insert_by_datetime(record1)
|
||||
|
||||
array = sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 2),
|
||||
interval=interval,
|
||||
)
|
||||
assert len(array) == 2
|
||||
assert array[0] == 0.8 # Interpolated from previous day
|
||||
assert array[1] == 0.8
|
||||
|
||||
def test_key_to_array_invalid_fill_method(self, sequence):
|
||||
"""Test invalid fill_method raises an error."""
|
||||
interval = to_duration("1 hour")
|
||||
record1 = self.create_test_record(pendulum.datetime(2023, 11, 6, 0), 0.8)
|
||||
sequence.insert_by_datetime(record1)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported fill method: invalid"):
|
||||
sequence.key_to_array(
|
||||
key="data_value",
|
||||
start_datetime=pendulum.datetime(2023, 11, 6),
|
||||
end_datetime=pendulum.datetime(2023, 11, 6, 1),
|
||||
interval=interval,
|
||||
fill_method="invalid",
|
||||
)
|
||||
|
||||
def test_to_datetimeindex(self, sequence2):
|
||||
record1 = self.create_test_record(datetime(2023, 11, 5), 0.8)
|
||||
record2 = self.create_test_record(datetime(2023, 11, 6), 0.9)
|
||||
@ -531,10 +652,9 @@ class TestDataImportProvider:
|
||||
],
|
||||
)
|
||||
def test_import_datetimes(self, provider, start_datetime, value_count, expected_mapping_count):
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
|
||||
start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin")
|
||||
|
||||
value_datetime_mapping = provider.import_datetimes(value_count)
|
||||
value_datetime_mapping = provider.import_datetimes(start_datetime, value_count)
|
||||
|
||||
assert len(value_datetime_mapping) == expected_mapping_count
|
||||
|
||||
@ -551,11 +671,10 @@ class TestDataImportProvider:
|
||||
self, set_other_timezone, provider, start_datetime, value_count, expected_mapping_count
|
||||
):
|
||||
original_tz = set_other_timezone("Etc/UTC")
|
||||
ems_eos = get_ems()
|
||||
ems_eos.set_start_datetime(to_datetime(start_datetime, in_timezone="Europe/Berlin"))
|
||||
assert ems_eos.start_datetime.timezone.name == "Europe/Berlin"
|
||||
start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin")
|
||||
assert start_datetime.timezone.name == "Europe/Berlin"
|
||||
|
||||
value_datetime_mapping = provider.import_datetimes(value_count)
|
||||
value_datetime_mapping = provider.import_datetimes(start_datetime, value_count)
|
||||
|
||||
assert len(value_datetime_mapping) == expected_mapping_count
|
||||
|
||||
@ -636,7 +755,7 @@ class TestDataContainer:
|
||||
del container_with_providers["data_value"]
|
||||
series = container_with_providers["data_value"]
|
||||
assert series.name == "data_value"
|
||||
assert series.tolist() == [None, None, None]
|
||||
assert series.tolist() == []
|
||||
|
||||
def test_delitem_non_existing_key(self, container_with_providers):
|
||||
with pytest.raises(KeyError, match="Key 'non_existent_key' not found"):
|
||||
|
@ -19,7 +19,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
|
||||
# Test cases for valid pendulum.duration inputs
|
||||
@pytest.mark.parametrize(
|
||||
"test_case, local_timezone, date_input, as_string, in_timezone, to_naiv, to_maxtime, expected_output",
|
||||
"test_case, local_timezone, date_input, as_string, in_timezone, to_naiv, to_maxtime, expected_output, expected_approximately",
|
||||
[
|
||||
# ---------------------------------------
|
||||
# from string to pendulum.datetime object
|
||||
@ -34,6 +34,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 0, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC002",
|
||||
@ -44,6 +45,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 0, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC003",
|
||||
@ -54,6 +56,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2023, 12, 31, 23, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC004",
|
||||
@ -64,6 +67,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 0, 0, 0, tz="Europe/Paris"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC005",
|
||||
@ -74,6 +78,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 1, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC006",
|
||||
@ -84,6 +89,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2023, 12, 31, 23, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC007",
|
||||
@ -102,6 +108,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
0,
|
||||
tz="Atlantic/Canary",
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC008",
|
||||
@ -112,6 +119,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 13, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC009",
|
||||
@ -122,6 +130,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 1, 1, 11, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
# - with timezone
|
||||
(
|
||||
@ -133,6 +142,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 2, 2, 0, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC011",
|
||||
@ -143,6 +153,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
pendulum.datetime(2024, 3, 3, 10, 20, 30, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC012",
|
||||
@ -153,6 +164,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
False,
|
||||
None,
|
||||
pendulum.datetime(2024, 4, 4, 10, 20, 30, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC013",
|
||||
@ -163,6 +175,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
True,
|
||||
None,
|
||||
pendulum.naive(2024, 5, 5, 10, 20, 30, 0),
|
||||
False,
|
||||
),
|
||||
# - without local timezone as UTC
|
||||
(
|
||||
@ -174,6 +187,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 2, 2, 0, 0, 0, tz="UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC015",
|
||||
@ -184,6 +198,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
pendulum.datetime(2024, 3, 3, 10, 20, 30, 0, tz="UTC"),
|
||||
False,
|
||||
),
|
||||
# ---------------------------------------
|
||||
# from pendulum.datetime to pendulum.datetime object
|
||||
@ -197,6 +212,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 4, 4, 0, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC017",
|
||||
@ -207,6 +223,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 4, 4, 3, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC018",
|
||||
@ -217,6 +234,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 4, 4, 3, 0, 0, tz="Europe/Berlin"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC019",
|
||||
@ -227,6 +245,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
False,
|
||||
pendulum.datetime(2024, 4, 4, 0, 0, 0, tz="Etc/UTC"),
|
||||
False,
|
||||
),
|
||||
# ---------------------------------------
|
||||
# from string to UTC string
|
||||
@ -242,6 +261,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
"2023-11-06T00:00:00Z",
|
||||
False,
|
||||
),
|
||||
# local timezone "Europe/Berlin"
|
||||
(
|
||||
@ -253,6 +273,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
"2023-11-05T23:00:00Z",
|
||||
False,
|
||||
),
|
||||
# - no microseconds
|
||||
(
|
||||
@ -264,6 +285,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
"2024-10-29T23:00:00Z",
|
||||
False,
|
||||
),
|
||||
(
|
||||
"TC023",
|
||||
@ -274,6 +296,7 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
"2024-10-30T00:00:00Z",
|
||||
False,
|
||||
),
|
||||
# - with microseconds
|
||||
(
|
||||
@ -285,6 +308,23 @@ from akkudoktoreos.utils.datetimeutil import (
|
||||
None,
|
||||
None,
|
||||
"2024-10-07T08:20:30Z",
|
||||
False,
|
||||
),
|
||||
# ---------------------------------------
|
||||
# from None to pendulum.datetime object
|
||||
# ---------------------------------------
|
||||
# - no timezone
|
||||
# local timezone
|
||||
(
|
||||
"TC025",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
pendulum.now(),
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -298,6 +338,7 @@ def test_to_datetime(
|
||||
to_naiv,
|
||||
to_maxtime,
|
||||
expected_output,
|
||||
expected_approximately,
|
||||
):
|
||||
"""Test pendulum.datetime conversion with valid inputs."""
|
||||
set_other_timezone(local_timezone)
|
||||
@ -326,7 +367,10 @@ def test_to_datetime(
|
||||
# print(f"Expected: {expected_output} tz={expected_output.timezone}")
|
||||
# print(f"Result: {result} tz={result.timezone}")
|
||||
# print(f"Compare: {compare}")
|
||||
assert compare.equal == True
|
||||
if expected_approximately:
|
||||
assert compare.time_diff < 200
|
||||
else:
|
||||
assert compare.equal == True
|
||||
|
||||
|
||||
# -----------------------------
|
||||
|
@ -2,7 +2,9 @@ import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.prediction.elecpriceakkudoktor import (
|
||||
@ -66,6 +68,35 @@ def test_invalid_provider(elecprice_provider, monkeypatch):
|
||||
# ------------------------------------------------
|
||||
|
||||
|
||||
@patch("akkudoktoreos.prediction.elecpriceakkudoktor.logger.error")
|
||||
def test_validate_data_invalid_format(mock_logger, elecprice_provider):
|
||||
"""Test validation for invalid Akkudoktor data."""
|
||||
invalid_data = '{"invalid": "data"}'
|
||||
with pytest.raises(ValueError):
|
||||
elecprice_provider._validate_data(invalid_data)
|
||||
mock_logger.assert_called_once_with(mock_logger.call_args[0][0])
|
||||
|
||||
|
||||
def test_calculate_weighted_mean(elecprice_provider):
|
||||
"""Test calculation of weighted mean for electricity prices."""
|
||||
elecprice_provider.elecprice_8days = np.random.rand(24, 8) * 100
|
||||
price_mean = elecprice_provider._calculate_weighted_mean(day_of_week=2, hour=10)
|
||||
assert isinstance(price_mean, float)
|
||||
assert not np.isnan(price_mean)
|
||||
expected = np.array(
|
||||
[
|
||||
[1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 1.0],
|
||||
[0.25, 1.0, 0.5, 0.125, 0.0625, 0.03125, 0.015625, 1.0],
|
||||
[0.125, 0.5, 1.0, 0.25, 0.0625, 0.03125, 0.015625, 1.0],
|
||||
[0.0625, 0.125, 0.5, 1.0, 0.25, 0.03125, 0.015625, 1.0],
|
||||
[0.0625, 0.125, 0.25, 0.5, 1.0, 0.03125, 0.015625, 1.0],
|
||||
[0.015625, 0.03125, 0.0625, 0.125, 0.5, 1.0, 0.25, 1.0],
|
||||
[0.015625, 0.03125, 0.0625, 0.125, 0.25, 0.5, 1.0, 1.0],
|
||||
]
|
||||
)
|
||||
np.testing.assert_array_equal(elecprice_provider.elecprice_8days_weights_day_of_week, expected)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_request_forecast(mock_get, elecprice_provider, sample_akkudoktor_1_json):
|
||||
"""Test requesting forecast from Akkudoktor."""
|
||||
@ -110,7 +141,7 @@ def test_update_data(mock_get, elecprice_provider, sample_akkudoktor_1_json, cac
|
||||
|
||||
# Assert: Verify the result is as expected
|
||||
mock_get.assert_called_once()
|
||||
assert len(elecprice_provider) == 25
|
||||
assert len(elecprice_provider) == 49 # prediction hours + 1
|
||||
|
||||
# Assert we get prediction_hours prioce values by resampling
|
||||
np_price_array = elecprice_provider.key_to_array(
|
||||
@ -124,6 +155,63 @@ def test_update_data(mock_get, elecprice_provider, sample_akkudoktor_1_json, cac
|
||||
# f_out.write(elecprice_provider.to_json())
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_update_data_with_incomplete_forecast(mock_get, elecprice_provider):
|
||||
"""Test `_update_data` with incomplete or missing forecast data."""
|
||||
incomplete_data: dict = {"meta": {}, "values": []}
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = json.dumps(incomplete_data)
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(ValueError):
|
||||
elecprice_provider._update_data(force_update=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, exception",
|
||||
[(400, requests.exceptions.HTTPError), (500, requests.exceptions.HTTPError), (200, None)],
|
||||
)
|
||||
@patch("requests.get")
|
||||
def test_request_forecast_status_codes(
|
||||
mock_get, elecprice_provider, sample_akkudoktor_1_json, status_code, exception
|
||||
):
|
||||
"""Test handling of various API status codes."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.content = json.dumps(sample_akkudoktor_1_json)
|
||||
mock_response.raise_for_status.side_effect = (
|
||||
requests.exceptions.HTTPError if exception else None
|
||||
)
|
||||
mock_get.return_value = mock_response
|
||||
if exception:
|
||||
with pytest.raises(exception):
|
||||
elecprice_provider._request_forecast()
|
||||
else:
|
||||
elecprice_provider._request_forecast()
|
||||
|
||||
|
||||
@patch("akkudoktoreos.utils.cacheutil.CacheFileStore")
|
||||
def test_cache_integration(mock_cache, elecprice_provider):
|
||||
"""Test caching of 8-day electricity price data."""
|
||||
mock_cache_instance = mock_cache.return_value
|
||||
mock_cache_instance.get.return_value = None # Simulate no cache
|
||||
elecprice_provider._update_data(force_update=True)
|
||||
mock_cache_instance.create.assert_called_once()
|
||||
mock_cache_instance.get.assert_called_once()
|
||||
|
||||
|
||||
def test_key_to_array_resampling(elecprice_provider):
|
||||
"""Test resampling of forecast data to NumPy array."""
|
||||
elecprice_provider.update_data(force_update=True)
|
||||
array = elecprice_provider.key_to_array(
|
||||
key="elecprice_marketprice",
|
||||
start_datetime=elecprice_provider.start_datetime,
|
||||
end_datetime=elecprice_provider.end_datetime,
|
||||
)
|
||||
assert isinstance(array, np.ndarray)
|
||||
assert len(array) == elecprice_provider.total_hours
|
||||
|
||||
|
||||
# ------------------------------------------------
|
||||
# Development Akkudoktor
|
||||
# ------------------------------------------------
|
||||
|
@ -96,7 +96,9 @@ def test_import(elecprice_provider, sample_import_1_json, start_datetime, from_f
|
||||
assert elecprice_provider.total_hours is not None
|
||||
assert compare_datetimes(elecprice_provider.start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["elecprice_marketprice"]
|
||||
value_datetime_mapping = elecprice_provider.import_datetimes(len(values))
|
||||
value_datetime_mapping = elecprice_provider.import_datetimes(
|
||||
ems_eos.start_datetime, len(values)
|
||||
)
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(elecprice_provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
|
@ -1,39 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from akkudoktoreos.prediction.load_aggregator import LoadAggregator
|
||||
|
||||
|
||||
def test_initialization():
|
||||
aggregator = LoadAggregator()
|
||||
assert aggregator.prediction_hours == 24
|
||||
assert aggregator.loads == {}
|
||||
|
||||
|
||||
def test_add_load_valid():
|
||||
aggregator = LoadAggregator(prediction_hours=3)
|
||||
aggregator.add_load("Source1", [10.0, 20.0, 30.0])
|
||||
assert aggregator.loads["Source1"] == [10.0, 20.0, 30.0]
|
||||
|
||||
|
||||
def test_add_load_invalid_length():
|
||||
aggregator = LoadAggregator(prediction_hours=3)
|
||||
with pytest.raises(ValueError, match="Total load inconsistent lengths in arrays: Source1 2"):
|
||||
aggregator.add_load("Source1", [10.0, 20.0])
|
||||
|
||||
|
||||
def test_calculate_total_load_empty():
|
||||
aggregator = LoadAggregator()
|
||||
assert aggregator.calculate_total_load() == []
|
||||
|
||||
|
||||
def test_calculate_total_load():
|
||||
aggregator = LoadAggregator(prediction_hours=3)
|
||||
aggregator.add_load("Source1", [10.0, 20.0, 30.0])
|
||||
aggregator.add_load("Source2", [5.0, 15.0, 25.0])
|
||||
assert aggregator.calculate_total_load() == [15.0, 35.0, 55.0]
|
||||
|
||||
|
||||
def test_calculate_total_load_single_source():
|
||||
aggregator = LoadAggregator(prediction_hours=3)
|
||||
aggregator.add_load("Source1", [10.0, 20.0, 30.0])
|
||||
assert aggregator.calculate_total_load() == [10.0, 20.0, 30.0]
|
@ -6,18 +6,20 @@ import pytest
|
||||
|
||||
from akkudoktoreos.config.config import get_config
|
||||
from akkudoktoreos.core.ems import get_ems
|
||||
from akkudoktoreos.measurement.measurement import MeasurementDataRecord, get_measurement
|
||||
from akkudoktoreos.prediction.loadakkudoktor import (
|
||||
LoadAkkudoktor,
|
||||
LoadAkkudoktorCommonSettings,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
|
||||
config_eos = get_config()
|
||||
ems_eos = get_ems()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def load_provider(monkeypatch):
|
||||
"""Fixture to create a LoadAkkudoktor instance."""
|
||||
def load_provider():
|
||||
"""Fixture to initialise the LoadAkkudoktor instance."""
|
||||
settings = {
|
||||
"load_provider": "LoadAkkudoktor",
|
||||
"load_name": "Akkudoktor Profile",
|
||||
@ -27,6 +29,30 @@ def load_provider(monkeypatch):
|
||||
return LoadAkkudoktor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def measurement_eos():
|
||||
"""Fixture to initialise the Measurement instance."""
|
||||
measurement = get_measurement()
|
||||
load0_mr = 500
|
||||
load1_mr = 500
|
||||
dt = to_datetime("2024-01-01T00:00:00")
|
||||
interval = to_duration("1 hour")
|
||||
for i in range(25):
|
||||
measurement.records.append(
|
||||
MeasurementDataRecord(
|
||||
date_time=dt,
|
||||
measurement_load0_mr=load0_mr,
|
||||
measurement_load1_mr=load1_mr,
|
||||
)
|
||||
)
|
||||
dt += interval
|
||||
load0_mr += 50
|
||||
load1_mr += 50
|
||||
assert compare_datetimes(measurement.min_datetime, to_datetime("2024-01-01T00:00:00")).equal
|
||||
assert compare_datetimes(measurement.max_datetime, to_datetime("2024-01-02T00:00:00")).equal
|
||||
return measurement
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load_profiles_file(tmp_path):
|
||||
"""Fixture to create a mock load profiles file."""
|
||||
@ -97,3 +123,90 @@ def test_update_data(mock_load_data, load_provider):
|
||||
|
||||
# Validate that update_value is called
|
||||
assert len(load_provider) > 0
|
||||
|
||||
|
||||
def test_calculate_adjustment(load_provider, measurement_eos):
|
||||
"""Test `_calculate_adjustment` for various scenarios."""
|
||||
data_year_energy = np.random.rand(365, 2, 24)
|
||||
|
||||
# Call the method and validate results
|
||||
weekday_adjust, weekend_adjust = load_provider._calculate_adjustment(data_year_energy)
|
||||
assert weekday_adjust.shape == (24,)
|
||||
assert weekend_adjust.shape == (24,)
|
||||
|
||||
data_year_energy = np.zeros((365, 2, 24))
|
||||
weekday_adjust, weekend_adjust = load_provider._calculate_adjustment(data_year_energy)
|
||||
|
||||
assert weekday_adjust.shape == (24,)
|
||||
expected = np.array(
|
||||
[
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
100.0,
|
||||
]
|
||||
)
|
||||
np.testing.assert_array_equal(weekday_adjust, expected)
|
||||
|
||||
assert weekend_adjust.shape == (24,)
|
||||
expected = np.array(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
]
|
||||
)
|
||||
np.testing.assert_array_equal(weekend_adjust, expected)
|
||||
|
||||
|
||||
def test_load_provider_adjustments_with_mock_data(load_provider):
|
||||
"""Test full integration of adjustments with mock data."""
|
||||
with patch(
|
||||
"akkudoktoreos.prediction.loadakkudoktor.LoadAkkudoktor._calculate_adjustment"
|
||||
) as mock_adjust:
|
||||
mock_adjust.return_value = (np.zeros(24), np.zeros(24))
|
||||
|
||||
# Test execution
|
||||
load_provider._update_data()
|
||||
assert mock_adjust.called
|
||||
|
218
tests/test_measurement.py
Normal file
218
tests/test_measurement.py
Normal file
@ -0,0 +1,218 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pendulum import datetime, duration
|
||||
|
||||
from akkudoktoreos.config.config import SettingsEOS
|
||||
from akkudoktoreos.measurement.measurement import MeasurementDataRecord, get_measurement
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def measurement_eos():
|
||||
"""Fixture to create a Measurement instance."""
|
||||
measurement = get_measurement()
|
||||
measurement.records = [
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=0),
|
||||
measurement_load0_mr=100,
|
||||
measurement_load1_mr=200,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=1),
|
||||
measurement_load0_mr=150,
|
||||
measurement_load1_mr=250,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=2),
|
||||
measurement_load0_mr=200,
|
||||
measurement_load1_mr=300,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=3),
|
||||
measurement_load0_mr=250,
|
||||
measurement_load1_mr=350,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=4),
|
||||
measurement_load0_mr=300,
|
||||
measurement_load1_mr=400,
|
||||
),
|
||||
MeasurementDataRecord(
|
||||
date_time=datetime(2023, 1, 1, hour=5),
|
||||
measurement_load0_mr=350,
|
||||
measurement_load1_mr=450,
|
||||
),
|
||||
]
|
||||
return measurement
|
||||
|
||||
|
||||
def test_interval_count(measurement_eos):
|
||||
"""Test interval count calculation."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
interval = duration(hours=1)
|
||||
|
||||
assert measurement_eos._interval_count(start, end, interval) == 3
|
||||
|
||||
|
||||
def test_interval_count_invalid_end_before_start(measurement_eos):
|
||||
"""Test interval count raises ValueError when end_datetime is before start_datetime."""
|
||||
start = datetime(2023, 1, 1, 3)
|
||||
end = datetime(2023, 1, 1, 0)
|
||||
interval = duration(hours=1)
|
||||
|
||||
with pytest.raises(ValueError, match="end_datetime must be after start_datetime"):
|
||||
measurement_eos._interval_count(start, end, interval)
|
||||
|
||||
|
||||
def test_interval_count_invalid_non_positive_interval(measurement_eos):
|
||||
"""Test interval count raises ValueError when interval is non-positive."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
|
||||
with pytest.raises(ValueError, match="interval must be positive"):
|
||||
measurement_eos._interval_count(start, end, duration(hours=0))
|
||||
|
||||
|
||||
def test_energy_from_meter_readings_valid_input(measurement_eos):
|
||||
"""Test _energy_from_meter_readings with valid inputs and proper alignment of load data."""
|
||||
key = "measurement_load0_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
interval = duration(hours=1)
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
)
|
||||
|
||||
expected_load_array = np.array([50, 50, 50, 50, 50]) # Differences between consecutive readings
|
||||
np.testing.assert_array_equal(load_array, expected_load_array)
|
||||
|
||||
|
||||
def test_energy_from_meter_readings_empty_array(measurement_eos):
|
||||
"""Test _energy_from_meter_readings with no data (empty array)."""
|
||||
key = "measurement_load0_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
interval = duration(hours=1)
|
||||
|
||||
# Use empyt records array
|
||||
measurement_eos.records = []
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
)
|
||||
|
||||
# Expected: an array of zeros with one less than the number of intervals
|
||||
expected_size = (
|
||||
measurement_eos._interval_count(start_datetime, end_datetime + interval, interval) - 1
|
||||
)
|
||||
expected_load_array = np.zeros(expected_size)
|
||||
np.testing.assert_array_equal(load_array, expected_load_array)
|
||||
|
||||
|
||||
def test_energy_from_meter_readings_misaligned_array(measurement_eos):
|
||||
"""Test _energy_from_meter_readings with misaligned array size."""
|
||||
key = "measurement_load1_mr"
|
||||
start_datetime = measurement_eos.min_datetime
|
||||
end_datetime = measurement_eos.max_datetime
|
||||
interval = duration(hours=1)
|
||||
|
||||
# Use misaligned array, latest interval set to 2 hours (instead of 1 hour)
|
||||
measurement_eos.records[-1].date_time = datetime(2023, 1, 1, 6)
|
||||
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
)
|
||||
|
||||
expected_load_array = np.array([50, 50, 50, 50, 25]) # Differences between consecutive readings
|
||||
np.testing.assert_array_equal(load_array, expected_load_array)
|
||||
|
||||
|
||||
def test_energy_from_meter_readings_partial_data(measurement_eos, caplog):
|
||||
"""Test _energy_from_meter_readings with partial data (misaligned but empty array)."""
|
||||
key = "measurement_load2_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
interval = duration(hours=1)
|
||||
|
||||
with caplog.at_level("DEBUG"):
|
||||
load_array = measurement_eos._energy_from_meter_readings(
|
||||
key, start_datetime, end_datetime, interval
|
||||
)
|
||||
|
||||
expected_size = (
|
||||
measurement_eos._interval_count(start_datetime, end_datetime + interval, interval) - 1
|
||||
)
|
||||
expected_load_array = np.zeros(expected_size)
|
||||
np.testing.assert_array_equal(load_array, expected_load_array)
|
||||
|
||||
|
||||
def test_energy_from_meter_readings_negative_interval(measurement_eos):
|
||||
"""Test _energy_from_meter_readings with a negative interval."""
|
||||
key = "measurement_load3_mr"
|
||||
start_datetime = datetime(2023, 1, 1, 0)
|
||||
end_datetime = datetime(2023, 1, 1, 5)
|
||||
interval = duration(hours=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="interval must be positive"):
|
||||
measurement_eos._energy_from_meter_readings(key, start_datetime, end_datetime, interval)
|
||||
|
||||
|
||||
def test_load_total(measurement_eos):
|
||||
"""Test total load calculation."""
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 2)
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total(start_datetime=start, end_datetime=end, interval=interval)
|
||||
|
||||
# Expected total load per interval
|
||||
expected = np.array([100, 100]) # Differences between consecutive meter readings
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
|
||||
def test_load_total_no_data(measurement_eos):
|
||||
"""Test total load calculation with no data."""
|
||||
measurement_eos.records = []
|
||||
start = datetime(2023, 1, 1, 0)
|
||||
end = datetime(2023, 1, 1, 3)
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total(start_datetime=start, end_datetime=end, interval=interval)
|
||||
expected = np.zeros(3) # No data, so all intervals are zero
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
|
||||
def test_name_to_key(measurement_eos):
|
||||
"""Test name_to_key functionality."""
|
||||
settings = SettingsEOS(
|
||||
measurement_load0_name="Household",
|
||||
measurement_load1_name="Heat Pump",
|
||||
)
|
||||
measurement_eos.config.merge_settings(settings)
|
||||
|
||||
assert measurement_eos.name_to_key("Household", "measurement_load") == "measurement_load0_mr"
|
||||
assert measurement_eos.name_to_key("Heat Pump", "measurement_load") == "measurement_load1_mr"
|
||||
assert measurement_eos.name_to_key("Unknown", "measurement_load") is None
|
||||
|
||||
|
||||
def test_name_to_key_invalid_topic(measurement_eos):
|
||||
"""Test name_to_key with an invalid topic."""
|
||||
settings = SettingsEOS(
|
||||
measurement_load0_name="Household",
|
||||
measurement_load1_name="Heat Pump",
|
||||
)
|
||||
measurement_eos.config.merge_settings(settings)
|
||||
|
||||
assert measurement_eos.name_to_key("Household", "invalid_topic") is None
|
||||
|
||||
|
||||
def test_load_total_partial_intervals(measurement_eos):
|
||||
"""Test total load calculation with partial intervals."""
|
||||
start = datetime(2023, 1, 1, 0, 30) # Start in the middle of an interval
|
||||
end = datetime(2023, 1, 1, 1, 30) # End in the middle of another interval
|
||||
interval = duration(hours=1)
|
||||
|
||||
result = measurement_eos.load_total(start_datetime=start, end_datetime=end, interval=interval)
|
||||
expected = np.array([100]) # Only one complete interval covered
|
||||
np.testing.assert_array_equal(result, expected)
|
@ -406,7 +406,7 @@ class TestPredictionContainer:
|
||||
del container_with_providers["prediction_value"]
|
||||
series = container_with_providers["prediction_value"]
|
||||
assert series.name == "prediction_value"
|
||||
assert series.tolist() == [None, None, None]
|
||||
assert series.tolist() == []
|
||||
|
||||
def test_delitem_non_existing_key(self, container_with_providers):
|
||||
with pytest.raises(KeyError, match="Key 'non_existent_key' not found"):
|
||||
|
@ -303,5 +303,5 @@ def test_timezone_behaviour(
|
||||
forecast_measured = provider.key_to_series(
|
||||
"pvforecastakkudoktor_ac_power_measured", other_start_datetime, other_end_datetime
|
||||
)
|
||||
assert len(forecast_measured) == 48
|
||||
assert len(forecast_measured) == 1
|
||||
assert forecast_measured.iloc[0] == 1000.0 # changed before
|
||||
|
@ -96,7 +96,9 @@ def test_import(pvforecast_provider, sample_import_1_json, start_datetime, from_
|
||||
assert pvforecast_provider.total_hours is not None
|
||||
assert compare_datetimes(pvforecast_provider.start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["pvforecast_ac_power"]
|
||||
value_datetime_mapping = pvforecast_provider.import_datetimes(len(values))
|
||||
value_datetime_mapping = pvforecast_provider.import_datetimes(
|
||||
ems_eos.start_datetime, len(values)
|
||||
)
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(pvforecast_provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
|
116
tests/test_pydantic.py
Normal file
116
tests/test_pydantic.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import pendulum
|
||||
import pytest
|
||||
from pydantic import Field, ValidationError
|
||||
|
||||
from akkudoktoreos.core.pydantic import (
|
||||
PydanticBaseModel,
|
||||
PydanticDateTimeData,
|
||||
PydanticDateTimeDataFrame,
|
||||
PydanticDateTimeSeries,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
|
||||
|
||||
|
||||
class PydanticTestModel(PydanticBaseModel):
|
||||
datetime_field: pendulum.DateTime = Field(
|
||||
..., description="A datetime field with pendulum support."
|
||||
)
|
||||
optional_field: Optional[str] = Field(default=None, description="An optional field.")
|
||||
|
||||
|
||||
class TestPydanticBaseModel:
|
||||
def test_valid_pendulum_datetime(self):
|
||||
dt = pendulum.now()
|
||||
model = PydanticTestModel(datetime_field=dt)
|
||||
assert model.datetime_field == dt
|
||||
|
||||
def test_invalid_datetime_string(self):
|
||||
with pytest.raises(ValidationError, match="Input should be an instance of DateTime"):
|
||||
PydanticTestModel(datetime_field="invalid_datetime")
|
||||
|
||||
def test_iso8601_serialization(self):
|
||||
dt = pendulum.datetime(2024, 12, 21, 15, 0, 0)
|
||||
model = PydanticTestModel(datetime_field=dt)
|
||||
serialized = model.to_dict()
|
||||
expected_dt = to_datetime(dt)
|
||||
result_dt = to_datetime(serialized["datetime_field"])
|
||||
assert compare_datetimes(result_dt, expected_dt)
|
||||
|
||||
def test_reset_to_defaults(self):
|
||||
dt = pendulum.now()
|
||||
model = PydanticTestModel(datetime_field=dt, optional_field="some value")
|
||||
model.reset_to_defaults()
|
||||
assert model.datetime_field == dt
|
||||
assert model.optional_field is None
|
||||
|
||||
def test_from_dict_and_to_dict(self):
|
||||
dt = pendulum.now()
|
||||
model = PydanticTestModel(datetime_field=dt)
|
||||
data = model.to_dict()
|
||||
restored_model = PydanticTestModel.from_dict(data)
|
||||
assert restored_model.datetime_field == dt
|
||||
|
||||
def test_to_json_and_from_json(self):
|
||||
dt = pendulum.now()
|
||||
model = PydanticTestModel(datetime_field=dt)
|
||||
json_data = model.to_json()
|
||||
restored_model = PydanticTestModel.from_json(json_data)
|
||||
assert restored_model.datetime_field == dt
|
||||
|
||||
|
||||
class TestPydanticDateTimeData:
|
||||
def test_valid_list_lengths(self):
|
||||
data = {
|
||||
"timestamps": ["2024-12-21T15:00:00+00:00"],
|
||||
"values": [100],
|
||||
}
|
||||
model = PydanticDateTimeData(root=data)
|
||||
assert pendulum.parse(model.root["timestamps"][0]) == pendulum.parse(
|
||||
"2024-12-21T15:00:00+00:00"
|
||||
)
|
||||
|
||||
def test_invalid_list_lengths(self):
|
||||
data = {
|
||||
"timestamps": ["2024-12-21T15:00:00+00:00"],
|
||||
"values": [100, 200],
|
||||
}
|
||||
with pytest.raises(
|
||||
ValidationError, match="All lists in the dictionary must have the same length"
|
||||
):
|
||||
PydanticDateTimeData(root=data)
|
||||
|
||||
|
||||
class TestPydanticDateTimeDataFrame:
|
||||
def test_valid_dataframe(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"value": [100, 200],
|
||||
},
|
||||
index=pd.to_datetime(["2024-12-21", "2024-12-22"]),
|
||||
)
|
||||
model = PydanticDateTimeDataFrame.from_dataframe(df)
|
||||
result = model.to_dataframe()
|
||||
|
||||
# Check index
|
||||
assert len(result.index) == len(df.index)
|
||||
for i, dt in enumerate(df.index):
|
||||
expected_dt = to_datetime(dt)
|
||||
result_dt = to_datetime(result.index[i])
|
||||
assert compare_datetimes(result_dt, expected_dt).equal
|
||||
|
||||
|
||||
class TestPydanticDateTimeSeries:
|
||||
def test_valid_series(self):
|
||||
series = pd.Series([100, 200], index=pd.to_datetime(["2024-12-21", "2024-12-22"]))
|
||||
model = PydanticDateTimeSeries.from_series(series)
|
||||
result = model.to_series()
|
||||
|
||||
# Check index
|
||||
assert len(result.index) == len(series.index)
|
||||
for i, dt in enumerate(series.index):
|
||||
expected_dt = to_datetime(dt)
|
||||
result_dt = to_datetime(result.index[i])
|
||||
assert compare_datetimes(result_dt, expected_dt).equal
|
@ -13,5 +13,5 @@ def test_server(server):
|
||||
assert config_eos.data_folder_path is not None
|
||||
assert config_eos.data_folder_path.is_dir()
|
||||
|
||||
result = requests.get(f"{server}/config?")
|
||||
result = requests.get(f"{server}/v1/config?")
|
||||
assert result.status_code == HTTPStatus.OK
|
||||
|
@ -1,10 +1,10 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from matplotlib.testing.compare import compare_images
|
||||
|
||||
from akkudoktoreos.config.config import get_config
|
||||
from akkudoktoreos.utils.visualize import generate_example_report
|
||||
|
||||
filename = "example_report.pdf"
|
||||
|
||||
@ -17,14 +17,13 @@ DIR_TESTDATA = Path(__file__).parent / "testdata"
|
||||
reference_file = DIR_TESTDATA / "test_example_report.pdf"
|
||||
|
||||
|
||||
def test_generate_pdf_main():
|
||||
def test_generate_pdf_example():
|
||||
"""Test generation of example visualization report."""
|
||||
# Delete the old generated file if it exists
|
||||
if os.path.isfile(output_file):
|
||||
os.remove(output_file)
|
||||
|
||||
# Execute the __main__ block of visualize.py by running it as a script
|
||||
script_path = Path(__file__).parent.parent / "src" / "akkudoktoreos" / "utils" / "visualize.py"
|
||||
subprocess.run(["python", str(script_path)], check=True)
|
||||
generate_example_report(filename)
|
||||
|
||||
# Check if the file exists
|
||||
assert os.path.isfile(output_file)
|
||||
|
@ -96,7 +96,7 @@ def test_import(weather_provider, sample_import_1_json, start_datetime, from_fil
|
||||
assert weather_provider.total_hours is not None
|
||||
assert compare_datetimes(weather_provider.start_datetime, ems_eos.start_datetime).equal
|
||||
values = sample_import_1_json["weather_temp_air"]
|
||||
value_datetime_mapping = weather_provider.import_datetimes(len(values))
|
||||
value_datetime_mapping = weather_provider.import_datetimes(ems_eos.start_datetime, len(values))
|
||||
for i, mapping in enumerate(value_datetime_mapping):
|
||||
assert i < len(weather_provider.records)
|
||||
expected_datetime, expected_value_index = mapping
|
||||
|
Loading…
Reference in New Issue
Block a user