Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Memory-aware batch sizing for JAX operations.

Functions:

NameDescription
compute_batch_sizeCompute optimal batch size for jax.lax.map.
get_available_memory_gbQuery available system memory in GB.

Attributes

Functions

compute_batch_size

compute_batch_size(*, n_items: int, bytes_per_item: int, max_mem: float | None = None, min_batch: int = MIN_BATCH_SIZE, max_batch: int = MAX_BATCH_SIZE) -> int

Compute optimal batch size for jax.lax.map.

Balances memory usage, parallelism efficiency, and cache locality. Uses a safety margin to account for XLA compilation overhead and intermediate arrays.

Parameters:

NameTypeDescriptionDefault
n_itemsintTotal items to process (e.g., n_boot, n_perm).required
bytes_per_itemintMemory per output item in bytes. For bootstrap/permutation with p coefficients: p * 8 (float64).required
max_memfloat | NoneFraction of available system memory to use (0.0-1.0). None defaults to 0.5 (50%). Values outside [0.01, 1.0] are clamped.None
min_batchintMinimum batch size for parallelism efficiency.MIN_BATCH_SIZE
max_batchintMaximum batch size for cache efficiency.MAX_BATCH_SIZE

Returns:

TypeDescription
intBatch size clamped to [min_batch, min(max_batch, n_items)].

Examples:

>>> # 10000 bootstrap samples, each producing 100 float64 coefficients
>>> batch_size = compute_batch_size(n_items=10000, bytes_per_item=100 * 8)
>>> keys = jax.random.split(key, 10000)
>>> boot_samples = jax.lax.map(fn, keys, batch_size=batch_size)

get_available_memory_gb

get_available_memory_gb() -> float

Query available system memory in GB.

Attempts to detect available memory using multiple fallback strategies:

  1. JAX device stats (for GPU/TPU with memory tracking)

  2. psutil for CPU/system memory

  3. Conservative 4GB fallback with warning

Returns:

TypeDescription
floatAvailable memory in GB (minimum 0.5 GB).

Examples:

>>> mem_gb = get_available_memory_gb()
>>> print(f"Available: {mem_gb:.1f} GB")