Source code for wraquant.viz.advanced

"""Advanced and unconventional Plotly financial visualizations.

Regime overlays, 3-D volatility surfaces, animated yield curves, copula
scatters, network graphs, Sankey rebalancing flows, treemaps, and radar
charts --- the *wacky* side of quant viz.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from wraquant.core.decorators import requires_extra
from wraquant.viz.themes import COLORS

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
    import pandas as pd
    import plotly.graph_objects as go

__all__ = [
    "plotly_regime_overlay",
    "plotly_vol_surface",
    "plotly_term_structure",
    "plotly_copula_scatter",
    "plotly_network_graph",
    "plotly_sankey_flow",
    "plotly_treemap",
    "plotly_radar",
]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_PLOTLY_TEMPLATE = "plotly_white"

_REGIME_COLORS = [
    "rgba(31, 119, 180, 0.18)",   # blue
    "rgba(255, 127, 14, 0.18)",   # orange
    "rgba(44, 160, 44, 0.18)",    # green
    "rgba(214, 39, 40, 0.18)",    # red
    "rgba(148, 103, 189, 0.18)",  # purple
    "rgba(140, 86, 75, 0.18)",    # brown
    "rgba(227, 119, 194, 0.18)",  # pink
    "rgba(188, 189, 34, 0.18)",   # olive
]

_REGIME_COLORS_SOLID = [
    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
    "#9467bd", "#8c564b", "#e377c2", "#bcbd22",
]


def _base_layout(**overrides: object) -> dict:
    """Return a base Plotly layout dict with wraquant styling."""
    defaults: dict = dict(
        template=_PLOTLY_TEMPLATE,
        font=dict(family="sans-serif", size=12, color="#333333"),
        plot_bgcolor="white",
        paper_bgcolor="white",
        hovermode="closest",
        margin=dict(l=60, r=30, t=50, b=50),
    )
    defaults.update(overrides)
    return defaults


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @requires_extra("viz") def plotly_regime_overlay( prices: pd.Series, regime_labels: pd.Series, ) -> go.Figure: """Price chart with colored background bands for market regimes. Parameters: prices: Price or level time series. regime_labels: Integer series (same index) indicating the regime at each observation. Returns: A ``plotly.graph_objects.Figure``. """ import plotly.graph_objects as go fig = go.Figure() fig.add_trace( go.Scatter( x=prices.index, y=prices.values, mode="lines", name=prices.name or "Price", line=dict(color=COLORS["primary"], width=1.8), ) ) # Build regime spans unique_regimes = sorted(regime_labels.unique()) prev = regime_labels.iloc[0] start = prices.index[0] for idx, regime in zip( regime_labels.index[1:], regime_labels.iloc[1:], strict=False ): if regime != prev: fig.add_vrect( x0=start, x1=idx, fillcolor=_REGIME_COLORS[int(prev) % len(_REGIME_COLORS)], line_width=0, layer="below", ) start = idx prev = regime # Final span fig.add_vrect( x0=start, x1=prices.index[-1], fillcolor=_REGIME_COLORS[int(prev) % len(_REGIME_COLORS)], line_width=0, layer="below", ) # Invisible traces for the legend for r in unique_regimes: fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", marker=dict( size=12, color=_REGIME_COLORS_SOLID[int(r) % len(_REGIME_COLORS_SOLID)], ), name=f"Regime {r}", ) ) fig.update_layout( **_base_layout( title="Price with Regime Overlay", yaxis_title="Price", showlegend=True, ) ) return fig
[docs] @requires_extra("viz") def plotly_vol_surface( strikes: npt.NDArray[np.floating], expiries: npt.NDArray[np.floating], implied_vols: npt.NDArray[np.floating], ) -> go.Figure: """3-D implied volatility surface. Parameters: strikes: 1-D array of strike prices. expiries: 1-D array of expiries (e.g. days to expiry or years). implied_vols: 2-D array of shape ``(len(expiries), len(strikes))`` containing implied volatilities. Returns: A ``plotly.graph_objects.Figure`` with a 3-D surface. """ import numpy as np import plotly.graph_objects as go strike_grid, expiry_grid = np.meshgrid(strikes, expiries) fig = go.Figure( data=go.Surface( x=strike_grid, y=expiry_grid, z=implied_vols, colorscale="Plasma", colorbar=dict(title="IV"), hovertemplate=( "Strike: %{x:.1f}<br>" "Expiry: %{y:.2f}<br>" "IV: %{z:.2%}<br>" "<extra></extra>" ), ) ) fig.update_layout( title="Implied Volatility Surface", scene=dict( xaxis_title="Strike", yaxis_title="Expiry", zaxis_title="Implied Vol", camera=dict(eye=dict(x=1.6, y=-1.6, z=0.8)), ), template=_PLOTLY_TEMPLATE, width=800, height=650, ) return fig
[docs] @requires_extra("viz") def plotly_term_structure( maturities: npt.NDArray[np.floating], yields: npt.NDArray[np.floating], dates: list[str] | None = None, ) -> go.Figure: """Animated yield curve through time. Each row of *yields* is a snapshot of the yield curve at one date. The animation steps through dates. Parameters: maturities: 1-D array of maturities (e.g. years). yields: 2-D array of shape ``(n_dates, len(maturities))``. dates: Optional list of date labels for the slider. Returns: A ``plotly.graph_objects.Figure`` with animation frames. """ import numpy as np import plotly.graph_objects as go n_dates = yields.shape[0] if dates is None: dates = [f"t={i}" for i in range(n_dates)] # Initial frame fig = go.Figure( data=go.Scatter( x=maturities, y=yields[0], mode="lines+markers", line=dict(color=COLORS["primary"], width=2.5), marker=dict(size=6), name="Yield Curve", ) ) # Animation frames frames = [] for i in range(n_dates): frames.append( go.Frame( data=[ go.Scatter( x=maturities, y=yields[i], mode="lines+markers", line=dict(color=COLORS["primary"], width=2.5), marker=dict(size=6), ) ], name=dates[i], ) ) fig.frames = frames # Slider and buttons sliders = [ dict( active=0, steps=[ dict(args=[[d], dict(frame=dict(duration=200, redraw=True), mode="immediate")], label=d, method="animate") for d in dates ], currentvalue=dict(prefix="Date: "), ) ] fig.update_layout( **_base_layout( title="Yield Curve Term Structure", xaxis_title="Maturity (Years)", yaxis_title="Yield", yaxis_tickformat=".2%", sliders=sliders, updatemenus=[ dict( type="buttons", showactive=False, buttons=[ dict(label="Play", method="animate", args=[None, dict(frame=dict(duration=200), fromcurrent=True)]), dict(label="Pause", method="animate", args=[[None], dict(frame=dict(duration=0), mode="immediate")]), ], x=0.05, y=1.12, ) ], ) ) return fig
[docs] @requires_extra("viz") def plotly_copula_scatter( u: npt.NDArray[np.floating], v: npt.NDArray[np.floating], copula_type: str = "empirical", ) -> go.Figure: """Copula scatter plot with marginal histograms. Parameters: u: 1-D array of uniform marginals for the first variable. v: 1-D array of uniform marginals for the second variable. copula_type: Label for the copula (used in the title). Returns: A ``plotly.graph_objects.Figure`` with marginal histograms. """ import plotly.graph_objects as go from plotly.subplots import make_subplots fig = make_subplots( rows=2, cols=2, column_widths=[0.8, 0.2], row_heights=[0.2, 0.8], shared_xaxes=True, shared_yaxes=True, horizontal_spacing=0.02, vertical_spacing=0.02, ) # Main scatter fig.add_trace( go.Scatter( x=u, y=v, mode="markers", marker=dict(size=3, color=COLORS["primary"], opacity=0.5), name="Copula", ), row=2, col=1, ) # Top marginal histogram (u) fig.add_trace( go.Histogram( x=u, nbinsx=40, marker_color=COLORS["primary"], opacity=0.6, showlegend=False, ), row=1, col=1, ) # Right marginal histogram (v) fig.add_trace( go.Histogram( y=v, nbinsy=40, marker_color=COLORS["secondary"], opacity=0.6, showlegend=False, ), row=2, col=2, ) fig.update_xaxes(range=[0, 1], row=2, col=1) fig.update_yaxes(range=[0, 1], row=2, col=1) fig.update_layout( **_base_layout( title=f"Copula Scatter ({copula_type})", showlegend=False, height=600, width=650, ) ) return fig
[docs] @requires_extra("viz") def plotly_network_graph( correlation_matrix: npt.NDArray[np.floating] | pd.DataFrame, threshold: float = 0.5, ) -> go.Figure: """Asset correlation network graph. Draws edges between assets whose absolute correlation exceeds *threshold*. Node size is proportional to the number of edges. Parameters: correlation_matrix: Square correlation matrix (array or DataFrame). threshold: Minimum absolute correlation to draw an edge. Returns: A ``plotly.graph_objects.Figure``. """ import math import numpy as np import pandas as pd import plotly.graph_objects as go if isinstance(correlation_matrix, pd.DataFrame): labels = list(correlation_matrix.columns) corr = correlation_matrix.values else: corr = np.asarray(correlation_matrix) labels = [str(i) for i in range(corr.shape[0])] n = len(labels) # Circular layout angles = [2 * math.pi * i / n for i in range(n)] x_nodes = [math.cos(a) for a in angles] y_nodes = [math.sin(a) for a in angles] # Build edges edge_x: list[float | None] = [] edge_y: list[float | None] = [] edge_colors: list[str] = [] degree = [0] * n for i in range(n): for j in range(i + 1, n): if abs(corr[i, j]) >= threshold: edge_x.extend([x_nodes[i], x_nodes[j], None]) edge_y.extend([y_nodes[i], y_nodes[j], None]) color = COLORS["positive"] if corr[i, j] > 0 else COLORS["negative"] edge_colors.append(color) degree[i] += 1 degree[j] += 1 fig = go.Figure() # Draw edges individually to color them for k in range(len(edge_colors)): fig.add_trace( go.Scatter( x=edge_x[k * 3:(k + 1) * 3], y=edge_y[k * 3:(k + 1) * 3], mode="lines", line=dict(width=1.5, color=edge_colors[k]), hoverinfo="none", showlegend=False, ) ) # Draw nodes node_sizes = [max(12, 8 + d * 4) for d in degree] fig.add_trace( go.Scatter( x=x_nodes, y=y_nodes, mode="markers+text", marker=dict( size=node_sizes, color=COLORS["primary"], line=dict(width=1.5, color="white"), ), text=labels, textposition="top center", textfont=dict(size=11), hovertemplate="<b>%{text}</b><br>Connections: %{customdata}<extra></extra>", customdata=degree, name="Assets", ) ) fig.update_layout( **_base_layout( title=f"Correlation Network (|corr| >= {threshold})", showlegend=False, xaxis=dict(visible=False), yaxis=dict(visible=False, scaleanchor="x"), width=650, height=650, ) ) return fig
[docs] @requires_extra("viz") def plotly_sankey_flow( sectors: list[str], allocations_before: list[float], allocations_after: list[float], ) -> go.Figure: """Sankey diagram showing portfolio rebalancing flows. Left side shows *before* weights, right side shows *after* weights. Flows connect each sector's allocation change. Parameters: sectors: Sector / asset names. allocations_before: Weights before rebalancing. allocations_after: Weights after rebalancing. Returns: A ``plotly.graph_objects.Figure``. """ import plotly.graph_objects as go n = len(sectors) # Nodes: left side (before) + right side (after) node_labels = [f"{s} (before)" for s in sectors] + [ f"{s} (after)" for s in sectors ] node_colors = ( [COLORS["primary"]] * n + [COLORS["secondary"]] * n ) # Flows: each sector connects its before-node to its after-node sources = list(range(n)) targets = list(range(n, 2 * n)) # Flow value is the minimum of before/after (representing transferred weight) values = [min(b, a) for b, a in zip(allocations_before, allocations_after, strict=False)] fig = go.Figure( data=go.Sankey( node=dict( pad=15, thickness=20, label=node_labels, color=node_colors, ), link=dict( source=sources, target=targets, value=values, color=[ "rgba(31, 119, 180, 0.3)" if a >= b else "rgba(214, 39, 40, 0.3)" for b, a in zip(allocations_before, allocations_after, strict=False) ], ), ) ) fig.update_layout( title="Portfolio Rebalancing Flow", template=_PLOTLY_TEMPLATE, font=dict(size=11), height=500, ) return fig
[docs] @requires_extra("viz") def plotly_treemap( weights: list[float], sectors: list[str], returns: list[float], ) -> go.Figure: """Portfolio treemap with tiles sized by weight and colored by return. Parameters: weights: Portfolio weight per asset/sector. sectors: Sector or asset labels. returns: Period return per asset/sector (used for color). Returns: A ``plotly.graph_objects.Figure``. """ import plotly.graph_objects as go hover_text = [ f"<b>{s}</b><br>Weight: {w:.1%}<br>Return: {r:.2%}" for s, w, r in zip(sectors, weights, returns, strict=False) ] max_abs = max(abs(r) for r in returns) or 0.01 fig = go.Figure( go.Treemap( labels=sectors, parents=[""] * len(sectors), values=weights, marker=dict( colors=returns, colorscale="RdYlGn", cmid=0, cmin=-max_abs, cmax=max_abs, colorbar=dict(title="Return"), line=dict(width=2, color="white"), ), text=hover_text, hoverinfo="text", textinfo="label+percent parent", ) ) fig.update_layout( title="Portfolio Treemap", template=_PLOTLY_TEMPLATE, margin=dict(l=10, r=10, t=50, b=10), height=550, ) return fig
def _to_fill_color(color: str, alpha: float = 0.15) -> str: """Convert a color string to an rgba fill color.""" if color.startswith("rgb("): return color.replace("rgb(", "rgba(").replace(")", f", {alpha})") if color.startswith("#") and len(color) == 7: r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16) return f"rgba({r}, {g}, {b}, {alpha})" return f"rgba(100, 100, 100, {alpha})"
[docs] @requires_extra("viz") def plotly_radar( metrics_dict: dict[str, dict[str, float]], ) -> go.Figure: """Radar / spider chart comparing portfolio metrics. Parameters: metrics_dict: Mapping from portfolio name to a dict of ``{metric_name: value}``. All portfolios must share the same metric names. Returns: A ``plotly.graph_objects.Figure``. """ import plotly.graph_objects as go palette = [ COLORS["primary"], COLORS["secondary"], COLORS["positive"], COLORS["accent"], COLORS["info"], COLORS["warning"], ] # All metric names (from first portfolio) first_key = next(iter(metrics_dict)) categories = list(metrics_dict[first_key].keys()) fig = go.Figure() for i, (name, metrics) in enumerate(metrics_dict.items()): values = [metrics[c] for c in categories] # Close the polygon values_closed = values + [values[0]] cats_closed = categories + [categories[0]] fig.add_trace( go.Scatterpolar( r=values_closed, theta=cats_closed, fill="toself", fillcolor=_to_fill_color(palette[i % len(palette)]), line=dict(color=palette[i % len(palette)], width=2), name=name, ) ) fig.update_layout( polar=dict( radialaxis=dict(visible=True, showticklabels=True), ), title="Portfolio Metrics Comparison", template=_PLOTLY_TEMPLATE, showlegend=True, height=550, width=600, ) return fig