Awesome JAX 

JAX brings automatic
differentiation and the XLA
compiler together through a NumPy-like API for high performance
machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other
resources. Contributions are welcome!
Contents
Libraries
- Neural Network Libraries
- Flax - Centered on
flexibility and clarity.

- Flax
NNX - An evolution on Flax by the same team

- Haiku - Focused
on simplicity, created by the authors of Sonnet at DeepMind.

- Objax - Has an object
oriented design similar to PyTorch.

- Elegy - A High Level
API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.

- Trax - “Batteries
included” deep learning library focused on providing solutions for
common workloads.

- Jraph - Lightweight
graph neural network library.

- Neural
Tangents - High-level API for specifying neural networks of both
finite and infinite width.

- HuggingFace
Transformers - Ecosystem of pretrained Transformers for a wide range
of natural language tasks (Flax).

- Equinox -
Callable PyTrees and filtered JIT/grad transformations => neural
networks in JAX.

- Scenic - A
Jax Library for Computer Vision Research and Beyond.

- Penzai -
Prioritizes legibility, visualization, and easy editing of neural
network models with composable tools and a simple mental model.

- Levanter -
Legible, Scalable, Reproducible Foundation Models with Named Tensors and
JAX.

- EasyLM - LLMs
made easy: Pre-training, finetuning, evaluating and serving LLMs in
JAX/Flax.

- NumPyro -
Probabilistic programming based on the Pyro library.

- Chex - Utilities to
write and test reliable JAX code.

- Optax - Gradient
processing and optimization library.

- RLax - Library for
implementing reinforcement learning agents.

- JAX, M.D. -
Accelerated, differential molecular dynamics.

- Coax - Turn RL papers
into code, the easy way.

- Distrax -
Reimplementation of TensorFlow Probability, containing probability
distributions and bijectors.

- cvxpylayers -
Construct differentiable convex optimization layers.

- TensorLy - Tensor
learning made simple.

- NetKet - Machine
Learning toolbox for Quantum Physics.

- Fortuna - AWS
library for Uncertainty Quantification in Deep Learning.

- BlackJAX -
Library of samplers for JAX.

New Libraries
This section contains libraries that are well-made and useful, but
have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
- FedJAX - Federated
learning in JAX, built on Optax and Haiku.

- Equivariant
MLP - Construct equivariant neural network layers.

- jax-resnet -
Implementations and checkpoints for ResNet variants in Flax.

- Parallax - Immutable
Torch Modules for JAX.

- Nonlinear Optimization
- Optimistix -
Root finding, minimisation, fixed points, and least squares.

- JAXopt - Hardware
accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

- jax-unirep -
Library implementing the UniRep
model for protein machine learning applications.

- flowjax -
Distributions and normalizing flows built as equinox modules.

- jax-flows -
Normalizing flows in JAX.

- sklearn-jax-kernels
-
scikit-learn kernel matrices using JAX.

- jax-cosmo
- Differentiable cosmology library.

- efax - Exponential
Families in JAX.

- mpi4jax -
Combine MPI operations with your Jax code on CPUs and GPUs.

- imax - Image
augmentations and transformations.

- FlaxVision -
Flax version of TorchVision.

- Oryx
- Probabilistic programming language based on program
transformations.
- Optimal Transport
Tools - Toolbox that bundles utilities to solve optimal transport
problems.
- delta PV - A
photovoltaic simulator with automatic differentation.

- jaxlie - Lie theory
library for rigid body transformations and optimization.

- BRAX - Differentiable
physics engine to simulate environments along with learning algorithms
to train agents for these environments.

- flaxmodels -
Pretrained models for Jax/Flax.

- CR.Sparse
- XLA accelerated algorithms for sparse representations and compressive
sensing.

- exojax -
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
compatible to JAX.

- PIX - PIX is an
image processing library in JAX, for JAX.

- bayex - Bayesian
Optimization powered by JAX.

- JaxDF - Framework for
differentiable simulators with arbitrary discretizations.

- tree-math -
Convert functions that operate on arrays into functions that operate on
PyTrees.

- jax-models -
Implementations of research papers originally without code or code
written with frameworks other than JAX.

- PGMax - A
framework for building discrete Probabilistic Graphical Models (PGM’s)
and running inference inference on them via JAX.

- EvoJAX -
Hardware-Accelerated Neuroevolution

- evosax -
JAX-Based Evolution Strategies

- SymJAX - Symbolic
CPU/GPU/TPU programming.

- mcx - Express &
compile probabilistic programs for performant inference.

- Einshape -
DSL-based reshaping library for JAX and other frameworks.

- ALX
- Open-source library for distributed matrix factorization using
Alternating Least Squares, more info in ALX: Large Scale Matrix
Factorization on TPUs.
- Diffrax -
Numerical differential equation solvers in JAX.

- tinygp - The
tiniest of Gaussian process libraries in JAX.

- gymnax -
Reinforcement Learning Environments with the well-known gym API.

- Mctx - Monte Carlo
tree search algorithms in native JAX.

- KFAC-JAX - Second
Order Optimization with Approximate Curvature for NNs.

- TF2JAX - Convert
functions/graphs to JAX functions.

- jwave - A library for
differentiable acoustic simulations

- GPJax - Gaussian
processes in JAX.
- Jumanji - A
Suite of Industry-Driven Hardware-Accelerated RL Environments written in
JAX.

- Eqxvision -
Equinox version of Torchvision.

- JAXFit
- Accelerated curve fitting library for nonlinear least-squares problems
(see arXiv paper).

- econpizza - Solve
macroeconomic models with hetereogeneous agents using JAX.

- SPU - A
domain-specific compiler and runtime suite to run JAX code with
MPC(Secure Multi-Party Computation).

- jax-tqdm -
Add a tqdm progress bar to JAX scans and loops.

- safejax -
Serialize JAX, Flax, Haiku, or Objax model params with
🤗
safetensors.

- Kernex -
Differentiable stencil decorators in JAX.

- MaxText - A simple,
performant and scalable Jax LLM written in pure Python/Jax and targeting
Google Cloud TPUs.

- Pax - A Jax-based
machine learning framework for training large scale models.

- Praxis - The layer
library for Pax with a goal to be usable by other JAX-based ML projects.

- purejaxrl -
Vectorisable, end-to-end RL algorithms in JAX.

- Lorax -
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
- SCICO - Scientific
computational imaging in JAX.

- Spyx - Spiking Neural
Networks in JAX for machine learning on neuromorphic hardware.

- Brain Dynamics Programming Ecosystem
- BrainPy - Brain
Dynamics Programming in Python.

- brainunit -
Physical units and unit-aware mathematical system in JAX.

- dendritex -
Dendritic Modeling in JAX.

- brainstate -
State-based Transformation System for Program Compilation and
Augmentation.

- braintaichi -
Leveraging Taichi Lang to customize brain dynamics operators.

- OTT-JAX - Optimal
transport tools in JAX.

- QDax -
Quality Diversity optimization in Jax.

- JAX Toolbox -
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries
such as T5x, Paxml, and Transformer Engine.

- Pgx - Vectorized board
game environments for RL with an AlphaZero example.

- EasyDeL - EasyDeL
🔮 is an OpenSource Library to make your training faster and more
Optimized With cool Options for training and serving (Llama, MPT,
Mixtral, Falcon, etc) in JAX

- XLB - A Differentiable
Massively Parallel Lattice Boltzmann Library in Python for Physics-Based
Machine Learning.

- dynamiqs -
High-performance and differentiable simulations of quantum systems with
JAX.

- foragax -
Agent-Based modelling framework in JAX.

- tmmax - Vectorized
calculation of optical properties in thin-film structures using JAX.
Swiss Army knife tool for thin-film optics research

- Coreax - Algorithms for
finding coresets to compress large datasets while retaining their
statistical properties.

- NAVIX - A
reimplementation of MiniGrid, a Reinforcement Learning environment, in
JAX

Models and Projects
JAX
Flax
Haiku
Trax
- Reformer
- Implementation of the Reformer (efficient transformer)
architecture.
NumPyro
Videos
- NeurIPS 2020:
JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion
between engineers, scientists, and JAX core team.
- Introduction to JAX -
Simple neural network from scratch in JAX.
- JAX: Accelerated Machine
Learning Research | SciPy 2020 | VanderPlas - JAX’s core design, how
it’s powering new research, and how you can start using it.
- Bayesian Programming with JAX
+ NumPyro — Andy Kitchen - Introduction to Bayesian modelling using
NumPyro.
- JAX:
Accelerated machine-learning research via composable function
transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne -
JAX intro presentation in Program
Transformations for Machine Learning workshop.
- JAX
on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
Bradbury - Presentation of TPU host access with demo.
- Deep
Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
NeurIPS 2020 - Tutorial created by Zico Kolter, David Duvenaud, and
Matt Johnson with Colab notebooks avaliable in Deep Implicit
Layers.
- Solving y=mx+b
with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube
tutorial series with Colab notebooks that starts with Jax fundamentals
and moves up to training with a data parallel approach on a v3-32 TPU
Pod slice.
- JAX,
Flax & Transformers 🤗 - 3 days of talks around JAX / Flax,
Transformers, large-scale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAX-based library
whitepapers, research on JAX, etc). Papers implemented in JAX are listed
in the Models/Projects section.
- Compiling
machine learning programs via high-level tracing. Roy Frostig,
Matthew James Johnson, Chris Leary. MLSys 2018. - White
paper describing an early version of JAX, detailing how computation is
traced and compiled.
- JAX, M.D.: A
Framework for Differentiable Physics. Samuel S. Schoenholz,
Ekin D. Cubuk. NeurIPS 2020. - Introduces JAX, M.D., a
differentiable physics library which includes simulation environments,
interaction potentials, neural networks, and more.
- Enabling Fast
Differentially Private SGD via Just-in-Time Compilation and
Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam
Kamath. arXiv 2020. - Uses JAX’s JIT and VMAP to achieve
faster differentially private than existing libraries.
- XLB: A
Differentiable Massively Parallel Lattice Boltzmann Library in
Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv
2023. - White paper describing the XLB library: benchmarks,
validations, and more details about the library.
Tutorials and Blog Posts
- Using
JAX to accelerate our research by David Budden and Matteo Hessel -
Describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting
started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural
network building blocks from scratch with the basic JAX operators.
- Learn JAX: From
Linear Regression to Neural Networks by Rito Ghosh - A gentle
introduction to JAX and using it to implement Linear and Logistic
Regression, and Neural Network models and using them to solve real world
problems.
- Tutorial:
image classification with JAX and Flax Linen by 8bitmp3 - Learn how
to create a simple convolutional network with the Linen API by Flax and
train it to recognize handwritten digits.
- Plugging
Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the
Kaggle flower classification challenge.
- Meta-Learning in 50
Lines of JAX by Eric Jang - Introduction to both JAX and
Meta-Learning.
- Normalizing
Flows in 100 Lines of JAX by Eric Jang - Concise implementation of
RealNVP.
- Differentiable
Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing
path tracing.
- Ensemble networks by
Mat Kelcey - Ensemble nets are a method of representing an ensemble
of models as one single logical model.
- Out of
distribution (OOD) detection by Mat Kelcey - Implements different
methods for OOD detection.
- Understanding Autodiff with
JAX by Srihari Radhakrishna - Understand how autodiff works using
JAX.
- From PyTorch to JAX:
towards neural net frameworks that purify stateful code by Sabrina J.
Mielke - Showcases how to go from a PyTorch-like style of coding to
a more Functional-style of coding.
- Extending JAX with
custom C++ and CUDA code by Dan Foreman-Mackey - Tutorial
demonstrating the infrastructure required to provide custom ops in
JAX.
- Evolving
Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can
power the next generation of scalable neuroevolution algorithms.
- Exploring
hyperparameter meta-loss landscapes with JAX by Luke Metz -
Demonstrates how to use JAX to perform inner-loss optimization with SGD
and Momentum, outer-loss optimization with gradients, and outer-loss
optimization using evolutionary strategies.
- Deterministic
ADVI in JAX by Martin Ingram - Walk through of implementing
automatic differentiation variational inference (ADVI) easily and
cleanly with JAX.
- Evolved
channel selection by Mat Kelcey - Trains a classification model
robust to different combinations of input channels at different
resolutions, then uses a genetic algorithm to decide the best
combination for a particular loss.
- Introduction
to JAX by Kevin Murphy - Colab that introduces various aspects of
the language and applies them to simple ML problems.
- Writing
an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the
different ways to write an MCMC sampler in JAX along with speed
benchmarks.
- How
to add a progress bar to JAX scans and loops by Jeremie Coullon -
Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callback module.
- Get
started with JAX by Aleksa Gordić - A series of notebooks and videos
going from zero JAX knowledge to building neural networks in Haiku.
- Writing
a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit
- A tutorial on writing a simple end-to-end training and evaluation
pipeline in JAX, Flax and Optax.
- Implementing
NeRF in JAX by Soumik Rakshit and Saurav Maheshkar - A tutorial on
3D volumetric rendering of scenes represented by Neural Radiance Fields
in JAX.
- Deep
Learning tutorials with JAX+Flax by Phillip Lippe - A series of
notebooks explaining various deep learning concepts, from basics
(e.g. intro to JAX/Flax, activiation functions) to recent advances
(e.g., Vision Transformers, SimCLR), with translations to PyTorch.
- Achieving 4000x
Speedups with PureJaxRL - A blog post on how JAX can massively
speedup RL training through vectorisation.
- Simple
PDE solver + Constrained Optimization with JAX by Philip Mocz - A
simple example of solving the advection-diffusion equations with JAX and
using it in a constrained optimization problem to find initial
conditions that yield desired result.
Books
- Jax in
Action - A hands-on guide to using JAX for deep learning and other
mathematically-intensive applications.
Contributing
Contributions welcome! Read the contribution guidelines first.
jax.md
Github