Source code for wraquant.viz.interactive

"""Core Plotly interactive chart wrappers for financial analysis.

Cumulative returns, drawdowns, rolling statistics, distributions,
correlation heatmaps, efficient frontier, and risk-return scatters ---
all as interactive Plotly figures with hover tooltips and rich formatting.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from wraquant.core.decorators import requires_extra
from wraquant.viz.themes import COLORS

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
    import pandas as pd
    import plotly.graph_objects as go

__all__ = [
    "plotly_returns",
    "plotly_drawdown",
    "plotly_rolling_stats",
    "plotly_distribution",
    "plotly_correlation_heatmap",
    "plotly_efficient_frontier",
    "plotly_risk_return_scatter",
]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_PLOTLY_TEMPLATE = "plotly_white"

_PALETTE = [
    COLORS["primary"],
    COLORS["secondary"],
    COLORS["positive"],
    COLORS["negative"],
    COLORS["accent"],
    COLORS["info"],
    COLORS["warning"],
]


def _base_layout(**overrides: object) -> dict:
    """Return a base Plotly layout dict with wraquant styling."""
    defaults: dict = dict(
        template=_PLOTLY_TEMPLATE,
        font=dict(family="sans-serif", size=12, color="#333333"),
        plot_bgcolor="white",
        paper_bgcolor="white",
        hovermode="x unified",
        margin=dict(l=60, r=30, t=50, b=50),
    )
    defaults.update(overrides)
    return defaults


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @requires_extra("viz") def plotly_returns( returns: pd.Series, benchmark: pd.Series | None = None, title: str | None = None, ) -> go.Figure: """Interactive cumulative returns chart with hover tooltips. Hover shows date, cumulative return, and current drawdown at each point. Parameters: returns: Simple (non-cumulative) return series. benchmark: Optional benchmark return series for comparison. title: Chart title. Defaults to ``"Cumulative Returns"``. Returns: A ``plotly.graph_objects.Figure``. """ import plotly.graph_objects as go cum = (1 + returns).cumprod() - 1 running_max = (1 + returns).cumprod().cummax() drawdown = ((1 + returns).cumprod() - running_max) / running_max hover_text = [ f"Date: {d:%Y-%m-%d}<br>Return: {r:.2%}<br>Drawdown: {dd:.2%}" for d, r, dd in zip(cum.index, cum.values, drawdown.values, strict=False) ] fig = go.Figure() fig.add_trace( go.Scatter( x=cum.index, y=cum.values, mode="lines", name=returns.name or "Strategy", line=dict(color=COLORS["primary"], width=2), text=hover_text, hoverinfo="text", ) ) if benchmark is not None: cum_bench = (1 + benchmark).cumprod() - 1 bench_hover = [ f"Date: {d:%Y-%m-%d}<br>Return: {r:.2%}" for d, r in zip(cum_bench.index, cum_bench.values, strict=False) ] fig.add_trace( go.Scatter( x=cum_bench.index, y=cum_bench.values, mode="lines", name=benchmark.name or "Benchmark", line=dict(color=COLORS["benchmark"], width=2, dash="dash"), text=bench_hover, hoverinfo="text", ) ) fig.add_hline(y=0, line_color=COLORS["neutral"], line_width=0.8) fig.update_layout( **_base_layout( title=title or "Cumulative Returns", yaxis_title="Cumulative Return", yaxis_tickformat=".0%", showlegend=True, ) ) return fig
[docs] @requires_extra("viz") def plotly_drawdown( returns: pd.Series, ) -> go.Figure: """Interactive underwater chart with recovery periods highlighted. Shades the drawdown area and annotates the deepest drawdown. Parameters: returns: Simple return series. Returns: A ``plotly.graph_objects.Figure``. """ import numpy as np import plotly.graph_objects as go cum = (1 + returns).cumprod() running_max = cum.cummax() drawdown = (cum - running_max) / running_max fig = go.Figure() # Filled area for drawdown fig.add_trace( go.Scatter( x=drawdown.index, y=drawdown.values, fill="tozeroy", mode="lines", name="Drawdown", line=dict(color=COLORS["drawdown"], width=1), fillcolor="rgba(214, 39, 40, 0.25)", ) ) # Highlight recovery periods (where drawdown is recovering toward 0) recovering = (drawdown < 0) & (drawdown.diff() > 0) spans: list[tuple[object, object]] = [] in_span = False start = None for idx, val in recovering.items(): if val and not in_span: in_span = True start = idx elif not val and in_span: in_span = False spans.append((start, idx)) if in_span: spans.append((start, drawdown.index[-1])) for s, e in spans[:10]: # limit shapes for performance fig.add_vrect( x0=s, x1=e, fillcolor="rgba(44, 160, 44, 0.08)", line_width=0, layer="below", ) # Annotate max drawdown min_idx = drawdown.idxmin() min_val = drawdown.min() fig.add_annotation( x=min_idx, y=min_val, text=f"Max DD: {min_val:.2%}", showarrow=True, arrowhead=2, arrowcolor=COLORS["negative"], font=dict(color=COLORS["negative"], size=11), ) fig.update_layout( **_base_layout( title="Underwater Plot (Drawdowns)", yaxis_title="Drawdown", yaxis_tickformat=".1%", ) ) return fig
[docs] @requires_extra("viz") def plotly_rolling_stats( returns: pd.Series, window: int = 63, ) -> go.Figure: """Rolling Sharpe, volatility, and beta in vertically-stacked subplots. Parameters: returns: Simple return series. window: Rolling window in periods (default 63 ~ 1 quarter). Returns: A ``plotly.graph_objects.Figure`` with three subplot rows. """ import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots roll_mean = returns.rolling(window).mean() roll_std = returns.rolling(window).std() roll_sharpe = (roll_mean / roll_std) * np.sqrt(252) roll_vol = roll_std * np.sqrt(252) # Rolling beta relative to own mean (auto-correlation proxy) roll_beta = returns.rolling(window).apply( lambda x: np.corrcoef(x, np.arange(len(x)))[0, 1], raw=True, ) fig = make_subplots( rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.06, subplot_titles=( f"Rolling {window}-Day Sharpe Ratio", f"Rolling {window}-Day Annualized Volatility", f"Rolling {window}-Day Trend Beta", ), ) fig.add_trace( go.Scatter( x=roll_sharpe.index, y=roll_sharpe.values, mode="lines", name="Sharpe", line=dict(color=COLORS["primary"], width=1.5), ), row=1, col=1, ) fig.add_hline(y=0, line_color=COLORS["neutral"], line_width=0.6, row=1, col=1) fig.add_trace( go.Scatter( x=roll_vol.index, y=roll_vol.values, mode="lines", name="Volatility", line=dict(color=COLORS["secondary"], width=1.5), fill="tozeroy", fillcolor="rgba(255, 127, 14, 0.15)", ), row=2, col=1, ) fig.add_trace( go.Scatter( x=roll_beta.index, y=roll_beta.values, mode="lines", name="Trend Beta", line=dict(color=COLORS["accent"], width=1.5), ), row=3, col=1, ) fig.add_hline(y=0, line_color=COLORS["neutral"], line_width=0.6, row=3, col=1) fig.update_layout( **_base_layout( title=f"Rolling Statistics (window={window})", height=750, showlegend=False, ) ) fig.update_yaxes(tickformat=".1f", row=1, col=1) fig.update_yaxes(tickformat=".1%", row=2, col=1) fig.update_yaxes(tickformat=".2f", row=3, col=1) return fig
[docs] @requires_extra("viz") def plotly_distribution( returns: pd.Series, bins: int = 50, overlay_normal: bool = True, ) -> go.Figure: """Interactive histogram with KDE and optional fitted normal overlay. Parameters: returns: Simple return series. bins: Number of histogram bins. overlay_normal: If *True*, overlay a fitted normal distribution PDF. Returns: A ``plotly.graph_objects.Figure``. """ import numpy as np import plotly.graph_objects as go clean = returns.dropna().values fig = go.Figure() fig.add_trace( go.Histogram( x=clean, nbinsx=bins, histnorm="probability density", name="Returns", marker_color=COLORS["primary"], opacity=0.65, ) ) # KDE from scipy.stats import gaussian_kde kde = gaussian_kde(clean) x_grid = np.linspace(clean.min(), clean.max(), 300) fig.add_trace( go.Scatter( x=x_grid, y=kde(x_grid), mode="lines", name="KDE", line=dict(color=COLORS["accent"], width=2), ) ) if overlay_normal: from scipy.stats import norm mu, sigma = clean.mean(), clean.std() fig.add_trace( go.Scatter( x=x_grid, y=norm.pdf(x_grid, mu, sigma), mode="lines", name="Normal Fit", line=dict(color=COLORS["negative"], width=1.5, dash="dash"), ) ) # Annotate skew / kurtosis from scipy.stats import kurtosis, skew sk = skew(clean) ku = kurtosis(clean) fig.add_annotation( x=0.98, y=0.95, xref="paper", yref="paper", text=f"Skew: {sk:.2f}<br>Kurt: {ku:.2f}", showarrow=False, font=dict(size=11), bgcolor="rgba(255,255,255,0.8)", bordercolor=COLORS["neutral"], ) fig.update_layout( **_base_layout( title="Return Distribution", xaxis_title="Return", yaxis_title="Density", bargap=0.02, ) ) return fig
[docs] @requires_extra("viz") def plotly_correlation_heatmap( returns_df: pd.DataFrame, ) -> go.Figure: """Interactive correlation matrix heatmap with hierarchical clustering. Reorders assets by hierarchical clustering so that correlated groups appear together. Hover shows the pair and correlation value. Parameters: returns_df: DataFrame of asset returns (columns = assets). Returns: A ``plotly.graph_objects.Figure``. """ import numpy as np import plotly.graph_objects as go from scipy.cluster.hierarchy import leaves_list, linkage corr = returns_df.corr() # Hierarchical clustering to reorder dist = 1 - corr.values np.fill_diagonal(dist, 0) # Ensure symmetry dist = (dist + dist.T) / 2 condensed = dist[np.triu_indices_from(dist, k=1)] Z = linkage(condensed, method="ward") order = leaves_list(Z) labels = [corr.columns[i] for i in order] ordered = corr.iloc[order, order] hover_text = [] for i, row_label in enumerate(labels): row_texts = [] for j, col_label in enumerate(labels): row_texts.append( f"{row_label} vs {col_label}<br>Corr: {ordered.iloc[i, j]:.3f}" ) hover_text.append(row_texts) fig = go.Figure( data=go.Heatmap( z=ordered.values, x=labels, y=labels, colorscale="RdBu_r", zmin=-1, zmax=1, text=hover_text, hoverinfo="text", colorbar=dict(title="Corr"), ) ) fig.update_layout( **_base_layout( title="Correlation Matrix (Hierarchically Clustered)", width=700, height=650, xaxis=dict(side="bottom"), yaxis=dict(autorange="reversed"), ) ) return fig
[docs] @requires_extra("viz") def plotly_efficient_frontier( expected_returns: npt.NDArray[np.floating], cov_matrix: npt.NDArray[np.floating], n_portfolios: int = 5000, ) -> go.Figure: """Interactive efficient frontier with hover showing portfolio weights. Generates random portfolios and plots the risk-return cloud. The efficient frontier is highlighted. Hovering over points reveals the weight vector. Parameters: expected_returns: 1-D array of expected returns per asset. cov_matrix: 2-D covariance matrix. n_portfolios: Number of random portfolios to simulate. Returns: A ``plotly.graph_objects.Figure``. """ import numpy as np import plotly.graph_objects as go n_assets = len(expected_returns) rng = np.random.default_rng(42) port_returns = np.empty(n_portfolios) port_vols = np.empty(n_portfolios) all_weights = np.empty((n_portfolios, n_assets)) for i in range(n_portfolios): w = rng.dirichlet(np.ones(n_assets)) all_weights[i] = w port_returns[i] = w @ expected_returns port_vols[i] = np.sqrt(w @ cov_matrix @ w) sharpes = port_returns / port_vols # Build hover text with weights hover_texts = [] for i in range(n_portfolios): parts = [f"Return: {port_returns[i]:.2%}", f"Vol: {port_vols[i]:.2%}"] parts.append(f"Sharpe: {sharpes[i]:.2f}") for j in range(n_assets): parts.append(f"w{j}: {all_weights[i, j]:.1%}") hover_texts.append("<br>".join(parts)) fig = go.Figure() fig.add_trace( go.Scatter( x=port_vols, y=port_returns, mode="markers", marker=dict( size=4, color=sharpes, colorscale="Viridis", colorbar=dict(title="Sharpe"), opacity=0.7, ), text=hover_texts, hoverinfo="text", name="Portfolios", ) ) # Mark max-Sharpe portfolio best = int(np.argmax(sharpes)) fig.add_trace( go.Scatter( x=[port_vols[best]], y=[port_returns[best]], mode="markers", marker=dict( size=14, color=COLORS["negative"], symbol="star", line=dict(width=1, color="white"), ), name="Max Sharpe", text=[hover_texts[best]], hoverinfo="text", ) ) fig.update_layout( **_base_layout( title="Efficient Frontier (Monte Carlo)", xaxis_title="Volatility", yaxis_title="Expected Return", xaxis_tickformat=".1%", yaxis_tickformat=".1%", ) ) return fig
[docs] @requires_extra("viz") def plotly_risk_return_scatter( returns_df: pd.DataFrame, ) -> go.Figure: """Risk-return scatter plot with asset labels and clickable points. Annualizes both risk and return assuming 252 trading days. Parameters: returns_df: DataFrame of asset returns (columns = assets). Returns: A ``plotly.graph_objects.Figure``. """ import numpy as np import plotly.graph_objects as go ann_ret = returns_df.mean() * 252 ann_vol = returns_df.std() * np.sqrt(252) sharpe = ann_ret / ann_vol fig = go.Figure() fig.add_trace( go.Scatter( x=ann_vol.values, y=ann_ret.values, mode="markers+text", text=returns_df.columns.tolist(), textposition="top center", marker=dict( size=12, color=sharpe.values, colorscale="Viridis", colorbar=dict(title="Sharpe"), line=dict(width=1, color="white"), ), hovertemplate=( "<b>%{text}</b><br>" "Return: %{y:.2%}<br>" "Volatility: %{x:.2%}<br>" "<extra></extra>" ), ) ) fig.update_layout( **_base_layout( title="Risk-Return Scatter", xaxis_title="Annualized Volatility", yaxis_title="Annualized Return", xaxis_tickformat=".1%", yaxis_tickformat=".1%", showlegend=False, ) ) return fig