Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Unified RNG abstraction for JAX and NumPy backends.

Classes:

NameDescription
RNGUnified RNG wrapper for JAX and NumPy backends.

Functions:

NameDescription
build_rngCreate RNG from seed (convenience function).

Classes

RNG

RNG(key: Any, backend: str | None = None)

Unified RNG wrapper for JAX and NumPy backends.

This class wraps either a JAX PRNGKey or a NumPy Generator, providing a consistent interface for random number generation across backends.

The key insight is that JAX uses functional RNG (immutable keys that are split to produce new keys), while NumPy uses stateful RNG (mutable generators). This class bridges the gap by:

Attributes:

NameTypeDescription
keyAnyThe underlying random state. For JAX backend, this is a PRNGKey that can be passed directly to jax.random functions. For NumPy backend, this is a Generator (but prefer using the helper methods).

Examples:

>>> rng = RNG.from_seed(42)
>>> rng1, rng2 = rng.split()
>>> values = rng1.normal(shape=(100,))

Functions:

NameDescription
bernoulliGenerate Bernoulli random values.
choiceRandom choice from [0, 1, ..., n-1].
from_seedCreate RNG from integer seed.
normalGenerate standard normal random values.
permutationGenerate random permutation of [0, 1, ..., n-1].
poissonGenerate Poisson random values.
splitSplit RNG into n independent RNGs.
split_oneSplit into two RNGs, returning (new_self, child).
uniformGenerate uniform random values.

Parameters:

NameTypeDescriptionDefault
keyAnyJAX PRNGKey or NumPy Generator.required
backendstr | NoneBackend name (“jax” or “numpy”). If None, auto-detects.None

Attributes

key
key: Any

Get the underlying key/generator.

For JAX backend, returns a PRNGKey suitable for jax.random functions. For NumPy backend, returns a Generator.

Examples:

>>> rng = RNG.from_seed(42)
>>> # JAX usage
>>> values = jax.random.normal(rng.key, shape=(100,))

Functions

bernoulli
bernoulli(p: float | Any, shape: tuple[int, ...] | None = None) -> Any

Generate Bernoulli random values.

Parameters:

NameTypeDescriptionDefault
pfloat | AnyProbability of True/1. Can be a scalar or an array of probabilities for element-wise sampling.required
shapetuple[int, ...] | NoneShape of output array. If None and p is an array, samples one value per element of p. If None and p is scalar, returns a single sample.None

Returns:

TypeDescription
AnyBoolean array (JAX) or int array 0/1 (NumPy). Shape is determined
Anyby the shape parameter if provided, otherwise by p’s shape.
choice
choice(n: int, shape: tuple[int, ...], replace: bool = True) -> Any

Random choice from [0, 1, ..., n-1].

Parameters:

NameTypeDescriptionDefault
nintUpper bound (exclusive) for choices.required
shapetuple[int, ...]Shape of output array.required
replaceboolWhether to sample with replacement.True

Returns:

TypeDescription
AnyArray of shape shape with random integers in [0, n).
from_seed
from_seed(seed: int | None = None) -> RNG

Create RNG from integer seed.

Parameters:

NameTypeDescriptionDefault
seedint | NoneInteger seed for reproducibility. If None, uses random seed.None

Returns:

TypeDescription
RNGNew RNG instance.

Examples:

>>> rng = RNG.from_seed(42)
>>> rng.split(n=3)  # Returns list of 3 RNGs
normal
normal(shape: tuple[int, ...]) -> Any

Generate standard normal random values.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of output array.required

Returns:

TypeDescription
AnyArray of shape shape with N(0, 1) values.
permutation
permutation(n: int) -> Any

Generate random permutation of [0, 1, ..., n-1].

Parameters:

NameTypeDescriptionDefault
nintLength of permutation.required

Returns:

TypeDescription
AnyArray of shape (n,) containing a random permutation.
poisson
poisson(lam: Any) -> Any

Generate Poisson random values.

Parameters:

NameTypeDescriptionDefault
lamAnyRate parameter (can be scalar or array).required

Returns:

TypeDescription
AnyArray of same shape as lam with Poisson samples.
split
split(n: int = 2) -> list[RNG]

Split RNG into n independent RNGs.

This is the key operation for parallel random number generation. Each returned RNG is independent and can be used in parallel.

Parameters:

NameTypeDescriptionDefault
nintNumber of RNGs to create (default: 2).2

Returns:

TypeDescription
list[ RNG ]List of n independent RNG objects.

Examples:

>>> rng = RNG.from_seed(42)
>>> rng1, rng2 = rng.split()
>>> keys = rng.split(n=10)  # For parallel operations
split_one
split_one() -> tuple[RNG, RNG]

Split into two RNGs, returning (new_self, child).

Common pattern: advance self and get one child for use.

Returns:

TypeDescription
tuple[ RNG , RNG ]Tuple of (new_self, child) RNGs.

Examples:

>>> rng = RNG.from_seed(42)
>>> rng, child = rng.split_one()
>>> values = child.normal(shape=(100,))
uniform
uniform(shape: tuple[int, ...], minval: float = 0.0, maxval: float = 1.0) -> Any

Generate uniform random values.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of output array.required
minvalfloatMinimum value (inclusive).0.0
maxvalfloatMaximum value (exclusive).1.0

Returns:

TypeDescription
AnyArray of shape shape with U(minval, maxval) values.

Functions

build_rng

build_rng(seed: int | None = None) -> RNG

Create RNG from seed (convenience function).

Parameters:

NameTypeDescriptionDefault
seedint | NoneInteger seed. If None, uses random seed.None

Returns:

TypeDescription
RNGNew RNG instance.

Examples:

>>> rng = build_rng(42)
>>> values = rng.normal(shape=(100,))