Grammar-first generalized linear modeling in JAX. JIT-compiled end-to-end and differentiable through the fitted parameters via the implicit function theorem.
import jax.numpy as jnp
import glmax
X = jnp.array([[1.0, 0.5], [1.0, -0.3], [1.0, 1.2], [1.0, -0.8]])
y = jnp.array([2.0, 1.0, 4.0, 1.0])
fitted = glmax.fit(glmax.Poisson(), X, y)
pred = glmax.predict(fitted.family, fitted.params, X)
result = glmax.infer(fitted)
diag = glmax.check(fitted)Four verbs — fit, predict, infer, and check — cover the full modeling workflow. Each takes explicit inputs and returns an explicit result. No hidden state is threaded between calls.
predict returns response-scale fitted means, matching fitted.mu.
See the docs for the full API reference and guides.
pip install git+https://github.com/mancusolab/glmax.gitBenchmarked against statsmodels on Poisson regression. Timing uses 10 steady-state runs after JIT warm-up. Statsmodels does not support GPU/TPU natively, and all reported runtimes are CPU-based, however the colab environment has different host CPUs depending on the runtime selected. As such we report the native CPU runtimes for statsmodels for each runtime environment.
| n | p | statsmodels (ms) | glmax (ms) | speedup | runtime |
|---|---|---|---|---|---|
| 500 | 10 | 4.32 | 0.92 | 4.7× | CPU |
| 2,000 | 20 | 277.76 | 4.14 | 67.1× | CPU |
| 10,000 | 50 | 1428.76 | 42.77 | 33.4× | CPU |
| 500 | 10 | 2.97 | 3.00 | 1.0× | T4 GPU |
| 2,000 | 20 | 13.94 | 4.38 | 3.2× | T4 GPU |
| 10,000 | 50 | 212.70 | 17.94 | 11.9× | T4 GPU |
| 500 | 10 | 1.90 | 2.89 | 0.7× | v5e-1 TPU |
| 2,000 | 20 | 8.46 | 8.80 | 1.0× | v5e-1 TPU |
| 10,000 | 50 | 1220.66 | 25.65 | 47.6× | v5e-1 TPU |
See examples/benchmark_colab.ipynb for the full benchmark notebook.
pytest -p no:capture testsglmax was developed by members of the Mancuso Lab with assistance from
Claude Code and Codex, following the practices described in the
scientific-software-playbook.
MIT