Source code for wraquant.bayes.integrations

"""External package wrappers for Bayesian analysis.

Functions in this module require the ``bayes`` optional dependency group
(PyMC, ArviZ, NumPyro) and are guarded by ``@requires_extra('bayes')``.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from wraquant.core.decorators import requires_extra

__all__ = [
    "pymc_regression",
    "arviz_summary",
    "numpyro_regression",
    "bambi_regression",
    "emcee_sample",
    "blackjax_nuts",
]


# ---------------------------------------------------------------------------
# PyMC Bayesian regression
# ---------------------------------------------------------------------------


[docs] @requires_extra("bayes") def pymc_regression( y: np.ndarray | pd.Series, X: np.ndarray | pd.DataFrame, samples: int = 2_000, chains: int = 2, target_accept: float = 0.9, random_seed: int = 42, ) -> dict[str, Any]: """Bayesian linear regression using PyMC. Fits the model y ~ Normal(X @ beta, sigma) with weakly informative priors. Parameters ---------- y : array-like Response variable. X : array-like Design matrix. An intercept column is added automatically. samples : int Number of posterior samples per chain. chains : int Number of MCMC chains. target_accept : float Target acceptance rate for NUTS. random_seed : int Random seed for reproducibility. Returns ------- dict ``trace``: PyMC InferenceData object, ``coefficients_mean``: np.ndarray of posterior mean coefficients, ``coefficients_std``: np.ndarray of posterior std coefficients, ``sigma_mean``: float — posterior mean of noise std, ``model``: PyMC Model object. """ import pymc as pm y_arr = np.asarray(y, dtype=float).ravel() X_arr = np.asarray(X, dtype=float) if X_arr.ndim == 1: X_arr = X_arr.reshape(-1, 1) n, k = X_arr.shape # Add intercept X_arr = np.column_stack([np.ones(n), X_arr]) k_total = k + 1 with pm.Model() as model: beta = pm.Normal("beta", mu=0, sigma=10, shape=k_total) sigma = pm.HalfNormal("sigma", sigma=5) mu = pm.math.dot(X_arr, beta) pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_arr) trace = pm.sample( draws=samples, chains=chains, target_accept=target_accept, random_seed=random_seed, return_inferencedata=True, ) beta_samples = trace.posterior["beta"].values.reshape(-1, k_total) sigma_samples = trace.posterior["sigma"].values.ravel() return { "trace": trace, "coefficients_mean": np.mean(beta_samples, axis=0), "coefficients_std": np.std(beta_samples, axis=0), "sigma_mean": float(np.mean(sigma_samples)), "model": model, }
# --------------------------------------------------------------------------- # ArviZ summary # ---------------------------------------------------------------------------
[docs] @requires_extra("bayes") def arviz_summary( trace: Any, var_names: list[str] | None = None, hdi_prob: float = 0.94, ) -> pd.DataFrame: """Generate a summary table from a trace using ArviZ. Parameters ---------- trace : InferenceData or dict ArviZ InferenceData object or a dict of arrays. var_names : list[str] or None Variables to include. If None, includes all. hdi_prob : float Probability mass for the HDI interval. Default is 0.94. Returns ------- pd.DataFrame Summary table with mean, sd, HDI, ESS, R-hat. """ import arviz as az # ArviZ >= 0.20 renamed hdi_prob to ci_prob try: summary = az.summary(trace, var_names=var_names, hdi_prob=hdi_prob) except TypeError: summary = az.summary(trace, var_names=var_names, ci_prob=hdi_prob) return summary
# --------------------------------------------------------------------------- # NumPyro regression # ---------------------------------------------------------------------------
[docs] @requires_extra("bayes") def numpyro_regression( y: np.ndarray | pd.Series, X: np.ndarray | pd.DataFrame, samples: int = 2_000, warmup: int = 500, chains: int = 1, rng_seed: int = 0, ) -> dict[str, Any]: """Bayesian linear regression using NumPyro. Fits the model y ~ Normal(X @ beta, sigma) using NUTS sampling. Parameters ---------- y : array-like Response variable. X : array-like Design matrix. An intercept column is added automatically. samples : int Number of posterior samples. warmup : int Number of warmup (burn-in) samples. chains : int Number of MCMC chains. rng_seed : int Random seed for reproducibility. Returns ------- dict ``samples``: dict of posterior samples by parameter name, ``coefficients_mean``: np.ndarray of posterior mean coefficients, ``coefficients_std``: np.ndarray of posterior std coefficients, ``sigma_mean``: float — posterior mean of noise std. """ import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS y_arr = np.asarray(y, dtype=float).ravel() X_arr = np.asarray(X, dtype=float) if X_arr.ndim == 1: X_arr = X_arr.reshape(-1, 1) n, k = X_arr.shape # Add intercept X_arr = np.column_stack([np.ones(n), X_arr]) k_total = k + 1 X_jnp = jnp.array(X_arr) y_jnp = jnp.array(y_arr) def model(X: jnp.ndarray, y: jnp.ndarray | None = None) -> None: beta = numpyro.sample("beta", dist.Normal(0, 10).expand([k_total])) sigma = numpyro.sample("sigma", dist.HalfNormal(5)) mu = jnp.dot(X, beta) numpyro.sample("y_obs", dist.Normal(mu, sigma), obs=y) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples, num_chains=chains) rng_key = jax.random.PRNGKey(rng_seed) mcmc.run(rng_key, X_jnp, y_jnp) posterior_samples = mcmc.get_samples() beta_samples = np.asarray(posterior_samples["beta"]) sigma_samples = np.asarray(posterior_samples["sigma"]) return { "samples": {k: np.asarray(v) for k, v in posterior_samples.items()}, "coefficients_mean": np.mean(beta_samples, axis=0), "coefficients_std": np.std(beta_samples, axis=0), "sigma_mean": float(np.mean(sigma_samples)), }
# --------------------------------------------------------------------------- # Bambi regression # ---------------------------------------------------------------------------
[docs] @requires_extra("bayes") def bambi_regression( formula: str, data: pd.DataFrame, family: str = "gaussian", draws: int = 1000, chains: int = 2, seed: int | None = None, ) -> dict[str, Any]: """Fit a Bayesian regression using Bambi's formula interface. Bambi provides R-style formula syntax for Bayesian models built on PyMC. Parameters ---------- formula : str Model formula string (e.g., ``"y ~ x1 + x2"``). data : pd.DataFrame DataFrame with variables referenced in formula. family : str, default "gaussian" Likelihood family (``"gaussian"``, ``"bernoulli"``, ``"poisson"``). draws : int, default 1000 Number of posterior draws per chain. chains : int, default 2 Number of MCMC chains. seed : int or None Random seed. Returns ------- dict ``model``: Bambi Model object, ``trace``: InferenceData posterior samples, ``summary``: pd.DataFrame ArviZ summary table. """ import bambi as bmb model = bmb.Model(formula, data, family=family) trace = model.fit(draws=draws, chains=chains, random_seed=seed) import arviz as az try: summary = az.summary(trace, hdi_prob=0.94) except TypeError: summary = az.summary(trace, ci_prob=0.94) return {"model": model, "trace": trace, "summary": summary}
# --------------------------------------------------------------------------- # emcee ensemble sampling # ---------------------------------------------------------------------------
[docs] @requires_extra("bayes") def emcee_sample( log_prob_fn: Any, n_walkers: int, n_dim: int, n_steps: int = 1000, initial: np.ndarray | None = None, seed: int | None = None, ) -> dict[str, Any]: """Run ensemble MCMC sampling using emcee. Parameters ---------- log_prob_fn : callable Log probability function accepting a parameter array. n_walkers : int Number of walkers (must be >= 2 * n_dim). n_dim : int Number of parameters. n_steps : int, default 1000 Number of MCMC steps. initial : np.ndarray or None Initial walker positions of shape ``(n_walkers, n_dim)``. If None, drawn from a standard normal distribution. seed : int or None Random seed. Returns ------- dict ``samples``: array of shape ``(n_walkers * n_steps, n_dim)``, ``log_prob``: array of log probabilities, ``acceptance_fraction``: mean acceptance fraction. """ import emcee rng = np.random.default_rng(seed) if initial is None: initial = rng.normal(0, 1, (n_walkers, n_dim)) sampler = emcee.EnsembleSampler(n_walkers, n_dim, log_prob_fn) sampler.run_mcmc(initial, n_steps, progress=False) return { "samples": sampler.get_chain(flat=True), "log_prob": sampler.get_log_prob(flat=True), "acceptance_fraction": float(np.mean(sampler.acceptance_fraction)), }
# --------------------------------------------------------------------------- # BlackJAX NUTS sampling # ---------------------------------------------------------------------------
[docs] @requires_extra("bayes") def blackjax_nuts( log_prob_fn: Any, initial_position: Any, n_samples: int = 1000, step_size: float = 0.01, seed: int | None = None, ) -> dict[str, Any]: """Run NUTS sampling using BlackJAX (JAX-based). Parameters ---------- log_prob_fn : callable Log probability function (must be JAX-compatible). initial_position : jax.Array Initial parameter values. n_samples : int, default 1000 Number of samples. step_size : float, default 0.01 NUTS step size. seed : int or None Random seed integer. Returns ------- dict ``samples``: array of posterior samples, ``divergences``: number of divergent transitions. """ import jax import jax.numpy as jnp import blackjax key = jax.random.PRNGKey(seed or 0) # BlackJAX >= 1.0 requires inverse_mass_matrix; default to identity. if isinstance(initial_position, (int, float)): n_dim = 1 elif hasattr(initial_position, "shape"): n_dim = initial_position.shape[0] if initial_position.ndim >= 1 else 1 else: n_dim = len(initial_position) inv_mass = jnp.ones(n_dim) kernel = blackjax.nuts(log_prob_fn, step_size=step_size, inverse_mass_matrix=inv_mass) initial_state = kernel.init(initial_position) def one_step(state, key): state, info = kernel.step(key, state) return state, (state, info) keys = jax.random.split(key, n_samples) _, (states, infos) = jax.lax.scan(one_step, initial_state, keys) return { "samples": np.asarray(states.position), "divergences": int(np.sum(np.asarray(infos.is_divergent))), }