diff --git a/src/akkudoktoreos/prediction/price_forecast.py b/src/akkudoktoreos/prediction/price_forecast.py index 0318cb9..6c2f9ab 100644 --- a/src/akkudoktoreos/prediction/price_forecast.py +++ b/src/akkudoktoreos/prediction/price_forecast.py @@ -3,128 +3,167 @@ import json import zoneinfo from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Sequence +from typing import Union import numpy as np import requests from akkudoktoreos.config import AppConfig, SetupIncomplete +# Initialize logger with DEBUG level -def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray: - # Check if the array fits the target shape + +def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: + """Expands an array to a specified shape using repetition.""" + # logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}") if len(target_shape) != array.ndim: - raise ValueError("Array and target shape must have the same number of dimensions") + error_msg = "Array and target shape must have the same number of dimensions" + # logger.debug(f"Validation did not succeed: {error_msg}") + raise ValueError(error_msg) - # Number of repetitions per dimension repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim)) - - # Use np.tile to expand the array expanded_array = np.tile(array, repeats) + # logger.debug(f"Expanded array shape: {expanded_array.shape}") return expanded_array class HourlyElectricityPriceForecast: def __init__( self, - source: str | Path, + source: Union[str, Path], config: AppConfig, charges: float = 0.000228, use_cache: bool = True, - ): # 228 + ) -> None: + # logger.debug("Initializing HourlyElectricityPriceForecast") self.cache_dir = config.working_dir / config.directories.cache self.use_cache = use_cache - if not self.cache_dir.is_dir(): - raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.") - - self.cache_time_file = self.cache_dir / "cache_timestamp.txt" - self.prices = self.load_data(source) self.charges = charges self.prediction_hours = config.eos.prediction_hours - def load_data(self, source: str | Path) -> list[dict[str, Any]]: + if not self.cache_dir.is_dir(): + error_msg = f"Output path does not exist: {self.cache_dir}" + # logger.debug(f"Validation did not succeed: {error_msg}") + raise SetupIncomplete(error_msg) + + self.cache_time_file = self.cache_dir / "cache_timestamp.txt" + self.prices = self.load_data(source) + + def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]: + """Loads data from a cache file or source, returns a list of price entries.""" cache_file = self.get_cache_file(source) + # logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}") + + if ( + isinstance(source, str) + and self.use_cache + and cache_file.is_file() + and not self.is_cache_expired() + ): + # logger.debug("Loading data from cache...") + with cache_file.open("r") as file: + json_data = json.load(file) + else: + # logger.debug("Fetching data from source and updating cache...") + json_data = self.fetch_and_cache_data(source, cache_file) + + return json_data.get("values", []) + + def get_cache_file(self, source: Union[str, Path]) -> Path: + """Generates a unique cache file path for the source URL.""" + url = str(source) + hash_object = hashlib.sha256(url.encode()) + hex_dig = hash_object.hexdigest() + cache_file = self.cache_dir / f"cache_{hex_dig}.json" + # logger.debug(f"Generated cache file path: {cache_file}") + return cache_file + + def is_cache_expired(self) -> bool: + """Checks if the cache has expired based on a one-hour limit.""" + if not self.cache_time_file.is_file(): + # logger.debug("Cache timestamp file does not exist; cache considered expired") + return True + + with self.cache_time_file.open("r") as file: + timestamp_str = file.read() + last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S") + cache_expired = datetime.now() - last_cache_time > timedelta(hours=1) + # logger.debug(f"Cache expired: {cache_expired}") + return cache_expired + + def update_cache_timestamp(self) -> None: + """Updates the cache timestamp to the current time.""" + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with self.cache_time_file.open("w") as file: + file.write(current_time) + + # logger.debug(f"Updated cache timestamp to {current_time}") + + def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict: + """Fetches data from a URL or file and caches it.""" if isinstance(source, str): - if cache_file.is_file() and not self.is_cache_expired() and self.use_cache: - print("Loading data from cache...") - with cache_file.open("r") as file: - json_data = json.load(file) - else: - print("Loading data from the URL...") - response = requests.get(source) - if response.status_code == 200: - json_data = response.json() - with cache_file.open("w") as file: - json.dump(json_data, file) - self.update_cache_timestamp() - else: - raise Exception(f"Error fetching data: {response.status_code}") + # logger.debug(f"Fetching data from URL: {source}") + response = requests.get(source) + if response.status_code != 200: + error_msg = f"Error fetching data: {response.status_code}" + # logger.debug(f"Validation did not succeed: {error_msg}") + raise Exception(error_msg) + + json_data = response.json() + with cache_file.open("w") as file: + json.dump(json_data, file) + self.update_cache_timestamp() elif source.is_file(): + # logger.debug(f"Loading data from file: {source}") with source.open("r") as file: json_data = json.load(file) else: - raise ValueError(f"Input is not a valid path: {source}") - return json_data["values"] + error_msg = f"Invalid input path: {source}" + # logger.debug(f"Validation did not succeed: {error_msg}") + raise ValueError(error_msg) - def get_cache_file(self, url: str | Path) -> Path: - if isinstance(url, Path): - url = str(url) - hash_object = hashlib.sha256(url.encode()) - hex_dig = hash_object.hexdigest() - return self.cache_dir / f"cache_{hex_dig}.json" - - def is_cache_expired(self) -> bool: - if not self.cache_time_file.is_file(): - return True - with self.cache_time_file.open("r") as file: - timestamp_str = file.read() - last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S") - return datetime.now() - last_cache_time > timedelta(hours=1) - - def update_cache_timestamp(self) -> None: - with self.cache_time_file.open("w") as file: - file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + return json_data def get_price_for_date(self, date_str: str) -> np.ndarray: - """Returns all prices for the specified date, including the price from 00:00 of the previous day.""" - # Convert date string to datetime object + """Retrieves all prices for a specified date, adding the previous day's last price if needed.""" + # logger.debug(f"Getting prices for date: {date_str}") date_obj = datetime.strptime(date_str, "%Y-%m-%d") + previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d") - # Calculate the previous day - previous_day = date_obj - timedelta(days=1) - previous_day_str = previous_day.strftime("%Y-%m-%d") - - # Extract the price from 00:00 of the previous day - last_price_of_previous_day = [ + previous_day_prices = [ entry["marketpriceEurocentPerKWh"] + self.charges for entry in self.prices if previous_day_str in entry["end"] - ][-1] + ] + last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0 - # Extract all prices for the specified date date_prices = [ entry["marketpriceEurocentPerKWh"] + self.charges for entry in self.prices if date_str in entry["end"] ] - print(f"getPrice: {len(date_prices)}") - # Add the last price of the previous day at the start of the list - if len(date_prices) == 23: + if len(date_prices) < 24: date_prices.insert(0, last_price_of_previous_day) - return np.array(date_prices) / (1000.0 * 100.0) + self.charges + # logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}") + return np.round(np.array(date_prices) / 100000.0, 10) def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray: - """Returns all prices between the start and end dates.""" - print(start_date_str) - print(end_date_str) - start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) - end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) - start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) - end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) + """Retrieves all prices within a specified date range.""" + # logger.debug(f"Getting prices from {start_date_str} to {end_date_str}") + start_date = ( + datetime.strptime(start_date_str, "%Y-%m-%d") + .replace(tzinfo=timezone.utc) + .astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) + ) + end_date = ( + datetime.strptime(end_date_str, "%Y-%m-%d") + .replace(tzinfo=timezone.utc) + .astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) + ) - price_list: list[float] = [] + price_list = [] while start_date < end_date: date_str = start_date.strftime("%Y-%m-%d") @@ -134,10 +173,9 @@ class HourlyElectricityPriceForecast: price_list.extend(daily_prices) start_date += timedelta(days=1) - price_list_np = np.array(price_list) - - # If prediction hours are greater than 0, reshape the price list if self.prediction_hours > 0: - price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,)) + # logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}") + price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,)) - return price_list_np + # logger.debug(f"Total prices retrieved for date range: {len(price_list)}") + return np.round(np.array(price_list), 10)