Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Backend abstraction layer for JAX/NumPy compatibility.

Attributes:

NameTypeDescription
BackendName

Classes:

NameDescription
ArrayOpsProtocol for array operations across backends.

Functions:

NameDescription
backendContext manager for temporary backend switching.
clear_ops_cacheClear the backend operations cache.
get_backendGet the current backend name.
get_opsGet array operations for the current backend.
lock_backendLock the backend to prevent switching after model fitting.
reset_backendReset backend state (for testing only).
set_backendSet the backend to use for computations.

Modules:

NameDescription
dispatchBackend detection, switching, and ArrayOps dispatch.
jaxJAX backend implementation.
numpyNumPy backend implementation.
protocolArray operations protocol for backend abstraction.

Attributes

BackendName

BackendName = Literal['jax', 'numpy']

Classes

ArrayOps

Bases: Protocol

Protocol for array operations across backends.

This defines the interface that both NumPyBackend and JAXBackend must implement. All linear algebra and array manipulation needed by bossanova should go through this interface.

Attributes:

NameTypeDescription
npAnyThe numpy-like module (numpy or jax.numpy).

Functions:

NameDescription
arangeCreate array with evenly spaced values.
asarrayConvert input to array.
choleskyCholesky decomposition.
detMatrix determinant.
eighEigendecomposition of symmetric matrix.
eyeCreate identity matrix.
fullCreate array filled with a scalar value.
invMatrix inverse.
jitJIT-compile a function.
lstsqLeast squares solution.
normMatrix or vector norm.
onesCreate array of ones.
qrQR decomposition.
solveSolve linear system a @ x = b.
solve_triangularSolve triangular linear system.
svdSingular value decomposition.
vmapVectorize a function over a batch dimension.
zerosCreate array of zeros.

Attributes

np
np: Any

Functions

arange
arange(start: int, stop: int | None = None, step: int = 1) -> Array

Create array with evenly spaced values.

Parameters:

NameTypeDescriptionDefault
startintStart value (or stop if stop is None).required
stopint | NoneStop value (exclusive).None
stepintStep size.1

Returns:

TypeDescription
ArrayArray of evenly spaced values.
asarray
asarray(x: Any, dtype: Any = None) -> Array

Convert input to array.

Parameters:

NameTypeDescriptionDefault
xAnyInput data (list, tuple, ndarray, etc.).required
dtypeAnyDesired data type.None

Returns:

TypeDescription
ArrayArray of the appropriate backend type.
cholesky
cholesky(a: Array) -> Array

Cholesky decomposition.

Parameters:

NameTypeDescriptionDefault
aArraySymmetric positive-definite matrix.required

Returns:

TypeDescription
ArrayLower triangular Cholesky factor L such that a = L @ L.T.
det
det(a: Array) -> Array

Matrix determinant.

Parameters:

NameTypeDescriptionDefault
aArraySquare matrix.required

Returns:

TypeDescription
ArrayDeterminant of a.
eigh
eigh(a: Array) -> tuple[Array, Array]

Eigendecomposition of symmetric matrix.

Parameters:

NameTypeDescriptionDefault
aArraySymmetric matrix.required

Returns:

TypeDescription
tuple[ Array , Array ]Tuple (eigenvalues, eigenvectors).
eye
eye(n: int, dtype: Any = None) -> Array

Create identity matrix.

Parameters:

NameTypeDescriptionDefault
nintSize of the identity matrix.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayIdentity matrix of shape (n, n).
full
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> Array

Create array filled with a scalar value.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
fill_valuefloatValue to fill the array with.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray filled with fill_value.
inv
inv(a: Array) -> Array

Matrix inverse.

Parameters:

NameTypeDescriptionDefault
aArraySquare matrix.required

Returns:

TypeDescription
ArrayInverse of a.
jit
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> Callable

JIT-compile a function.

For NumPy backend, this is a no-op (returns the function unchanged). For JAX backend, this wraps the function with jax.jit.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to compile.required
donate_argnumstuple[int, ...] | NonePositional arg indices whose buffers can be reused by XLA (memory optimization). Donated buffers must not be accessed after the call. Ignored on NumPy backend.None

Returns:

TypeDescription
CallableJIT-compiled function (or original for NumPy).
lstsq
lstsq(a: Array, b: Array) -> tuple[Array, Any, Any, Any]

Least squares solution.

Parameters:

NameTypeDescriptionDefault
aArrayCoefficient matrix.required
bArrayRight-hand side.required

Returns:

TypeDescription
tuple[ Array , Any, Any, Any]Tuple (solution, residuals, rank, singular_values).
norm
norm(a: Array, ord: Any = None, axis: int | None = None) -> Array

Matrix or vector norm.

Parameters:

NameTypeDescriptionDefault
aArrayInput array.required
ordAnyOrder of the norm.None
axisint | NoneAxis along which to compute.None

Returns:

TypeDescription
ArrayNorm of the array.
ones
ones(shape: tuple[int, ...], dtype: Any = None) -> Array

Create array of ones.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray of ones.
qr
qr(a: Array) -> tuple[Array, Array]

QR decomposition.

Parameters:

NameTypeDescriptionDefault
aArrayMatrix to decompose.required

Returns:

TypeDescription
tuple[ Array , Array ]Tuple (Q, R) where Q is orthogonal and R is upper triangular.
solve
solve(a: Array, b: Array) -> Array

Solve linear system a @ x = b.

Parameters:

NameTypeDescriptionDefault
aArrayCoefficient matrix.required
bArrayRight-hand side.required

Returns:

TypeDescription
ArraySolution x.
solve_triangular
solve_triangular(a: Array, b: Array, lower: bool = False) -> Array

Solve triangular linear system.

Parameters:

NameTypeDescriptionDefault
aArrayTriangular coefficient matrix.required
bArrayRight-hand side.required
lowerboolIf True, a is lower triangular.False

Returns:

TypeDescription
ArraySolution x.
svd
svd(a: Array, full_matrices: bool = True) -> tuple[Array, Array, Array]

Singular value decomposition.

Parameters:

NameTypeDescriptionDefault
aArrayMatrix to decompose.required
full_matricesboolIf True, return full U and Vt matrices.True

Returns:

TypeDescription
tuple[ Array , Array , Array ]Tuple (U, s, Vt) such that a = U @ diag(s) @ Vt.
vmap
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> Callable

Vectorize a function over a batch dimension.

For NumPy backend, this uses a Python loop with np.stack. For JAX backend, this uses jax.vmap.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to vectorize.required
in_axesint | tuple[int, ...]Axis to map over for each input.0

Returns:

TypeDescription
CallableVectorized function.
zeros
zeros(shape: tuple[int, ...], dtype: Any = None) -> Array

Create array of zeros.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray of zeros.

Functions

backend

backend(name: BackendName) -> Iterator[None]

Context manager for temporary backend switching.

This is primarily intended for testing. It temporarily switches the backend and restores the previous state on exit.

Parameters:

NameTypeDescriptionDefault
nameBackendNameBackend name to use within the context.required

Yields None; restores previous backend on exit. Yields None; restores previous backend on exit.

Examples:

import bossanova
with bossanova.backend("numpy"):
    print(bossanova.get_backend())
# 'numpy'

clear_ops_cache

clear_ops_cache() -> None

Clear the backend operations cache.

This is primarily for testing purposes. Clears the cached backend instances so that the next call to get_ops() creates a fresh instance.

This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.

get_backend

get_backend() -> BackendName

Get the current backend name.

If no backend has been explicitly set, auto-detects the best available backend on first call.

Returns:

TypeDescription
BackendNameThe current backend name (‘jax’ or ‘numpy’).

Examples:

import bossanova
bossanova.get_backend()
# 'jax'

get_ops

get_ops() -> 'ArrayOps'

Get array operations for the current backend.

Returns the appropriate backend instance (NumPyBackend or JAXBackend) based on the current backend setting. The instance is cached so that repeated calls return the same object.

Returns:

TypeDescription
‘ArrayOps’ArrayOps instance for the current backend.

Examples:

>>> from maths import get_ops
>>> ops = get_ops()
>>> X = ops.asarray([[1, 2], [3, 4]])
>>> L = ops.cholesky(X @ X.T)

lock_backend

lock_backend() -> None

Lock the backend to prevent switching after model fitting.

This is called internally when models are fitted to ensure consistent behavior throughout a session.

reset_backend

reset_backend() -> None

Reset backend state (for testing only).

This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.

set_backend

set_backend(name: BackendName) -> None

Set the backend to use for computations.

Must be called before any model fitting occurs. Once a model has been fitted, the backend is locked and cannot be changed.

Parameters:

NameTypeDescriptionDefault
nameBackendNameBackend name, either ‘jax’ or ‘numpy’.required

Examples:

import bossanova
bossanova.set_backend("numpy")
bossanova.get_backend()
# 'numpy'

Modules

dispatch

Backend detection, switching, and ArrayOps dispatch.

Functions:

NameDescription
backendContext manager for temporary backend switching.
clear_ops_cacheClear the backend operations cache.
get_backendGet the current backend name.
get_opsGet array operations for the current backend.
lock_backendLock the backend to prevent switching after model fitting.
reset_backendReset backend state (for testing only).
set_backendSet the backend to use for computations.

Attributes:

NameTypeDescription
BackendName

Attributes

BackendName
BackendName = Literal['jax', 'numpy']

Classes

Functions

backend
backend(name: BackendName) -> Iterator[None]

Context manager for temporary backend switching.

This is primarily intended for testing. It temporarily switches the backend and restores the previous state on exit.

Parameters:

NameTypeDescriptionDefault
nameBackendNameBackend name to use within the context.required

Yields None; restores previous backend on exit. Yields None; restores previous backend on exit.

Examples:

import bossanova
with bossanova.backend("numpy"):
    print(bossanova.get_backend())
# 'numpy'
clear_ops_cache
clear_ops_cache() -> None

Clear the backend operations cache.

This is primarily for testing purposes. Clears the cached backend instances so that the next call to get_ops() creates a fresh instance.

This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.

get_backend
get_backend() -> BackendName

Get the current backend name.

If no backend has been explicitly set, auto-detects the best available backend on first call.

Returns:

TypeDescription
BackendNameThe current backend name (‘jax’ or ‘numpy’).

Examples:

import bossanova
bossanova.get_backend()
# 'jax'
get_ops
get_ops() -> 'ArrayOps'

Get array operations for the current backend.

Returns the appropriate backend instance (NumPyBackend or JAXBackend) based on the current backend setting. The instance is cached so that repeated calls return the same object.

Returns:

TypeDescription
‘ArrayOps’ArrayOps instance for the current backend.

Examples:

>>> from maths import get_ops
>>> ops = get_ops()
>>> X = ops.asarray([[1, 2], [3, 4]])
>>> L = ops.cholesky(X @ X.T)
lock_backend
lock_backend() -> None

Lock the backend to prevent switching after model fitting.

This is called internally when models are fitted to ensure consistent behavior throughout a session.

reset_backend
reset_backend() -> None

Reset backend state (for testing only).

This should only be used in tests. Using it in production code This should only be used in tests. Using it in production code can lead to inconsistent behavior.

set_backend
set_backend(name: BackendName) -> None

Set the backend to use for computations.

Must be called before any model fitting occurs. Once a model has been fitted, the backend is locked and cannot be changed.

Parameters:

NameTypeDescriptionDefault
nameBackendNameBackend name, either ‘jax’ or ‘numpy’.required

Examples:

import bossanova
bossanova.set_backend("numpy")
bossanova.get_backend()
# 'numpy'

jax

JAX backend implementation.

This module provides the JAX implementation of array operations for bossanova. It implements the ArrayOps protocol.

Note: JAX x64 config is set in bossanova/__init__.py at package import time, before any submodules are loaded. By the time this module is imported (lazily, via get_ops()), x64 is already enabled.

Classes:

NameDescription
JAXBackendJAX implementation of array operations.

Classes

JAXBackend
JAXBackend() -> None

JAX implementation of array operations.

This backend uses JAX for array operations and provides JIT compilation and vectorization via vmap.

Attributes:

NameTypeDescription
npThe jax.numpy module.
jaxThe jax module (for jit, vmap, etc.).

Functions:

NameDescription
arangeCreate array with evenly spaced values.
asarrayConvert input to JAX array.
cho_solveSolve using Cholesky factor.
choleskyCholesky decomposition (lower triangular).
detMatrix determinant.
eighEigendecomposition of symmetric matrix.
eyeCreate identity matrix.
fullCreate array filled with a scalar value.
gradCompute gradient of a function.
invMatrix inverse.
jitJIT-compile a function using JAX.
lstsqLeast squares solution.
normMatrix or vector norm.
onesCreate array of ones.
qrQR decomposition.
qr_pivotedPivoted QR decomposition for rank detection.
solveSolve linear system a @ x = b.
solve_triangularSolve triangular linear system.
svdSingular value decomposition.
value_and_gradCompute value and gradient of a function.
vmapVectorize a function over a batch dimension using JAX.
zerosCreate array of zeros.
Attributes
jax
jax = jax
np
np = jnp
Functions
arange
arange(start: int, stop: int | None = None, step: int = 1) -> jax.Array

Create array with evenly spaced values.

asarray
asarray(x: Any, dtype: Any = None) -> jax.Array

Convert input to JAX array.

cho_solve
cho_solve(c_and_lower: tuple[jax.Array, bool], b: jax.Array) -> jax.Array

Solve using Cholesky factor.

Parameters:

NameTypeDescriptionDefault
c_and_lowertuple[ Array , bool]Tuple of (cholesky_factor, is_lower).required
bArrayRight-hand side.required

Returns:

TypeDescription
ArraySolution x to A @ x = b where A = L @ L.T.
cholesky
cholesky(a: jax.Array) -> jax.Array

Cholesky decomposition (lower triangular).

det
det(a: jax.Array) -> jax.Array

Matrix determinant.

eigh
eigh(a: jax.Array) -> tuple[jax.Array, jax.Array]

Eigendecomposition of symmetric matrix.

eye
eye(n: int, dtype: Any = None) -> jax.Array

Create identity matrix.

full
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> jax.Array

Create array filled with a scalar value.

grad
grad(fn: Callable, argnums: int = 0) -> Callable

Compute gradient of a function.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to differentiate.required
argnumsintWhich argument to differentiate with respect to.0

Returns:

TypeDescription
CallableFunction that computes the gradient.
inv
inv(a: jax.Array) -> jax.Array

Matrix inverse.

jit
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> Callable

JIT-compile a function using JAX.

lstsq
lstsq(a: jax.Array, b: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]

Least squares solution.

norm
norm(a: jax.Array, ord: Any = None, axis: int | None = None) -> jax.Array

Matrix or vector norm.

ones
ones(shape: tuple[int, ...], dtype: Any = None) -> jax.Array

Create array of ones.

qr
qr(a: jax.Array) -> tuple[jax.Array, jax.Array]

QR decomposition.

qr_pivoted
qr_pivoted(a: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]

Pivoted QR decomposition for rank detection.

Returns (Q, R, P) where:

The factorization satisfies: A[:, P] = Q @ R

Pivoting reorders columns by importance, enabling rank detection via the diagonal of R.

solve
solve(a: jax.Array, b: jax.Array) -> jax.Array

Solve linear system a @ x = b.

solve_triangular
solve_triangular(a: jax.Array, b: jax.Array, lower: bool = False) -> jax.Array

Solve triangular linear system.

svd
svd(a: jax.Array, full_matrices: bool = True) -> tuple[jax.Array, jax.Array, jax.Array]

Singular value decomposition.

value_and_grad
value_and_grad(fn: Callable, argnums: int = 0) -> Callable

Compute value and gradient of a function.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to differentiate.required
argnumsintWhich argument to differentiate with respect to.0

Returns:

TypeDescription
CallableFunction that returns (value, gradient).
vmap
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> Callable

Vectorize a function over a batch dimension using JAX.

zeros
zeros(shape: tuple[int, ...], dtype: Any = None) -> jax.Array

Create array of zeros.

numpy

NumPy backend implementation.

This module provides the NumPy/SciPy implementation of array operations for bossanova. It implements the ArrayOps protocol.

Classes:

NameDescription
NumPyBackendNumPy/SciPy implementation of array operations.

Classes

NumPyBackend
NumPyBackend() -> None

NumPy/SciPy implementation of array operations.

This backend uses NumPy for array operations and SciPy for linear algebra. JIT and vmap are no-ops (or simple loop-based implementations).

Attributes:

NameTypeDescription
npThe numpy module.

Functions:

NameDescription
arangeCreate array with evenly spaced values.
asarrayConvert input to numpy array.
choleskyCholesky decomposition (lower triangular).
detMatrix determinant.
eighEigendecomposition of symmetric matrix.
eyeCreate identity matrix.
fullCreate array filled with a scalar value.
invMatrix inverse.
jitNo-op: NumPy doesn’t have JIT compilation.
lstsqLeast squares solution.
normMatrix or vector norm.
onesCreate array of ones.
qrQR decomposition (economic mode).
qr_pivotedPivoted QR decomposition for rank detection.
solveSolve linear system a @ x = b.
solve_triangularSolve triangular linear system.
svdSingular value decomposition.
vmapVectorize via Python loop.
zerosCreate array of zeros.
Attributes
np
np = np
Functions
arange
arange(start: int, stop: int | None = None, step: int = 1) -> np.ndarray

Create array with evenly spaced values.

asarray
asarray(x: Any, dtype: Any = None) -> np.ndarray

Convert input to numpy array.

cholesky
cholesky(a: np.ndarray) -> np.ndarray

Cholesky decomposition (lower triangular).

det
det(a: np.ndarray) -> np.floating[Any]

Matrix determinant.

eigh
eigh(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]

Eigendecomposition of symmetric matrix.

eye
eye(n: int, dtype: Any = None) -> np.ndarray

Create identity matrix.

full
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> np.ndarray

Create array filled with a scalar value.

inv
inv(a: np.ndarray) -> np.ndarray

Matrix inverse.

jit
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> Callable

No-op: NumPy doesn’t have JIT compilation.

lstsq
lstsq(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray, int, np.ndarray]

Least squares solution.

norm
norm(a: np.ndarray, ord: Any = None, axis: int | None = None) -> np.ndarray

Matrix or vector norm.

ones
ones(shape: tuple[int, ...], dtype: Any = None) -> np.ndarray

Create array of ones.

qr
qr(a: np.ndarray) -> tuple[np.ndarray, np.ndarray]

QR decomposition (economic mode).

qr_pivoted
qr_pivoted(a: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]

Pivoted QR decomposition for rank detection.

Returns (Q, R, P) where:

The factorization satisfies: A[:, P] = Q @ R

Pivoting reorders columns by importance, enabling rank detection via the diagonal of R.

solve
solve(a: np.ndarray, b: np.ndarray) -> np.ndarray

Solve linear system a @ x = b.

solve_triangular
solve_triangular(a: np.ndarray, b: np.ndarray, lower: bool = False) -> np.ndarray

Solve triangular linear system.

svd
svd(a: np.ndarray, full_matrices: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray]

Singular value decomposition.

vmap
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> Callable

Vectorize via Python loop.

This is a simple implementation that loops over the batch dimension and stacks results. It’s slower than JAX’s vmap but provides the same interface.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to vectorize.required
in_axesint | tuple[int, ...]Axis to map over (only supports 0 or tuple of 0s/Nones).0

Returns:

TypeDescription
CallableVectorized function.
zeros
zeros(shape: tuple[int, ...], dtype: Any = None) -> np.ndarray

Create array of zeros.

protocol

Array operations protocol for backend abstraction.

This module defines the interface that both NumPy and JAX backends must implement. Using a Protocol allows for static type checking while maintaining flexibility in implementation.

Classes:

NameDescription
ArrayOpsProtocol for array operations across backends.

Attributes:

NameTypeDescription
Array

Attributes

Array
Array = TypeVar('Array')

Classes

ArrayOps

Bases: Protocol

Protocol for array operations across backends.

This defines the interface that both NumPyBackend and JAXBackend must implement. All linear algebra and array manipulation needed by bossanova should go through this interface.

Attributes:

NameTypeDescription
npAnyThe numpy-like module (numpy or jax.numpy).

Functions:

NameDescription
arangeCreate array with evenly spaced values.
asarrayConvert input to array.
choleskyCholesky decomposition.
detMatrix determinant.
eighEigendecomposition of symmetric matrix.
eyeCreate identity matrix.
fullCreate array filled with a scalar value.
invMatrix inverse.
jitJIT-compile a function.
lstsqLeast squares solution.
normMatrix or vector norm.
onesCreate array of ones.
qrQR decomposition.
solveSolve linear system a @ x = b.
solve_triangularSolve triangular linear system.
svdSingular value decomposition.
vmapVectorize a function over a batch dimension.
zerosCreate array of zeros.
Attributes
np
np: Any
Functions
arange
arange(start: int, stop: int | None = None, step: int = 1) -> Array

Create array with evenly spaced values.

Parameters:

NameTypeDescriptionDefault
startintStart value (or stop if stop is None).required
stopint | NoneStop value (exclusive).None
stepintStep size.1

Returns:

TypeDescription
ArrayArray of evenly spaced values.
asarray
asarray(x: Any, dtype: Any = None) -> Array

Convert input to array.

Parameters:

NameTypeDescriptionDefault
xAnyInput data (list, tuple, ndarray, etc.).required
dtypeAnyDesired data type.None

Returns:

TypeDescription
ArrayArray of the appropriate backend type.
cholesky
cholesky(a: Array) -> Array

Cholesky decomposition.

Parameters:

NameTypeDescriptionDefault
aArraySymmetric positive-definite matrix.required

Returns:

TypeDescription
ArrayLower triangular Cholesky factor L such that a = L @ L.T.
det
det(a: Array) -> Array

Matrix determinant.

Parameters:

NameTypeDescriptionDefault
aArraySquare matrix.required

Returns:

TypeDescription
ArrayDeterminant of a.
eigh
eigh(a: Array) -> tuple[Array, Array]

Eigendecomposition of symmetric matrix.

Parameters:

NameTypeDescriptionDefault
aArraySymmetric matrix.required

Returns:

TypeDescription
tuple[ Array , Array ]Tuple (eigenvalues, eigenvectors).
eye
eye(n: int, dtype: Any = None) -> Array

Create identity matrix.

Parameters:

NameTypeDescriptionDefault
nintSize of the identity matrix.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayIdentity matrix of shape (n, n).
full
full(shape: tuple[int, ...], fill_value: float, dtype: Any = None) -> Array

Create array filled with a scalar value.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
fill_valuefloatValue to fill the array with.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray filled with fill_value.
inv
inv(a: Array) -> Array

Matrix inverse.

Parameters:

NameTypeDescriptionDefault
aArraySquare matrix.required

Returns:

TypeDescription
ArrayInverse of a.
jit
jit(fn: Callable, *, donate_argnums: tuple[int, ...] | None = None) -> Callable

JIT-compile a function.

For NumPy backend, this is a no-op (returns the function unchanged). For JAX backend, this wraps the function with jax.jit.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to compile.required
donate_argnumstuple[int, ...] | NonePositional arg indices whose buffers can be reused by XLA (memory optimization). Donated buffers must not be accessed after the call. Ignored on NumPy backend.None

Returns:

TypeDescription
CallableJIT-compiled function (or original for NumPy).
lstsq
lstsq(a: Array, b: Array) -> tuple[Array, Any, Any, Any]

Least squares solution.

Parameters:

NameTypeDescriptionDefault
aArrayCoefficient matrix.required
bArrayRight-hand side.required

Returns:

TypeDescription
tuple[ Array , Any, Any, Any]Tuple (solution, residuals, rank, singular_values).
norm
norm(a: Array, ord: Any = None, axis: int | None = None) -> Array

Matrix or vector norm.

Parameters:

NameTypeDescriptionDefault
aArrayInput array.required
ordAnyOrder of the norm.None
axisint | NoneAxis along which to compute.None

Returns:

TypeDescription
ArrayNorm of the array.
ones
ones(shape: tuple[int, ...], dtype: Any = None) -> Array

Create array of ones.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray of ones.
qr
qr(a: Array) -> tuple[Array, Array]

QR decomposition.

Parameters:

NameTypeDescriptionDefault
aArrayMatrix to decompose.required

Returns:

TypeDescription
tuple[ Array , Array ]Tuple (Q, R) where Q is orthogonal and R is upper triangular.
solve
solve(a: Array, b: Array) -> Array

Solve linear system a @ x = b.

Parameters:

NameTypeDescriptionDefault
aArrayCoefficient matrix.required
bArrayRight-hand side.required

Returns:

TypeDescription
ArraySolution x.
solve_triangular
solve_triangular(a: Array, b: Array, lower: bool = False) -> Array

Solve triangular linear system.

Parameters:

NameTypeDescriptionDefault
aArrayTriangular coefficient matrix.required
bArrayRight-hand side.required
lowerboolIf True, a is lower triangular.False

Returns:

TypeDescription
ArraySolution x.
svd
svd(a: Array, full_matrices: bool = True) -> tuple[Array, Array, Array]

Singular value decomposition.

Parameters:

NameTypeDescriptionDefault
aArrayMatrix to decompose.required
full_matricesboolIf True, return full U and Vt matrices.True

Returns:

TypeDescription
tuple[ Array , Array , Array ]Tuple (U, s, Vt) such that a = U @ diag(s) @ Vt.
vmap
vmap(fn: Callable, in_axes: int | tuple[int, ...] = 0) -> Callable

Vectorize a function over a batch dimension.

For NumPy backend, this uses a Python loop with np.stack. For JAX backend, this uses jax.vmap.

Parameters:

NameTypeDescriptionDefault
fnCallableFunction to vectorize.required
in_axesint | tuple[int, ...]Axis to map over for each input.0

Returns:

TypeDescription
CallableVectorized function.
zeros
zeros(shape: tuple[int, ...], dtype: Any = None) -> Array

Create array of zeros.

Parameters:

NameTypeDescriptionDefault
shapetuple[int, ...]Shape of the array.required
dtypeAnyData type (defaults to float64).None

Returns:

TypeDescription
ArrayArray of zeros.