mirror of
https://github.com/Akkudoktor-EOS/EOS.git
synced 2025-09-20 10:41:14 +00:00
Fix2 config and predictions revamp. (#281)
measurement: - Add new measurement class to hold real world measurements. - Handles load meter readings, grid import and export meter readings. - Aggregates load meter readings aka. measurements to total load. - Can import measurements from files, pandas datetime series, pandas datetime dataframes, simple daetime arrays and programmatically. - Maybe expanded to other measurement values. - Should be used for load prediction adaptions by real world measurements. core/coreabc: - Add mixin class to access measurements core/pydantic: - Add pydantic models for pandas datetime series and dataframes. - Add pydantic models for simple datetime array core/dataabc: - Provide DataImport mixin class for generic import handling. Imports from JSON string and files. Imports from pandas datetime dataframes and simple datetime arrays. Signature of import method changed to allow import datetimes to be given programmatically and by data content. - Use pydantic models for datetime series, dataframes, arrays - Validate generic imports by pydantic models - Provide new attributes min_datetime and max_datetime for DataSequence. - Add parameter dropna to drop NAN/ None values when creating lists, pandas series or numpy array from DataSequence. config/config: - Add common settings for the measurement module. predictions/elecpriceakkudoktor: - Use mean values of last 7 days to fill prediction values not provided by akkudoktor.net (only provides 24 values). prediction/loadabc: - Extend the generic prediction keys by 'load_total_adjusted' for load predictions that adjust the predicted total load by measured load values. prediction/loadakkudoktor: - Extend the Akkudoktor load prediction by load adjustment using measured load values. prediction/load_aggregator: - Module removed. Load aggregation is now handled by the measurement module. prediction/load_corrector: - Module removed. Load correction (aka. adjustment of load prediction by measured load energy) is handled by the LoadAkkudoktor prediction and the generic 'load_mean_adjusted' prediction key. prediction/load_forecast: - Module removed. Functionality now completely handled by the LoadAkkudoktor prediction. utils/cacheutil: - Use pydantic. - Fix potential bug in ttl (time to live) duration handling. utils/datetimeutil: - Added missing handling of pendulum.DateTime and pendulum.Duration instances as input. Handled before as datetime.datetime and datetime.timedelta. utils/visualize: - Move main to generate_example_report() for better testing support. server/server: - Added new configuration option server_fastapi_startup_server_fasthtml to make startup of FastHTML server by FastAPI server conditional. server/fastapi_server: - Add APIs for measurements - Improve APIs to provide or take pandas datetime series and datetime dataframes controlled by pydantic model. - Improve APIs to provide or take simple datetime data arrays controlled by pydantic model. - Move fastAPI server API to v1 for new APIs. - Update pre v1 endpoints to use new prediction and measurement capabilities. - Only start FastHTML server if 'server_fastapi_startup_server_fasthtml' config option is set. tests: - Adapt import tests to changed import method signature - Adapt server test to use the v1 API - Extend the dataabc test to test for array generation from data with several data interval scenarios. - Extend the datetimeutil test to also test for correct handling of to_datetime() providing now(). - Adapt LoadAkkudoktor test for new adjustment calculation. - Adapt visualization test to use example report function instead of visualize.py run as process. - Removed test_load_aggregator. Functionality is now tested in test_measurement. - Added tests for measurement module docs: - Remove sphinxcontrib-openapi as it prevents build of documentation. "site-packages/sphinxcontrib/openapi/openapi31.py", line 305, in _get_type_from_schema for t in schema["anyOf"]: KeyError: 'anyOf'" Signed-off-by: Bobby Noelte <b0661n0e17e@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from akkudoktoreos.utils.logutil import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
config_eos: Any = None
|
||||
measurement_eos: Any = None
|
||||
prediction_eos: Any = None
|
||||
devices_eos: Any = None
|
||||
ems_eos: Any = None
|
||||
@@ -50,7 +51,7 @@ class ConfigMixin:
|
||||
|
||||
@property
|
||||
def config(self) -> Any:
|
||||
"""Convenience method/ attribute to retrieve the EOS onfiguration data.
|
||||
"""Convenience method/ attribute to retrieve the EOS configuration data.
|
||||
|
||||
Returns:
|
||||
ConfigEOS: The configuration.
|
||||
@@ -65,6 +66,46 @@ class ConfigMixin:
|
||||
return config_eos
|
||||
|
||||
|
||||
class MeasurementMixin:
|
||||
"""Mixin class for managing EOS measurement data.
|
||||
|
||||
This class serves as a foundational component for EOS-related classes requiring access
|
||||
to global measurement data. It provides a `measurement` property that dynamically retrieves
|
||||
the measurement instance, ensuring up-to-date access to measurement results.
|
||||
|
||||
Usage:
|
||||
Subclass this base class to gain access to the `measurement` attribute, which retrieves the
|
||||
global measurement instance lazily to avoid import-time circular dependencies.
|
||||
|
||||
Attributes:
|
||||
measurement (Measurement): Property to access the global EOS measurement data.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyOptimizationClass(MeasurementMixin):
|
||||
def analyze_mymeasurement(self):
|
||||
measurement_data = self.measurement.mymeasurement
|
||||
# Perform analysis
|
||||
```
|
||||
"""
|
||||
|
||||
@property
|
||||
def measurement(self) -> Any:
|
||||
"""Convenience method/ attribute to retrieve the EOS measurement data.
|
||||
|
||||
Returns:
|
||||
Measurement: The measurement.
|
||||
"""
|
||||
# avoid circular dependency at import time
|
||||
global measurement_eos
|
||||
if measurement_eos is None:
|
||||
from akkudoktoreos.measurement.measurement import get_measurement
|
||||
|
||||
measurement_eos = get_measurement()
|
||||
|
||||
return measurement_eos
|
||||
|
||||
|
||||
class PredictionMixin:
|
||||
"""Mixin class for managing EOS prediction data.
|
||||
|
||||
|
@@ -21,10 +21,21 @@ import pandas as pd
|
||||
import pendulum
|
||||
from numpydantic import NDArray, Shape
|
||||
from pendulum import DateTime, Duration
|
||||
from pydantic import AwareDatetime, ConfigDict, Field, computed_field, field_validator
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationError,
|
||||
computed_field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin, StartMixin
|
||||
from akkudoktoreos.core.pydantic import PydanticBaseModel
|
||||
from akkudoktoreos.core.pydantic import (
|
||||
PydanticBaseModel,
|
||||
PydanticDateTimeData,
|
||||
PydanticDateTimeDataFrame,
|
||||
)
|
||||
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
|
||||
from akkudoktoreos.utils.logutil import get_logger
|
||||
|
||||
@@ -47,22 +58,22 @@ class DataRecord(DataBase, MutableMapping):
|
||||
and attribute-style access (`record.field_name`).
|
||||
|
||||
Attributes:
|
||||
date_time (Optional[AwareDatetime]): Aware datetime indicating when the data record applies.
|
||||
date_time (Optional[DateTime]): Aware datetime indicating when the data record applies.
|
||||
|
||||
Configurations:
|
||||
- Allows mutation after creation.
|
||||
- Supports non-standard data types like `datetime`.
|
||||
"""
|
||||
|
||||
date_time: Optional[AwareDatetime] = Field(default=None, description="DateTime")
|
||||
date_time: Optional[DateTime] = Field(default=None, description="DateTime")
|
||||
|
||||
# Pydantic v2 model configuration
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
@field_validator("date_time", mode="before")
|
||||
@classmethod
|
||||
def transform_to_datetime(cls, value: Any) -> DateTime:
|
||||
"""Converts various datetime formats into AwareDatetime."""
|
||||
def transform_to_datetime(cls, value: Any) -> Optional[DateTime]:
|
||||
"""Converts various datetime formats into DateTime."""
|
||||
if value is None:
|
||||
# Allow to set to default.
|
||||
return None
|
||||
@@ -307,6 +318,38 @@ class DataSequence(DataBase, MutableSequence):
|
||||
records: List[DataRecord] = Field(default_factory=list, description="List of data records")
|
||||
|
||||
# Derived fields (computed)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def min_datetime(self) -> Optional[DateTime]:
|
||||
"""Minimum (earliest) datetime in the sorted sequence of data records.
|
||||
|
||||
This property computes the earliest datetime from the sequence of data records.
|
||||
If no records are present, it returns `None`.
|
||||
|
||||
Returns:
|
||||
Optional[DateTime]: The earliest datetime in the sequence, or `None` if no
|
||||
data records exist.
|
||||
"""
|
||||
if len(self.records) == 0:
|
||||
return None
|
||||
return self.records[0].date_time
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def max_datetime(self) -> DateTime:
|
||||
"""Maximum (latest) datetime in the sorted sequence of data records.
|
||||
|
||||
This property computes the latest datetime from the sequence of data records.
|
||||
If no records are present, it returns `None`.
|
||||
|
||||
Returns:
|
||||
Optional[DateTime]: The latest datetime in the sequence, or `None` if no
|
||||
data records exist.
|
||||
"""
|
||||
if len(self.records) == 0:
|
||||
return None
|
||||
return self.records[-1].date_time
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def record_keys(self) -> List[str]:
|
||||
@@ -319,12 +362,31 @@ class DataSequence(DataBase, MutableSequence):
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def record_keys_writable(self) -> List[str]:
|
||||
"""Returns the keys of all fields in the data records that are writable."""
|
||||
"""Get the keys of all writable fields in the data records.
|
||||
|
||||
This property retrieves the keys of all fields in the data records that
|
||||
can be written to. It uses the `record_class` to determine the model's
|
||||
field structure.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of field keys that are writable in the data records.
|
||||
"""
|
||||
return list(self.record_class().model_fields.keys())
|
||||
|
||||
@classmethod
|
||||
def record_class(cls) -> Type:
|
||||
"""Returns the class of the data record this data sequence handles."""
|
||||
"""Get the class of the data record handled by this data sequence.
|
||||
|
||||
This method determines the class of the data record type associated with
|
||||
the `records` field of the model. The field is expected to be a list, and
|
||||
the element type of the list should be a subclass of `DataRecord`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the record type is not a subclass of `DataRecord`.
|
||||
|
||||
Returns:
|
||||
Type: The class of the data record handled by the data sequence.
|
||||
"""
|
||||
# Access the model field metadata
|
||||
field_info = cls.model_fields["records"]
|
||||
# Get the list element type from the 'type_' attribute
|
||||
@@ -573,6 +635,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> Dict[DateTime, Any]:
|
||||
"""Extract a dictionary indexed by the date_time field of the DataRecords.
|
||||
|
||||
@@ -583,6 +646,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date to filter records (inclusive).
|
||||
end_datetime (datetime, optional): The end date to filter records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[datetime, Any]: A dictionary with the date_time of each record as the key
|
||||
@@ -597,12 +661,22 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
# Create a dictionary to hold date_time and corresponding values
|
||||
filtered_data = {
|
||||
to_datetime(record.date_time, as_string=True): getattr(record, key, None)
|
||||
for record in self.records
|
||||
if (start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge)
|
||||
and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt)
|
||||
}
|
||||
if dropna is None:
|
||||
dropna = True
|
||||
filtered_data = {}
|
||||
for record in self.records:
|
||||
if (
|
||||
record.date_time is None
|
||||
or (dropna and getattr(record, key, None) is None)
|
||||
or (dropna and getattr(record, key, None) == float("nan"))
|
||||
):
|
||||
continue
|
||||
if (
|
||||
start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge
|
||||
) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt):
|
||||
filtered_data[to_datetime(record.date_time, as_string=True)] = getattr(
|
||||
record, key, None
|
||||
)
|
||||
|
||||
return filtered_data
|
||||
|
||||
@@ -611,6 +685,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> Tuple[List[DateTime], List[Optional[float]]]:
|
||||
"""Extracts two lists from data records within an optional date range.
|
||||
|
||||
@@ -622,6 +697,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The key of the attribute in DataRecord to extract.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing a list of datetime values and a list of extracted values.
|
||||
@@ -635,9 +711,15 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
# Create two lists to hold date_time and corresponding values
|
||||
if dropna is None:
|
||||
dropna = True
|
||||
filtered_records = []
|
||||
for record in self.records:
|
||||
if record.date_time is None:
|
||||
if (
|
||||
record.date_time is None
|
||||
or (dropna and getattr(record, key, None) is None)
|
||||
or (dropna and getattr(record, key, None) == float("nan"))
|
||||
):
|
||||
continue
|
||||
if (
|
||||
start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge
|
||||
@@ -653,6 +735,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> pd.Series:
|
||||
"""Extract a series indexed by the date_time field from data records within an optional date range.
|
||||
|
||||
@@ -660,6 +743,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
pd.Series: A Pandas Series with the index as the date_time of each record
|
||||
@@ -668,7 +752,9 @@ class DataSequence(DataBase, MutableSequence):
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
dates, values = self.key_to_lists(key, start_datetime, end_datetime)
|
||||
dates, values = self.key_to_lists(
|
||||
key=key, start_datetime=start_datetime, end_datetime=end_datetime, dropna=dropna
|
||||
)
|
||||
return pd.Series(data=values, index=pd.DatetimeIndex(dates), name=key)
|
||||
|
||||
def key_from_series(self, key: str, series: pd.Series) -> None:
|
||||
@@ -704,6 +790,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
fill_method: Optional[str] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> NDArray[Shape["*"], Any]:
|
||||
"""Extract an array indexed by fixed time intervals from data records within an optional date range.
|
||||
|
||||
@@ -717,6 +804,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
- 'ffill': Forward fill missing values.
|
||||
- 'bfill': Backward fill missing values.
|
||||
- 'none': Defaults to 'linear' for numeric values, otherwise 'ffill'.
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A NumPy Array of the values extracted from the specified key.
|
||||
@@ -724,10 +812,54 @@ class DataSequence(DataBase, MutableSequence):
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
self._validate_key(key)
|
||||
# Ensure datetime objects are normalized
|
||||
start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None
|
||||
end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None
|
||||
|
||||
resampled = None
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
series = self.key_to_series(key)
|
||||
|
||||
dates, values = self.key_to_lists(key=key, dropna=dropna)
|
||||
values_len = len(values)
|
||||
|
||||
if values_len < 1:
|
||||
# No values, assume at at least one value set to None
|
||||
if start_datetime is not None:
|
||||
dates.append(start_datetime - interval)
|
||||
else:
|
||||
dates.append(to_datetime(to_maxtime=False))
|
||||
values.append(None)
|
||||
|
||||
if start_datetime is not None:
|
||||
start_index = 0
|
||||
while start_index < values_len:
|
||||
if compare_datetimes(dates[start_index], start_datetime).ge:
|
||||
break
|
||||
start_index += 1
|
||||
if start_index == 0:
|
||||
# No value before start
|
||||
# Add dummy value
|
||||
dates.insert(0, dates[0] - interval)
|
||||
values.insert(0, values[0])
|
||||
elif start_index > 1:
|
||||
# Truncate all values before latest value before start_datetime
|
||||
dates = dates[start_index - 1 :]
|
||||
values = values[start_index - 1 :]
|
||||
|
||||
if end_datetime is not None:
|
||||
if compare_datetimes(dates[-1], end_datetime).lt:
|
||||
# Add dummy value at end_datetime
|
||||
dates.append(end_datetime)
|
||||
values.append(values[-1])
|
||||
|
||||
series = pd.Series(data=values, index=pd.DatetimeIndex(dates), name=key)
|
||||
if not series.index.inferred_type == "datetime64":
|
||||
raise TypeError(
|
||||
f"Expected DatetimeIndex, but got {type(series.index)} "
|
||||
f"infered to {series.index.inferred_type}: {series}"
|
||||
)
|
||||
|
||||
# Handle missing values
|
||||
if series.dtype in [np.float64, np.float32, np.int64, np.int32]:
|
||||
@@ -735,7 +867,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
if fill_method is None:
|
||||
fill_method = "linear"
|
||||
# Resample the series to the specified interval
|
||||
resampled = series.resample(interval).mean()
|
||||
resampled = series.resample(interval, origin="start").first()
|
||||
if fill_method == "linear":
|
||||
resampled = resampled.interpolate(method="linear")
|
||||
elif fill_method == "ffill":
|
||||
@@ -749,7 +881,7 @@ class DataSequence(DataBase, MutableSequence):
|
||||
if fill_method is None:
|
||||
fill_method = "ffill"
|
||||
# Resample the series to the specified interval
|
||||
resampled = series.resample(interval).first()
|
||||
resampled = series.resample(interval, origin="start").first()
|
||||
if fill_method == "ffill":
|
||||
resampled = resampled.ffill()
|
||||
elif fill_method == "bfill":
|
||||
@@ -955,18 +1087,29 @@ class DataProvider(SingletonMixin, DataSequence):
|
||||
self.sort_by_datetime()
|
||||
|
||||
|
||||
class DataImportProvider(DataProvider):
|
||||
"""Abstract base class for data providers that import generic data.
|
||||
class DataImportMixin:
|
||||
"""Mixin class for import of generic data.
|
||||
|
||||
This class is designed to handle generic data provided in the form of a key-value dictionary.
|
||||
- **Keys**: Represent identifiers from the record keys of a specific data.
|
||||
- **Values**: Are lists of data values starting at a specified `start_datetime`, where
|
||||
each value corresponds to a subsequent time interval (e.g., hourly).
|
||||
|
||||
Subclasses must implement the logic for managing generic data based on the imported records.
|
||||
Two special keys are handled. `start_datetime` may be used to defined the starting datetime of
|
||||
the values. `ìnterval` may be used to define the fixed time interval between two values.
|
||||
|
||||
On import `self.update_value(datetime, key, value)` is called which has to be provided.
|
||||
Also `self.start_datetime` may be necessary as a default in case `start_datetime`is not given.
|
||||
"""
|
||||
|
||||
def import_datetimes(self, value_count: int) -> List[Tuple[DateTime, int]]:
|
||||
# Attributes required but defined elsehere.
|
||||
# - start_datetime
|
||||
# - record_keys_writable
|
||||
# - update_valu
|
||||
|
||||
def import_datetimes(
|
||||
self, start_datetime: DateTime, value_count: int, interval: Optional[Duration] = None
|
||||
) -> List[Tuple[DateTime, int]]:
|
||||
"""Generates a list of tuples containing timestamps and their corresponding value indices.
|
||||
|
||||
The function accounts for daylight saving time (DST) transitions:
|
||||
@@ -975,7 +1118,9 @@ class DataImportProvider(DataProvider):
|
||||
but they share the same value index.
|
||||
|
||||
Args:
|
||||
start_datetime (DateTime): Start datetime of values
|
||||
value_count (int): The number of timestamps to generate.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Returns:
|
||||
List[Tuple[DateTime, int]]:
|
||||
@@ -990,7 +1135,7 @@ class DataImportProvider(DataProvider):
|
||||
|
||||
Example:
|
||||
>>> start_datetime = pendulum.datetime(2024, 11, 3, 0, 0, tz="America/New_York")
|
||||
>>> import_datetimes(5)
|
||||
>>> import_datetimes(start_datetime, 5)
|
||||
[(DateTime(2024, 11, 3, 0, 0, tzinfo=Timezone('America/New_York')), 0),
|
||||
(DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1),
|
||||
(DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), # Repeated hour
|
||||
@@ -998,7 +1143,16 @@ class DataImportProvider(DataProvider):
|
||||
(DateTime(2024, 11, 3, 3, 0, tzinfo=Timezone('America/New_York')), 3)]
|
||||
"""
|
||||
timestamps_with_indices: List[Tuple[DateTime, int]] = []
|
||||
value_datetime = self.start_datetime
|
||||
|
||||
if interval is None:
|
||||
interval = to_duration("1 hour")
|
||||
interval_steps_per_hour = int(3600 / interval.total_seconds())
|
||||
if interval.total_seconds() * interval_steps_per_hour != 3600:
|
||||
error_msg = f"Interval {interval} does not fit into hour."
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
value_datetime = start_datetime
|
||||
value_index = 0
|
||||
|
||||
while value_index < value_count:
|
||||
@@ -1006,37 +1160,219 @@ class DataImportProvider(DataProvider):
|
||||
logger.debug(f"{i}: Insert at {value_datetime} with index {value_index}")
|
||||
timestamps_with_indices.append((value_datetime, value_index))
|
||||
|
||||
# Check if there is a DST transition
|
||||
next_time = value_datetime.add(hours=1)
|
||||
if next_time <= value_datetime:
|
||||
# Check if there is a DST transition (i.e., ambiguous time during fall back)
|
||||
# Repeat the hour value (reuse value index)
|
||||
value_datetime = next_time
|
||||
logger.debug(f"{i+1}: Repeat at {value_datetime} with index {value_index}")
|
||||
timestamps_with_indices.append((value_datetime, value_index))
|
||||
elif next_time.hour != value_datetime.hour + 1 and value_datetime.hour != 23:
|
||||
# Skip the hour value (spring forward in value index)
|
||||
value_index += 1
|
||||
logger.debug(f"{i+1}: Skip at {next_time} with index {value_index}")
|
||||
next_time = value_datetime.add(seconds=interval.total_seconds())
|
||||
|
||||
# Increment value index and value_datetime for new hour
|
||||
# Check if there is a DST transition
|
||||
if next_time.dst() != value_datetime.dst():
|
||||
if next_time.hour == value_datetime.hour:
|
||||
# We jump back by 1 hour
|
||||
# Repeat the value(s) (reuse value index)
|
||||
for i in range(interval_steps_per_hour):
|
||||
logger.debug(f"{i+1}: Repeat at {next_time} with index {value_index}")
|
||||
timestamps_with_indices.append((next_time, value_index))
|
||||
next_time = next_time.add(seconds=interval.total_seconds())
|
||||
else:
|
||||
# We jump forward by 1 hour
|
||||
# Drop the value(s)
|
||||
logger.debug(
|
||||
f"{i+1}: Skip {interval_steps_per_hour} at {next_time} with index {value_index}"
|
||||
)
|
||||
value_index += interval_steps_per_hour
|
||||
|
||||
# Increment value index and value_datetime for new interval
|
||||
value_index += 1
|
||||
value_datetime = value_datetime.add(hours=1)
|
||||
value_datetime = next_time
|
||||
|
||||
return timestamps_with_indices
|
||||
|
||||
def import_from_json(self, json_str: str, key_prefix: str = "") -> None:
|
||||
def import_from_dict(
|
||||
self,
|
||||
import_data: dict,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a dictionary.
|
||||
|
||||
This method reads generic data from a dictionary, matches keys based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values sequentially.
|
||||
All value lists must have the same length.
|
||||
|
||||
Args:
|
||||
import_data (dict): Dictionary containing the generic data with optional
|
||||
'start_datetime' and 'interval' keys.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values if not in dict.
|
||||
interval (Duration, optional): The fixed time interval if not in dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If value lists have different lengths or if datetime conversion fails.
|
||||
"""
|
||||
# Handle datetime and interval from dict or parameters
|
||||
if "start_datetime" in import_data:
|
||||
try:
|
||||
start_datetime = to_datetime(import_data["start_datetime"])
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid start_datetime in import data: {e}")
|
||||
|
||||
if start_datetime is None:
|
||||
start_datetime = self.start_datetime # type: ignore
|
||||
|
||||
if "interval" in import_data:
|
||||
try:
|
||||
interval = to_duration(import_data["interval"])
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid interval in import data: {e}")
|
||||
|
||||
# Filter keys based on key_prefix and record_keys_writable
|
||||
valid_keys = [
|
||||
key
|
||||
for key in import_data.keys()
|
||||
if key.startswith(key_prefix)
|
||||
and key in self.record_keys_writable # type: ignore
|
||||
and key not in ("start_datetime", "interval")
|
||||
]
|
||||
|
||||
if not valid_keys:
|
||||
return
|
||||
|
||||
# Validate all value lists have the same length
|
||||
value_lengths = []
|
||||
for key in valid_keys:
|
||||
value_list = import_data[key]
|
||||
if not isinstance(value_list, (list, tuple, np.ndarray)):
|
||||
raise ValueError(f"Value for key '{key}' must be a list, tuple, or array")
|
||||
value_lengths.append(len(value_list))
|
||||
|
||||
if len(set(value_lengths)) > 1:
|
||||
raise ValueError(
|
||||
f"All value lists must have the same length. Found lengths: "
|
||||
f"{dict(zip(valid_keys, value_lengths))}"
|
||||
)
|
||||
|
||||
# Generate datetime mapping once for the common length
|
||||
values_count = value_lengths[0]
|
||||
value_datetime_mapping = self.import_datetimes(
|
||||
start_datetime, values_count, interval=interval
|
||||
)
|
||||
|
||||
# Process each valid key
|
||||
for key in valid_keys:
|
||||
try:
|
||||
value_list = import_data[key]
|
||||
|
||||
# Update values, skipping any None/NaN
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
value = value_list[value_index]
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(value_datetime, key, value) # type: ignore
|
||||
|
||||
except (IndexError, TypeError) as e:
|
||||
raise ValueError(f"Error processing values for key '{key}': {e}")
|
||||
|
||||
def import_from_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a pandas DataFrame.
|
||||
|
||||
This method reads generic data from a DataFrame, matches columns based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values using
|
||||
the DataFrame's index as timestamps.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame containing the generic data with datetime index
|
||||
or sequential values.
|
||||
key_prefix (str, optional): A prefix to filter relevant columns from the DataFrame.
|
||||
Only columns starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime if DataFrame doesn't have datetime index.
|
||||
interval (Duration, optional): The fixed time interval if DataFrame doesn't have datetime index.
|
||||
|
||||
Raises:
|
||||
ValueError: If DataFrame structure is invalid or datetime conversion fails.
|
||||
"""
|
||||
# Validate DataFrame
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError("Input must be a pandas DataFrame")
|
||||
|
||||
# Handle datetime index
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
try:
|
||||
index_datetimes = [to_datetime(dt) for dt in df.index]
|
||||
has_datetime_index = True
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime index in DataFrame: {e}")
|
||||
else:
|
||||
if start_datetime is None:
|
||||
start_datetime = self.start_datetime # type: ignore
|
||||
has_datetime_index = False
|
||||
|
||||
# Filter columns based on key_prefix and record_keys_writable
|
||||
valid_columns = [
|
||||
col
|
||||
for col in df.columns
|
||||
if col.startswith(key_prefix) and col in self.record_keys_writable # type: ignore
|
||||
]
|
||||
|
||||
if not valid_columns:
|
||||
return
|
||||
|
||||
# For DataFrame, length validation is implicit since all columns have same length
|
||||
values_count = len(df)
|
||||
|
||||
# Generate value_datetime_mapping once if not using datetime index
|
||||
if not has_datetime_index:
|
||||
value_datetime_mapping = self.import_datetimes(
|
||||
start_datetime, values_count, interval=interval
|
||||
)
|
||||
|
||||
# Process each valid column
|
||||
for column in valid_columns:
|
||||
try:
|
||||
values = df[column].tolist()
|
||||
|
||||
if has_datetime_index:
|
||||
# Use the DataFrame's datetime index
|
||||
for dt, value in zip(index_datetimes, values):
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(dt, column, value) # type: ignore
|
||||
else:
|
||||
# Use the pre-generated datetime mapping
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
value = values[value_index]
|
||||
if value is not None and not pd.isna(value):
|
||||
self.update_value(value_datetime, column, value) # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error processing column '{column}': {e}")
|
||||
|
||||
def import_from_json(
|
||||
self,
|
||||
json_str: str,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a JSON string.
|
||||
|
||||
This method reads generic data from a JSON string, matches keys based on the
|
||||
record keys and the provided `key_prefix`, and updates the data values sequentially,
|
||||
starting from the `start_datetime`. Each data value is associated with an hourly
|
||||
interval.
|
||||
starting from the `start_datetime`.
|
||||
|
||||
If start_datetime and or interval is given in the JSON dict it will be used. Otherwise
|
||||
the given parameters are used. If None is given start_datetime defaults to
|
||||
'self.start_datetime' and interval defaults to 1 hour.
|
||||
|
||||
Args:
|
||||
json_str (str): The JSON string containing the generic data.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Raises:
|
||||
JSONDecodeError: If the file content is not valid JSON.
|
||||
@@ -1045,22 +1381,56 @@ class DataImportProvider(DataProvider):
|
||||
Given a JSON string with the following content:
|
||||
```json
|
||||
{
|
||||
"load0_mean": [20.5, 21.0, 22.1],
|
||||
"load1_mean": [50, 55, 60]
|
||||
"start_datetime": "2024-11-10 00:00:00"
|
||||
"interval": "30 minutes"
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"other_xyz: [10.5, 11.0, 12.1],
|
||||
}
|
||||
```
|
||||
and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though
|
||||
and `key_prefix = "load"`, only the "load_mean" key will be processed even though
|
||||
both keys are in the record.
|
||||
"""
|
||||
import_data = json.loads(json_str)
|
||||
for key in self.record_keys_writable:
|
||||
if key.startswith(key_prefix) and key in import_data:
|
||||
value_list = import_data[key]
|
||||
value_datetime_mapping = self.import_datetimes(len(value_list))
|
||||
for value_datetime, value_index in value_datetime_mapping:
|
||||
self.update_value(value_datetime, key, value_list[value_index])
|
||||
# Try pandas dataframe with orient="split"
|
||||
try:
|
||||
import_data = PydanticDateTimeDataFrame.model_validate_json(json_str)
|
||||
self.import_from_dataframe(import_data.to_dataframe())
|
||||
return
|
||||
except ValidationError as e:
|
||||
error_msg = ""
|
||||
for error in e.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
message = error["msg"]
|
||||
error_type = error["type"]
|
||||
error_msg += f"Field: {field}\nError: {message}\nType: {error_type}\n"
|
||||
logger.debug(f"PydanticDateTimeDataFrame import: {error_msg}")
|
||||
|
||||
def import_from_file(self, import_file_path: Path, key_prefix: str = "") -> None:
|
||||
# Try dictionary with special keys start_datetime and intervall
|
||||
try:
|
||||
import_data = PydanticDateTimeData.model_validate_json(json_str)
|
||||
self.import_from_dict(import_data.to_dict())
|
||||
return
|
||||
except ValidationError as e:
|
||||
error_msg = ""
|
||||
for error in e.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
message = error["msg"]
|
||||
error_type = error["type"]
|
||||
error_msg += f"Field: {field}\nError: {message}\nType: {error_type}\n"
|
||||
logger.debug(f"PydanticDateTimeData import: {error_msg}")
|
||||
|
||||
# Use simple dict format
|
||||
import_data = json.loads(json_str)
|
||||
self.import_from_dict(
|
||||
import_data, key_prefix=key_prefix, start_datetime=start_datetime, interval=interval
|
||||
)
|
||||
|
||||
def import_from_file(
|
||||
self,
|
||||
import_file_path: Path,
|
||||
key_prefix: str = "",
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
interval: Optional[Duration] = None,
|
||||
) -> None:
|
||||
"""Updates generic data by importing it from a file.
|
||||
|
||||
This method reads generic data from a JSON file, matches keys based on the
|
||||
@@ -1068,10 +1438,16 @@ class DataImportProvider(DataProvider):
|
||||
starting from the `start_datetime`. Each data value is associated with an hourly
|
||||
interval.
|
||||
|
||||
If start_datetime and or interval is given in the JSON dict it will be used. Otherwise
|
||||
the given parameters are used. If None is given start_datetime defaults to
|
||||
'self.start_datetime' and interval defaults to 1 hour.
|
||||
|
||||
Args:
|
||||
import_file_path (Path): The path to the JSON file containing the generic data.
|
||||
key_prefix (str, optional): A prefix to filter relevant keys from the generic data.
|
||||
Only keys starting with this prefix will be considered. Defaults to an empty string.
|
||||
start_datetime (DateTime, optional): Start datetime of values.
|
||||
interval (duration, optional): The fixed time interval. Defaults to 1 hour.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified file does not exist.
|
||||
@@ -1081,16 +1457,32 @@ class DataImportProvider(DataProvider):
|
||||
Given a JSON file with the following content:
|
||||
```json
|
||||
{
|
||||
"load0_mean": [20.5, 21.0, 22.1],
|
||||
"load1_mean": [50, 55, 60]
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"other_xyz: [10.5, 11.0, 12.1],
|
||||
}
|
||||
```
|
||||
and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though
|
||||
and `key_prefix = "load"`, only the "load_mean" key will be processed even though
|
||||
both keys are in the record.
|
||||
"""
|
||||
with import_file_path.open("r") as import_file:
|
||||
import_str = import_file.read()
|
||||
self.import_from_json(import_str, key_prefix)
|
||||
self.import_from_json(
|
||||
import_str, key_prefix=key_prefix, start_datetime=start_datetime, interval=interval
|
||||
)
|
||||
|
||||
|
||||
class DataImportProvider(DataImportMixin, DataProvider):
|
||||
"""Abstract base class for data providers that import generic data.
|
||||
|
||||
This class is designed to handle generic data provided in the form of a key-value dictionary.
|
||||
- **Keys**: Represent identifiers from the record keys of a specific data.
|
||||
- **Values**: Are lists of data values starting at a specified `start_datetime`, where
|
||||
each value corresponds to a subsequent time interval (e.g., hourly).
|
||||
|
||||
Subclasses must implement the logic for managing generic data based on the imported records.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
@@ -1129,6 +1521,24 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
enab.append(provider)
|
||||
return enab
|
||||
|
||||
@property
|
||||
def record_keys(self) -> list[str]:
|
||||
"""Returns the keys of all fields in the data records of all enabled providers."""
|
||||
key_set = set(
|
||||
chain.from_iterable(provider.record_keys for provider in self.enabled_providers)
|
||||
)
|
||||
return list(key_set)
|
||||
|
||||
@property
|
||||
def record_keys_writable(self) -> list[str]:
|
||||
"""Returns the keys of all fields in the data records that are writable of all enabled providers."""
|
||||
key_set = set(
|
||||
chain.from_iterable(
|
||||
provider.record_keys_writable for provider in self.enabled_providers
|
||||
)
|
||||
)
|
||||
return list(key_set)
|
||||
|
||||
def __getitem__(self, key: str) -> pd.Series:
|
||||
"""Retrieve a Pandas Series for a specified key from the data in each DataProvider.
|
||||
|
||||
@@ -1206,9 +1616,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
Returns:
|
||||
Iterator[str]: An iterator over the unique keys from all providers.
|
||||
"""
|
||||
return iter(
|
||||
set(chain.from_iterable(provider.record_keys for provider in self.enabled_providers))
|
||||
)
|
||||
return iter(self.record_keys)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of keys in the container.
|
||||
@@ -1216,9 +1624,7 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
Returns:
|
||||
int: The total number of keys in this container.
|
||||
"""
|
||||
return len(
|
||||
list(chain.from_iterable(provider.record_keys for provider in self.enabled_providers))
|
||||
)
|
||||
return len(self.record_keys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Provide a string representation of the DataContainer instance.
|
||||
@@ -1242,6 +1648,48 @@ class DataContainer(SingletonMixin, DataBase, MutableMapping):
|
||||
for provider in self.enabled_providers:
|
||||
provider.update_data(force_enable=force_enable, force_update=force_update)
|
||||
|
||||
def key_to_series(
|
||||
self,
|
||||
key: str,
|
||||
start_datetime: Optional[DateTime] = None,
|
||||
end_datetime: Optional[DateTime] = None,
|
||||
dropna: Optional[bool] = None,
|
||||
) -> pd.Series:
|
||||
"""Extract a series indexed by the date_time field from data records within an optional date range.
|
||||
|
||||
Iterates through providers to find and return the first available series for the specified key.
|
||||
|
||||
Args:
|
||||
key (str): The field name in the DataRecord from which to extract values.
|
||||
start_datetime (datetime, optional): The start date for filtering the records (inclusive).
|
||||
end_datetime (datetime, optional): The end date for filtering the records (exclusive).
|
||||
dropna: (bool, optional): Whether to drop NAN/ None values before processing. Defaults to True.
|
||||
|
||||
Returns:
|
||||
pd.Series: A Pandas Series with the index as the date_time of each record
|
||||
and the values extracted from the specified key.
|
||||
|
||||
Raises:
|
||||
KeyError: If the specified key is not found in any of the DataRecords.
|
||||
"""
|
||||
series = None
|
||||
for provider in self.enabled_providers:
|
||||
try:
|
||||
series = provider.key_to_series(
|
||||
key,
|
||||
start_datetime=start_datetime,
|
||||
end_datetime=end_datetime,
|
||||
dropna=dropna,
|
||||
)
|
||||
break
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if series is None:
|
||||
raise KeyError(f"No data found for key '{key}'.")
|
||||
|
||||
return series
|
||||
|
||||
def key_to_array(
|
||||
self,
|
||||
key: str,
|
||||
|
@@ -1,68 +1,43 @@
|
||||
"""Module for managing and serializing Pydantic-based models with custom support.
|
||||
|
||||
This module introduces the `PydanticBaseModel` class, which extends Pydantic’s `BaseModel` to facilitate
|
||||
custom serialization and deserialization for `pendulum.DateTime` objects. The main features include
|
||||
automatic handling of `pendulum.DateTime` fields, custom serialization to ISO 8601 format, and utility
|
||||
methods for converting model instances to and from dictionary and JSON formats.
|
||||
This module provides classes that extend Pydantic’s functionality to include robust handling
|
||||
of `pendulum.DateTime` fields, offering seamless serialization and deserialization into ISO 8601 format.
|
||||
These enhancements facilitate the use of Pydantic models in applications requiring timezone-aware
|
||||
datetime fields and consistent data serialization.
|
||||
|
||||
Key Classes:
|
||||
- PendulumDateTime: A custom type adapter that provides serialization and deserialization
|
||||
functionality for `pendulum.DateTime` objects, converting them to ISO 8601 strings and back.
|
||||
- PydanticBaseModel: A base model class for handling prediction records or configuration data
|
||||
with automatic Pendulum DateTime handling and additional methods for JSON and dictionary
|
||||
conversion.
|
||||
|
||||
Classes:
|
||||
PendulumDateTime(TypeAdapter[pendulum.DateTime]): Type adapter for `pendulum.DateTime` fields
|
||||
with ISO 8601 serialization. Includes:
|
||||
- serialize: Converts `pendulum.DateTime` instances to ISO 8601 string.
|
||||
- deserialize: Converts ISO 8601 strings to `pendulum.DateTime` instances.
|
||||
- is_iso8601: Validates if a string matches the ISO 8601 date format.
|
||||
|
||||
PydanticBaseModel(BaseModel): Extends `pydantic.BaseModel` to handle `pendulum.DateTime` fields
|
||||
and adds convenience methods for dictionary and JSON serialization. Key methods:
|
||||
- model_dump: Dumps the model, converting `pendulum.DateTime` fields to ISO 8601.
|
||||
- model_construct: Constructs a model instance with automatic deserialization of
|
||||
`pendulum.DateTime` fields from ISO 8601.
|
||||
- to_dict: Serializes the model instance to a dictionary.
|
||||
- from_dict: Constructs a model instance from a dictionary.
|
||||
- to_json: Converts the model instance to a JSON string.
|
||||
- from_json: Creates a model instance from a JSON string.
|
||||
|
||||
Usage Example:
|
||||
# Define custom settings in a model using PydanticBaseModel
|
||||
class PredictionCommonSettings(PydanticBaseModel):
|
||||
prediction_start: pendulum.DateTime = Field(...)
|
||||
|
||||
# Serialize a model instance to a dictionary or JSON
|
||||
config = PredictionCommonSettings(prediction_start=pendulum.now())
|
||||
config_dict = config.to_dict()
|
||||
config_json = config.to_json()
|
||||
|
||||
# Deserialize from dictionary or JSON
|
||||
new_config = PredictionCommonSettings.from_dict(config_dict)
|
||||
restored_config = PredictionCommonSettings.from_json(config_json)
|
||||
|
||||
Dependencies:
|
||||
- `pendulum`: Required for handling timezone-aware datetime fields.
|
||||
- `pydantic`: Required for model and validation functionality.
|
||||
|
||||
Notes:
|
||||
- This module enables custom handling of Pendulum DateTime fields within Pydantic models,
|
||||
which is particularly useful for applications requiring consistent ISO 8601 datetime formatting
|
||||
and robust timezone-aware datetime support.
|
||||
Key Features:
|
||||
- Custom type adapter for `pendulum.DateTime` fields with automatic serialization to ISO 8601 strings.
|
||||
- Utility methods for converting models to and from dictionaries and JSON strings.
|
||||
- Validation tools for maintaining data consistency, including specialized support for
|
||||
pandas DataFrames and Series with datetime indexes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas as pd
|
||||
import pendulum
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from pandas.api.types import is_datetime64_any_dtype
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
|
||||
|
||||
|
||||
# Custom type adapter for Pendulum DateTime fields
|
||||
class PendulumDateTime(TypeAdapter[pendulum.DateTime]):
|
||||
class PydanticTypeAdapterDateTime(TypeAdapter[pendulum.DateTime]):
|
||||
"""Custom type adapter for Pendulum DateTime fields."""
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value: Any) -> str:
|
||||
"""Convert pendulum.DateTime to ISO 8601 string."""
|
||||
@@ -105,41 +80,69 @@ class PydanticBaseModel(BaseModel):
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def validate_and_convert_pendulum(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
"""Validator to convert fields of type `pendulum.DateTime`.
|
||||
|
||||
Converts fields to proper `pendulum.DateTime` objects, ensuring correct input types.
|
||||
|
||||
This method is invoked for every field before the field value is set. If the field's type
|
||||
is `pendulum.DateTime`, it tries to convert string or timestamp values to `pendulum.DateTime`
|
||||
objects. If the value cannot be converted, a validation error is raised.
|
||||
|
||||
Args:
|
||||
value: The value to be assigned to the field.
|
||||
info: Validation information for the field.
|
||||
|
||||
Returns:
|
||||
The converted value, if successful.
|
||||
|
||||
Raises:
|
||||
ValidationError: If the value cannot be converted to `pendulum.DateTime`.
|
||||
"""
|
||||
# Get the field name and expected type
|
||||
field_name = info.field_name
|
||||
expected_type = cls.model_fields[field_name].annotation
|
||||
|
||||
# Convert
|
||||
if expected_type is pendulum.DateTime or expected_type is AwareDatetime:
|
||||
try:
|
||||
value = to_datetime(value)
|
||||
except:
|
||||
pass
|
||||
return value
|
||||
|
||||
# Override Pydantic’s serialization for all DateTime fields
|
||||
def model_dump(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Custom dump method to handle serialization for DateTime fields."""
|
||||
result = super().model_dump(*args, **kwargs)
|
||||
for key, value in result.items():
|
||||
if isinstance(value, pendulum.DateTime):
|
||||
result[key] = PendulumDateTime.serialize(value)
|
||||
result[key] = PydanticTypeAdapterDateTime.serialize(value)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, data: dict) -> "PydanticBaseModel":
|
||||
def model_construct(
|
||||
cls, _fields_set: set[str] | None = None, **values: Any
|
||||
) -> "PydanticBaseModel":
|
||||
"""Custom constructor to handle deserialization for DateTime fields."""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and PendulumDateTime.is_iso8601(value):
|
||||
data[key] = PendulumDateTime.deserialize(value)
|
||||
return super().model_construct(data)
|
||||
for key, value in values.items():
|
||||
if isinstance(value, str) and PydanticTypeAdapterDateTime.is_iso8601(value):
|
||||
values[key] = PydanticTypeAdapterDateTime.deserialize(value)
|
||||
return super().model_construct(_fields_set, **values)
|
||||
|
||||
def reset_optional(self) -> "PydanticBaseModel":
|
||||
"""Resets all optional fields in the model to None.
|
||||
|
||||
Iterates through all model fields and sets any optional (non-required)
|
||||
fields to None. The modification is done in-place on the current instance.
|
||||
|
||||
Returns:
|
||||
PydanticBaseModel: The current instance with all optional fields
|
||||
reset to None.
|
||||
|
||||
Example:
|
||||
>>> settings = PydanticBaseModel(name="test", optional_field="value")
|
||||
>>> settings.reset_optional()
|
||||
>>> assert settings.optional_field is None
|
||||
"""
|
||||
for field_name, field in self.model_fields.items():
|
||||
if field.is_required is False: # Check if field is optional
|
||||
setattr(self, field_name, None)
|
||||
def reset_to_defaults(self) -> "PydanticBaseModel":
|
||||
"""Resets the fields to their default values."""
|
||||
for field_name, field_info in self.model_fields.items():
|
||||
if field_info.default_factory is not None: # Handle fields with default_factory
|
||||
default_value = field_info.default_factory()
|
||||
else:
|
||||
default_value = field_info.default
|
||||
try:
|
||||
setattr(self, field_name, default_value)
|
||||
except (AttributeError, TypeError, ValidationError):
|
||||
# Skip fields that are read-only or dynamically computed or can not be set to default
|
||||
pass
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@@ -167,40 +170,6 @@ class PydanticBaseModel(BaseModel):
|
||||
"""
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict_with_reset(cls, data: dict | None = None) -> "PydanticBaseModel":
|
||||
"""Creates a new instance with reset optional fields, then updates from dict.
|
||||
|
||||
First creates an instance with default values, resets all optional fields
|
||||
to None, then updates the instance with the provided dictionary data if any.
|
||||
|
||||
Args:
|
||||
data (dict | None): Dictionary containing field values to initialize
|
||||
the instance with. Defaults to None.
|
||||
|
||||
Returns:
|
||||
PydanticBaseModel: A new instance with all optional fields initially
|
||||
reset to None and then updated with provided data.
|
||||
|
||||
Example:
|
||||
>>> data = {"name": "test", "optional_field": "value"}
|
||||
>>> settings = PydanticBaseModel.from_dict_with_reset(data)
|
||||
>>> # All non-specified optional fields will be None
|
||||
"""
|
||||
# Create instance with model defaults
|
||||
instance = cls()
|
||||
|
||||
# Reset all optional fields to None
|
||||
instance.reset_optional()
|
||||
|
||||
# Update with provided data if any
|
||||
if data:
|
||||
# Use model_validate to ensure proper type conversion and validation
|
||||
updated_instance = instance.model_validate({**instance.model_dump(), **data})
|
||||
return updated_instance
|
||||
|
||||
return instance
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert the PydanticBaseModel instance to a JSON string.
|
||||
|
||||
@@ -224,3 +193,287 @@ class PydanticBaseModel(BaseModel):
|
||||
"""
|
||||
data = json.loads(json_str)
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
class PydanticDateTimeData(RootModel):
|
||||
"""Pydantic model for time series data with consistent value lengths.
|
||||
|
||||
This model validates a dictionary where:
|
||||
- Keys are strings representing data series names
|
||||
- Values are lists of numeric or string values
|
||||
- Special keys 'start_datetime' and 'interval' can contain string values
|
||||
for time series indexing
|
||||
- All value lists must have the same length
|
||||
|
||||
Example:
|
||||
{
|
||||
"start_datetime": "2024-01-01 00:00:00", # optional
|
||||
"interval": "1 Hour", # optional
|
||||
"load_mean": [20.5, 21.0, 22.1],
|
||||
"load_min": [18.5, 19.0, 20.1]
|
||||
}
|
||||
"""
|
||||
|
||||
root: Dict[str, Union[str, List[Union[float, int, str, None]]]]
|
||||
|
||||
@field_validator("root", mode="after")
|
||||
@classmethod
|
||||
def validate_root(
|
||||
cls, value: Dict[str, Union[str, List[Union[float, int, str, None]]]]
|
||||
) -> Dict[str, Union[str, List[Union[float, int, str, None]]]]:
|
||||
# Validate that all keys are strings
|
||||
if not all(isinstance(k, str) for k in value.keys()):
|
||||
raise ValueError("All keys in the dictionary must be strings.")
|
||||
|
||||
# Validate that no lists contain only None values
|
||||
for v in value.values():
|
||||
if isinstance(v, list) and all(item is None for item in v):
|
||||
raise ValueError("Lists cannot contain only None values.")
|
||||
|
||||
# Validate that all lists have consistent lengths (if they are lists)
|
||||
list_lengths = [len(v) for v in value.values() if isinstance(v, list)]
|
||||
if len(set(list_lengths)) > 1:
|
||||
raise ValueError("All lists in the dictionary must have the same length.")
|
||||
|
||||
# Validate special keys
|
||||
if "start_datetime" in value.keys():
|
||||
value["start_datetime"] = to_datetime(value["start_datetime"])
|
||||
if "interval" in value.keys():
|
||||
value["interval"] = to_duration(value["interval"])
|
||||
|
||||
return value
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, List[Union[float, int, str, None]]]]:
|
||||
"""Convert the model to a plain dictionary.
|
||||
|
||||
Returns:
|
||||
Dict containing the validated data.
|
||||
"""
|
||||
return self.root
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PydanticDateTimeData":
|
||||
"""Create a PydanticDateTimeData instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: Input dictionary
|
||||
|
||||
Returns:
|
||||
PydanticDateTimeData instance
|
||||
"""
|
||||
return cls(root=data)
|
||||
|
||||
|
||||
class PydanticDateTimeDataFrame(PydanticBaseModel):
|
||||
"""Pydantic model for validating pandas DataFrame data with datetime index."""
|
||||
|
||||
data: Dict[str, Dict[str, Any]]
|
||||
dtypes: Dict[str, str] = Field(default_factory=dict)
|
||||
tz: Optional[str] = Field(default=None, description="Timezone for datetime values")
|
||||
datetime_columns: list[str] = Field(
|
||||
default_factory=lambda: ["date_time"], description="Columns to be treated as datetime"
|
||||
)
|
||||
|
||||
@field_validator("tz")
|
||||
def validate_timezone(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that the timezone is valid."""
|
||||
if v is not None:
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid timezone: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, v: Dict[str, Any], info: ValidationInfo) -> Dict[str, Any]:
|
||||
if not v:
|
||||
return v
|
||||
|
||||
# Validate consistent columns
|
||||
columns = set(next(iter(v.values())).keys())
|
||||
if not all(set(row.keys()) == columns for row in v.values()):
|
||||
raise ValueError("All rows must have the same columns")
|
||||
|
||||
# Convert index datetime strings
|
||||
try:
|
||||
d = {
|
||||
to_datetime(dt, as_string=True, in_timezone=info.data.get("tz")): value
|
||||
for dt, value in v.items()
|
||||
}
|
||||
v = d
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime string in index: {e}")
|
||||
|
||||
# Convert datetime columns
|
||||
datetime_cols = info.data.get("datetime_columns", [])
|
||||
try:
|
||||
for dt_str, value in v.items():
|
||||
for column_name, column_value in value.items():
|
||||
if column_name in datetime_cols and column_value is not None:
|
||||
v[dt_str][column_name] = to_datetime(
|
||||
column_value, in_timezone=info.data.get("tz")
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime value in column: {e}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("dtypes")
|
||||
@classmethod
|
||||
def validate_dtypes(cls, v: Dict[str, str], info: ValidationInfo) -> Dict[str, str]:
|
||||
if not v:
|
||||
return v
|
||||
|
||||
valid_dtypes = {"int64", "float64", "bool", "datetime64[ns]", "object", "string"}
|
||||
invalid_dtypes = set(v.values()) - valid_dtypes
|
||||
if invalid_dtypes:
|
||||
raise ValueError(f"Unsupported dtypes: {invalid_dtypes}")
|
||||
|
||||
data = info.data.get("data", {})
|
||||
if data:
|
||||
columns = set(next(iter(data.values())).keys())
|
||||
if not all(col in columns for col in v.keys()):
|
||||
raise ValueError("dtype columns must exist in data columns")
|
||||
return v
|
||||
|
||||
def to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert the validated model data to a pandas DataFrame."""
|
||||
df = pd.DataFrame.from_dict(self.data, orient="index")
|
||||
|
||||
# Convert index to datetime
|
||||
index = pd.Index([to_datetime(dt, in_timezone=self.tz) for dt in df.index])
|
||||
df.index = index
|
||||
|
||||
dtype_mapping = {
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bool": bool,
|
||||
}
|
||||
|
||||
# Apply dtypes
|
||||
for col, dtype in self.dtypes.items():
|
||||
if dtype == "datetime64[ns]":
|
||||
df[col] = pd.to_datetime(to_datetime(df[col], in_timezone=self.tz))
|
||||
elif dtype in dtype_mapping.keys():
|
||||
df[col] = df[col].astype(dtype_mapping[dtype])
|
||||
else:
|
||||
pass
|
||||
|
||||
return df
|
||||
|
||||
@classmethod
|
||||
def from_dataframe(
|
||||
cls, df: pd.DataFrame, tz: Optional[str] = None
|
||||
) -> "PydanticDateTimeDataFrame":
|
||||
"""Create a PydanticDateTimeDataFrame instance from a pandas DataFrame."""
|
||||
index = pd.Index([to_datetime(dt, as_string=True, in_timezone=tz) for dt in df.index])
|
||||
df.index = index
|
||||
|
||||
datetime_columns = [col for col in df.columns if is_datetime64_any_dtype(df[col])]
|
||||
|
||||
return cls(
|
||||
data=df.to_dict(orient="index"),
|
||||
dtypes={col: str(dtype) for col, dtype in df.dtypes.items()},
|
||||
tz=tz,
|
||||
datetime_columns=datetime_columns,
|
||||
)
|
||||
|
||||
|
||||
class PydanticDateTimeSeries(PydanticBaseModel):
|
||||
"""Pydantic model for validating pandas Series with datetime index in JSON format.
|
||||
|
||||
This model handles Series data serialized with orient='index', where the keys are
|
||||
datetime strings and values are the series values. Provides validation and
|
||||
conversion between JSON and pandas Series with datetime index.
|
||||
|
||||
Attributes:
|
||||
data (Dict[str, Any]): Dictionary mapping datetime strings to values.
|
||||
dtype (str): The data type of the series values.
|
||||
tz (str | None): Timezone name if the datetime index is timezone-aware.
|
||||
"""
|
||||
|
||||
data: Dict[str, Any]
|
||||
dtype: str = Field(default="float64")
|
||||
tz: Optional[str] = Field(default=None)
|
||||
|
||||
@field_validator("data", mode="after")
|
||||
@classmethod
|
||||
def validate_datetime_index(cls, v: Dict[str, Any], info: ValidationInfo) -> Dict[str, Any]:
|
||||
"""Validate that all keys can be parsed as datetime strings.
|
||||
|
||||
Args:
|
||||
v: Dictionary with datetime string keys and series values.
|
||||
|
||||
Returns:
|
||||
The validated data dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any key cannot be parsed as a datetime.
|
||||
"""
|
||||
tz = info.data.get("tz")
|
||||
if tz is not None:
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except KeyError:
|
||||
tz = None
|
||||
try:
|
||||
# Attempt to parse each key as datetime
|
||||
d = dict()
|
||||
for dt_str, value in v.items():
|
||||
d[to_datetime(dt_str, as_string=True, in_timezone=tz)] = value
|
||||
return d
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid datetime string in index: {e}")
|
||||
|
||||
@field_validator("tz")
|
||||
def validate_timezone(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that the timezone is valid."""
|
||||
if v is not None:
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid timezone: {v}")
|
||||
return v
|
||||
|
||||
def to_series(self) -> pd.Series:
|
||||
"""Convert the validated model data to a pandas Series.
|
||||
|
||||
Returns:
|
||||
A pandas Series with datetime index constructed from the model data.
|
||||
"""
|
||||
index = [to_datetime(dt, in_timezone=self.tz) for dt in list(self.data.keys())]
|
||||
|
||||
series = pd.Series(data=list(self.data.values()), index=index, dtype=self.dtype)
|
||||
return series
|
||||
|
||||
@classmethod
|
||||
def from_series(cls, series: pd.Series, tz: Optional[str] = None) -> "PydanticDateTimeSeries":
|
||||
"""Create a PydanticDateTimeSeries instance from a pandas Series.
|
||||
|
||||
Args:
|
||||
series: The pandas Series with datetime index to convert.
|
||||
|
||||
Returns:
|
||||
A new instance containing the Series data.
|
||||
|
||||
Raises:
|
||||
ValueError: If series index is not datetime type.
|
||||
|
||||
Example:
|
||||
>>> dates = pd.date_range('2024-01-01', periods=3)
|
||||
>>> s = pd.Series([1.1, 2.2, 3.3], index=dates)
|
||||
>>> model = PydanticDateTimeSeries.from_series(s)
|
||||
"""
|
||||
index = pd.Index([to_datetime(dt, as_string=True, in_timezone=tz) for dt in series.index])
|
||||
series.index = index
|
||||
|
||||
if len(index) > 0:
|
||||
tz = to_datetime(series.index[0]).timezone.name
|
||||
|
||||
return cls(
|
||||
data=series.to_dict(),
|
||||
dtype=str(series.dtype),
|
||||
tz=tz,
|
||||
)
|
||||
|
Reference in New Issue
Block a user