Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

GLM family functions for link, variance, and deviance computations.

This module provides pure JAX functions for GLM families, designed for JIT compilation and automatic differentiation. All link, variance, and deviance functions are stateless and composable.

Attributes:

NameTypeDescription
CANONICAL_LINKSdict[str, str]
ESTIMATED_DISPERSION_FAMILIESfrozenset[str]
LINK_FUNCTIONS

Classes:

NameDescription
FamilyFamily configuration for GLM fitting.

Functions:

NameDescription
apply_linkApply link function by name: η = g(μ).
apply_link_derivApply link function derivative by name: dη/dμ.
apply_link_inverseApply inverse link function by name: μ = g⁻¹(η).
binomial_devianceBinomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].
binomial_dispersionDispersion parameter for binomial family.
binomial_initializeInitialize μ for binomial family.
binomial_loglikBinomial conditional log-likelihood (per observation).
binomial_varianceBinomial variance function: V(μ) = μ(1-μ).
build_familyCreate a Family object from family and link names.
cloglog_linkComplementary log-log link function: η = log(-log(1-μ)).
cloglog_link_derivCloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).
cloglog_link_inverseCloglog inverse link: μ = 1 - exp(-exp(η)).
gamma_devianceGamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].
gamma_dispersionEstimate dispersion parameter for Gamma family.
gamma_initializeInitialize μ for Gamma family.
gamma_loglikGamma conditional log-likelihood (per observation).
gamma_varianceGamma variance function: V(μ) = μ².
gaussian_devianceGaussian unit deviance: d(y, μ) = (y - μ)².
gaussian_dispersionEstimate dispersion parameter for Gaussian family.
gaussian_initializeInitialize μ for Gaussian family.
gaussian_loglikGaussian conditional log-likelihood (per observation).
gaussian_varianceGaussian variance function: V(μ) = 1.
identity_linkIdentity link function: η = μ.
identity_link_derivIdentity link derivative: dη/dμ = 1.
identity_link_inverseIdentity inverse link: μ = η.
inverse_linkInverse link function: η = 1/μ.
inverse_link_derivInverse link derivative: dη/dμ = -1/μ².
inverse_link_inverseInverse link inverse: μ = 1/η.
log_linkLog link function: η = log(μ).
log_link_derivLog link derivative: dη/dμ = 1/μ.
log_link_inverseLog inverse link: μ = exp(η).
logit_linkLogit link function: η = log(μ/(1-μ)).
logit_link_derivLogit link derivative: dη/dμ = 1/(μ(1-μ)).
logit_link_inverseLogit inverse link: μ = 1/(1 + exp(-η)).
poisson_deviancePoisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].
poisson_dispersionDispersion parameter for Poisson family.
poisson_initializeInitialize μ for Poisson family.
poisson_loglikPoisson conditional log-likelihood (per observation).
poisson_variancePoisson variance function: V(μ) = μ.
probit_linkProbit link function: η = Φ⁻¹(μ).
probit_link_derivProbit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).
probit_link_inverseProbit inverse link: μ = Φ(η).
resolve_sigmaResolve optional sigma to a concrete float.
sample_responseSample response values from a GLM family distribution.
tdist_deviancePlaceholder - use tdist(df=...) factory to get proper function.
tdist_dispersionEstimate dispersion (scale) parameter for Student-t family.
tdist_initializeInitialize μ for Student-t family.
tdist_loglikPlaceholder - use tdist(df=...) factory to get proper function.
tdist_robust_weightsPlaceholder - use tdist(df=...) factory to get proper function.
tdist_varianceStudent-t variance function: V(μ) = 1.

Modules:

NameDescription
binomialBinomial family functions for GLM fitting.
createFamily object construction from string names.
gammaGamma family functions for GLM fitting.
gaussianGaussian family functions for GLM fitting.
linksLink functions for GLM families.
poissonPoisson family functions for GLM fitting.
responseResponse sampling and sigma resolution for GLM families.
schemaFamily configuration dataclass.
tdistStudent-t family functions for robust GLM fitting.

Attributes

CANONICAL_LINKS: dict[str, str] = {'gaussian': 'identity', 'binomial': 'logit', 'poisson': 'log', 'gamma': 'inverse', 'tdist': 'identity'}

ESTIMATED_DISPERSION_FAMILIES

ESTIMATED_DISPERSION_FAMILIES: frozenset[str] = frozenset({'gaussian', 'gamma', 'tdist'})
LINK_FUNCTIONS = {'identity': (identity_link, identity_link_inverse, identity_link_deriv), 'log': (log_link, log_link_inverse, log_link_deriv), 'logit': (logit_link, logit_link_inverse, logit_link_deriv), 'probit': (probit_link, probit_link_inverse, probit_link_deriv), 'inverse': (inverse_link, inverse_link_inverse, inverse_link_deriv), 'cloglog': (cloglog_link, cloglog_link_inverse, cloglog_link_deriv)}

Classes

Family

Family(name: str, link_name: str, link: Callable[[jnp.ndarray], jnp.ndarray], link_inverse: Callable[[jnp.ndarray], jnp.ndarray], link_deriv: Callable[[jnp.ndarray], jnp.ndarray], variance: Callable[[jnp.ndarray], jnp.ndarray], deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], initialize: Callable[..., jnp.ndarray], dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float], df: int | None = None, robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None) -> None

Family configuration for GLM fitting.

All functions are pure JAX operations, enabling JIT compilation and automatic differentiation. This is a simple data container with no methods.

Attributes:

NameTypeDescription
namestrFamily name (e.g., “gaussian”, “binomial”, “poisson”, “tdist”)
link_namestrLink function name (e.g., “identity”, “logit”, “log”)
linkCallable[[ndarray], ndarray]Link function η = g(μ)
link_inverseCallable[[ndarray], ndarray]Inverse link μ = g⁻¹(η)
link_derivCallable[[ndarray], ndarray]Link derivative dη/dμ
varianceCallable[[ndarray], ndarray]Variance function V(μ)
devianceCallable[[ndarray, ndarray], ndarray]Unit deviance function d(y, μ)
loglikCallable[[ndarray, ndarray], ndarray]Conditional log-likelihood function log p(y
initializeCallable..., [ndarray]Initialization function for starting μ values. Signature: (y, weights=None) -> mu_init. The weights parameter is optional and only used by binomial family.
dispersionCallable[[ndarray, ndarray, int], float]Dispersion parameter estimation function
dfint | NoneDegrees of freedom for t-distribution family (None for others)
robust_weightsCallable[[ndarray, ndarray, float], ndarray] | NoneOptional function for residual-based weights (t-dist). Signature: (y, mu, scale) -> weights. Returns multiplicative weights that downweight outliers. None for standard families.

Attributes

deviance
deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
df
df: int | None = None
dispersion
dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float]
initialize
initialize: Callable[..., jnp.ndarray]
link: Callable[[jnp.ndarray], jnp.ndarray]
link_deriv: Callable[[jnp.ndarray], jnp.ndarray]
link_inverse: Callable[[jnp.ndarray], jnp.ndarray]
link_name: str
loglik
loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
name
name: str
robust_weights
robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None
variance
variance: Callable[[jnp.ndarray], jnp.ndarray]

Functions

apply_link(link: str, mu: 'np.ndarray') -> 'np.ndarray'

Apply link function by name: η = g(μ).

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
mu‘np.ndarray’Mean values (response scale).required

Returns:

TypeDescription
‘np.ndarray’Linear predictor values η.

Examples:

>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link("logit", mu)
array([-1.386,  0.   ,  1.386])
apply_link_deriv(link: str, mu: 'np.ndarray') -> 'np.ndarray'

Apply link function derivative by name: dη/dμ.

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
mu‘np.ndarray’Mean values (response scale).required

Returns:

TypeDescription
‘np.ndarray’Derivative values dη/dμ.

Examples:

>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link_deriv("logit", mu)
array([6.25, 4.  , 6.25])
apply_link_inverse(link: str, eta: 'np.ndarray') -> 'np.ndarray'

Apply inverse link function by name: μ = g⁻¹(η).

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
eta‘np.ndarray’Linear predictor values.required

Returns:

TypeDescription
‘np.ndarray’Mean values μ (response scale).

Examples:

>>> import numpy as np
>>> eta = np.array([-1, 0, 1])
>>> apply_link_inverse("logit", eta)
array([0.269, 0.5  , 0.731])

binomial_deviance

binomial_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Binomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
mundarrayFitted mean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayUnit deviance values (n,)

binomial_dispersion

binomial_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Dispersion parameter for binomial family.

Fixed at 1.0 for binomial models.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused)required

Returns:

TypeDescription
floatDispersion value (always 1.0)

binomial_initialize

binomial_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for binomial family.

Uses the weighted formula from R’s stats::binomial family: mustart <- (weights * y + 0.5) / (weights + 1)

This avoids boundary values (0 or 1) while accounting for prior weights. When weights=1 (unweighted), this gives (y + 0.5) / 2.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
weightsndarray | NoneOptional prior weights (n,). Defaults to 1.0 for all observations.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)

binomial_loglik

binomial_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Binomial conditional log-likelihood (per observation).

Computes log p(y|μ) = y*log(μ) + (1-y)*log(1-μ) for Bernoulli trials. Uses the same numerical stability patterns as binomial_deviance.

For binomial trials with n > 1, the binomial coefficient log(n choose k) would be added, but for Bernoulli (n=1) this term is zero.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
mundarrayFitted mean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed

binomial_variance

binomial_variance(mu: jnp.ndarray) -> jnp.ndarray

Binomial variance function: V(μ) = μ(1-μ).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayVariance values (n,)

build_family

build_family(family_name: str, link_name: str | None = None) -> Family

Create a Family object from family and link names.

This is a convenience function that dispatches to the appropriate family factory function based on the family name string.

Parameters:

NameTypeDescriptionDefault
family_namestrName of the family (“gaussian”, “binomial”, “poisson”, “gamma”). Note: “tdist” is not supported here as it requires a df parameter.required
link_namestr | NoneOptional link function name. If None, uses the canonical link for each family: - gaussian: “identity” - binomial: “logit” - poisson: “log” - gamma: “inverse”None

Returns:

TypeDescription
FamilyFamily configuration object

Examples:

>>> fam = build_family("gaussian")
>>> fam.name
'gaussian'
>>> fam = build_family("binomial", "probit")
>>> fam.link_name
'probit'
>>> fam = build_family("gamma", "log")
>>> fam.name
'gamma'
cloglog_link(mu: jnp.ndarray) -> jnp.ndarray

Complementary log-log link function: η = log(-log(1-μ)).

Used for asymmetric binary responses where P(Y=1) approaches 1 faster than it approaches 0.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
cloglog_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Cloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
cloglog_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Cloglog inverse link: μ = 1 - exp(-exp(η)).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)

gamma_deviance

gamma_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayUnit deviance values (n,)

gamma_dispersion

gamma_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion parameter for Gamma family.

Uses Pearson χ² / df_resid. For Gamma, the dispersion parameter is 1/shape, so the shape parameter is 1/dispersion.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedomrequired

Returns:

TypeDescription
floatDispersion estimate φ̂

gamma_initialize

gamma_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Gamma family.

Uses y directly as the starting value, ensuring positive values.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
weightsndarray | NoneOptional prior weights (n,). Unused for Gamma, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)

gamma_loglik

gamma_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gamma conditional log-likelihood (per observation).

Computes log p(y|μ) for Gamma distribution with shape parameter k and scale θ = μ/k, parameterized by mean μ. Ignores terms involving k.

The kernel is: (k-1)log(y) - y/θ - klog(θ) Which simplifies to proportional terms: -y/μ - log(μ) (ignoring k-dependent terms)

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed

gamma_variance

gamma_variance(mu: jnp.ndarray) -> jnp.ndarray

Gamma variance function: V(μ) = μ².

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayVariance values (n,)

gaussian_deviance

gaussian_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gaussian unit deviance: d(y, μ) = (y - μ)².

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required

Returns:

TypeDescription
ndarrayUnit deviance values (n,)

gaussian_dispersion

gaussian_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion parameter for Gaussian family.

Uses Pearson χ² / df_resid.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedomrequired

Returns:

TypeDescription
floatDispersion estimate φ̂

gaussian_initialize

gaussian_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Gaussian family.

Uses y directly as the starting value (matches R’s gaussian family).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
weightsndarray | NoneOptional prior weights (n,). Unused for Gaussian, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)

gaussian_loglik

gaussian_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gaussian conditional log-likelihood (per observation).

Computes log p(y|μ) = -0.5 * (y - μ)² ignoring constant terms. The full Gaussian log-likelihood includes -0.5*log(2πσ²), but this constant term cancels in optimization and is added separately when computing final log-likelihood values.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed

gaussian_variance

gaussian_variance(mu: jnp.ndarray) -> jnp.ndarray

Gaussian variance function: V(μ) = 1.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayVariance values (n,), all ones
identity_link(mu: jnp.ndarray) -> jnp.ndarray

Identity link function: η = μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
identity_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Identity link derivative: dη/dμ = 1.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
identity_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Identity inverse link: μ = η.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,)
inverse_link(mu: jnp.ndarray) -> jnp.ndarray

Inverse link function: η = 1/μ.

Canonical link for Gamma family.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
inverse_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Inverse link derivative: dη/dμ = -1/μ².

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayDerivative values (n,)
inverse_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Inverse link inverse: μ = 1/η.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayMean values μ (n,)
log_link(mu: jnp.ndarray) -> jnp.ndarray

Log link function: η = log(μ).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
log_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Log link derivative: dη/dμ = 1/μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayDerivative values (n,)
log_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Log inverse link: μ = exp(η).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,)
logit_link(mu: jnp.ndarray) -> jnp.ndarray

Logit link function: η = log(μ/(1-μ)).

Values are clipped to [1e-10, 1-1e-10] to avoid log(0).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
logit_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Logit link derivative: dη/dμ = 1/(μ(1-μ)).

Values are clipped to [1e-10, 1-1e-10] to avoid division by zero.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
logit_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Logit inverse link: μ = 1/(1 + exp(-η)).

Uses numerically stable computation to avoid overflow.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)

poisson_deviance

poisson_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Poisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayUnit deviance values (n,)

poisson_dispersion

poisson_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Dispersion parameter for Poisson family.

Fixed at 1.0 for Poisson models (can estimate for quasi-Poisson).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused)required

Returns:

TypeDescription
floatDispersion value (always 1.0)

poisson_initialize

poisson_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Poisson family.

Adds small value to avoid zero counts (matches R’s poisson family).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
weightsndarray | NoneOptional prior weights (n,). Unused for Poisson, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)

poisson_loglik

poisson_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Poisson conditional log-likelihood (per observation).

Computes log p(y|μ) = y*log(μ) - μ - log(y!). The log(y!) term uses log-gamma: log(Γ(y+1)) = log(y!).

Uses numerical stability patterns similar to poisson_deviance.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed

poisson_variance

poisson_variance(mu: jnp.ndarray) -> jnp.ndarray

Poisson variance function: V(μ) = μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayVariance values (n,)
probit_link(mu: jnp.ndarray) -> jnp.ndarray

Probit link function: η = Φ⁻¹(μ).

Uses the inverse error function for numerical stability. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
probit_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Probit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).

Where φ is the standard normal PDF. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
probit_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Probit inverse link: μ = Φ(η).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)

resolve_sigma

resolve_sigma(sigma: float | None) -> float

Resolve optional sigma to a concrete float.

GLMs without a dispersion parameter (binomial, poisson) store sigma=None in FitState. This resolves to 1.0 for those families.

Parameters:

NameTypeDescriptionDefault
sigmafloat | NoneResidual SD from FitState, or None.required

Returns:

TypeDescription
floatThe sigma value, or 1.0 if None.

sample_response

sample_response(family: str, mu: np.ndarray, sigma: float, rng: np.random.Generator) -> np.ndarray

Sample response values from a GLM family distribution.

Given the conditional mean on the response scale (after inverse link), draws observations from the appropriate distribution.

Parameters:

NameTypeDescriptionDefault
familystrDistribution family name (“gaussian”, “binomial”, “poisson”).required
mundarrayConditional mean on the response scale, shape (n,).required
sigmafloatResidual standard deviation (used only for gaussian/tdist).required
rngGeneratorNumPy random number generator.required

Returns:

TypeDescription
ndarraySampled response values, shape (n,).

tdist_deviance

tdist_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_dispersion

tdist_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion (scale) parameter for Student-t family.

Uses MAD (median absolute deviation) for robust scale estimation: σ = MAD / 0.6745

where 0.6745 is the MAD of the standard normal distribution.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused for MAD)required

Returns:

TypeDescription
floatDispersion estimate σ̂

tdist_initialize

tdist_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Student-t family.

Uses y directly as the starting value (like Gaussian).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
weightsndarray | NoneOptional prior weights (n,). Unused for t-distribution.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)

tdist_loglik

tdist_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_robust_weights

tdist_robust_weights(y: jnp.ndarray, mu: jnp.ndarray, scale: float) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_variance

tdist_variance(mu: jnp.ndarray) -> jnp.ndarray

Student-t variance function: V(μ) = 1.

Like Gaussian, the variance function is constant. The heavy-tailed behavior comes from the robust weights, not the variance function.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayVariance values (n,), all ones

Modules

binomial

Binomial family functions for GLM fitting.

Functions:

NameDescription
binomialCreate binomial family for binary or proportion data.
binomial_devianceBinomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].
binomial_dispersionDispersion parameter for binomial family.
binomial_initializeInitialize μ for binomial family.
binomial_loglikBinomial conditional log-likelihood (per observation).
binomial_varianceBinomial variance function: V(μ) = μ(1-μ).

Classes

Functions

binomial
binomial(link: str | None = None) -> Family

Create binomial family for binary or proportion data.

The binomial family is appropriate for binary outcomes (0/1) or proportions (successes/trials). Commonly used for logistic and probit regression.

Parameters:

NameTypeDescriptionDefault
linkstr | NoneLink function name. Options are “logit” (default) or “probit”. Defaults to “logit” if None.None

Returns:

TypeDescription
FamilyBinomial family configuration

Examples:

>>> # Logistic regression (default)
>>> fam = binomial()
>>> fam.link_name
'logit'
>>> # Probit regression
>>> fam = binomial("probit")
>>> fam.link_name
'probit'
binomial_deviance
binomial_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Binomial unit deviance: d(y, μ) = 2[y log(y/μ) + (1-y) log((1-y)/(1-μ))].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
mundarrayFitted mean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayUnit deviance values (n,)
binomial_dispersion
binomial_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Dispersion parameter for binomial family.

Fixed at 1.0 for binomial models.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused)required

Returns:

TypeDescription
floatDispersion value (always 1.0)
binomial_initialize
binomial_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for binomial family.

Uses the weighted formula from R’s stats::binomial family: mustart <- (weights * y + 0.5) / (weights + 1)

This avoids boundary values (0 or 1) while accounting for prior weights. When weights=1 (unweighted), this gives (y + 0.5) / 2.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
weightsndarray | NoneOptional prior weights (n,). Defaults to 1.0 for all observations.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)
binomial_loglik
binomial_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Binomial conditional log-likelihood (per observation).

Computes log p(y|μ) = y*log(μ) + (1-y)*log(1-μ) for Bernoulli trials. Uses the same numerical stability patterns as binomial_deviance.

For binomial trials with n > 1, the binomial coefficient log(n choose k) would be added, but for Bernoulli (n=1) this term is zero.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be in [0, 1]required
mundarrayFitted mean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed
binomial_variance
binomial_variance(mu: jnp.ndarray) -> jnp.ndarray

Binomial variance function: V(μ) = μ(1-μ).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayVariance values (n,)

create

Family object construction from string names.

Functions:

NameDescription
build_familyCreate a Family object from family and link names.

Classes

Functions

build_family
build_family(family_name: str, link_name: str | None = None) -> Family

Create a Family object from family and link names.

This is a convenience function that dispatches to the appropriate family factory function based on the family name string.

Parameters:

NameTypeDescriptionDefault
family_namestrName of the family (“gaussian”, “binomial”, “poisson”, “gamma”). Note: “tdist” is not supported here as it requires a df parameter.required
link_namestr | NoneOptional link function name. If None, uses the canonical link for each family: - gaussian: “identity” - binomial: “logit” - poisson: “log” - gamma: “inverse”None

Returns:

TypeDescription
FamilyFamily configuration object

Examples:

>>> fam = build_family("gaussian")
>>> fam.name
'gaussian'
>>> fam = build_family("binomial", "probit")
>>> fam.link_name
'probit'
>>> fam = build_family("gamma", "log")
>>> fam.name
'gamma'

gamma

Gamma family functions for GLM fitting.

Functions:

NameDescription
gammaCreate Gamma family for positive continuous data.
gamma_devianceGamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].
gamma_dispersionEstimate dispersion parameter for Gamma family.
gamma_initializeInitialize μ for Gamma family.
gamma_loglikGamma conditional log-likelihood (per observation).
gamma_varianceGamma variance function: V(μ) = μ².

Classes

Functions

gamma
gamma(link: str | None = None) -> Family

Create Gamma family for positive continuous data.

The Gamma family is appropriate for positive continuous response data where variance increases with the mean (V(μ) = μ²). Commonly used for modeling waiting times, insurance claims, or other positive-valued data.

Parameters:

NameTypeDescriptionDefault
linkstr | NoneLink function name. Options are “inverse” (canonical, default) or “log”. Defaults to “inverse” if None.None

Returns:

TypeDescription
FamilyGamma family configuration

Examples:

>>> # Gamma with inverse link (canonical)
>>> fam = gamma()
>>> fam.name
'gamma'
>>> fam.link_name
'inverse'
>>> # Gamma with log link
>>> fam = gamma("log")
>>> fam.link_name
'log'
gamma_deviance
gamma_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gamma unit deviance: d(y, μ) = 2[-log(y/μ) + (y - μ)/μ].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayUnit deviance values (n,)
gamma_dispersion
gamma_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion parameter for Gamma family.

Uses Pearson χ² / df_resid. For Gamma, the dispersion parameter is 1/shape, so the shape parameter is 1/dispersion.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedomrequired

Returns:

TypeDescription
floatDispersion estimate φ̂
gamma_initialize
gamma_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Gamma family.

Uses y directly as the starting value, ensuring positive values.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
weightsndarray | NoneOptional prior weights (n,). Unused for Gamma, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)
gamma_loglik
gamma_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gamma conditional log-likelihood (per observation).

Computes log p(y|μ) for Gamma distribution with shape parameter k and scale θ = μ/k, parameterized by mean μ. Ignores terms involving k.

The kernel is: (k-1)log(y) - y/θ - klog(θ) Which simplifies to proportional terms: -y/μ - log(μ) (ignoring k-dependent terms)

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be positiverequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed
gamma_variance
gamma_variance(mu: jnp.ndarray) -> jnp.ndarray

Gamma variance function: V(μ) = μ².

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayVariance values (n,)

gaussian

Gaussian family functions for GLM fitting.

Functions:

NameDescription
gaussianCreate Gaussian family with identity link.
gaussian_devianceGaussian unit deviance: d(y, μ) = (y - μ)².
gaussian_dispersionEstimate dispersion parameter for Gaussian family.
gaussian_initializeInitialize μ for Gaussian family.
gaussian_loglikGaussian conditional log-likelihood (per observation).
gaussian_varianceGaussian variance function: V(μ) = 1.

Classes

Functions

gaussian
gaussian(link: str | None = None) -> Family

Create Gaussian family with identity link.

The Gaussian family is appropriate for continuous response data with constant variance. This is equivalent to ordinary least squares (OLS) when using the identity link.

Parameters:

NameTypeDescriptionDefault
linkstr | NoneLink function name. Only “identity” supported (canonical). Defaults to “identity” if None.None

Returns:

TypeDescription
FamilyGaussian family configuration

Examples:

>>> fam = gaussian()
>>> fam.name
'gaussian'
>>> fam.link_name
'identity'
gaussian_deviance
gaussian_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gaussian unit deviance: d(y, μ) = (y - μ)².

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required

Returns:

TypeDescription
ndarrayUnit deviance values (n,)
gaussian_dispersion
gaussian_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion parameter for Gaussian family.

Uses Pearson χ² / df_resid.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedomrequired

Returns:

TypeDescription
floatDispersion estimate φ̂
gaussian_initialize
gaussian_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Gaussian family.

Uses y directly as the starting value (matches R’s gaussian family).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
weightsndarray | NoneOptional prior weights (n,). Unused for Gaussian, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)
gaussian_loglik
gaussian_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Gaussian conditional log-likelihood (per observation).

Computes log p(y|μ) = -0.5 * (y - μ)² ignoring constant terms. The full Gaussian log-likelihood includes -0.5*log(2πσ²), but this constant term cancels in optimization and is added separately when computing final log-likelihood values.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed
gaussian_variance
gaussian_variance(mu: jnp.ndarray) -> jnp.ndarray

Gaussian variance function: V(μ) = 1.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayVariance values (n,), all ones

Link functions for GLM families.

Functions:

NameDescription
apply_linkApply link function by name: η = g(μ).
apply_link_derivApply link function derivative by name: dη/dμ.
apply_link_inverseApply inverse link function by name: μ = g⁻¹(η).
apply_link_inverse_derivCompute derivative of inverse link function: dμ/dη = d/dη[g⁻¹(η)].
cloglog_linkComplementary log-log link function: η = log(-log(1-μ)).
cloglog_link_derivCloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).
cloglog_link_inverseCloglog inverse link: μ = 1 - exp(-exp(η)).
identity_linkIdentity link function: η = μ.
identity_link_derivIdentity link derivative: dη/dμ = 1.
identity_link_inverseIdentity inverse link: μ = η.
inverse_linkInverse link function: η = 1/μ.
inverse_link_derivInverse link derivative: dη/dμ = -1/μ².
inverse_link_inverseInverse link inverse: μ = 1/η.
log_linkLog link function: η = log(μ).
log_link_derivLog link derivative: dη/dμ = 1/μ.
log_link_inverseLog inverse link: μ = exp(η).
logit_linkLogit link function: η = log(μ/(1-μ)).
logit_link_derivLogit link derivative: dη/dμ = 1/(μ(1-μ)).
logit_link_inverseLogit inverse link: μ = 1/(1 + exp(-η)).
probit_linkProbit link function: η = Φ⁻¹(μ).
probit_link_derivProbit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).
probit_link_inverseProbit inverse link: μ = Φ(η).

Attributes:

NameTypeDescription
LINK_FUNCTIONS

Attributes

LINK_FUNCTIONS = {'identity': (identity_link, identity_link_inverse, identity_link_deriv), 'log': (log_link, log_link_inverse, log_link_deriv), 'logit': (logit_link, logit_link_inverse, logit_link_deriv), 'probit': (probit_link, probit_link_inverse, probit_link_deriv), 'inverse': (inverse_link, inverse_link_inverse, inverse_link_deriv), 'cloglog': (cloglog_link, cloglog_link_inverse, cloglog_link_deriv)}

Classes

Functions

apply_link(link: str, mu: 'np.ndarray') -> 'np.ndarray'

Apply link function by name: η = g(μ).

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
mu‘np.ndarray’Mean values (response scale).required

Returns:

TypeDescription
‘np.ndarray’Linear predictor values η.

Examples:

>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link("logit", mu)
array([-1.386,  0.   ,  1.386])
apply_link_deriv(link: str, mu: 'np.ndarray') -> 'np.ndarray'

Apply link function derivative by name: dη/dμ.

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
mu‘np.ndarray’Mean values (response scale).required

Returns:

TypeDescription
‘np.ndarray’Derivative values dη/dμ.

Examples:

>>> import numpy as np
>>> mu = np.array([0.2, 0.5, 0.8])
>>> apply_link_deriv("logit", mu)
array([6.25, 4.  , 6.25])
apply_link_inverse(link: str, eta: 'np.ndarray') -> 'np.ndarray'

Apply inverse link function by name: μ = g⁻¹(η).

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
eta‘np.ndarray’Linear predictor values.required

Returns:

TypeDescription
‘np.ndarray’Mean values μ (response scale).

Examples:

>>> import numpy as np
>>> eta = np.array([-1, 0, 1])
>>> apply_link_inverse("logit", eta)
array([0.269, 0.5  , 0.731])
apply_link_inverse_deriv(link: str, eta: 'np.ndarray') -> 'np.ndarray'

Compute derivative of inverse link function: dμ/dη = d/dη[g⁻¹(η)].

Used for delta method transformation from link to data scale: SE_data = |dμ/dη| × SE_link.

The derivative is computed as 1 / (dη/dμ) evaluated at μ = g⁻¹(η). For logit: p·(1-p). For log: exp(η). For identity: 1.

Parameters:

NameTypeDescriptionDefault
linkstrLink function name. One of: identity, log, logit, probit, inverse, cloglog.required
eta‘np.ndarray’Linear predictor values.required

Returns:

TypeDescription
‘np.ndarray’Derivative values dμ/dη.
cloglog_link(mu: jnp.ndarray) -> jnp.ndarray

Complementary log-log link function: η = log(-log(1-μ)).

Used for asymmetric binary responses where P(Y=1) approaches 1 faster than it approaches 0.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
cloglog_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Cloglog link derivative: dη/dμ = 1/((1-μ) * (-log(1-μ))).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
cloglog_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Cloglog inverse link: μ = 1 - exp(-exp(η)).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)
identity_link(mu: jnp.ndarray) -> jnp.ndarray

Identity link function: η = μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
identity_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Identity link derivative: dη/dμ = 1.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
identity_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Identity inverse link: μ = η.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,)
inverse_link(mu: jnp.ndarray) -> jnp.ndarray

Inverse link function: η = 1/μ.

Canonical link for Gamma family.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
inverse_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Inverse link derivative: dη/dμ = -1/μ².

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayDerivative values (n,)
inverse_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Inverse link inverse: μ = 1/η.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayMean values μ (n,)
log_link(mu: jnp.ndarray) -> jnp.ndarray

Log link function: η = log(μ).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
log_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Log link derivative: dη/dμ = 1/μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayDerivative values (n,)
log_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Log inverse link: μ = exp(η).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,)
logit_link(mu: jnp.ndarray) -> jnp.ndarray

Logit link function: η = log(μ/(1-μ)).

Values are clipped to [1e-10, 1-1e-10] to avoid log(0).

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
logit_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Logit link derivative: dη/dμ = 1/(μ(1-μ)).

Values are clipped to [1e-10, 1-1e-10] to avoid division by zero.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
logit_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Logit inverse link: μ = 1/(1 + exp(-η)).

Uses numerically stable computation to avoid overflow.

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)
probit_link(mu: jnp.ndarray) -> jnp.ndarray

Probit link function: η = Φ⁻¹(μ).

Uses the inverse error function for numerical stability. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayLinear predictor values η (n,)
probit_link_deriv(mu: jnp.ndarray) -> jnp.ndarray

Probit link derivative: dη/dμ = 1/φ(Φ⁻¹(μ)).

Where φ is the standard normal PDF. Values are clipped to [1e-10, 1-1e-10] to avoid infinite results.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be in (0, 1)required

Returns:

TypeDescription
ndarrayDerivative values (n,)
probit_link_inverse(eta: jnp.ndarray) -> jnp.ndarray

Probit inverse link: μ = Φ(η).

Parameters:

NameTypeDescriptionDefault
etandarrayLinear predictor values (n,)required

Returns:

TypeDescription
ndarrayMean values μ (n,) in (0, 1)

poisson

Poisson family functions for GLM fitting.

Functions:

NameDescription
poissonCreate Poisson family for count data.
poisson_deviancePoisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].
poisson_dispersionDispersion parameter for Poisson family.
poisson_initializeInitialize μ for Poisson family.
poisson_loglikPoisson conditional log-likelihood (per observation).
poisson_variancePoisson variance function: V(μ) = μ.

Classes

Functions

poisson
poisson(link: str | None = None) -> Family

Create Poisson family for count data.

The Poisson family is appropriate for count data where the variance equals the mean. Commonly used for modeling event counts, frequencies, or rates.

Parameters:

NameTypeDescriptionDefault
linkstr | NoneLink function name. Only “log” supported (canonical). Defaults to “log” if None.None

Returns:

TypeDescription
FamilyPoisson family configuration

Examples:

>>> fam = poisson()
>>> fam.name
'poisson'
>>> fam.link_name
'log'
poisson_deviance
poisson_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Poisson unit deviance: d(y, μ) = 2[y log(y/μ) - (y - μ)].

Uses log-space arithmetic for numerical stability.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayUnit deviance values (n,)
poisson_dispersion
poisson_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Dispersion parameter for Poisson family.

Fixed at 1.0 for Poisson models (can estimate for quasi-Poisson).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused)required

Returns:

TypeDescription
floatDispersion value (always 1.0)
poisson_initialize
poisson_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Poisson family.

Adds small value to avoid zero counts (matches R’s poisson family).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
weightsndarray | NoneOptional prior weights (n,). Unused for Poisson, included for API consistency with other families.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)
poisson_loglik
poisson_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Poisson conditional log-likelihood (per observation).

Computes log p(y|μ) = y*log(μ) - μ - log(y!). The log(y!) term uses log-gamma: log(Γ(y+1)) = log(y!).

Uses numerical stability patterns similar to poisson_deviance.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,), must be non-negativerequired
mundarrayFitted mean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayPer-observation log-likelihood values (n,), NOT summed
poisson_variance
poisson_variance(mu: jnp.ndarray) -> jnp.ndarray

Poisson variance function: V(μ) = μ.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,), must be positiverequired

Returns:

TypeDescription
ndarrayVariance values (n,)

response

Response sampling and sigma resolution for GLM families.

Provides sample_response to draw response values from a GLM family distribution given the mean on the response scale, and resolve_sigma to handle optional dispersion parameters.

These utilities eliminate triplicated family-dispatch code in simulation, prediction, and varying-state operations.

Functions:

NameDescription
resolve_sigmaResolve optional sigma to a concrete float.
sample_responseSample response values from a GLM family distribution.

Attributes:

NameTypeDescription
CANONICAL_LINKSdict[str, str]

Attributes

CANONICAL_LINKS: dict[str, str] = {'gaussian': 'identity', 'binomial': 'logit', 'poisson': 'log', 'gamma': 'inverse', 'tdist': 'identity'}

Functions

resolve_sigma
resolve_sigma(sigma: float | None) -> float

Resolve optional sigma to a concrete float.

GLMs without a dispersion parameter (binomial, poisson) store sigma=None in FitState. This resolves to 1.0 for those families.

Parameters:

NameTypeDescriptionDefault
sigmafloat | NoneResidual SD from FitState, or None.required

Returns:

TypeDescription
floatThe sigma value, or 1.0 if None.
sample_response
sample_response(family: str, mu: np.ndarray, sigma: float, rng: np.random.Generator) -> np.ndarray

Sample response values from a GLM family distribution.

Given the conditional mean on the response scale (after inverse link), draws observations from the appropriate distribution.

Parameters:

NameTypeDescriptionDefault
familystrDistribution family name (“gaussian”, “binomial”, “poisson”).required
mundarrayConditional mean on the response scale, shape (n,).required
sigmafloatResidual standard deviation (used only for gaussian/tdist).required
rngGeneratorNumPy random number generator.required

Returns:

TypeDescription
ndarraySampled response values, shape (n,).

schema

Family configuration dataclass.

Classes:

NameDescription
FamilyFamily configuration for GLM fitting.

Attributes:

NameTypeDescription
ESTIMATED_DISPERSION_FAMILIESfrozenset[str]

Attributes

ESTIMATED_DISPERSION_FAMILIES
ESTIMATED_DISPERSION_FAMILIES: frozenset[str] = frozenset({'gaussian', 'gamma', 'tdist'})

Classes

Family
Family(name: str, link_name: str, link: Callable[[jnp.ndarray], jnp.ndarray], link_inverse: Callable[[jnp.ndarray], jnp.ndarray], link_deriv: Callable[[jnp.ndarray], jnp.ndarray], variance: Callable[[jnp.ndarray], jnp.ndarray], deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], initialize: Callable[..., jnp.ndarray], dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float], df: int | None = None, robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None) -> None

Family configuration for GLM fitting.

All functions are pure JAX operations, enabling JIT compilation and automatic differentiation. This is a simple data container with no methods.

Attributes:

NameTypeDescription
namestrFamily name (e.g., “gaussian”, “binomial”, “poisson”, “tdist”)
link_namestrLink function name (e.g., “identity”, “logit”, “log”)
linkCallable[[ndarray], ndarray]Link function η = g(μ)
link_inverseCallable[[ndarray], ndarray]Inverse link μ = g⁻¹(η)
link_derivCallable[[ndarray], ndarray]Link derivative dη/dμ
varianceCallable[[ndarray], ndarray]Variance function V(μ)
devianceCallable[[ndarray, ndarray], ndarray]Unit deviance function d(y, μ)
loglikCallable[[ndarray, ndarray], ndarray]Conditional log-likelihood function log p(y
initializeCallable..., [ndarray]Initialization function for starting μ values. Signature: (y, weights=None) -> mu_init. The weights parameter is optional and only used by binomial family.
dispersionCallable[[ndarray, ndarray, int], float]Dispersion parameter estimation function
dfint | NoneDegrees of freedom for t-distribution family (None for others)
robust_weightsCallable[[ndarray, ndarray, float], ndarray] | NoneOptional function for residual-based weights (t-dist). Signature: (y, mu, scale) -> weights. Returns multiplicative weights that downweight outliers. None for standard families.
Attributes
deviance
deviance: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
df
df: int | None = None
dispersion
dispersion: Callable[[jnp.ndarray, jnp.ndarray, int], float]
initialize
initialize: Callable[..., jnp.ndarray]
link: Callable[[jnp.ndarray], jnp.ndarray]
link_deriv: Callable[[jnp.ndarray], jnp.ndarray]
link_inverse: Callable[[jnp.ndarray], jnp.ndarray]
link_name: str
loglik
loglik: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
name
name: str
robust_weights
robust_weights: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray] | None = None
variance
variance: Callable[[jnp.ndarray], jnp.ndarray]

tdist

Student-t family functions for robust GLM fitting.

Functions:

NameDescription
tdistCreate Student-t family for robust regression.
tdist_deviancePlaceholder - use tdist(df=...) factory to get proper function.
tdist_dispersionEstimate dispersion (scale) parameter for Student-t family.
tdist_initializeInitialize μ for Student-t family.
tdist_loglikPlaceholder - use tdist(df=...) factory to get proper function.
tdist_robust_weightsPlaceholder - use tdist(df=...) factory to get proper function.
tdist_varianceStudent-t variance function: V(μ) = 1.

Classes

Functions

tdist
tdist(df: int, link: str | None = None) -> Family

Create Student-t family for robust regression.

The Student-t family provides robust regression by using a t-distribution for the errors instead of Gaussian. This downweights outliers through iteratively reweighted least squares (IRLS).

The t-distribution has heavier tails than the Gaussian:

For automatic df based on residual degrees of freedom (n - p), use family="tdist" in glm() and df will be set at fit() time.

Parameters:

NameTypeDescriptionDefault
dfintDegrees of freedom. Must be > 0. Typically set to n - p (residual degrees of freedom) for proper inference.required
linkstr | NoneLink function name. Only “identity” supported. Defaults to “identity” if None.None

Returns:

TypeDescription
FamilyStudent-t family configuration with robust weights.

Examples:

>>> # Robust regression with df=10
>>> fam = tdist(df=10)
>>> fam.name
'tdist'
>>> # In practice, use with glm (df set automatically)
>>> model = glm("y ~ x", data=df, family="tdist").fit()
tdist_deviance
tdist_deviance(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_dispersion
tdist_dispersion(y: jnp.ndarray, mu: jnp.ndarray, df_resid: int) -> float

Estimate dispersion (scale) parameter for Student-t family.

Uses MAD (median absolute deviation) for robust scale estimation: σ = MAD / 0.6745

where 0.6745 is the MAD of the standard normal distribution.

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
mundarrayFitted mean values (n,)required
df_residintResidual degrees of freedom (unused for MAD)required

Returns:

TypeDescription
floatDispersion estimate σ̂
tdist_initialize
tdist_initialize(y: jnp.ndarray, weights: jnp.ndarray | None = None) -> jnp.ndarray

Initialize μ for Student-t family.

Uses y directly as the starting value (like Gaussian).

Parameters:

NameTypeDescriptionDefault
yndarrayResponse values (n,)required
weightsndarray | NoneOptional prior weights (n,). Unused for t-distribution.None

Returns:

TypeDescription
ndarrayInitial mean values (n,)
tdist_loglik
tdist_loglik(y: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_robust_weights
tdist_robust_weights(y: jnp.ndarray, mu: jnp.ndarray, scale: float) -> jnp.ndarray

Placeholder - use tdist(df=...) factory to get proper function.

tdist_variance
tdist_variance(mu: jnp.ndarray) -> jnp.ndarray

Student-t variance function: V(μ) = 1.

Like Gaussian, the variance function is constant. The heavy-tailed behavior comes from the robust weights, not the variance function.

Parameters:

NameTypeDescriptionDefault
mundarrayMean values (n,)required

Returns:

TypeDescription
ndarrayVariance values (n,), all ones