This document provides a comprehensive overview of linear equation solving algorithms available in Julia, JAX, PyTorch, Lineax, and NumPy. It covers the theory, convergence properties, usage scenarios, and implementation details for each library.
This section covers the theoretical underpinnings of common linear solvers.
Direct methods compute the exact solution (up to floating-point error) in a finite number of steps, typically involving matrix factorization.
| Algorithm | Theory | Convergence / Stability | Best For | Complexity |
|---|---|---|---|---|
| LU (Gaussian Elimination) | Factors $A = PLU$ (Permutation, Lower, Upper). Reduces problem to two triangular solves. | Stable (with partial pivoting). Backward stable. | General square, non-singular matrices. Default for most dense solvers. | $O(N^3)$ (factor) + $O(N^2)$ (solve) |
| Cholesky | Factors $A = LL^T$ (Lower triangular). Requires $A$ to be Symmetric and Positive Definite (SPD). | Very Stable. Fails if matrix is not PD. ~2x faster than LU. | SPD matrices (e.g., covariance matrices, physics simulations). | $\frac{1}{3}N^3$ |
| QR Decomposition | Factors $A = QR$ (Values Orthogonal, Upper Triangular). Solves $Rx = Q^T b$. | Extremely Stable. Better numerical properties than LU for ill-conditioned matrices. | Least squares, rectangular systems, or highly ill-conditioned square matrices. | $O(N^3)$ (higher constant than LU) |
| SVD (Singular Value) | Factors $A = U \Sigma V^T$. Solves via pseudoinverse $x = V \Sigma^+ U^T b$. | Most Stable. Handles rank-deficient and near-singular matrices perfectly. | Rank-deficient systems, minimum-norm least squares, analysis of system stability. | $O(N^3)$ (very high constant) |
| Diagonal / Triangular | Direct substitution (Forward/Backward). | Exact and stable. | Diagonal or Triangular systems. Often the final step of other factorizations. | $O(N)$ (diag) / $O(N^2)$ (tri) |
| TDMA (Thomas Algorithm) | Gaussian elimination optimized for tridiagonal systems. | Stable for diagonally dominant or SPD matrices. Unstable otherwise. | 1D PDEs (heat/wave equations), cubic splines, time-series smoothing. | $O(N)$ |
Iterative methods approximate the solution by minimizing an error function over a subspace. They are preferred for large sparse matrices where $O(N^2)$ storage of factors is prohibitive.
| Algorithm | Theory | Convergence | Best For | Memory |
|---|---|---|---|---|
| CG (Conjugate Gradient) | Minimizes error in $A$-norm over Krylov subspace. | Depends on $\sqrt{\kappa(A)}$ (condition number) and eigenvalue clustering. Guaranteed for SPD. | Large Sparse SPD matrices. | Low ($O(N)$) |
| GMRES (Generalized Minimal Residual) | Minimizes residual norm $|b - Ax_k|_2$. Arnoldi iteration. | Monotonically decreases residual. Depends on eigenvalue distribution. | General non-symmetric square systems. | High (stores basis vectors; often restarted: GMRES(k)). |
| BiCGStab (Bi-Conjugate Gradient Stabilized) | Variation of BiCG using updates to smooth convergence. | Irregular convergence (spiky residue), but often faster than GMRES per step. No theoretical guarantee. | General non-symmetric systems where GMRES memory is too high. | Low ($O(N)$) |
LinearAlgebra & LinearSolve.jl)Julia uses a powerful polyalgorithm via the \ operator, dispatching to LAPACK (dense) or SuiteSparse/specialized code (sparse).
A \ b: Checks properties (Triangular -> Diagonal -> Tridiagonal -> Hermitian -> General).
dgtsv or native).dposv (Cholesky).dgsjv (LU).dgels (QR min-norm solution).A \ b:
IterativeSolvers.jl or Krylov.jl. LinearSolve.jl provides a unified interface.Lineax is a dedicated JAX library for linear solves, designed for differentiation and structure awareness.
lineax.linear_solve(operator, vector, solver=...)lineax.AutoLinearSolver: Automatically selects based on operator structure (e.g., TridiagonalLinearOperator $\to$ Tridiagonal, DiagonalLinearOperator $\to$ Diagonal, MatrixLinearOperator $\to$ LU or QR).lineax.Tridiagonal: $O(N)$ solver for tridiagonal operators.lineax.LU, lineax.QR, lineax.SVD: Standard direct solvers.lineax.Cholesky: For PD operators.lineax.CG, lineax.GMRES, lineax.BiCGStab: Iterative solvers written in JAX.jax.numpy & jax.scipy)JAX wraps standard LAPACK/cuSOLVER routines.
jax.numpy.linalg.solve (LU), jax.numpy.linalg.lstsq (SVD/QR).
jax.scipy.sparse.linalg contains iterative solvers only (cg, gmres, bicgstab).
spsolve exists but is limited).jax.lax.linalg.tridiagonal_solve (TPU tailored, uses Thomas Algorithm).torch.linalg)PyTorch provides dense solvers similar to NumPy/JAX, powered by MAGMA/cuSOLVER on GPU.
torch.linalg.solve (LU), torch.linalg.lstsq (QR/SVD).torch.sparse exists but solving systems usually requires conversion to dense or external libraries, though simple sparse-dense solves exist.numpy.linalg)The standard CPU reference.
numpy.linalg.solve (LAPACK _gesv LU).scipy.linalg.solve_banded (LAPACK dgbsv) for $O(N)$ performance.numpy. Users must use scipy.sparse.linalg.| Feature | Julia | Lineax (JAX) | JAX (Native) | PyTorch | NumPy |
|---|---|---|---|---|---|
Default (\) |
Polyalgorithm (LU/QR/Chol) | AutoLinearSolver |
LU (solve) |
LU (solve) |
LU (solve) |
| Sparse Direct | Yes (integrated) | No (dense mostly) | No (experimental) | No | No (use SciPy) |
| Sparse Iterative | Via Packages | Yes (CG/GMRES) | Yes (CG/GMRES) | Limited | No (use SciPy) |
| Gradients | Via Zygote/Enzyme | First-class | First-class | First-class | No |
| GPU | Yes (CUDA.jl) | Yes | Yes | Yes | No |
import numpy as np
# System: Ax = b
N = 100
A = np.random.rand(N, N)
# Ensure non-singularity
A = A + np.eye(N) * N
b = np.random.rand(N)
# 1. Standard Solve (LU)
x_lu = np.linalg.solve(A, b)
# 2. Least Squares (if A is not square)
A_rect = np.random.rand(N + 10, N)
b_rect = np.random.rand(N + 10)
x_lstsq, residuals, rank, s = np.linalg.lstsq(A_rect, b_rect, rcond=None)
# 3. Cholesky (Manual)
# Only for SPD matrices: A must be symmetric and positive definite
A_spd = A @ A.T # Construct SPD matrix
L = np.linalg.cholesky(A_spd) # A = L L^T
# Solve L y = b
y = np.linalg.solve(L, b)
# Solve L^T x = y
x_chol = np.linalg.solve(L.T, y)
import jax
import jax.numpy as jnp
import jax.scipy.sparse.linalg
key = jax.random.PRNGKey(0)
# Data Generation
N = 100
A = jax.random.normal(key, (N, N))
A = A + jnp.eye(N) * N # Diagonally dominant -> non-singular
b = jax.random.normal(key, (N,))
# 1. Standard Dense Solve (LU on GPU)
x_dense = jnp.linalg.solve(A, b)
# 2. Sparse / Iterative Solve (CG)
# JAX requires a linear operator function (matvec)
def matvec(x):
return A @ x
# CG requires Symmetric Positive Definite (SPD) matrix usually
A_spd = A.T @ A
b_spd = A.T @ b
# Solving (A^T A) x = A^T b which is the normal equation
x_cg, info = jax.scipy.sparse.linalg.cg(
lambda v: A_spd @ v,
b_spd,
maxiter=1000
)
# 3. GMRES (General matrices)
x_gmres, info = jax.scipy.sparse.linalg.gmres(matvec, b)
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = 100
A = torch.randn(N, N, device=device)
A = A + torch.eye(N, device=device) * N
b = torch.randn(N, 1, device=device) # Must be (N, 1) or (N,) depending on version
# 1. Standard Solve (LU)
x_lu = torch.linalg.solve(A, b)
# 2. Cholesky Solve
# Faster for SPD systems
A_spd = A @ A.T
# torch.linalg.cholesky_ex is safer for avoiding errors
L = torch.linalg.cholesky(A_spd)
x_chol = torch.cholesky_solve(b, L) # Takes L, not A
# 3. Least Squares
# driver='gels' (QR) or 'gelsd' (SVD - more stable)
x_lstsq, residuals, rank, singular_values = torch.linalg.lstsq(A, b, driver='gels')
import jax
import jax.numpy as jnp
import lineax as lx
key = jax.random.PRNGKey(0)
N = 100
A_val = jax.random.normal(key, (N, N))
b = jax.random.normal(key, (N,))
# Lineax uses "LinearOperators"
operator = lx.MatrixLinearOperator(A_val)
# 1. Auto Solve (Best Practice)
# Automatically chooses LU, QR, etc.
solver = lx.AutoLinearSolver(well_posed=True)
solution = lx.linear_solve(operator, b, solver)
print(solution.value)
# 2. Explicit Choice (e.g., QR)
solver_qr = lx.QR()
solution_qr = lx.linear_solve(operator, b, solver_qr)
# ... (previous code)
# 3. Large Scale / Iterative (GMRES)
# Useful if 'operator' is defined functionally, not as a matrix
solver_gmres = lx.GMRES(rtol=1e-5, atol=1e-5)
solution_iter = lx.linear_solve(operator, b, solver_gmres)
Solving for tridiagonal matrices ($O(N)$).
import numpy as np
import scipy.linalg
N = 100
# Banded format: [Upper diag, Main diag, Lower diag]
# shape (3, N) - padding required strictly speaking but usually handled by offsets
# scipy.linalg.solve_banded expects shape (u+l+1, N)
# For tridiagonal (l=1, u=1), shape is (3, N).
# Row 0: upper diagonal (first element ignored)
# Row 1: main diagonal
# Row 2: lower diagonal (last element ignored)
ab = np.zeros((3, N))
ab[0, 1:] = np.random.rand(N-1) # Upper
ab[1, :] = np.random.rand(N) + 2 # Main (diagonally dominant)
ab[2, :-1] = np.random.rand(N-1) # Lower
b = np.random.rand(N)
x = scipy.linalg.solve_banded((1, 1), ab, b)
import jax.random as jr
import lineax as lx
import jax.numpy as jnp
N = 100
diagonal = jr.normal(jr.PRNGKey(0), (N,))
upper = jr.normal(jr.PRNGKey(1), (N-1,))
lower = jr.normal(jr.PRNGKey(2), (N-1,))
b = jr.normal(jr.PRNGKey(3), (N,))
operator = lx.TridiagonalLinearOperator(diagonal, lower, upper)
solution = lx.linear_solve(operator, b, lx.Tridiagonal())
import jax
import jax.numpy as jnp
N = 100
# Inputs must be (B, N) or (N,)
# diagonals format: (3, N) -> [lower, diagonal, upper]
# Note: This is different from scipy's [upper, diagonal, lower] or [upper, diagonal, lower]
dl = jax.random.normal(jax.random.key(0), (N,))
d = jax.random.normal(jax.random.key(1), (N,))
du = jax.random.normal(jax.random.key(2), (N,))
b = jax.random.normal(jax.random.key(3), (N,))
# jax.lax.linalg.tridiagonal_solve(dl, d, du, b)
# dl: lower diagonal (first element ignored)
# d: main diagonal
# du: upper diagonal (last element ignored)
x = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)
This section provides a deep dive into the theory, mathematical formulation, and pseudo-algorithms for the solvers mentioned above.
Direct solvers factorize matrix $A$ into simpler forms (triangular, diagonal, orthogonal) to make solving $Ax = b$ trivial (e.g., via simple back-substitution).
Theory: Any square matrix $A$ can be decomposed into a lower triangular matrix $L$ (with unit diagonal) and an upper triangular matrix $U$, such that $PA = LU$, where $P$ is a permutation matrix to ensure numerical stability (partial pivoting). Solving $Ax = b$ becomes:
Pseudo-Code (Simplified without pivoting):
function LU_Decomposition(A):
n = size(A, 1)
L = eye(n)
U = copy(A)
for k = 1 to n-1:
for i = k+1 to n:
factor = U[i, k] / U[k, k]
L[i, k] = factor
U[i, k:] = U[i, k:] - factor * U[k, k:]
return L, U
Background: This is the standard “Gaussian Elimination” taught in linear algebra. Without pivoting ($P$), it is unstable if diagonal elements are near zero. With partial pivoting ($O(N^3)$), it is the industry standard for general dense systems.
Theory: If $A$ is Symmetric and Positive Definite (SPD) (i.e., $x^T Ax > 0$ for all $x \neq 0$), it can be factored uniquely as $A = LL^T$, where $L$ is lower triangular. Solving $Ax=b \implies LL^T x = b$:
Pseudo-Code:
function Cholesky(A):
n = size(A, 1)
L = zeros(n, n)
for i = 1 to n:
for j = 1 to i:
sum_val = sum(L[i, k] * L[j, k] for k = 1 to j-1)
if i == j: # Diagonal elements
L[i, j] = sqrt(A[i, i] - sum_val)
else:
L[i, j] = (1.0 / L[j, j]) * (A[i, j] - sum_val)
return L
Background: Cholesky is roughly twice as fast as LU because it exploits symmetry (only calculates lower triangle). It is numerically very stable; if the algorithm encounters a negative number inside the square root, it proves the matrix is not positive definite.
Theory: Factors $A = QR$, where $Q$ is an orthogonal matrix ($Q^T Q = I$) and $R$ is upper triangular. Solving $Ax = b \implies QRx = b \implies Rx = Q^T b$. Since $Q$ is orthogonal, multiplying by $Q^T$ does not amplify errors, making it extremely stable.
Pseudo-Code (Householder Reflections):
function QR_Householder(A):
m, n = size(A)
Q = eye(m)
R = copy(A)
for k = 1 to n:
x = R[k:m, k]
# Construct Householder vector v to zero out elements below diagonal
e1 = zeros(length(x)); e1[0] = 1
v = sign(x[0]) * norm(x) * e1 + x
v = v / norm(v)
# Apply reflection to R and Q
R[k:m, k:n] = R[k:m, k:n] - 2 * outer(v, dot(v, R[k:m, k:n]))
Q[k:m, :] = Q[k:m, :] - 2 * outer(v, dot(v, Q[k:m, :]))
return Q.T, R
Background: While Gram-Schmidt is intuitively simpler, Householder reflections are implemented in libraries like LAPACK because they maintain orthogonality much better in floating-point arithmetic.
Theory: A specialized version of Gaussian Elimination for tridiagonal matrices. Since most elements are zero, we only eliminate the sub-diagonal. System: $a_i x_{i-1} + b_i x_i + c_i x_{i+1} = d_i$
Pseudo-Code:
function TDMA(a, b, c, d):
n = length(d)
# Forward elimination
c'[0] = c[0] / b[0]
d'[0] = d[0] / b[0]
for i = 1 to n-1:
temp = b[i] - a[i] * c'[i-1]
c'[i] = c[i] / temp
d'[i] = (d[i] - a[i] * d'[i-1]) / temp
# Backward substitution
x[n-1] = d'[n-1]
for i = n-2 down to 0:
x[i] = d'[i] - c'[i] * x[i+1]
return x
Background: This is an $O(N)$ algorithm, essential for 1D PDE solvers (like solving the heat equation implicitly). It is stable if the matrix is strictly diagonally dominant ($|b_i| > |a_i| + |c_i|$).
Direct solvers convert $A$ to a soluble form. Iterative solvers strictly use matrix-vector multiplication ($v \to Av$) to search for the solution in a “Krylov Subspace” $\mathcal{K}_k = \text{span}{r_0, Ar_0, A^2r_0, \dots}$.
Theory: Discovers the solution by generating a sequence of $A$-orthogonal search directions $p_k$. This means $p_i^T A p_j = 0$ for $i \neq j$. This “conjugacy” property ensures that each step brings us optimally closer to the solution in the underlying norm, without undoing previous progress. Strictly requires $A$ to be SPD.
Pseudo-Code:
function CG(A, b, x0):
r = b - A @ x0
p = r
rho = dot(r, r)
for k = 1 to max_iter:
Ap = A @ p
alpha = rho / dot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
new_rho = dot(r, r)
if sqrt(new_rho) < tolerance:
break
beta = new_rho / rho
p = r + beta * p
rho = new_rho
return x
Background: The magic of CG is that it only requires storage of a few vectors ($x, r, p$). For SPD matrices, it is the theoretical optimal Krylov solver.
Theory: For general non-symmetric matrices, we cannot rely on short recurrences like CG. GMRES builds an orthonormal basis for the Krylov subspace explicitly (with Arnoldi iteration) and finds the vector $x_k$ in that subspace that minimizes the residual norm $| b - Ax_k |_2$.
Pseudo-Code (Simplified Arnoldi-based):
function GMRES(A, b, x0, m):
# m is restart parameter (GMRES(m))
r0 = b - A @ x0
beta = norm(r0)
V = [r0 / beta] # Basis vectors
H = zeros(m+1, m) # Hessenberg matrix
for j = 0 to m-1:
w = A @ V[j]
# Arnoldi Orthogonalization (Gram-Schmidt on Krylov vectors)
for i = 0 to j:
H[i, j] = dot(w, V[i])
w = w - H[i, j] * V[i]
H[j+1, j] = norm(w)
V.append(w / H[j+1, j])
# Solve least squares for y_k: min || beta * e1 - H_k * y ||
# Update x = x0 + V_k * y_k
Background: GMRES stores all basis vectors $V_k$, so memory grows linearly with iterations. To manage this, we use “Restarted GMRES(m)”, where we discard the basis and restart with the current $x$ as guess after $m$ steps. It is the robust default for non-symmetric systems.
Theory: An attempt to get the low memory of CG for non-symmetric systems. It uses a “shadow” Krylov subspace (using $A^T$) to maintain short recurrences (biorthogonality) like BiCG, but “stabilizes” the irregular convergence of BiCG by combining it with GMRES-like local minimization steps.
Pseudo-Code:
function BiCGStab(A, b, x0):
r = b - A @ x0
r_hat = r # Shadow residual, arbitrary
p = r
rho = dot(r_hat, r)
for k = 1 to max_iter:
v = A @ p
alpha = rho / dot(r_hat, v)
s = r - alpha * v
t = A @ s
omega = dot(t, s) / dot(t, t)
x = x + alpha * p + omega * s
r = s - omega * t
if norm(r) < tolerance: break
new_rho = dot(r_hat, r)
beta = (new_rho / rho) * (alpha / omega)
p = r + beta * (p - omega * v)
rho = new_rho
return x
Background: BiCGStab is very popular because it often converges smoothly like GMRES but with constant low memory usage like CG. However, it can “break down” (divide by zero) in rare unlucky cases, unlike GMRES which is robust.