This commit is contained in:
Andreas 2024-12-11 07:41:24 +01:00 committed by Andreas
parent d63601d742
commit f74f895550

View File

@ -3,133 +3,98 @@ import json
import zoneinfo
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Union
from typing import Any, Sequence
import numpy as np
import requests
from akkudoktoreos.config import AppConfig, SetupIncomplete
# Initialize logger with DEBUG level
def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray:
"""Expands an array to a specified shape using repetition."""
# logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}")
def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray:
# Check if the array fits the target shape
if len(target_shape) != array.ndim:
error_msg = "Array and target shape must have the same number of dimensions"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise ValueError(error_msg)
raise ValueError("Array and target shape must have the same number of dimensions")
# Number of repetitions per dimension
repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
# Use np.tile to expand the array
expanded_array = np.tile(array, repeats)
# logger.debug(f"Expanded array shape: {expanded_array.shape}")
return expanded_array
class HourlyElectricityPriceForecast:
def __init__(
self,
source: Union[str, Path],
source: str | Path,
config: AppConfig,
charges: float = 0.000228,
use_cache: bool = True,
) -> None:
# logger.debug("Initializing HourlyElectricityPriceForecast")
): # 228
self.cache_dir = config.working_dir / config.directories.cache
self.use_cache = use_cache
self.charges = charges
self.prediction_hours = config.eos.prediction_hours
if not self.cache_dir.is_dir():
error_msg = f"Output path does not exist: {self.cache_dir}"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise SetupIncomplete(error_msg)
raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.")
self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
self.prices = self.load_data(source)
self.charges = charges
self.prediction_hours = config.eos.prediction_hours
def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]:
"""Loads data from a cache file or source, returns a list of price entries."""
def load_data(self, source: str | Path) -> list[dict[str, Any]]:
cache_file = self.get_cache_file(source)
# logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}")
if (
isinstance(source, str)
and self.use_cache
and cache_file.is_file()
and not self.is_cache_expired()
):
# logger.debug("Loading data from cache...")
with cache_file.open("r") as file:
json_data = json.load(file)
else:
# logger.debug("Fetching data from source and updating cache...")
json_data = self.fetch_and_cache_data(source, cache_file)
return json_data.get("values", [])
def get_cache_file(self, source: Union[str, Path]) -> Path:
"""Generates a unique cache file path for the source URL."""
url = str(source)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
cache_file = self.cache_dir / f"cache_{hex_dig}.json"
# logger.debug(f"Generated cache file path: {cache_file}")
return cache_file
def is_cache_expired(self) -> bool:
"""Checks if the cache has expired based on a one-hour limit."""
if not self.cache_time_file.is_file():
# logger.debug("Cache timestamp file does not exist; cache considered expired")
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
cache_expired = datetime.now() - last_cache_time > timedelta(hours=1)
# logger.debug(f"Cache expired: {cache_expired}")
return cache_expired
def update_cache_timestamp(self) -> None:
"""Updates the cache timestamp to the current time."""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self.cache_time_file.open("w") as file:
file.write(current_time)
# logger.debug(f"Updated cache timestamp to {current_time}")
def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict:
"""Fetches data from a URL or file and caches it."""
if isinstance(source, str):
# logger.debug(f"Fetching data from URL: {source}")
response = requests.get(source)
if response.status_code != 200:
error_msg = f"Error fetching data: {response.status_code}"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise Exception(error_msg)
json_data = response.json()
with cache_file.open("w") as file:
json.dump(json_data, file)
self.update_cache_timestamp()
if cache_file.is_file() and not self.is_cache_expired() and self.use_cache:
print("Loading data from cache...")
with cache_file.open("r") as file:
json_data = json.load(file)
else:
print("Loading data from the URL...")
response = requests.get(source)
if response.status_code == 200:
json_data = response.json()
with cache_file.open("w") as file:
json.dump(json_data, file)
self.update_cache_timestamp()
else:
raise Exception(f"Error fetching data: {response.status_code}")
elif source.is_file():
# logger.debug(f"Loading data from file: {source}")
with source.open("r") as file:
json_data = json.load(file)
else:
error_msg = f"Invalid input path: {source}"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise ValueError(error_msg)
raise ValueError(f"Input is not a valid path: {source}")
return json_data["values"]
return json_data
def get_cache_file(self, url: str | Path) -> Path:
if isinstance(url, Path):
url = str(url)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
return self.cache_dir / f"cache_{hex_dig}.json"
def is_cache_expired(self) -> bool:
if not self.cache_time_file.is_file():
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
return datetime.now() - last_cache_time > timedelta(hours=1)
def update_cache_timestamp(self) -> None:
with self.cache_time_file.open("w") as file:
file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
def get_price_for_date(self, date_str: str) -> np.ndarray:
"""Retrieves all prices for a specified date, adding the previous day's last price if needed."""
# logger.debug(f"Getting prices for date: {date_str}")
"""Returns all prices for the specified date, including the price from 00:00 of the previous day."""
# Convert date string to datetime object
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d")
# Calculate the previous day
previous_day = date_obj - timedelta(days=1)
previous_day_str = previous_day.strftime("%Y-%m-%d")
# Extract the price from 00:00 of the previous day
previous_day_prices = [
entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices
@ -137,33 +102,30 @@ class HourlyElectricityPriceForecast:
]
last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
# Extract all prices for the specified date
date_prices = [
entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices
if date_str in entry["end"]
]
print(f"getPrice: {len(date_prices)}")
if len(date_prices) < 24:
# Add the last price of the previous day at the start of the list
if len(date_prices) == 23:
date_prices.insert(0, last_price_of_previous_day)
# logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}")
return np.round(np.array(date_prices) / 100000.0, 10)
return np.array(date_prices) / (1000.0 * 100.0) + self.charges
def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
"""Retrieves all prices within a specified date range."""
# logger.debug(f"Getting prices from {start_date_str} to {end_date_str}")
start_date = (
datetime.strptime(start_date_str, "%Y-%m-%d")
.replace(tzinfo=timezone.utc)
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
)
end_date = (
datetime.strptime(end_date_str, "%Y-%m-%d")
.replace(tzinfo=timezone.utc)
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
)
"""Returns all prices between the start and end dates."""
print(start_date_str)
print(end_date_str)
start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
price_list = []
price_list: list[float] = []
while start_date < end_date:
date_str = start_date.strftime("%Y-%m-%d")
@ -173,9 +135,10 @@ class HourlyElectricityPriceForecast:
price_list.extend(daily_prices)
start_date += timedelta(days=1)
if self.prediction_hours > 0:
# logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}")
price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,))
price_list_np = np.array(price_list)
# logger.debug(f"Total prices retrieved for date range: {len(price_list)}")
return np.round(np.array(price_list), 10)
# If prediction hours are greater than 0, reshape the price list
if self.prediction_hours > 0:
price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,))
return price_list_np