fix: logging, prediction update, multiple bugs (#584)
Some checks failed
docker-build / platform-excludes (push) Has been cancelled
pre-commit / pre-commit (push) Has been cancelled
Run Pytest on Pull Request / test (push) Has been cancelled
docker-build / build (push) Has been cancelled
docker-build / merge (push) Has been cancelled
Close stale pull requests/issues / Find Stale issues and PRs (push) Has been cancelled

* Fix logging configuration issues that made logging stop operation. Switch to Loguru
  logging (from Python logging). Enable console and file logging with different log levels.
  Add logging documentation.

* Fix logging configuration and EOS configuration out of sync. Added tracking support
  for nested value updates of Pydantic models. This used to update the logging configuration
  when the EOS configurationm for logging is changed. Should keep logging config and EOS
  config in sync as long as all changes to the EOS logging configuration are done by
  set_nested_value(), which is the case for the REST API.

* Fix energy management task looping endlessly after the second update when trying to update
  the last_update datetime.

* Fix get_nested_value() to correctly take values from the dicts in a Pydantic model instance.

* Fix usage of model classes instead of model instances in nested value access when evaluation
  the value type that is associated to each key.

* Fix illegal json format in prediction documentation for PVForecastAkkudoktor provider.

* Fix documentation qirks and add EOS Connect to integrations.

* Support deprecated fields in configuration in documentation generation and EOSdash.

* Enhance EOSdash demo to show BrightSky humidity data (that is often missing)

* Update documentation reference to German EOS installation videos.

Signed-off-by: Bobby Noelte <b0661n0e17e@gmail.com>
This commit is contained in:
Bobby Noelte
2025-06-10 22:00:28 +02:00
committed by GitHub
parent 9d46f3c08e
commit bd38b3c5ef
70 changed files with 5927 additions and 5035 deletions

View File

@@ -27,17 +27,14 @@ from typing import (
)
import cachebox
from loguru import logger
from pendulum import DateTime, Duration
from pydantic import Field
from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin
from akkudoktoreos.core.logging import get_logger
from akkudoktoreos.core.pydantic import PydanticBaseModel
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
logger = get_logger(__name__)
# ---------------------------------
# In-Memory Caching Functionality
# ---------------------------------

View File

@@ -13,13 +13,10 @@ Classes:
import threading
from typing import Any, ClassVar, Dict, Optional, Type
from loguru import logger
from pendulum import DateTime
from pydantic import computed_field
from akkudoktoreos.core.logging import get_logger
logger = get_logger(__name__)
config_eos: Any = None
measurement_eos: Any = None
prediction_eos: Any = None

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, over
import numpy as np
import pandas as pd
import pendulum
from loguru import logger
from numpydantic import NDArray, Shape
from pendulum import DateTime, Duration
from pydantic import (
@@ -31,7 +32,6 @@ from pydantic import (
)
from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin, StartMixin
from akkudoktoreos.core.logging import get_logger
from akkudoktoreos.core.pydantic import (
PydanticBaseModel,
PydanticDateTimeData,
@@ -39,8 +39,6 @@ from akkudoktoreos.core.pydantic import (
)
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration
logger = get_logger(__name__)
class DataBase(ConfigMixin, StartMixin, PydanticBaseModel):
"""Base class for handling generic data.

View File

@@ -1,10 +1,6 @@
from collections.abc import Callable
from typing import Any, Optional
from akkudoktoreos.core.logging import get_logger
logger = get_logger(__name__)
class classproperty:
"""A decorator to define a read-only property at the class level.

View File

@@ -2,6 +2,7 @@ import traceback
from typing import Any, ClassVar, Optional
import numpy as np
from loguru import logger
from numpydantic import NDArray, Shape
from pendulum import DateTime
from pydantic import ConfigDict, Field, computed_field, field_validator, model_validator
@@ -9,7 +10,6 @@ from typing_extensions import Self
from akkudoktoreos.core.cache import CacheUntilUpdateStore
from akkudoktoreos.core.coreabc import ConfigMixin, PredictionMixin, SingletonMixin
from akkudoktoreos.core.logging import get_logger
from akkudoktoreos.core.pydantic import ParametersBaseModel, PydanticBaseModel
from akkudoktoreos.devices.battery import Battery
from akkudoktoreos.devices.generic import HomeAppliance
@@ -17,8 +17,6 @@ from akkudoktoreos.devices.inverter import Inverter
from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime
from akkudoktoreos.utils.utils import NumpyEncoder
logger = get_logger(__name__)
class EnergyManagementParameters(ParametersBaseModel):
pv_prognose_wh: list[float] = Field(
@@ -283,6 +281,8 @@ class EnergyManagement(SingletonMixin, ConfigMixin, PredictionMixin, PydanticBas
self.prediction.update_data(force_enable=force_enable, force_update=force_update)
# TODO: Create optimisation problem that calls into devices.update_data() for simulations.
logger.info("Energy management run (crippled version - prediction update only)")
def manage_energy(self) -> None:
"""Repeating task for managing energy.
@@ -302,6 +302,7 @@ class EnergyManagement(SingletonMixin, ConfigMixin, PredictionMixin, PydanticBas
Note: The task maintains the interval even if some intervals are missed.
"""
current_datetime = to_datetime()
interval = self.config.ems.interval # interval maybe changed in between
if EnergyManagement._last_datetime is None:
# Never run before
@@ -316,13 +317,13 @@ class EnergyManagement(SingletonMixin, ConfigMixin, PredictionMixin, PydanticBas
logger.error(message)
return
if self.config.ems.interval is None or self.config.ems.interval == float("nan"):
if interval is None or interval == float("nan"):
# No Repetition
return
if (
compare_datetimes(current_datetime, self._last_datetime).time_diff
< self.config.ems.interval
compare_datetimes(current_datetime, EnergyManagement._last_datetime).time_diff
< interval
):
# Wait for next run
return
@@ -337,9 +338,9 @@ class EnergyManagement(SingletonMixin, ConfigMixin, PredictionMixin, PydanticBas
# Remember the energy management run - keep on interval even if we missed some intervals
while (
compare_datetimes(current_datetime, EnergyManagement._last_datetime).time_diff
>= self.config.ems.interval
>= interval
):
EnergyManagement._last_datetime.add(seconds=self.config.ems.interval)
EnergyManagement._last_datetime = EnergyManagement._last_datetime.add(seconds=interval)
def set_start_hour(self, start_hour: Optional[int] = None) -> None:
"""Sets start datetime to given hour.

View File

@@ -1,20 +1,3 @@
"""Abstract and base classes for logging."""
import logging
def logging_str_to_level(level_str: str) -> int:
"""Convert log level string to logging level."""
if level_str == "DEBUG":
level = logging.DEBUG
elif level_str == "INFO":
level = logging.INFO
elif level_str == "WARNING":
level = logging.WARNING
elif level_str == "CRITICAL":
level = logging.CRITICAL
elif level_str == "ERROR":
level = logging.ERROR
else:
raise ValueError(f"Unknown loggin level: {level_str}")
return level
LOGGING_LEVELS: list[str] = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]

View File

@@ -1,95 +1,241 @@
"""Utility functions for handling logging tasks.
Functions:
----------
- get_logger: Creates and configures a logger with console and optional rotating file logging.
Example usage:
--------------
# Logger setup
>>> logger = get_logger(__name__, log_file="app.log", logging_level="DEBUG")
>>> logger.info("Logging initialized.")
Notes:
------
- The logger supports rotating log files to prevent excessive log file size.
"""
"""Utility for configuring Loguru loggers."""
import json
import logging as pylogging
import os
from logging.handlers import RotatingFileHandler
from typing import Optional
import re
import sys
from pathlib import Path
from types import FrameType
from typing import Any, List, Optional
from akkudoktoreos.core.logabc import logging_str_to_level
import pendulum
from loguru import logger
from akkudoktoreos.core.logabc import LOGGING_LEVELS
def get_logger(
name: str,
log_file: Optional[str] = None,
logging_level: Optional[str] = None,
max_bytes: int = 5000000,
backup_count: int = 5,
) -> pylogging.Logger:
"""Creates and configures a logger with a given name.
class InterceptHandler(pylogging.Handler):
"""A logging handler that redirects standard Python logging messages to Loguru.
The logger supports logging to both the console and an optional log file. File logging is
handled by a rotating file handler to prevent excessive log file size.
This handler ensures consistency between the `logging` module and Loguru by intercepting
logs sent to the standard logging system and re-emitting them through Loguru with proper
formatting and context (including exception info and call depth).
Attributes:
loglevel_mapping (dict): Mapping from standard logging levels to Loguru level names.
"""
loglevel_mapping: dict[int, str] = {
50: "CRITICAL",
40: "ERROR",
30: "WARNING",
20: "INFO",
10: "DEBUG",
5: "TRACE",
0: "NOTSET",
}
def emit(self, record: pylogging.LogRecord) -> None:
"""Emits a logging record by forwarding it to Loguru with preserved metadata.
Args:
record (logging.LogRecord): A record object containing log message and metadata.
"""
try:
level = logger.level(record.levelname).name
except AttributeError:
level = self.loglevel_mapping.get(record.levelno, "INFO")
frame: Optional[FrameType] = pylogging.currentframe()
depth: int = 2
while frame and frame.f_code.co_filename == pylogging.__file__:
frame = frame.f_back
depth += 1
log = logger.bind(request_id="app")
log.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
console_handler_id = None
file_handler_id = None
def track_logging_config(config_eos: Any, path: str, old_value: Any, value: Any) -> None:
"""Track logging config changes."""
global console_handler_id, file_handler_id
if not path.startswith("logging"):
raise ValueError(f"Logging shall not track '{path}'")
if not config_eos.logging.console_level:
# No value given - check environment value - may also be None
config_eos.logging.console_level = os.getenv("EOS_LOGGING__LEVEL")
if not config_eos.logging.file_level:
# No value given - check environment value - may also be None
config_eos.logging.file_level = os.getenv("EOS_LOGGING__LEVEL")
# Remove handlers
if console_handler_id:
try:
logger.remove(console_handler_id)
except Exception as e:
logger.debug("Exception on logger.remove: {}", e, exc_info=True)
console_handler_id = None
if file_handler_id:
try:
logger.remove(file_handler_id)
except Exception as e:
logger.debug("Exception on logger.remove: {}", e, exc_info=True)
file_handler_id = None
# Create handlers with new configuration
# Always add console handler
if config_eos.logging.console_level not in LOGGING_LEVELS:
logger.error(
f"Invalid console log level '{config_eos.logging.console_level} - forced to INFO'."
)
config_eos.logging.console_level = "INFO"
console_handler_id = logger.add(
sys.stderr,
enqueue=True,
backtrace=True,
level=config_eos.logging.console_level,
# format=_console_format
)
# Add file handler
if config_eos.logging.file_level and config_eos.logging.file_path:
if config_eos.logging.file_level not in LOGGING_LEVELS:
logger.error(
f"Invalid file log level '{config_eos.logging.console_level}' - forced to INFO."
)
config_eos.logging.file_level = "INFO"
file_handler_id = logger.add(
sink=config_eos.logging.file_path,
rotation="100 MB",
retention="3 days",
enqueue=True,
backtrace=True,
level=config_eos.logging.file_level,
serialize=True, # JSON dict formatting
# format=_file_format
)
# Redirect standard logging to Loguru
pylogging.basicConfig(handlers=[InterceptHandler()], level=0)
# Redirect uvicorn and fastapi logging to Loguru
pylogging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
for pylogger_name in ["uvicorn", "uvicorn.error", "fastapi"]:
pylogger = pylogging.getLogger(pylogger_name)
pylogger.handlers = [InterceptHandler()]
pylogger.propagate = False
logger.info(
f"Logger reconfigured - console: {config_eos.logging.console_level}, file: {config_eos.logging.file_level}."
)
def read_file_log(
log_path: Path,
limit: int = 100,
level: Optional[str] = None,
contains: Optional[str] = None,
regex: Optional[str] = None,
from_time: Optional[str] = None,
to_time: Optional[str] = None,
tail: bool = False,
) -> List[dict]:
"""Read and filter structured log entries from a JSON-formatted log file.
Args:
name (str): The name of the logger, typically `__name__` from the calling module.
log_file (Optional[str]): Path to the log file for file logging. If None, no file logging is done.
logging_level (Optional[str]): Logging level (e.g., "INFO", "DEBUG"). Defaults to "INFO".
max_bytes (int): Maximum size in bytes for log file before rotation. Defaults to 5 MB.
backup_count (int): Number of backup log files to keep. Defaults to 5.
log_path (Path): Path to the JSON-formatted log file.
limit (int, optional): Maximum number of log entries to return. Defaults to 100.
level (Optional[str], optional): Filter logs by log level (e.g., "INFO", "ERROR"). Defaults to None.
contains (Optional[str], optional): Filter logs that contain this substring in their message. Case-insensitive. Defaults to None.
regex (Optional[str], optional): Filter logs whose message matches this regular expression. Defaults to None.
from_time (Optional[str], optional): ISO 8601 datetime string to filter logs not earlier than this time. Defaults to None.
to_time (Optional[str], optional): ISO 8601 datetime string to filter logs not later than this time. Defaults to None.
tail (bool, optional): If True, read the last lines of the file (like `tail -n`). Defaults to False.
Returns:
logging.Logger: Configured logger instance.
List[dict]: A list of filtered log entries as dictionaries.
Example:
logger = get_logger(__name__, log_file="app.log", logging_level="DEBUG")
logger.info("Application started")
Raises:
FileNotFoundError: If the log file does not exist.
ValueError: If the datetime strings are invalid or improperly formatted.
Exception: For other unforeseen I/O or parsing errors.
"""
# Create a logger with the specified name
logger = pylogging.getLogger(name)
logger.propagate = True
# This is already supported by pydantic-settings in LoggingCommonSettings, however in case
# loading the config itself fails and to set the level before we load the config, we set it here manually.
if logging_level is None and (env_level := os.getenv("EOS_LOGGING__LEVEL")) is not None:
logging_level = env_level
if logging_level is not None:
level = logging_str_to_level(logging_level)
logger.setLevel(level)
if not log_path.exists():
raise FileNotFoundError("Log file not found")
# The log message format
formatter = pylogging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
try:
from_dt = pendulum.parse(from_time) if from_time else None
to_dt = pendulum.parse(to_time) if to_time else None
except Exception as e:
raise ValueError(f"Invalid date/time format: {e}")
# Prevent loggers from being added multiple times
# There may already be a logger from pytest
if not logger.handlers:
# Create a console handler with a standard output stream
console_handler = pylogging.StreamHandler()
if logging_level is not None:
console_handler.setLevel(level)
console_handler.setFormatter(formatter)
regex_pattern = re.compile(regex) if regex else None
# Add the console handler to the logger
logger.addHandler(console_handler)
def matches_filters(log: dict) -> bool:
if level and log.get("level", {}).get("name") != level.upper():
return False
if contains and contains.lower() not in log.get("message", "").lower():
return False
if regex_pattern and not regex_pattern.search(log.get("message", "")):
return False
if from_dt or to_dt:
try:
log_time = pendulum.parse(log["time"])
except Exception:
return False
if from_dt and log_time < from_dt:
return False
if to_dt and log_time > to_dt:
return False
return True
if log_file and len(logger.handlers) < 2: # We assume a console logger to be the first logger
# If a log file path is specified, create a rotating file handler
matched_logs = []
lines: list[str] = []
# Ensure the log directory exists
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
if tail:
with log_path.open("rb") as f:
f.seek(0, 2)
end = f.tell()
buffer = bytearray()
pointer = end
# Create a rotating file handler
file_handler = RotatingFileHandler(log_file, maxBytes=max_bytes, backupCount=backup_count)
if logging_level is not None:
file_handler.setLevel(level)
file_handler.setFormatter(formatter)
while pointer > 0 and len(lines) < limit * 5:
pointer -= 1
f.seek(pointer)
byte = f.read(1)
if byte == b"\n":
if buffer:
line = buffer[::-1].decode("utf-8", errors="ignore")
lines.append(line)
buffer.clear()
else:
buffer.append(byte[0])
if buffer:
line = buffer[::-1].decode("utf-8", errors="ignore")
lines.append(line)
lines = lines[::-1]
else:
with log_path.open("r", encoding="utf-8", newline=None) as f_txt:
lines = f_txt.readlines()
# Add the file handler to the logger
logger.addHandler(file_handler)
for line in lines:
if not line.strip():
continue
try:
log = json.loads(line)
except json.JSONDecodeError:
continue
if matches_filters(log):
matched_logs.append(log)
if len(matched_logs) >= limit:
break
return logger
return matched_logs

View File

@@ -3,13 +3,13 @@
Kept in an extra module to avoid cyclic dependencies on package import.
"""
import logging
from pathlib import Path
from typing import Optional
from pydantic import Field, computed_field, field_validator
from akkudoktoreos.config.configabc import SettingsBaseModel
from akkudoktoreos.core.logabc import logging_str_to_level
from akkudoktoreos.core.logabc import LOGGING_LEVELS
class LoggingCommonSettings(SettingsBaseModel):
@@ -17,27 +17,47 @@ class LoggingCommonSettings(SettingsBaseModel):
level: Optional[str] = Field(
default=None,
description="EOS default logging level.",
examples=["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"],
deprecated="This is deprecated. Use console_level and file_level instead.",
)
# Validators
@field_validator("level", mode="after")
@classmethod
def set_default_logging_level(cls, value: Optional[str]) -> Optional[str]:
if isinstance(value, str) and value.upper() == "NONE":
value = None
if value is None:
return None
level = logging_str_to_level(value)
logging.getLogger().setLevel(level)
return value
console_level: Optional[str] = Field(
default=None,
description="Logging level when logging to console.",
examples=LOGGING_LEVELS,
)
file_level: Optional[str] = Field(
default=None,
description="Logging level when logging to file.",
examples=LOGGING_LEVELS,
)
# Computed fields
@computed_field # type: ignore[prop-decorator]
@property
def root_level(self) -> str:
"""Root logger logging level."""
level = logging.getLogger().getEffectiveLevel()
level_name = logging.getLevelName(level)
return level_name
def file_path(self) -> Optional[Path]:
"""Computed log file path based on data output path."""
try:
path = SettingsBaseModel.config.general.data_output_path / "eos.log"
except:
# Config may not be fully set up
path = None
return path
# Validators
@field_validator("console_level", "file_level", mode="after")
@classmethod
def validate_level(cls, value: Optional[str]) -> Optional[str]:
"""Validate logging level string."""
if value is None:
# Nothing to set
return None
if isinstance(value, str):
level = value.upper()
if level == "NONE":
return None
if level not in LOGGING_LEVELS:
raise ValueError(f"Logging level {value} not supported")
value = level
else:
raise TypeError(f"Invalid {type(value)} of logging level {value}")
return value

View File

@@ -12,20 +12,35 @@ Key Features:
pandas DataFrames and Series with datetime indexes.
"""
import inspect
import json
import re
import uuid
import weakref
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
Union,
get_args,
get_origin,
)
from zoneinfo import ZoneInfo
import pandas as pd
import pendulum
from loguru import logger
from pandas.api.types import is_datetime64_any_dtype
from pydantic import (
AwareDatetime,
BaseModel,
ConfigDict,
Field,
PrivateAttr,
RootModel,
TypeAdapter,
ValidationError,
@@ -35,6 +50,10 @@ from pydantic import (
from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration
# Global weakref dictionary to hold external state per model instance
# Used as a workaround for PrivateAttr not working in e.g. Mixin Classes
_model_private_state: "weakref.WeakKeyDictionary[Union[PydanticBaseModel, PydanticModelNestedValueMixin], Dict[str, Any]]" = weakref.WeakKeyDictionary()
def merge_models(source: BaseModel, update_dict: dict[str, Any]) -> dict[str, Any]:
def deep_update(source_dict: dict[str, Any], update_dict: dict[str, Any]) -> dict[str, Any]:
@@ -83,13 +102,164 @@ class PydanticTypeAdapterDateTime(TypeAdapter[pendulum.DateTime]):
class PydanticModelNestedValueMixin:
"""A mixin providing methods to get and set nested values within a Pydantic model.
"""A mixin providing methods to get, set and track nested values within a Pydantic model.
The methods use a '/'-separated path to denote the nested values.
Supports handling `Optional`, `List`, and `Dict` types, ensuring correct initialization of
missing attributes.
Example:
class Address(PydanticBaseModel):
city: str
class User(PydanticBaseModel):
name: str
address: Address
def on_city_change(old, new, path):
print(f"{path}: {old} -> {new}")
user = User(name="Alice", address=Address(city="NY"))
user.track_nested_value("address/city", on_city_change)
user.set_nested_value("address/city", "LA") # triggers callback
"""
def track_nested_value(self, path: str, callback: Callable[[Any, str, Any, Any], None]) -> None:
"""Register a callback for a specific path (or subtree).
Callback triggers if set path is equal or deeper.
Args:
path (str): '/'-separated path to track.
callback (callable): Function called as callback(model_instance, set_path, old_value, new_value).
"""
try:
self._validate_path_structure(path)
pass
except:
raise ValueError(f"Path '{path}' is invalid")
path = path.strip("/")
# Use private data workaround
# Should be:
# _nested_value_callbacks: dict[str, list[Callable[[str, Any, Any], None]]]
# = PrivateAttr(default_factory=dict)
nested_value_callbacks = get_private_attr(self, "nested_value_callbacks", dict())
if path not in nested_value_callbacks:
nested_value_callbacks[path] = []
nested_value_callbacks[path].append(callback)
set_private_attr(self, "nested_value_callbacks", nested_value_callbacks)
logger.debug("Nested value callbacks {}", nested_value_callbacks)
def _validate_path_structure(self, path: str) -> None:
"""Validate that a '/'-separated path is structurally valid for this model.
Checks that each segment of the path corresponds to a field or index in the model's type structure,
without requiring that all intermediate values are currently initialized. This method is intended
to ensure that the path could be valid for nested access or assignment, according to the model's
class definition.
Args:
path (str): The '/'-separated attribute/index path to validate (e.g., "address/city" or "items/0/value").
Raises:
ValueError: If any segment of the path does not correspond to a valid field in the model,
or an invalid transition is made (such as an attribute on a non-model).
Example:
class Address(PydanticBaseModel):
city: str
class User(PydanticBaseModel):
name: str
address: Address
user = User(name="Alice", address=Address(city="NY"))
user._validate_path_structure("address/city") # OK
user._validate_path_structure("address/zipcode") # Raises ValueError
"""
path_elements = path.strip("/").split("/")
# The model we are currently working on
model: Any = self
# The model we get the type information from. It is a pydantic BaseModel
parent: BaseModel = model
# The field that provides type information for the current key
# Fields may have nested types that translates to a sequence of keys, not just one
# - my_field: Optional[list[OtherModel]] -> e.g. "myfield/0" for index 0
# parent_key = ["myfield",] ... ["myfield", "0"]
# parent_key_types = [list, OtherModel]
parent_key: list[str] = []
parent_key_types: list = []
for i, key in enumerate(path_elements):
is_final_key = i == len(path_elements) - 1
# Add current key to parent key to enable nested type tracking
parent_key.append(key)
# Get next value
next_value = None
if isinstance(model, BaseModel):
# Track parent and key for possible assignment later
parent = model
parent_key = [
key,
]
parent_key_types = self._get_key_types(model.__class__, key)
# If this is the final key, set the value
if is_final_key:
return
# Attempt to access the next attribute, handling None values
next_value = getattr(model, key, None)
# Handle missing values (initialize dict/list/model if necessary)
if next_value is None:
next_type = parent_key_types[len(parent_key) - 1]
next_value = self._initialize_value(next_type)
elif isinstance(model, list):
# Handle lists
try:
idx = int(key)
except Exception as e:
raise IndexError(
f"Invalid list index '{key}' at '{path}': key = '{key}'; parent = '{parent}', parent_key = '{parent_key}'; model = '{model}'; {e}"
)
# Get next type from parent key type information
next_type = parent_key_types[len(parent_key) - 1]
if len(model) > idx:
next_value = model[idx]
else:
return
if is_final_key:
return
elif isinstance(model, dict):
# Handle dictionaries (auto-create missing keys)
# Get next type from parent key type information
next_type = parent_key_types[len(parent_key) - 1]
if is_final_key:
return
if key not in model:
return
else:
next_value = model[key]
else:
raise KeyError(f"Key '{key}' not found in model.")
# Move deeper
model = next_value
def get_nested_value(self, path: str) -> Any:
"""Retrieve a nested value from the model using a '/'-separated path.
@@ -128,6 +298,11 @@ class PydanticModelNestedValueMixin:
model = model[int(key)]
except (ValueError, IndexError) as e:
raise IndexError(f"Invalid list index at '{path}': {key}; {e}")
elif isinstance(model, dict):
try:
model = model[key]
except Exception as e:
raise KeyError(f"Invalid dict key at '{path}': {key}; {e}")
elif isinstance(model, BaseModel):
model = getattr(model, key)
else:
@@ -142,6 +317,8 @@ class PydanticModelNestedValueMixin:
Automatically initializes missing `Optional`, `Union`, `dict`, and `list` fields if necessary.
If a missing field cannot be initialized, raises an exception.
Triggers the callbacks registered by track_nested_value().
Args:
path (str): A '/'-separated path to the nested attribute (e.g., "key1/key2/0").
value (Any): The new value to set.
@@ -170,6 +347,44 @@ class PydanticModelNestedValueMixin:
print(user.settings) # Output: {'theme': 'dark'}
```
"""
path = path.strip("/")
# Store old value (if possible)
try:
old_value = self.get_nested_value(path)
except Exception as e:
# We can not get the old value
# raise ValueError(f"Can not get old (current) value of '{path}': {e}") from e
old_value = None
# Proceed with core logic
self._set_nested_value(path, value)
# Trigger all callbacks whose path is a prefix of set path
triggered = set()
nested_value_callbacks = get_private_attr(self, "nested_value_callbacks", dict())
for cb_path, callbacks in nested_value_callbacks.items():
# Match: cb_path == path, or cb_path is a prefix (parent) of path
pass
if path == cb_path or path.startswith(cb_path + "/"):
for cb in callbacks:
# Prevent duplicate calls
if (cb_path, id(cb)) not in triggered:
cb(self, path, old_value, value)
triggered.add((cb_path, id(cb)))
def _set_nested_value(self, path: str, value: Any) -> None:
"""Set a nested value core logic.
Args:
path (str): A '/'-separated path to the nested attribute (e.g., "key1/key2/0").
value (Any): The new value to set.
Raises:
KeyError: If a key is not found in the model.
IndexError: If a list index is out of bounds or invalid.
ValueError: If a validation error occurs.
TypeError: If a missing field cannot be initialized.
"""
path_elements = path.strip("/").split("/")
# The model we are currently working on
model: Any = self
@@ -191,6 +406,13 @@ class PydanticModelNestedValueMixin:
# Get next value
next_value = None
if isinstance(model, BaseModel):
# Track parent and key for possible assignment later
parent = model
parent_key = [
key,
]
parent_key_types = self._get_key_types(model.__class__, key)
# If this is the final key, set the value
if is_final_key:
try:
@@ -199,13 +421,6 @@ class PydanticModelNestedValueMixin:
raise ValueError(f"Error updating model: {e}") from e
return
# Track parent and key for possible assignment later
parent = model
parent_key = [
key,
]
parent_key_types = self._get_key_types(model, key)
# Attempt to access the next attribute, handling None values
next_value = getattr(model, key, None)
@@ -227,7 +442,7 @@ class PydanticModelNestedValueMixin:
idx = int(key)
except Exception as e:
raise IndexError(
f"Invalid list index '{key}' at '{path}': key = {key}; parent = {parent}, parent_key = {parent_key}; model = {model}; {e}"
f"Invalid list index '{key}' at '{path}': key = '{key}'; parent = '{parent}', parent_key = '{parent_key}'; model = '{model}'; {e}"
)
# Get next type from parent key type information
@@ -309,6 +524,9 @@ class PydanticModelNestedValueMixin:
Raises:
TypeError: If the key does not exist or lacks a valid type annotation.
"""
if not inspect.isclass(model):
raise TypeError(f"Model '{model}' is not of class type.")
if key not in model.model_fields:
raise TypeError(f"Field '{key}' does not exist in model '{model.__name__}'.")
@@ -408,11 +626,13 @@ class PydanticModelNestedValueMixin:
raise TypeError(f"Unsupported type hint '{type_hint}' for initialization.")
class PydanticBaseModel(BaseModel, PydanticModelNestedValueMixin):
"""Base model class with automatic serialization and deserialization of `pendulum.DateTime` fields.
class PydanticBaseModel(PydanticModelNestedValueMixin, BaseModel):
"""Base model with pendulum datetime support, nested value utilities, and stable hashing.
This model serializes pendulum.DateTime objects to ISO 8601 strings and
deserializes ISO 8601 strings to pendulum.DateTime objects.
This class provides:
- ISO 8601 serialization/deserialization of `pendulum.DateTime` fields.
- Nested attribute access and mutation via `PydanticModelNestedValueMixin`.
- A consistent hash using a UUID for use in sets and as dictionary keys
"""
# Enable custom serialization globally in config
@@ -422,6 +642,17 @@ class PydanticBaseModel(BaseModel, PydanticModelNestedValueMixin):
validate_assignment=True,
)
_uuid: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
"""str: A private UUID string generated on instantiation, used for hashing."""
def __hash__(self) -> int:
"""Returns a stable hash based on the instance's UUID.
Returns:
int: Hash value derived from the model's UUID.
"""
return hash(self._uuid)
@field_validator("*", mode="before")
def validate_and_convert_pendulum(cls, value: Any, info: ValidationInfo) -> Any:
"""Validator to convert fields of type `pendulum.DateTime`.
@@ -839,3 +1070,27 @@ class PydanticDateTimeSeries(PydanticBaseModel):
class ParametersBaseModel(PydanticBaseModel):
model_config = ConfigDict(extra="forbid")
def set_private_attr(
model: Union[PydanticBaseModel, PydanticModelNestedValueMixin], key: str, value: Any
) -> None:
"""Set a private attribute for a model instance (not stored in model itself)."""
if model not in _model_private_state:
_model_private_state[model] = {}
_model_private_state[model][key] = value
def get_private_attr(
model: Union[PydanticBaseModel, PydanticModelNestedValueMixin], key: str, default: Any = None
) -> Any:
"""Get a private attribute or return default."""
return _model_private_state.get(model, {}).get(key, default)
def del_private_attr(
model: Union[PydanticBaseModel, PydanticModelNestedValueMixin], key: str
) -> None:
"""Delete a private attribute."""
if model in _model_private_state and key in _model_private_state[model]:
del _model_private_state[model][key]