Backend abstraction layer for JAX/NumPy compatibility.
Attributes:
| Name | Type | Description |
|---|---|---|
BackendName |
Classes:
| Name | Description |
|---|---|
ArrayOps | Protocol for array operations across backends. |
Functions:
| Name | Description |
|---|---|
backend | Context manager for temporary backend switching. |
clear_ops_cache | Clear the backend operations cache. |
get_backend | Get the current backend name. |
get_ops | Get array operations for the current backend. |
lock_backend | Lock the backend to prevent switching after model fitting. |
reset_backend | Reset backend state (for testing only). |
set_backend | Set the backend to use for computations. |
Modules:
| Name | Description |
|---|---|
dispatch | Backend detection, switching, and ArrayOps dispatch. |
jax | JAX backend implementation. |
numpy | NumPy backend implementation. |
protocol | Array operations protocol for backend abstraction. |
Attributes¶
BackendName¶
BackendName = Literal['jax', 'numpy']Classes¶
ArrayOps¶
Bases: Protocol
Protocol for array operations across backends.
This defines the interface that both NumPyBackend and JAXBackend must implement. All linear algebra and array manipulation needed by bossanova should go through this interface.
Attributes:
| Name | Type | Description |
|---|---|---|
np | Any | The numpy-like module (numpy or jax.numpy). |
Functions:
| Name | Description |
|---|---|
arange | Create array with evenly spaced values. |
asarray | Convert input to array. |
cholesky | Cholesky decomposition. |
det | Matrix determinant. |
eigh | Eigendecomposition of symmetric matrix. |
eye | Create identity matrix. |
full | Create array filled with a scalar value. |
inv | Matrix inverse. |
jit | JIT-compile a function. |
lstsq | Least squares solution. |
norm | Matrix or vector norm. |
ones | Create array of ones. |
qr | QR decomposition. |
solve | Solve linear system a @ x = b. |
solve_triangular | Solve triangular linear system. |
svd | Singular value decomposition. |
vmap | Vectorize a function over a batch dimension. |
zeros | Create array of zeros. |
Attributes¶
np¶
np: AnyFunctions¶
arange¶
arange(start: int, stop: int | None = None, step: int = 1) -> ArrayCreate array with evenly spaced values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start | int | Start value (or stop if stop is None). | required |
stop | int | None | Stop value (exclusive). | None |
step | int | Step size. | 1 |
Returns:
| Type | Description |
|---|---|
Array | Array of evenly spaced values. |
asarray¶
asarray(x: Any, dtype: Any = None) -> ArrayConvert input to array.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Any | Input data (list, tuple, ndarray, etc.). | required |
dtype | Any | Desired data type. | None |
Returns:
| Type | Description |
|---|---|
Array | Array of the appropriate backend type. |
cholesky¶
cholesky(a: Array) -> ArrayCholesky decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Symmetric positive-definite matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Lower triangular Cholesky factor L such that a = L @ L.T. |
det¶
det(a: Array) -> ArrayMatrix determinant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Square matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Determinant of a. |
eigh¶
eigh(a: Array) -> tuple[Array, Array]Eigendecomposition of symmetric matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Symmetric matrix. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array ] | Tuple (eigenvalues, eigenvectors). |
eye¶
eye(n: int, dtype: Any = None) -> ArrayCreate identity matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n | int | Size of the identity matrix. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Identity matrix of shape (n, n). |
full¶
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> ArrayCreate array filled with a scalar value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
fill_value | float | Value to fill the array with. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array filled with fill_value. |
inv¶
inv(a: Array) -> ArrayMatrix inverse.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Square matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Inverse of a. |
jit¶
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> CallableJIT-compile a function.
For NumPy backend, this is a no-op (returns the function unchanged). For JAX backend, this wraps the function with jax.jit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to compile. | required |
donate_argnums | tuple[int, ...] | None | Positional arg indices whose buffers can be reused by XLA (memory optimization). Donated buffers must not be accessed after the call. Ignored on NumPy backend. | None |
Returns:
| Type | Description |
|---|---|
Callable | JIT-compiled function (or original for NumPy). |
lstsq¶
lstsq(a: Array, b: Array) -> tuple[Array, Any, Any, Any]Least squares solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Coefficient matrix. | required |
b | Array | Right-hand side. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Any, Any, Any] | Tuple (solution, residuals, rank, singular_values). |
norm¶
norm(a: Array, ord: Any = None, axis: int | None = None) -> ArrayMatrix or vector norm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Input array. | required |
ord | Any | Order of the norm. | None |
axis | int | None | Axis along which to compute. | None |
Returns:
| Type | Description |
|---|---|
Array | Norm of the array. |
ones¶
ones(shape: tuple[int, ...], dtype: Any = None) -> ArrayCreate array of ones.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array of ones. |
qr¶
qr(a: Array) -> tuple[Array, Array]QR decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Matrix to decompose. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array ] | Tuple (Q, R) where Q is orthogonal and R is upper triangular. |
solve¶
solve(a: Array, b: Array) -> ArraySolve linear system a @ x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Coefficient matrix. | required |
b | Array | Right-hand side. | required |
Returns:
| Type | Description |
|---|---|
Array | Solution x. |
solve_triangular¶
solve_triangular(a: Array, b: Array, lower: bool = False) -> ArraySolve triangular linear system.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Triangular coefficient matrix. | required |
b | Array | Right-hand side. | required |
lower | bool | If True, a is lower triangular. | False |
Returns:
| Type | Description |
|---|---|
Array | Solution x. |
svd¶
svd(a: Array, full_matrices: bool = True) -> tuple[Array, Array, Array]Singular value decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Matrix to decompose. | required |
full_matrices | bool | If True, return full U and Vt matrices. | True |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array , Array ] | Tuple (U, s, Vt) such that a = U @ diag(s) @ Vt. |
vmap¶
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> CallableVectorize a function over a batch dimension.
For NumPy backend, this uses a Python loop with np.stack. For JAX backend, this uses jax.vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to vectorize. | required |
in_axes | int | tuple[int, ...] | Axis to map over for each input. | 0 |
Returns:
| Type | Description |
|---|---|
Callable | Vectorized function. |
zeros¶
zeros(shape: tuple[int, ...], dtype: Any = None) -> ArrayCreate array of zeros.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array of zeros. |
Functions¶
backend¶
backend(name: BackendName) -> Iterator[None]Context manager for temporary backend switching.
This is primarily intended for testing. It temporarily switches the backend and restores the previous state on exit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | BackendName | Backend name to use within the context. | required |
Yields None; restores previous backend on exit. Yields None; restores previous backend on exit.
Examples:
import bossanova
with bossanova.backend("numpy"):
print(bossanova.get_backend())
# 'numpy'clear_ops_cache¶
clear_ops_cache() -> NoneClear the backend operations cache.
This is primarily for testing purposes. Clears the cached backend instances so that the next call to get_ops() creates a fresh instance.
This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.
get_backend¶
get_backend() -> BackendNameGet the current backend name.
If no backend has been explicitly set, auto-detects the best available backend on first call.
Returns:
| Type | Description |
|---|---|
BackendName | The current backend name (‘jax’ or ‘numpy’). |
Examples:
import bossanova
bossanova.get_backend()
# 'jax'get_ops¶
get_ops() -> 'ArrayOps'Get array operations for the current backend.
Returns the appropriate backend instance (NumPyBackend or JAXBackend) based on the current backend setting. The instance is cached so that repeated calls return the same object.
Returns:
| Type | Description |
|---|---|
‘ArrayOps’ | ArrayOps instance for the current backend. |
Examples:
>>> from maths import get_ops
>>> ops = get_ops()
>>> X = ops.asarray([[1, 2], [3, 4]])
>>> L = ops.cholesky(X @ X.T)lock_backend¶
lock_backend() -> NoneLock the backend to prevent switching after model fitting.
This is called internally when models are fitted to ensure consistent behavior throughout a session.
reset_backend¶
reset_backend() -> NoneReset backend state (for testing only).
This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.
set_backend¶
set_backend(name: BackendName) -> NoneSet the backend to use for computations.
Must be called before any model fitting occurs. Once a model has been fitted, the backend is locked and cannot be changed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | BackendName | Backend name, either ‘jax’ or ‘numpy’. | required |
Examples:
import bossanova
bossanova.set_backend("numpy")
bossanova.get_backend()
# 'numpy'Modules¶
dispatch¶
Backend detection, switching, and ArrayOps dispatch.
Functions:
| Name | Description |
|---|---|
backend | Context manager for temporary backend switching. |
clear_ops_cache | Clear the backend operations cache. |
get_backend | Get the current backend name. |
get_ops | Get array operations for the current backend. |
lock_backend | Lock the backend to prevent switching after model fitting. |
reset_backend | Reset backend state (for testing only). |
set_backend | Set the backend to use for computations. |
Attributes:
| Name | Type | Description |
|---|---|---|
BackendName |
Attributes¶
BackendName¶
BackendName = Literal['jax', 'numpy']Classes¶
Functions¶
backend¶
backend(name: BackendName) -> Iterator[None]Context manager for temporary backend switching.
This is primarily intended for testing. It temporarily switches the backend and restores the previous state on exit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | BackendName | Backend name to use within the context. | required |
Yields None; restores previous backend on exit. Yields None; restores previous backend on exit.
Examples:
import bossanova
with bossanova.backend("numpy"):
print(bossanova.get_backend())
# 'numpy'clear_ops_cache¶
clear_ops_cache() -> NoneClear the backend operations cache.
This is primarily for testing purposes. Clears the cached backend instances so that the next call to get_ops() creates a fresh instance.
This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.
get_backend¶
get_backend() -> BackendNameGet the current backend name.
If no backend has been explicitly set, auto-detects the best available backend on first call.
Returns:
| Type | Description |
|---|---|
BackendName | The current backend name (‘jax’ or ‘numpy’). |
Examples:
import bossanova
bossanova.get_backend()
# 'jax'get_ops¶
get_ops() -> 'ArrayOps'Get array operations for the current backend.
Returns the appropriate backend instance (NumPyBackend or JAXBackend) based on the current backend setting. The instance is cached so that repeated calls return the same object.
Returns:
| Type | Description |
|---|---|
‘ArrayOps’ | ArrayOps instance for the current backend. |
Examples:
>>> from maths import get_ops
>>> ops = get_ops()
>>> X = ops.asarray([[1, 2], [3, 4]])
>>> L = ops.cholesky(X @ X.T)lock_backend¶
lock_backend() -> NoneLock the backend to prevent switching after model fitting.
This is called internally when models are fitted to ensure consistent behavior throughout a session.
reset_backend¶
reset_backend() -> NoneReset backend state (for testing only).
This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.
set_backend¶
set_backend(name: BackendName) -> NoneSet the backend to use for computations.
Must be called before any model fitting occurs. Once a model has been fitted, the backend is locked and cannot be changed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | BackendName | Backend name, either ‘jax’ or ‘numpy’. | required |
Examples:
import bossanova
bossanova.set_backend("numpy")
bossanova.get_backend()
# 'numpy'jax¶
JAX backend implementation.
This module provides the JAX implementation of array operations for bossanova. It implements the ArrayOps protocol.
Note: JAX x64 config is set in bossanova/__init__.py at package import
time, before any submodules are loaded. By the time this module is imported
(lazily, via get_ops()), x64 is already enabled.
Classes:
| Name | Description |
|---|---|
JAXBackend | JAX implementation of array operations. |
Classes¶
JAXBackend¶
JAXBackend() -> NoneJAX implementation of array operations.
This backend uses JAX for array operations and provides JIT compilation and vectorization via vmap.
Attributes:
| Name | Type | Description |
|---|---|---|
np | The jax.numpy module. | |
jax | The jax module (for jit, vmap, etc.). |
Functions:
| Name | Description |
|---|---|
arange | Create array with evenly spaced values. |
asarray | Convert input to JAX array. |
cho_solve | Solve using Cholesky factor. |
cholesky | Cholesky decomposition (lower triangular). |
det | Matrix determinant. |
eigh | Eigendecomposition of symmetric matrix. |
eye | Create identity matrix. |
full | Create array filled with a scalar value. |
grad | Compute gradient of a function. |
inv | Matrix inverse. |
jit | JIT-compile a function using JAX. |
lstsq | Least squares solution. |
norm | Matrix or vector norm. |
ones | Create array of ones. |
qr | QR decomposition. |
qr_pivoted | Pivoted QR decomposition for rank detection. |
solve | Solve linear system a @ x = b. |
solve_triangular | Solve triangular linear system. |
svd | Singular value decomposition. |
value_and_grad | Compute value and gradient of a function. |
vmap | Vectorize a function over a batch dimension using JAX. |
zeros | Create array of zeros. |
Attributes¶
jax¶
jax = jaxnp¶
np = jnpFunctions¶
arange¶
arange(start: int, stop: int | None = None, step: int = 1) -> jax.ArrayCreate array with evenly spaced values.
asarray¶
asarray(x: Any, dtype: Any = None) -> jax.ArrayConvert input to JAX array.
cho_solve¶
cho_solve(c_and_lower: tuple[jax.Array, bool], b: jax.Array) -> jax.ArraySolve using Cholesky factor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
c_and_lower | tuple[ Array , bool] | Tuple of (cholesky_factor, is_lower). | required |
b | Array | Right-hand side. | required |
Returns:
| Type | Description |
|---|---|
Array | Solution x to A @ x = b where A = L @ L.T. |
cholesky¶
cholesky(a: jax.Array) -> jax.ArrayCholesky decomposition (lower triangular).
det¶
det(a: jax.Array) -> jax.ArrayMatrix determinant.
eigh¶
eigh(a: jax.Array) -> tuple[jax.Array, jax.Array]Eigendecomposition of symmetric matrix.
eye¶
eye(n: int, dtype: Any = None) -> jax.ArrayCreate identity matrix.
full¶
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> jax.ArrayCreate array filled with a scalar value.
grad¶
grad(fn: Callable, argnums: int = 0) -> CallableCompute gradient of a function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to differentiate. | required |
argnums | int | Which argument to differentiate with respect to. | 0 |
Returns:
| Type | Description |
|---|---|
Callable | Function that computes the gradient. |
inv¶
inv(a: jax.Array) -> jax.ArrayMatrix inverse.
jit¶
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> CallableJIT-compile a function using JAX.
lstsq¶
lstsq(a: jax.Array, b: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]Least squares solution.
norm¶
norm(a: jax.Array, ord: Any = None, axis: int | None = None) -> jax.ArrayMatrix or vector norm.
ones¶
ones(shape: tuple[int, ...], dtype: Any = None) -> jax.ArrayCreate array of ones.
qr¶
qr(a: jax.Array) -> tuple[jax.Array, jax.Array]QR decomposition.
qr_pivoted¶
qr_pivoted(a: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]Pivoted QR decomposition for rank detection.
Returns (Q, R, P) where:
Q: orthogonal matrix (n x min(n,p))
R: upper triangular (min(n,p) x p)
P: column permutation indices (p,)
The factorization satisfies: A[:, P] = Q @ R
Pivoting reorders columns by importance, enabling rank detection via the diagonal of R.
solve¶
solve(a: jax.Array, b: jax.Array) -> jax.ArraySolve linear system a @ x = b.
solve_triangular¶
solve_triangular(a: jax.Array, b: jax.Array, lower: bool = False) -> jax.ArraySolve triangular linear system.
svd¶
svd(a: jax.Array, full_matrices: bool = True) -> tuple[jax.Array, jax.Array, jax.Array]Singular value decomposition.
value_and_grad¶
value_and_grad(fn: Callable, argnums: int = 0) -> CallableCompute value and gradient of a function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to differentiate. | required |
argnums | int | Which argument to differentiate with respect to. | 0 |
Returns:
| Type | Description |
|---|---|
Callable | Function that returns (value, gradient). |
vmap¶
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> CallableVectorize a function over a batch dimension using JAX.
zeros¶
zeros(shape: tuple[int, ...], dtype: Any = None) -> jax.ArrayCreate array of zeros.
numpy¶
NumPy backend implementation.
This module provides the NumPy/SciPy implementation of array operations for bossanova. It implements the ArrayOps protocol.
Classes:
| Name | Description |
|---|---|
NumPyBackend | NumPy/SciPy implementation of array operations. |
Classes¶
NumPyBackend¶
NumPyBackend() -> NoneNumPy/SciPy implementation of array operations.
This backend uses NumPy for array operations and SciPy for linear algebra. JIT and vmap are no-ops (or simple loop-based implementations).
Attributes:
| Name | Type | Description |
|---|---|---|
np | The numpy module. |
Functions:
| Name | Description |
|---|---|
arange | Create array with evenly spaced values. |
asarray | Convert input to numpy array. |
cholesky | Cholesky decomposition (lower triangular). |
det | Matrix determinant. |
eigh | Eigendecomposition of symmetric matrix. |
eye | Create identity matrix. |
full | Create array filled with a scalar value. |
inv | Matrix inverse. |
jit | No-op: NumPy doesn’t have JIT compilation. |
lstsq | Least squares solution. |
norm | Matrix or vector norm. |
ones | Create array of ones. |
qr | QR decomposition (economic mode). |
qr_pivoted | Pivoted QR decomposition for rank detection. |
solve | Solve linear system a @ x = b. |
solve_triangular | Solve triangular linear system. |
svd | Singular value decomposition. |
vmap | Vectorize via Python loop. |
zeros | Create array of zeros. |
Attributes¶
np¶
np = npFunctions¶
arange¶
arange(start: int, stop: int | None = None, step: int = 1) -> np.ndarrayCreate array with evenly spaced values.
asarray¶
asarray(x: Any, dtype: Any = None) -> np.ndarrayConvert input to numpy array.
cholesky¶
cholesky(a: np.ndarray) -> np.ndarrayCholesky decomposition (lower triangular).
det¶
det(a: np.ndarray) -> np.floating[Any]Matrix determinant.
eigh¶
eigh(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]Eigendecomposition of symmetric matrix.
eye¶
eye(n: int, dtype: Any = None) -> np.ndarrayCreate identity matrix.
full¶
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> np.ndarrayCreate array filled with a scalar value.
inv¶
inv(a: np.ndarray) -> np.ndarrayMatrix inverse.
jit¶
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> CallableNo-op: NumPy doesn’t have JIT compilation.
lstsq¶
lstsq(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray, int, np.ndarray]Least squares solution.
norm¶
norm(a: np.ndarray, ord: Any = None, axis: int | None = None) -> np.ndarrayMatrix or vector norm.
ones¶
ones(shape: tuple[int, ...], dtype: Any = None) -> np.ndarrayCreate array of ones.
qr¶
qr(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]QR decomposition (economic mode).
qr_pivoted¶
qr_pivoted(a: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]Pivoted QR decomposition for rank detection.
Returns (Q, R, P) where:
Q: orthogonal matrix (n x min(n,p))
R: upper triangular (min(n,p) x p)
P: column permutation indices (p,)
The factorization satisfies: A[:, P] = Q @ R
Pivoting reorders columns by importance, enabling rank detection via the diagonal of R.
solve¶
solve(a: np.ndarray, b: np.ndarray) -> np.ndarraySolve linear system a @ x = b.
solve_triangular¶
solve_triangular(a: np.ndarray, b: np.ndarray, lower: bool = False) -> np.ndarraySolve triangular linear system.
svd¶
svd(a: np.ndarray, full_matrices: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray]Singular value decomposition.
vmap¶
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> CallableVectorize via Python loop.
This is a simple implementation that loops over the batch dimension and stacks results. It’s slower than JAX’s vmap but provides the same interface.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to vectorize. | required |
in_axes | int | tuple[int, ...] | Axis to map over (only supports 0 or tuple of 0s/Nones). | 0 |
Returns:
| Type | Description |
|---|---|
Callable | Vectorized function. |
zeros¶
zeros(shape: tuple[int, ...], dtype: Any = None) -> np.ndarrayCreate array of zeros.
protocol¶
Array operations protocol for backend abstraction.
This module defines the interface that both NumPy and JAX backends must implement. Using a Protocol allows for static type checking while maintaining flexibility in implementation.
Classes:
| Name | Description |
|---|---|
ArrayOps | Protocol for array operations across backends. |
Attributes:
| Name | Type | Description |
|---|---|---|
Array |
Attributes¶
Array¶
Array = TypeVar('Array')Classes¶
ArrayOps¶
Bases: Protocol
Protocol for array operations across backends.
This defines the interface that both NumPyBackend and JAXBackend must implement. All linear algebra and array manipulation needed by bossanova should go through this interface.
Attributes:
| Name | Type | Description |
|---|---|---|
np | Any | The numpy-like module (numpy or jax.numpy). |
Functions:
| Name | Description |
|---|---|
arange | Create array with evenly spaced values. |
asarray | Convert input to array. |
cholesky | Cholesky decomposition. |
det | Matrix determinant. |
eigh | Eigendecomposition of symmetric matrix. |
eye | Create identity matrix. |
full | Create array filled with a scalar value. |
inv | Matrix inverse. |
jit | JIT-compile a function. |
lstsq | Least squares solution. |
norm | Matrix or vector norm. |
ones | Create array of ones. |
qr | QR decomposition. |
solve | Solve linear system a @ x = b. |
solve_triangular | Solve triangular linear system. |
svd | Singular value decomposition. |
vmap | Vectorize a function over a batch dimension. |
zeros | Create array of zeros. |
Attributes¶
np¶
np: AnyFunctions¶
arange¶
arange(start: int, stop: int | None = None, step: int = 1) -> ArrayCreate array with evenly spaced values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start | int | Start value (or stop if stop is None). | required |
stop | int | None | Stop value (exclusive). | None |
step | int | Step size. | 1 |
Returns:
| Type | Description |
|---|---|
Array | Array of evenly spaced values. |
asarray¶
asarray(x: Any, dtype: Any = None) -> ArrayConvert input to array.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Any | Input data (list, tuple, ndarray, etc.). | required |
dtype | Any | Desired data type. | None |
Returns:
| Type | Description |
|---|---|
Array | Array of the appropriate backend type. |
cholesky¶
cholesky(a: Array) -> ArrayCholesky decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Symmetric positive-definite matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Lower triangular Cholesky factor L such that a = L @ L.T. |
det¶
det(a: Array) -> ArrayMatrix determinant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Square matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Determinant of a. |
eigh¶
eigh(a: Array) -> tuple[Array, Array]Eigendecomposition of symmetric matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Symmetric matrix. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array ] | Tuple (eigenvalues, eigenvectors). |
eye¶
eye(n: int, dtype: Any = None) -> ArrayCreate identity matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n | int | Size of the identity matrix. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Identity matrix of shape (n, n). |
full¶
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> ArrayCreate array filled with a scalar value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
fill_value | float | Value to fill the array with. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array filled with fill_value. |
inv¶
inv(a: Array) -> ArrayMatrix inverse.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Square matrix. | required |
Returns:
| Type | Description |
|---|---|
Array | Inverse of a. |
jit¶
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> CallableJIT-compile a function.
For NumPy backend, this is a no-op (returns the function unchanged). For JAX backend, this wraps the function with jax.jit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to compile. | required |
donate_argnums | tuple[int, ...] | None | Positional arg indices whose buffers can be reused by XLA (memory optimization). Donated buffers must not be accessed after the call. Ignored on NumPy backend. | None |
Returns:
| Type | Description |
|---|---|
Callable | JIT-compiled function (or original for NumPy). |
lstsq¶
lstsq(a: Array, b: Array) -> tuple[Array, Any, Any, Any]Least squares solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Coefficient matrix. | required |
b | Array | Right-hand side. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Any, Any, Any] | Tuple (solution, residuals, rank, singular_values). |
norm¶
norm(a: Array, ord: Any = None, axis: int | None = None) -> ArrayMatrix or vector norm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Input array. | required |
ord | Any | Order of the norm. | None |
axis | int | None | Axis along which to compute. | None |
Returns:
| Type | Description |
|---|---|
Array | Norm of the array. |
ones¶
ones(shape: tuple[int, ...], dtype: Any = None) -> ArrayCreate array of ones.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array of ones. |
qr¶
qr(a: Array) -> tuple[Array, Array]QR decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Matrix to decompose. | required |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array ] | Tuple (Q, R) where Q is orthogonal and R is upper triangular. |
solve¶
solve(a: Array, b: Array) -> ArraySolve linear system a @ x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Coefficient matrix. | required |
b | Array | Right-hand side. | required |
Returns:
| Type | Description |
|---|---|
Array | Solution x. |
solve_triangular¶
solve_triangular(a: Array, b: Array, lower: bool = False) -> ArraySolve triangular linear system.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Triangular coefficient matrix. | required |
b | Array | Right-hand side. | required |
lower | bool | If True, a is lower triangular. | False |
Returns:
| Type | Description |
|---|---|
Array | Solution x. |
svd¶
svd(a: Array, full_matrices: bool = True) -> tuple[Array, Array, Array]Singular value decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | Array | Matrix to decompose. | required |
full_matrices | bool | If True, return full U and Vt matrices. | True |
Returns:
| Type | Description |
|---|---|
tuple[ Array , Array , Array ] | Tuple (U, s, Vt) such that a = U @ diag(s) @ Vt. |
vmap¶
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> CallableVectorize a function over a batch dimension.
For NumPy backend, this uses a Python loop with np.stack. For JAX backend, this uses jax.vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn | Callable | Function to vectorize. | required |
in_axes | int | tuple[int, ...] | Axis to map over for each input. | 0 |
Returns:
| Type | Description |
|---|---|
Callable | Vectorized function. |
zeros¶
zeros(shape: tuple[int, ...], dtype: Any = None) -> ArrayCreate array of zeros.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape | tuple[int, ...] | Shape of the array. | required |
dtype | Any | Data type (defaults to float64). | None |
Returns:
| Type | Description |
|---|---|
Array | Array of zeros. |