Unified RNG abstraction for JAX and NumPy backends.
Classes:
| Name | Description |
|---|---|
RNG | Unified RNG wrapper for JAX and NumPy backends. |
Functions:
| Name | Description |
|---|---|
build_rng | Create RNG from seed (convenience function). |
Classes¶
RNG¶
RNG(key: Any, backend: str | None = None)Unified RNG wrapper for JAX and NumPy backends.
This class wraps either a JAX PRNGKey or a NumPy Generator, providing a consistent interface for random number generation across backends.
The key insight is that JAX uses functional RNG (immutable keys that are split to produce new keys), while NumPy uses stateful RNG (mutable generators). This class bridges the gap by:
JAX backend: Wraps a PRNGKey, split() returns new RNG objects with split keys
NumPy backend: Wraps a Generator, split() returns new RNG objects with spawned generators (using SeedSequence.spawn)
Attributes:
| Name | Type | Description |
|---|---|---|
key | Any | The underlying random state. For JAX backend, this is a PRNGKey that can be passed directly to jax.random functions. For NumPy backend, this is a Generator (but prefer using the helper methods). |
Examples:
>>> rng = RNG.from_seed(42)
>>> rng1, rng2 = rng.split()
>>> values = rng1.normal(shape=(100,))Functions:
| Name | Description |
|---|---|
bernoulli | Generate Bernoulli random values. |
choice | Random choice from [0, 1, ..., n-1]. |
from_seed | Create RNG from integer seed. |
normal | Generate standard normal random values. |
permutation | Generate random permutation of [0, 1, ..., n-1]. |
poisson | Generate Poisson random values. |
split | Split RNG into n independent RNGs. |
split_one | Split into two RNGs, returning (new_self, child). |
uniform | Generate uniform random values. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key | Any | JAX PRNGKey or NumPy Generator. | required |
backend | str | None | Backend name (“jax” or “numpy”). If None, auto-detects. | None |
Attributes¶
key¶
key: AnyGet the underlying key/generator.
For JAX backend, returns a PRNGKey suitable for jax.random functions. For NumPy backend, returns a Generator.
Examples:
>>> rng = RNG.from_seed(42)
>>> # JAX usage
>>> values = jax.random.normal(rng.key, shape=(100,))Functions¶
bernoulli¶
bernoulli(p: float | Any, shape: tuple[int, ...] | None = None) -> AnyGenerate Bernoulli random values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p | float | Any | Probability of True/1. Can be a scalar or an array of probabilities for element-wise sampling. | required |
shape | tuple[int, ...] | None | Shape of output array. If None and p is an array, samples one value per element of p. If None and p is scalar, returns a single sample. | None |
Returns:
| Type | Description |
|---|---|
Any | Boolean array (JAX) or int array 0/1 (NumPy). Shape is determined |
Any | by the shape parameter if provided, otherwise by p’s shape. |
choice¶
choice(n: int, shape: tuple[int, ...], replace: bool = True) -> AnyRandom choice from [0, 1, ..., n-1].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n | int | Upper bound (exclusive) for choices. | required |
shape | tuple[int, ...] | Shape of output array. | required |
replace | bool | Whether to sample with replacement. | True |
Returns:
| Type | Description |
|---|---|
Any | Array of shape shape with random integers in [0, n). |
from_seed¶
from_seed(seed: int | None = None) -> RNGCreate RNG from integer seed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed | int | None | Integer seed for reproducibility. If None, uses random seed. | None |
Returns:
| Type | Description |
|---|---|
RNG | New RNG instance. |
Examples:
>>> rng = RNG.from_seed(42)
>>> rng.split(n=3) # Returns list of 3 RNGsnormal¶
normal(shape: tuple[int, ...]) -> AnyGenerate standard normal random values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of output array. | required |
Returns:
| Type | Description |
|---|---|
Any | Array of shape shape with N(0, 1) values. |
permutation¶
permutation(n: int) -> AnyGenerate random permutation of [0, 1, ..., n-1].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n | int | Length of permutation. | required |
Returns:
| Type | Description |
|---|---|
Any | Array of shape (n,) containing a random permutation. |
poisson¶
poisson(lam: Any) -> AnyGenerate Poisson random values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lam | Any | Rate parameter (can be scalar or array). | required |
Returns:
| Type | Description |
|---|---|
Any | Array of same shape as lam with Poisson samples. |
split¶
split(n: int = 2) -> list[RNG]Split RNG into n independent RNGs.
This is the key operation for parallel random number generation. Each returned RNG is independent and can be used in parallel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n | int | Number of RNGs to create (default: 2). | 2 |
Returns:
| Type | Description |
|---|---|
list[ RNG ] | List of n independent RNG objects. |
Examples:
>>> rng = RNG.from_seed(42)
>>> rng1, rng2 = rng.split()
>>> keys = rng.split(n=10) # For parallel operationssplit_one¶
split_one() -> tuple[RNG, RNG]Split into two RNGs, returning (new_self, child).
Common pattern: advance self and get one child for use.
Returns:
| Type | Description |
|---|---|
tuple[ RNG , RNG ] | Tuple of (new_self, child) RNGs. |
Examples:
>>> rng = RNG.from_seed(42)
>>> rng, child = rng.split_one()
>>> values = child.normal(shape=(100,))uniform¶
uniform(shape: tuple[int, ...], minval: float = 0.0, maxval: float = 1.0) -> AnyGenerate uniform random values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of output array. | required |
minval | float | Minimum value (inclusive). | 0.0 |
maxval | float | Maximum value (exclusive). | 1.0 |
Returns:
| Type | Description |
|---|---|
Any | Array of shape shape with U(minval, maxval) values. |
Functions¶
build_rng¶
build_rng(seed: int | None = None) -> RNGCreate RNG from seed (convenience function).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed | int | None | Integer seed. If None, uses random seed. | None |
Returns:
| Type | Description |
|---|---|
RNG | New RNG instance. |
Examples:
>>> rng = build_rng(42)
>>> values = rng.normal(shape=(100,))