From 3121490a239e0fef1e28c19c57e80a29e23df025 Mon Sep 17 00:00:00 2001 From: Normann Date: Tue, 7 Jan 2025 00:30:53 +0100 Subject: [PATCH] mypy, req --- requirements.txt | 2 +- .../prediction/elecpriceakkudoktor.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6a17e86..f0d67ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ pendulum==3.0.0 platformdirs==4.3.6 pvlib==0.11.1 pydantic==2.10.4 -statsmodels==0.14.4 \ No newline at end of file +statsmodels==0.14.4 diff --git a/src/akkudoktoreos/prediction/elecpriceakkudoktor.py b/src/akkudoktoreos/prediction/elecpriceakkudoktor.py index 09b7840..159b68f 100644 --- a/src/akkudoktoreos/prediction/elecpriceakkudoktor.py +++ b/src/akkudoktoreos/prediction/elecpriceakkudoktor.py @@ -12,13 +12,13 @@ import numpy as np import requests from numpydantic import NDArray, Shape from pydantic import Field, ValidationError +from statsmodels.tsa.holtwinters import ExponentialSmoothing from akkudoktoreos.core.logging import get_logger from akkudoktoreos.core.pydantic import PydanticBaseModel from akkudoktoreos.prediction.elecpriceabc import ElecPriceDataRecord, ElecPriceProvider from akkudoktoreos.utils.cacheutil import cache_in_file from akkudoktoreos.utils.datetimeutil import to_datetime, to_duration -from statsmodels.tsa.holtwinters import ExponentialSmoothing logger = get_logger(__name__) @@ -122,22 +122,22 @@ class ElecPriceAkkudoktor(ElecPriceProvider): self.update_datetime = to_datetime(in_timezone=self.config.timezone) return akkudoktor_data - def cap_outliers(data: np.ndarray, sigma: int = 2) -> np.ndarray: - mean = data.mean() - std = data.std() - lower_bound = mean - sigma * std - upper_bound = mean + sigma * std - capped_data = data.clip(lower=lower_bound, upper=upper_bound) - return capped_data + def _cap_outliers(self, data: np.ndarray, sigma: int = 2) -> np.ndarray: + mean = data.mean() + std = data.std() + lower_bound = mean - sigma * std + upper_bound = mean + sigma * std + capped_data = data.clip(min=lower_bound, max=upper_bound) + return capped_data - def predict_ets( - history: np.ndarray, seasonal_periods: int, prediction_hours: int - ) -> np.ndarray: - clean_history = cap_outliers(history) - model = ExponentialSmoothing( - clean_history, seasonal="add", seasonal_periods=seasonal_periods - ).fit() - return model.forecast(prediction_hours) + def _predict_ets( + self, history: np.ndarray, seasonal_periods: int, prediction_hours: int + ) -> np.ndarray: + clean_history = self._cap_outliers(history) + model = ExponentialSmoothing( + clean_history, seasonal="add", seasonal_periods=seasonal_periods + ).fit() + return model.forecast(prediction_hours) def _update_data(self, force_update: Optional[bool] = False) -> None: """Update forecast data in the ElecPriceDataRecord format.