Marginal effects and estimated marginal means computation.
Call chain:
model.explore() -> build_reference_grid() -> compute_emm() / compute_slopes() -> compute_contrasts()Classes:
| Name | Description |
|---|---|
Condition | A conditioning specification in explore formula. |
ExploreFormulaError | Error in explore formula syntax. |
ExploreFormulaSpec | Parsed explore formula. |
ResolvedConditions | Typed buckets for resolved conditioning specifications. |
Functions:
| Name | Description |
|---|---|
apply_bracket_contrasts | Apply bracket contrast expression to an EMM MeeState. |
apply_bracket_contrasts_grouped | Apply bracket contrasts within each condition group of a crossed MeeState. |
apply_contrasts | Apply contrast matrix to marginal means/effects. |
apply_contrasts_grouped | Apply contrasts within each condition group of a crossed MeeState. |
apply_rhs_bracket_contrast | Apply a bracket contrast on a RHS condition column. |
build_all_pairwise_matrix | Build all pairwise contrasts between EMM levels. |
build_bracket_contrast_matrix | Build contrast matrix and labels from bracket contrast expression. |
build_contrast_matrix | Build a contrast matrix based on contrast type. |
build_helmert_matrix | Build Helmert contrasts (each level vs mean of previous levels). |
build_pairwise_matrix | Build (n-1) linearly independent pairwise contrasts. |
build_poly_matrix | Build orthogonal polynomial contrast matrix for EMMs. |
build_reference_grid | Construct reference grid for marginal effects evaluation. |
build_sequential_matrix | Build sequential (successive differences) contrasts. |
build_sum_to_zero_matrix | Build sum-to-zero contrasts (deviation coding). |
build_treatment_matrix | Build treatment (Dunnett-style) contrasts against a reference level. |
combine_resolved | Merge two ResolvedConditions, with b taking precedence on conflicts. |
compose_contrast_matrix | Compose contrast matrix with prediction matrix. |
compute_compound_bracket_contrasts | Compute bracket contrasts for a compound focal variable. |
compute_conditional_emm | Compute per-group conditional EMMs incorporating intercept BLUPs. |
compute_conditional_slopes | Compute per-group conditional slopes incorporating BLUPs. |
compute_contrasts | Apply contrast matrix to EMMs. |
compute_emm | Compute estimated marginal means for a categorical focal variable. |
compute_joint_test | Compute joint hypothesis tests for model terms. |
compute_mee_inference | Compute delta method inference for marginal effects. |
compute_mee_inference_fallback | Compute inference for MEE without L_matrix (fallback path). |
compute_mee_se | Compute standard errors for MEE estimates (means or slopes). |
compute_slopes | Compute marginal slope for a continuous focal variable. |
compute_slopes_crossed | Compute crossed slopes over focal variable x condition grid. |
compute_slopes_finite_diff | Compute marginal slopes via centered finite differences. |
dispatch_marginal_computation | Route a parsed explore formula to the appropriate marginal computation. |
get_contrast_labels | Generate human-readable labels for contrasts. |
parse_explore_formula | Parse an explore formula string. |
resolve_conditions | Classify each Condition into the appropriate typed bucket. |
Modules:
| Name | Description |
|---|---|
bracket_contrasts | Bracket contrast matrix builder and application. |
compute | Marginal effects dispatch and routing. |
conditions | Condition resolution for explore formula RHS conditioning. |
contrasts | Contrast computation for marginal effects. |
emm | Estimated marginal means computation. |
explore | Explore formula parser. |
explore_parser | Explore formula recursive descent parser. |
explore_scanner | Explore formula scanner/tokenizer. |
factors | Factor level extraction and handling utilities. |
grid | Reference grid construction for marginal effects. |
inference | Inference for marginal effects using the delta method. |
joint_tests | ANOVA-style joint hypothesis tests for model terms. |
matrices | Contrast matrix builders for EMM comparisons. |
resolve | Resolution helpers for marginal effects dispatch. |
slopes | Marginal slopes computation. |
transforms | Pure transform operations for result DataFrames. |
validation | Validation guards for marginal effects operations. |
Classes¶
Condition¶
A conditioning specification in explore formula.
Represents a variable to condition on, optionally with specific values, a method for generating values, or a bracket contrast expression.
Created by: parse_explore_formula(), ExploreParser Consumed by: resolve_all_conditions(), dispatch_marginal_computation() Augmented by: Never
Attributes:
| Name | Type | Description |
|---|---|---|
var | str | Variable name to condition on. |
at_values | tuple | None | Specific values to evaluate at (e.g., (50.0,) or (“A”, “B”)). |
at_range | int | None | Number of evenly-spaced values across the variable’s range. |
at_quantile | int | None | Number of quantile values to use. |
contrast_expr | ContrastExpr | None | Bracket contrast expression on this condition variable (e.g., from Dose[High - Low] on the RHS). When set, the variable is treated as a grid categorical during condition resolution, and the contrast is applied as a post-processing step. |
Attributes¶
at_quantile¶
at_quantile: int | None = field(default=None, validator=is_optional_positive_int)at_range¶
at_range: int | None = field(default=None, validator=is_optional_positive_int)at_values¶
at_values: tuple | None = field(default=None)contrast_expr¶
contrast_expr: ContrastExpr | None = field(default=None, validator=(validators.optional(validators.instance_of(ContrastExpr))))var¶
var: str = field(validator=(validators.instance_of(str)))ExploreFormulaError¶
ExploreFormulaError(message: str, formula: str, position: int | None = None) -> NoneBases: ValueError
Error in explore formula syntax.
Provides helpful error messages with position indicators for syntax errors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
message | str | Error description. | required |
formula | str | The formula that caused the error. | required |
position | int | None | Character position of the error (optional). | None |
Attributes:
| Name | Type | Description |
|---|---|---|
formula | ||
position |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
message | str | Error description. | required |
formula | str | The formula that caused the error. | required |
position | int | None | Character position of the error. | None |
Attributes¶
formula¶
formula = formulaposition¶
position = positionExploreFormulaSpec¶
Parsed explore formula.
Represents a parsed explore formula with focal variable, optional contrast type, and conditioning specifications.
Created by: parse_explore_formula() Consumed by: dispatch_marginal_computation(), plot_predict(), plot_explore() Augmented by: attrs.evolve() in resolve_focal_at_spec() materializes focal at-values
Attributes:
| Name | Type | Description |
|---|---|---|
focal_var | str | The variable to compute marginal effects for. |
contrast_type | str | None | Type of contrast (pairwise, sequential, poly, treatment, sum, helmert, custom) or None for simple EMMs. Set to "custom" for bracket contrast expressions. |
contrast_degree | int | None | Degree parameter for polynomial contrasts (default None means use n_levels - 1, i.e., maximum degree). |
contrast_ref | str | None | Reference level for treatment/dummy contrasts (e.g., "Placebo" from treatment(Drug, ref=Placebo)). |
contrast_level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts (helmert, sequential, poly). Parsed from bracket list syntax, e.g. poly(dose, [low, med, high]). |
contrast_expr | ContrastExpr | None | Bracket contrast expression AST (e.g., from Drug[Active - Placebo] syntax). None for named contrast functions or simple EMMs. |
conditions | tuple[Condition, ...] | Tuple of Condition objects specifying conditioning variables. |
focal_at_values | tuple[float | str, ...] | None | Specific values to evaluate the focal variable at (e.g., from Days@[0, 3, 6, 9] syntax). None means use all levels. |
focal_at_range | int | None | Number of evenly-spaced values across the focal variable’s range (e.g., from Days@range(5) syntax). None means not set. |
focal_at_quantile | int | None | Number of quantile values for the focal variable (e.g., from Days@quantile(3) syntax). None means not set. |
Attributes¶
conditions¶
conditions: tuple[Condition, ...] = field(factory=tuple, converter=tuple)contrast_degree¶
contrast_degree: int | None = field(default=None, validator=is_optional_positive_int)contrast_expr¶
contrast_expr: ContrastExpr | None = field(default=None)contrast_level_ordering¶
contrast_level_ordering: tuple[str, ...] | None = field(default=None, validator=is_optional_tuple_of_str)contrast_ref¶
contrast_ref: str | None = field(default=None, validator=is_optional_str)contrast_type¶
contrast_type: str | None = field(default=None, validator=(validators.optional(is_choice_str(('pairwise', 'sequential', 'poly', 'treatment', 'sum', 'helmert', 'custom')))))focal_at_quantile¶
focal_at_quantile: int | None = field(default=None, validator=is_optional_positive_int)focal_at_range¶
focal_at_range: int | None = field(default=None, validator=is_optional_positive_int)focal_at_values¶
focal_at_values: tuple[float | str, ...] | None = field(default=None)focal_var¶
focal_var: str = field(validator=(validators.instance_of(str)))has_conditions¶
has_conditions: boolReturn True if conditioning variables are specified.
has_contrast¶
has_contrast: boolReturn True if any contrast is specified (named function or bracket expr).
has_contrast_expr¶
has_contrast_expr: boolReturn True if a bracket contrast expression is specified.
has_rhs_contrasts¶
has_rhs_contrasts: boolReturn True if any RHS condition has a bracket contrast expression.
ResolvedConditions¶
Typed buckets for resolved conditioning specifications.
Attributes:
| Name | Type | Description |
|---|---|---|
at_overrides | dict[str, float] | Single numeric pin per variable (e.g. Income@50). |
set_categoricals | dict[str, str] | Single categorical pin per variable (e.g. Ethnicity@Asian). |
grid_categoricals | dict[str, list[str]] | Multi-level categoricals to cross (e.g. bare Ethnicity → all levels, or Ethnicity@(Asian, Caucasian) → those two levels). |
grid_numerics | dict[str, list[float]] | Multi-value numerics to cross (e.g. Income@(10, 20) or Income@:range(5)). |
Attributes¶
at_overrides¶
at_overrides: dict[str, float]grid_categoricals¶
grid_categoricals: dict[str, list[str]]grid_numerics¶
grid_numerics: dict[str, list[float]]has_grid¶
has_grid: boolReturn True if any condition requires grid expansion.
raw_at_overrides¶
raw_at_overrides: dict[str, float] = Factory(dict)Pre-transform at-overrides with original (raw) variable names.
Populated by the dispatch layer after _resolve_at_overrides remaps
keys to design-matrix names. The slopes path needs raw keys because it
operates on a data grid with raw column names, not design-matrix names.
Empty when no transform resolution was applied.
set_categoricals¶
set_categoricals: dict[str, str]Functions¶
apply_bracket_contrasts¶
apply_bracket_contrasts(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply bracket contrast expression to an EMM MeeState.
Computes contrasts by building a weight matrix from the bracket expression and multiplying by the EMM estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState with EMM estimates. | required |
expr | ContrastExpr | ContrastExpr from parser. | required |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with contrast estimates, type "contrasts", |
MeeState | and contrast_method="custom". |
apply_bracket_contrasts_grouped¶
apply_bracket_contrasts_grouped(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply bracket contrasts within each condition group of a crossed MeeState.
Similar to apply_contrasts_grouped() but for bracket contrast
expressions. Applies the same contrast matrix independently within
each group of n_focal consecutive rows.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from crossed EMM computation with n_focal x n_groups rows. | required |
expr | ContrastExpr | ContrastExpr from parser. | required |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with n_contrasts x n_groups rows and condition columns. |
apply_contrasts¶
apply_contrasts(mee_state: MeeState, contrast_type: str, fit: FitState | None = None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeStateApply contrast matrix to marginal means/effects.
High-level function that takes a MeeState with EMM estimates and returns a new MeeState with contrast estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState containing EMM estimates from compute_emm(). | required |
contrast_type | str | Type of contrast to apply: - “pairwise”: All pairwise comparisons (B-A, C-A, C-B, ...) - “sequential”: Adjacent differences (B-A, C-B, D-C, ...) - “poly”: Orthogonal polynomial contrasts - “treatment”: Each level vs reference (requires ref_idx) - “sum”: Each level vs grand mean - “helmert”: Each level vs mean of previous levels | required |
fit | FitState | None | FitState with vcov for variance propagation (optional, only needed if computing SE during this call). | None |
degree | int | None | Maximum polynomial degree for poly contrasts (default: n-1). | None |
ref_idx | int | None | Reference level index for treatment contrasts (0-based). | None |
level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])). When provided, EMM estimates and grid are reordered to match before applying the contrast matrix. | None |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with: |
MeeState | - grid: DataFrame with contrast labels |
MeeState | - estimate: Contrast estimates |
MeeState | - type: “contrasts” |
MeeState | - focal_var: Same as input |
MeeState | - explore_formula: Updated to reflect contrast |
Examples:
>>> mee = compute_emm(bundle, fit, "treatment", "treatment")
>>> contrasts = apply_contrasts(mee, "pairwise")
>>> contrasts.estimate # B-A, C-A, C-B differencesNote: Variance propagation for inference is deferred to .infer(). This function only computes point estimates.
apply_contrasts_grouped¶
apply_contrasts_grouped(mee_state: MeeState, contrast_type: str | None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeStateApply contrasts within each condition group of a crossed MeeState.
When EMMs have been computed over a crossed grid (focal levels x
condition groups), this function applies the contrast matrix
independently within each group of n_focal consecutive rows,
then stacks the results.
The number of focal levels is inferred from the focal_var column of
mee_state.grid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from _compute_emm_crossed() with n_focal x n_groups rows ordered as [focal_1_group_1, ..., focal_k_group_1, focal_1_group_2, ...]. | required |
contrast_type | str | None | Type of contrast (pairwise, sequential, poly, treatment, sum, helmert). | required |
degree | int | None | Degree for polynomial contrasts. | None |
ref_idx | int | None | Reference level index for treatment contrasts. | None |
level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts. When provided, EMM rows within each group are reordered before applying the contrast matrix. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with n_contrasts x n_groups rows and condition columns. |
apply_rhs_bracket_contrast¶
apply_rhs_bracket_contrast(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply a bracket contrast on a RHS condition column.
After the main computation produces a crossed MeeState with a condition
column (e.g., Dose with levels [High, Low]), this collapses
that column by applying the bracket contrast.
The condition column must vary slowest in the grid (standard layout from crossed EMM/slope computation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from a crossed computation. Must contain expr.var as a column in the grid. | required |
expr | ContrastExpr | ContrastExpr specifying the contrast on the condition variable. | required |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with the condition column replaced by contrast rows. |
build_all_pairwise_matrix¶
build_all_pairwise_matrix(n_levels: int) -> np.ndarrayBuild all pairwise contrasts between EMM levels.
Creates C(n, 2) = n*(n-1)/2 contrasts for all unique pairs. Unlike build_pairwise_matrix(), this includes all pairs, not just comparisons to the reference level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n*(n-1)/2, n_levels). |
ndarray | Each row compares two levels (level_j - level_i where j > i). |
Examples:
>>> C = build_all_pairwise_matrix(3)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.],
[ 0., -1., 1.]])build_bracket_contrast_matrix¶
build_bracket_contrast_matrix(expr: ContrastExpr, levels: list[str]) -> tuple[np.ndarray, list[str]]Build contrast matrix and labels from bracket contrast expression.
Handles auto-normalization (each operand side sums to +1 or -1)
and wildcard expansion (* expands to all unmentioned levels).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr | ContrastExpr | ContrastExpr AST from the parser. | required |
levels | list[str] | Actual factor levels from the EMM grid. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Tuple of (contrast_matrix, labels) where contrast_matrix has |
list[str] | shape (n_contrasts, n_levels) and labels has length |
tuple[ndarray, list[str]] | n_contrasts. |
build_contrast_matrix¶
build_contrast_matrix(contrast_type: str | dict, levels: list, normalize: bool = False) -> np.ndarrayBuild a contrast matrix based on contrast type.
Dispatcher function that builds the appropriate contrast matrix based on the contrast type specification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
contrast_type | str | dict | Type of contrast: - “pairwise”: All pairwise comparisons (k(k-1)/2 contrasts) - “trt.vs.ctrl” or “treatment”: Compare each level to first level - “sequential”: Successive differences (level[i+1] - level[i]) - “sum”: Each level vs grand mean (deviation coding) - “helmert”: Each level vs mean of previous levels - dict: Custom contrasts with names as keys and weights as values | required |
levels | list | List of factor levels. | required |
normalize | bool | If True, normalize custom contrasts to sum to 1/-1. | False |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_contrasts, n_levels). |
Examples:
>>> build_contrast_matrix("pairwise", ["A", "B", "C"])
array([[-1., 1., 0.],
[-1., 0., 1.],
[ 0., -1., 1.]])build_helmert_matrix¶
build_helmert_matrix(n_levels: int) -> np.ndarrayBuild Helmert contrasts (each level vs mean of previous levels).
Creates n_levels - 1 contrasts where level k is compared to the mean of levels 0, 1, ..., k-1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_helmert_matrix(3)
>>> C
array([[-1. , 1. , 0. ],
[-0.5, -0.5, 1. ]])>>> C = build_helmert_matrix(4)
>>> C
array([[-1. , 1. , 0. , 0. ],
[-0.5 , -0.5 , 1. , 0. ],
[-0.33333333, -0.33333333, -0.33333333, 1. ]])build_pairwise_matrix¶
build_pairwise_matrix(n_levels: int) -> np.ndarrayBuild (n-1) linearly independent pairwise contrasts.
Creates “treatment-style” contrasts comparing each level to the first (reference) level. This produces n_levels - 1 contrasts that span the space of all pairwise differences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
ndarray | Row i compares level i+1 to level 0. |
Examples:
>>> C = build_pairwise_matrix(3)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.]])build_poly_matrix¶
build_poly_matrix(n_levels: int, degree: int | None = None) -> np.ndarrayBuild orthogonal polynomial contrast matrix for EMMs.
Creates orthogonal polynomial contrasts for ordered factors. Tests for linear, quadratic, cubic (etc.) trends across the factor levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of factor levels. Must be >= 2. | required |
degree | int | None | Maximum polynomial degree. If None, uses n_levels - 1. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (degree, n_levels). |
ndarray | Rows are orthonormal polynomial basis vectors. |
Examples:
>>> build_poly_matrix(5) # 4x5 matrix
>>> build_poly_matrix(5, degree=2) # 2x5 matrix (linear + quadratic)build_reference_grid¶
build_reference_grid(bundle: DataBundle, focal_vars: list[str], *, at: dict[str, Any] | None = None, covariate_means: dict[str, float] | None = None) -> pl.DataFrameConstruct reference grid for marginal effects evaluation.
Creates a grid of covariate values at which to evaluate predictions. This function transforms the parsed explore formula’s conditions into a concrete evaluation grid.
The grid construction follows emmeans conventions:
Focal variables: vary across their unique levels/values from data
Non-focal continuous: set to their overall mean from the data
Non-focal categorical: create rows for all levels (equal weight averaging)
Conditioned variables (@ spec): use the specified values
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. Used to extract: - factor_levels: dict mapping categorical vars to their levels - X_names: column names for identifying variable types | required |
focal_vars | list[str] | Variables to vary across their range/levels. These become the rows of the output grid. | required |
at | dict[str, Any] | None | Fixed values for conditioning variables. Keys are variable names, values are the conditioning value(s). This dict is typically built from the Condition objects in the parsed explore formula. | None |
covariate_means | dict[str, float] | None | Pre-computed means for continuous covariates. If not provided, covariates will be omitted from the grid (caller must handle them separately or pass via at). | None |
Returns:
| Type | Description |
|---|---|
DataFrame | Polars DataFrame with one row per grid point. Columns include: |
DataFrame | - All focal variables |
DataFrame | - All conditioned variables (if any) |
DataFrame | The number of rows is the Cartesian product of focal variable levels. |
Integration with explore():
Called from model._compute_marginal_effects() after parsing::
# From parsed explore formula "treatment ~ age@50"
parsed = ExploreFormulaSpec(
focal_var="treatment",
conditions=[Condition(var="age", at_values=(50,))]
)
# Convert conditions to at dict
at_dict = {"age": 50.0}
# Build grid
grid = build_reference_grid(
bundle=self._bundle,
focal_vars=["treatment"],
at=at_dict,
)
# grid: pl.DataFrame with columns ["treatment"] and rows ["A", "B", "C"]
# age is NOT in the grid (it's held fixed at 50)Examples:
Grid for categorical focal::
grid = build_reference_grid(bundle, ["treatment"])
# Returns grid with one row per treatment levelGrid with conditioning::
grid = build_reference_grid(bundle, ["treatment"], at={"age": 50})
# Returns grid at age=50Multiple focal variables (interaction EMMs)::
grid = build_reference_grid(bundle, ["treatment", "sex"])
# Returns Cartesian product: treatment x sex levelsbuild_sequential_matrix¶
build_sequential_matrix(n_levels: int) -> np.ndarrayBuild sequential (successive differences) contrasts.
Creates n_levels - 1 contrasts comparing each level to the previous one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
ndarray | Row i compares level i+1 to level i. |
Examples:
>>> C = build_sequential_matrix(3)
>>> C
array([[-1., 1., 0.],
[ 0., -1., 1.]])>>> C = build_sequential_matrix(4)
>>> C
array([[-1., 1., 0., 0.],
[ 0., -1., 1., 0.],
[ 0., 0., -1., 1.]])build_sum_to_zero_matrix¶
build_sum_to_zero_matrix(n_levels: int) -> np.ndarrayBuild sum-to-zero contrasts (deviation coding).
Creates contrasts comparing each level to the grand mean.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_sum_to_zero_matrix(3)
>>> C
array([[ 0.667, -0.333, -0.333],
[-0.333, 0.667, -0.333]])build_treatment_matrix¶
build_treatment_matrix(n_levels: int, ref_idx: int = 0) -> np.ndarrayBuild treatment (Dunnett-style) contrasts against a reference level.
Creates n_levels - 1 contrasts, each comparing one non-reference level to the specified reference level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
ref_idx | int | Index of the reference level (0-based). | 0 |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_treatment_matrix(3, ref_idx=0)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.]])>>> C = build_treatment_matrix(3, ref_idx=2)
>>> C
array([[ 1., 0., -1.],
[ 0., 1., -1.]])combine_resolved¶
combine_resolved(a: ResolvedConditions, b: ResolvedConditions) -> ResolvedConditionsMerge two ResolvedConditions, with b taking precedence on conflicts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | ResolvedConditions | First resolved conditions. | required |
b | ResolvedConditions | Second resolved conditions (takes precedence). | required |
Returns:
| Type | Description |
|---|---|
ResolvedConditions | Merged ResolvedConditions. |
compose_contrast_matrix¶
compose_contrast_matrix(C: np.ndarray, X_ref: np.ndarray) -> np.ndarrayCompose contrast matrix with prediction matrix.
L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
C | ndarray | Contrast matrix of shape (n_contrasts, n_emms). | required |
X_ref | ndarray | Prediction matrix of shape (n_emms, n_coef). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Composed contrast matrix L_emm of shape (n_contrasts, n_coef). |
compute_compound_bracket_contrasts¶
compute_compound_bracket_contrasts(bundle: object, fit: object, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: object | None = None, effect_scale: str = 'link', resolved: object | None = None) -> MeeStateCompute bracket contrasts for a compound focal variable.
Handles compound variables like Drug:Dose by building a crossed
EMM grid over the component variables, creating compound level names,
and then applying the bracket contrast.
For example, Drug:Dose[Active:High - Placebo:Low] builds EMMs over
Drug × Dose, labels each cell as Active:High, Active:Low, etc.,
and applies the contrast Active:High - Placebo:Low.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | object | DataBundle with model data and metadata. | required |
fit | object | FitState with fitted coefficients. | required |
focal_var | str | Compound variable name (e.g., "Drug:Dose"). | required |
contrast_expr | ContrastExpr | ContrastExpr from the parser. | required |
data | DataFrame | Model data DataFrame. | required |
spec | object | None | ModelSpec with link/family info. | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
resolved | object | None | ResolvedConditions for additional conditioning. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with compound bracket contrast estimates. |
compute_conditional_emm¶
compute_conditional_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link', levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None) -> MeeStateCompute per-group conditional EMMs incorporating intercept BLUPs.
For each group g and each focal level, the conditional EMM is: η_g = X_ref_row @ β + b_intercept_g (+ other BLUP contributions)
When effect_scale=“response”: estimates = g⁻¹(η_g).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical variable. | required |
explore_formula | str | The explore formula string. | required |
spec | object | ModelSpec with link function info. | required |
varying_offsets | object | VaryingState with BLUPs per group. | required |
grouping_var | str | Name of the grouping variable. | required |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
levels | list[str] | None | Optional list of focal levels. | None |
at_overrides | dict[str, float] | None | Optional covariate overrides. | None |
set_categoricals | dict[str, str] | None | Optional dict pinning non-focal categoricals. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with per-(level, group) estimates. |
compute_conditional_slopes¶
compute_conditional_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link') -> MeeStateCompute per-group conditional slopes incorporating BLUPs.
For each group, the conditional slope is: slope_g = β_x + b_x_g (random slopes model) slope_g = β_x (random intercepts only)
When effect_scale=“response” and the link is non-identity, the response-scale slope is: slope_response_g = slope_link_g × dμ/dη(η_g)
where η_g is the linear predictor at the reference point for group g.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the continuous variable to get slopes for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
spec | object | ModelSpec with link function info. | required |
varying_offsets | object | VaryingState with BLUPs per group. | required |
grouping_var | str | Name of the grouping variable. | required |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with per-group slope estimates. |
compute_contrasts¶
compute_contrasts(emm: np.ndarray, contrast_matrix: np.ndarray) -> np.ndarrayApply contrast matrix to EMMs.
Transforms a vector of estimated marginal means into contrast estimates
by matrix multiplication. This is the final step when the explore formula
includes a contrast function like pairwise(treatment).
contrasts = C @ EMM contrasts = C @ EMM
Where:
C is the contrast matrix, shape (n_contrasts, n_levels)
EMM is the vector of marginal means, shape (n_levels,)
contrasts is the result, shape (n_contrasts,)
Each row sums to 0 (contrasts compare, not estimate absolute levels)
Each row sums to 0 (contrasts compare, not estimate absolute levels)
The result represents a comparison (e.g., B - A)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
emm | ndarray | Array of estimated marginal means from compute_emm(). Shape: (n_levels,) | required |
contrast_matrix | ndarray | Contrast matrix from build_all_pairwise_matrix(), build_sequential_matrix(), or build_poly_matrix(). Shape: (n_contrasts, n_levels) Rows must be orthogonal to the constant vector (sum to 0). | required |
Returns:
| Name | Type | Description |
|---|---|---|
ndarray | Array of contrast estimates. | |
Shape | ndarray | (n_contrasts,) |
Integration with explore(): Called after compute_emm when a contrast function is specified::
# In model._compute_marginal_effects():
if parsed.has_contrast:
# First compute EMMs
emms = compute_emm(...)
# Build appropriate contrast matrix
if parsed.contrast_type == "pairwise":
C = build_all_pairwise_matrix(n_levels)
elif parsed.contrast_type == "sequential":
C = build_sequential_matrix(n_levels)
elif parsed.contrast_type == "poly":
C = build_poly_matrix(n_levels, degree=parsed.contrast_degree)
# Apply contrasts
contrast_estimates = compute_contrasts(emms, C)
return build_mee_state(
grid=contrast_grid, # with contrast labels
estimate=contrast_estimates,
mee_type="contrasts",
...
)Note: For inference on contrasts, the variance of contrasts is: Var(C @ EMM) = C @ Var(EMM) @ C.T
This is handled by the inference methods (delta method for asymp, direct computation for bootstrap).
Examples:
Apply pairwise contrasts to 3-level EMMs::
emm = np.array([2.0, 3.5, 2.8]) # A, B, C means
C = build_all_pairwise_matrix(3)
# C = [[-1, 1, 0], # B - A
# [-1, 0, 1], # C - A
# [0, -1, 1]] # C - B
contrasts = compute_contrasts(emm, C)
# contrasts: array([1.5, 0.8, -0.7])compute_emm¶
compute_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: object | None = None, how: str = 'mem', effect_scale: str = 'link') -> MeeStateCompute estimated marginal means for a categorical focal variable.
Supports two averaging methods via the how parameter:
"mem"(default): Marginal Estimated Mean. Balanced reference grid (emmeans-style). Predictions at a grid where covariates are at their means."ame": Average Marginal Effect / g-computation. For each focal level, sets every observation to that level and averages the resulting predictions. Preserves the observed covariate distribution.
For linear models (identity link), both approaches give identical results.
For GLMs (non-identity link), they diverge because
mean(g⁻¹(Xᵢβ)) ≠ g⁻¹(mean(Xᵢ) · β).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. Used to extract: - X: design matrix for computing covariate means - X_names: column names to identify variable types - factor_levels: level metadata for categorical variables | required |
fit | ‘FitState’ | FitState with fitted coefficients (fit.coef array). | required |
focal_var | str | Name of the categorical variable to compute EMMs for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
levels | list[str] | None | Optional list of levels to compute EMMs for. If None, uses all levels from bundle.factor_levels. | None |
at_overrides | dict[str, float] | None | Optional dict of {variable: value} to fix specific covariates at given values instead of their means. Used for conditioning (e.g., cyl ~ wt@[3.0] fixes wt at 3.0). | None |
set_categoricals | dict[str, str] | None | Optional dict mapping non-focal categorical variable names to specific levels to pin them at (instead of marginalizing at column means). E.g. {"Ethnicity": "Asian"}. | None |
spec | object | None | ModelSpec with link/family info (needed for effect_scale=“response”). | None |
how | str | Averaging method: "mem" for balanced reference grid (emmeans-style), "ame" for g-computation over observed data. | ‘mem’ |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with grid of levels and their estimated means. |
Examples:
Compute EMMs with default (reference grid) averaging::
mee = compute_emm(bundle, fit, "treatment", "treatment")G-computation averaging for a logistic regression::
mee = compute_emm(bundle, fit, "treatment", "treatment",
spec=spec, how="ame", effect_scale="response")Note:
For how="ame" with effect_scale="response" on a non-identity
link, confidence intervals use the response-scale delta method
(symmetric CIs) rather than the link-scale back-transformation
(asymmetric CIs) used by how="mem".
compute_joint_test¶
compute_joint_test(fit: FitState, bundle: DataBundle, spec: ModelSpec, terms: list[str] | None = None, *, errors: str = 'iid', data: pl.DataFrame | None = None) -> JointTestStateCompute joint hypothesis tests for model terms.
Tests each model term (continuous variables, factors, and interactions) using direct coefficient tests. This matches R’s emmeans::joint_tests() behavior.
For each term:
Continuous variables: F-test on the single coefficient (df1=1)
Categorical factors: Joint F-test on all indicator coefficients
Interactions: Joint F-test on all interaction coefficients
The test type is determined by the model family:
Gaussian (LM/LMER): F-test with df_resid
Non-Gaussian (GLM/GLMER): Chi-square test (asymptotic)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fit | FitState | FitState with fitted coefficients and vcov. | required |
bundle | DataBundle | DataBundle with X_names for term structure. | required |
spec | ModelSpec | ModelSpec for model type detection. | required |
terms | list[str] | None | Specific terms to test, or None for all terms. | None |
errors | str | Error structure assumption. Options: - "iid" (default): Standard errors from fit.vcov. - "HC0"-"HC3", "hetero": Sandwich SEs. - "unequal_var": Welch ANOVA — sandwich SEs from per-cell variances with per-term Satterthwaite df2. For single-df terms this matches the Welch t-test; for multi-df terms (k-level factors), per-row dfs are combined via the Fai-Cornelius (1996) approximation. | ‘iid’ |
data | DataFrame | None | Original data frame, required for errors='unequal_var'. | None |
Returns:
| Type | Description |
|---|---|
JointTestState | JointTestState with test results for each term. |
Examples:
Basic usage::
from marginal import compute_joint_test
state = compute_joint_test(fit, bundle, spec)
# state.terms: ("x", "group", "x:group")
# state.statistic: F or chi2 values per term
# state.p_value: p-values per termWelch ANOVA::
state = compute_joint_test(fit, bundle, spec, errors="unequal_var", data=df)
# Per-term Satterthwaite df2 in state.df2Test specific terms only::
state = compute_joint_test(fit, bundle, spec, terms=["group"])Note: Intercept is never tested. Random effect grouping factors are excluded for mixed models.
compute_mee_inference¶
compute_mee_inference(mee: MeeState, vcov: np.ndarray, df_resid: float | np.ndarray | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> MeeStateCompute delta method inference for marginal effects.
SE = sqrt(diag(L @ vcov @ L’)) SE = sqrt(diag(L @ vcov @ L’))
where L is the design matrix (X_ref for EMMs, selector for slopes).
Applies multiplicity adjustment for contrast CIs to match R’s emmeans:
Tukey HSD for pairwise contrasts.
Multivariate-t (MVT) for sequential, treatment, sum, helmert contrasts.
No adjustment for polynomial contrasts or non-contrast MEEs.
For response-scale EMMs (effect_scale="response"), CIs are computed on the
link scale then back-transformed via the inverse link function,
matching R’s emmeans behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | MeeState | MeeState with L_matrix from compute_emm or compute_slopes. | required |
vcov | ndarray | Variance-covariance matrix of coefficients, shape (p, p). | required |
df_resid | float | ndarray | None | Degrees of freedom for t-distribution. Can be: - None: z-distribution (asymptotic normality). - scalar float: single df for all estimates (residual df). - np.ndarray: per-estimate Satterthwaite df (mixed models). | required |
conf_level | float | Confidence level for intervals (default 0.95). | 0.95 |
null | float | Null hypothesis value (default 0.0). | 0.0 |
alternative | str | Alternative hypothesis direction (default “two-sided”). | ‘two-sided’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState augmented with inference fields (se, df, statistic, |
MeeState | p_value, ci_lower, ci_upper, conf_level). |
Examples:
Compute inference for EMMs::
from marginal import compute_emm
from inference import compute_mee_inference
mee = compute_emm(bundle, fit, "treatment", "treatment")
mee_with_inf = compute_mee_inference(mee, fit.vcov, fit.df_resid)
# mee_with_inf.se: standard errors
# mee_with_inf.ci_lower, mee_with_inf.ci_upper: confidence boundsNote: For EMMs, L_matrix is the reference design matrix X_ref. For slopes, L_matrix is a selector row that picks the coefficient. For contrasts, L_matrix is the contrast matrix applied to EMMs.
compute_mee_inference_fallback¶
compute_mee_inference_fallback(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame', df_resid: float | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> 'MeeState'Compute inference for MEE without L_matrix (fallback path).
Used when L_matrix is not available (legacy MEE). Computes simplified
SEs via compute_mee_se, then builds test statistics, p-values,
and confidence intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | ‘MeeState’ | MeeState without L_matrix. | required |
bundle | ‘DataBundle’ | Data bundle with X_names. | required |
fit | ‘FitState’ | Fit state with vcov, sigma, etc. | required |
data | ‘pl.DataFrame’ | Original data frame for group counting. | required |
df_resid | float | None | Residual degrees of freedom (None for z-distribution). | required |
conf_level | float | Confidence level for intervals. | 0.95 |
null | float | Null hypothesis value (default 0.0). | 0.0 |
alternative | str | Alternative hypothesis direction (default “two-sided”). | ‘two-sided’ |
Returns:
| Type | Description |
|---|---|
‘MeeState’ | MeeState augmented with inference fields. |
compute_mee_se¶
compute_mee_se(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame') -> np.ndarrayCompute standard errors for MEE estimates (means or slopes).
When an L_matrix is available, uses the delta method:
SE = sqrt(diag(L @ vcov @ L')). This correctly handles
multi-row L_matrices (e.g. per-group conditional slopes).
Otherwise falls back to simplified type-based dispatch:
"means": SE = sqrt(MSE / n_per_group) for each group level"slopes": SE = sqrt(vcov[idx, idx]) from coefficient vcov diagonal
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | ‘MeeState’ | MeeState with type (“means” or “slopes”) and focal_var. | required |
bundle | ‘DataBundle’ | Data bundle with X_names. | required |
fit | ‘FitState’ | Fit state with sigma, dispersion, residuals, vcov, df_resid. | required |
data | ‘pl.DataFrame’ | Original data frame for group counting. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Array of standard errors, one per estimate. |
compute_slopes¶
compute_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object | None = None, effect_scale: str = 'link') -> MeeStateCompute marginal slope for a continuous focal variable.
For a linear model, this extracts the coefficient for the variable. For models with interactions, this would need to average marginal effects across the grid (future enhancement).
Marginal slopes answer: “By how much does the predicted outcome change for a one-unit increase in this variable?”
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data. Used to extract: - X_names: to locate the variable’s coefficient index | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the continuous variable to get slope for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
spec | object | None | ModelSpec with link/family info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with marginal slope estimate. |
Examples:
Compute marginal effect of age::
from marginal import compute_slopes
mee = compute_slopes(bundle, fit, "age", "age")
# For y ~ age: returns MeeState with estimate = [coef_age]Note:
This is the fast path for Gaussian identity-link models without
interactions. For models with interactions or GLMs (non-identity
link), the dispatcher routes to compute_slopes_finite_diff
which handles both cases via centered finite differences.
compute_slopes_crossed¶
compute_slopes_crossed(bundle: 'DataBundle', fit: 'FitState', focal_var: str, *, resolved: object, data: pl.DataFrame, spec: 'ModelSpec | None' = None, formula_spec: 'FormulaSpec | None' = None, effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeStateCompute crossed slopes over focal variable x condition grid.
Builds a Cartesian product of condition levels (grid categoricals,
grid numerics, scalar pins) and computes a finite-difference AME
for each combination. Analogous to _compute_emm_crossed but
for continuous focal variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Continuous variable to compute slopes for. | required |
resolved | object | ResolvedConditions with grid and scalar conditions. | required |
data | DataFrame | Raw model data (for computing ranges and covariate means). | required |
spec | ‘ModelSpec | None’ | ModelSpec with family/link info. | None |
formula_spec | ‘FormulaSpec | None’ | FormulaSpec with learned encoding for evaluate_newdata. | None |
effect_scale | str | "link" (linear predictor scale) or "response" (inverse-link / data scale). | ‘link’ |
delta_frac | float | Fraction of the focal variable’s range used as the finite-difference step size. | 0.001 |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with one slope estimate per condition combination. |
compute_slopes_finite_diff¶
compute_slopes_finite_diff(bundle: DataBundle, fit: FitState, focal_var: str, explore_formula: str, *, spec: ModelSpec, formula_spec: FormulaSpec, data: pl.DataFrame, how: str = 'mem', effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeStateCompute marginal slopes via centered finite differences.
Supports two averaging methods via the how parameter:
"mem": Average over a balanced reference grid (Cartesian product of categorical levels, continuous covariates at means). Matches R’semmeans::emtrends."ame": Average over actual data rows. Gives a true Average Marginal Effect (AME) that preserves the observed covariate distribution.
For linear models (identity link), both approaches give identical results.
delta = delta_frac × range(focal_var)delta = delta_frac × range(focal_var)Build evaluation grid (balanced or observed data)
Perturb:
grid_plus = grid[focal + delta/2],grid_minus = grid[focal − delta/2]X_plus, X_minus = evaluate_newdata(formula_spec, grid_*)L_diff = (X_plus − X_minus) / delta(link-scale functional)If
effect_scale="response"andhow="ame":J_i = [f'(η_i+) X_i+ − f'(η_i−) X_i−] / delta(per-observation Jacobian with f’ at perturbed points)L_avg = mean(J, axis=0, keepdims=True)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Continuous variable to compute the slope for. | required |
explore_formula | str | Explore formula string (for result metadata). | required |
spec | ModelSpec | ModelSpec with family/link info. | required |
formula_spec | FormulaSpec | FormulaSpec with learned encoding for evaluate_newdata. | required |
data | DataFrame | Raw model data (for computing ranges and covariate means). | required |
how | str | "mem" for balanced reference grid, "ame" for actual data rows. | ‘mem’ |
effect_scale | str | "link" (linear predictor scale) or "response" (inverse-link / data scale). | ‘link’ |
delta_frac | float | Fraction of the focal variable’s range used as the finite-difference step size. Default 0.001 matches R’s emmeans. | 0.001 |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with a single-row AME estimate and L_matrix for |
MeeState | delta-method inference. |
dispatch_marginal_computation¶
dispatch_marginal_computation(parsed: ExploreFormulaSpec, bundle: DataBundle, fit: FitState, data: pl.DataFrame, *, spec: ModelSpec | None = None, formula_spec: object | None = None, varying_offsets: VaryingState | None = None, effect_scale: str = 'link', varying: str = 'exclude', how: str = 'auto', inverse_transforms: bool = True, by: str | None = None) -> MeeStateRoute a parsed explore formula to the appropriate marginal computation.
Validates the focal variable, determines whether it is categorical or continuous, and dispatches to the correct computation function.
Supports effect_scale= and varying= for scale transforms and
conditional (group-specific) effects in mixed models. When RHS conditions
contain bracket contrasts (e.g., Drug[A - B] ~ Dose[High - Low]),
applies them as a post-processing step after the main computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parsed | ExploreFormulaSpec | Parsed explore formula with focal variable and contrast info. | required |
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
data | DataFrame | The model’s data DataFrame (for variable lookup and type detection). | required |
spec | ModelSpec | None | ModelSpec with link/family info (needed for effect_scale=“response”). | None |
formula_spec | object | None | FormulaSpec with learned encoding (needed for finite-diff slopes). | None |
varying_offsets | VaryingState | None | VaryingState with BLUPs (needed for conditional effects). | None |
effect_scale | str | "link" (default) or "response" (inverse-link / data scale). | ‘link’ |
varying | str | "exclude" (default) or "include" (conditional effects). | ‘exclude’ |
how | str | "auto" (default), "mem" (emmeans-style), or "ame" (g-computation / average marginal effect). | ‘auto’ |
inverse_transforms | bool | When True (default), raw variable names and values are auto-resolved through learned formula transforms. Set to False to use transformed-scale names/values directly. | True |
by | str | None | Grouping variable for faceted effects (default: None). | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with computed marginal effects. |
get_contrast_labels¶
get_contrast_labels(levels: list[str], contrast_type: str = 'pairwise') -> list[str]Generate human-readable labels for contrasts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
levels | list[str] | List of factor level names. | required |
contrast_type | str | Type of contrast (“pairwise”, “all_pairwise”, or “sequential”). | ‘pairwise’ |
Returns:
| Type | Description |
|---|---|
list[str] | List of contrast labels like “B - A”, “C - A”, etc. |
Examples:
>>> get_contrast_labels(["A", "B", "C"], "pairwise")
['B - A', 'C - A']
>>> get_contrast_labels(["A", "B", "C"], "all_pairwise")
['B - A', 'C - A', 'C - B']parse_explore_formula¶
parse_explore_formula(formula: str, model_terms: list[str] | None = None) -> ExploreFormulaSpecParse an explore formula string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula | str | Explore formula (e.g., "pairwise(treatment) ~ age@50"). | required |
model_terms | list[str] | None | Optional list of valid model terms for validation. | None |
Returns:
| Type | Description |
|---|---|
ExploreFormulaSpec | Parsed ExploreFormulaSpec object. |
Examples:
Simple term::
>>> parse_explore_formula("treatment")
ExploreFormulaSpec(focal_var='treatment', contrast_type=None, conditions=())Contrast function::
>>> parse_explore_formula("pairwise(treatment)")
ExploreFormulaSpec(focal_var='treatment', contrast_type='pairwise', conditions=())With conditioning::
>>> parse_explore_formula("treatment ~ age@50")
ExploreFormulaSpec(
focal_var='treatment',
contrast_type=None,
conditions=(Condition(var='age', at_values=(50.0,)),)
)resolve_conditions¶
resolve_conditions(conditions: tuple[Condition, ...], bundle: DataBundle, data: pl.DataFrame) -> ResolvedConditionsClassify each Condition into the appropriate typed bucket.
Uses bundle.factor_levels to distinguish categorical from continuous
variables, and the data DataFrame to compute range / quantile grids
for continuous variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conditions | tuple[Condition, ...] | Parsed Condition objects from the explore formula. | required |
bundle | DataBundle | DataBundle with factor_levels metadata. | required |
data | DataFrame | Original data DataFrame for computing range/quantile values. | required |
Returns:
| Type | Description |
|---|---|
ResolvedConditions | ResolvedConditions with conditions classified into typed buckets. |
Modules¶
bracket_contrasts¶
Bracket contrast matrix builder and application.
Builds contrast weight matrices from bracket contrast expressions
(e.g., Drug[Active - Placebo], Drug[* - Placebo],
Drug[(A + B) - C]).
Auto-normalization: each side of - normalizes so weights sum to 1.
Wildcard expansion: * expands to all levels not mentioned on the
other side, producing one contrast row per expanded level.
Functions:
| Name | Description |
|---|---|
apply_bracket_contrasts | Apply bracket contrast expression to an EMM MeeState. |
apply_bracket_contrasts_grouped | Apply bracket contrasts within each condition group of a crossed MeeState. |
apply_rhs_bracket_contrast | Apply a bracket contrast on a RHS condition column. |
build_bracket_contrast_matrix | Build contrast matrix and labels from bracket contrast expression. |
compute_compound_bracket_contrasts | Compute bracket contrasts for a compound focal variable. |
dispatch_bracket_contrasts | Compute bracket contrast expression for a categorical focal variable. |
Classes¶
Functions¶
apply_bracket_contrasts¶
apply_bracket_contrasts(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply bracket contrast expression to an EMM MeeState.
Computes contrasts by building a weight matrix from the bracket expression and multiplying by the EMM estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState with EMM estimates. | required |
expr | ContrastExpr | ContrastExpr from parser. | required |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with contrast estimates, type "contrasts", |
MeeState | and contrast_method="custom". |
apply_bracket_contrasts_grouped¶
apply_bracket_contrasts_grouped(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply bracket contrasts within each condition group of a crossed MeeState.
Similar to apply_contrasts_grouped() but for bracket contrast
expressions. Applies the same contrast matrix independently within
each group of n_focal consecutive rows.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from crossed EMM computation with n_focal x n_groups rows. | required |
expr | ContrastExpr | ContrastExpr from parser. | required |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with n_contrasts x n_groups rows and condition columns. |
apply_rhs_bracket_contrast¶
apply_rhs_bracket_contrast(mee_state: MeeState, expr: ContrastExpr) -> MeeStateApply a bracket contrast on a RHS condition column.
After the main computation produces a crossed MeeState with a condition
column (e.g., Dose with levels [High, Low]), this collapses
that column by applying the bracket contrast.
The condition column must vary slowest in the grid (standard layout from crossed EMM/slope computation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from a crossed computation. Must contain expr.var as a column in the grid. | required |
expr | ContrastExpr | ContrastExpr specifying the contrast on the condition variable. | required |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with the condition column replaced by contrast rows. |
build_bracket_contrast_matrix¶
build_bracket_contrast_matrix(expr: ContrastExpr, levels: list[str]) -> tuple[np.ndarray, list[str]]Build contrast matrix and labels from bracket contrast expression.
Handles auto-normalization (each operand side sums to +1 or -1)
and wildcard expansion (* expands to all unmentioned levels).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr | ContrastExpr | ContrastExpr AST from the parser. | required |
levels | list[str] | Actual factor levels from the EMM grid. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Tuple of (contrast_matrix, labels) where contrast_matrix has |
list[str] | shape (n_contrasts, n_levels) and labels has length |
tuple[ndarray, list[str]] | n_contrasts. |
compute_compound_bracket_contrasts¶
compute_compound_bracket_contrasts(bundle: object, fit: object, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: object | None = None, effect_scale: str = 'link', resolved: object | None = None) -> MeeStateCompute bracket contrasts for a compound focal variable.
Handles compound variables like Drug:Dose by building a crossed
EMM grid over the component variables, creating compound level names,
and then applying the bracket contrast.
For example, Drug:Dose[Active:High - Placebo:Low] builds EMMs over
Drug × Dose, labels each cell as Active:High, Active:Low, etc.,
and applies the contrast Active:High - Placebo:Low.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | object | DataBundle with model data and metadata. | required |
fit | object | FitState with fitted coefficients. | required |
focal_var | str | Compound variable name (e.g., "Drug:Dose"). | required |
contrast_expr | ContrastExpr | ContrastExpr from the parser. | required |
data | DataFrame | Model data DataFrame. | required |
spec | object | None | ModelSpec with link/family info. | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
resolved | object | None | ResolvedConditions for additional conditioning. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with compound bracket contrast estimates. |
dispatch_bracket_contrasts¶
dispatch_bracket_contrasts(bundle: DataBundle, fit: FitState, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem', resolved: ResolvedConditions | None = None) -> MeeStateCompute bracket contrast expression for a categorical focal variable.
First computes EMMs, then applies the bracket contrast matrix.
When resolved contains grid conditions, computes crossed EMMs
and applies grouped bracket contrasts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical variable. | required |
contrast_expr | ContrastExpr | ContrastExpr AST from the parser. | required |
data | DataFrame | Model data DataFrame. | required |
spec | ModelSpec | None | ModelSpec with link info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
how | str | Averaging method: "mem" or "ame". | ‘mem’ |
resolved | ResolvedConditions | None | Resolved conditions for conditioning. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with bracket contrast estimates. |
compute¶
Marginal effects dispatch and routing.
Routes parsed explore formulas to the appropriate computation: EMMs for
categorical focal variables, slopes for continuous, or contrasts.
Supports effect_scale= (link/response scale), how= (mem/ame),
and varying= (marginal/conditional).
Functions:
| Name | Description |
|---|---|
compute_emm_categorical | Compute EMMs for a categorical focal variable. |
dispatch_marginal_computation | Route a parsed explore formula to the appropriate marginal computation. |
Classes¶
Functions¶
compute_emm_categorical¶
compute_emm_categorical(bundle: DataBundle, fit: FitState, focal_var: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem') -> MeeStateCompute EMMs for a categorical focal variable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical variable. | required |
levels | list[str] | None | Optional subset of levels to compute EMMs for (e.g. from cyl@[4, 8] syntax). If None, all levels are used. | None |
at_overrides | dict[str, float] | None | Optional covariate overrides for conditioning. | None |
set_categoricals | dict[str, str] | None | Optional dict pinning non-focal categoricals to specific levels (e.g. {"Ethnicity": "Asian"}). | None |
spec | ModelSpec | None | ModelSpec with link info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
how | str | Averaging method: "mem" or "ame". | ‘mem’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with grid of levels and their estimated means. |
dispatch_marginal_computation¶
dispatch_marginal_computation(parsed: ExploreFormulaSpec, bundle: DataBundle, fit: FitState, data: pl.DataFrame, *, spec: ModelSpec | None = None, formula_spec: object | None = None, varying_offsets: VaryingState | None = None, effect_scale: str = 'link', varying: str = 'exclude', how: str = 'auto', inverse_transforms: bool = True, by: str | None = None) -> MeeStateRoute a parsed explore formula to the appropriate marginal computation.
Validates the focal variable, determines whether it is categorical or continuous, and dispatches to the correct computation function.
Supports effect_scale= and varying= for scale transforms and
conditional (group-specific) effects in mixed models. When RHS conditions
contain bracket contrasts (e.g., Drug[A - B] ~ Dose[High - Low]),
applies them as a post-processing step after the main computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parsed | ExploreFormulaSpec | Parsed explore formula with focal variable and contrast info. | required |
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
data | DataFrame | The model’s data DataFrame (for variable lookup and type detection). | required |
spec | ModelSpec | None | ModelSpec with link/family info (needed for effect_scale=“response”). | None |
formula_spec | object | None | FormulaSpec with learned encoding (needed for finite-diff slopes). | None |
varying_offsets | VaryingState | None | VaryingState with BLUPs (needed for conditional effects). | None |
effect_scale | str | "link" (default) or "response" (inverse-link / data scale). | ‘link’ |
varying | str | "exclude" (default) or "include" (conditional effects). | ‘exclude’ |
how | str | "auto" (default), "mem" (emmeans-style), or "ame" (g-computation / average marginal effect). | ‘auto’ |
inverse_transforms | bool | When True (default), raw variable names and values are auto-resolved through learned formula transforms. Set to False to use transformed-scale names/values directly. | True |
by | str | None | Grouping variable for faceted effects (default: None). | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with computed marginal effects. |
conditions¶
Condition resolution for explore formula RHS conditioning.
Resolves parsed Condition objects into typed buckets that downstream
EMM / slope / contrast computation can consume:
at_overrides: single numeric pins (Income@50)set_categoricals: single categorical pins (Ethnicity@Asian)grid_categoricals: multi-level categoricals to cross (bareEthnicity)grid_numerics: multi-value numerics to cross (Income@:range(5))
Classes:
| Name | Description |
|---|---|
ResolvedConditions | Typed buckets for resolved conditioning specifications. |
Functions:
| Name | Description |
|---|---|
combine_resolved | Merge two ResolvedConditions, with b taking precedence on conflicts. |
get_column_values | Extract a data column as a numpy array for range/quantile computation. |
resolve_conditions | Classify each Condition into the appropriate typed bucket. |
Classes¶
ResolvedConditions¶
Typed buckets for resolved conditioning specifications.
Attributes:
| Name | Type | Description |
|---|---|---|
at_overrides | dict[str, float] | Single numeric pin per variable (e.g. Income@50). |
set_categoricals | dict[str, str] | Single categorical pin per variable (e.g. Ethnicity@Asian). |
grid_categoricals | dict[str, list[str]] | Multi-level categoricals to cross (e.g. bare Ethnicity → all levels, or Ethnicity@(Asian, Caucasian) → those two levels). |
grid_numerics | dict[str, list[float]] | Multi-value numerics to cross (e.g. Income@(10, 20) or Income@:range(5)). |
Attributes¶
at_overrides¶
at_overrides: dict[str, float]grid_categoricals¶
grid_categoricals: dict[str, list[str]]grid_numerics¶
grid_numerics: dict[str, list[float]]has_grid¶
has_grid: boolReturn True if any condition requires grid expansion.
raw_at_overrides¶
raw_at_overrides: dict[str, float] = Factory(dict)Pre-transform at-overrides with original (raw) variable names.
Populated by the dispatch layer after _resolve_at_overrides remaps
keys to design-matrix names. The slopes path needs raw keys because it
operates on a data grid with raw column names, not design-matrix names.
Empty when no transform resolution was applied.
set_categoricals¶
set_categoricals: dict[str, str]Functions¶
combine_resolved¶
combine_resolved(a: ResolvedConditions, b: ResolvedConditions) -> ResolvedConditionsMerge two ResolvedConditions, with b taking precedence on conflicts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | ResolvedConditions | First resolved conditions. | required |
b | ResolvedConditions | Second resolved conditions (takes precedence). | required |
Returns:
| Type | Description |
|---|---|
ResolvedConditions | Merged ResolvedConditions. |
get_column_values¶
get_column_values(data: pl.DataFrame, bundle: DataBundle, var: str) -> np.ndarrayExtract a data column as a numpy array for range/quantile computation.
Looks in the original data DataFrame first, falling back to the
design matrix column if the variable name matches an X_names entry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data | DataFrame | Original data DataFrame. | required |
bundle | DataBundle | DataBundle with design matrix. | required |
var | str | Variable name. | required |
Returns:
| Type | Description |
|---|---|
ndarray | 1-D numpy array of column values. |
resolve_conditions¶
resolve_conditions(conditions: tuple[Condition, ...], bundle: DataBundle, data: pl.DataFrame) -> ResolvedConditionsClassify each Condition into the appropriate typed bucket.
Uses bundle.factor_levels to distinguish categorical from continuous
variables, and the data DataFrame to compute range / quantile grids
for continuous variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conditions | tuple[Condition, ...] | Parsed Condition objects from the explore formula. | required |
bundle | DataBundle | DataBundle with factor_levels metadata. | required |
data | DataFrame | Original data DataFrame for computing range/quantile values. | required |
Returns:
| Type | Description |
|---|---|
ResolvedConditions | ResolvedConditions with conditions classified into typed buckets. |
contrasts¶
Contrast computation for marginal effects.
Applies contrast matrices to EMM estimates. Matrix builders live in
marginal.matrices.
apply_contrasts: Apply contrasts to MeeState apply_contrasts: Apply contrasts to MeeState apply_contrasts_grouped: Apply contrasts within condition groups compute_contrasts: Apply contrast matrix to EMM vector
Functions:
| Name | Description |
|---|---|
apply_contrasts | Apply contrast matrix to marginal means/effects. |
apply_contrasts_grouped | Apply contrasts within each condition group of a crossed MeeState. |
compute_contrasts | Apply contrast matrix to EMMs. |
dispatch_contrasts | Compute contrasts for a categorical focal variable. |
Classes¶
Functions¶
apply_contrasts¶
apply_contrasts(mee_state: MeeState, contrast_type: str, fit: FitState | None = None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeStateApply contrast matrix to marginal means/effects.
High-level function that takes a MeeState with EMM estimates and returns a new MeeState with contrast estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState containing EMM estimates from compute_emm(). | required |
contrast_type | str | Type of contrast to apply: - “pairwise”: All pairwise comparisons (B-A, C-A, C-B, ...) - “sequential”: Adjacent differences (B-A, C-B, D-C, ...) - “poly”: Orthogonal polynomial contrasts - “treatment”: Each level vs reference (requires ref_idx) - “sum”: Each level vs grand mean - “helmert”: Each level vs mean of previous levels | required |
fit | FitState | None | FitState with vcov for variance propagation (optional, only needed if computing SE during this call). | None |
degree | int | None | Maximum polynomial degree for poly contrasts (default: n-1). | None |
ref_idx | int | None | Reference level index for treatment contrasts (0-based). | None |
level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])). When provided, EMM estimates and grid are reordered to match before applying the contrast matrix. | None |
Returns:
| Type | Description |
|---|---|
MeeState | New MeeState with: |
MeeState | - grid: DataFrame with contrast labels |
MeeState | - estimate: Contrast estimates |
MeeState | - type: “contrasts” |
MeeState | - focal_var: Same as input |
MeeState | - explore_formula: Updated to reflect contrast |
Examples:
>>> mee = compute_emm(bundle, fit, "treatment", "treatment")
>>> contrasts = apply_contrasts(mee, "pairwise")
>>> contrasts.estimate # B-A, C-A, C-B differencesNote: Variance propagation for inference is deferred to .infer(). This function only computes point estimates.
apply_contrasts_grouped¶
apply_contrasts_grouped(mee_state: MeeState, contrast_type: str | None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeStateApply contrasts within each condition group of a crossed MeeState.
When EMMs have been computed over a crossed grid (focal levels x
condition groups), this function applies the contrast matrix
independently within each group of n_focal consecutive rows,
then stacks the results.
The number of focal levels is inferred from the focal_var column of
mee_state.grid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee_state | MeeState | MeeState from _compute_emm_crossed() with n_focal x n_groups rows ordered as [focal_1_group_1, ..., focal_k_group_1, focal_1_group_2, ...]. | required |
contrast_type | str | None | Type of contrast (pairwise, sequential, poly, treatment, sum, helmert). | required |
degree | int | None | Degree for polynomial contrasts. | None |
ref_idx | int | None | Reference level index for treatment contrasts. | None |
level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts. When provided, EMM rows within each group are reordered before applying the contrast matrix. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with n_contrasts x n_groups rows and condition columns. |
compute_contrasts¶
compute_contrasts(emm: np.ndarray, contrast_matrix: np.ndarray) -> np.ndarrayApply contrast matrix to EMMs.
Transforms a vector of estimated marginal means into contrast estimates
by matrix multiplication. This is the final step when the explore formula
includes a contrast function like pairwise(treatment).
contrasts = C @ EMM contrasts = C @ EMM
Where:
C is the contrast matrix, shape (n_contrasts, n_levels)
EMM is the vector of marginal means, shape (n_levels,)
contrasts is the result, shape (n_contrasts,)
Each row sums to 0 (contrasts compare, not estimate absolute levels)
Each row sums to 0 (contrasts compare, not estimate absolute levels)
The result represents a comparison (e.g., B - A)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
emm | ndarray | Array of estimated marginal means from compute_emm(). Shape: (n_levels,) | required |
contrast_matrix | ndarray | Contrast matrix from build_all_pairwise_matrix(), build_sequential_matrix(), or build_poly_matrix(). Shape: (n_contrasts, n_levels) Rows must be orthogonal to the constant vector (sum to 0). | required |
Returns:
| Name | Type | Description |
|---|---|---|
ndarray | Array of contrast estimates. | |
Shape | ndarray | (n_contrasts,) |
Integration with explore(): Called after compute_emm when a contrast function is specified::
# In model._compute_marginal_effects():
if parsed.has_contrast:
# First compute EMMs
emms = compute_emm(...)
# Build appropriate contrast matrix
if parsed.contrast_type == "pairwise":
C = build_all_pairwise_matrix(n_levels)
elif parsed.contrast_type == "sequential":
C = build_sequential_matrix(n_levels)
elif parsed.contrast_type == "poly":
C = build_poly_matrix(n_levels, degree=parsed.contrast_degree)
# Apply contrasts
contrast_estimates = compute_contrasts(emms, C)
return build_mee_state(
grid=contrast_grid, # with contrast labels
estimate=contrast_estimates,
mee_type="contrasts",
...
)Note: For inference on contrasts, the variance of contrasts is: Var(C @ EMM) = C @ Var(EMM) @ C.T
This is handled by the inference methods (delta method for asymp, direct computation for bootstrap).
Examples:
Apply pairwise contrasts to 3-level EMMs::
emm = np.array([2.0, 3.5, 2.8]) # A, B, C means
C = build_all_pairwise_matrix(3)
# C = [[-1, 1, 0], # B - A
# [-1, 0, 1], # C - A
# [0, -1, 1]] # C - B
contrasts = compute_contrasts(emm, C)
# contrasts: array([1.5, 0.8, -0.7])dispatch_contrasts¶
dispatch_contrasts(bundle: DataBundle, fit: FitState, focal_var: str, contrast_type: str | None, contrast_degree: int | None, data: pl.DataFrame, *, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem', resolved: ResolvedConditions | None = None, focal_at_values: tuple[float | str, ...] | None = None, contrast_ref: str | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeStateCompute contrasts for a categorical focal variable.
First computes EMMs, then applies the requested contrast matrix.
When resolved contains grid conditions, computes crossed EMMs and
applies grouped contrasts. When focal_at_values is provided
(e.g. from pairwise(cyl@[4, 8])), contrasts are computed only
over the requested subset of levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical variable. | required |
contrast_type | str | None | Type of contrast (pairwise, sequential, poly, treatment, sum, helmert). | required |
contrast_degree | int | None | Degree for polynomial contrasts (None = max). | required |
data | DataFrame | Model data for level extraction. | required |
spec | ModelSpec | None | ModelSpec with link info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
how | str | Averaging method: "mem" or "ame". | ‘mem’ |
resolved | ResolvedConditions | None | Resolved conditions for conditioning. | None |
focal_at_values | tuple[float | str, ...] | None | Optional subset of levels from at-spec syntax (e.g. pairwise(cyl@[4, 8])). | None |
contrast_ref | str | None | Reference level name for treatment contrasts. | None |
level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])). | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with contrast estimates. |
emm¶
Estimated marginal means computation.
This module provides EMM computation for categorical focal variables.
compute_emm: Compute estimated marginal means at grid points compute_emm: Compute estimated marginal means at grid points
Functions:
| Name | Description |
|---|---|
compute_conditional_emm | Compute per-group conditional EMMs incorporating intercept BLUPs. |
compute_emm | Compute estimated marginal means for a categorical focal variable. |
compute_emm_crossed | Compute crossed EMMs over focal levels x condition grid. |
Classes¶
Functions¶
compute_conditional_emm¶
compute_conditional_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link', levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None) -> MeeStateCompute per-group conditional EMMs incorporating intercept BLUPs.
For each group g and each focal level, the conditional EMM is: η_g = X_ref_row @ β + b_intercept_g (+ other BLUP contributions)
When effect_scale=“response”: estimates = g⁻¹(η_g).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical variable. | required |
explore_formula | str | The explore formula string. | required |
spec | object | ModelSpec with link function info. | required |
varying_offsets | object | VaryingState with BLUPs per group. | required |
grouping_var | str | Name of the grouping variable. | required |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
levels | list[str] | None | Optional list of focal levels. | None |
at_overrides | dict[str, float] | None | Optional covariate overrides. | None |
set_categoricals | dict[str, str] | None | Optional dict pinning non-focal categoricals. | None |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with per-(level, group) estimates. |
compute_emm¶
compute_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: object | None = None, how: str = 'mem', effect_scale: str = 'link') -> MeeStateCompute estimated marginal means for a categorical focal variable.
Supports two averaging methods via the how parameter:
"mem"(default): Marginal Estimated Mean. Balanced reference grid (emmeans-style). Predictions at a grid where covariates are at their means."ame": Average Marginal Effect / g-computation. For each focal level, sets every observation to that level and averages the resulting predictions. Preserves the observed covariate distribution.
For linear models (identity link), both approaches give identical results.
For GLMs (non-identity link), they diverge because
mean(g⁻¹(Xᵢβ)) ≠ g⁻¹(mean(Xᵢ) · β).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. Used to extract: - X: design matrix for computing covariate means - X_names: column names to identify variable types - factor_levels: level metadata for categorical variables | required |
fit | ‘FitState’ | FitState with fitted coefficients (fit.coef array). | required |
focal_var | str | Name of the categorical variable to compute EMMs for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
levels | list[str] | None | Optional list of levels to compute EMMs for. If None, uses all levels from bundle.factor_levels. | None |
at_overrides | dict[str, float] | None | Optional dict of {variable: value} to fix specific covariates at given values instead of their means. Used for conditioning (e.g., cyl ~ wt@[3.0] fixes wt at 3.0). | None |
set_categoricals | dict[str, str] | None | Optional dict mapping non-focal categorical variable names to specific levels to pin them at (instead of marginalizing at column means). E.g. {"Ethnicity": "Asian"}. | None |
spec | object | None | ModelSpec with link/family info (needed for effect_scale=“response”). | None |
how | str | Averaging method: "mem" for balanced reference grid (emmeans-style), "ame" for g-computation over observed data. | ‘mem’ |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with grid of levels and their estimated means. |
Examples:
Compute EMMs with default (reference grid) averaging::
mee = compute_emm(bundle, fit, "treatment", "treatment")G-computation averaging for a logistic regression::
mee = compute_emm(bundle, fit, "treatment", "treatment",
spec=spec, how="ame", effect_scale="response")Note:
For how="ame" with effect_scale="response" on a non-identity
link, confidence intervals use the response-scale delta method
(symmetric CIs) rather than the link-scale back-transformation
(asymmetric CIs) used by how="mem".
compute_emm_crossed¶
compute_emm_crossed(bundle: DataBundle, fit: FitState, focal_var: str, *, resolved: ResolvedConditions, levels: list[str] | None = None, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem') -> MeeStateCompute crossed EMMs over focal levels x condition grid.
Builds a Cartesian product of focal levels with all grid conditions, computing one EMM row per combination.
Supports two averaging methods via how:
"mem": Reference grid at covariate means (emmeans-style)."ame": G-computation over observed data rows with counterfactual treatment assignment per combo.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Name of the categorical focal variable. | required |
resolved | ResolvedConditions | ResolvedConditions with grid and scalar conditions. | required |
levels | list[str] | None | Optional subset of focal levels to include (e.g. from treatment@[A, B]). If None, all levels are used. | None |
spec | ModelSpec | None | ModelSpec with link info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
how | str | Averaging method: "mem" or "ame". | ‘mem’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with n_focal x n_grid rows and condition columns. |
explore¶
Explore formula parser.
Tokenizes and parses explore formula strings into :class:ExploreFormulaSpec
containers using :class:ExploreScanner and :class:ExploreParser.
Grammar::
explore_formula := lhs [ '~' rhs ]
lhs := focal_term | contrast_fn '(' focal_term { ',' arg } ')'
focal_term := varname [ '@' at_spec ]
at_spec := value | '[' value_list ']' | [':']'range(' n ')' | [':']'quantile(' n ')'
contrast_fn := 'pairwise' | 'sequential' | 'poly'
| 'treatment' | 'dummy' | 'sum' | 'deviation' | 'helmert'
arg := IDENTIFIER '=' value | value
rhs := condition [ '+' condition ]*
condition := term [ '@' at_spec ]Classes:
| Name | Description |
|---|---|
Condition | A conditioning specification in explore formula. |
ExploreFormulaError | Error in explore formula syntax. |
ExploreFormulaSpec | Parsed explore formula. |
Functions:
| Name | Description |
|---|---|
parse_explore_formula | Parse an explore formula string. |
parse_lhs | Parse LHS of explore formula. |
Classes¶
Condition¶
A conditioning specification in explore formula.
Represents a variable to condition on, optionally with specific values, a method for generating values, or a bracket contrast expression.
Created by: parse_explore_formula(), ExploreParser Consumed by: resolve_all_conditions(), dispatch_marginal_computation() Augmented by: Never
Attributes:
| Name | Type | Description |
|---|---|---|
var | str | Variable name to condition on. |
at_values | tuple | None | Specific values to evaluate at (e.g., (50.0,) or (“A”, “B”)). |
at_range | int | None | Number of evenly-spaced values across the variable’s range. |
at_quantile | int | None | Number of quantile values to use. |
contrast_expr | ContrastExpr | None | Bracket contrast expression on this condition variable (e.g., from Dose[High - Low] on the RHS). When set, the variable is treated as a grid categorical during condition resolution, and the contrast is applied as a post-processing step. |
Attributes¶
at_quantile¶
at_quantile: int | None = field(default=None, validator=is_optional_positive_int)at_range¶
at_range: int | None = field(default=None, validator=is_optional_positive_int)at_values¶
at_values: tuple | None = field(default=None)contrast_expr¶
contrast_expr: ContrastExpr | None = field(default=None, validator=(validators.optional(validators.instance_of(ContrastExpr))))var¶
var: str = field(validator=(validators.instance_of(str)))ExploreFormulaError¶
ExploreFormulaError(message: str, formula: str, position: int | None = None) -> NoneBases: ValueError
Error in explore formula syntax.
Provides helpful error messages with position indicators for syntax errors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
message | str | Error description. | required |
formula | str | The formula that caused the error. | required |
position | int | None | Character position of the error (optional). | None |
Attributes:
| Name | Type | Description |
|---|---|---|
formula | ||
position |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
message | str | Error description. | required |
formula | str | The formula that caused the error. | required |
position | int | None | Character position of the error. | None |
Attributes¶
formula¶
formula = formulaposition¶
position = positionExploreFormulaSpec¶
Parsed explore formula.
Represents a parsed explore formula with focal variable, optional contrast type, and conditioning specifications.
Created by: parse_explore_formula() Consumed by: dispatch_marginal_computation(), plot_predict(), plot_explore() Augmented by: attrs.evolve() in resolve_focal_at_spec() materializes focal at-values
Attributes:
| Name | Type | Description |
|---|---|---|
focal_var | str | The variable to compute marginal effects for. |
contrast_type | str | None | Type of contrast (pairwise, sequential, poly, treatment, sum, helmert, custom) or None for simple EMMs. Set to "custom" for bracket contrast expressions. |
contrast_degree | int | None | Degree parameter for polynomial contrasts (default None means use n_levels - 1, i.e., maximum degree). |
contrast_ref | str | None | Reference level for treatment/dummy contrasts (e.g., "Placebo" from treatment(Drug, ref=Placebo)). |
contrast_level_ordering | tuple[str, ...] | None | Explicit level ordering for order-dependent contrasts (helmert, sequential, poly). Parsed from bracket list syntax, e.g. poly(dose, [low, med, high]). |
contrast_expr | ContrastExpr | None | Bracket contrast expression AST (e.g., from Drug[Active - Placebo] syntax). None for named contrast functions or simple EMMs. |
conditions | tuple[Condition, ...] | Tuple of Condition objects specifying conditioning variables. |
focal_at_values | tuple[float | str, ...] | None | Specific values to evaluate the focal variable at (e.g., from Days@[0, 3, 6, 9] syntax). None means use all levels. |
focal_at_range | int | None | Number of evenly-spaced values across the focal variable’s range (e.g., from Days@range(5) syntax). None means not set. |
focal_at_quantile | int | None | Number of quantile values for the focal variable (e.g., from Days@quantile(3) syntax). None means not set. |
Attributes¶
conditions¶
conditions: tuple[Condition, ...] = field(factory=tuple, converter=tuple)contrast_degree¶
contrast_degree: int | None = field(default=None, validator=is_optional_positive_int)contrast_expr¶
contrast_expr: ContrastExpr | None = field(default=None)contrast_level_ordering¶
contrast_level_ordering: tuple[str, ...] | None = field(default=None, validator=is_optional_tuple_of_str)contrast_ref¶
contrast_ref: str | None = field(default=None, validator=is_optional_str)contrast_type¶
contrast_type: str | None = field(default=None, validator=(validators.optional(is_choice_str(('pairwise', 'sequential', 'poly', 'treatment', 'sum', 'helmert', 'custom')))))focal_at_quantile¶
focal_at_quantile: int | None = field(default=None, validator=is_optional_positive_int)focal_at_range¶
focal_at_range: int | None = field(default=None, validator=is_optional_positive_int)focal_at_values¶
focal_at_values: tuple[float | str, ...] | None = field(default=None)focal_var¶
focal_var: str = field(validator=(validators.instance_of(str)))has_conditions¶
has_conditions: boolReturn True if conditioning variables are specified.
has_contrast¶
has_contrast: boolReturn True if any contrast is specified (named function or bracket expr).
has_contrast_expr¶
has_contrast_expr: boolReturn True if a bracket contrast expression is specified.
has_rhs_contrasts¶
has_rhs_contrasts: boolReturn True if any RHS condition has a bracket contrast expression.
Functions¶
parse_explore_formula¶
parse_explore_formula(formula: str, model_terms: list[str] | None = None) -> ExploreFormulaSpecParse an explore formula string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula | str | Explore formula (e.g., "pairwise(treatment) ~ age@50"). | required |
model_terms | list[str] | None | Optional list of valid model terms for validation. | None |
Returns:
| Type | Description |
|---|---|
ExploreFormulaSpec | Parsed ExploreFormulaSpec object. |
Examples:
Simple term::
>>> parse_explore_formula("treatment")
ExploreFormulaSpec(focal_var='treatment', contrast_type=None, conditions=())Contrast function::
>>> parse_explore_formula("pairwise(treatment)")
ExploreFormulaSpec(focal_var='treatment', contrast_type='pairwise', conditions=())With conditioning::
>>> parse_explore_formula("treatment ~ age@50")
ExploreFormulaSpec(
focal_var='treatment',
contrast_type=None,
conditions=(Condition(var='age', at_values=(50.0,)),)
)parse_lhs¶
parse_lhs(lhs: str, formula: str) -> tuple[str, str | None, int | None, tuple[float | str, ...] | None, int | None, int | None]Parse LHS of explore formula.
Backward-compatible wrapper that tokenizes and parses just the LHS portion of an explore formula.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lhs | str | Left-hand side of the formula. | required |
formula | str | Full formula for error messages. | required |
Returns:
| Type | Description |
|---|---|
str | Tuple of (focal_var, contrast_type, contrast_degree, focal_at_values, |
str | None | focal_at_range, focal_at_quantile). |
Note:
This wrapper does not return contrast_ref. Use
parse_explore_formula() for the full result including
contrast_ref.
explore_parser¶
Explore formula recursive descent parser.
Parses token streams from :class:ExploreScanner into
:class:ExploreFormulaSpec containers.
Grammar::
explore_formula := lhs [ TILDE rhs ]
lhs := contrast_fn_call | contrast_expr | focal_term
contrast_fn_call := fn_name LEFT_PAREN focal_term { COMMA arg } RIGHT_PAREN
fn_name := 'pairwise' | 'sequential' | 'poly'
| 'treatment' | 'dummy'
| 'sum' | 'deviation'
| 'helmert'
arg := kwarg | positional_arg
kwarg := IDENTIFIER EQUAL value
positional_arg := value
contrast_expr := IDENTIFIER LEFT_BRACKET contrast_list RIGHT_BRACKET
contrast_list := contrast_item { COMMA contrast_item }
contrast_item := contrast_operand MINUS contrast_operand
contrast_operand := STAR
| LEFT_PAREN level_name { PLUS level_name } RIGHT_PAREN
| level_name
level_name := IDENTIFIER | STRING | NUMBER
focal_term := IDENTIFIER [ AT at_spec ]
at_spec := value
| LEFT_BRACKET value_list RIGHT_BRACKET
| [COLON] range_or_quantile
range_or_quantile := IDENTIFIER LEFT_PAREN NUMBER RIGHT_PAREN
(where IDENTIFIER is 'range' or 'quantile')
value_list := value { COMMA value }
value := NUMBER | STRING | IDENTIFIER
rhs := condition { PLUS condition }
condition := IDENTIFIER LEFT_BRACKET contrast_list RIGHT_BRACKET
| IDENTIFIER [ AT at_spec ]Classes:
| Name | Description |
|---|---|
ExploreParser | Recursive descent parser for explore formula syntax. |
Attributes¶
Classes¶
ExploreParser¶
ExploreParser(tokens: list[Token], formula: str) -> NoneRecursive descent parser for explore formula syntax.
Consumes tokens produced by :class:ExploreScanner and builds an
:class:ExploreFormulaSpec container.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens | list[Token] | List of tokens from the scanner. | required |
formula | str | Original formula string (for error messages). | required |
Functions:
| Name | Description |
|---|---|
parse | Parse explore formula tokens into ExploreFormulaSpec container. |
Functions¶
parse¶
parse() -> ExploreFormulaSpecParse explore formula tokens into ExploreFormulaSpec container.
Returns:
| Type | Description |
|---|---|
ExploreFormulaSpec | Parsed ExploreFormulaSpec. |
Functions¶
explore_scanner¶
Explore formula scanner/tokenizer.
Extends the formula scanner with @ token support for explore-specific
at-spec syntax (e.g., Days@50, Income@range(5)).
Classes:
| Name | Description |
|---|---|
ExploreScanner | Scanner for explore formulas. |
Classes¶
ExploreScanner¶
Bases: Scanner
Scanner for explore formulas.
Extends the base formula scanner to:
Recognize
@as anATtoken.Skip intercept insertion (explore formulas have no implicit intercept).
Allow multiple tildes check to be skipped (explore
~separates LHS/RHS, not response/predictors).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
code | str | The explore formula string to scan. | required |
Functions:
| Name | Description |
|---|---|
add_token | |
advance | |
at_end | |
backquote | |
char | |
floatnum | |
identifier | |
match | |
number | |
peek | |
peek_next | |
scan | Scan explore formula string. |
scan_token | Scan a single token, adding @ support. |
Attributes:
| Name | Type | Description |
|---|---|---|
code | ||
current | ||
start | ||
tokens | list[Token] |
Attributes¶
code¶
code = codecurrent¶
current = 0start¶
start = 0tokens¶
tokens: list[Token] = []Functions¶
add_token¶
add_token(kind: str, literal: object = None) -> Noneadvance¶
advance() -> strat_end¶
at_end() -> boolbackquote¶
backquote() -> Nonechar¶
char() -> Nonefloatnum¶
floatnum() -> Noneidentifier¶
identifier() -> Nonematch¶
match(expected: str) -> boolnumber¶
number() -> Nonepeek¶
peek() -> strpeek_next¶
peek_next() -> strscan¶
scan(add_intercept: bool = False) -> list[Token]Scan explore formula string.
Overrides base to skip intercept insertion and skip the
multiple-tilde validation (explore formulas use ~ differently).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
add_intercept | bool | Ignored; always False for explore formulas. | False |
Returns:
| Type | Description |
|---|---|
list[Token] | A list of Token objects. |
scan_token¶
scan_token() -> NoneScan a single token, adding @ support.
Functions¶
factors¶
Factor level extraction and handling utilities.
Functions:
| Name | Description |
|---|---|
detect_factor_levels_from_data | Infer factor levels from unique values in data column. |
get_factor_levels | Extract factor levels for a variable from bundle or data. |
Classes¶
Functions¶
detect_factor_levels_from_data¶
detect_factor_levels_from_data(data: 'pl.DataFrame', var: str) -> list[str]Infer factor levels from unique values in data column.
Extracts unique values and sorts them for consistent ordering. Values are converted to strings for uniform handling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data | ‘pl.DataFrame’ | DataFrame containing the variable. | required |
var | str | Column name to extract levels from. | required |
Returns:
| Type | Description |
|---|---|
list[str] | Sorted list of unique values as strings. |
Examples:
>>> import polars as pl
>>> df = pl.DataFrame({"group": ["B", "A", "B", "C", "A"]})
>>> detect_factor_levels_from_data(df, "group")
['A', 'B', 'C']get_factor_levels¶
get_factor_levels(bundle: 'DataBundle', var: str, *, fallback_data: 'pl.DataFrame | None' = None, allow_inference: bool = True) -> list[str]Extract factor levels for a variable from bundle or data.
Priority order:
bundle.factor_levels[var] if present
Infer from fallback_data[var].unique() if provided and allow_inference=True
Raise ValueError if not found and no fallback
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with factor_levels metadata. | required |
var | str | Variable name to get levels for. | required |
fallback_data | ‘pl.DataFrame | None’ | Optional DataFrame to infer levels from if not in bundle. | None |
allow_inference | bool | If True, allow inferring levels from data. If False, require levels to be in bundle.factor_levels. | True |
Returns:
| Type | Description |
|---|---|
list[str] | List of level values as strings, sorted. |
Examples:
Get levels from bundle metadata::
levels = get_factor_levels(bundle, "treatment")
# ["control", "drug_a", "drug_b"]Get levels with data fallback::
levels = get_factor_levels(bundle, "group", fallback_data=df)
# Inferred from df["group"].unique()grid¶
Reference grid construction for marginal effects.
Functions:
| Name | Description |
|---|---|
build_reference_grid | Construct reference grid for marginal effects evaluation. |
Classes¶
Functions¶
build_reference_grid¶
build_reference_grid(bundle: DataBundle, focal_vars: list[str], *, at: dict[str, Any] | None = None, covariate_means: dict[str, float] | None = None) -> pl.DataFrameConstruct reference grid for marginal effects evaluation.
Creates a grid of covariate values at which to evaluate predictions. This function transforms the parsed explore formula’s conditions into a concrete evaluation grid.
The grid construction follows emmeans conventions:
Focal variables: vary across their unique levels/values from data
Non-focal continuous: set to their overall mean from the data
Non-focal categorical: create rows for all levels (equal weight averaging)
Conditioned variables (@ spec): use the specified values
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. Used to extract: - factor_levels: dict mapping categorical vars to their levels - X_names: column names for identifying variable types | required |
focal_vars | list[str] | Variables to vary across their range/levels. These become the rows of the output grid. | required |
at | dict[str, Any] | None | Fixed values for conditioning variables. Keys are variable names, values are the conditioning value(s). This dict is typically built from the Condition objects in the parsed explore formula. | None |
covariate_means | dict[str, float] | None | Pre-computed means for continuous covariates. If not provided, covariates will be omitted from the grid (caller must handle them separately or pass via at). | None |
Returns:
| Type | Description |
|---|---|
DataFrame | Polars DataFrame with one row per grid point. Columns include: |
DataFrame | - All focal variables |
DataFrame | - All conditioned variables (if any) |
DataFrame | The number of rows is the Cartesian product of focal variable levels. |
Integration with explore():
Called from model._compute_marginal_effects() after parsing::
# From parsed explore formula "treatment ~ age@50"
parsed = ExploreFormulaSpec(
focal_var="treatment",
conditions=[Condition(var="age", at_values=(50,))]
)
# Convert conditions to at dict
at_dict = {"age": 50.0}
# Build grid
grid = build_reference_grid(
bundle=self._bundle,
focal_vars=["treatment"],
at=at_dict,
)
# grid: pl.DataFrame with columns ["treatment"] and rows ["A", "B", "C"]
# age is NOT in the grid (it's held fixed at 50)Examples:
Grid for categorical focal::
grid = build_reference_grid(bundle, ["treatment"])
# Returns grid with one row per treatment levelGrid with conditioning::
grid = build_reference_grid(bundle, ["treatment"], at={"age": 50})
# Returns grid at age=50Multiple focal variables (interaction EMMs)::
grid = build_reference_grid(bundle, ["treatment", "sex"])
# Returns Cartesian product: treatment x sex levelsinference¶
Inference for marginal effects using the delta method.
This module provides inference computation for MeeState results. Uses the delta method: SE = sqrt(diag(L @ vcov @ L’))
compute_mee_inference: Add inference to MeeState via delta method compute_mee_inference: Add inference to MeeState via delta method compute_mee_se: Compute SEs for MEE (means or slopes) compute_mee_inference_fallback: Fallback inference without L_matrix
Functions:
| Name | Description |
|---|---|
compute_mee_inference | Compute delta method inference for marginal effects. |
compute_mee_inference_fallback | Compute inference for MEE without L_matrix (fallback path). |
compute_mee_se | Compute standard errors for MEE estimates (means or slopes). |
Classes¶
Functions¶
compute_mee_inference¶
compute_mee_inference(mee: MeeState, vcov: np.ndarray, df_resid: float | np.ndarray | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> MeeStateCompute delta method inference for marginal effects.
SE = sqrt(diag(L @ vcov @ L’)) SE = sqrt(diag(L @ vcov @ L’))
where L is the design matrix (X_ref for EMMs, selector for slopes).
Applies multiplicity adjustment for contrast CIs to match R’s emmeans:
Tukey HSD for pairwise contrasts.
Multivariate-t (MVT) for sequential, treatment, sum, helmert contrasts.
No adjustment for polynomial contrasts or non-contrast MEEs.
For response-scale EMMs (effect_scale="response"), CIs are computed on the
link scale then back-transformed via the inverse link function,
matching R’s emmeans behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | MeeState | MeeState with L_matrix from compute_emm or compute_slopes. | required |
vcov | ndarray | Variance-covariance matrix of coefficients, shape (p, p). | required |
df_resid | float | ndarray | None | Degrees of freedom for t-distribution. Can be: - None: z-distribution (asymptotic normality). - scalar float: single df for all estimates (residual df). - np.ndarray: per-estimate Satterthwaite df (mixed models). | required |
conf_level | float | Confidence level for intervals (default 0.95). | 0.95 |
null | float | Null hypothesis value (default 0.0). | 0.0 |
alternative | str | Alternative hypothesis direction (default “two-sided”). | ‘two-sided’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState augmented with inference fields (se, df, statistic, |
MeeState | p_value, ci_lower, ci_upper, conf_level). |
Examples:
Compute inference for EMMs::
from marginal import compute_emm
from inference import compute_mee_inference
mee = compute_emm(bundle, fit, "treatment", "treatment")
mee_with_inf = compute_mee_inference(mee, fit.vcov, fit.df_resid)
# mee_with_inf.se: standard errors
# mee_with_inf.ci_lower, mee_with_inf.ci_upper: confidence boundsNote: For EMMs, L_matrix is the reference design matrix X_ref. For slopes, L_matrix is a selector row that picks the coefficient. For contrasts, L_matrix is the contrast matrix applied to EMMs.
compute_mee_inference_fallback¶
compute_mee_inference_fallback(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame', df_resid: float | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> 'MeeState'Compute inference for MEE without L_matrix (fallback path).
Used when L_matrix is not available (legacy MEE). Computes simplified
SEs via compute_mee_se, then builds test statistics, p-values,
and confidence intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | ‘MeeState’ | MeeState without L_matrix. | required |
bundle | ‘DataBundle’ | Data bundle with X_names. | required |
fit | ‘FitState’ | Fit state with vcov, sigma, etc. | required |
data | ‘pl.DataFrame’ | Original data frame for group counting. | required |
df_resid | float | None | Residual degrees of freedom (None for z-distribution). | required |
conf_level | float | Confidence level for intervals. | 0.95 |
null | float | Null hypothesis value (default 0.0). | 0.0 |
alternative | str | Alternative hypothesis direction (default “two-sided”). | ‘two-sided’ |
Returns:
| Type | Description |
|---|---|
‘MeeState’ | MeeState augmented with inference fields. |
compute_mee_se¶
compute_mee_se(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame') -> np.ndarrayCompute standard errors for MEE estimates (means or slopes).
When an L_matrix is available, uses the delta method:
SE = sqrt(diag(L @ vcov @ L')). This correctly handles
multi-row L_matrices (e.g. per-group conditional slopes).
Otherwise falls back to simplified type-based dispatch:
"means": SE = sqrt(MSE / n_per_group) for each group level"slopes": SE = sqrt(vcov[idx, idx]) from coefficient vcov diagonal
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mee | ‘MeeState’ | MeeState with type (“means” or “slopes”) and focal_var. | required |
bundle | ‘DataBundle’ | Data bundle with X_names. | required |
fit | ‘FitState’ | Fit state with sigma, dispersion, residuals, vcov, df_resid. | required |
data | ‘pl.DataFrame’ | Original data frame for group counting. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Array of standard errors, one per estimate. |
joint_tests¶
ANOVA-style joint hypothesis tests for model terms.
Functions:
| Name | Description |
|---|---|
compute_joint_test | Compute joint hypothesis tests for model terms. |
Classes¶
Functions¶
compute_joint_test¶
compute_joint_test(fit: FitState, bundle: DataBundle, spec: ModelSpec, terms: list[str] | None = None, *, errors: str = 'iid', data: pl.DataFrame | None = None) -> JointTestStateCompute joint hypothesis tests for model terms.
Tests each model term (continuous variables, factors, and interactions) using direct coefficient tests. This matches R’s emmeans::joint_tests() behavior.
For each term:
Continuous variables: F-test on the single coefficient (df1=1)
Categorical factors: Joint F-test on all indicator coefficients
Interactions: Joint F-test on all interaction coefficients
The test type is determined by the model family:
Gaussian (LM/LMER): F-test with df_resid
Non-Gaussian (GLM/GLMER): Chi-square test (asymptotic)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fit | FitState | FitState with fitted coefficients and vcov. | required |
bundle | DataBundle | DataBundle with X_names for term structure. | required |
spec | ModelSpec | ModelSpec for model type detection. | required |
terms | list[str] | None | Specific terms to test, or None for all terms. | None |
errors | str | Error structure assumption. Options: - "iid" (default): Standard errors from fit.vcov. - "HC0"-"HC3", "hetero": Sandwich SEs. - "unequal_var": Welch ANOVA — sandwich SEs from per-cell variances with per-term Satterthwaite df2. For single-df terms this matches the Welch t-test; for multi-df terms (k-level factors), per-row dfs are combined via the Fai-Cornelius (1996) approximation. | ‘iid’ |
data | DataFrame | None | Original data frame, required for errors='unequal_var'. | None |
Returns:
| Type | Description |
|---|---|
JointTestState | JointTestState with test results for each term. |
Examples:
Basic usage::
from marginal import compute_joint_test
state = compute_joint_test(fit, bundle, spec)
# state.terms: ("x", "group", "x:group")
# state.statistic: F or chi2 values per term
# state.p_value: p-values per termWelch ANOVA::
state = compute_joint_test(fit, bundle, spec, errors="unequal_var", data=df)
# Per-term Satterthwaite df2 in state.df2Test specific terms only::
state = compute_joint_test(fit, bundle, spec, terms=["group"])Note: Intercept is never tested. Random effect grouping factors are excluded for mixed models.
matrices¶
Contrast matrix builders for EMM comparisons.
Provides functions to build hypothesis contrast matrices used in marginal effects computation. These matrices transform EMM vectors into contrast estimates (e.g., pairwise differences, sequential differences, polynomial trends).
Distinct from design.coding which builds design matrix coding matrices
for categorical variable encoding during model fitting.
Created by: absorbed from maths/inference/contrasts.py (zero consumers
outside marginal/) + poly_contrasts from contrasts.py.
Consumed by: marginal/contrasts.py (apply_contrasts, apply_contrasts_grouped).
Functions:
| Name | Description |
|---|---|
build_all_pairwise_matrix | Build all pairwise contrasts between EMM levels. |
build_contrast_matrix | Build a contrast matrix based on contrast type. |
build_helmert_matrix | Build Helmert contrasts (each level vs mean of previous levels). |
build_pairwise_matrix | Build (n-1) linearly independent pairwise contrasts. |
build_poly_matrix | Build orthogonal polynomial contrast matrix for EMMs. |
build_sequential_matrix | Build sequential (successive differences) contrasts. |
build_sum_to_zero_matrix | Build sum-to-zero contrasts (deviation coding). |
build_treatment_matrix | Build treatment (Dunnett-style) contrasts against a reference level. |
compose_contrast_matrix | Compose contrast matrix with prediction matrix. |
get_contrast_labels | Generate human-readable labels for contrasts. |
Functions¶
build_all_pairwise_matrix¶
build_all_pairwise_matrix(n_levels: int) -> np.ndarrayBuild all pairwise contrasts between EMM levels.
Creates C(n, 2) = n*(n-1)/2 contrasts for all unique pairs. Unlike build_pairwise_matrix(), this includes all pairs, not just comparisons to the reference level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n*(n-1)/2, n_levels). |
ndarray | Each row compares two levels (level_j - level_i where j > i). |
Examples:
>>> C = build_all_pairwise_matrix(3)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.],
[ 0., -1., 1.]])build_contrast_matrix¶
build_contrast_matrix(contrast_type: str | dict, levels: list, normalize: bool = False) -> np.ndarrayBuild a contrast matrix based on contrast type.
Dispatcher function that builds the appropriate contrast matrix based on the contrast type specification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
contrast_type | str | dict | Type of contrast: - “pairwise”: All pairwise comparisons (k(k-1)/2 contrasts) - “trt.vs.ctrl” or “treatment”: Compare each level to first level - “sequential”: Successive differences (level[i+1] - level[i]) - “sum”: Each level vs grand mean (deviation coding) - “helmert”: Each level vs mean of previous levels - dict: Custom contrasts with names as keys and weights as values | required |
levels | list | List of factor levels. | required |
normalize | bool | If True, normalize custom contrasts to sum to 1/-1. | False |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_contrasts, n_levels). |
Examples:
>>> build_contrast_matrix("pairwise", ["A", "B", "C"])
array([[-1., 1., 0.],
[-1., 0., 1.],
[ 0., -1., 1.]])build_helmert_matrix¶
build_helmert_matrix(n_levels: int) -> np.ndarrayBuild Helmert contrasts (each level vs mean of previous levels).
Creates n_levels - 1 contrasts where level k is compared to the mean of levels 0, 1, ..., k-1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_helmert_matrix(3)
>>> C
array([[-1. , 1. , 0. ],
[-0.5, -0.5, 1. ]])>>> C = build_helmert_matrix(4)
>>> C
array([[-1. , 1. , 0. , 0. ],
[-0.5 , -0.5 , 1. , 0. ],
[-0.33333333, -0.33333333, -0.33333333, 1. ]])build_pairwise_matrix¶
build_pairwise_matrix(n_levels: int) -> np.ndarrayBuild (n-1) linearly independent pairwise contrasts.
Creates “treatment-style” contrasts comparing each level to the first (reference) level. This produces n_levels - 1 contrasts that span the space of all pairwise differences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
ndarray | Row i compares level i+1 to level 0. |
Examples:
>>> C = build_pairwise_matrix(3)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.]])build_poly_matrix¶
build_poly_matrix(n_levels: int, degree: int | None = None) -> np.ndarrayBuild orthogonal polynomial contrast matrix for EMMs.
Creates orthogonal polynomial contrasts for ordered factors. Tests for linear, quadratic, cubic (etc.) trends across the factor levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of factor levels. Must be >= 2. | required |
degree | int | None | Maximum polynomial degree. If None, uses n_levels - 1. | None |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (degree, n_levels). |
ndarray | Rows are orthonormal polynomial basis vectors. |
Examples:
>>> build_poly_matrix(5) # 4x5 matrix
>>> build_poly_matrix(5, degree=2) # 2x5 matrix (linear + quadratic)build_sequential_matrix¶
build_sequential_matrix(n_levels: int) -> np.ndarrayBuild sequential (successive differences) contrasts.
Creates n_levels - 1 contrasts comparing each level to the previous one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
ndarray | Row i compares level i+1 to level i. |
Examples:
>>> C = build_sequential_matrix(3)
>>> C
array([[-1., 1., 0.],
[ 0., -1., 1.]])>>> C = build_sequential_matrix(4)
>>> C
array([[-1., 1., 0., 0.],
[ 0., -1., 1., 0.],
[ 0., 0., -1., 1.]])build_sum_to_zero_matrix¶
build_sum_to_zero_matrix(n_levels: int) -> np.ndarrayBuild sum-to-zero contrasts (deviation coding).
Creates contrasts comparing each level to the grand mean.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels. | required |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_sum_to_zero_matrix(3)
>>> C
array([[ 0.667, -0.333, -0.333],
[-0.333, 0.667, -0.333]])build_treatment_matrix¶
build_treatment_matrix(n_levels: int, ref_idx: int = 0) -> np.ndarrayBuild treatment (Dunnett-style) contrasts against a reference level.
Creates n_levels - 1 contrasts, each comparing one non-reference level to the specified reference level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_levels | int | Number of EMM levels (factor levels). | required |
ref_idx | int | Index of the reference level (0-based). | 0 |
Returns:
| Type | Description |
|---|---|
ndarray | Contrast matrix of shape (n_levels - 1, n_levels). |
Examples:
>>> C = build_treatment_matrix(3, ref_idx=0)
>>> C
array([[-1., 1., 0.],
[-1., 0., 1.]])>>> C = build_treatment_matrix(3, ref_idx=2)
>>> C
array([[ 1., 0., -1.],
[ 0., 1., -1.]])compose_contrast_matrix¶
compose_contrast_matrix(C: np.ndarray, X_ref: np.ndarray) -> np.ndarrayCompose contrast matrix with prediction matrix.
L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
C | ndarray | Contrast matrix of shape (n_contrasts, n_emms). | required |
X_ref | ndarray | Prediction matrix of shape (n_emms, n_coef). | required |
Returns:
| Type | Description |
|---|---|
ndarray | Composed contrast matrix L_emm of shape (n_contrasts, n_coef). |
get_contrast_labels¶
get_contrast_labels(levels: list[str], contrast_type: str = 'pairwise') -> list[str]Generate human-readable labels for contrasts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
levels | list[str] | List of factor level names. | required |
contrast_type | str | Type of contrast (“pairwise”, “all_pairwise”, or “sequential”). | ‘pairwise’ |
Returns:
| Type | Description |
|---|---|
list[str] | List of contrast labels like “B - A”, “C - A”, etc. |
Examples:
>>> get_contrast_labels(["A", "B", "C"], "pairwise")
['B - A', 'C - A']
>>> get_contrast_labels(["A", "B", "C"], "all_pairwise")
['B - A', 'C - A', 'C - B']resolve¶
Resolution helpers for marginal effects dispatch.
Translates raw user-facing formula elements into internal representations: variable name resolution through transforms, at-value forward transforms, condition classification, and focal at-range/at-quantile materialization.
Functions:
| Name | Description |
|---|---|
build_var_transform_map | Build a mapping from raw variable names to their transformed column info. |
resolve_all_conditions | Resolve formula conditions into typed buckets. |
resolve_at_overrides | Remap at-override keys through forward transforms; optionally transform values. |
resolve_conditional | Determine if conditional effects are requested and resolve grouping var. |
resolve_focal_at_spec | Resolve focal_at_range / focal_at_quantile into focal_at_values. |
resolve_focal_at_values | Forward-transform focal at-values through a formula transform. |
resolve_focal_var | Resolve a raw focal variable name to its transformed column name. |
Attributes:
| Name | Type | Description |
|---|---|---|
INVERTIBLE_TRANSFORMS |
Attributes¶
INVERTIBLE_TRANSFORMS¶
INVERTIBLE_TRANSFORMS = frozenset({'center', 'norm', 'zscore', 'scale'})Classes¶
Functions¶
build_var_transform_map¶
build_var_transform_map(formula_spec: FormulaSpec | None) -> dict[str, tuple[str, StatefulTransform]]Build a mapping from raw variable names to their transformed column info.
Inspects formula_spec.transform_state keys (e.g. "center(x)") and
extracts the inner variable name. Only includes invertible transforms
(center, norm, zscore, scale).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula_spec | FormulaSpec | None | FormulaSpec with transform_state and transforms dicts. | required |
Returns:
| Type | Description |
|---|---|
dict[str, tuple[str, StatefulTransform]] | Dict mapping raw variable name to (transformed_col_name, transform_obj). |
dict[str, tuple[str, StatefulTransform]] | Empty dict when no transforms are present or formula_spec is None. |
resolve_all_conditions¶
resolve_all_conditions(parsed: ExploreFormulaSpec, bundle: DataBundle, data: pl.DataFrame, fspec: FormulaSpec | None, inverse_transforms: bool) -> ResolvedConditionsResolve formula conditions into typed buckets.
Classifies RHS conditions from the parsed formula, then applies forward transforms to numeric overrides.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parsed | ExploreFormulaSpec | Parsed explore formula. | required |
bundle | DataBundle | DataBundle with factor_levels and design matrix. | required |
data | DataFrame | Original data DataFrame. | required |
fspec | FormulaSpec | None | FormulaSpec with learned transforms (for forward resolution). | required |
inverse_transforms | bool | Whether to apply forward transforms. | required |
Returns:
| Type | Description |
|---|---|
ResolvedConditions | ResolvedConditions with all conditions classified. |
resolve_at_overrides¶
resolve_at_overrides(at_overrides: dict[str, float] | None, formula_spec: FormulaSpec | None, inverse_transforms: bool) -> dict[str, float] | NoneRemap at-override keys through forward transforms; optionally transform values.
Always remaps raw variable names to their transformed column names
(e.g. "x" → "center(x)"). When inverse_transforms=True,
also forward-transforms the values (e.g. 50 → 50 - mean).
When False, values are left as-is on the user-specified scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
at_overrides | dict[str, float] | None | Original at-override dict (may be None). | required |
formula_spec | FormulaSpec | None | FormulaSpec with learned transforms. | required |
inverse_transforms | bool | Whether to apply forward transforms to values. | required |
Returns:
| Type | Description |
|---|---|
dict[str, float] | None | Resolved at_overrides dict, or None if input was None. |
resolve_conditional¶
resolve_conditional(parsed: ExploreFormulaSpec, bundle: DataBundle, varying: str, varying_offsets: VaryingState | None) -> tuple[bool, str | None]Determine if conditional effects are requested and resolve grouping var.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parsed | ExploreFormulaSpec | Parsed explore formula. | required |
bundle | DataBundle | DataBundle with RE metadata. | required |
varying | str | “exclude” or “include”. | required |
varying_offsets | VaryingState | None | VaryingState (None for non-mixed models). | required |
Returns:
| Type | Description |
|---|---|
tuple[bool, str | None] | Tuple of (is_conditional, grouping_var or None). |
resolve_focal_at_spec¶
resolve_focal_at_spec(parsed: ExploreFormulaSpec, data: pl.DataFrame, bundle: DataBundle) -> ExploreFormulaSpecResolve focal_at_range / focal_at_quantile into focal_at_values.
When the parsed formula contains focal_at_range=n or
focal_at_quantile=n, this function computes the concrete values from
the data and returns an evolved copy with focal_at_values populated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parsed | ExploreFormulaSpec | Parsed explore formula (may have range/quantile fields set). | required |
data | DataFrame | Original data DataFrame for computing range/quantile values. | required |
bundle | DataBundle | DataBundle with design matrix metadata. | required |
Returns:
| Type | Description |
|---|---|
ExploreFormulaSpec | ExploreFormulaSpec with focal_at_values resolved (or unchanged if |
ExploreFormulaSpec | neither range nor quantile was set). |
resolve_focal_at_values¶
resolve_focal_at_values(focal_var: str, at_values: tuple[float, ...], formula_spec: FormulaSpec | None, inverse_transforms: bool) -> tuple[float, ...]Forward-transform focal at-values through a formula transform.
When inverse_transforms=True and the focal variable has a learned
transform (e.g. center(x)), applies the forward transform to each
at-value so that raw-scale values become transformed-scale values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
focal_var | str | Raw focal variable name (e.g. "Days"). | required |
at_values | tuple[float, ...] | User-specified at-values on the raw scale. | required |
formula_spec | FormulaSpec | None | FormulaSpec with learned transforms. | required |
inverse_transforms | bool | Whether to apply forward transforms. | required |
Returns:
| Type | Description |
|---|---|
tuple[float, ...] | Transformed at-values tuple (or original if no transform applies). |
resolve_focal_var¶
resolve_focal_var(focal_var: str, X_names: tuple[str, ...] | list[str], formula_spec: FormulaSpec | None, inverse_transforms: bool) -> strResolve a raw focal variable name to its transformed column name.
Always attempts resolution regardless of inverse_transforms so that
raw variable names (e.g. "x") are mapped to their design-matrix
column names (e.g. "center(x)").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
focal_var | str | Raw variable name from the explore formula. | required |
X_names | tuple[str, ...] | list[str] | Design matrix column names. | required |
formula_spec | FormulaSpec | None | FormulaSpec with transform info. | required |
inverse_transforms | bool | Unused; kept for API compatibility. | required |
Returns:
| Type | Description |
|---|---|
str | The resolved column name (transformed or original). |
slopes¶
Marginal slopes computation.
This module provides marginal slope computation for continuous focal variables.
Three strategies are available:
Coefficient extraction (
compute_slopes): fast path for Gaussian identity-link models without interactions involving the focal variable.Finite differences (
compute_slopes_finite_diff): general path for GLMs and/or models with interactions. Produces the Average Marginal Effect (AME) by building perturbed design matrices and averaging over a reference grid of factor-level combinations.Conditional slopes (
compute_conditional_slopes): per-group slopes incorporating BLUPs for mixed models.
compute_slopes: Coefficient-extraction slopes (Gaussian identity, no interactions) compute_slopes: Coefficient-extraction slopes (Gaussian identity, no interactions) compute_slopes_finite_diff: Finite-difference AME slopes compute_conditional_slopes: Per-group conditional slopes with BLUPs
Functions:
| Name | Description |
|---|---|
compute_conditional_slopes | Compute per-group conditional slopes incorporating BLUPs. |
compute_slopes | Compute marginal slope for a continuous focal variable. |
compute_slopes_crossed | Compute crossed slopes over focal variable x condition grid. |
compute_slopes_finite_diff | Compute marginal slopes via centered finite differences. |
Classes¶
Functions¶
compute_conditional_slopes¶
compute_conditional_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link') -> MeeStateCompute per-group conditional slopes incorporating BLUPs.
For each group, the conditional slope is: slope_g = β_x + b_x_g (random slopes model) slope_g = β_x (random intercepts only)
When effect_scale=“response” and the link is non-identity, the response-scale slope is: slope_response_g = slope_link_g × dμ/dη(η_g)
where η_g is the linear predictor at the reference point for group g.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the continuous variable to get slopes for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
spec | object | ModelSpec with link function info. | required |
varying_offsets | object | VaryingState with BLUPs per group. | required |
grouping_var | str | Name of the grouping variable. | required |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with per-group slope estimates. |
compute_slopes¶
compute_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object | None = None, effect_scale: str = 'link') -> MeeStateCompute marginal slope for a continuous focal variable.
For a linear model, this extracts the coefficient for the variable. For models with interactions, this would need to average marginal effects across the grid (future enhancement).
Marginal slopes answer: “By how much does the predicted outcome change for a one-unit increase in this variable?”
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data. Used to extract: - X_names: to locate the variable’s coefficient index | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Name of the continuous variable to get slope for. | required |
explore_formula | str | The explore formula string (for result metadata). | required |
spec | object | None | ModelSpec with link/family info (for effect_scale=“response”). | None |
effect_scale | str | Scale of estimates: "link" or "response". | ‘link’ |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with marginal slope estimate. |
Examples:
Compute marginal effect of age::
from marginal import compute_slopes
mee = compute_slopes(bundle, fit, "age", "age")
# For y ~ age: returns MeeState with estimate = [coef_age]Note:
This is the fast path for Gaussian identity-link models without
interactions. For models with interactions or GLMs (non-identity
link), the dispatcher routes to compute_slopes_finite_diff
which handles both cases via centered finite differences.
compute_slopes_crossed¶
compute_slopes_crossed(bundle: 'DataBundle', fit: 'FitState', focal_var: str, *, resolved: object, data: pl.DataFrame, spec: 'ModelSpec | None' = None, formula_spec: 'FormulaSpec | None' = None, effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeStateCompute crossed slopes over focal variable x condition grid.
Builds a Cartesian product of condition levels (grid categoricals,
grid numerics, scalar pins) and computes a finite-difference AME
for each combination. Analogous to _compute_emm_crossed but
for continuous focal variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | ‘DataBundle’ | DataBundle with model data and metadata. | required |
fit | ‘FitState’ | FitState with fitted coefficients. | required |
focal_var | str | Continuous variable to compute slopes for. | required |
resolved | object | ResolvedConditions with grid and scalar conditions. | required |
data | DataFrame | Raw model data (for computing ranges and covariate means). | required |
spec | ‘ModelSpec | None’ | ModelSpec with family/link info. | None |
formula_spec | ‘FormulaSpec | None’ | FormulaSpec with learned encoding for evaluate_newdata. | None |
effect_scale | str | "link" (linear predictor scale) or "response" (inverse-link / data scale). | ‘link’ |
delta_frac | float | Fraction of the focal variable’s range used as the finite-difference step size. | 0.001 |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with one slope estimate per condition combination. |
compute_slopes_finite_diff¶
compute_slopes_finite_diff(bundle: DataBundle, fit: FitState, focal_var: str, explore_formula: str, *, spec: ModelSpec, formula_spec: FormulaSpec, data: pl.DataFrame, how: str = 'mem', effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeStateCompute marginal slopes via centered finite differences.
Supports two averaging methods via the how parameter:
"mem": Average over a balanced reference grid (Cartesian product of categorical levels, continuous covariates at means). Matches R’semmeans::emtrends."ame": Average over actual data rows. Gives a true Average Marginal Effect (AME) that preserves the observed covariate distribution.
For linear models (identity link), both approaches give identical results.
delta = delta_frac × range(focal_var)delta = delta_frac × range(focal_var)Build evaluation grid (balanced or observed data)
Perturb:
grid_plus = grid[focal + delta/2],grid_minus = grid[focal − delta/2]X_plus, X_minus = evaluate_newdata(formula_spec, grid_*)L_diff = (X_plus − X_minus) / delta(link-scale functional)If
effect_scale="response"andhow="ame":J_i = [f'(η_i+) X_i+ − f'(η_i−) X_i−] / delta(per-observation Jacobian with f’ at perturbed points)L_avg = mean(J, axis=0, keepdims=True)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model data and metadata. | required |
fit | FitState | FitState with fitted coefficients. | required |
focal_var | str | Continuous variable to compute the slope for. | required |
explore_formula | str | Explore formula string (for result metadata). | required |
spec | ModelSpec | ModelSpec with family/link info. | required |
formula_spec | FormulaSpec | FormulaSpec with learned encoding for evaluate_newdata. | required |
data | DataFrame | Raw model data (for computing ranges and covariate means). | required |
how | str | "mem" for balanced reference grid, "ame" for actual data rows. | ‘mem’ |
effect_scale | str | "link" (linear predictor scale) or "response" (inverse-link / data scale). | ‘link’ |
delta_frac | float | Fraction of the focal variable’s range used as the finite-difference step size. Default 0.001 matches R’s emmeans. | 0.001 |
Returns:
| Type | Description |
|---|---|
MeeState | MeeState with a single-row AME estimate and L_matrix for |
MeeState | delta-method inference. |
transforms¶
Pure transform operations for result DataFrames.
Stateless functions that transform params/effects DataFrames. Each function takes a DataFrame + context and returns a new DataFrame.
Functions:
| Name | Description |
|---|---|
convert_to_effect_size | Compute standardized effect sizes from a params DataFrame. |
convert_to_odds_ratio | Exponentiate estimates and CIs to odds ratio scale. |
convert_to_response_scale | Apply inverse link to transform estimates to response scale. |
filter_effects_dataframe | Filter an effects DataFrame by term, level, or contrast. |
Classes¶
Functions¶
convert_to_effect_size¶
convert_to_effect_size(df: pl.DataFrame, sigma: float | None, df_resid: float | None, family: str = 'gaussian', *, include_intercept: bool = False) -> pl.DataFrameCompute standardized effect sizes from a params DataFrame.
Computes:
d (Cohen’s d): estimate / sigma
r_semi (semi-partial r): |t| / sqrt(t^2 + df_resid)
eta_sq (eta-squared): t^2 / (t^2 + df_resid)
odds_ratio: exp(estimate) for binomial only
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df | DataFrame | Params DataFrame with ‘term’, ‘estimate’, optionally ‘statistic’, ‘ci_lower’, ‘ci_upper’. | required |
sigma | float | None | Residual standard deviation. None if unavailable. | required |
df_resid | float | None | Residual degrees of freedom. None if unavailable. | required |
family | str | Model family string. | ‘gaussian’ |
include_intercept | bool | Whether to include intercept row. | False |
Returns:
| Type | Description |
|---|---|
DataFrame | DataFrame with added effect size columns. |
convert_to_odds_ratio¶
convert_to_odds_ratio(df: pl.DataFrame, family: str) -> pl.DataFrameExponentiate estimates and CIs to odds ratio scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df | DataFrame | Params DataFrame with ‘estimate’ and optionally ‘ci_lower’, ‘ci_upper’. | required |
family | str | Model family string (must be “binomial”). | required |
Returns:
| Type | Description |
|---|---|
DataFrame | DataFrame with exponentiated estimate/CI columns. |
convert_to_response_scale¶
convert_to_response_scale(df: pl.DataFrame, link: str) -> pl.DataFrameApply inverse link to transform estimates to response scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df | DataFrame | Effects DataFrame with ‘estimate’ and optionally ‘ci_lower’, ‘ci_upper’. | required |
link | str | Link function name (e.g., “logit”, “log”). | required |
Returns:
| Type | Description |
|---|---|
DataFrame | DataFrame with transformed columns on response scale. |
filter_effects_dataframe¶
filter_effects_dataframe(df: pl.DataFrame, *, terms: list[str] | str | None = None, levels: list[str] | str | None = None, contrasts: list[str] | str | None = None) -> pl.DataFrameFilter an effects DataFrame by term, level, or contrast.
Pure function that applies one or more filters to a marginal-effects
DataFrame produced by the effects property.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df | DataFrame | Effects DataFrame (from model.effects). | required |
terms | list[str] | str | None | Term name(s) to keep (filters the term column). | None |
levels | list[str] | str | None | Level name(s) to keep (filters the first grid column, i.e. the first column that is not a standard statistic column). | None |
contrasts | list[str] | str | None | Contrast label(s) to keep (filters the contrast column). | None |
Returns:
| Type | Description |
|---|---|
DataFrame | Filtered DataFrame containing only matching rows. |
validation¶
Validation guards for marginal effects operations.
Shared precondition checks for compute_emm, compute_slopes, etc.
Functions:
| Name | Description |
|---|---|
validate_focal_var | Guard: focal variable must be a predictor, not the response. |
Classes¶
Functions¶
validate_focal_var¶
validate_focal_var(bundle: DataBundle, focal_var: str) -> NoneGuard: focal variable must be a predictor, not the response.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bundle | DataBundle | DataBundle with model metadata. | required |
focal_var | str | Name of the variable to compute marginal effects for. | required |