Price Prediction failed, used Normanns fixes for the new code

This commit is contained in:
Andreas 2024-12-11 07:21:25 +01:00 committed by Andreas
parent 4fbfff2baf
commit c115435ab3

View File

@ -3,128 +3,167 @@ import json
import zoneinfo import zoneinfo
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Sequence from typing import Union
import numpy as np import numpy as np
import requests import requests
from akkudoktoreos.config import AppConfig, SetupIncomplete from akkudoktoreos.config import AppConfig, SetupIncomplete
# Initialize logger with DEBUG level
def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray:
# Check if the array fits the target shape def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray:
"""Expands an array to a specified shape using repetition."""
# logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}")
if len(target_shape) != array.ndim: if len(target_shape) != array.ndim:
raise ValueError("Array and target shape must have the same number of dimensions") error_msg = "Array and target shape must have the same number of dimensions"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise ValueError(error_msg)
# Number of repetitions per dimension
repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim)) repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
# Use np.tile to expand the array
expanded_array = np.tile(array, repeats) expanded_array = np.tile(array, repeats)
# logger.debug(f"Expanded array shape: {expanded_array.shape}")
return expanded_array return expanded_array
class HourlyElectricityPriceForecast: class HourlyElectricityPriceForecast:
def __init__( def __init__(
self, self,
source: str | Path, source: Union[str, Path],
config: AppConfig, config: AppConfig,
charges: float = 0.000228, charges: float = 0.000228,
use_cache: bool = True, use_cache: bool = True,
): # 228 ) -> None:
# logger.debug("Initializing HourlyElectricityPriceForecast")
self.cache_dir = config.working_dir / config.directories.cache self.cache_dir = config.working_dir / config.directories.cache
self.use_cache = use_cache self.use_cache = use_cache
if not self.cache_dir.is_dir():
raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.")
self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
self.prices = self.load_data(source)
self.charges = charges self.charges = charges
self.prediction_hours = config.eos.prediction_hours self.prediction_hours = config.eos.prediction_hours
def load_data(self, source: str | Path) -> list[dict[str, Any]]: if not self.cache_dir.is_dir():
error_msg = f"Output path does not exist: {self.cache_dir}"
# logger.debug(f"Validation did not succeed: {error_msg}")
raise SetupIncomplete(error_msg)
self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
self.prices = self.load_data(source)
def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]:
"""Loads data from a cache file or source, returns a list of price entries."""
cache_file = self.get_cache_file(source) cache_file = self.get_cache_file(source)
# logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}")
if (
isinstance(source, str)
and self.use_cache
and cache_file.is_file()
and not self.is_cache_expired()
):
# logger.debug("Loading data from cache...")
with cache_file.open("r") as file:
json_data = json.load(file)
else:
# logger.debug("Fetching data from source and updating cache...")
json_data = self.fetch_and_cache_data(source, cache_file)
return json_data.get("values", [])
def get_cache_file(self, source: Union[str, Path]) -> Path:
"""Generates a unique cache file path for the source URL."""
url = str(source)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
cache_file = self.cache_dir / f"cache_{hex_dig}.json"
# logger.debug(f"Generated cache file path: {cache_file}")
return cache_file
def is_cache_expired(self) -> bool:
"""Checks if the cache has expired based on a one-hour limit."""
if not self.cache_time_file.is_file():
# logger.debug("Cache timestamp file does not exist; cache considered expired")
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
cache_expired = datetime.now() - last_cache_time > timedelta(hours=1)
# logger.debug(f"Cache expired: {cache_expired}")
return cache_expired
def update_cache_timestamp(self) -> None:
"""Updates the cache timestamp to the current time."""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self.cache_time_file.open("w") as file:
file.write(current_time)
# logger.debug(f"Updated cache timestamp to {current_time}")
def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict:
"""Fetches data from a URL or file and caches it."""
if isinstance(source, str): if isinstance(source, str):
if cache_file.is_file() and not self.is_cache_expired() and self.use_cache: # logger.debug(f"Fetching data from URL: {source}")
print("Loading data from cache...") response = requests.get(source)
with cache_file.open("r") as file: if response.status_code != 200:
json_data = json.load(file) error_msg = f"Error fetching data: {response.status_code}"
else: # logger.debug(f"Validation did not succeed: {error_msg}")
print("Loading data from the URL...") raise Exception(error_msg)
response = requests.get(source)
if response.status_code == 200: json_data = response.json()
json_data = response.json() with cache_file.open("w") as file:
with cache_file.open("w") as file: json.dump(json_data, file)
json.dump(json_data, file) self.update_cache_timestamp()
self.update_cache_timestamp()
else:
raise Exception(f"Error fetching data: {response.status_code}")
elif source.is_file(): elif source.is_file():
# logger.debug(f"Loading data from file: {source}")
with source.open("r") as file: with source.open("r") as file:
json_data = json.load(file) json_data = json.load(file)
else: else:
raise ValueError(f"Input is not a valid path: {source}") error_msg = f"Invalid input path: {source}"
return json_data["values"] # logger.debug(f"Validation did not succeed: {error_msg}")
raise ValueError(error_msg)
def get_cache_file(self, url: str | Path) -> Path: return json_data
if isinstance(url, Path):
url = str(url)
hash_object = hashlib.sha256(url.encode())
hex_dig = hash_object.hexdigest()
return self.cache_dir / f"cache_{hex_dig}.json"
def is_cache_expired(self) -> bool:
if not self.cache_time_file.is_file():
return True
with self.cache_time_file.open("r") as file:
timestamp_str = file.read()
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
return datetime.now() - last_cache_time > timedelta(hours=1)
def update_cache_timestamp(self) -> None:
with self.cache_time_file.open("w") as file:
file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
def get_price_for_date(self, date_str: str) -> np.ndarray: def get_price_for_date(self, date_str: str) -> np.ndarray:
"""Returns all prices for the specified date, including the price from 00:00 of the previous day.""" """Retrieves all prices for a specified date, adding the previous day's last price if needed."""
# Convert date string to datetime object # logger.debug(f"Getting prices for date: {date_str}")
date_obj = datetime.strptime(date_str, "%Y-%m-%d") date_obj = datetime.strptime(date_str, "%Y-%m-%d")
previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d")
# Calculate the previous day previous_day_prices = [
previous_day = date_obj - timedelta(days=1)
previous_day_str = previous_day.strftime("%Y-%m-%d")
# Extract the price from 00:00 of the previous day
last_price_of_previous_day = [
entry["marketpriceEurocentPerKWh"] + self.charges entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices for entry in self.prices
if previous_day_str in entry["end"] if previous_day_str in entry["end"]
][-1] ]
last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
# Extract all prices for the specified date
date_prices = [ date_prices = [
entry["marketpriceEurocentPerKWh"] + self.charges entry["marketpriceEurocentPerKWh"] + self.charges
for entry in self.prices for entry in self.prices
if date_str in entry["end"] if date_str in entry["end"]
] ]
print(f"getPrice: {len(date_prices)}")
# Add the last price of the previous day at the start of the list if len(date_prices) < 24:
if len(date_prices) == 23:
date_prices.insert(0, last_price_of_previous_day) date_prices.insert(0, last_price_of_previous_day)
return np.array(date_prices) / (1000.0 * 100.0) + self.charges # logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}")
return np.round(np.array(date_prices) / 100000.0, 10)
def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray: def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
"""Returns all prices between the start and end dates.""" """Retrieves all prices within a specified date range."""
print(start_date_str) # logger.debug(f"Getting prices from {start_date_str} to {end_date_str}")
print(end_date_str) start_date = (
start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) datetime.strptime(start_date_str, "%Y-%m-%d")
end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) .replace(tzinfo=timezone.utc)
start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) .astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin")) )
end_date = (
datetime.strptime(end_date_str, "%Y-%m-%d")
.replace(tzinfo=timezone.utc)
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
)
price_list: list[float] = [] price_list = []
while start_date < end_date: while start_date < end_date:
date_str = start_date.strftime("%Y-%m-%d") date_str = start_date.strftime("%Y-%m-%d")
@ -134,10 +173,9 @@ class HourlyElectricityPriceForecast:
price_list.extend(daily_prices) price_list.extend(daily_prices)
start_date += timedelta(days=1) start_date += timedelta(days=1)
price_list_np = np.array(price_list)
# If prediction hours are greater than 0, reshape the price list
if self.prediction_hours > 0: if self.prediction_hours > 0:
price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,)) # logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}")
price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,))
return price_list_np # logger.debug(f"Total prices retrieved for date range: {len(price_list)}")
return np.round(np.array(price_list), 10)