Memory-aware batch sizing for JAX operations.
Functions:
| Name | Description |
|---|---|
compute_batch_size | Compute optimal batch size for jax.lax.map. |
get_available_memory_gb | Query available system memory in GB. |
Attributes¶
Functions¶
compute_batch_size¶
compute_batch_size(*, n_items: int, bytes_per_item: int, max_mem: float | None = None, min_batch: int = MIN_BATCH_SIZE, max_batch: int = MAX_BATCH_SIZE) -> intCompute optimal batch size for jax.lax.map.
Balances memory usage, parallelism efficiency, and cache locality. Uses a safety margin to account for XLA compilation overhead and intermediate arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_items | int | Total items to process (e.g., n_boot, n_perm). | required |
bytes_per_item | int | Memory per output item in bytes. For bootstrap/permutation with p coefficients: p * 8 (float64). | required |
max_mem | float | None | Fraction of available system memory to use (0.0-1.0). None defaults to 0.5 (50%). Values outside [0.01, 1.0] are clamped. | None |
min_batch | int | Minimum batch size for parallelism efficiency. | MIN_BATCH_SIZE |
max_batch | int | Maximum batch size for cache efficiency. | MAX_BATCH_SIZE |
Returns:
| Type | Description |
|---|---|
int | Batch size clamped to [min_batch, min(max_batch, n_items)]. |
Examples:
>>> # 10000 bootstrap samples, each producing 100 float64 coefficients
>>> batch_size = compute_batch_size(n_items=10000, bytes_per_item=100 * 8)
>>> keys = jax.random.split(key, 10000)
>>> boot_samples = jax.lax.map(fn, keys, batch_size=batch_size)get_available_memory_gb¶
get_available_memory_gb() -> floatQuery available system memory in GB.
Attempts to detect available memory using multiple fallback strategies:
JAX device stats (for GPU/TPU with memory tracking)
psutil for CPU/system memory
Conservative 4GB fallback with warning
Returns:
| Type | Description |
|---|---|
float | Available memory in GB (minimum 0.5 GB). |
Examples:
>>> mem_gb = get_available_memory_gb()
>>> print(f"Available: {mem_gb:.1f} GB")