mirror of
				https://github.com/Akkudoktor-EOS/EOS.git
				synced 2025-10-30 22:36:21 +00:00 
			
		
		
		
	Bugfixes
This commit is contained in:
		| @@ -3,133 +3,98 @@ import json | ||||
| import zoneinfo | ||||
| from datetime import datetime, timedelta, timezone | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
| from typing import Any, Sequence | ||||
|  | ||||
| import numpy as np | ||||
| import requests | ||||
|  | ||||
| from akkudoktoreos.config import AppConfig, SetupIncomplete | ||||
|  | ||||
| # Initialize logger with DEBUG level | ||||
|  | ||||
|  | ||||
| def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: | ||||
|     """Expands an array to a specified shape using repetition.""" | ||||
|     # logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}") | ||||
| def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray: | ||||
|     # Check if the array fits the target shape | ||||
|     if len(target_shape) != array.ndim: | ||||
|         error_msg = "Array and target shape must have the same number of dimensions" | ||||
|         # logger.debug(f"Validation did not succeed: {error_msg}") | ||||
|         raise ValueError(error_msg) | ||||
|         raise ValueError("Array and target shape must have the same number of dimensions") | ||||
|  | ||||
|     # Number of repetitions per dimension | ||||
|     repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim)) | ||||
|  | ||||
|     # Use np.tile to expand the array | ||||
|     expanded_array = np.tile(array, repeats) | ||||
|     #  logger.debug(f"Expanded array shape: {expanded_array.shape}") | ||||
|     return expanded_array | ||||
|  | ||||
|  | ||||
| class HourlyElectricityPriceForecast: | ||||
|     def __init__( | ||||
|         self, | ||||
|         source: Union[str, Path], | ||||
|         source: str | Path, | ||||
|         config: AppConfig, | ||||
|         charges: float = 0.000228, | ||||
|         use_cache: bool = True, | ||||
|     ) -> None: | ||||
|         # logger.debug("Initializing HourlyElectricityPriceForecast") | ||||
|     ):  # 228 | ||||
|         self.cache_dir = config.working_dir / config.directories.cache | ||||
|         self.use_cache = use_cache | ||||
|         self.charges = charges | ||||
|         self.prediction_hours = config.eos.prediction_hours | ||||
|  | ||||
|         if not self.cache_dir.is_dir(): | ||||
|             error_msg = f"Output path does not exist: {self.cache_dir}" | ||||
|             # logger.debug(f"Validation did not succeed: {error_msg}") | ||||
|             raise SetupIncomplete(error_msg) | ||||
|             raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.") | ||||
|  | ||||
|         self.cache_time_file = self.cache_dir / "cache_timestamp.txt" | ||||
|         self.prices = self.load_data(source) | ||||
|         self.charges = charges | ||||
|         self.prediction_hours = config.eos.prediction_hours | ||||
|  | ||||
|     def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]: | ||||
|         """Loads data from a cache file or source, returns a list of price entries.""" | ||||
|     def load_data(self, source: str | Path) -> list[dict[str, Any]]: | ||||
|         cache_file = self.get_cache_file(source) | ||||
|         # logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}") | ||||
|  | ||||
|         if ( | ||||
|             isinstance(source, str) | ||||
|             and self.use_cache | ||||
|             and cache_file.is_file() | ||||
|             and not self.is_cache_expired() | ||||
|         ): | ||||
|             # logger.debug("Loading data from cache...") | ||||
|             with cache_file.open("r") as file: | ||||
|                 json_data = json.load(file) | ||||
|         else: | ||||
|             # logger.debug("Fetching data from source and updating cache...") | ||||
|             json_data = self.fetch_and_cache_data(source, cache_file) | ||||
|  | ||||
|         return json_data.get("values", []) | ||||
|  | ||||
|     def get_cache_file(self, source: Union[str, Path]) -> Path: | ||||
|         """Generates a unique cache file path for the source URL.""" | ||||
|         url = str(source) | ||||
|         hash_object = hashlib.sha256(url.encode()) | ||||
|         hex_dig = hash_object.hexdigest() | ||||
|         cache_file = self.cache_dir / f"cache_{hex_dig}.json" | ||||
|         # logger.debug(f"Generated cache file path: {cache_file}") | ||||
|         return cache_file | ||||
|  | ||||
|     def is_cache_expired(self) -> bool: | ||||
|         """Checks if the cache has expired based on a one-hour limit.""" | ||||
|         if not self.cache_time_file.is_file(): | ||||
|             # logger.debug("Cache timestamp file does not exist; cache considered expired") | ||||
|             return True | ||||
|  | ||||
|         with self.cache_time_file.open("r") as file: | ||||
|             timestamp_str = file.read() | ||||
|         last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S") | ||||
|         cache_expired = datetime.now() - last_cache_time > timedelta(hours=1) | ||||
|         # logger.debug(f"Cache expired: {cache_expired}") | ||||
|         return cache_expired | ||||
|  | ||||
|     def update_cache_timestamp(self) -> None: | ||||
|         """Updates the cache timestamp to the current time.""" | ||||
|         current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||||
|         with self.cache_time_file.open("w") as file: | ||||
|             file.write(current_time) | ||||
|  | ||||
|     # logger.debug(f"Updated cache timestamp to {current_time}") | ||||
|  | ||||
|     def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict: | ||||
|         """Fetches data from a URL or file and caches it.""" | ||||
|         if isinstance(source, str): | ||||
|             # logger.debug(f"Fetching data from URL: {source}") | ||||
|             response = requests.get(source) | ||||
|             if response.status_code != 200: | ||||
|                 error_msg = f"Error fetching data: {response.status_code}" | ||||
|                 # logger.debug(f"Validation did not succeed: {error_msg}") | ||||
|                 raise Exception(error_msg) | ||||
|  | ||||
|             json_data = response.json() | ||||
|             with cache_file.open("w") as file: | ||||
|                 json.dump(json_data, file) | ||||
|             self.update_cache_timestamp() | ||||
|             if cache_file.is_file() and not self.is_cache_expired() and self.use_cache: | ||||
|                 print("Loading data from cache...") | ||||
|                 with cache_file.open("r") as file: | ||||
|                     json_data = json.load(file) | ||||
|             else: | ||||
|                 print("Loading data from the URL...") | ||||
|                 response = requests.get(source) | ||||
|                 if response.status_code == 200: | ||||
|                     json_data = response.json() | ||||
|                     with cache_file.open("w") as file: | ||||
|                         json.dump(json_data, file) | ||||
|                     self.update_cache_timestamp() | ||||
|                 else: | ||||
|                     raise Exception(f"Error fetching data: {response.status_code}") | ||||
|         elif source.is_file(): | ||||
|             # logger.debug(f"Loading data from file: {source}") | ||||
|             with source.open("r") as file: | ||||
|                 json_data = json.load(file) | ||||
|         else: | ||||
|             error_msg = f"Invalid input path: {source}" | ||||
|             # logger.debug(f"Validation did not succeed: {error_msg}") | ||||
|             raise ValueError(error_msg) | ||||
|             raise ValueError(f"Input is not a valid path: {source}") | ||||
|         return json_data["values"] | ||||
|  | ||||
|         return json_data | ||||
|     def get_cache_file(self, url: str | Path) -> Path: | ||||
|         if isinstance(url, Path): | ||||
|             url = str(url) | ||||
|         hash_object = hashlib.sha256(url.encode()) | ||||
|         hex_dig = hash_object.hexdigest() | ||||
|         return self.cache_dir / f"cache_{hex_dig}.json" | ||||
|  | ||||
|     def is_cache_expired(self) -> bool: | ||||
|         if not self.cache_time_file.is_file(): | ||||
|             return True | ||||
|         with self.cache_time_file.open("r") as file: | ||||
|             timestamp_str = file.read() | ||||
|             last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S") | ||||
|         return datetime.now() - last_cache_time > timedelta(hours=1) | ||||
|  | ||||
|     def update_cache_timestamp(self) -> None: | ||||
|         with self.cache_time_file.open("w") as file: | ||||
|             file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | ||||
|  | ||||
|     def get_price_for_date(self, date_str: str) -> np.ndarray: | ||||
|         """Retrieves all prices for a specified date, adding the previous day's last price if needed.""" | ||||
|         # logger.debug(f"Getting prices for date: {date_str}") | ||||
|         """Returns all prices for the specified date, including the price from 00:00 of the previous day.""" | ||||
|         # Convert date string to datetime object | ||||
|         date_obj = datetime.strptime(date_str, "%Y-%m-%d") | ||||
|         previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d") | ||||
|  | ||||
|         # Calculate the previous day | ||||
|         previous_day = date_obj - timedelta(days=1) | ||||
|         previous_day_str = previous_day.strftime("%Y-%m-%d") | ||||
|  | ||||
|         # Extract the price from 00:00 of the previous day | ||||
|         previous_day_prices = [ | ||||
|             entry["marketpriceEurocentPerKWh"] + self.charges | ||||
|             for entry in self.prices | ||||
| @@ -137,33 +102,30 @@ class HourlyElectricityPriceForecast: | ||||
|         ] | ||||
|         last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0 | ||||
|  | ||||
|         # Extract all prices for the specified date | ||||
|         date_prices = [ | ||||
|             entry["marketpriceEurocentPerKWh"] + self.charges | ||||
|             for entry in self.prices | ||||
|             if date_str in entry["end"] | ||||
|         ] | ||||
|         print(f"getPrice: {len(date_prices)}") | ||||
|  | ||||
|         if len(date_prices) < 24: | ||||
|         # Add the last price of the previous day at the start of the list | ||||
|         if len(date_prices) == 23: | ||||
|             date_prices.insert(0, last_price_of_previous_day) | ||||
|  | ||||
|         # logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}") | ||||
|         return np.round(np.array(date_prices) / 100000.0, 10) | ||||
|         return np.array(date_prices) / (1000.0 * 100.0) + self.charges | ||||
|  | ||||
|     def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray: | ||||
|         """Retrieves all prices within a specified date range.""" | ||||
|         # logger.debug(f"Getting prices from {start_date_str} to {end_date_str}") | ||||
|         start_date = ( | ||||
|             datetime.strptime(start_date_str, "%Y-%m-%d") | ||||
|             .replace(tzinfo=timezone.utc) | ||||
|             .astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) | ||||
|         ) | ||||
|         end_date = ( | ||||
|             datetime.strptime(end_date_str, "%Y-%m-%d") | ||||
|             .replace(tzinfo=timezone.utc) | ||||
|             .astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) | ||||
|         ) | ||||
|         """Returns all prices between the start and end dates.""" | ||||
|         print(start_date_str) | ||||
|         print(end_date_str) | ||||
|         start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) | ||||
|         end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) | ||||
|         start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) | ||||
|         end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) | ||||
|  | ||||
|         price_list = [] | ||||
|         price_list: list[float] = [] | ||||
|  | ||||
|         while start_date < end_date: | ||||
|             date_str = start_date.strftime("%Y-%m-%d") | ||||
| @@ -173,9 +135,10 @@ class HourlyElectricityPriceForecast: | ||||
|                 price_list.extend(daily_prices) | ||||
|             start_date += timedelta(days=1) | ||||
|  | ||||
|         if self.prediction_hours > 0: | ||||
|             # logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}") | ||||
|             price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,)) | ||||
|         price_list_np = np.array(price_list) | ||||
|  | ||||
|         # logger.debug(f"Total prices retrieved for date range: {len(price_list)}") | ||||
|         return np.round(np.array(price_list), 10) | ||||
|         # If prediction hours are greater than 0, reshape the price list | ||||
|         if self.prediction_hours > 0: | ||||
|             price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,)) | ||||
|  | ||||
|         return price_list_np | ||||
|   | ||||
		Reference in New Issue
	
	Block a user