Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Marginal effects and estimated marginal means computation.

Call chain:

model.explore() -> build_reference_grid() -> compute_emm() / compute_slopes() -> compute_contrasts()

Classes:

NameDescription
ConditionA conditioning specification in explore formula.
ExploreFormulaErrorError in explore formula syntax.
ExploreFormulaSpecParsed explore formula.
ResolvedConditionsTyped buckets for resolved conditioning specifications.

Functions:

NameDescription
apply_bracket_contrastsApply bracket contrast expression to an EMM MeeState.
apply_bracket_contrasts_groupedApply bracket contrasts within each condition group of a crossed MeeState.
apply_contrastsApply contrast matrix to marginal means/effects.
apply_contrasts_groupedApply contrasts within each condition group of a crossed MeeState.
apply_rhs_bracket_contrastApply a bracket contrast on a RHS condition column.
build_all_pairwise_matrixBuild all pairwise contrasts between EMM levels.
build_bracket_contrast_matrixBuild contrast matrix and labels from bracket contrast expression.
build_contrast_matrixBuild a contrast matrix based on contrast type.
build_helmert_matrixBuild Helmert contrasts (each level vs mean of previous levels).
build_pairwise_matrixBuild (n-1) linearly independent pairwise contrasts.
build_poly_matrixBuild orthogonal polynomial contrast matrix for EMMs.
build_reference_gridConstruct reference grid for marginal effects evaluation.
build_sequential_matrixBuild sequential (successive differences) contrasts.
build_sum_to_zero_matrixBuild sum-to-zero contrasts (deviation coding).
build_treatment_matrixBuild treatment (Dunnett-style) contrasts against a reference level.
combine_resolvedMerge two ResolvedConditions, with b taking precedence on conflicts.
compose_contrast_matrixCompose contrast matrix with prediction matrix.
compute_compound_bracket_contrastsCompute bracket contrasts for a compound focal variable.
compute_conditional_emmCompute per-group conditional EMMs incorporating intercept BLUPs.
compute_conditional_slopesCompute per-group conditional slopes incorporating BLUPs.
compute_contrastsApply contrast matrix to EMMs.
compute_emmCompute estimated marginal means for a categorical focal variable.
compute_joint_testCompute joint hypothesis tests for model terms.
compute_mee_inferenceCompute delta method inference for marginal effects.
compute_mee_inference_fallbackCompute inference for MEE without L_matrix (fallback path).
compute_mee_seCompute standard errors for MEE estimates (means or slopes).
compute_slopesCompute marginal slope for a continuous focal variable.
compute_slopes_crossedCompute crossed slopes over focal variable x condition grid.
compute_slopes_finite_diffCompute marginal slopes via centered finite differences.
dispatch_marginal_computationRoute a parsed explore formula to the appropriate marginal computation.
get_contrast_labelsGenerate human-readable labels for contrasts.
parse_explore_formulaParse an explore formula string.
resolve_conditionsClassify each Condition into the appropriate typed bucket.

Modules:

NameDescription
bracket_contrastsBracket contrast matrix builder and application.
computeMarginal effects dispatch and routing.
conditionsCondition resolution for explore formula RHS conditioning.
contrastsContrast computation for marginal effects.
emmEstimated marginal means computation.
exploreExplore formula parser.
explore_parserExplore formula recursive descent parser.
explore_scannerExplore formula scanner/tokenizer.
factorsFactor level extraction and handling utilities.
gridReference grid construction for marginal effects.
inferenceInference for marginal effects using the delta method.
joint_testsANOVA-style joint hypothesis tests for model terms.
matricesContrast matrix builders for EMM comparisons.
resolveResolution helpers for marginal effects dispatch.
slopesMarginal slopes computation.
transformsPure transform operations for result DataFrames.
validationValidation guards for marginal effects operations.

Classes

Condition

A conditioning specification in explore formula.

Represents a variable to condition on, optionally with specific values, a method for generating values, or a bracket contrast expression.

Created by: parse_explore_formula(), ExploreParser Consumed by: resolve_all_conditions(), dispatch_marginal_computation() Augmented by: Never

Attributes:

NameTypeDescription
varstrVariable name to condition on.
at_valuestuple | NoneSpecific values to evaluate at (e.g., (50.0,) or (“A”, “B”)).
at_rangeint | NoneNumber of evenly-spaced values across the variable’s range.
at_quantileint | NoneNumber of quantile values to use.
contrast_exprContrastExpr | NoneBracket contrast expression on this condition variable (e.g., from Dose[High - Low] on the RHS). When set, the variable is treated as a grid categorical during condition resolution, and the contrast is applied as a post-processing step.

Attributes

at_quantile
at_quantile: int | None = field(default=None, validator=is_optional_positive_int)
at_range
at_range: int | None = field(default=None, validator=is_optional_positive_int)
at_values
at_values: tuple | None = field(default=None)
contrast_expr
contrast_expr: ContrastExpr | None = field(default=None, validator=(validators.optional(validators.instance_of(ContrastExpr))))
var
var: str = field(validator=(validators.instance_of(str)))

ExploreFormulaError

ExploreFormulaError(message: str, formula: str, position: int | None = None) -> None

Bases: ValueError

Error in explore formula syntax.

Provides helpful error messages with position indicators for syntax errors.

Parameters:

NameTypeDescriptionDefault
messagestrError description.required
formulastrThe formula that caused the error.required
positionint | NoneCharacter position of the error (optional).None

Attributes:

NameTypeDescription
formula
position

Parameters:

NameTypeDescriptionDefault
messagestrError description.required
formulastrThe formula that caused the error.required
positionint | NoneCharacter position of the error.None

Attributes

formula
formula = formula
position
position = position

ExploreFormulaSpec

Parsed explore formula.

Represents a parsed explore formula with focal variable, optional contrast type, and conditioning specifications.

Created by: parse_explore_formula() Consumed by: dispatch_marginal_computation(), plot_predict(), plot_explore() Augmented by: attrs.evolve() in resolve_focal_at_spec() materializes focal at-values

Attributes:

NameTypeDescription
focal_varstrThe variable to compute marginal effects for.
contrast_typestr | NoneType of contrast (pairwise, sequential, poly, treatment, sum, helmert, custom) or None for simple EMMs. Set to "custom" for bracket contrast expressions.
contrast_degreeint | NoneDegree parameter for polynomial contrasts (default None means use n_levels - 1, i.e., maximum degree).
contrast_refstr | NoneReference level for treatment/dummy contrasts (e.g., "Placebo" from treatment(Drug, ref=Placebo)).
contrast_level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts (helmert, sequential, poly). Parsed from bracket list syntax, e.g. poly(dose, [low, med, high]).
contrast_exprContrastExpr | NoneBracket contrast expression AST (e.g., from Drug[Active - Placebo] syntax). None for named contrast functions or simple EMMs.
conditionstuple[Condition, ...]Tuple of Condition objects specifying conditioning variables.
focal_at_valuestuple[float | str, ...] | NoneSpecific values to evaluate the focal variable at (e.g., from Days@[0, 3, 6, 9] syntax). None means use all levels.
focal_at_rangeint | NoneNumber of evenly-spaced values across the focal variable’s range (e.g., from Days@range(5) syntax). None means not set.
focal_at_quantileint | NoneNumber of quantile values for the focal variable (e.g., from Days@quantile(3) syntax). None means not set.

Attributes

conditions
conditions: tuple[Condition, ...] = field(factory=tuple, converter=tuple)
contrast_degree
contrast_degree: int | None = field(default=None, validator=is_optional_positive_int)
contrast_expr
contrast_expr: ContrastExpr | None = field(default=None)
contrast_level_ordering
contrast_level_ordering: tuple[str, ...] | None = field(default=None, validator=is_optional_tuple_of_str)
contrast_ref
contrast_ref: str | None = field(default=None, validator=is_optional_str)
contrast_type
contrast_type: str | None = field(default=None, validator=(validators.optional(is_choice_str(('pairwise', 'sequential', 'poly', 'treatment', 'sum', 'helmert', 'custom')))))
focal_at_quantile
focal_at_quantile: int | None = field(default=None, validator=is_optional_positive_int)
focal_at_range
focal_at_range: int | None = field(default=None, validator=is_optional_positive_int)
focal_at_values
focal_at_values: tuple[float | str, ...] | None = field(default=None)
focal_var
focal_var: str = field(validator=(validators.instance_of(str)))
has_conditions
has_conditions: bool

Return True if conditioning variables are specified.

has_contrast
has_contrast: bool

Return True if any contrast is specified (named function or bracket expr).

has_contrast_expr
has_contrast_expr: bool

Return True if a bracket contrast expression is specified.

has_rhs_contrasts
has_rhs_contrasts: bool

Return True if any RHS condition has a bracket contrast expression.

ResolvedConditions

Typed buckets for resolved conditioning specifications.

Attributes:

NameTypeDescription
at_overridesdict[str, float]Single numeric pin per variable (e.g. Income@50).
set_categoricalsdict[str, str]Single categorical pin per variable (e.g. Ethnicity@Asian).
grid_categoricalsdict[str, list[str]]Multi-level categoricals to cross (e.g. bare Ethnicity → all levels, or Ethnicity@(Asian, Caucasian) → those two levels).
grid_numericsdict[str, list[float]]Multi-value numerics to cross (e.g. Income@(10, 20) or Income@:range(5)).

Attributes

at_overrides
at_overrides: dict[str, float]
grid_categoricals
grid_categoricals: dict[str, list[str]]
grid_numerics
grid_numerics: dict[str, list[float]]
has_grid
has_grid: bool

Return True if any condition requires grid expansion.

raw_at_overrides
raw_at_overrides: dict[str, float] = Factory(dict)

Pre-transform at-overrides with original (raw) variable names.

Populated by the dispatch layer after _resolve_at_overrides remaps keys to design-matrix names. The slopes path needs raw keys because it operates on a data grid with raw column names, not design-matrix names. Empty when no transform resolution was applied.

set_categoricals
set_categoricals: dict[str, str]

Functions

apply_bracket_contrasts

apply_bracket_contrasts(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply bracket contrast expression to an EMM MeeState.

Computes contrasts by building a weight matrix from the bracket expression and multiplying by the EMM estimates.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState with EMM estimates.required
exprContrastExprContrastExpr from parser.required

Returns:

TypeDescription
MeeStateNew MeeState with contrast estimates, type "contrasts",
MeeStateand contrast_method="custom".

apply_bracket_contrasts_grouped

apply_bracket_contrasts_grouped(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply bracket contrasts within each condition group of a crossed MeeState.

Similar to apply_contrasts_grouped() but for bracket contrast expressions. Applies the same contrast matrix independently within each group of n_focal consecutive rows.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from crossed EMM computation with n_focal x n_groups rows.required
exprContrastExprContrastExpr from parser.required

Returns:

TypeDescription
MeeStateMeeState with n_contrasts x n_groups rows and condition columns.

apply_contrasts

apply_contrasts(mee_state: MeeState, contrast_type: str, fit: FitState | None = None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeState

Apply contrast matrix to marginal means/effects.

High-level function that takes a MeeState with EMM estimates and returns a new MeeState with contrast estimates.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState containing EMM estimates from compute_emm().required
contrast_typestrType of contrast to apply: - “pairwise”: All pairwise comparisons (B-A, C-A, C-B, ...) - “sequential”: Adjacent differences (B-A, C-B, D-C, ...) - “poly”: Orthogonal polynomial contrasts - “treatment”: Each level vs reference (requires ref_idx) - “sum”: Each level vs grand mean - “helmert”: Each level vs mean of previous levelsrequired
fitFitState | NoneFitState with vcov for variance propagation (optional, only needed if computing SE during this call).None
degreeint | NoneMaximum polynomial degree for poly contrasts (default: n-1).None
ref_idxint | NoneReference level index for treatment contrasts (0-based).None
level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])). When provided, EMM estimates and grid are reordered to match before applying the contrast matrix.None

Returns:

TypeDescription
MeeStateNew MeeState with:
MeeState- grid: DataFrame with contrast labels
MeeState- estimate: Contrast estimates
MeeState- type: “contrasts”
MeeState- focal_var: Same as input
MeeState- explore_formula: Updated to reflect contrast

Examples:

>>> mee = compute_emm(bundle, fit, "treatment", "treatment")
>>> contrasts = apply_contrasts(mee, "pairwise")
>>> contrasts.estimate  # B-A, C-A, C-B differences

Note: Variance propagation for inference is deferred to .infer(). This function only computes point estimates.

apply_contrasts_grouped

apply_contrasts_grouped(mee_state: MeeState, contrast_type: str | None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeState

Apply contrasts within each condition group of a crossed MeeState.

When EMMs have been computed over a crossed grid (focal levels x condition groups), this function applies the contrast matrix independently within each group of n_focal consecutive rows, then stacks the results.

The number of focal levels is inferred from the focal_var column of mee_state.grid.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from _compute_emm_crossed() with n_focal x n_groups rows ordered as [focal_1_group_1, ..., focal_k_group_1, focal_1_group_2, ...].required
contrast_typestr | NoneType of contrast (pairwise, sequential, poly, treatment, sum, helmert).required
degreeint | NoneDegree for polynomial contrasts.None
ref_idxint | NoneReference level index for treatment contrasts.None
level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts. When provided, EMM rows within each group are reordered before applying the contrast matrix.None

Returns:

TypeDescription
MeeStateMeeState with n_contrasts x n_groups rows and condition columns.

apply_rhs_bracket_contrast

apply_rhs_bracket_contrast(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply a bracket contrast on a RHS condition column.

After the main computation produces a crossed MeeState with a condition column (e.g., Dose with levels [High, Low]), this collapses that column by applying the bracket contrast.

The condition column must vary slowest in the grid (standard layout from crossed EMM/slope computation).

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from a crossed computation. Must contain expr.var as a column in the grid.required
exprContrastExprContrastExpr specifying the contrast on the condition variable.required

Returns:

TypeDescription
MeeStateNew MeeState with the condition column replaced by contrast rows.

build_all_pairwise_matrix

build_all_pairwise_matrix(n_levels: int) -> np.ndarray

Build all pairwise contrasts between EMM levels.

Creates C(n, 2) = n*(n-1)/2 contrasts for all unique pairs. Unlike build_pairwise_matrix(), this includes all pairs, not just comparisons to the reference level.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n*(n-1)/2, n_levels).
ndarrayEach row compares two levels (level_j - level_i where j > i).

Examples:

>>> C = build_all_pairwise_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.],
       [ 0., -1.,  1.]])

build_bracket_contrast_matrix

build_bracket_contrast_matrix(expr: ContrastExpr, levels: list[str]) -> tuple[np.ndarray, list[str]]

Build contrast matrix and labels from bracket contrast expression.

Handles auto-normalization (each operand side sums to +1 or -1) and wildcard expansion (* expands to all unmentioned levels).

Parameters:

NameTypeDescriptionDefault
exprContrastExprContrastExpr AST from the parser.required
levelslist[str]Actual factor levels from the EMM grid.required

Returns:

TypeDescription
ndarrayTuple of (contrast_matrix, labels) where contrast_matrix has
list[str]shape (n_contrasts, n_levels) and labels has length
tuple[ndarray, list[str]]n_contrasts.

build_contrast_matrix

build_contrast_matrix(contrast_type: str | dict, levels: list, normalize: bool = False) -> np.ndarray

Build a contrast matrix based on contrast type.

Dispatcher function that builds the appropriate contrast matrix based on the contrast type specification.

Parameters:

NameTypeDescriptionDefault
contrast_typestr | dictType of contrast: - “pairwise”: All pairwise comparisons (k(k-1)/2 contrasts) - “trt.vs.ctrl” or “treatment”: Compare each level to first level - “sequential”: Successive differences (level[i+1] - level[i]) - “sum”: Each level vs grand mean (deviation coding) - “helmert”: Each level vs mean of previous levels - dict: Custom contrasts with names as keys and weights as valuesrequired
levelslistList of factor levels.required
normalizeboolIf True, normalize custom contrasts to sum to 1/-1.False

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_contrasts, n_levels).

Examples:

>>> build_contrast_matrix("pairwise", ["A", "B", "C"])
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.],
       [ 0., -1.,  1.]])

build_helmert_matrix

build_helmert_matrix(n_levels: int) -> np.ndarray

Build Helmert contrasts (each level vs mean of previous levels).

Creates n_levels - 1 contrasts where level k is compared to the mean of levels 0, 1, ..., k-1.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_helmert_matrix(3)
>>> C
array([[-1. ,  1. ,  0. ],
       [-0.5, -0.5,  1. ]])
>>> C = build_helmert_matrix(4)
>>> C
array([[-1.        ,  1.        ,  0.        ,  0.        ],
       [-0.5       , -0.5       ,  1.        ,  0.        ],
       [-0.33333333, -0.33333333, -0.33333333,  1.        ]])

build_pairwise_matrix

build_pairwise_matrix(n_levels: int) -> np.ndarray

Build (n-1) linearly independent pairwise contrasts.

Creates “treatment-style” contrasts comparing each level to the first (reference) level. This produces n_levels - 1 contrasts that span the space of all pairwise differences.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).
ndarrayRow i compares level i+1 to level 0.

Examples:

>>> C = build_pairwise_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.]])

build_poly_matrix

build_poly_matrix(n_levels: int, degree: int | None = None) -> np.ndarray

Build orthogonal polynomial contrast matrix for EMMs.

Creates orthogonal polynomial contrasts for ordered factors. Tests for linear, quadratic, cubic (etc.) trends across the factor levels.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of factor levels. Must be >= 2.required
degreeint | NoneMaximum polynomial degree. If None, uses n_levels - 1.None

Returns:

TypeDescription
ndarrayContrast matrix of shape (degree, n_levels).
ndarrayRows are orthonormal polynomial basis vectors.

Examples:

>>> build_poly_matrix(5)  # 4x5 matrix
>>> build_poly_matrix(5, degree=2)  # 2x5 matrix (linear + quadratic)

build_reference_grid

build_reference_grid(bundle: DataBundle, focal_vars: list[str], *, at: dict[str, Any] | None = None, covariate_means: dict[str, float] | None = None) -> pl.DataFrame

Construct reference grid for marginal effects evaluation.

Creates a grid of covariate values at which to evaluate predictions. This function transforms the parsed explore formula’s conditions into a concrete evaluation grid.

The grid construction follows emmeans conventions:

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata. Used to extract: - factor_levels: dict mapping categorical vars to their levels - X_names: column names for identifying variable typesrequired
focal_varslist[str]Variables to vary across their range/levels. These become the rows of the output grid.required
atdict[str, Any] | NoneFixed values for conditioning variables. Keys are variable names, values are the conditioning value(s). This dict is typically built from the Condition objects in the parsed explore formula.None
covariate_meansdict[str, float] | NonePre-computed means for continuous covariates. If not provided, covariates will be omitted from the grid (caller must handle them separately or pass via at).None

Returns:

TypeDescription
DataFramePolars DataFrame with one row per grid point. Columns include:
DataFrame- All focal variables
DataFrame- All conditioned variables (if any)
DataFrameThe number of rows is the Cartesian product of focal variable levels.

Integration with explore(): Called from model._compute_marginal_effects() after parsing::

    # From parsed explore formula "treatment ~ age@50"
    parsed = ExploreFormulaSpec(
        focal_var="treatment",
        conditions=[Condition(var="age", at_values=(50,))]
    )

    # Convert conditions to at dict
    at_dict = {"age": 50.0}

    # Build grid
    grid = build_reference_grid(
        bundle=self._bundle,
        focal_vars=["treatment"],
        at=at_dict,
    )
    # grid: pl.DataFrame with columns ["treatment"] and rows ["A", "B", "C"]
    # age is NOT in the grid (it's held fixed at 50)

Examples:

Grid for categorical focal::

grid = build_reference_grid(bundle, ["treatment"])
# Returns grid with one row per treatment level

Grid with conditioning::

grid = build_reference_grid(bundle, ["treatment"], at={"age": 50})
# Returns grid at age=50

Multiple focal variables (interaction EMMs)::

grid = build_reference_grid(bundle, ["treatment", "sex"])
# Returns Cartesian product: treatment x sex levels

build_sequential_matrix

build_sequential_matrix(n_levels: int) -> np.ndarray

Build sequential (successive differences) contrasts.

Creates n_levels - 1 contrasts comparing each level to the previous one.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).
ndarrayRow i compares level i+1 to level i.

Examples:

>>> C = build_sequential_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [ 0., -1.,  1.]])
>>> C = build_sequential_matrix(4)
>>> C
array([[-1.,  1.,  0.,  0.],
       [ 0., -1.,  1.,  0.],
       [ 0.,  0., -1.,  1.]])

build_sum_to_zero_matrix

build_sum_to_zero_matrix(n_levels: int) -> np.ndarray

Build sum-to-zero contrasts (deviation coding).

Creates contrasts comparing each level to the grand mean.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels.required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_sum_to_zero_matrix(3)
>>> C
array([[ 0.667, -0.333, -0.333],
       [-0.333,  0.667, -0.333]])

build_treatment_matrix

build_treatment_matrix(n_levels: int, ref_idx: int = 0) -> np.ndarray

Build treatment (Dunnett-style) contrasts against a reference level.

Creates n_levels - 1 contrasts, each comparing one non-reference level to the specified reference level.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required
ref_idxintIndex of the reference level (0-based).0

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_treatment_matrix(3, ref_idx=0)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.]])
>>> C = build_treatment_matrix(3, ref_idx=2)
>>> C
array([[ 1.,  0., -1.],
       [ 0.,  1., -1.]])

combine_resolved

combine_resolved(a: ResolvedConditions, b: ResolvedConditions) -> ResolvedConditions

Merge two ResolvedConditions, with b taking precedence on conflicts.

Parameters:

NameTypeDescriptionDefault
aResolvedConditionsFirst resolved conditions.required
bResolvedConditionsSecond resolved conditions (takes precedence).required

Returns:

TypeDescription
ResolvedConditionsMerged ResolvedConditions.

compose_contrast_matrix

compose_contrast_matrix(C: np.ndarray, X_ref: np.ndarray) -> np.ndarray

Compose contrast matrix with prediction matrix.

L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs

Parameters:

NameTypeDescriptionDefault
CndarrayContrast matrix of shape (n_contrasts, n_emms).required
X_refndarrayPrediction matrix of shape (n_emms, n_coef).required

Returns:

TypeDescription
ndarrayComposed contrast matrix L_emm of shape (n_contrasts, n_coef).

compute_compound_bracket_contrasts

compute_compound_bracket_contrasts(bundle: object, fit: object, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: object | None = None, effect_scale: str = 'link', resolved: object | None = None) -> MeeState

Compute bracket contrasts for a compound focal variable.

Handles compound variables like Drug:Dose by building a crossed EMM grid over the component variables, creating compound level names, and then applying the bracket contrast.

For example, Drug:Dose[Active:High - Placebo:Low] builds EMMs over Drug × Dose, labels each cell as Active:High, Active:Low, etc., and applies the contrast Active:High - Placebo:Low.

Parameters:

NameTypeDescriptionDefault
bundleobjectDataBundle with model data and metadata.required
fitobjectFitState with fitted coefficients.required
focal_varstrCompound variable name (e.g., "Drug:Dose").required
contrast_exprContrastExprContrastExpr from the parser.required
dataDataFrameModel data DataFrame.required
specobject | NoneModelSpec with link/family info.None
effect_scalestrScale of estimates: "link" or "response".‘link’
resolvedobject | NoneResolvedConditions for additional conditioning.None

Returns:

TypeDescription
MeeStateMeeState with compound bracket contrast estimates.

compute_conditional_emm

compute_conditional_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link', levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None) -> MeeState

Compute per-group conditional EMMs incorporating intercept BLUPs.

For each group g and each focal level, the conditional EMM is: η_g = X_ref_row @ β + b_intercept_g (+ other BLUP contributions)

When effect_scale=“response”: estimates = g⁻¹(η_g).

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the categorical variable.required
explore_formulastrThe explore formula string.required
specobjectModelSpec with link function info.required
varying_offsetsobjectVaryingState with BLUPs per group.required
grouping_varstrName of the grouping variable.required
effect_scalestrScale of estimates: "link" or "response".‘link’
levelslist[str] | NoneOptional list of focal levels.None
at_overridesdict[str, float] | NoneOptional covariate overrides.None
set_categoricalsdict[str, str] | NoneOptional dict pinning non-focal categoricals.None

Returns:

TypeDescription
MeeStateMeeState with per-(level, group) estimates.

compute_conditional_slopes

compute_conditional_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link') -> MeeState

Compute per-group conditional slopes incorporating BLUPs.

For each group, the conditional slope is: slope_g = β_x + b_x_g (random slopes model) slope_g = β_x (random intercepts only)

When effect_scale=“response” and the link is non-identity, the response-scale slope is: slope_response_g = slope_link_g × dμ/dη(η_g)

where η_g is the linear predictor at the reference point for group g.

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the continuous variable to get slopes for.required
explore_formulastrThe explore formula string (for result metadata).required
specobjectModelSpec with link function info.required
varying_offsetsobjectVaryingState with BLUPs per group.required
grouping_varstrName of the grouping variable.required
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with per-group slope estimates.

compute_contrasts

compute_contrasts(emm: np.ndarray, contrast_matrix: np.ndarray) -> np.ndarray

Apply contrast matrix to EMMs.

Transforms a vector of estimated marginal means into contrast estimates by matrix multiplication. This is the final step when the explore formula includes a contrast function like pairwise(treatment).

contrasts = C @ EMM contrasts = C @ EMM

Where:

Parameters:

NameTypeDescriptionDefault
emmndarrayArray of estimated marginal means from compute_emm(). Shape: (n_levels,)required
contrast_matrixndarrayContrast matrix from build_all_pairwise_matrix(), build_sequential_matrix(), or build_poly_matrix(). Shape: (n_contrasts, n_levels) Rows must be orthogonal to the constant vector (sum to 0).required

Returns:

NameTypeDescription
ndarrayArray of contrast estimates.
Shapendarray(n_contrasts,)

Integration with explore(): Called after compute_emm when a contrast function is specified::

    # In model._compute_marginal_effects():
    if parsed.has_contrast:
        # First compute EMMs
        emms = compute_emm(...)

        # Build appropriate contrast matrix
        if parsed.contrast_type == "pairwise":
            C = build_all_pairwise_matrix(n_levels)
        elif parsed.contrast_type == "sequential":
            C = build_sequential_matrix(n_levels)
        elif parsed.contrast_type == "poly":
            C = build_poly_matrix(n_levels, degree=parsed.contrast_degree)

        # Apply contrasts
        contrast_estimates = compute_contrasts(emms, C)

        return build_mee_state(
            grid=contrast_grid,  # with contrast labels
            estimate=contrast_estimates,
            mee_type="contrasts",
            ...
        )

Note: For inference on contrasts, the variance of contrasts is: Var(C @ EMM) = C @ Var(EMM) @ C.T

This is handled by the inference methods (delta method for asymp, direct computation for bootstrap).

Examples:

Apply pairwise contrasts to 3-level EMMs::

emm = np.array([2.0, 3.5, 2.8])  # A, B, C means
C = build_all_pairwise_matrix(3)
# C = [[-1, 1, 0],   # B - A
#      [-1, 0, 1],   # C - A
#      [0, -1, 1]]   # C - B

contrasts = compute_contrasts(emm, C)
# contrasts: array([1.5, 0.8, -0.7])

compute_emm

compute_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: object | None = None, how: str = 'mem', effect_scale: str = 'link') -> MeeState

Compute estimated marginal means for a categorical focal variable.

Supports two averaging methods via the how parameter:

For linear models (identity link), both approaches give identical results. For GLMs (non-identity link), they diverge because mean(g⁻¹(Xᵢβ)) ≠ g⁻¹(mean(Xᵢ) · β).

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata. Used to extract: - X: design matrix for computing covariate means - X_names: column names to identify variable types - factor_levels: level metadata for categorical variablesrequired
fit‘FitState’FitState with fitted coefficients (fit.coef array).required
focal_varstrName of the categorical variable to compute EMMs for.required
explore_formulastrThe explore formula string (for result metadata).required
levelslist[str] | NoneOptional list of levels to compute EMMs for. If None, uses all levels from bundle.factor_levels.None
at_overridesdict[str, float] | NoneOptional dict of {variable: value} to fix specific covariates at given values instead of their means. Used for conditioning (e.g., cyl ~ wt@[3.0] fixes wt at 3.0).None
set_categoricalsdict[str, str] | NoneOptional dict mapping non-focal categorical variable names to specific levels to pin them at (instead of marginalizing at column means). E.g. {"Ethnicity": "Asian"}.None
specobject | NoneModelSpec with link/family info (needed for effect_scale=“response”).None
howstrAveraging method: "mem" for balanced reference grid (emmeans-style), "ame" for g-computation over observed data.‘mem’
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with grid of levels and their estimated means.

Examples:

Compute EMMs with default (reference grid) averaging::

mee = compute_emm(bundle, fit, "treatment", "treatment")

G-computation averaging for a logistic regression::

mee = compute_emm(bundle, fit, "treatment", "treatment",
                  spec=spec, how="ame", effect_scale="response")

Note: For how="ame" with effect_scale="response" on a non-identity link, confidence intervals use the response-scale delta method (symmetric CIs) rather than the link-scale back-transformation (asymmetric CIs) used by how="mem".

compute_joint_test

compute_joint_test(fit: FitState, bundle: DataBundle, spec: ModelSpec, terms: list[str] | None = None, *, errors: str = 'iid', data: pl.DataFrame | None = None) -> JointTestState

Compute joint hypothesis tests for model terms.

Tests each model term (continuous variables, factors, and interactions) using direct coefficient tests. This matches R’s emmeans::joint_tests() behavior.

For each term:

The test type is determined by the model family:

Parameters:

NameTypeDescriptionDefault
fitFitStateFitState with fitted coefficients and vcov.required
bundleDataBundleDataBundle with X_names for term structure.required
specModelSpecModelSpec for model type detection.required
termslist[str] | NoneSpecific terms to test, or None for all terms.None
errorsstrError structure assumption. Options:
- "iid" (default): Standard errors from fit.vcov. - "HC0"-"HC3", "hetero": Sandwich SEs. - "unequal_var": Welch ANOVA — sandwich SEs from per-cell variances with per-term Satterthwaite df2. For single-df terms this matches the Welch t-test; for multi-df terms (k-level factors), per-row dfs are combined via the Fai-Cornelius (1996) approximation.
‘iid’
dataDataFrame | NoneOriginal data frame, required for errors='unequal_var'.None

Returns:

TypeDescription
JointTestStateJointTestState with test results for each term.

Examples:

Basic usage::

from marginal import compute_joint_test
state = compute_joint_test(fit, bundle, spec)
# state.terms: ("x", "group", "x:group")
# state.statistic: F or chi2 values per term
# state.p_value: p-values per term

Welch ANOVA::

state = compute_joint_test(fit, bundle, spec, errors="unequal_var", data=df)
# Per-term Satterthwaite df2 in state.df2

Test specific terms only::

state = compute_joint_test(fit, bundle, spec, terms=["group"])

Note: Intercept is never tested. Random effect grouping factors are excluded for mixed models.

compute_mee_inference

compute_mee_inference(mee: MeeState, vcov: np.ndarray, df_resid: float | np.ndarray | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> MeeState

Compute delta method inference for marginal effects.

SE = sqrt(diag(L @ vcov @ L’)) SE = sqrt(diag(L @ vcov @ L’))

where L is the design matrix (X_ref for EMMs, selector for slopes).

Applies multiplicity adjustment for contrast CIs to match R’s emmeans:

For response-scale EMMs (effect_scale="response"), CIs are computed on the link scale then back-transformed via the inverse link function, matching R’s emmeans behavior.

Parameters:

NameTypeDescriptionDefault
meeMeeStateMeeState with L_matrix from compute_emm or compute_slopes.required
vcovndarrayVariance-covariance matrix of coefficients, shape (p, p).required
df_residfloat | ndarray | NoneDegrees of freedom for t-distribution. Can be: - None: z-distribution (asymptotic normality). - scalar float: single df for all estimates (residual df). - np.ndarray: per-estimate Satterthwaite df (mixed models).required
conf_levelfloatConfidence level for intervals (default 0.95).0.95
nullfloatNull hypothesis value (default 0.0).0.0
alternativestrAlternative hypothesis direction (default “two-sided”).‘two-sided’

Returns:

TypeDescription
MeeStateMeeState augmented with inference fields (se, df, statistic,
MeeStatep_value, ci_lower, ci_upper, conf_level).

Examples:

Compute inference for EMMs::

from marginal import compute_emm
from inference import compute_mee_inference

mee = compute_emm(bundle, fit, "treatment", "treatment")
mee_with_inf = compute_mee_inference(mee, fit.vcov, fit.df_resid)
# mee_with_inf.se: standard errors
# mee_with_inf.ci_lower, mee_with_inf.ci_upper: confidence bounds

Note: For EMMs, L_matrix is the reference design matrix X_ref. For slopes, L_matrix is a selector row that picks the coefficient. For contrasts, L_matrix is the contrast matrix applied to EMMs.

compute_mee_inference_fallback

compute_mee_inference_fallback(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame', df_resid: float | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> 'MeeState'

Compute inference for MEE without L_matrix (fallback path).

Used when L_matrix is not available (legacy MEE). Computes simplified SEs via compute_mee_se, then builds test statistics, p-values, and confidence intervals.

Parameters:

NameTypeDescriptionDefault
mee‘MeeState’MeeState without L_matrix.required
bundle‘DataBundle’Data bundle with X_names.required
fit‘FitState’Fit state with vcov, sigma, etc.required
data‘pl.DataFrame’Original data frame for group counting.required
df_residfloat | NoneResidual degrees of freedom (None for z-distribution).required
conf_levelfloatConfidence level for intervals.0.95
nullfloatNull hypothesis value (default 0.0).0.0
alternativestrAlternative hypothesis direction (default “two-sided”).‘two-sided’

Returns:

TypeDescription
‘MeeState’MeeState augmented with inference fields.

compute_mee_se

compute_mee_se(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame') -> np.ndarray

Compute standard errors for MEE estimates (means or slopes).

When an L_matrix is available, uses the delta method: SE = sqrt(diag(L @ vcov @ L')). This correctly handles multi-row L_matrices (e.g. per-group conditional slopes).

Otherwise falls back to simplified type-based dispatch:

Parameters:

NameTypeDescriptionDefault
mee‘MeeState’MeeState with type (“means” or “slopes”) and focal_var.required
bundle‘DataBundle’Data bundle with X_names.required
fit‘FitState’Fit state with sigma, dispersion, residuals, vcov, df_resid.required
data‘pl.DataFrame’Original data frame for group counting.required

Returns:

TypeDescription
ndarrayArray of standard errors, one per estimate.

compute_slopes

compute_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object | None = None, effect_scale: str = 'link') -> MeeState

Compute marginal slope for a continuous focal variable.

For a linear model, this extracts the coefficient for the variable. For models with interactions, this would need to average marginal effects across the grid (future enhancement).

Marginal slopes answer: “By how much does the predicted outcome change for a one-unit increase in this variable?”

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data. Used to extract: - X_names: to locate the variable’s coefficient indexrequired
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the continuous variable to get slope for.required
explore_formulastrThe explore formula string (for result metadata).required
specobject | NoneModelSpec with link/family info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with marginal slope estimate.

Examples:

Compute marginal effect of age::

from marginal import compute_slopes
mee = compute_slopes(bundle, fit, "age", "age")
# For y ~ age: returns MeeState with estimate = [coef_age]

Note: This is the fast path for Gaussian identity-link models without interactions. For models with interactions or GLMs (non-identity link), the dispatcher routes to compute_slopes_finite_diff which handles both cases via centered finite differences.

compute_slopes_crossed

compute_slopes_crossed(bundle: 'DataBundle', fit: 'FitState', focal_var: str, *, resolved: object, data: pl.DataFrame, spec: 'ModelSpec | None' = None, formula_spec: 'FormulaSpec | None' = None, effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeState

Compute crossed slopes over focal variable x condition grid.

Builds a Cartesian product of condition levels (grid categoricals, grid numerics, scalar pins) and computes a finite-difference AME for each combination. Analogous to _compute_emm_crossed but for continuous focal variables.

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrContinuous variable to compute slopes for.required
resolvedobjectResolvedConditions with grid and scalar conditions.required
dataDataFrameRaw model data (for computing ranges and covariate means).required
spec‘ModelSpec | None’ModelSpec with family/link info.None
formula_spec‘FormulaSpec | None’FormulaSpec with learned encoding for evaluate_newdata.None
effect_scalestr"link" (linear predictor scale) or "response" (inverse-link / data scale).‘link’
delta_fracfloatFraction of the focal variable’s range used as the finite-difference step size.0.001

Returns:

TypeDescription
MeeStateMeeState with one slope estimate per condition combination.

compute_slopes_finite_diff

compute_slopes_finite_diff(bundle: DataBundle, fit: FitState, focal_var: str, explore_formula: str, *, spec: ModelSpec, formula_spec: FormulaSpec, data: pl.DataFrame, how: str = 'mem', effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeState

Compute marginal slopes via centered finite differences.

Supports two averaging methods via the how parameter:

For linear models (identity link), both approaches give identical results.

  1. delta = delta_frac × range(focal_var)

  2. delta = delta_frac × range(focal_var)

  3. Build evaluation grid (balanced or observed data)

  4. Perturb: grid_plus = grid[focal + delta/2], grid_minus = grid[focal − delta/2]

  5. X_plus, X_minus = evaluate_newdata(formula_spec, grid_*)

  6. L_diff = (X_plus − X_minus) / delta (link-scale functional)

  7. If effect_scale="response" and how="ame": J_i = [f'(η_i+) X_i+ − f'(η_i−) X_i−] / delta (per-observation Jacobian with f’ at perturbed points)

  8. L_avg = mean(J, axis=0, keepdims=True)

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrContinuous variable to compute the slope for.required
explore_formulastrExplore formula string (for result metadata).required
specModelSpecModelSpec with family/link info.required
formula_specFormulaSpecFormulaSpec with learned encoding for evaluate_newdata.required
dataDataFrameRaw model data (for computing ranges and covariate means).required
howstr"mem" for balanced reference grid, "ame" for actual data rows.‘mem’
effect_scalestr"link" (linear predictor scale) or "response" (inverse-link / data scale).‘link’
delta_fracfloatFraction of the focal variable’s range used as the finite-difference step size. Default 0.001 matches R’s emmeans.0.001

Returns:

TypeDescription
MeeStateMeeState with a single-row AME estimate and L_matrix for
MeeStatedelta-method inference.

dispatch_marginal_computation

dispatch_marginal_computation(parsed: ExploreFormulaSpec, bundle: DataBundle, fit: FitState, data: pl.DataFrame, *, spec: ModelSpec | None = None, formula_spec: object | None = None, varying_offsets: VaryingState | None = None, effect_scale: str = 'link', varying: str = 'exclude', how: str = 'auto', inverse_transforms: bool = True, by: str | None = None) -> MeeState

Route a parsed explore formula to the appropriate marginal computation.

Validates the focal variable, determines whether it is categorical or continuous, and dispatches to the correct computation function.

Supports effect_scale= and varying= for scale transforms and conditional (group-specific) effects in mixed models. When RHS conditions contain bracket contrasts (e.g., Drug[A - B] ~ Dose[High - Low]), applies them as a post-processing step after the main computation.

Parameters:

NameTypeDescriptionDefault
parsedExploreFormulaSpecParsed explore formula with focal variable and contrast info.required
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
dataDataFrameThe model’s data DataFrame (for variable lookup and type detection).required
specModelSpec | NoneModelSpec with link/family info (needed for effect_scale=“response”).None
formula_specobject | NoneFormulaSpec with learned encoding (needed for finite-diff slopes).None
varying_offsetsVaryingState | NoneVaryingState with BLUPs (needed for conditional effects).None
effect_scalestr"link" (default) or "response" (inverse-link / data scale).‘link’
varyingstr"exclude" (default) or "include" (conditional effects).‘exclude’
howstr"auto" (default), "mem" (emmeans-style), or "ame" (g-computation / average marginal effect).‘auto’
inverse_transformsboolWhen True (default), raw variable names and values are auto-resolved through learned formula transforms. Set to False to use transformed-scale names/values directly.True
bystr | NoneGrouping variable for faceted effects (default: None).None

Returns:

TypeDescription
MeeStateMeeState with computed marginal effects.

get_contrast_labels

get_contrast_labels(levels: list[str], contrast_type: str = 'pairwise') -> list[str]

Generate human-readable labels for contrasts.

Parameters:

NameTypeDescriptionDefault
levelslist[str]List of factor level names.required
contrast_typestrType of contrast (“pairwise”, “all_pairwise”, or “sequential”).‘pairwise’

Returns:

TypeDescription
list[str]List of contrast labels like “B - A”, “C - A”, etc.

Examples:

>>> get_contrast_labels(["A", "B", "C"], "pairwise")
['B - A', 'C - A']
>>> get_contrast_labels(["A", "B", "C"], "all_pairwise")
['B - A', 'C - A', 'C - B']

parse_explore_formula

parse_explore_formula(formula: str, model_terms: list[str] | None = None) -> ExploreFormulaSpec

Parse an explore formula string.

Parameters:

NameTypeDescriptionDefault
formulastrExplore formula (e.g., "pairwise(treatment) ~ age@50").required
model_termslist[str] | NoneOptional list of valid model terms for validation.None

Returns:

TypeDescription
ExploreFormulaSpecParsed ExploreFormulaSpec object.

Examples:

Simple term::

>>> parse_explore_formula("treatment")
ExploreFormulaSpec(focal_var='treatment', contrast_type=None, conditions=())

Contrast function::

>>> parse_explore_formula("pairwise(treatment)")
ExploreFormulaSpec(focal_var='treatment', contrast_type='pairwise', conditions=())

With conditioning::

>>> parse_explore_formula("treatment ~ age@50")
ExploreFormulaSpec(
    focal_var='treatment',
    contrast_type=None,
    conditions=(Condition(var='age', at_values=(50.0,)),)
)

resolve_conditions

resolve_conditions(conditions: tuple[Condition, ...], bundle: DataBundle, data: pl.DataFrame) -> ResolvedConditions

Classify each Condition into the appropriate typed bucket.

Uses bundle.factor_levels to distinguish categorical from continuous variables, and the data DataFrame to compute range / quantile grids for continuous variables.

Parameters:

NameTypeDescriptionDefault
conditionstuple[Condition, ...]Parsed Condition objects from the explore formula.required
bundleDataBundleDataBundle with factor_levels metadata.required
dataDataFrameOriginal data DataFrame for computing range/quantile values.required

Returns:

TypeDescription
ResolvedConditionsResolvedConditions with conditions classified into typed buckets.

Modules

bracket_contrasts

Bracket contrast matrix builder and application.

Builds contrast weight matrices from bracket contrast expressions (e.g., Drug[Active - Placebo], Drug[* - Placebo], Drug[(A + B) - C]).

Auto-normalization: each side of - normalizes so weights sum to 1. Wildcard expansion: * expands to all levels not mentioned on the other side, producing one contrast row per expanded level.

Functions:

NameDescription
apply_bracket_contrastsApply bracket contrast expression to an EMM MeeState.
apply_bracket_contrasts_groupedApply bracket contrasts within each condition group of a crossed MeeState.
apply_rhs_bracket_contrastApply a bracket contrast on a RHS condition column.
build_bracket_contrast_matrixBuild contrast matrix and labels from bracket contrast expression.
compute_compound_bracket_contrastsCompute bracket contrasts for a compound focal variable.
dispatch_bracket_contrastsCompute bracket contrast expression for a categorical focal variable.

Classes

Functions

apply_bracket_contrasts
apply_bracket_contrasts(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply bracket contrast expression to an EMM MeeState.

Computes contrasts by building a weight matrix from the bracket expression and multiplying by the EMM estimates.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState with EMM estimates.required
exprContrastExprContrastExpr from parser.required

Returns:

TypeDescription
MeeStateNew MeeState with contrast estimates, type "contrasts",
MeeStateand contrast_method="custom".
apply_bracket_contrasts_grouped
apply_bracket_contrasts_grouped(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply bracket contrasts within each condition group of a crossed MeeState.

Similar to apply_contrasts_grouped() but for bracket contrast expressions. Applies the same contrast matrix independently within each group of n_focal consecutive rows.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from crossed EMM computation with n_focal x n_groups rows.required
exprContrastExprContrastExpr from parser.required

Returns:

TypeDescription
MeeStateMeeState with n_contrasts x n_groups rows and condition columns.
apply_rhs_bracket_contrast
apply_rhs_bracket_contrast(mee_state: MeeState, expr: ContrastExpr) -> MeeState

Apply a bracket contrast on a RHS condition column.

After the main computation produces a crossed MeeState with a condition column (e.g., Dose with levels [High, Low]), this collapses that column by applying the bracket contrast.

The condition column must vary slowest in the grid (standard layout from crossed EMM/slope computation).

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from a crossed computation. Must contain expr.var as a column in the grid.required
exprContrastExprContrastExpr specifying the contrast on the condition variable.required

Returns:

TypeDescription
MeeStateNew MeeState with the condition column replaced by contrast rows.
build_bracket_contrast_matrix
build_bracket_contrast_matrix(expr: ContrastExpr, levels: list[str]) -> tuple[np.ndarray, list[str]]

Build contrast matrix and labels from bracket contrast expression.

Handles auto-normalization (each operand side sums to +1 or -1) and wildcard expansion (* expands to all unmentioned levels).

Parameters:

NameTypeDescriptionDefault
exprContrastExprContrastExpr AST from the parser.required
levelslist[str]Actual factor levels from the EMM grid.required

Returns:

TypeDescription
ndarrayTuple of (contrast_matrix, labels) where contrast_matrix has
list[str]shape (n_contrasts, n_levels) and labels has length
tuple[ndarray, list[str]]n_contrasts.
compute_compound_bracket_contrasts
compute_compound_bracket_contrasts(bundle: object, fit: object, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: object | None = None, effect_scale: str = 'link', resolved: object | None = None) -> MeeState

Compute bracket contrasts for a compound focal variable.

Handles compound variables like Drug:Dose by building a crossed EMM grid over the component variables, creating compound level names, and then applying the bracket contrast.

For example, Drug:Dose[Active:High - Placebo:Low] builds EMMs over Drug × Dose, labels each cell as Active:High, Active:Low, etc., and applies the contrast Active:High - Placebo:Low.

Parameters:

NameTypeDescriptionDefault
bundleobjectDataBundle with model data and metadata.required
fitobjectFitState with fitted coefficients.required
focal_varstrCompound variable name (e.g., "Drug:Dose").required
contrast_exprContrastExprContrastExpr from the parser.required
dataDataFrameModel data DataFrame.required
specobject | NoneModelSpec with link/family info.None
effect_scalestrScale of estimates: "link" or "response".‘link’
resolvedobject | NoneResolvedConditions for additional conditioning.None

Returns:

TypeDescription
MeeStateMeeState with compound bracket contrast estimates.
dispatch_bracket_contrasts
dispatch_bracket_contrasts(bundle: DataBundle, fit: FitState, focal_var: str, contrast_expr: ContrastExpr, *, data: pl.DataFrame, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem', resolved: ResolvedConditions | None = None) -> MeeState

Compute bracket contrast expression for a categorical focal variable.

First computes EMMs, then applies the bracket contrast matrix. When resolved contains grid conditions, computes crossed EMMs and applies grouped bracket contrasts.

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrName of the categorical variable.required
contrast_exprContrastExprContrastExpr AST from the parser.required
dataDataFrameModel data DataFrame.required
specModelSpec | NoneModelSpec with link info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’
howstrAveraging method: "mem" or "ame".‘mem’
resolvedResolvedConditions | NoneResolved conditions for conditioning.None

Returns:

TypeDescription
MeeStateMeeState with bracket contrast estimates.

compute

Marginal effects dispatch and routing.

Routes parsed explore formulas to the appropriate computation: EMMs for categorical focal variables, slopes for continuous, or contrasts. Supports effect_scale= (link/response scale), how= (mem/ame), and varying= (marginal/conditional).

Functions:

NameDescription
compute_emm_categoricalCompute EMMs for a categorical focal variable.
dispatch_marginal_computationRoute a parsed explore formula to the appropriate marginal computation.

Classes

Functions

compute_emm_categorical
compute_emm_categorical(bundle: DataBundle, fit: FitState, focal_var: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem') -> MeeState

Compute EMMs for a categorical focal variable.

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrName of the categorical variable.required
levelslist[str] | NoneOptional subset of levels to compute EMMs for (e.g. from cyl@[4, 8] syntax). If None, all levels are used.None
at_overridesdict[str, float] | NoneOptional covariate overrides for conditioning.None
set_categoricalsdict[str, str] | NoneOptional dict pinning non-focal categoricals to specific levels (e.g. {"Ethnicity": "Asian"}).None
specModelSpec | NoneModelSpec with link info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’
howstrAveraging method: "mem" or "ame".‘mem’

Returns:

TypeDescription
MeeStateMeeState with grid of levels and their estimated means.
dispatch_marginal_computation
dispatch_marginal_computation(parsed: ExploreFormulaSpec, bundle: DataBundle, fit: FitState, data: pl.DataFrame, *, spec: ModelSpec | None = None, formula_spec: object | None = None, varying_offsets: VaryingState | None = None, effect_scale: str = 'link', varying: str = 'exclude', how: str = 'auto', inverse_transforms: bool = True, by: str | None = None) -> MeeState

Route a parsed explore formula to the appropriate marginal computation.

Validates the focal variable, determines whether it is categorical or continuous, and dispatches to the correct computation function.

Supports effect_scale= and varying= for scale transforms and conditional (group-specific) effects in mixed models. When RHS conditions contain bracket contrasts (e.g., Drug[A - B] ~ Dose[High - Low]), applies them as a post-processing step after the main computation.

Parameters:

NameTypeDescriptionDefault
parsedExploreFormulaSpecParsed explore formula with focal variable and contrast info.required
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
dataDataFrameThe model’s data DataFrame (for variable lookup and type detection).required
specModelSpec | NoneModelSpec with link/family info (needed for effect_scale=“response”).None
formula_specobject | NoneFormulaSpec with learned encoding (needed for finite-diff slopes).None
varying_offsetsVaryingState | NoneVaryingState with BLUPs (needed for conditional effects).None
effect_scalestr"link" (default) or "response" (inverse-link / data scale).‘link’
varyingstr"exclude" (default) or "include" (conditional effects).‘exclude’
howstr"auto" (default), "mem" (emmeans-style), or "ame" (g-computation / average marginal effect).‘auto’
inverse_transformsboolWhen True (default), raw variable names and values are auto-resolved through learned formula transforms. Set to False to use transformed-scale names/values directly.True
bystr | NoneGrouping variable for faceted effects (default: None).None

Returns:

TypeDescription
MeeStateMeeState with computed marginal effects.

conditions

Condition resolution for explore formula RHS conditioning.

Resolves parsed Condition objects into typed buckets that downstream EMM / slope / contrast computation can consume:

Classes:

NameDescription
ResolvedConditionsTyped buckets for resolved conditioning specifications.

Functions:

NameDescription
combine_resolvedMerge two ResolvedConditions, with b taking precedence on conflicts.
get_column_valuesExtract a data column as a numpy array for range/quantile computation.
resolve_conditionsClassify each Condition into the appropriate typed bucket.

Classes

ResolvedConditions

Typed buckets for resolved conditioning specifications.

Attributes:

NameTypeDescription
at_overridesdict[str, float]Single numeric pin per variable (e.g. Income@50).
set_categoricalsdict[str, str]Single categorical pin per variable (e.g. Ethnicity@Asian).
grid_categoricalsdict[str, list[str]]Multi-level categoricals to cross (e.g. bare Ethnicity → all levels, or Ethnicity@(Asian, Caucasian) → those two levels).
grid_numericsdict[str, list[float]]Multi-value numerics to cross (e.g. Income@(10, 20) or Income@:range(5)).
Attributes
at_overrides
at_overrides: dict[str, float]
grid_categoricals
grid_categoricals: dict[str, list[str]]
grid_numerics
grid_numerics: dict[str, list[float]]
has_grid
has_grid: bool

Return True if any condition requires grid expansion.

raw_at_overrides
raw_at_overrides: dict[str, float] = Factory(dict)

Pre-transform at-overrides with original (raw) variable names.

Populated by the dispatch layer after _resolve_at_overrides remaps keys to design-matrix names. The slopes path needs raw keys because it operates on a data grid with raw column names, not design-matrix names. Empty when no transform resolution was applied.

set_categoricals
set_categoricals: dict[str, str]

Functions

combine_resolved
combine_resolved(a: ResolvedConditions, b: ResolvedConditions) -> ResolvedConditions

Merge two ResolvedConditions, with b taking precedence on conflicts.

Parameters:

NameTypeDescriptionDefault
aResolvedConditionsFirst resolved conditions.required
bResolvedConditionsSecond resolved conditions (takes precedence).required

Returns:

TypeDescription
ResolvedConditionsMerged ResolvedConditions.
get_column_values
get_column_values(data: pl.DataFrame, bundle: DataBundle, var: str) -> np.ndarray

Extract a data column as a numpy array for range/quantile computation.

Looks in the original data DataFrame first, falling back to the design matrix column if the variable name matches an X_names entry.

Parameters:

NameTypeDescriptionDefault
dataDataFrameOriginal data DataFrame.required
bundleDataBundleDataBundle with design matrix.required
varstrVariable name.required

Returns:

TypeDescription
ndarray1-D numpy array of column values.
resolve_conditions
resolve_conditions(conditions: tuple[Condition, ...], bundle: DataBundle, data: pl.DataFrame) -> ResolvedConditions

Classify each Condition into the appropriate typed bucket.

Uses bundle.factor_levels to distinguish categorical from continuous variables, and the data DataFrame to compute range / quantile grids for continuous variables.

Parameters:

NameTypeDescriptionDefault
conditionstuple[Condition, ...]Parsed Condition objects from the explore formula.required
bundleDataBundleDataBundle with factor_levels metadata.required
dataDataFrameOriginal data DataFrame for computing range/quantile values.required

Returns:

TypeDescription
ResolvedConditionsResolvedConditions with conditions classified into typed buckets.

contrasts

Contrast computation for marginal effects.

Applies contrast matrices to EMM estimates. Matrix builders live in marginal.matrices.

apply_contrasts: Apply contrasts to MeeState apply_contrasts: Apply contrasts to MeeState apply_contrasts_grouped: Apply contrasts within condition groups compute_contrasts: Apply contrast matrix to EMM vector

Functions:

NameDescription
apply_contrastsApply contrast matrix to marginal means/effects.
apply_contrasts_groupedApply contrasts within each condition group of a crossed MeeState.
compute_contrastsApply contrast matrix to EMMs.
dispatch_contrastsCompute contrasts for a categorical focal variable.

Classes

Functions

apply_contrasts
apply_contrasts(mee_state: MeeState, contrast_type: str, fit: FitState | None = None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeState

Apply contrast matrix to marginal means/effects.

High-level function that takes a MeeState with EMM estimates and returns a new MeeState with contrast estimates.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState containing EMM estimates from compute_emm().required
contrast_typestrType of contrast to apply: - “pairwise”: All pairwise comparisons (B-A, C-A, C-B, ...) - “sequential”: Adjacent differences (B-A, C-B, D-C, ...) - “poly”: Orthogonal polynomial contrasts - “treatment”: Each level vs reference (requires ref_idx) - “sum”: Each level vs grand mean - “helmert”: Each level vs mean of previous levelsrequired
fitFitState | NoneFitState with vcov for variance propagation (optional, only needed if computing SE during this call).None
degreeint | NoneMaximum polynomial degree for poly contrasts (default: n-1).None
ref_idxint | NoneReference level index for treatment contrasts (0-based).None
level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])). When provided, EMM estimates and grid are reordered to match before applying the contrast matrix.None

Returns:

TypeDescription
MeeStateNew MeeState with:
MeeState- grid: DataFrame with contrast labels
MeeState- estimate: Contrast estimates
MeeState- type: “contrasts”
MeeState- focal_var: Same as input
MeeState- explore_formula: Updated to reflect contrast

Examples:

>>> mee = compute_emm(bundle, fit, "treatment", "treatment")
>>> contrasts = apply_contrasts(mee, "pairwise")
>>> contrasts.estimate  # B-A, C-A, C-B differences

Note: Variance propagation for inference is deferred to .infer(). This function only computes point estimates.

apply_contrasts_grouped
apply_contrasts_grouped(mee_state: MeeState, contrast_type: str | None, *, degree: int | None = None, ref_idx: int | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeState

Apply contrasts within each condition group of a crossed MeeState.

When EMMs have been computed over a crossed grid (focal levels x condition groups), this function applies the contrast matrix independently within each group of n_focal consecutive rows, then stacks the results.

The number of focal levels is inferred from the focal_var column of mee_state.grid.

Parameters:

NameTypeDescriptionDefault
mee_stateMeeStateMeeState from _compute_emm_crossed() with n_focal x n_groups rows ordered as [focal_1_group_1, ..., focal_k_group_1, focal_1_group_2, ...].required
contrast_typestr | NoneType of contrast (pairwise, sequential, poly, treatment, sum, helmert).required
degreeint | NoneDegree for polynomial contrasts.None
ref_idxint | NoneReference level index for treatment contrasts.None
level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts. When provided, EMM rows within each group are reordered before applying the contrast matrix.None

Returns:

TypeDescription
MeeStateMeeState with n_contrasts x n_groups rows and condition columns.
compute_contrasts
compute_contrasts(emm: np.ndarray, contrast_matrix: np.ndarray) -> np.ndarray

Apply contrast matrix to EMMs.

Transforms a vector of estimated marginal means into contrast estimates by matrix multiplication. This is the final step when the explore formula includes a contrast function like pairwise(treatment).

contrasts = C @ EMM contrasts = C @ EMM

Where:

Parameters:

NameTypeDescriptionDefault
emmndarrayArray of estimated marginal means from compute_emm(). Shape: (n_levels,)required
contrast_matrixndarrayContrast matrix from build_all_pairwise_matrix(), build_sequential_matrix(), or build_poly_matrix(). Shape: (n_contrasts, n_levels) Rows must be orthogonal to the constant vector (sum to 0).required

Returns:

NameTypeDescription
ndarrayArray of contrast estimates.
Shapendarray(n_contrasts,)

Integration with explore(): Called after compute_emm when a contrast function is specified::

    # In model._compute_marginal_effects():
    if parsed.has_contrast:
        # First compute EMMs
        emms = compute_emm(...)

        # Build appropriate contrast matrix
        if parsed.contrast_type == "pairwise":
            C = build_all_pairwise_matrix(n_levels)
        elif parsed.contrast_type == "sequential":
            C = build_sequential_matrix(n_levels)
        elif parsed.contrast_type == "poly":
            C = build_poly_matrix(n_levels, degree=parsed.contrast_degree)

        # Apply contrasts
        contrast_estimates = compute_contrasts(emms, C)

        return build_mee_state(
            grid=contrast_grid,  # with contrast labels
            estimate=contrast_estimates,
            mee_type="contrasts",
            ...
        )

Note: For inference on contrasts, the variance of contrasts is: Var(C @ EMM) = C @ Var(EMM) @ C.T

This is handled by the inference methods (delta method for asymp, direct computation for bootstrap).

Examples:

Apply pairwise contrasts to 3-level EMMs::

emm = np.array([2.0, 3.5, 2.8])  # A, B, C means
C = build_all_pairwise_matrix(3)
# C = [[-1, 1, 0],   # B - A
#      [-1, 0, 1],   # C - A
#      [0, -1, 1]]   # C - B

contrasts = compute_contrasts(emm, C)
# contrasts: array([1.5, 0.8, -0.7])
dispatch_contrasts
dispatch_contrasts(bundle: DataBundle, fit: FitState, focal_var: str, contrast_type: str | None, contrast_degree: int | None, data: pl.DataFrame, *, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem', resolved: ResolvedConditions | None = None, focal_at_values: tuple[float | str, ...] | None = None, contrast_ref: str | None = None, level_ordering: tuple[str, ...] | None = None) -> MeeState

Compute contrasts for a categorical focal variable.

First computes EMMs, then applies the requested contrast matrix. When resolved contains grid conditions, computes crossed EMMs and applies grouped contrasts. When focal_at_values is provided (e.g. from pairwise(cyl@[4, 8])), contrasts are computed only over the requested subset of levels.

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrName of the categorical variable.required
contrast_typestr | NoneType of contrast (pairwise, sequential, poly, treatment, sum, helmert).required
contrast_degreeint | NoneDegree for polynomial contrasts (None = max).required
dataDataFrameModel data for level extraction.required
specModelSpec | NoneModelSpec with link info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’
howstrAveraging method: "mem" or "ame".‘mem’
resolvedResolvedConditions | NoneResolved conditions for conditioning.None
focal_at_valuestuple[float | str, ...] | NoneOptional subset of levels from at-spec syntax (e.g. pairwise(cyl@[4, 8])).None
contrast_refstr | NoneReference level name for treatment contrasts.None
level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts (e.g. from poly(dose, [low, med, high])).None

Returns:

TypeDescription
MeeStateMeeState with contrast estimates.

emm

Estimated marginal means computation.

This module provides EMM computation for categorical focal variables.

compute_emm: Compute estimated marginal means at grid points compute_emm: Compute estimated marginal means at grid points

Functions:

NameDescription
compute_conditional_emmCompute per-group conditional EMMs incorporating intercept BLUPs.
compute_emmCompute estimated marginal means for a categorical focal variable.
compute_emm_crossedCompute crossed EMMs over focal levels x condition grid.

Classes

Functions

compute_conditional_emm
compute_conditional_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link', levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None) -> MeeState

Compute per-group conditional EMMs incorporating intercept BLUPs.

For each group g and each focal level, the conditional EMM is: η_g = X_ref_row @ β + b_intercept_g (+ other BLUP contributions)

When effect_scale=“response”: estimates = g⁻¹(η_g).

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the categorical variable.required
explore_formulastrThe explore formula string.required
specobjectModelSpec with link function info.required
varying_offsetsobjectVaryingState with BLUPs per group.required
grouping_varstrName of the grouping variable.required
effect_scalestrScale of estimates: "link" or "response".‘link’
levelslist[str] | NoneOptional list of focal levels.None
at_overridesdict[str, float] | NoneOptional covariate overrides.None
set_categoricalsdict[str, str] | NoneOptional dict pinning non-focal categoricals.None

Returns:

TypeDescription
MeeStateMeeState with per-(level, group) estimates.
compute_emm
compute_emm(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, levels: list[str] | None = None, at_overrides: dict[str, float] | None = None, set_categoricals: dict[str, str] | None = None, spec: object | None = None, how: str = 'mem', effect_scale: str = 'link') -> MeeState

Compute estimated marginal means for a categorical focal variable.

Supports two averaging methods via the how parameter:

For linear models (identity link), both approaches give identical results. For GLMs (non-identity link), they diverge because mean(g⁻¹(Xᵢβ)) ≠ g⁻¹(mean(Xᵢ) · β).

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata. Used to extract: - X: design matrix for computing covariate means - X_names: column names to identify variable types - factor_levels: level metadata for categorical variablesrequired
fit‘FitState’FitState with fitted coefficients (fit.coef array).required
focal_varstrName of the categorical variable to compute EMMs for.required
explore_formulastrThe explore formula string (for result metadata).required
levelslist[str] | NoneOptional list of levels to compute EMMs for. If None, uses all levels from bundle.factor_levels.None
at_overridesdict[str, float] | NoneOptional dict of {variable: value} to fix specific covariates at given values instead of their means. Used for conditioning (e.g., cyl ~ wt@[3.0] fixes wt at 3.0).None
set_categoricalsdict[str, str] | NoneOptional dict mapping non-focal categorical variable names to specific levels to pin them at (instead of marginalizing at column means). E.g. {"Ethnicity": "Asian"}.None
specobject | NoneModelSpec with link/family info (needed for effect_scale=“response”).None
howstrAveraging method: "mem" for balanced reference grid (emmeans-style), "ame" for g-computation over observed data.‘mem’
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with grid of levels and their estimated means.

Examples:

Compute EMMs with default (reference grid) averaging::

mee = compute_emm(bundle, fit, "treatment", "treatment")

G-computation averaging for a logistic regression::

mee = compute_emm(bundle, fit, "treatment", "treatment",
                  spec=spec, how="ame", effect_scale="response")

Note: For how="ame" with effect_scale="response" on a non-identity link, confidence intervals use the response-scale delta method (symmetric CIs) rather than the link-scale back-transformation (asymmetric CIs) used by how="mem".

compute_emm_crossed
compute_emm_crossed(bundle: DataBundle, fit: FitState, focal_var: str, *, resolved: ResolvedConditions, levels: list[str] | None = None, spec: ModelSpec | None = None, effect_scale: str = 'link', how: str = 'mem') -> MeeState

Compute crossed EMMs over focal levels x condition grid.

Builds a Cartesian product of focal levels with all grid conditions, computing one EMM row per combination.

Supports two averaging methods via how:

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrName of the categorical focal variable.required
resolvedResolvedConditionsResolvedConditions with grid and scalar conditions.required
levelslist[str] | NoneOptional subset of focal levels to include (e.g. from treatment@[A, B]). If None, all levels are used.None
specModelSpec | NoneModelSpec with link info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’
howstrAveraging method: "mem" or "ame".‘mem’

Returns:

TypeDescription
MeeStateMeeState with n_focal x n_grid rows and condition columns.

explore

Explore formula parser.

Tokenizes and parses explore formula strings into :class:ExploreFormulaSpec containers using :class:ExploreScanner and :class:ExploreParser.

Grammar::

explore_formula := lhs [ '~' rhs ]

lhs := focal_term | contrast_fn '(' focal_term { ',' arg } ')'
focal_term := varname [ '@' at_spec ]
at_spec := value | '[' value_list ']' | [':']'range(' n ')' | [':']'quantile(' n ')'
contrast_fn := 'pairwise' | 'sequential' | 'poly'
             | 'treatment' | 'dummy' | 'sum' | 'deviation' | 'helmert'
arg := IDENTIFIER '=' value | value

rhs := condition [ '+' condition ]*
condition := term [ '@' at_spec ]

Classes:

NameDescription
ConditionA conditioning specification in explore formula.
ExploreFormulaErrorError in explore formula syntax.
ExploreFormulaSpecParsed explore formula.

Functions:

NameDescription
parse_explore_formulaParse an explore formula string.
parse_lhsParse LHS of explore formula.

Classes

Condition

A conditioning specification in explore formula.

Represents a variable to condition on, optionally with specific values, a method for generating values, or a bracket contrast expression.

Created by: parse_explore_formula(), ExploreParser Consumed by: resolve_all_conditions(), dispatch_marginal_computation() Augmented by: Never

Attributes:

NameTypeDescription
varstrVariable name to condition on.
at_valuestuple | NoneSpecific values to evaluate at (e.g., (50.0,) or (“A”, “B”)).
at_rangeint | NoneNumber of evenly-spaced values across the variable’s range.
at_quantileint | NoneNumber of quantile values to use.
contrast_exprContrastExpr | NoneBracket contrast expression on this condition variable (e.g., from Dose[High - Low] on the RHS). When set, the variable is treated as a grid categorical during condition resolution, and the contrast is applied as a post-processing step.
Attributes
at_quantile
at_quantile: int | None = field(default=None, validator=is_optional_positive_int)
at_range
at_range: int | None = field(default=None, validator=is_optional_positive_int)
at_values
at_values: tuple | None = field(default=None)
contrast_expr
contrast_expr: ContrastExpr | None = field(default=None, validator=(validators.optional(validators.instance_of(ContrastExpr))))
var
var: str = field(validator=(validators.instance_of(str)))
ExploreFormulaError
ExploreFormulaError(message: str, formula: str, position: int | None = None) -> None

Bases: ValueError

Error in explore formula syntax.

Provides helpful error messages with position indicators for syntax errors.

Parameters:

NameTypeDescriptionDefault
messagestrError description.required
formulastrThe formula that caused the error.required
positionint | NoneCharacter position of the error (optional).None

Attributes:

NameTypeDescription
formula
position

Parameters:

NameTypeDescriptionDefault
messagestrError description.required
formulastrThe formula that caused the error.required
positionint | NoneCharacter position of the error.None
Attributes
formula
formula = formula
position
position = position
ExploreFormulaSpec

Parsed explore formula.

Represents a parsed explore formula with focal variable, optional contrast type, and conditioning specifications.

Created by: parse_explore_formula() Consumed by: dispatch_marginal_computation(), plot_predict(), plot_explore() Augmented by: attrs.evolve() in resolve_focal_at_spec() materializes focal at-values

Attributes:

NameTypeDescription
focal_varstrThe variable to compute marginal effects for.
contrast_typestr | NoneType of contrast (pairwise, sequential, poly, treatment, sum, helmert, custom) or None for simple EMMs. Set to "custom" for bracket contrast expressions.
contrast_degreeint | NoneDegree parameter for polynomial contrasts (default None means use n_levels - 1, i.e., maximum degree).
contrast_refstr | NoneReference level for treatment/dummy contrasts (e.g., "Placebo" from treatment(Drug, ref=Placebo)).
contrast_level_orderingtuple[str, ...] | NoneExplicit level ordering for order-dependent contrasts (helmert, sequential, poly). Parsed from bracket list syntax, e.g. poly(dose, [low, med, high]).
contrast_exprContrastExpr | NoneBracket contrast expression AST (e.g., from Drug[Active - Placebo] syntax). None for named contrast functions or simple EMMs.
conditionstuple[Condition, ...]Tuple of Condition objects specifying conditioning variables.
focal_at_valuestuple[float | str, ...] | NoneSpecific values to evaluate the focal variable at (e.g., from Days@[0, 3, 6, 9] syntax). None means use all levels.
focal_at_rangeint | NoneNumber of evenly-spaced values across the focal variable’s range (e.g., from Days@range(5) syntax). None means not set.
focal_at_quantileint | NoneNumber of quantile values for the focal variable (e.g., from Days@quantile(3) syntax). None means not set.
Attributes
conditions
conditions: tuple[Condition, ...] = field(factory=tuple, converter=tuple)
contrast_degree
contrast_degree: int | None = field(default=None, validator=is_optional_positive_int)
contrast_expr
contrast_expr: ContrastExpr | None = field(default=None)
contrast_level_ordering
contrast_level_ordering: tuple[str, ...] | None = field(default=None, validator=is_optional_tuple_of_str)
contrast_ref
contrast_ref: str | None = field(default=None, validator=is_optional_str)
contrast_type
contrast_type: str | None = field(default=None, validator=(validators.optional(is_choice_str(('pairwise', 'sequential', 'poly', 'treatment', 'sum', 'helmert', 'custom')))))
focal_at_quantile
focal_at_quantile: int | None = field(default=None, validator=is_optional_positive_int)
focal_at_range
focal_at_range: int | None = field(default=None, validator=is_optional_positive_int)
focal_at_values
focal_at_values: tuple[float | str, ...] | None = field(default=None)
focal_var
focal_var: str = field(validator=(validators.instance_of(str)))
has_conditions
has_conditions: bool

Return True if conditioning variables are specified.

has_contrast
has_contrast: bool

Return True if any contrast is specified (named function or bracket expr).

has_contrast_expr
has_contrast_expr: bool

Return True if a bracket contrast expression is specified.

has_rhs_contrasts
has_rhs_contrasts: bool

Return True if any RHS condition has a bracket contrast expression.

Functions

parse_explore_formula
parse_explore_formula(formula: str, model_terms: list[str] | None = None) -> ExploreFormulaSpec

Parse an explore formula string.

Parameters:

NameTypeDescriptionDefault
formulastrExplore formula (e.g., "pairwise(treatment) ~ age@50").required
model_termslist[str] | NoneOptional list of valid model terms for validation.None

Returns:

TypeDescription
ExploreFormulaSpecParsed ExploreFormulaSpec object.

Examples:

Simple term::

>>> parse_explore_formula("treatment")
ExploreFormulaSpec(focal_var='treatment', contrast_type=None, conditions=())

Contrast function::

>>> parse_explore_formula("pairwise(treatment)")
ExploreFormulaSpec(focal_var='treatment', contrast_type='pairwise', conditions=())

With conditioning::

>>> parse_explore_formula("treatment ~ age@50")
ExploreFormulaSpec(
    focal_var='treatment',
    contrast_type=None,
    conditions=(Condition(var='age', at_values=(50.0,)),)
)
parse_lhs
parse_lhs(lhs: str, formula: str) -> tuple[str, str | None, int | None, tuple[float | str, ...] | None, int | None, int | None]

Parse LHS of explore formula.

Backward-compatible wrapper that tokenizes and parses just the LHS portion of an explore formula.

Parameters:

NameTypeDescriptionDefault
lhsstrLeft-hand side of the formula.required
formulastrFull formula for error messages.required

Returns:

TypeDescription
strTuple of (focal_var, contrast_type, contrast_degree, focal_at_values,
str | Nonefocal_at_range, focal_at_quantile).

Note: This wrapper does not return contrast_ref. Use parse_explore_formula() for the full result including contrast_ref.

explore_parser

Explore formula recursive descent parser.

Parses token streams from :class:ExploreScanner into :class:ExploreFormulaSpec containers.

Grammar::

explore_formula  := lhs [ TILDE rhs ]

lhs              := contrast_fn_call | contrast_expr | focal_term
contrast_fn_call := fn_name LEFT_PAREN focal_term { COMMA arg } RIGHT_PAREN
fn_name          := 'pairwise' | 'sequential' | 'poly'
                  | 'treatment' | 'dummy'
                  | 'sum' | 'deviation'
                  | 'helmert'
arg              := kwarg | positional_arg
kwarg            := IDENTIFIER EQUAL value
positional_arg   := value

contrast_expr    := IDENTIFIER LEFT_BRACKET contrast_list RIGHT_BRACKET
contrast_list    := contrast_item { COMMA contrast_item }
contrast_item    := contrast_operand MINUS contrast_operand
contrast_operand := STAR
                  | LEFT_PAREN level_name { PLUS level_name } RIGHT_PAREN
                  | level_name
level_name       := IDENTIFIER | STRING | NUMBER

focal_term       := IDENTIFIER [ AT at_spec ]
at_spec          := value
                  | LEFT_BRACKET value_list RIGHT_BRACKET
                  | [COLON] range_or_quantile
range_or_quantile := IDENTIFIER LEFT_PAREN NUMBER RIGHT_PAREN
                     (where IDENTIFIER is 'range' or 'quantile')
value_list       := value { COMMA value }
value            := NUMBER | STRING | IDENTIFIER

rhs              := condition { PLUS condition }
condition        := IDENTIFIER LEFT_BRACKET contrast_list RIGHT_BRACKET
                  | IDENTIFIER [ AT at_spec ]

Classes:

NameDescription
ExploreParserRecursive descent parser for explore formula syntax.

Attributes

Classes

ExploreParser
ExploreParser(tokens: list[Token], formula: str) -> None

Recursive descent parser for explore formula syntax.

Consumes tokens produced by :class:ExploreScanner and builds an :class:ExploreFormulaSpec container.

Parameters:

NameTypeDescriptionDefault
tokenslist[Token]List of tokens from the scanner.required
formulastrOriginal formula string (for error messages).required

Functions:

NameDescription
parseParse explore formula tokens into ExploreFormulaSpec container.
Functions
parse
parse() -> ExploreFormulaSpec

Parse explore formula tokens into ExploreFormulaSpec container.

Returns:

TypeDescription
ExploreFormulaSpecParsed ExploreFormulaSpec.

Functions

explore_scanner

Explore formula scanner/tokenizer.

Extends the formula scanner with @ token support for explore-specific at-spec syntax (e.g., Days@50, Income@range(5)).

Classes:

NameDescription
ExploreScannerScanner for explore formulas.

Classes

ExploreScanner

Bases: Scanner

Scanner for explore formulas.

Extends the base formula scanner to:

  1. Recognize @ as an AT token.

  2. Skip intercept insertion (explore formulas have no implicit intercept).

  3. Allow multiple tildes check to be skipped (explore ~ separates LHS/RHS, not response/predictors).

Parameters:

NameTypeDescriptionDefault
codestrThe explore formula string to scan.required

Functions:

NameDescription
add_token
advance
at_end
backquote
char
floatnum
identifier
match
number
peek
peek_next
scanScan explore formula string.
scan_tokenScan a single token, adding @ support.

Attributes:

NameTypeDescription
code
current
start
tokenslist[Token]
Attributes
code
code = code
current
current = 0
start
start = 0
tokens
tokens: list[Token] = []
Functions
add_token
add_token(kind: str, literal: object = None) -> None
advance
advance() -> str
at_end
at_end() -> bool
backquote
backquote() -> None
char
char() -> None
floatnum
floatnum() -> None
identifier
identifier() -> None
match
match(expected: str) -> bool
number
number() -> None
peek
peek() -> str
peek_next
peek_next() -> str
scan
scan(add_intercept: bool = False) -> list[Token]

Scan explore formula string.

Overrides base to skip intercept insertion and skip the multiple-tilde validation (explore formulas use ~ differently).

Parameters:

NameTypeDescriptionDefault
add_interceptboolIgnored; always False for explore formulas.False

Returns:

TypeDescription
list[Token]A list of Token objects.
scan_token
scan_token() -> None

Scan a single token, adding @ support.

Functions

factors

Factor level extraction and handling utilities.

Functions:

NameDescription
detect_factor_levels_from_dataInfer factor levels from unique values in data column.
get_factor_levelsExtract factor levels for a variable from bundle or data.

Classes

Functions

detect_factor_levels_from_data
detect_factor_levels_from_data(data: 'pl.DataFrame', var: str) -> list[str]

Infer factor levels from unique values in data column.

Extracts unique values and sorts them for consistent ordering. Values are converted to strings for uniform handling.

Parameters:

NameTypeDescriptionDefault
data‘pl.DataFrame’DataFrame containing the variable.required
varstrColumn name to extract levels from.required

Returns:

TypeDescription
list[str]Sorted list of unique values as strings.

Examples:

>>> import polars as pl
>>> df = pl.DataFrame({"group": ["B", "A", "B", "C", "A"]})
>>> detect_factor_levels_from_data(df, "group")
['A', 'B', 'C']
get_factor_levels
get_factor_levels(bundle: 'DataBundle', var: str, *, fallback_data: 'pl.DataFrame | None' = None, allow_inference: bool = True) -> list[str]

Extract factor levels for a variable from bundle or data.

Priority order:

  1. bundle.factor_levels[var] if present

  2. Infer from fallback_data[var].unique() if provided and allow_inference=True

  3. Raise ValueError if not found and no fallback

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with factor_levels metadata.required
varstrVariable name to get levels for.required
fallback_data‘pl.DataFrame | None’Optional DataFrame to infer levels from if not in bundle.None
allow_inferenceboolIf True, allow inferring levels from data. If False, require levels to be in bundle.factor_levels.True

Returns:

TypeDescription
list[str]List of level values as strings, sorted.

Examples:

Get levels from bundle metadata::

levels = get_factor_levels(bundle, "treatment")
# ["control", "drug_a", "drug_b"]

Get levels with data fallback::

levels = get_factor_levels(bundle, "group", fallback_data=df)
# Inferred from df["group"].unique()

grid

Reference grid construction for marginal effects.

Functions:

NameDescription
build_reference_gridConstruct reference grid for marginal effects evaluation.

Classes

Functions

build_reference_grid
build_reference_grid(bundle: DataBundle, focal_vars: list[str], *, at: dict[str, Any] | None = None, covariate_means: dict[str, float] | None = None) -> pl.DataFrame

Construct reference grid for marginal effects evaluation.

Creates a grid of covariate values at which to evaluate predictions. This function transforms the parsed explore formula’s conditions into a concrete evaluation grid.

The grid construction follows emmeans conventions:

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata. Used to extract: - factor_levels: dict mapping categorical vars to their levels - X_names: column names for identifying variable typesrequired
focal_varslist[str]Variables to vary across their range/levels. These become the rows of the output grid.required
atdict[str, Any] | NoneFixed values for conditioning variables. Keys are variable names, values are the conditioning value(s). This dict is typically built from the Condition objects in the parsed explore formula.None
covariate_meansdict[str, float] | NonePre-computed means for continuous covariates. If not provided, covariates will be omitted from the grid (caller must handle them separately or pass via at).None

Returns:

TypeDescription
DataFramePolars DataFrame with one row per grid point. Columns include:
DataFrame- All focal variables
DataFrame- All conditioned variables (if any)
DataFrameThe number of rows is the Cartesian product of focal variable levels.

Integration with explore(): Called from model._compute_marginal_effects() after parsing::

    # From parsed explore formula "treatment ~ age@50"
    parsed = ExploreFormulaSpec(
        focal_var="treatment",
        conditions=[Condition(var="age", at_values=(50,))]
    )

    # Convert conditions to at dict
    at_dict = {"age": 50.0}

    # Build grid
    grid = build_reference_grid(
        bundle=self._bundle,
        focal_vars=["treatment"],
        at=at_dict,
    )
    # grid: pl.DataFrame with columns ["treatment"] and rows ["A", "B", "C"]
    # age is NOT in the grid (it's held fixed at 50)

Examples:

Grid for categorical focal::

grid = build_reference_grid(bundle, ["treatment"])
# Returns grid with one row per treatment level

Grid with conditioning::

grid = build_reference_grid(bundle, ["treatment"], at={"age": 50})
# Returns grid at age=50

Multiple focal variables (interaction EMMs)::

grid = build_reference_grid(bundle, ["treatment", "sex"])
# Returns Cartesian product: treatment x sex levels

inference

Inference for marginal effects using the delta method.

This module provides inference computation for MeeState results. Uses the delta method: SE = sqrt(diag(L @ vcov @ L’))

compute_mee_inference: Add inference to MeeState via delta method compute_mee_inference: Add inference to MeeState via delta method compute_mee_se: Compute SEs for MEE (means or slopes) compute_mee_inference_fallback: Fallback inference without L_matrix

Functions:

NameDescription
compute_mee_inferenceCompute delta method inference for marginal effects.
compute_mee_inference_fallbackCompute inference for MEE without L_matrix (fallback path).
compute_mee_seCompute standard errors for MEE estimates (means or slopes).

Classes

Functions

compute_mee_inference
compute_mee_inference(mee: MeeState, vcov: np.ndarray, df_resid: float | np.ndarray | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> MeeState

Compute delta method inference for marginal effects.

SE = sqrt(diag(L @ vcov @ L’)) SE = sqrt(diag(L @ vcov @ L’))

where L is the design matrix (X_ref for EMMs, selector for slopes).

Applies multiplicity adjustment for contrast CIs to match R’s emmeans:

For response-scale EMMs (effect_scale="response"), CIs are computed on the link scale then back-transformed via the inverse link function, matching R’s emmeans behavior.

Parameters:

NameTypeDescriptionDefault
meeMeeStateMeeState with L_matrix from compute_emm or compute_slopes.required
vcovndarrayVariance-covariance matrix of coefficients, shape (p, p).required
df_residfloat | ndarray | NoneDegrees of freedom for t-distribution. Can be: - None: z-distribution (asymptotic normality). - scalar float: single df for all estimates (residual df). - np.ndarray: per-estimate Satterthwaite df (mixed models).required
conf_levelfloatConfidence level for intervals (default 0.95).0.95
nullfloatNull hypothesis value (default 0.0).0.0
alternativestrAlternative hypothesis direction (default “two-sided”).‘two-sided’

Returns:

TypeDescription
MeeStateMeeState augmented with inference fields (se, df, statistic,
MeeStatep_value, ci_lower, ci_upper, conf_level).

Examples:

Compute inference for EMMs::

from marginal import compute_emm
from inference import compute_mee_inference

mee = compute_emm(bundle, fit, "treatment", "treatment")
mee_with_inf = compute_mee_inference(mee, fit.vcov, fit.df_resid)
# mee_with_inf.se: standard errors
# mee_with_inf.ci_lower, mee_with_inf.ci_upper: confidence bounds

Note: For EMMs, L_matrix is the reference design matrix X_ref. For slopes, L_matrix is a selector row that picks the coefficient. For contrasts, L_matrix is the contrast matrix applied to EMMs.

compute_mee_inference_fallback
compute_mee_inference_fallback(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame', df_resid: float | None, conf_level: float = 0.95, null: float = 0.0, alternative: str = 'two-sided') -> 'MeeState'

Compute inference for MEE without L_matrix (fallback path).

Used when L_matrix is not available (legacy MEE). Computes simplified SEs via compute_mee_se, then builds test statistics, p-values, and confidence intervals.

Parameters:

NameTypeDescriptionDefault
mee‘MeeState’MeeState without L_matrix.required
bundle‘DataBundle’Data bundle with X_names.required
fit‘FitState’Fit state with vcov, sigma, etc.required
data‘pl.DataFrame’Original data frame for group counting.required
df_residfloat | NoneResidual degrees of freedom (None for z-distribution).required
conf_levelfloatConfidence level for intervals.0.95
nullfloatNull hypothesis value (default 0.0).0.0
alternativestrAlternative hypothesis direction (default “two-sided”).‘two-sided’

Returns:

TypeDescription
‘MeeState’MeeState augmented with inference fields.
compute_mee_se
compute_mee_se(mee: 'MeeState', bundle: 'DataBundle', fit: 'FitState', data: 'pl.DataFrame') -> np.ndarray

Compute standard errors for MEE estimates (means or slopes).

When an L_matrix is available, uses the delta method: SE = sqrt(diag(L @ vcov @ L')). This correctly handles multi-row L_matrices (e.g. per-group conditional slopes).

Otherwise falls back to simplified type-based dispatch:

Parameters:

NameTypeDescriptionDefault
mee‘MeeState’MeeState with type (“means” or “slopes”) and focal_var.required
bundle‘DataBundle’Data bundle with X_names.required
fit‘FitState’Fit state with sigma, dispersion, residuals, vcov, df_resid.required
data‘pl.DataFrame’Original data frame for group counting.required

Returns:

TypeDescription
ndarrayArray of standard errors, one per estimate.

joint_tests

ANOVA-style joint hypothesis tests for model terms.

Functions:

NameDescription
compute_joint_testCompute joint hypothesis tests for model terms.

Classes

Functions

compute_joint_test
compute_joint_test(fit: FitState, bundle: DataBundle, spec: ModelSpec, terms: list[str] | None = None, *, errors: str = 'iid', data: pl.DataFrame | None = None) -> JointTestState

Compute joint hypothesis tests for model terms.

Tests each model term (continuous variables, factors, and interactions) using direct coefficient tests. This matches R’s emmeans::joint_tests() behavior.

For each term:

The test type is determined by the model family:

Parameters:

NameTypeDescriptionDefault
fitFitStateFitState with fitted coefficients and vcov.required
bundleDataBundleDataBundle with X_names for term structure.required
specModelSpecModelSpec for model type detection.required
termslist[str] | NoneSpecific terms to test, or None for all terms.None
errorsstrError structure assumption. Options:
- "iid" (default): Standard errors from fit.vcov. - "HC0"-"HC3", "hetero": Sandwich SEs. - "unequal_var": Welch ANOVA — sandwich SEs from per-cell variances with per-term Satterthwaite df2. For single-df terms this matches the Welch t-test; for multi-df terms (k-level factors), per-row dfs are combined via the Fai-Cornelius (1996) approximation.
‘iid’
dataDataFrame | NoneOriginal data frame, required for errors='unequal_var'.None

Returns:

TypeDescription
JointTestStateJointTestState with test results for each term.

Examples:

Basic usage::

from marginal import compute_joint_test
state = compute_joint_test(fit, bundle, spec)
# state.terms: ("x", "group", "x:group")
# state.statistic: F or chi2 values per term
# state.p_value: p-values per term

Welch ANOVA::

state = compute_joint_test(fit, bundle, spec, errors="unequal_var", data=df)
# Per-term Satterthwaite df2 in state.df2

Test specific terms only::

state = compute_joint_test(fit, bundle, spec, terms=["group"])

Note: Intercept is never tested. Random effect grouping factors are excluded for mixed models.

matrices

Contrast matrix builders for EMM comparisons.

Provides functions to build hypothesis contrast matrices used in marginal effects computation. These matrices transform EMM vectors into contrast estimates (e.g., pairwise differences, sequential differences, polynomial trends).

Distinct from design.coding which builds design matrix coding matrices for categorical variable encoding during model fitting.

Created by: absorbed from maths/inference/contrasts.py (zero consumers outside marginal/) + poly_contrasts from contrasts.py. Consumed by: marginal/contrasts.py (apply_contrasts, apply_contrasts_grouped).

Functions:

NameDescription
build_all_pairwise_matrixBuild all pairwise contrasts between EMM levels.
build_contrast_matrixBuild a contrast matrix based on contrast type.
build_helmert_matrixBuild Helmert contrasts (each level vs mean of previous levels).
build_pairwise_matrixBuild (n-1) linearly independent pairwise contrasts.
build_poly_matrixBuild orthogonal polynomial contrast matrix for EMMs.
build_sequential_matrixBuild sequential (successive differences) contrasts.
build_sum_to_zero_matrixBuild sum-to-zero contrasts (deviation coding).
build_treatment_matrixBuild treatment (Dunnett-style) contrasts against a reference level.
compose_contrast_matrixCompose contrast matrix with prediction matrix.
get_contrast_labelsGenerate human-readable labels for contrasts.

Functions

build_all_pairwise_matrix
build_all_pairwise_matrix(n_levels: int) -> np.ndarray

Build all pairwise contrasts between EMM levels.

Creates C(n, 2) = n*(n-1)/2 contrasts for all unique pairs. Unlike build_pairwise_matrix(), this includes all pairs, not just comparisons to the reference level.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n*(n-1)/2, n_levels).
ndarrayEach row compares two levels (level_j - level_i where j > i).

Examples:

>>> C = build_all_pairwise_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.],
       [ 0., -1.,  1.]])
build_contrast_matrix
build_contrast_matrix(contrast_type: str | dict, levels: list, normalize: bool = False) -> np.ndarray

Build a contrast matrix based on contrast type.

Dispatcher function that builds the appropriate contrast matrix based on the contrast type specification.

Parameters:

NameTypeDescriptionDefault
contrast_typestr | dictType of contrast: - “pairwise”: All pairwise comparisons (k(k-1)/2 contrasts) - “trt.vs.ctrl” or “treatment”: Compare each level to first level - “sequential”: Successive differences (level[i+1] - level[i]) - “sum”: Each level vs grand mean (deviation coding) - “helmert”: Each level vs mean of previous levels - dict: Custom contrasts with names as keys and weights as valuesrequired
levelslistList of factor levels.required
normalizeboolIf True, normalize custom contrasts to sum to 1/-1.False

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_contrasts, n_levels).

Examples:

>>> build_contrast_matrix("pairwise", ["A", "B", "C"])
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.],
       [ 0., -1.,  1.]])
build_helmert_matrix
build_helmert_matrix(n_levels: int) -> np.ndarray

Build Helmert contrasts (each level vs mean of previous levels).

Creates n_levels - 1 contrasts where level k is compared to the mean of levels 0, 1, ..., k-1.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_helmert_matrix(3)
>>> C
array([[-1. ,  1. ,  0. ],
       [-0.5, -0.5,  1. ]])
>>> C = build_helmert_matrix(4)
>>> C
array([[-1.        ,  1.        ,  0.        ,  0.        ],
       [-0.5       , -0.5       ,  1.        ,  0.        ],
       [-0.33333333, -0.33333333, -0.33333333,  1.        ]])
build_pairwise_matrix
build_pairwise_matrix(n_levels: int) -> np.ndarray

Build (n-1) linearly independent pairwise contrasts.

Creates “treatment-style” contrasts comparing each level to the first (reference) level. This produces n_levels - 1 contrasts that span the space of all pairwise differences.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).
ndarrayRow i compares level i+1 to level 0.

Examples:

>>> C = build_pairwise_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.]])
build_poly_matrix
build_poly_matrix(n_levels: int, degree: int | None = None) -> np.ndarray

Build orthogonal polynomial contrast matrix for EMMs.

Creates orthogonal polynomial contrasts for ordered factors. Tests for linear, quadratic, cubic (etc.) trends across the factor levels.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of factor levels. Must be >= 2.required
degreeint | NoneMaximum polynomial degree. If None, uses n_levels - 1.None

Returns:

TypeDescription
ndarrayContrast matrix of shape (degree, n_levels).
ndarrayRows are orthonormal polynomial basis vectors.

Examples:

>>> build_poly_matrix(5)  # 4x5 matrix
>>> build_poly_matrix(5, degree=2)  # 2x5 matrix (linear + quadratic)
build_sequential_matrix
build_sequential_matrix(n_levels: int) -> np.ndarray

Build sequential (successive differences) contrasts.

Creates n_levels - 1 contrasts comparing each level to the previous one.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).
ndarrayRow i compares level i+1 to level i.

Examples:

>>> C = build_sequential_matrix(3)
>>> C
array([[-1.,  1.,  0.],
       [ 0., -1.,  1.]])
>>> C = build_sequential_matrix(4)
>>> C
array([[-1.,  1.,  0.,  0.],
       [ 0., -1.,  1.,  0.],
       [ 0.,  0., -1.,  1.]])
build_sum_to_zero_matrix
build_sum_to_zero_matrix(n_levels: int) -> np.ndarray

Build sum-to-zero contrasts (deviation coding).

Creates contrasts comparing each level to the grand mean.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels.required

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_sum_to_zero_matrix(3)
>>> C
array([[ 0.667, -0.333, -0.333],
       [-0.333,  0.667, -0.333]])
build_treatment_matrix
build_treatment_matrix(n_levels: int, ref_idx: int = 0) -> np.ndarray

Build treatment (Dunnett-style) contrasts against a reference level.

Creates n_levels - 1 contrasts, each comparing one non-reference level to the specified reference level.

Parameters:

NameTypeDescriptionDefault
n_levelsintNumber of EMM levels (factor levels).required
ref_idxintIndex of the reference level (0-based).0

Returns:

TypeDescription
ndarrayContrast matrix of shape (n_levels - 1, n_levels).

Examples:

>>> C = build_treatment_matrix(3, ref_idx=0)
>>> C
array([[-1.,  1.,  0.],
       [-1.,  0.,  1.]])
>>> C = build_treatment_matrix(3, ref_idx=2)
>>> C
array([[ 1.,  0., -1.],
       [ 0.,  1., -1.]])
compose_contrast_matrix
compose_contrast_matrix(C: np.ndarray, X_ref: np.ndarray) -> np.ndarray

Compose contrast matrix with prediction matrix.

L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs L_emm @ beta = C @ (X_ref @ beta) = C @ EMMs

Parameters:

NameTypeDescriptionDefault
CndarrayContrast matrix of shape (n_contrasts, n_emms).required
X_refndarrayPrediction matrix of shape (n_emms, n_coef).required

Returns:

TypeDescription
ndarrayComposed contrast matrix L_emm of shape (n_contrasts, n_coef).
get_contrast_labels
get_contrast_labels(levels: list[str], contrast_type: str = 'pairwise') -> list[str]

Generate human-readable labels for contrasts.

Parameters:

NameTypeDescriptionDefault
levelslist[str]List of factor level names.required
contrast_typestrType of contrast (“pairwise”, “all_pairwise”, or “sequential”).‘pairwise’

Returns:

TypeDescription
list[str]List of contrast labels like “B - A”, “C - A”, etc.

Examples:

>>> get_contrast_labels(["A", "B", "C"], "pairwise")
['B - A', 'C - A']
>>> get_contrast_labels(["A", "B", "C"], "all_pairwise")
['B - A', 'C - A', 'C - B']

resolve

Resolution helpers for marginal effects dispatch.

Translates raw user-facing formula elements into internal representations: variable name resolution through transforms, at-value forward transforms, condition classification, and focal at-range/at-quantile materialization.

Functions:

NameDescription
build_var_transform_mapBuild a mapping from raw variable names to their transformed column info.
resolve_all_conditionsResolve formula conditions into typed buckets.
resolve_at_overridesRemap at-override keys through forward transforms; optionally transform values.
resolve_conditionalDetermine if conditional effects are requested and resolve grouping var.
resolve_focal_at_specResolve focal_at_range / focal_at_quantile into focal_at_values.
resolve_focal_at_valuesForward-transform focal at-values through a formula transform.
resolve_focal_varResolve a raw focal variable name to its transformed column name.

Attributes:

NameTypeDescription
INVERTIBLE_TRANSFORMS

Attributes

INVERTIBLE_TRANSFORMS
INVERTIBLE_TRANSFORMS = frozenset({'center', 'norm', 'zscore', 'scale'})

Classes

Functions

build_var_transform_map
build_var_transform_map(formula_spec: FormulaSpec | None) -> dict[str, tuple[str, StatefulTransform]]

Build a mapping from raw variable names to their transformed column info.

Inspects formula_spec.transform_state keys (e.g. "center(x)") and extracts the inner variable name. Only includes invertible transforms (center, norm, zscore, scale).

Parameters:

NameTypeDescriptionDefault
formula_specFormulaSpec | NoneFormulaSpec with transform_state and transforms dicts.required

Returns:

TypeDescription
dict[str, tuple[str, StatefulTransform]]Dict mapping raw variable name to (transformed_col_name, transform_obj).
dict[str, tuple[str, StatefulTransform]]Empty dict when no transforms are present or formula_spec is None.
resolve_all_conditions
resolve_all_conditions(parsed: ExploreFormulaSpec, bundle: DataBundle, data: pl.DataFrame, fspec: FormulaSpec | None, inverse_transforms: bool) -> ResolvedConditions

Resolve formula conditions into typed buckets.

Classifies RHS conditions from the parsed formula, then applies forward transforms to numeric overrides.

Parameters:

NameTypeDescriptionDefault
parsedExploreFormulaSpecParsed explore formula.required
bundleDataBundleDataBundle with factor_levels and design matrix.required
dataDataFrameOriginal data DataFrame.required
fspecFormulaSpec | NoneFormulaSpec with learned transforms (for forward resolution).required
inverse_transformsboolWhether to apply forward transforms.required

Returns:

TypeDescription
ResolvedConditionsResolvedConditions with all conditions classified.
resolve_at_overrides
resolve_at_overrides(at_overrides: dict[str, float] | None, formula_spec: FormulaSpec | None, inverse_transforms: bool) -> dict[str, float] | None

Remap at-override keys through forward transforms; optionally transform values.

Always remaps raw variable names to their transformed column names (e.g. "x""center(x)"). When inverse_transforms=True, also forward-transforms the values (e.g. 5050 - mean). When False, values are left as-is on the user-specified scale.

Parameters:

NameTypeDescriptionDefault
at_overridesdict[str, float] | NoneOriginal at-override dict (may be None).required
formula_specFormulaSpec | NoneFormulaSpec with learned transforms.required
inverse_transformsboolWhether to apply forward transforms to values.required

Returns:

TypeDescription
dict[str, float] | NoneResolved at_overrides dict, or None if input was None.
resolve_conditional
resolve_conditional(parsed: ExploreFormulaSpec, bundle: DataBundle, varying: str, varying_offsets: VaryingState | None) -> tuple[bool, str | None]

Determine if conditional effects are requested and resolve grouping var.

Parameters:

NameTypeDescriptionDefault
parsedExploreFormulaSpecParsed explore formula.required
bundleDataBundleDataBundle with RE metadata.required
varyingstr“exclude” or “include”.required
varying_offsetsVaryingState | NoneVaryingState (None for non-mixed models).required

Returns:

TypeDescription
tuple[bool, str | None]Tuple of (is_conditional, grouping_var or None).
resolve_focal_at_spec
resolve_focal_at_spec(parsed: ExploreFormulaSpec, data: pl.DataFrame, bundle: DataBundle) -> ExploreFormulaSpec

Resolve focal_at_range / focal_at_quantile into focal_at_values.

When the parsed formula contains focal_at_range=n or focal_at_quantile=n, this function computes the concrete values from the data and returns an evolved copy with focal_at_values populated.

Parameters:

NameTypeDescriptionDefault
parsedExploreFormulaSpecParsed explore formula (may have range/quantile fields set).required
dataDataFrameOriginal data DataFrame for computing range/quantile values.required
bundleDataBundleDataBundle with design matrix metadata.required

Returns:

TypeDescription
ExploreFormulaSpecExploreFormulaSpec with focal_at_values resolved (or unchanged if
ExploreFormulaSpecneither range nor quantile was set).
resolve_focal_at_values
resolve_focal_at_values(focal_var: str, at_values: tuple[float, ...], formula_spec: FormulaSpec | None, inverse_transforms: bool) -> tuple[float, ...]

Forward-transform focal at-values through a formula transform.

When inverse_transforms=True and the focal variable has a learned transform (e.g. center(x)), applies the forward transform to each at-value so that raw-scale values become transformed-scale values.

Parameters:

NameTypeDescriptionDefault
focal_varstrRaw focal variable name (e.g. "Days").required
at_valuestuple[float, ...]User-specified at-values on the raw scale.required
formula_specFormulaSpec | NoneFormulaSpec with learned transforms.required
inverse_transformsboolWhether to apply forward transforms.required

Returns:

TypeDescription
tuple[float, ...]Transformed at-values tuple (or original if no transform applies).
resolve_focal_var
resolve_focal_var(focal_var: str, X_names: tuple[str, ...] | list[str], formula_spec: FormulaSpec | None, inverse_transforms: bool) -> str

Resolve a raw focal variable name to its transformed column name.

Always attempts resolution regardless of inverse_transforms so that raw variable names (e.g. "x") are mapped to their design-matrix column names (e.g. "center(x)").

Parameters:

NameTypeDescriptionDefault
focal_varstrRaw variable name from the explore formula.required
X_namestuple[str, ...] | list[str]Design matrix column names.required
formula_specFormulaSpec | NoneFormulaSpec with transform info.required
inverse_transformsboolUnused; kept for API compatibility.required

Returns:

TypeDescription
strThe resolved column name (transformed or original).

slopes

Marginal slopes computation.

This module provides marginal slope computation for continuous focal variables.

Three strategies are available:

compute_slopes: Coefficient-extraction slopes (Gaussian identity, no interactions) compute_slopes: Coefficient-extraction slopes (Gaussian identity, no interactions) compute_slopes_finite_diff: Finite-difference AME slopes compute_conditional_slopes: Per-group conditional slopes with BLUPs

Functions:

NameDescription
compute_conditional_slopesCompute per-group conditional slopes incorporating BLUPs.
compute_slopesCompute marginal slope for a continuous focal variable.
compute_slopes_crossedCompute crossed slopes over focal variable x condition grid.
compute_slopes_finite_diffCompute marginal slopes via centered finite differences.

Classes

Functions

compute_conditional_slopes
compute_conditional_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object, varying_offsets: object, grouping_var: str, effect_scale: str = 'link') -> MeeState

Compute per-group conditional slopes incorporating BLUPs.

For each group, the conditional slope is: slope_g = β_x + b_x_g (random slopes model) slope_g = β_x (random intercepts only)

When effect_scale=“response” and the link is non-identity, the response-scale slope is: slope_response_g = slope_link_g × dμ/dη(η_g)

where η_g is the linear predictor at the reference point for group g.

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the continuous variable to get slopes for.required
explore_formulastrThe explore formula string (for result metadata).required
specobjectModelSpec with link function info.required
varying_offsetsobjectVaryingState with BLUPs per group.required
grouping_varstrName of the grouping variable.required
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with per-group slope estimates.
compute_slopes
compute_slopes(bundle: 'DataBundle', fit: 'FitState', focal_var: str, explore_formula: str, *, spec: object | None = None, effect_scale: str = 'link') -> MeeState

Compute marginal slope for a continuous focal variable.

For a linear model, this extracts the coefficient for the variable. For models with interactions, this would need to average marginal effects across the grid (future enhancement).

Marginal slopes answer: “By how much does the predicted outcome change for a one-unit increase in this variable?”

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data. Used to extract: - X_names: to locate the variable’s coefficient indexrequired
fit‘FitState’FitState with fitted coefficients.required
focal_varstrName of the continuous variable to get slope for.required
explore_formulastrThe explore formula string (for result metadata).required
specobject | NoneModelSpec with link/family info (for effect_scale=“response”).None
effect_scalestrScale of estimates: "link" or "response".‘link’

Returns:

TypeDescription
MeeStateMeeState with marginal slope estimate.

Examples:

Compute marginal effect of age::

from marginal import compute_slopes
mee = compute_slopes(bundle, fit, "age", "age")
# For y ~ age: returns MeeState with estimate = [coef_age]

Note: This is the fast path for Gaussian identity-link models without interactions. For models with interactions or GLMs (non-identity link), the dispatcher routes to compute_slopes_finite_diff which handles both cases via centered finite differences.

compute_slopes_crossed
compute_slopes_crossed(bundle: 'DataBundle', fit: 'FitState', focal_var: str, *, resolved: object, data: pl.DataFrame, spec: 'ModelSpec | None' = None, formula_spec: 'FormulaSpec | None' = None, effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeState

Compute crossed slopes over focal variable x condition grid.

Builds a Cartesian product of condition levels (grid categoricals, grid numerics, scalar pins) and computes a finite-difference AME for each combination. Analogous to _compute_emm_crossed but for continuous focal variables.

Parameters:

NameTypeDescriptionDefault
bundle‘DataBundle’DataBundle with model data and metadata.required
fit‘FitState’FitState with fitted coefficients.required
focal_varstrContinuous variable to compute slopes for.required
resolvedobjectResolvedConditions with grid and scalar conditions.required
dataDataFrameRaw model data (for computing ranges and covariate means).required
spec‘ModelSpec | None’ModelSpec with family/link info.None
formula_spec‘FormulaSpec | None’FormulaSpec with learned encoding for evaluate_newdata.None
effect_scalestr"link" (linear predictor scale) or "response" (inverse-link / data scale).‘link’
delta_fracfloatFraction of the focal variable’s range used as the finite-difference step size.0.001

Returns:

TypeDescription
MeeStateMeeState with one slope estimate per condition combination.
compute_slopes_finite_diff
compute_slopes_finite_diff(bundle: DataBundle, fit: FitState, focal_var: str, explore_formula: str, *, spec: ModelSpec, formula_spec: FormulaSpec, data: pl.DataFrame, how: str = 'mem', effect_scale: str = 'link', delta_frac: float = 0.001) -> MeeState

Compute marginal slopes via centered finite differences.

Supports two averaging methods via the how parameter:

For linear models (identity link), both approaches give identical results.

  1. delta = delta_frac × range(focal_var)

  2. delta = delta_frac × range(focal_var)

  3. Build evaluation grid (balanced or observed data)

  4. Perturb: grid_plus = grid[focal + delta/2], grid_minus = grid[focal − delta/2]

  5. X_plus, X_minus = evaluate_newdata(formula_spec, grid_*)

  6. L_diff = (X_plus − X_minus) / delta (link-scale functional)

  7. If effect_scale="response" and how="ame": J_i = [f'(η_i+) X_i+ − f'(η_i−) X_i−] / delta (per-observation Jacobian with f’ at perturbed points)

  8. L_avg = mean(J, axis=0, keepdims=True)

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model data and metadata.required
fitFitStateFitState with fitted coefficients.required
focal_varstrContinuous variable to compute the slope for.required
explore_formulastrExplore formula string (for result metadata).required
specModelSpecModelSpec with family/link info.required
formula_specFormulaSpecFormulaSpec with learned encoding for evaluate_newdata.required
dataDataFrameRaw model data (for computing ranges and covariate means).required
howstr"mem" for balanced reference grid, "ame" for actual data rows.‘mem’
effect_scalestr"link" (linear predictor scale) or "response" (inverse-link / data scale).‘link’
delta_fracfloatFraction of the focal variable’s range used as the finite-difference step size. Default 0.001 matches R’s emmeans.0.001

Returns:

TypeDescription
MeeStateMeeState with a single-row AME estimate and L_matrix for
MeeStatedelta-method inference.

transforms

Pure transform operations for result DataFrames.

Stateless functions that transform params/effects DataFrames. Each function takes a DataFrame + context and returns a new DataFrame.

Functions:

NameDescription
convert_to_effect_sizeCompute standardized effect sizes from a params DataFrame.
convert_to_odds_ratioExponentiate estimates and CIs to odds ratio scale.
convert_to_response_scaleApply inverse link to transform estimates to response scale.
filter_effects_dataframeFilter an effects DataFrame by term, level, or contrast.

Classes

Functions

convert_to_effect_size
convert_to_effect_size(df: pl.DataFrame, sigma: float | None, df_resid: float | None, family: str = 'gaussian', *, include_intercept: bool = False) -> pl.DataFrame

Compute standardized effect sizes from a params DataFrame.

Computes:

Parameters:

NameTypeDescriptionDefault
dfDataFrameParams DataFrame with ‘term’, ‘estimate’, optionally ‘statistic’, ‘ci_lower’, ‘ci_upper’.required
sigmafloat | NoneResidual standard deviation. None if unavailable.required
df_residfloat | NoneResidual degrees of freedom. None if unavailable.required
familystrModel family string.‘gaussian’
include_interceptboolWhether to include intercept row.False

Returns:

TypeDescription
DataFrameDataFrame with added effect size columns.
convert_to_odds_ratio
convert_to_odds_ratio(df: pl.DataFrame, family: str) -> pl.DataFrame

Exponentiate estimates and CIs to odds ratio scale.

Parameters:

NameTypeDescriptionDefault
dfDataFrameParams DataFrame with ‘estimate’ and optionally ‘ci_lower’, ‘ci_upper’.required
familystrModel family string (must be “binomial”).required

Returns:

TypeDescription
DataFrameDataFrame with exponentiated estimate/CI columns.
convert_to_response_scale
convert_to_response_scale(df: pl.DataFrame, link: str) -> pl.DataFrame

Apply inverse link to transform estimates to response scale.

Parameters:

NameTypeDescriptionDefault
dfDataFrameEffects DataFrame with ‘estimate’ and optionally ‘ci_lower’, ‘ci_upper’.required
linkstrLink function name (e.g., “logit”, “log”).required

Returns:

TypeDescription
DataFrameDataFrame with transformed columns on response scale.
filter_effects_dataframe
filter_effects_dataframe(df: pl.DataFrame, *, terms: list[str] | str | None = None, levels: list[str] | str | None = None, contrasts: list[str] | str | None = None) -> pl.DataFrame

Filter an effects DataFrame by term, level, or contrast.

Pure function that applies one or more filters to a marginal-effects DataFrame produced by the effects property.

Parameters:

NameTypeDescriptionDefault
dfDataFrameEffects DataFrame (from model.effects).required
termslist[str] | str | NoneTerm name(s) to keep (filters the term column).None
levelslist[str] | str | NoneLevel name(s) to keep (filters the first grid column, i.e. the first column that is not a standard statistic column).None
contrastslist[str] | str | NoneContrast label(s) to keep (filters the contrast column).None

Returns:

TypeDescription
DataFrameFiltered DataFrame containing only matching rows.

validation

Validation guards for marginal effects operations.

Shared precondition checks for compute_emm, compute_slopes, etc.

Functions:

NameDescription
validate_focal_varGuard: focal variable must be a predictor, not the response.

Classes

Functions

validate_focal_var
validate_focal_var(bundle: DataBundle, focal_var: str) -> None

Guard: focal variable must be a predictor, not the response.

Parameters:

NameTypeDescriptionDefault
bundleDataBundleDataBundle with model metadata.required
focal_varstrName of the variable to compute marginal effects for.required