This commit is contained in:
Andreas 2024-12-11 07:41:24 +01:00 committed by Andreas
parent c115435ab3
commit 5c2bbb4de2

View File

@ -3,133 +3,98 @@ import json
import zoneinfo import zoneinfo
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Union from typing import Any, Sequence
import numpy as np import numpy as np
import requests import requests
from akkudoktoreos.config import AppConfig, SetupIncomplete from akkudoktoreos.config import AppConfig, SetupIncomplete
# Initialize logger with DEBUG level
def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray:
def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: # Check if the array fits the target shape
"""Expands an array to a specified shape using repetition."""
# logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}")
if len(target_shape) != array.ndim: if len(target_shape) != array.ndim:
error_msg = "Array and target shape must have the same number of dimensions" raise ValueError("Array and target shape must have the same number of dimensions")
# logger.debug(f"Validation did not succeed: {error_msg}")
raise ValueError(error_msg)
# Number of repetitions per dimension
repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim)) repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
# Use np.tile to expand the array
expanded_array = np.tile(array, repeats) expanded_array = np.tile(array, repeats)
# logger.debug(f"Expanded array shape: {expanded_array.shape}")
return expanded_array return expanded_array
class HourlyElectricityPriceForecast: class HourlyElectricityPriceForecast:
def __init__( def __init__(
self, self,
source: Union[str, Path], source: str | Path,
config: AppConfig, config: AppConfig,
charges: float = 0.000228, charges: float = 0.000228,
use_cache: bool = True, use_cache: bool = True,
) -> None: ): # 228
# logger.debug("Initializing HourlyElectricityPriceForecast")
self.cache_dir = config.working_dir / config.directories.cache self.cache_dir = config.working_dir / config.directories.cache
self.use_cache = use_cache self.use_cache = use_cache
self.charges = charges
self.prediction_hours = config.eos.prediction_hours
if not self.cache_dir.is_dir(): if not self.cache_dir.is_dir():
error_msg = f"Output path does not exist: {self.cache_dir}" raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.")
# logger.debug(f"Validation did not succeed: {error_msg}")
raise SetupIncomplete(error_msg)
self.cache_time_file = self.cache_dir / "cache_timestamp.txt" self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
self.prices = self.load_data(source) self.prices = self.load_data(source)
self.charges = charges
self.prediction_hours = config.eos.prediction_hours
def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]: def load_data(self, source: str | Path) -> list[dict[str, Any]]:
"""Loads data from a cache file or source, returns a list of price entries."""
cache_file = self.get_cache_file(source) cache_file = self.get_cache_file(source)
# logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}")
if (
isinstance(source, str)
and self.use_cache
and cache_file.is_file()
and not self.is_cache_expired()
):
# logger.debug("Loading data from cache...")
with cache_file.open("r") as file:
json_data = json.load(file)
else:
# logger.debug("Fetching data from source and updating cache...")
json_data = self.fetch_and_cache_data(source, cache_file)
return json_data.get("values", [])
def get_cache_file(self, source: Union[str, Path]) -> Path:
"""Generates a unique cache file path for the source URL."""
url = str(source)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
cache_file = self.cache_dir / f"cache_{hex_dig}.json"
# logger.debug(f"Generated cache file path: {cache_file}")
return cache_file
def is_cache_expired(self) -> bool:
"""Checks if the cache has expired based on a one-hour limit."""
if not self.cache_time_file.is_file():
# logger.debug("Cache timestamp file does not exist; cache considered expired")
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
cache_expired = datetime.now() - last_cache_time > timedelta(hours=1)
# logger.debug(f"Cache expired: {cache_expired}")
return cache_expired
def update_cache_timestamp(self) -> None:
"""Updates the cache timestamp to the current time."""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self.cache_time_file.open("w") as file:
file.write(current_time)
# logger.debug(f"Updated cache timestamp to {current_time}")
def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict:
"""Fetches data from a URL or file and caches it."""
if isinstance(source, str): if isinstance(source, str):
# logger.debug(f"Fetching data from URL: {source}") if cache_file.is_file() and not self.is_cache_expired() and self.use_cache:
response = requests.get(source) print("Loading data from cache...")
if response.status_code != 200: with cache_file.open("r") as file:
error_msg = f"Error fetching data: {response.status_code}" json_data = json.load(file)
# logger.debug(f"Validation did not succeed: {error_msg}") else:
raise Exception(error_msg) print("Loading data from the URL...")
response = requests.get(source)
json_data = response.json() if response.status_code == 200:
with cache_file.open("w") as file: json_data = response.json()
json.dump(json_data, file) with cache_file.open("w") as file:
self.update_cache_timestamp() json.dump(json_data, file)
self.update_cache_timestamp()
else:
raise Exception(f"Error fetching data: {response.status_code}")
elif source.is_file(): elif source.is_file():
# logger.debug(f"Loading data from file: {source}")
with source.open("r") as file: with source.open("r") as file:
json_data = json.load(file) json_data = json.load(file)
else: else:
error_msg = f"Invalid input path: {source}" raise ValueError(f"Input is not a valid path: {source}")
# logger.debug(f"Validation did not succeed: {error_msg}") return json_data["values"]
raise ValueError(error_msg)
return json_data def get_cache_file(self, url: str | Path) -> Path:
if isinstance(url, Path):
url = str(url)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
return self.cache_dir / f"cache_{hex_dig}.json"
def is_cache_expired(self) -> bool:
if not self.cache_time_file.is_file():
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
return datetime.now() - last_cache_time > timedelta(hours=1)
def update_cache_timestamp(self) -> None:
with self.cache_time_file.open("w") as file:
file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
def get_price_for_date(self, date_str: str) -> np.ndarray: def get_price_for_date(self, date_str: str) -> np.ndarray:
"""Retrieves all prices for a specified date, adding the previous day's last price if needed.""" """Returns all prices for the specified date, including the price from 00:00 of the previous day."""
# logger.debug(f"Getting prices for date: {date_str}") # Convert date string to datetime object
date_obj = datetime.strptime(date_str, "%Y-%m-%d") date_obj = datetime.strptime(date_str, "%Y-%m-%d")
previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d")
# Calculate the previous day
previous_day = date_obj - timedelta(days=1)
previous_day_str = previous_day.strftime("%Y-%m-%d")
# Extract the price from 00:00 of the previous day
previous_day_prices = [ previous_day_prices = [
entry["marketpriceEurocentPerKWh"] + self.charges entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices for entry in self.prices
@ -137,33 +102,30 @@ class HourlyElectricityPriceForecast:
] ]
last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0 last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
# Extract all prices for the specified date
date_prices = [ date_prices = [
entry["marketpriceEurocentPerKWh"] + self.charges entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices for entry in self.prices
if date_str in entry["end"] if date_str in entry["end"]
] ]
print(f"getPrice: {len(date_prices)}")
if len(date_prices) < 24: # Add the last price of the previous day at the start of the list
if len(date_prices) == 23:
date_prices.insert(0, last_price_of_previous_day) date_prices.insert(0, last_price_of_previous_day)
# logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}") return np.array(date_prices) / (1000.0 * 100.0) + self.charges
return np.round(np.array(date_prices) / 100000.0, 10)
def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray: def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
"""Retrieves all prices within a specified date range.""" """Returns all prices between the start and end dates."""
# logger.debug(f"Getting prices from {start_date_str} to {end_date_str}") print(start_date_str)
start_date = ( print(end_date_str)
datetime.strptime(start_date_str, "%Y-%m-%d") start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
.replace(tzinfo=timezone.utc) end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
) end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
end_date = (
datetime.strptime(end_date_str, "%Y-%m-%d")
.replace(tzinfo=timezone.utc)
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
)
price_list = [] price_list: list[float] = []
while start_date < end_date: while start_date < end_date:
date_str = start_date.strftime("%Y-%m-%d") date_str = start_date.strftime("%Y-%m-%d")
@ -173,9 +135,10 @@ class HourlyElectricityPriceForecast:
price_list.extend(daily_prices) price_list.extend(daily_prices)
start_date += timedelta(days=1) start_date += timedelta(days=1)
if self.prediction_hours > 0: price_list_np = np.array(price_list)
# logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}")
price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,))
# logger.debug(f"Total prices retrieved for date range: {len(price_list)}") # If prediction hours are greater than 0, reshape the price list
return np.round(np.array(price_list), 10) if self.prediction_hours > 0:
price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,))
return price_list_np