GLM family functions for link, variance, and deviance computations.
This module provides pure JAX functions for GLM families, designed for JIT compilation and automatic differentiation. All link, variance, and deviance functions are stateless and composable.
Attributes:
| Name | Type | Description |
|---|---|---|
CANONICAL_LINKS | dict[str, str] | |
ESTIMATED_DISPERSION_FAMILIES | frozenset[str] | |
LINK_FUNCTIONS |
Classes:
| Name | Description |
|---|---|
Family | Family configuration for GLM fitting. |
Functions:
| Name | Description |
|---|---|
apply_link | Apply link function by name: η = g(μ). |
apply_link_deriv | Apply link function derivative by name: dη/dμ. |
apply_link_inverse | Apply inverse link function by name: μ = g⁻¹(η). |
binomial_deviance | Binomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))]. |
binomial_dispersion | Dispersion parameter for binomial family. |
binomial_initialize | Initialize μ for binomial family. |
binomial_loglik | Binomial conditional log-likelihood (per observation). |
binomial_variance | Binomial variance function: V(μ) = μ(1-μ). |
build_family | Create a Family object from family and link names. |
cloglog_link | Complementary log-log link function: η = log(-log(1-μ)). |
cloglog_link_deriv | Cloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))). |
cloglog_link_inverse | Cloglog inverse link: μ = 1 - exp(-exp(η)). |
gamma_deviance | Gamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ]. |
gamma_dispersion | Estimate dispersion parameter for Gamma family. |
gamma_initialize | Initialize μ for Gamma family. |
gamma_loglik | Gamma conditional log-likelihood (per observation). |
gamma_variance | Gamma variance function: V(μ) = μ². |
gaussian_deviance | Gaussian unit deviance: d(y, μ) = (y - μ)². |
gaussian_dispersion | Estimate dispersion parameter for Gaussian family. |
gaussian_initialize | Initialize μ for Gaussian family. |
gaussian_loglik | Gaussian conditional log-likelihood (per observation). |
gaussian_variance | Gaussian variance function: V(μ) = 1. |
identity_link | Identity link function: η = μ. |
identity_link_deriv | Identity link derivative: dη/dμ = 1. |
identity_link_inverse | Identity inverse link: μ = η. |
inverse_link | Inverse link function: η = 1/μ. |
inverse_link_deriv | Inverse link derivative: dη/dμ = -1/μ². |
inverse_link_inverse | Inverse link inverse: μ = 1/η. |
log_link | Log link function: η = log(μ). |
log_link_deriv | Log link derivative: dη/dμ = 1/μ. |
log_link_inverse | Log inverse link: μ = exp(η). |
logit_link | Logit link function: η = log(μ/(1-μ)). |
logit_link_deriv | Logit link derivative: dη/dμ = 1/(μ(1-μ)). |
logit_link_inverse | Logit inverse link: μ = 1/(1 + exp(-η)). |
poisson_deviance | Poisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)]. |
poisson_dispersion | Dispersion parameter for Poisson family. |
poisson_initialize | Initialize μ for Poisson family. |
poisson_loglik | Poisson conditional log-likelihood (per observation). |
poisson_variance | Poisson variance function: V(μ) = μ. |
probit_link | Probit link function: η = Φ⁻¹(μ). |
probit_link_deriv | Probit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)). |
probit_link_inverse | Probit inverse link: μ = Φ(η). |
resolve_sigma | Resolve optional sigma to a concrete float. |
sample_response | Sample response values from a GLM family distribution. |
tdist_deviance | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_dispersion | Estimate dispersion (scale) parameter for Student-t family. |
tdist_initialize | Initialize μ for Student-t family. |
tdist_loglik | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_robust_weights | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_variance | Student-t variance function: V(μ) = 1. |
Modules:
| Name | Description |
|---|---|
binomial | Binomial family functions for GLM fitting. |
create | Family object construction from string names. |
gamma | Gamma family functions for GLM fitting. |
gaussian | Gaussian family functions for GLM fitting. |
links | Link functions for GLM families. |
poisson | Poisson family functions for GLM fitting. |
response | Response sampling and sigma resolution for GLM families. |
schema | Family configuration dataclass. |
tdist | Student-t family functions for robust GLM fitting. |
Attributes¶
CANONICAL_LINKS¶
CANONICAL_LINKS: dict[str, str] = {'gaussian': 'identity', 'binomial': 'logit', 'poisson': 'log', 'gamma': 'inverse', 'tdist': 'identity'}ESTIMATED_DISPERSION_FAMILIES¶
ESTIMATED_DISPERSION_FAMILIES: frozenset[str] = frozenset({'gaussian', 'gamma', 'tdist'})LINK_FUNCTIONS¶
LINK_FUNCTIONS = {'identity': (identity_link, identity_link_inverse, identity_link_deriv), 'log': (log_link, log_link_inverse, log_link_deriv), 'logit': (logit_link, logit_link_inverse, logit_link_deriv), 'probit': (probit_link, probit_link_inverse, probit_link_deriv), 'inverse': (inverse_link, inverse_link_inverse, inverse_link_deriv), 'cloglog': (cloglog_link, cloglog_link_inverse, cloglog_link_deriv)}Classes¶
Family¶
Family(name: str, link_name: str, link: Callable[[jnp.ndarray], jnp.ndarray], link_inverse: Callable[[jnp.ndarray], jnp.ndarray], link_deriv: Callable[[jnp.ndarray], jnp.ndarray], variance: Callable[[jnp.ndarray], jnp.ndarray], deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], initialize: Callable[..., jnp.ndarray], dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float], df: int | None = None, robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None) -> NoneFamily configuration for GLM fitting.
All functions are pure JAX operations, enabling JIT compilation and automatic differentiation. This is a simple data container with no methods.
Attributes:
| Name | Type | Description |
|---|---|---|
name | str | Family name (e.g., “gaussian”, “binomial”, “poisson”, “tdist”) |
link_name | str | Link function name (e.g., “identity”, “logit”, “log”) |
link | Callable[[ndarray], ndarray] | Link function η = g(μ) |
link_inverse | Callable[[ndarray], ndarray] | Inverse link μ = g⁻¹(η) |
link_deriv | Callable[[ndarray], ndarray] | Link derivative dη/dμ |
variance | Callable[[ndarray], ndarray] | Variance function V(μ) |
deviance | Callable[[ndarray, ndarray], ndarray] | Unit deviance function d(y, μ) |
loglik | Callable[[ndarray, ndarray], ndarray] | Conditional log-likelihood function log p(y |
initialize | Callable..., [ndarray] | Initialization function for starting μ values. Signature: (y, weights=None) -> mu_init. The weights parameter is optional and only used by binomial family. |
dispersion | Callable[[ndarray, ndarray, int], float] | Dispersion parameter estimation function |
df | int | None | Degrees of freedom for t-distribution family (None for others) |
robust_weights | Callable[[ndarray, ndarray, float], ndarray] | None | Optional function for residual-based weights (t-dist). Signature: (y, mu, scale) -> weights. Returns multiplicative weights that downweight outliers. None for standard families. |
Attributes¶
deviance¶
deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]df¶
df: int | None = Nonedispersion¶
dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float]initialize¶
initialize: Callable[..., jnp.ndarray]link¶
link: Callable[[jnp.ndarray], jnp.ndarray]link_deriv¶
link_deriv: Callable[[jnp.ndarray], jnp.ndarray]link_inverse¶
link_inverse: Callable[[jnp.ndarray], jnp.ndarray]link_name¶
link_name: strloglik¶
loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]name¶
name: strrobust_weights¶
robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = Nonevariance¶
variance: Callable[[jnp.ndarray], jnp.ndarray]Functions¶
apply_link¶
apply_link(link: str, mu: 'np.ndarray') -> 'np.ndarray'Apply link function by name: η = g(μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
mu | ‘np.ndarray’ | Mean values (response scale). | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Linear predictor values η. |
Examples:
>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link("logit", mu)
array([-1.386, 0. , 1.386])apply_link_deriv¶
apply_link_deriv(link: str, mu: 'np.ndarray') -> 'np.ndarray'Apply link function derivative by name: dη/dμ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
mu | ‘np.ndarray’ | Mean values (response scale). | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Derivative values dη/dμ. |
Examples:
>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link_deriv("logit", mu)
array([6.25, 4. , 6.25])apply_link_inverse¶
apply_link_inverse(link: str, eta: 'np.ndarray') -> 'np.ndarray'Apply inverse link function by name: μ = g⁻¹(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
eta | ‘np.ndarray’ | Linear predictor values. | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Mean values μ (response scale). |
Examples:
>>> import numpy as np
>>> eta = np.array([-1, 0, 1])
>>> apply_link_inverse("logit", eta)
array([0.269, 0.5 , 0.731])binomial_deviance¶
binomial_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayBinomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
mu | ndarray | Fitted mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
binomial_dispersion¶
binomial_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatDispersion parameter for binomial family.
Fixed at 1.0 for binomial models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion value (always 1.0) |
binomial_initialize¶
binomial_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for binomial family.
Uses the weighted formula from R’s stats::binomial family: mustart <- (weights * y + 0.5) / (weights + 1)
This avoids boundary values (0 or 1) while accounting for prior weights. When weights=1 (unweighted), this gives (y + 0.5) / 2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
weights | ndarray | None | Optional prior weights (n,). Defaults to 1.0 for all observations. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
binomial_loglik¶
binomial_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayBinomial conditional log-likelihood (per observation).
Computes log p(y|μ) = y*log(μ) + (1-y)*log(1-μ) for Bernoulli trials. Uses the same numerical stability patterns as binomial_deviance.
For binomial trials with n > 1, the binomial coefficient log(n choose k) would be added, but for Bernoulli (n=1) this term is zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
mu | ndarray | Fitted mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
binomial_variance¶
binomial_variance(mu: jnp.ndarray) -> jnp.ndarrayBinomial variance function: V(μ) = μ(1-μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
build_family¶
build_family(family_name: str, link_name: str | None = None) -> FamilyCreate a Family object from family and link names.
This is a convenience function that dispatches to the appropriate family factory function based on the family name string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
family_name | str | Name of the family (“gaussian”, “binomial”, “poisson”, “gamma”). Note: “tdist” is not supported here as it requires a df parameter. | required |
link_name | str | None | Optional link function name. If None, uses the canonical link for each family: - gaussian: “identity” - binomial: “logit” - poisson: “log” - gamma: “inverse” | None |
Returns:
| Type | Description |
|---|---|
Family | Family configuration object |
Examples:
>>> fam = build_family("gaussian")
>>> fam.name
'gaussian'>>> fam = build_family("binomial", "probit")
>>> fam.link_name
'probit'>>> fam = build_family("gamma", "log")
>>> fam.name
'gamma'cloglog_link¶
cloglog_link(mu: jnp.ndarray) -> jnp.ndarrayComplementary log-log link function: η = log(-log(1-μ)).
Used for asymmetric binary responses where P(Y=1) approaches 1 faster than it approaches 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
cloglog_link_deriv¶
cloglog_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayCloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
cloglog_link_inverse¶
cloglog_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayCloglog inverse link: μ = 1 - exp(-exp(η)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
gamma_deviance¶
gamma_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
gamma_dispersion¶
gamma_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion parameter for Gamma family.
Uses Pearson χ² / df_resid. For Gamma, the dispersion parameter is 1/shape, so the shape parameter is 1/dispersion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate φ̂ |
gamma_initialize¶
gamma_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Gamma family.
Uses y directly as the starting value, ensuring positive values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Gamma, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
gamma_loglik¶
gamma_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGamma conditional log-likelihood (per observation).
Computes log p(y|μ) for Gamma distribution with shape parameter k and scale θ = μ/k, parameterized by mean μ. Ignores terms involving k.
The kernel is: (k-1)log(y) - y/θ - klog(θ) Which simplifies to proportional terms: -y/μ - log(μ) (ignoring k-dependent terms)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
gamma_variance¶
gamma_variance(mu: jnp.ndarray) -> jnp.ndarrayGamma variance function: V(μ) = μ².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
gaussian_deviance¶
gaussian_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGaussian unit deviance: d(y, μ) = (y - μ)².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
gaussian_dispersion¶
gaussian_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion parameter for Gaussian family.
Uses Pearson χ² / df_resid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate φ̂ |
gaussian_initialize¶
gaussian_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Gaussian family.
Uses y directly as the starting value (matches R’s gaussian family).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Gaussian, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
gaussian_loglik¶
gaussian_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGaussian conditional log-likelihood (per observation).
Computes log p(y|μ) = -0.5 * (y - μ)² ignoring constant terms. The full Gaussian log-likelihood includes -0.5*log(2πσ²), but this constant term cancels in optimization and is added separately when computing final log-likelihood values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
gaussian_variance¶
gaussian_variance(mu: jnp.ndarray) -> jnp.ndarrayGaussian variance function: V(μ) = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,), all ones |
identity_link¶
identity_link(mu: jnp.ndarray) -> jnp.ndarrayIdentity link function: η = μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
identity_link_deriv¶
identity_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayIdentity link derivative: dη/dμ = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
identity_link_inverse¶
identity_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayIdentity inverse link: μ = η.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
inverse_link¶
inverse_link(mu: jnp.ndarray) -> jnp.ndarrayInverse link function: η = 1/μ.
Canonical link for Gamma family.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
inverse_link_deriv¶
inverse_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayInverse link derivative: dη/dμ = -1/μ².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
inverse_link_inverse¶
inverse_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayInverse link inverse: μ = 1/η.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
log_link¶
log_link(mu: jnp.ndarray) -> jnp.ndarrayLog link function: η = log(μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
log_link_deriv¶
log_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayLog link derivative: dη/dμ = 1/μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
log_link_inverse¶
log_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayLog inverse link: μ = exp(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
logit_link¶
logit_link(mu: jnp.ndarray) -> jnp.ndarrayLogit link function: η = log(μ/(1-μ)).
Values are clipped to [1e-10, 1-1e-10] to avoid log(0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
logit_link_deriv¶
logit_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayLogit link derivative: dη/dμ = 1/(μ(1-μ)).
Values are clipped to [1e-10, 1-1e-10] to avoid division by zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
logit_link_inverse¶
logit_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayLogit inverse link: μ = 1/(1 + exp(-η)).
Uses numerically stable computation to avoid overflow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
poisson_deviance¶
poisson_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPoisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
poisson_dispersion¶
poisson_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatDispersion parameter for Poisson family.
Fixed at 1.0 for Poisson models (can estimate for quasi-Poisson).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion value (always 1.0) |
poisson_initialize¶
poisson_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Poisson family.
Adds small value to avoid zero counts (matches R’s poisson family).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Poisson, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
poisson_loglik¶
poisson_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPoisson conditional log-likelihood (per observation).
Computes log p(y|μ) = y*log(μ) - μ - log(y!). The log(y!) term uses log-gamma: log(Γ(y+1)) = log(y!).
Uses numerical stability patterns similar to poisson_deviance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
poisson_variance¶
poisson_variance(mu: jnp.ndarray) -> jnp.ndarrayPoisson variance function: V(μ) = μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
probit_link¶
probit_link(mu: jnp.ndarray) -> jnp.ndarrayProbit link function: η = Φ⁻¹(μ).
Uses the inverse error function for numerical stability. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
probit_link_deriv¶
probit_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayProbit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).
Where φ is the standard normal PDF. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
probit_link_inverse¶
probit_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayProbit inverse link: μ = Φ(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
resolve_sigma¶
resolve_sigma(sigma: float | None) -> floatResolve optional sigma to a concrete float.
GLMs without a dispersion parameter (binomial, poisson) store
sigma=None in FitState. This resolves to 1.0 for those families.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sigma | float | None | Residual SD from FitState, or None. | required |
Returns:
| Type | Description |
|---|---|
float | The sigma value, or 1.0 if None. |
sample_response¶
sample_response(family: str, mu: np.ndarray, sigma: float, rng: np.random.Generator) -> np.ndarraySample response values from a GLM family distribution.
Given the conditional mean on the response scale (after inverse link), draws observations from the appropriate distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
family | str | Distribution family name (“gaussian”, “binomial”, “poisson”). | required |
mu | ndarray | Conditional mean on the response scale, shape (n,). | required |
sigma | float | Residual standard deviation (used only for gaussian/tdist). | required |
rng | Generator | NumPy random number generator. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Sampled response values, shape (n,). |
tdist_deviance¶
tdist_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_dispersion¶
tdist_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion (scale) parameter for Student-t family.
Uses MAD (median absolute deviation) for robust scale estimation: σ = MAD / 0.6745
where 0.6745 is the MAD of the standard normal distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused for MAD) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate σ̂ |
tdist_initialize¶
tdist_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Student-t family.
Uses y directly as the starting value (like Gaussian).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
weights | ndarray | None | Optional prior weights (n,). Unused for t-distribution. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
tdist_loglik¶
tdist_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_robust_weights¶
tdist_robust_weights(y: jnp.ndarray, mu: jnp.ndarray, scale: float) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_variance¶
tdist_variance(mu: jnp.ndarray) -> jnp.ndarrayStudent-t variance function: V(μ) = 1.
Like Gaussian, the variance function is constant. The heavy-tailed behavior comes from the robust weights, not the variance function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,), all ones |
Modules¶
binomial¶
Binomial family functions for GLM fitting.
Functions:
| Name | Description |
|---|---|
binomial | Create binomial family for binary or proportion data. |
binomial_deviance | Binomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))]. |
binomial_dispersion | Dispersion parameter for binomial family. |
binomial_initialize | Initialize μ for binomial family. |
binomial_loglik | Binomial conditional log-likelihood (per observation). |
binomial_variance | Binomial variance function: V(μ) = μ(1-μ). |
Classes¶
Functions¶
binomial¶
binomial(link: str | None = None) -> FamilyCreate binomial family for binary or proportion data.
The binomial family is appropriate for binary outcomes (0/1) or proportions (successes/trials). Commonly used for logistic and probit regression.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | None | Link function name. Options are “logit” (default) or “probit”. Defaults to “logit” if None. | None |
Returns:
| Type | Description |
|---|---|
Family | Binomial family configuration |
Examples:
>>> # Logistic regression (default)
>>> fam = binomial()
>>> fam.link_name
'logit'>>> # Probit regression
>>> fam = binomial("probit")
>>> fam.link_name
'probit'binomial_deviance¶
binomial_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayBinomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
mu | ndarray | Fitted mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
binomial_dispersion¶
binomial_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatDispersion parameter for binomial family.
Fixed at 1.0 for binomial models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion value (always 1.0) |
binomial_initialize¶
binomial_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for binomial family.
Uses the weighted formula from R’s stats::binomial family: mustart <- (weights * y + 0.5) / (weights + 1)
This avoids boundary values (0 or 1) while accounting for prior weights. When weights=1 (unweighted), this gives (y + 0.5) / 2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
weights | ndarray | None | Optional prior weights (n,). Defaults to 1.0 for all observations. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
binomial_loglik¶
binomial_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayBinomial conditional log-likelihood (per observation).
Computes log p(y|μ) = y*log(μ) + (1-y)*log(1-μ) for Bernoulli trials. Uses the same numerical stability patterns as binomial_deviance.
For binomial trials with n > 1, the binomial coefficient log(n choose k) would be added, but for Bernoulli (n=1) this term is zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be in [0, 1] | required |
mu | ndarray | Fitted mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
binomial_variance¶
binomial_variance(mu: jnp.ndarray) -> jnp.ndarrayBinomial variance function: V(μ) = μ(1-μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
create¶
Family object construction from string names.
Functions:
| Name | Description |
|---|---|
build_family | Create a Family object from family and link names. |
Classes¶
Functions¶
build_family¶
build_family(family_name: str, link_name: str | None = None) -> FamilyCreate a Family object from family and link names.
This is a convenience function that dispatches to the appropriate family factory function based on the family name string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
family_name | str | Name of the family (“gaussian”, “binomial”, “poisson”, “gamma”). Note: “tdist” is not supported here as it requires a df parameter. | required |
link_name | str | None | Optional link function name. If None, uses the canonical link for each family: - gaussian: “identity” - binomial: “logit” - poisson: “log” - gamma: “inverse” | None |
Returns:
| Type | Description |
|---|---|
Family | Family configuration object |
Examples:
>>> fam = build_family("gaussian")
>>> fam.name
'gaussian'>>> fam = build_family("binomial", "probit")
>>> fam.link_name
'probit'>>> fam = build_family("gamma", "log")
>>> fam.name
'gamma'gamma¶
Gamma family functions for GLM fitting.
Functions:
| Name | Description |
|---|---|
gamma | Create Gamma family for positive continuous data. |
gamma_deviance | Gamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ]. |
gamma_dispersion | Estimate dispersion parameter for Gamma family. |
gamma_initialize | Initialize μ for Gamma family. |
gamma_loglik | Gamma conditional log-likelihood (per observation). |
gamma_variance | Gamma variance function: V(μ) = μ². |
Classes¶
Functions¶
gamma¶
gamma(link: str | None = None) -> FamilyCreate Gamma family for positive continuous data.
The Gamma family is appropriate for positive continuous response data where variance increases with the mean (V(μ) = μ²). Commonly used for modeling waiting times, insurance claims, or other positive-valued data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | None | Link function name. Options are “inverse” (canonical, default) or “log”. Defaults to “inverse” if None. | None |
Returns:
| Type | Description |
|---|---|
Family | Gamma family configuration |
Examples:
>>> # Gamma with inverse link (canonical)
>>> fam = gamma()
>>> fam.name
'gamma'
>>> fam.link_name
'inverse'>>> # Gamma with log link
>>> fam = gamma("log")
>>> fam.link_name
'log'gamma_deviance¶
gamma_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
gamma_dispersion¶
gamma_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion parameter for Gamma family.
Uses Pearson χ² / df_resid. For Gamma, the dispersion parameter is 1/shape, so the shape parameter is 1/dispersion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate φ̂ |
gamma_initialize¶
gamma_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Gamma family.
Uses y directly as the starting value, ensuring positive values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Gamma, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
gamma_loglik¶
gamma_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGamma conditional log-likelihood (per observation).
Computes log p(y|μ) for Gamma distribution with shape parameter k and scale θ = μ/k, parameterized by mean μ. Ignores terms involving k.
The kernel is: (k-1)log(y) - y/θ - klog(θ) Which simplifies to proportional terms: -y/μ - log(μ) (ignoring k-dependent terms)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be positive | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
gamma_variance¶
gamma_variance(mu: jnp.ndarray) -> jnp.ndarrayGamma variance function: V(μ) = μ².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
gaussian¶
Gaussian family functions for GLM fitting.
Functions:
| Name | Description |
|---|---|
gaussian | Create Gaussian family with identity link. |
gaussian_deviance | Gaussian unit deviance: d(y, μ) = (y - μ)². |
gaussian_dispersion | Estimate dispersion parameter for Gaussian family. |
gaussian_initialize | Initialize μ for Gaussian family. |
gaussian_loglik | Gaussian conditional log-likelihood (per observation). |
gaussian_variance | Gaussian variance function: V(μ) = 1. |
Classes¶
Functions¶
gaussian¶
gaussian(link: str | None = None) -> FamilyCreate Gaussian family with identity link.
The Gaussian family is appropriate for continuous response data with constant variance. This is equivalent to ordinary least squares (OLS) when using the identity link.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | None | Link function name. Only “identity” supported (canonical). Defaults to “identity” if None. | None |
Returns:
| Type | Description |
|---|---|
Family | Gaussian family configuration |
Examples:
>>> fam = gaussian()
>>> fam.name
'gaussian'
>>> fam.link_name
'identity'gaussian_deviance¶
gaussian_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGaussian unit deviance: d(y, μ) = (y - μ)².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
gaussian_dispersion¶
gaussian_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion parameter for Gaussian family.
Uses Pearson χ² / df_resid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate φ̂ |
gaussian_initialize¶
gaussian_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Gaussian family.
Uses y directly as the starting value (matches R’s gaussian family).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Gaussian, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
gaussian_loglik¶
gaussian_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayGaussian conditional log-likelihood (per observation).
Computes log p(y|μ) = -0.5 * (y - μ)² ignoring constant terms. The full Gaussian log-likelihood includes -0.5*log(2πσ²), but this constant term cancels in optimization and is added separately when computing final log-likelihood values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
gaussian_variance¶
gaussian_variance(mu: jnp.ndarray) -> jnp.ndarrayGaussian variance function: V(μ) = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,), all ones |
links¶
Link functions for GLM families.
Functions:
| Name | Description |
|---|---|
apply_link | Apply link function by name: η = g(μ). |
apply_link_deriv | Apply link function derivative by name: dη/dμ. |
apply_link_inverse | Apply inverse link function by name: μ = g⁻¹(η). |
apply_link_inverse_deriv | Compute derivative of inverse link function: dμ/dη = d/dη[g⁻¹(η)]. |
cloglog_link | Complementary log-log link function: η = log(-log(1-μ)). |
cloglog_link_deriv | Cloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))). |
cloglog_link_inverse | Cloglog inverse link: μ = 1 - exp(-exp(η)). |
identity_link | Identity link function: η = μ. |
identity_link_deriv | Identity link derivative: dη/dμ = 1. |
identity_link_inverse | Identity inverse link: μ = η. |
inverse_link | Inverse link function: η = 1/μ. |
inverse_link_deriv | Inverse link derivative: dη/dμ = -1/μ². |
inverse_link_inverse | Inverse link inverse: μ = 1/η. |
log_link | Log link function: η = log(μ). |
log_link_deriv | Log link derivative: dη/dμ = 1/μ. |
log_link_inverse | Log inverse link: μ = exp(η). |
logit_link | Logit link function: η = log(μ/(1-μ)). |
logit_link_deriv | Logit link derivative: dη/dμ = 1/(μ(1-μ)). |
logit_link_inverse | Logit inverse link: μ = 1/(1 + exp(-η)). |
probit_link | Probit link function: η = Φ⁻¹(μ). |
probit_link_deriv | Probit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)). |
probit_link_inverse | Probit inverse link: μ = Φ(η). |
Attributes:
| Name | Type | Description |
|---|---|---|
LINK_FUNCTIONS |
Attributes¶
LINK_FUNCTIONS¶
LINK_FUNCTIONS = {'identity': (identity_link, identity_link_inverse, identity_link_deriv), 'log': (log_link, log_link_inverse, log_link_deriv), 'logit': (logit_link, logit_link_inverse, logit_link_deriv), 'probit': (probit_link, probit_link_inverse, probit_link_deriv), 'inverse': (inverse_link, inverse_link_inverse, inverse_link_deriv), 'cloglog': (cloglog_link, cloglog_link_inverse, cloglog_link_deriv)}Classes¶
Functions¶
apply_link¶
apply_link(link: str, mu: 'np.ndarray') -> 'np.ndarray'Apply link function by name: η = g(μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
mu | ‘np.ndarray’ | Mean values (response scale). | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Linear predictor values η. |
Examples:
>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link("logit", mu)
array([-1.386, 0. , 1.386])apply_link_deriv¶
apply_link_deriv(link: str, mu: 'np.ndarray') -> 'np.ndarray'Apply link function derivative by name: dη/dμ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
mu | ‘np.ndarray’ | Mean values (response scale). | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Derivative values dη/dμ. |
Examples:
>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link_deriv("logit", mu)
array([6.25, 4. , 6.25])apply_link_inverse¶
apply_link_inverse(link: str, eta: 'np.ndarray') -> 'np.ndarray'Apply inverse link function by name: μ = g⁻¹(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
eta | ‘np.ndarray’ | Linear predictor values. | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Mean values μ (response scale). |
Examples:
>>> import numpy as np
>>> eta = np.array([-1, 0, 1])
>>> apply_link_inverse("logit", eta)
array([0.269, 0.5 , 0.731])apply_link_inverse_deriv¶
apply_link_inverse_deriv(link: str, eta: 'np.ndarray') -> 'np.ndarray'Compute derivative of inverse link function: dμ/dη = d/dη[g⁻¹(η)].
Used for delta method transformation from link to data scale: SE_data = |dμ/dη| × SE_link.
The derivative is computed as 1 / (dη/dμ) evaluated at μ = g⁻¹(η). For logit: p·(1-p). For log: exp(η). For identity: 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | Link function name. One of: identity, log, logit, probit, inverse, cloglog. | required |
eta | ‘np.ndarray’ | Linear predictor values. | required |
Returns:
| Type | Description |
|---|---|
‘np.ndarray’ | Derivative values dμ/dη. |
cloglog_link¶
cloglog_link(mu: jnp.ndarray) -> jnp.ndarrayComplementary log-log link function: η = log(-log(1-μ)).
Used for asymmetric binary responses where P(Y=1) approaches 1 faster than it approaches 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
cloglog_link_deriv¶
cloglog_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayCloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
cloglog_link_inverse¶
cloglog_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayCloglog inverse link: μ = 1 - exp(-exp(η)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
identity_link¶
identity_link(mu: jnp.ndarray) -> jnp.ndarrayIdentity link function: η = μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
identity_link_deriv¶
identity_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayIdentity link derivative: dη/dμ = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
identity_link_inverse¶
identity_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayIdentity inverse link: μ = η.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
inverse_link¶
inverse_link(mu: jnp.ndarray) -> jnp.ndarrayInverse link function: η = 1/μ.
Canonical link for Gamma family.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
inverse_link_deriv¶
inverse_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayInverse link derivative: dη/dμ = -1/μ².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
inverse_link_inverse¶
inverse_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayInverse link inverse: μ = 1/η.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
log_link¶
log_link(mu: jnp.ndarray) -> jnp.ndarrayLog link function: η = log(μ).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
log_link_deriv¶
log_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayLog link derivative: dη/dμ = 1/μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
log_link_inverse¶
log_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayLog inverse link: μ = exp(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) |
logit_link¶
logit_link(mu: jnp.ndarray) -> jnp.ndarrayLogit link function: η = log(μ/(1-μ)).
Values are clipped to [1e-10, 1-1e-10] to avoid log(0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
logit_link_deriv¶
logit_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayLogit link derivative: dη/dμ = 1/(μ(1-μ)).
Values are clipped to [1e-10, 1-1e-10] to avoid division by zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
logit_link_inverse¶
logit_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayLogit inverse link: μ = 1/(1 + exp(-η)).
Uses numerically stable computation to avoid overflow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
probit_link¶
probit_link(mu: jnp.ndarray) -> jnp.ndarrayProbit link function: η = Φ⁻¹(μ).
Uses the inverse error function for numerical stability. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Linear predictor values η (n,) |
probit_link_deriv¶
probit_link_deriv(mu: jnp.ndarray) -> jnp.ndarrayProbit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).
Where φ is the standard normal PDF. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be in (0, 1) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Derivative values (n,) |
probit_link_inverse¶
probit_link_inverse(eta: jnp.ndarray) -> jnp.ndarrayProbit inverse link: μ = Φ(η).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta | ndarray | Linear predictor values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Mean values μ (n,) in (0, 1) |
poisson¶
Poisson family functions for GLM fitting.
Functions:
| Name | Description |
|---|---|
poisson | Create Poisson family for count data. |
poisson_deviance | Poisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)]. |
poisson_dispersion | Dispersion parameter for Poisson family. |
poisson_initialize | Initialize μ for Poisson family. |
poisson_loglik | Poisson conditional log-likelihood (per observation). |
poisson_variance | Poisson variance function: V(μ) = μ. |
Classes¶
Functions¶
poisson¶
poisson(link: str | None = None) -> FamilyCreate Poisson family for count data.
The Poisson family is appropriate for count data where the variance equals the mean. Commonly used for modeling event counts, frequencies, or rates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link | str | None | Link function name. Only “log” supported (canonical). Defaults to “log” if None. | None |
Returns:
| Type | Description |
|---|---|
Family | Poisson family configuration |
Examples:
>>> fam = poisson()
>>> fam.name
'poisson'
>>> fam.link_name
'log'poisson_deviance¶
poisson_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPoisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].
Uses log-space arithmetic for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Unit deviance values (n,) |
poisson_dispersion¶
poisson_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatDispersion parameter for Poisson family.
Fixed at 1.0 for Poisson models (can estimate for quasi-Poisson).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion value (always 1.0) |
poisson_initialize¶
poisson_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Poisson family.
Adds small value to avoid zero counts (matches R’s poisson family).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
weights | ndarray | None | Optional prior weights (n,). Unused for Poisson, included for API consistency with other families. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
poisson_loglik¶
poisson_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPoisson conditional log-likelihood (per observation).
Computes log p(y|μ) = y*log(μ) - μ - log(y!). The log(y!) term uses log-gamma: log(Γ(y+1)) = log(y!).
Uses numerical stability patterns similar to poisson_deviance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,), must be non-negative | required |
mu | ndarray | Fitted mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Per-observation log-likelihood values (n,), NOT summed |
poisson_variance¶
poisson_variance(mu: jnp.ndarray) -> jnp.ndarrayPoisson variance function: V(μ) = μ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,), must be positive | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,) |
response¶
Response sampling and sigma resolution for GLM families.
Provides sample_response to draw response values from a GLM family
distribution given the mean on the response scale, and resolve_sigma
to handle optional dispersion parameters.
These utilities eliminate triplicated family-dispatch code in simulation, prediction, and varying-state operations.
Functions:
| Name | Description |
|---|---|
resolve_sigma | Resolve optional sigma to a concrete float. |
sample_response | Sample response values from a GLM family distribution. |
Attributes:
| Name | Type | Description |
|---|---|---|
CANONICAL_LINKS | dict[str, str] |
Attributes¶
CANONICAL_LINKS¶
CANONICAL_LINKS: dict[str, str] = {'gaussian': 'identity', 'binomial': 'logit', 'poisson': 'log', 'gamma': 'inverse', 'tdist': 'identity'}Functions¶
resolve_sigma¶
resolve_sigma(sigma: float | None) -> floatResolve optional sigma to a concrete float.
GLMs without a dispersion parameter (binomial, poisson) store
sigma=None in FitState. This resolves to 1.0 for those families.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sigma | float | None | Residual SD from FitState, or None. | required |
Returns:
| Type | Description |
|---|---|
float | The sigma value, or 1.0 if None. |
sample_response¶
sample_response(family: str, mu: np.ndarray, sigma: float, rng: np.random.Generator) -> np.ndarraySample response values from a GLM family distribution.
Given the conditional mean on the response scale (after inverse link), draws observations from the appropriate distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
family | str | Distribution family name (“gaussian”, “binomial”, “poisson”). | required |
mu | ndarray | Conditional mean on the response scale, shape (n,). | required |
sigma | float | Residual standard deviation (used only for gaussian/tdist). | required |
rng | Generator | NumPy random number generator. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Sampled response values, shape (n,). |
schema¶
Family configuration dataclass.
Classes:
| Name | Description |
|---|---|
Family | Family configuration for GLM fitting. |
Attributes:
| Name | Type | Description |
|---|---|---|
ESTIMATED_DISPERSION_FAMILIES | frozenset[str] |
Attributes¶
ESTIMATED_DISPERSION_FAMILIES¶
ESTIMATED_DISPERSION_FAMILIES: frozenset[str] = frozenset({'gaussian', 'gamma', 'tdist'})Classes¶
Family¶
Family(name: str, link_name: str, link: Callable[[jnp.ndarray], jnp.ndarray], link_inverse: Callable[[jnp.ndarray], jnp.ndarray], link_deriv: Callable[[jnp.ndarray], jnp.ndarray], variance: Callable[[jnp.ndarray], jnp.ndarray], deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], initialize: Callable[..., jnp.ndarray], dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float], df: int | None = None, robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None) -> NoneFamily configuration for GLM fitting.
All functions are pure JAX operations, enabling JIT compilation and automatic differentiation. This is a simple data container with no methods.
Attributes:
| Name | Type | Description |
|---|---|---|
name | str | Family name (e.g., “gaussian”, “binomial”, “poisson”, “tdist”) |
link_name | str | Link function name (e.g., “identity”, “logit”, “log”) |
link | Callable[[ndarray], ndarray] | Link function η = g(μ) |
link_inverse | Callable[[ndarray], ndarray] | Inverse link μ = g⁻¹(η) |
link_deriv | Callable[[ndarray], ndarray] | Link derivative dη/dμ |
variance | Callable[[ndarray], ndarray] | Variance function V(μ) |
deviance | Callable[[ndarray, ndarray], ndarray] | Unit deviance function d(y, μ) |
loglik | Callable[[ndarray, ndarray], ndarray] | Conditional log-likelihood function log p(y |
initialize | Callable..., [ndarray] | Initialization function for starting μ values. Signature: (y, weights=None) -> mu_init. The weights parameter is optional and only used by binomial family. |
dispersion | Callable[[ndarray, ndarray, int], float] | Dispersion parameter estimation function |
df | int | None | Degrees of freedom for t-distribution family (None for others) |
robust_weights | Callable[[ndarray, ndarray, float], ndarray] | None | Optional function for residual-based weights (t-dist). Signature: (y, mu, scale) -> weights. Returns multiplicative weights that downweight outliers. None for standard families. |
Attributes¶
deviance¶
deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]df¶
df: int | None = Nonedispersion¶
dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float]initialize¶
initialize: Callable[..., jnp.ndarray]link¶
link: Callable[[jnp.ndarray], jnp.ndarray]link_deriv¶
link_deriv: Callable[[jnp.ndarray], jnp.ndarray]link_inverse¶
link_inverse: Callable[[jnp.ndarray], jnp.ndarray]link_name¶
link_name: strloglik¶
loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]name¶
name: strrobust_weights¶
robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = Nonevariance¶
variance: Callable[[jnp.ndarray], jnp.ndarray]tdist¶
Student-t family functions for robust GLM fitting.
Functions:
| Name | Description |
|---|---|
tdist | Create Student-t family for robust regression. |
tdist_deviance | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_dispersion | Estimate dispersion (scale) parameter for Student-t family. |
tdist_initialize | Initialize μ for Student-t family. |
tdist_loglik | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_robust_weights | Placeholder - use tdist(df=...) factory to get proper function. |
tdist_variance | Student-t variance function: V(μ) = 1. |
Classes¶
Functions¶
tdist¶
tdist(df: int, link: str | None = None) -> FamilyCreate Student-t family for robust regression.
The Student-t family provides robust regression by using a t-distribution for the errors instead of Gaussian. This downweights outliers through iteratively reweighted least squares (IRLS).
The t-distribution has heavier tails than the Gaussian:
df=1: Cauchy distribution (very heavy tails)
df=4-5: Moderately heavy tails (common for robust regression)
df→∞: Converges to Gaussian
For automatic df based on residual degrees of freedom (n - p), use
family="tdist" in glm() and df will be set at fit() time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df | int | Degrees of freedom. Must be > 0. Typically set to n - p (residual degrees of freedom) for proper inference. | required |
link | str | None | Link function name. Only “identity” supported. Defaults to “identity” if None. | None |
Returns:
| Type | Description |
|---|---|
Family | Student-t family configuration with robust weights. |
Examples:
>>> # Robust regression with df=10
>>> fam = tdist(df=10)
>>> fam.name
'tdist'>>> # In practice, use with glm (df set automatically)
>>> model = glm("y ~ x", data=df, family="tdist").fit()tdist_deviance¶
tdist_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_dispersion¶
tdist_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> floatEstimate dispersion (scale) parameter for Student-t family.
Uses MAD (median absolute deviation) for robust scale estimation: σ = MAD / 0.6745
where 0.6745 is the MAD of the standard normal distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
mu | ndarray | Fitted mean values (n,) | required |
df_resid | int | Residual degrees of freedom (unused for MAD) | required |
Returns:
| Type | Description |
|---|---|
float | Dispersion estimate σ̂ |
tdist_initialize¶
tdist_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarrayInitialize μ for Student-t family.
Uses y directly as the starting value (like Gaussian).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y | ndarray | Response values (n,) | required |
weights | ndarray | None | Optional prior weights (n,). Unused for t-distribution. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Initial mean values (n,) |
tdist_loglik¶
tdist_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_robust_weights¶
tdist_robust_weights(y: jnp.ndarray, mu: jnp.ndarray, scale: float) -> jnp.ndarrayPlaceholder - use tdist(df=...) factory to get proper function.
tdist_variance¶
tdist_variance(mu: jnp.ndarray) -> jnp.ndarrayStudent-t variance function: V(μ) = 1.
Like Gaussian, the variance function is constant. The heavy-tailed behavior comes from the robust weights, not the variance function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu | ndarray | Mean values (n,) | required |
Returns:
| Type | Description |
|---|---|
ndarray | Variance values (n,), all ones |