Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Style Guide & Design Constraints

Architecture rationale, coding patterns, and project-specific conventions

UCSD Psychology

The “why” and “how” behind CLAUDE.md’s rules.


1. Scope & Complexity Budget

Spend complexity on: Numerical correctness, performance at 1k–100k scale, ergonomic API.

Don’t spend on: Edge cases that don’t exist in practice, configurability nobody asked for, “defensive” code for impossible states, hypothetical model #5. Only add complexity when a model type concretely requires it.

Out of scope

Facade anti-regression: model/ allowlist

model/ is a facade layer, not a workflow layer. A model method may:

  1. Validate user-facing arguments (types, enum values, mutual exclusivity)

  2. Delegate to one internal owner (a single function call into internal/)

  3. Assign returned state to private attrs fields

If a method coordinates multiple internal phases, mutates multiple pieces of internal lifecycle state, or contains subsystem-specific policy (e.g. “rebuild bundle if weights changed, then fit, then augment”), that logic belongs in internal/. The two lifecycle modules — internal/fit/lifecycle.py and internal/simulation/lifecycle.py — own multi-step orchestration for fit() and simulate() respectively.

Anti-patterns (move to internal/ if you see these in model/):


2. Container Conventions

Containers live in internal/containers/. Always @frozen. Augment with attrs.evolve().

@frozen
class MeeState:
    """Marginal effects / EMM results.

    Created by: compute_emm(), compute_slopes()
    Consumed by: model.effects property, compute_mee_inference()
    Augmented by: attrs.evolve() after inference adds SEs/CIs
    """

    grid: pl.DataFrame = field(repr=False)
    estimate: np.ndarray = field(validator=is_ndarray)
    type: str = field(validator=validators.in_(("means", "slopes", "contrasts")))
    se: np.ndarray | None = field(default=None, validator=is_optional_ndarray)
from bossanova.internal.containers.schemas import Col

# GOOD
data = {Col.TERM: terms, Col.ESTIMATE: estimates}

# BAD
data = {"term": terms, "estimate": estimates}

Pass schema= to pl.DataFrame() when columns are fully static. Skip schema= when columns are dynamic (effects grid, varying offsets, predictions with optional columns).


3. Backend Patterns

Dual environment

EnvironmentBackendConstraint
Native PythonJAX (default) or NumPyFull functionality
Browser/PyodideNumPy onlyNo unconditional JAX imports, no filesystem/subprocess assumptions

Pattern A: Early Dispatch

Use when JAX needs different control flow (lax.while_loop, lax.cond):

def fit_glm_irls(...):
    backend = get_backend()
    if backend == "jax" and family.robust_weights is None:
        return _fit_glm_irls_jax(...)   # lax.while_loop
    else:
        return _fit_glm_irls_numpy(...)  # Python loop

Pattern B: Polymorphic Ops

Use when algorithm is identical between backends:

def compute_leverage(X: np.ndarray) -> np.ndarray:
    ops = get_ops()
    Q, _ = ops.qr(X)
    return ops.np.sum(Q**2, axis=1)

JIT Caching

Cache per backend. NumPy’s ops.jit is a no-op. Use JAX only for hot loops (IRLS, resampling) where JIT gives 2-4x speedup — not single-call operations or entire workflows (~100ms compilation overhead).

_cache: dict[str, Any] = {}

def _make_fn(ops):
    def _core(X, y):
        ...
    return ops.jit(_core)

def _get_fn():
    backend = get_backend()
    if backend not in _cache:
        _cache[backend] = _make_fn(get_ops())
    return _cache[backend]

Direct @jax.jit only for small utilities that never need backend switching.

JAX Import Pattern

try:
    import jax
    import jax.numpy as jnp
except ImportError:
    import numpy as jnp
    class _FakeJax:
        @staticmethod
        def jit(fn): return fn
    jax = _FakeJax()

RNG

Always use RNG from bossanova.internal.maths.rng, never direct jax.random calls.


4. Function Signatures

Canonical parameter order

Containers first, then keyword-only scalars:

def dispatch_params_inference(
    *,
    how: str,
    spec: ModelSpec,
    bundle: DataBundle,
    fit: FitState,
    conf_level: float,
    n_boot: int = 1000,
    seed: int | None = None,
) -> InferenceState:

**kwargs rules


5. Naming Conventions

Verb prefixes for functions in internal/

VerbMeaningExamples
compute_Pure calculation → valuecompute_emm, compute_vcov
build_Construct complex objectbuild_reference_grid, build_model_spec
dispatch_Route to implementationdispatch_infer, dispatch_solver
parse_String → structured dataparse_explore_formula
fit_Solver entry (→ FitState)fit_model, fit_glm_irls
apply_Transform dataapply_contrasts, apply_link_inverse
validate_Check + raise on failurevalidate_fit_method
resolve_Ambiguity → concrete choiceresolve_contrast_specs
generate_Create data/indicesgenerate_lm_data, generate_kfold_splits
run_Multi-step workflowrun_power_analysis

Without verb prefix


6. Polars Idioms

No pandas internally. Convert at API boundary only:

if not isinstance(data, pl.DataFrame):
    data = pl.from_pandas(data)

Chain within one logical step; break at phase boundaries:

# One step: build and annotate grid
grid = (
    _cartesian_product(levels)
    .with_columns(pl.lit(mean_val).alias("covariate"))
    .with_row_index("_row_id")
)

# Break at phase boundary
predictions = model.predict(grid)
grid = grid.with_columns(pl.Series("fit", predictions))

Use lazy for multi-op sequences; .collect() before returning.

Avoid .iter_rows() — use vectorized ops or .to_dicts().

Viz exception: Seaborn’s map_dataframe may pass pandas. Use generic accessors (list(data["col"]), np.asarray(data["col"])) in drawing functions.


7. Resampling Architecture

Use standalone functions (not closures) for joblib compatibility:

# GOOD: Standalone — picklable
def bootstrap_single_iteration(
    rng: RNG, spec: ModelSpec, bundle: DataBundle,
    indices: np.ndarray, n_params: int,
) -> np.ndarray:
    bundle_boot = resample_bundle(bundle, indices)
    try:
        return fit_model(spec, bundle_boot).coef
    except Exception:
        return np.full(n_params, np.nan)

# BAD: Closure — not picklable
def bootstrap(spec, bundle, n_boot):
    def single(key):  # Can't serialize
        ...

8. Error Messages

Show context, cause, and fix:

raise ValueError(
    f'Variable "{name}" not found in data.\n\n'
    f"Available columns: {', '.join(cols[:10])}"
    + f"\n\nDid you mean: {suggestion}?"
)

Where to validate

LocationWhat
Container __init__ (validators)Field types, allowed values
Model class methodsState machine (fitted? has data?)
Operations (top of function)Semantic preconditions
Maths functionsNever — trust internal callers

except Exception only in resampling loops (NaN sentinel for failed replicates).


9. Testing Philosophy

LayerValidatesNotes
Parity (R comparison)Correctness vs RFree — proves correctness cheaply
Hypothesis (property-based)Mathematical invariantsExecutable documentation of the math
Recovery (Monte Carlo)Statistical properties (bias, coverage)Medium complexity
Redteam (pathology)Extreme/degenerate inputsExploratory, use xfail

Test domain functions in internal/ directly. Integration tests cover the model class via R parity.


10. Optimization Stack

nlopt + BOBYQA is the default optimizer (matches lme4). Lives in internal/maths/solvers/. Includes lme4’s numerical tricks: intelligent initialization, restart logic, log-parameterization, per-parameter scaling.