mirror of
https://github.com/Akkudoktor-EOS/EOS.git
synced 2025-10-24 11:26:19 +00:00
Bugfixes
This commit is contained in:
@@ -3,133 +3,98 @@ import json
|
|||||||
import zoneinfo
|
import zoneinfo
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from akkudoktoreos.config import AppConfig, SetupIncomplete
|
from akkudoktoreos.config import AppConfig, SetupIncomplete
|
||||||
|
|
||||||
# Initialize logger with DEBUG level
|
|
||||||
|
|
||||||
|
def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray:
|
||||||
def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray:
|
# Check if the array fits the target shape
|
||||||
"""Expands an array to a specified shape using repetition."""
|
|
||||||
# logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}")
|
|
||||||
if len(target_shape) != array.ndim:
|
if len(target_shape) != array.ndim:
|
||||||
error_msg = "Array and target shape must have the same number of dimensions"
|
raise ValueError("Array and target shape must have the same number of dimensions")
|
||||||
# logger.debug(f"Validation did not succeed: {error_msg}")
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
|
# Number of repetitions per dimension
|
||||||
repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
|
repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
|
||||||
|
|
||||||
|
# Use np.tile to expand the array
|
||||||
expanded_array = np.tile(array, repeats)
|
expanded_array = np.tile(array, repeats)
|
||||||
# logger.debug(f"Expanded array shape: {expanded_array.shape}")
|
|
||||||
return expanded_array
|
return expanded_array
|
||||||
|
|
||||||
|
|
||||||
class HourlyElectricityPriceForecast:
|
class HourlyElectricityPriceForecast:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path],
|
source: str | Path,
|
||||||
config: AppConfig,
|
config: AppConfig,
|
||||||
charges: float = 0.000228,
|
charges: float = 0.000228,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> None:
|
): # 228
|
||||||
# logger.debug("Initializing HourlyElectricityPriceForecast")
|
|
||||||
self.cache_dir = config.working_dir / config.directories.cache
|
self.cache_dir = config.working_dir / config.directories.cache
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.charges = charges
|
|
||||||
self.prediction_hours = config.eos.prediction_hours
|
|
||||||
|
|
||||||
if not self.cache_dir.is_dir():
|
if not self.cache_dir.is_dir():
|
||||||
error_msg = f"Output path does not exist: {self.cache_dir}"
|
raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.")
|
||||||
# logger.debug(f"Validation did not succeed: {error_msg}")
|
|
||||||
raise SetupIncomplete(error_msg)
|
|
||||||
|
|
||||||
self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
|
self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
|
||||||
self.prices = self.load_data(source)
|
self.prices = self.load_data(source)
|
||||||
|
self.charges = charges
|
||||||
|
self.prediction_hours = config.eos.prediction_hours
|
||||||
|
|
||||||
def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]:
|
def load_data(self, source: str | Path) -> list[dict[str, Any]]:
|
||||||
"""Loads data from a cache file or source, returns a list of price entries."""
|
|
||||||
cache_file = self.get_cache_file(source)
|
cache_file = self.get_cache_file(source)
|
||||||
# logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}")
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(source, str)
|
|
||||||
and self.use_cache
|
|
||||||
and cache_file.is_file()
|
|
||||||
and not self.is_cache_expired()
|
|
||||||
):
|
|
||||||
# logger.debug("Loading data from cache...")
|
|
||||||
with cache_file.open("r") as file:
|
|
||||||
json_data = json.load(file)
|
|
||||||
else:
|
|
||||||
# logger.debug("Fetching data from source and updating cache...")
|
|
||||||
json_data = self.fetch_and_cache_data(source, cache_file)
|
|
||||||
|
|
||||||
return json_data.get("values", [])
|
|
||||||
|
|
||||||
def get_cache_file(self, source: Union[str, Path]) -> Path:
|
|
||||||
"""Generates a unique cache file path for the source URL."""
|
|
||||||
url = str(source)
|
|
||||||
hash_object = hashlib.sha256(url.encode())
|
|
||||||
hex_dig = hash_object.hexdigest()
|
|
||||||
cache_file = self.cache_dir / f"cache_{hex_dig}.json"
|
|
||||||
# logger.debug(f"Generated cache file path: {cache_file}")
|
|
||||||
return cache_file
|
|
||||||
|
|
||||||
def is_cache_expired(self) -> bool:
|
|
||||||
"""Checks if the cache has expired based on a one-hour limit."""
|
|
||||||
if not self.cache_time_file.is_file():
|
|
||||||
# logger.debug("Cache timestamp file does not exist; cache considered expired")
|
|
||||||
return True
|
|
||||||
|
|
||||||
with self.cache_time_file.open("r") as file:
|
|
||||||
timestamp_str = file.read()
|
|
||||||
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
|
|
||||||
cache_expired = datetime.now() - last_cache_time > timedelta(hours=1)
|
|
||||||
# logger.debug(f"Cache expired: {cache_expired}")
|
|
||||||
return cache_expired
|
|
||||||
|
|
||||||
def update_cache_timestamp(self) -> None:
|
|
||||||
"""Updates the cache timestamp to the current time."""
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
with self.cache_time_file.open("w") as file:
|
|
||||||
file.write(current_time)
|
|
||||||
|
|
||||||
# logger.debug(f"Updated cache timestamp to {current_time}")
|
|
||||||
|
|
||||||
def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict:
|
|
||||||
"""Fetches data from a URL or file and caches it."""
|
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
# logger.debug(f"Fetching data from URL: {source}")
|
if cache_file.is_file() and not self.is_cache_expired() and self.use_cache:
|
||||||
response = requests.get(source)
|
print("Loading data from cache...")
|
||||||
if response.status_code != 200:
|
with cache_file.open("r") as file:
|
||||||
error_msg = f"Error fetching data: {response.status_code}"
|
json_data = json.load(file)
|
||||||
# logger.debug(f"Validation did not succeed: {error_msg}")
|
else:
|
||||||
raise Exception(error_msg)
|
print("Loading data from the URL...")
|
||||||
|
response = requests.get(source)
|
||||||
json_data = response.json()
|
if response.status_code == 200:
|
||||||
with cache_file.open("w") as file:
|
json_data = response.json()
|
||||||
json.dump(json_data, file)
|
with cache_file.open("w") as file:
|
||||||
self.update_cache_timestamp()
|
json.dump(json_data, file)
|
||||||
|
self.update_cache_timestamp()
|
||||||
|
else:
|
||||||
|
raise Exception(f"Error fetching data: {response.status_code}")
|
||||||
elif source.is_file():
|
elif source.is_file():
|
||||||
# logger.debug(f"Loading data from file: {source}")
|
|
||||||
with source.open("r") as file:
|
with source.open("r") as file:
|
||||||
json_data = json.load(file)
|
json_data = json.load(file)
|
||||||
else:
|
else:
|
||||||
error_msg = f"Invalid input path: {source}"
|
raise ValueError(f"Input is not a valid path: {source}")
|
||||||
# logger.debug(f"Validation did not succeed: {error_msg}")
|
return json_data["values"]
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
return json_data
|
def get_cache_file(self, url: str | Path) -> Path:
|
||||||
|
if isinstance(url, Path):
|
||||||
|
url = str(url)
|
||||||
|
hash_object = hashlib.sha256(url.encode())
|
||||||
|
hex_dig = hash_object.hexdigest()
|
||||||
|
return self.cache_dir / f"cache_{hex_dig}.json"
|
||||||
|
|
||||||
|
def is_cache_expired(self) -> bool:
|
||||||
|
if not self.cache_time_file.is_file():
|
||||||
|
return True
|
||||||
|
with self.cache_time_file.open("r") as file:
|
||||||
|
timestamp_str = file.read()
|
||||||
|
last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
|
||||||
|
return datetime.now() - last_cache_time > timedelta(hours=1)
|
||||||
|
|
||||||
|
def update_cache_timestamp(self) -> None:
|
||||||
|
with self.cache_time_file.open("w") as file:
|
||||||
|
file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
def get_price_for_date(self, date_str: str) -> np.ndarray:
|
def get_price_for_date(self, date_str: str) -> np.ndarray:
|
||||||
"""Retrieves all prices for a specified date, adding the previous day's last price if needed."""
|
"""Returns all prices for the specified date, including the price from 00:00 of the previous day."""
|
||||||
# logger.debug(f"Getting prices for date: {date_str}")
|
# Convert date string to datetime object
|
||||||
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
|
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
|
||||||
previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
|
# Calculate the previous day
|
||||||
|
previous_day = date_obj - timedelta(days=1)
|
||||||
|
previous_day_str = previous_day.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Extract the price from 00:00 of the previous day
|
||||||
previous_day_prices = [
|
previous_day_prices = [
|
||||||
entry["marketpriceEurocentPerKWh"] + self.charges
|
entry["marketpriceEurocentPerKWh"] + self.charges
|
||||||
for entry in self.prices
|
for entry in self.prices
|
||||||
@@ -137,33 +102,30 @@ class HourlyElectricityPriceForecast:
|
|||||||
]
|
]
|
||||||
last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
|
last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
|
||||||
|
|
||||||
|
# Extract all prices for the specified date
|
||||||
date_prices = [
|
date_prices = [
|
||||||
entry["marketpriceEurocentPerKWh"] + self.charges
|
entry["marketpriceEurocentPerKWh"] + self.charges
|
||||||
for entry in self.prices
|
for entry in self.prices
|
||||||
if date_str in entry["end"]
|
if date_str in entry["end"]
|
||||||
]
|
]
|
||||||
|
print(f"getPrice: {len(date_prices)}")
|
||||||
|
|
||||||
if len(date_prices) < 24:
|
# Add the last price of the previous day at the start of the list
|
||||||
|
if len(date_prices) == 23:
|
||||||
date_prices.insert(0, last_price_of_previous_day)
|
date_prices.insert(0, last_price_of_previous_day)
|
||||||
|
|
||||||
# logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}")
|
return np.array(date_prices) / (1000.0 * 100.0) + self.charges
|
||||||
return np.round(np.array(date_prices) / 100000.0, 10)
|
|
||||||
|
|
||||||
def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
|
def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
|
||||||
"""Retrieves all prices within a specified date range."""
|
"""Returns all prices between the start and end dates."""
|
||||||
# logger.debug(f"Getting prices from {start_date_str} to {end_date_str}")
|
print(start_date_str)
|
||||||
start_date = (
|
print(end_date_str)
|
||||||
datetime.strptime(start_date_str, "%Y-%m-%d")
|
start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||||
.replace(tzinfo=timezone.utc)
|
end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||||
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
|
start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
|
||||||
)
|
end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
|
||||||
end_date = (
|
|
||||||
datetime.strptime(end_date_str, "%Y-%m-%d")
|
|
||||||
.replace(tzinfo=timezone.utc)
|
|
||||||
.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
|
|
||||||
)
|
|
||||||
|
|
||||||
price_list = []
|
price_list: list[float] = []
|
||||||
|
|
||||||
while start_date < end_date:
|
while start_date < end_date:
|
||||||
date_str = start_date.strftime("%Y-%m-%d")
|
date_str = start_date.strftime("%Y-%m-%d")
|
||||||
@@ -173,9 +135,10 @@ class HourlyElectricityPriceForecast:
|
|||||||
price_list.extend(daily_prices)
|
price_list.extend(daily_prices)
|
||||||
start_date += timedelta(days=1)
|
start_date += timedelta(days=1)
|
||||||
|
|
||||||
if self.prediction_hours > 0:
|
price_list_np = np.array(price_list)
|
||||||
# logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}")
|
|
||||||
price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,))
|
|
||||||
|
|
||||||
# logger.debug(f"Total prices retrieved for date range: {len(price_list)}")
|
# If prediction hours are greater than 0, reshape the price list
|
||||||
return np.round(np.array(price_list), 10)
|
if self.prediction_hours > 0:
|
||||||
|
price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,))
|
||||||
|
|
||||||
|
return price_list_np
|
||||||
|
Reference in New Issue
Block a user