Source code for wraquant.viz.risk

"""Risk visualizations.

VaR backtests, rolling volatility, and tail distribution plots.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from wraquant.core.decorators import requires_extra
from wraquant.viz.themes import COLORS, apply_theme

if TYPE_CHECKING:
    import matplotlib.axes
    import matplotlib.figure
    import pandas as pd


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _get_or_create_ax(
    ax: matplotlib.axes.Axes | None,
    figsize: tuple[float, float] = (12, 6),
) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
    """Return an existing axes or create a new figure/axes pair."""
    import matplotlib.pyplot as plt

    if ax is not None:
        fig = ax.get_figure()
        return fig, ax
    fig, ax = plt.subplots(figsize=figsize)
    apply_theme(fig, ax)
    return fig, ax


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @requires_extra("viz") def plot_var_backtest( returns: pd.Series, var_level: pd.Series | float, confidence: float = 0.95, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot returns with a VaR threshold line, highlighting breaches. Parameters: returns: Simple return series. var_level: Value-at-Risk threshold. Pass a scalar for a constant threshold or a Series for a time-varying VaR. confidence: VaR confidence level (used only for the title). ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import pandas as pd fig, ax = _get_or_create_ax(ax) ax.plot( returns.index, returns.values, color=COLORS["primary"], linewidth=0.8, alpha=0.7, label="Returns", ) # VaR line if isinstance(var_level, (int, float)): var_series = pd.Series(var_level, index=returns.index) else: var_series = var_level ax.plot( var_series.index, var_series.values, color=COLORS["negative"], linewidth=1.2, linestyle="--", label=f"VaR ({confidence:.0%})", ) # Highlight breaches breaches = returns[returns < var_series] if not breaches.empty: ax.scatter( breaches.index, breaches.values, color=COLORS["negative"], s=18, zorder=5, label=f"Breaches ({len(breaches)})", ) ax.set_title(f"VaR Backtest ({confidence:.0%} confidence)") ax.set_ylabel("Return") ax.set_xlabel("") ax.legend(loc="lower left") ax.axhline(0, color=COLORS["neutral"], linewidth=0.5) return ax
[docs] @requires_extra("viz") def plot_rolling_volatility( returns: pd.Series, window: int = 21, annualize: bool = True, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot rolling volatility line chart. Parameters: returns: Simple return series. window: Rolling window in periods (default 21 for ~1 month daily). annualize: Whether to annualize (multiply by sqrt(252)). ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import numpy as np fig, ax = _get_or_create_ax(ax) vol = returns.rolling(window).std() if annualize: vol = vol * np.sqrt(252) ax.plot(vol.index, vol.values, color=COLORS["primary"], linewidth=1.2) ax.fill_between(vol.index, 0, vol.values, color=COLORS["primary"], alpha=0.15) label = "Annualized " if annualize else "" ax.set_title(f"Rolling {window}-Day {label}Volatility") ax.set_ylabel("Volatility") ax.set_xlabel("") return ax
[docs] @requires_extra("viz") def plot_tail_distribution( returns: pd.Series, threshold_percentile: float = 5, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot distribution tails, zoomed into the left tail. Parameters: returns: Simple return series. threshold_percentile: Percentile below which tails are highlighted. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import numpy as np fig, ax = _get_or_create_ax(ax) clean = returns.dropna() threshold = np.percentile(clean.values, threshold_percentile) # Full histogram in light colour ax.hist( clean.values, bins=80, density=True, color=COLORS["primary"], alpha=0.35, edgecolor="white", label="All returns", ) # Tail histogram tail = clean[clean <= threshold] if not tail.empty: ax.hist( tail.values, bins=30, density=True, color=COLORS["negative"], alpha=0.7, edgecolor="white", label=f"Left tail (<= {threshold_percentile}th pctile)", ) ax.axvline( threshold, color=COLORS["negative"], linestyle="--", linewidth=1.2, label=f"{threshold_percentile}th percentile: {threshold:.4f}", ) ax.set_title("Tail Distribution") ax.set_xlabel("Return") ax.set_ylabel("Density") ax.legend(loc="upper left") return ax