* Mypy: Initial support

 * Add to pre-commit (currently installs own deps, could maybe changed
   to poetry venv in the future to reuse environment and don't need
   duplicated types deps).
 * Add type hints.

* Mypy: Add missing annotations
This commit is contained in:
Dominique Lasserre
2024-11-26 22:28:05 +01:00
committed by GitHub
parent 2a163569bc
commit 1163ddb4ac
31 changed files with 637 additions and 531 deletions

View File

@@ -25,6 +25,8 @@ Notes:
- Cache files are automatically associated with the current date unless specified.
"""
from __future__ import annotations
import hashlib
import inspect
import os
@@ -32,7 +34,7 @@ import pickle
import tempfile
import threading
from datetime import date, datetime, time, timedelta
from typing import List, Optional, Union
from typing import IO, Callable, Generic, List, Optional, ParamSpec, TypeVar, Union
from akkudoktoreos.utils.datetimeutil import to_datetime, to_timedelta
from akkudoktoreos.utils.logutil import get_logger
@@ -40,15 +42,20 @@ from akkudoktoreos.utils.logutil import get_logger
logger = get_logger(__file__)
class CacheFileStoreMeta(type):
T = TypeVar("T")
Param = ParamSpec("Param")
RetType = TypeVar("RetType")
class CacheFileStoreMeta(type, Generic[T]):
"""A thread-safe implementation of CacheFileStore."""
_instances = {}
_instances: dict[CacheFileStoreMeta[T], T] = {}
_lock: threading.Lock = threading.Lock()
"""Lock object to synchronize threads on first access to CacheFileStore."""
def __call__(cls):
def __call__(cls) -> T:
"""Return CacheFileStore instance."""
with cls._lock:
if cls not in cls._instances:
@@ -80,18 +87,18 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
>>> print(cache_file.read()) # Output: 'Some data'
"""
def __init__(self):
def __init__(self) -> None:
"""Initializes the CacheFileStore instance.
This constructor sets up an empty key-value store (a dictionary) where each key
corresponds to a cache file that is associated with a given key and an optional date.
"""
self._store = {}
self._store: dict[str, tuple[IO[bytes], datetime]] = {}
self._store_lock = threading.Lock()
def _generate_cache_file_key(
self, key: str, until_datetime: Union[datetime, None]
) -> (str, datetime):
) -> tuple[str, datetime]:
"""Generates a unique cache file key based on the key and date.
The cache file key is a combination of the input key and the date (if provided),
@@ -114,7 +121,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
cache_key = hashlib.sha256(f"{key}{key_datetime}".encode("utf-8")).hexdigest()
return (f"{cache_key}", until_datetime)
def _get_file_path(self, file_obj):
def _get_file_path(self, file_obj: IO[bytes]) -> Optional[str]:
"""Retrieve the file path from a file-like object.
Args:
@@ -136,7 +143,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
until_date: Union[datetime, date, str, int, float, None] = None,
until_datetime: Union[datetime, date, str, int, float, None] = None,
with_ttl: Union[timedelta, str, int, float, None] = None,
):
) -> datetime:
"""Get until_datetime from the given options."""
if until_datetime:
until_datetime = to_datetime(until_datetime)
@@ -152,11 +159,11 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
def _is_valid_cache_item(
self,
cache_item: (),
until_datetime: datetime = None,
at_datetime: datetime = None,
before_datetime: datetime = None,
):
cache_item: tuple[IO[bytes], datetime],
until_datetime: Optional[datetime] = None,
at_datetime: Optional[datetime] = None,
before_datetime: Optional[datetime] = None,
) -> bool:
cache_file_datetime = cache_item[1] # Extract the datetime associated with the cache item
if (
(until_datetime and until_datetime == cache_file_datetime)
@@ -169,10 +176,10 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
def _search(
self,
key: str,
until_datetime: Union[datetime, date, str, int, float] = None,
at_datetime: Union[datetime, date, str, int, float] = None,
before_datetime: Union[datetime, date, str, int, float] = None,
):
until_datetime: Union[datetime, date, str, int, float, None] = None,
at_datetime: Union[datetime, date, str, int, float, None] = None,
before_datetime: Union[datetime, date, str, int, float, None] = None,
) -> Optional[tuple[str, IO[bytes], datetime]]:
"""Searches for a cached item that matches the key and falls within the datetime range.
This method looks for a cache item with a key that matches the given `key`, and whose associated
@@ -193,20 +200,23 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
otherwise returns `None`.
"""
# Convert input to datetime if they are not None
if until_datetime:
until_datetime = to_datetime(until_datetime)
if at_datetime:
at_datetime = to_datetime(at_datetime)
if before_datetime:
before_datetime = to_datetime(before_datetime)
until_datetime_dt: Optional[datetime] = None
if until_datetime is not None:
until_datetime_dt = to_datetime(until_datetime)
at_datetime_dt: Optional[datetime] = None
if at_datetime is not None:
at_datetime_dt = to_datetime(at_datetime)
before_datetime_dt: Optional[datetime] = None
if before_datetime is not None:
before_datetime_dt = to_datetime(before_datetime)
for cache_file_key, cache_item in self._store.items():
# Check if the cache file datetime matches the given criteria
if self._is_valid_cache_item(
cache_item,
until_datetime=until_datetime,
at_datetime=at_datetime,
before_datetime=before_datetime,
until_datetime=until_datetime_dt,
at_datetime=at_datetime_dt,
before_datetime=before_datetime_dt,
):
# This cache file is within the given datetime range
# Extract the datetime associated with the cache item
@@ -231,7 +241,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
mode: str = "wb+",
delete: bool = False,
suffix: Optional[str] = None,
):
) -> IO[bytes]:
"""Creates a new file-like tempfile object associated with the given key.
If a cache file with the given key and valid timedate already exists, the existing file is
@@ -262,31 +272,31 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
>>> cache_file.seek(0)
>>> print(cache_file.read()) # Output: 'Some cached data'
"""
until_datetime = self._until_datetime_by_options(
until_datetime_dt = self._until_datetime_by_options(
until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
)
cache_file_key, until_date = self._generate_cache_file_key(key, until_datetime)
cache_file_key, _ = self._generate_cache_file_key(key, until_datetime_dt)
with self._store_lock: # Synchronize access to _store
if cache_file_key in self._store:
if (cache_file_item := self._store.get(cache_file_key)) is not None:
# File already available
cache_file_obj, until_datetime = self._store.get(cache_file_key)
cache_file_obj = cache_file_item[0]
else:
cache_file_obj = tempfile.NamedTemporaryFile(
mode=mode, delete=delete, suffix=suffix
)
self._store[cache_file_key] = (cache_file_obj, until_datetime)
self._store[cache_file_key] = (cache_file_obj, until_datetime_dt)
cache_file_obj.seek(0)
return cache_file_obj
def set(
self,
key: str,
file_obj,
file_obj: IO[bytes],
until_date: Union[datetime, date, str, int, float, None] = None,
until_datetime: Union[datetime, date, str, int, float, None] = None,
with_ttl: Union[timedelta, str, int, float, None] = None,
):
) -> None:
"""Stores a file-like object in the cache under the specified key and date.
This method allows you to manually set a file-like object into the cache with a specific key
@@ -309,11 +319,11 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
Example:
>>> cache_store.set('example_file', io.BytesIO(b'Some binary data'))
"""
until_datetime = self._until_datetime_by_options(
until_datetime_dt = self._until_datetime_by_options(
until_datetime=until_datetime, until_date=until_date, with_ttl=with_ttl
)
cache_file_key, until_date = self._generate_cache_file_key(key, until_datetime)
cache_file_key, until_date = self._generate_cache_file_key(key, until_datetime_dt)
with self._store_lock: # Synchronize access to _store
if cache_file_key in self._store:
raise ValueError(f"Key already in store: `{key}`.")
@@ -327,7 +337,7 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
until_datetime: Union[datetime, date, str, int, float, None] = None,
at_datetime: Union[datetime, date, str, int, float, None] = None,
before_datetime: Union[datetime, date, str, int, float, None] = None,
):
) -> Optional[IO[bytes]]:
"""Retrieves the cache file associated with the given key and validity datetime.
If no cache file is found for the provided key and datetime, the method returns None.
@@ -374,11 +384,11 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
def delete(
self,
key,
key: str,
until_date: Union[datetime, date, str, int, float, None] = None,
until_datetime: Union[datetime, date, str, int, float, None] = None,
before_datetime: Union[datetime, date, str, int, float, None] = None,
):
) -> None:
"""Deletes the cache file associated with the given key and datetime.
This method removes the cache file from the store.
@@ -429,8 +439,10 @@ class CacheFileStore(metaclass=CacheFileStoreMeta):
logger.error(f"Error deleting cache file {file_path}: {e}")
def clear(
self, clear_all=False, before_datetime: Union[datetime, date, str, int, float, None] = None
):
self,
clear_all: bool = False,
before_datetime: Union[datetime, date, str, int, float, None] = None,
) -> None:
"""Deletes all cache files or those expiring before `before_datetime`.
Args:
@@ -500,7 +512,7 @@ def cache_in_file(
mode: str = "wb+",
delete: bool = False,
suffix: Optional[str] = None,
):
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
"""Decorator to cache the output of a function into a temporary file.
The decorator caches function output to a cache file based on its inputs as key to identify the
@@ -545,35 +557,35 @@ def cache_in_file(
>>> result = expensive_computation(until_date = date.today())
"""
def decorator(func):
def decorator(func: Callable[Param, RetType]) -> Callable[Param, RetType]:
nonlocal ignore_params, until_date, until_datetime, with_ttl, mode, delete, suffix
func_source_code = inspect.getsource(func)
def wrapper(*args, **kwargs):
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
nonlocal ignore_params, until_date, until_datetime, with_ttl, mode, delete, suffix
# Convert args to a dictionary based on the function's signature
args_names = func.__code__.co_varnames[: func.__code__.co_argcount]
args_dict = dict(zip(args_names, args))
# Search for caching parameters of function and remove
force_update = None
force_update: Optional[bool] = None
for param in ["force_update", "until_datetime", "with_ttl", "until_date"]:
if param in kwargs:
if param == "force_update":
force_update = kwargs[param]
force_update = kwargs[param] # type: ignore[assignment]
kwargs.pop("force_update")
if param == "until_datetime":
until_datetime = kwargs[param]
until_datetime = kwargs[param] # type: ignore[assignment]
until_date = None
with_ttl = None
elif param == "with_ttl":
until_datetime = None
until_date = None
with_ttl = kwargs[param]
with_ttl = kwargs[param] # type: ignore[assignment]
elif param == "until_date":
until_datetime = None
until_date = kwargs[param]
until_date = kwargs[param] # type: ignore[assignment]
with_ttl = None
kwargs.pop("until_datetime", None)
kwargs.pop("until_date", None)
@@ -589,7 +601,7 @@ def cache_in_file(
# Create key based on argument names, argument values, and function source code
key = str(args_dict) + str(kwargs_clone) + str(func_source_code)
result = None
result: Optional[RetType | bytes] = None
# Get cache file that is currently valid
cache_file = CacheFileStore().get(key)
if not force_update and cache_file is not None:
@@ -624,11 +636,11 @@ def cache_in_file(
if "b" in mode:
pickle.dump(result, cache_file)
else:
cache_file.write(result)
cache_file.write(result) # type: ignore[call-overload]
except Exception as e:
logger.info(f"Write failed: {e}")
CacheFileStore().delete(key)
return result
return result # type: ignore[return-value]
return wrapper

View File

@@ -24,19 +24,39 @@ Example usage:
import re
from datetime import date, datetime, time, timedelta, timezone
from typing import Optional, Union
from typing import Annotated, Literal, Optional, Union, overload
from zoneinfo import ZoneInfo
from timezonefinder import TimezoneFinder
@overload
def to_datetime(
date_input: Union[datetime, date, str, int, float, None],
as_string: str | Literal[True],
to_timezone: Optional[Union[ZoneInfo, str]] = None,
to_naiv: Optional[bool] = None,
to_maxtime: Optional[bool] = None,
) -> str: ...
@overload
def to_datetime(
date_input: Union[datetime, date, str, int, float, None],
as_string: Literal[False] | None = None,
to_timezone: Optional[Union[ZoneInfo, str]] = None,
to_naiv: Optional[bool] = None,
to_maxtime: Optional[bool] = None,
) -> datetime: ...
def to_datetime(
date_input: Union[datetime, date, str, int, float, None],
as_string: Optional[Union[str, bool]] = None,
to_timezone: Optional[Union[timezone, str]] = None,
to_timezone: Optional[Union[ZoneInfo, str]] = None,
to_naiv: Optional[bool] = None,
to_maxtime: Optional[bool] = None,
):
) -> str | datetime:
"""Converts a date input to a datetime object or a formatted string with timezone support.
Args:
@@ -67,7 +87,9 @@ def to_datetime(
Raises:
ValueError: If the date input is not a valid type or format.
RuntimeError: If no local timezone information available.
"""
dt_object: Optional[datetime] = None
if isinstance(date_input, datetime):
dt_object = date_input
elif isinstance(date_input, date):
@@ -104,7 +126,6 @@ def to_datetime(
dt_object = datetime.strptime(date_input, fmt)
break
except ValueError as e:
dt_object = None
continue
if dt_object is None:
raise ValueError(f"Date string {date_input} does not match any known formats.")
@@ -120,11 +141,13 @@ def to_datetime(
local_date = datetime.now().astimezone()
local_tz_name = local_date.tzname()
local_utc_offset = local_date.utcoffset()
if local_tz_name is None or local_utc_offset is None:
raise RuntimeError("Could not determine local time zone")
local_timezone = timezone(local_utc_offset, local_tz_name)
# Get target timezone
if to_timezone:
if isinstance(to_timezone, timezone):
if isinstance(to_timezone, ZoneInfo):
target_timezone = to_timezone
elif isinstance(to_timezone, str):
try:
@@ -168,7 +191,11 @@ def to_datetime(
return dt_object
def to_timedelta(input_value):
def to_timedelta(
input_value: Union[
timedelta, str, int, float, tuple[int, int, int, int], Annotated[list[int], 4]
],
) -> timedelta:
"""Converts various input types into a timedelta object.
Args:
@@ -238,7 +265,15 @@ def to_timedelta(input_value):
raise ValueError(f"Unsupported input type: {type(input_value)}")
def to_timezone(lat: float, lon: float, as_string: Optional[bool] = None):
@overload
def to_timezone(lat: float, lon: float, as_string: Literal[True]) -> str: ...
@overload
def to_timezone(lat: float, lon: float, as_string: Literal[False] | None = None) -> ZoneInfo: ...
def to_timezone(lat: float, lon: float, as_string: Optional[bool] = None) -> str | ZoneInfo:
"""Determines the timezone for a given geographic location specified by latitude and longitude.
By default, it returns a `ZoneInfo` object representing the timezone.
@@ -269,11 +304,13 @@ def to_timezone(lat: float, lon: float, as_string: Optional[bool] = None):
"""
# Initialize the static variable only once
if not hasattr(to_timezone, "timezone_finder"):
to_timezone.timezone_finder = TimezoneFinder() # static variable
# static variable
to_timezone.timezone_finder = TimezoneFinder() # type: ignore[attr-defined]
# Check and convert coordinates to timezone
tz_name: Optional[str] = None
try:
tz_name = to_timezone.timezone_finder.timezone_at(lat=lat, lng=lon)
tz_name = to_timezone.timezone_finder.timezone_at(lat=lat, lng=lon) # type: ignore[attr-defined]
if not tz_name:
raise ValueError(f"No timezone found for coordinates: latitude {lat}, longitude {lon}")
except Exception as e:

View File

@@ -1,12 +1,13 @@
import datetime
import json
import zoneinfo
from typing import Any
import numpy as np
# currently unused
def ist_dst_wechsel(tag: datetime.datetime, timezone="Europe/Berlin") -> bool:
def ist_dst_wechsel(tag: datetime.datetime, timezone: str = "Europe/Berlin") -> bool:
"""Checks if Daylight Saving Time (DST) starts or ends on a given day."""
tz = zoneinfo.ZoneInfo(timezone)
# Get the current day and the next day
@@ -20,15 +21,25 @@ def ist_dst_wechsel(tag: datetime.datetime, timezone="Europe/Berlin") -> bool:
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
@classmethod
def convert_numpy(cls, obj: Any) -> tuple[Any, bool]:
if isinstance(obj, np.ndarray):
return obj.tolist() # Convert NumPy arrays to lists
# Convert NumPy arrays to lists
return [
None if isinstance(x, (int, float)) and np.isnan(x) else x for x in obj.tolist()
], True
if isinstance(obj, np.generic):
return obj.item() # Convert NumPy scalars to native Python types
return obj.item(), True # Convert NumPy scalars to native Python types
return obj, False
def default(self, obj: Any) -> Any:
obj, converted = NumpyEncoder.convert_numpy(obj)
if converted:
return obj
return super(NumpyEncoder, self).default(obj)
@staticmethod
def dumps(data):
def dumps(data: Any) -> str:
"""Static method to serialize a Python object into a JSON string using NumpyEncoder.
Args: