mirror of
				https://github.com/Akkudoktor-EOS/EOS.git
				synced 2025-11-04 08:46:20 +00:00 
			
		
		
		
	Price Prediction failed, used Normanns fixes for the new code
This commit is contained in:
		@@ -3,128 +3,167 @@ import json
 | 
			
		||||
import zoneinfo
 | 
			
		||||
from datetime import datetime, timedelta, timezone
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Sequence
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from akkudoktoreos.config import AppConfig, SetupIncomplete
 | 
			
		||||
 | 
			
		||||
# Initialize logger with DEBUG level
 | 
			
		||||
 | 
			
		||||
def repeat_to_shape(array: np.ndarray, target_shape: Sequence[int]) -> np.ndarray:
 | 
			
		||||
    # Check if the array fits the target shape
 | 
			
		||||
 | 
			
		||||
def repeat_to_shape(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray:
 | 
			
		||||
    """Expands an array to a specified shape using repetition."""
 | 
			
		||||
    # logger.debug(f"Expanding array with shape {array.shape} to target shape {target_shape}")
 | 
			
		||||
    if len(target_shape) != array.ndim:
 | 
			
		||||
        raise ValueError("Array and target shape must have the same number of dimensions")
 | 
			
		||||
        error_msg = "Array and target shape must have the same number of dimensions"
 | 
			
		||||
        # logger.debug(f"Validation did not succeed: {error_msg}")
 | 
			
		||||
        raise ValueError(error_msg)
 | 
			
		||||
 | 
			
		||||
    # Number of repetitions per dimension
 | 
			
		||||
    repeats = tuple(target_shape[i] // array.shape[i] for i in range(array.ndim))
 | 
			
		||||
 | 
			
		||||
    # Use np.tile to expand the array
 | 
			
		||||
    expanded_array = np.tile(array, repeats)
 | 
			
		||||
    #  logger.debug(f"Expanded array shape: {expanded_array.shape}")
 | 
			
		||||
    return expanded_array
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HourlyElectricityPriceForecast:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        source: str | Path,
 | 
			
		||||
        source: Union[str, Path],
 | 
			
		||||
        config: AppConfig,
 | 
			
		||||
        charges: float = 0.000228,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
    ):  # 228
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # logger.debug("Initializing HourlyElectricityPriceForecast")
 | 
			
		||||
        self.cache_dir = config.working_dir / config.directories.cache
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        if not self.cache_dir.is_dir():
 | 
			
		||||
            raise SetupIncomplete(f"Output path does not exist: {self.cache_dir}.")
 | 
			
		||||
 | 
			
		||||
        self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
 | 
			
		||||
        self.prices = self.load_data(source)
 | 
			
		||||
        self.charges = charges
 | 
			
		||||
        self.prediction_hours = config.eos.prediction_hours
 | 
			
		||||
 | 
			
		||||
    def load_data(self, source: str | Path) -> list[dict[str, Any]]:
 | 
			
		||||
        if not self.cache_dir.is_dir():
 | 
			
		||||
            error_msg = f"Output path does not exist: {self.cache_dir}"
 | 
			
		||||
            # logger.debug(f"Validation did not succeed: {error_msg}")
 | 
			
		||||
            raise SetupIncomplete(error_msg)
 | 
			
		||||
 | 
			
		||||
        self.cache_time_file = self.cache_dir / "cache_timestamp.txt"
 | 
			
		||||
        self.prices = self.load_data(source)
 | 
			
		||||
 | 
			
		||||
    def load_data(self, source: Union[str, Path]) -> list[dict[str, Union[str, float]]]:
 | 
			
		||||
        """Loads data from a cache file or source, returns a list of price entries."""
 | 
			
		||||
        cache_file = self.get_cache_file(source)
 | 
			
		||||
        # logger.debug(f"Loading data from source: {source}, using cache file: {cache_file}")
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(source, str)
 | 
			
		||||
            and self.use_cache
 | 
			
		||||
            and cache_file.is_file()
 | 
			
		||||
            and not self.is_cache_expired()
 | 
			
		||||
        ):
 | 
			
		||||
            # logger.debug("Loading data from cache...")
 | 
			
		||||
            with cache_file.open("r") as file:
 | 
			
		||||
                json_data = json.load(file)
 | 
			
		||||
        else:
 | 
			
		||||
            # logger.debug("Fetching data from source and updating cache...")
 | 
			
		||||
            json_data = self.fetch_and_cache_data(source, cache_file)
 | 
			
		||||
 | 
			
		||||
        return json_data.get("values", [])
 | 
			
		||||
 | 
			
		||||
    def get_cache_file(self, source: Union[str, Path]) -> Path:
 | 
			
		||||
        """Generates a unique cache file path for the source URL."""
 | 
			
		||||
        url = str(source)
 | 
			
		||||
        hash_object = hashlib.sha256(url.encode())
 | 
			
		||||
        hex_dig = hash_object.hexdigest()
 | 
			
		||||
        cache_file = self.cache_dir / f"cache_{hex_dig}.json"
 | 
			
		||||
        # logger.debug(f"Generated cache file path: {cache_file}")
 | 
			
		||||
        return cache_file
 | 
			
		||||
 | 
			
		||||
    def is_cache_expired(self) -> bool:
 | 
			
		||||
        """Checks if the cache has expired based on a one-hour limit."""
 | 
			
		||||
        if not self.cache_time_file.is_file():
 | 
			
		||||
            # logger.debug("Cache timestamp file does not exist; cache considered expired")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        with self.cache_time_file.open("r") as file:
 | 
			
		||||
            timestamp_str = file.read()
 | 
			
		||||
        last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
 | 
			
		||||
        cache_expired = datetime.now() - last_cache_time > timedelta(hours=1)
 | 
			
		||||
        # logger.debug(f"Cache expired: {cache_expired}")
 | 
			
		||||
        return cache_expired
 | 
			
		||||
 | 
			
		||||
    def update_cache_timestamp(self) -> None:
 | 
			
		||||
        """Updates the cache timestamp to the current time."""
 | 
			
		||||
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
        with self.cache_time_file.open("w") as file:
 | 
			
		||||
            file.write(current_time)
 | 
			
		||||
 | 
			
		||||
    # logger.debug(f"Updated cache timestamp to {current_time}")
 | 
			
		||||
 | 
			
		||||
    def fetch_and_cache_data(self, source: Union[str, Path], cache_file: Path) -> dict:
 | 
			
		||||
        """Fetches data from a URL or file and caches it."""
 | 
			
		||||
        if isinstance(source, str):
 | 
			
		||||
            if cache_file.is_file() and not self.is_cache_expired() and self.use_cache:
 | 
			
		||||
                print("Loading data from cache...")
 | 
			
		||||
                with cache_file.open("r") as file:
 | 
			
		||||
                    json_data = json.load(file)
 | 
			
		||||
            else:
 | 
			
		||||
                print("Loading data from the URL...")
 | 
			
		||||
                response = requests.get(source)
 | 
			
		||||
                if response.status_code == 200:
 | 
			
		||||
                    json_data = response.json()
 | 
			
		||||
                    with cache_file.open("w") as file:
 | 
			
		||||
                        json.dump(json_data, file)
 | 
			
		||||
                    self.update_cache_timestamp()
 | 
			
		||||
                else:
 | 
			
		||||
                    raise Exception(f"Error fetching data: {response.status_code}")
 | 
			
		||||
            # logger.debug(f"Fetching data from URL: {source}")
 | 
			
		||||
            response = requests.get(source)
 | 
			
		||||
            if response.status_code != 200:
 | 
			
		||||
                error_msg = f"Error fetching data: {response.status_code}"
 | 
			
		||||
                # logger.debug(f"Validation did not succeed: {error_msg}")
 | 
			
		||||
                raise Exception(error_msg)
 | 
			
		||||
 | 
			
		||||
            json_data = response.json()
 | 
			
		||||
            with cache_file.open("w") as file:
 | 
			
		||||
                json.dump(json_data, file)
 | 
			
		||||
            self.update_cache_timestamp()
 | 
			
		||||
        elif source.is_file():
 | 
			
		||||
            # logger.debug(f"Loading data from file: {source}")
 | 
			
		||||
            with source.open("r") as file:
 | 
			
		||||
                json_data = json.load(file)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Input is not a valid path: {source}")
 | 
			
		||||
        return json_data["values"]
 | 
			
		||||
            error_msg = f"Invalid input path: {source}"
 | 
			
		||||
            # logger.debug(f"Validation did not succeed: {error_msg}")
 | 
			
		||||
            raise ValueError(error_msg)
 | 
			
		||||
 | 
			
		||||
    def get_cache_file(self, url: str | Path) -> Path:
 | 
			
		||||
        if isinstance(url, Path):
 | 
			
		||||
            url = str(url)
 | 
			
		||||
        hash_object = hashlib.sha256(url.encode())
 | 
			
		||||
        hex_dig = hash_object.hexdigest()
 | 
			
		||||
        return self.cache_dir / f"cache_{hex_dig}.json"
 | 
			
		||||
 | 
			
		||||
    def is_cache_expired(self) -> bool:
 | 
			
		||||
        if not self.cache_time_file.is_file():
 | 
			
		||||
            return True
 | 
			
		||||
        with self.cache_time_file.open("r") as file:
 | 
			
		||||
            timestamp_str = file.read()
 | 
			
		||||
            last_cache_time = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
 | 
			
		||||
        return datetime.now() - last_cache_time > timedelta(hours=1)
 | 
			
		||||
 | 
			
		||||
    def update_cache_timestamp(self) -> None:
 | 
			
		||||
        with self.cache_time_file.open("w") as file:
 | 
			
		||||
            file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
 | 
			
		||||
        return json_data
 | 
			
		||||
 | 
			
		||||
    def get_price_for_date(self, date_str: str) -> np.ndarray:
 | 
			
		||||
        """Returns all prices for the specified date, including the price from 00:00 of the previous day."""
 | 
			
		||||
        # Convert date string to datetime object
 | 
			
		||||
        """Retrieves all prices for a specified date, adding the previous day's last price if needed."""
 | 
			
		||||
        # logger.debug(f"Getting prices for date: {date_str}")
 | 
			
		||||
        date_obj = datetime.strptime(date_str, "%Y-%m-%d")
 | 
			
		||||
        previous_day_str = (date_obj - timedelta(days=1)).strftime("%Y-%m-%d")
 | 
			
		||||
 | 
			
		||||
        # Calculate the previous day
 | 
			
		||||
        previous_day = date_obj - timedelta(days=1)
 | 
			
		||||
        previous_day_str = previous_day.strftime("%Y-%m-%d")
 | 
			
		||||
 | 
			
		||||
        # Extract the price from 00:00 of the previous day
 | 
			
		||||
        last_price_of_previous_day = [
 | 
			
		||||
        previous_day_prices = [
 | 
			
		||||
            entry["marketpriceEurocentPerKWh"] + self.charges
 | 
			
		||||
            for entry in self.prices
 | 
			
		||||
            if previous_day_str in entry["end"]
 | 
			
		||||
        ][-1]
 | 
			
		||||
        ]
 | 
			
		||||
        last_price_of_previous_day = previous_day_prices[-1] if previous_day_prices else 0
 | 
			
		||||
 | 
			
		||||
        # Extract all prices for the specified date
 | 
			
		||||
        date_prices = [
 | 
			
		||||
            entry["marketpriceEurocentPerKWh"] + self.charges
 | 
			
		||||
            for entry in self.prices
 | 
			
		||||
            if date_str in entry["end"]
 | 
			
		||||
        ]
 | 
			
		||||
        print(f"getPrice: {len(date_prices)}")
 | 
			
		||||
 | 
			
		||||
        # Add the last price of the previous day at the start of the list
 | 
			
		||||
        if len(date_prices) == 23:
 | 
			
		||||
        if len(date_prices) < 24:
 | 
			
		||||
            date_prices.insert(0, last_price_of_previous_day)
 | 
			
		||||
 | 
			
		||||
        return np.array(date_prices) / (1000.0 * 100.0) + self.charges
 | 
			
		||||
        # logger.debug(f"Retrieved {len(date_prices)} prices for date {date_str}")
 | 
			
		||||
        return np.round(np.array(date_prices) / 100000.0, 10)
 | 
			
		||||
 | 
			
		||||
    def get_price_for_daterange(self, start_date_str: str, end_date_str: str) -> np.ndarray:
 | 
			
		||||
        """Returns all prices between the start and end dates."""
 | 
			
		||||
        print(start_date_str)
 | 
			
		||||
        print(end_date_str)
 | 
			
		||||
        start_date_utc = datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
 | 
			
		||||
        end_date_utc = datetime.strptime(end_date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
 | 
			
		||||
        start_date = start_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
 | 
			
		||||
        end_date = end_date_utc.astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
 | 
			
		||||
        """Retrieves all prices within a specified date range."""
 | 
			
		||||
        # logger.debug(f"Getting prices from {start_date_str} to {end_date_str}")
 | 
			
		||||
        start_date = (
 | 
			
		||||
            datetime.strptime(start_date_str, "%Y-%m-%d")
 | 
			
		||||
            .replace(tzinfo=timezone.utc)
 | 
			
		||||
            .astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
 | 
			
		||||
        )
 | 
			
		||||
        end_date = (
 | 
			
		||||
            datetime.strptime(end_date_str, "%Y-%m-%d")
 | 
			
		||||
            .replace(tzinfo=timezone.utc)
 | 
			
		||||
            .astimezone(zoneinfo.ZoneInfo("Europe/Berlin"))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        price_list: list[float] = []
 | 
			
		||||
        price_list = []
 | 
			
		||||
 | 
			
		||||
        while start_date < end_date:
 | 
			
		||||
            date_str = start_date.strftime("%Y-%m-%d")
 | 
			
		||||
@@ -134,10 +173,9 @@ class HourlyElectricityPriceForecast:
 | 
			
		||||
                price_list.extend(daily_prices)
 | 
			
		||||
            start_date += timedelta(days=1)
 | 
			
		||||
 | 
			
		||||
        price_list_np = np.array(price_list)
 | 
			
		||||
 | 
			
		||||
        # If prediction hours are greater than 0, reshape the price list
 | 
			
		||||
        if self.prediction_hours > 0:
 | 
			
		||||
            price_list_np = repeat_to_shape(price_list_np, (self.prediction_hours,))
 | 
			
		||||
            # logger.debug(f"Reshaping price list to match prediction hours: {self.prediction_hours}")
 | 
			
		||||
            price_list = repeat_to_shape(np.array(price_list), (self.prediction_hours,))
 | 
			
		||||
 | 
			
		||||
        return price_list_np
 | 
			
		||||
        # logger.debug(f"Total prices retrieved for date range: {len(price_list)}")
 | 
			
		||||
        return np.round(np.array(price_list), 10)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user