Source code for wraquant.ts.changepoint

"""Change-point detection for time series."""

from __future__ import annotations

import numpy as np
import pandas as pd

from wraquant.core.decorators import requires_extra


[docs] def cusum(data: pd.Series, threshold: float = 1.0) -> pd.Series: """Cumulative sum (CUSUM) control chart for change detection. Returns a CUSUM statistic series. Values exceeding *threshold* standard deviations indicate a potential shift. Parameters: data: Time series. threshold: Detection threshold in standard deviation units. Returns: CUSUM statistic series. """ clean = data.dropna() mean = clean.mean() std = clean.std() if std == 0: return pd.Series(0.0, index=clean.index) normalised = (clean - mean) / std cusum_pos = pd.Series(0.0, index=clean.index) cusum_neg = pd.Series(0.0, index=clean.index) s_pos = 0.0 s_neg = 0.0 for i, val in enumerate(normalised.values): s_pos = max(0.0, s_pos + val - threshold / 2) s_neg = max(0.0, s_neg - val - threshold / 2) cusum_pos.iloc[i] = s_pos cusum_neg.iloc[i] = s_neg return cusum_pos + cusum_neg
[docs] @requires_extra("timeseries") def detect_changepoints( data: pd.Series, method: str = "pelt", penalty: float | None = None, ) -> list[int]: """Detect change-points using the ``ruptures`` library. Parameters: data: Time series. method: Algorithm — ``"pelt"`` (default), ``"binseg"``, or ``"window"``. penalty: Penalty value for the algorithm. When *None*, a sensible default is chosen. Returns: List of change-point indices (positions in the array, excluding the final point). Raises: ValueError: If *method* is not recognized. """ import ruptures as rpt signal = data.dropna().values.reshape(-1, 1) if penalty is None: penalty = np.log(len(signal)) * data.std() ** 2 if method == "pelt": algo = rpt.Pelt(model="rbf").fit(signal) result = algo.predict(pen=penalty) elif method == "binseg": algo = rpt.Binseg(model="rbf").fit(signal) result = algo.predict(pen=penalty) elif method == "window": algo = rpt.Window(model="rbf").fit(signal) result = algo.predict(pen=penalty) else: msg = f"Unknown changepoint method: {method!r}" raise ValueError(msg) # ruptures includes the final index; remove it return [int(x) for x in result if x < len(signal)]