Source code for wraquant.viz.timeseries

"""Time series visualizations.

Basic time series line plots, regime overlay, and seasonal decomposition
panels.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from wraquant.core.decorators import requires_extra
from wraquant.viz.themes import COLORS, apply_theme

if TYPE_CHECKING:
    import matplotlib.axes
    import matplotlib.figure
    import pandas as pd


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _get_or_create_ax(
    ax: matplotlib.axes.Axes | None,
    figsize: tuple[float, float] = (12, 6),
) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
    """Return an existing axes or create a new figure/axes pair."""
    import matplotlib.pyplot as plt

    if ax is not None:
        fig = ax.get_figure()
        return fig, ax
    fig, ax = plt.subplots(figsize=figsize)
    apply_theme(fig, ax)
    return fig, ax


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @requires_extra("viz") def plot_series( data: pd.Series | pd.DataFrame, title: str | None = None, ylabel: str | None = None, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot a basic time series line chart. Parameters: data: Time series data. A Series produces a single line; a DataFrame plots one line per column. title: Plot title. ylabel: Y-axis label. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import pandas as pd fig, ax = _get_or_create_ax(ax) if isinstance(data, pd.DataFrame): for col in data.columns: ax.plot(data.index, data[col].values, label=col, linewidth=1.2) ax.legend() else: ax.plot( data.index, data.values, color=COLORS["primary"], linewidth=1.2, label=data.name, ) if data.name: ax.legend() ax.set_title(title or "") if ylabel: ax.set_ylabel(ylabel) ax.set_xlabel("") return ax
[docs] @requires_extra("viz") def plot_regime_overlay( data: pd.Series, regimes: pd.Series, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot a time series with colored background regions for regimes. Parameters: data: Time series to plot as a line. regimes: Integer or categorical series of the same index indicating the regime at each point. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import matplotlib.pyplot as plt fig, ax = _get_or_create_ax(ax) ax.plot(data.index, data.values, color=COLORS["primary"], linewidth=1.2) unique_regimes = sorted(regimes.unique()) cmap = plt.cm.Pastel1 # type: ignore[attr-defined] regime_colors = {r: cmap(i % cmap.N) for i, r in enumerate(unique_regimes)} # Shade regime spans prev_regime = regimes.iloc[0] span_start = data.index[0] for idx, regime in zip(regimes.index[1:], regimes.iloc[1:], strict=False): if regime != prev_regime: ax.axvspan(span_start, idx, alpha=0.25, color=regime_colors[prev_regime]) span_start = idx prev_regime = regime # Final span ax.axvspan(span_start, data.index[-1], alpha=0.25, color=regime_colors[prev_regime]) # Legend patches import matplotlib.patches as mpatches patches = [ mpatches.Patch(color=regime_colors[r], alpha=0.4, label=f"Regime {r}") for r in unique_regimes ] ax.legend(handles=patches, loc="best") ax.set_title("Time Series with Regime Overlay") ax.set_xlabel("") return ax
[docs] @requires_extra("viz") def plot_decomposition( trend: pd.Series, seasonal: pd.Series, residual: pd.Series, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.figure.Figure: """Plot a 3-panel decomposition (trend, seasonal, residual). Parameters: trend: Trend component series. seasonal: Seasonal component series. residual: Residual component series. ax: Ignored. A new 3-panel figure is always created. Returns: The matplotlib Figure containing the three-panel plot. """ import matplotlib.pyplot as plt fig, axes = plt.subplots(3, 1, figsize=(12, 9), sharex=True) for a in axes: apply_theme(fig, a) components = [ (trend, "Trend", COLORS["primary"]), (seasonal, "Seasonal", COLORS["secondary"]), (residual, "Residual", COLORS["accent"]), ] for a, (series, label, color) in zip(axes, components, strict=False): a.plot(series.index, series.values, color=color, linewidth=1.2) a.set_ylabel(label) a.set_title(label) axes[-1].set_xlabel("") fig.suptitle("Time Series Decomposition", fontsize=14, fontweight="bold", y=1.01) fig.tight_layout() return fig