Source code for wraquant.viz.portfolio

"""Portfolio-related visualizations.

Portfolio weight charts, efficient frontier, risk contributions, and
correlation matrix heatmaps.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from wraquant.core.decorators import requires_extra
from wraquant.viz.themes import COLORS, apply_theme

if TYPE_CHECKING:
    import matplotlib.axes
    import matplotlib.figure
    import numpy as np
    import numpy.typing as npt
    import pandas as pd


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _get_or_create_ax(
    ax: matplotlib.axes.Axes | None,
    figsize: tuple[float, float] = (12, 6),
) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
    """Return an existing axes or create a new figure/axes pair."""
    import matplotlib.pyplot as plt

    if ax is not None:
        fig = ax.get_figure()
        return fig, ax
    fig, ax = plt.subplots(figsize=figsize)
    apply_theme(fig, ax)
    return fig, ax


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @requires_extra("viz") def plot_weights( weights: pd.Series | npt.NDArray[np.floating], names: list[str] | None = None, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot portfolio weights as a horizontal bar chart. Parameters: weights: Portfolio weight vector (pandas Series or numpy array). names: Asset names. If *weights* is a Series its index is used by default; otherwise sequential integers are used. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import numpy as np import pandas as pd fig, ax = _get_or_create_ax(ax) if isinstance(weights, pd.Series): labels = list(weights.index) if names is None else names values = weights.values else: values = np.asarray(weights) labels = names if names is not None else [str(i) for i in range(len(values))] colors = [COLORS["positive"] if v >= 0 else COLORS["negative"] for v in values] y_pos = range(len(values)) ax.barh(y_pos, values, color=colors, edgecolor="white", height=0.6) ax.set_yticks(list(y_pos)) ax.set_yticklabels(labels) ax.set_xlabel("Weight") ax.set_title("Portfolio Weights") ax.axvline(0, color=COLORS["neutral"], linewidth=0.8) return ax
[docs] @requires_extra("viz") def plot_efficient_frontier( returns_range: npt.NDArray[np.floating], vol_range: npt.NDArray[np.floating], sharpe_range: npt.NDArray[np.floating] | None = None, optimal_point: tuple[float, float] | None = None, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot the efficient frontier as a scatter/line plot. Parameters: returns_range: Array of expected returns for each portfolio. vol_range: Array of volatilities for each portfolio. sharpe_range: Optional array of Sharpe ratios used to color the scatter points. optimal_point: Optional ``(volatility, return)`` tuple marking the optimal portfolio. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ fig, ax = _get_or_create_ax(ax) if sharpe_range is not None: sc = ax.scatter( vol_range, returns_range, c=sharpe_range, cmap="viridis", s=10, alpha=0.7, ) fig.colorbar(sc, ax=ax, label="Sharpe Ratio", shrink=0.8) else: ax.plot( vol_range, returns_range, color=COLORS["primary"], linewidth=1.5, ) if optimal_point is not None: ax.scatter( [optimal_point[0]], [optimal_point[1]], color=COLORS["negative"], marker="*", s=200, zorder=5, label="Optimal", ) ax.legend() ax.set_title("Efficient Frontier") ax.set_xlabel("Volatility") ax.set_ylabel("Expected Return") return ax
[docs] @requires_extra("viz") def plot_risk_contribution( contributions: pd.Series | npt.NDArray[np.floating], names: list[str] | None = None, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot risk contributions as a stacked bar chart. Parameters: contributions: Risk contribution per asset (pandas Series or array). names: Asset names. If *contributions* is a Series its index is used by default. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import matplotlib.pyplot as plt import numpy as np import pandas as pd fig, ax = _get_or_create_ax(ax) if isinstance(contributions, pd.Series): labels = list(contributions.index) if names is None else names values = contributions.values.astype(float) else: values = np.asarray(contributions, dtype=float) labels = names if names is not None else [str(i) for i in range(len(values))] cmap = plt.cm.tab10 # type: ignore[attr-defined] bar_colors = [cmap(i % 10) for i in range(len(values))] ax.bar(labels, values, color=bar_colors, edgecolor="white") ax.set_title("Risk Contribution by Asset") ax.set_ylabel("Risk Contribution") ax.set_xlabel("") # Rotate labels if there are many assets if len(labels) > 6: ax.tick_params(axis="x", rotation=45) return ax
[docs] @requires_extra("viz") def plot_correlation_matrix( corr_matrix: pd.DataFrame | npt.NDArray[np.floating], labels: list[str] | None = None, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.axes.Axes: """Plot correlation matrix as an annotated heatmap. Parameters: corr_matrix: Square correlation matrix (DataFrame or 2-D array). labels: Axis labels. If *corr_matrix* is a DataFrame its columns are used by default. ax: Matplotlib axes to plot on. A new figure is created when *None*. Returns: The matplotlib Axes containing the plot. """ import numpy as np import pandas as pd fig, ax = _get_or_create_ax(ax, figsize=(8, 7)) if isinstance(corr_matrix, pd.DataFrame): tick_labels = list(corr_matrix.columns) if labels is None else labels data = corr_matrix.values else: data = np.asarray(corr_matrix) tick_labels = ( labels if labels is not None else [str(i) for i in range(data.shape[0])] ) im = ax.imshow(data, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto") n = data.shape[0] ax.set_xticks(range(n)) ax.set_xticklabels(tick_labels, rotation=45, ha="right") ax.set_yticks(range(n)) ax.set_yticklabels(tick_labels) # Annotate cells for i in range(n): for j in range(n): val = data[i, j] ax.text( j, i, f"{val:.2f}", ha="center", va="center", fontsize=8, color="white" if abs(val) > 0.6 else "black", ) ax.set_title("Correlation Matrix") fig.colorbar(im, ax=ax, shrink=0.8) return ax