Files
awesome-awesomeness/html/jax.html
2025-07-18 22:22:32 +02:00

909 lines
50 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!--lint ignore double-link-->
<h1 id="awesome-jax-awesome">Awesome JAX <a
href="https://awesome.re"><img src="https://awesome.re/badge.svg"
alt="Awesome" /></a><a
href="https://github.com/google/jax"><img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="JAX Logo" align="right" height="100"></a></h1>
<!--lint ignore double-link-->
<p><a href="https://github.com/google/jax">JAX</a> brings automatic
differentiation and the <a href="https://www.tensorflow.org/xla">XLA
compiler</a> together through a <a
href="https://numpy.org/">NumPy</a>-like API for high performance
machine learning research on accelerators like GPUs and TPUs.
<!--lint enable double-link--></p>
<p>This is a curated list of awesome JAX libraries, projects, and other
resources. Contributions are welcome!</p>
<h2 id="contents">Contents</h2>
<ul>
<li><a href="#libraries">Libraries</a></li>
<li><a href="#models-and-projects">Models and Projects</a></li>
<li><a href="#videos">Videos</a></li>
<li><a href="#papers">Papers</a></li>
<li><a href="#tutorials-and-blog-posts">Tutorials and Blog
Posts</a></li>
<li><a href="#books">Books</a></li>
<li><a href="#community">Community</a></li>
</ul>
<p><a name="libraries" /></p>
<h2 id="libraries">Libraries</h2>
<ul>
<li>Neural Network Libraries
<ul>
<li><a href="https://github.com/google/flax">Flax</a> - Centered on
flexibility and clarity.
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
<li><a href="https://github.com/google/flax/tree/main/flax/nnx">Flax
NNX</a> - An evolution on Flax by the same team
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/dm-haiku">Haiku</a> - Focused
on simplicity, created by the authors of Sonnet at DeepMind.
<img src="https://img.shields.io/github/stars/deepmind/dm-haiku?style=social" align="center"></li>
<li><a href="https://github.com/google/objax">Objax</a> - Has an object
oriented design similar to PyTorch.
<img src="https://img.shields.io/github/stars/google/objax?style=social" align="center"></li>
<li><a href="https://poets-ai.github.io/elegy/">Elegy</a> - A High Level
API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
<img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center"></li>
<li><a href="https://github.com/google/trax">Trax</a> - “Batteries
included” deep learning library focused on providing solutions for
common workloads.
<img src="https://img.shields.io/github/stars/google/trax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/jraph">Jraph</a> - Lightweight
graph neural network library.
<img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center"></li>
<li><a href="https://github.com/google/neural-tangents">Neural
Tangents</a> - High-level API for specifying neural networks of both
finite and <em>infinite</em> width.
<img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center"></li>
<li><a href="https://github.com/huggingface/transformers">HuggingFace
Transformers</a> - Ecosystem of pretrained Transformers for a wide range
of natural language tasks (Flax).
<img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center"></li>
<li><a href="https://github.com/patrick-kidger/equinox">Equinox</a> -
Callable PyTrees and filtered JIT/grad transformations =&gt; neural
networks in JAX.
<img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center"></li>
<li><a href="https://github.com/google-research/scenic">Scenic</a> - A
Jax Library for Computer Vision Research and Beyond.
<img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center"></li>
<li><a href="https://github.com/google-deepmind/penzai">Penzai</a> -
Prioritizes legibility, visualization, and easy editing of neural
network models with composable tools and a simple mental model.
<img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center"></li>
</ul></li>
<li><a href="https://github.com/stanford-crfm/levanter">Levanter</a> -
Legible, Scalable, Reproducible Foundation Models with Named Tensors and
JAX.
<img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center"></li>
<li><a href="https://github.com/young-geng/EasyLM">EasyLM</a> - LLMs
made easy: Pre-training, finetuning, evaluating and serving LLMs in
JAX/Flax.
<img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center"></li>
<li><a href="https://github.com/pyro-ppl/numpyro">NumPyro</a> -
Probabilistic programming based on the Pyro library.
<img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/chex">Chex</a> - Utilities to
write and test reliable JAX code.
<img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/optax">Optax</a> - Gradient
processing and optimization library.
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/rlax">RLax</a> - Library for
implementing reinforcement learning agents.
<img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center"></li>
<li><a href="https://github.com/google/jax-md">JAX, M.D.</a> -
Accelerated, differential molecular dynamics.
<img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center"></li>
<li><a href="https://github.com/coax-dev/coax">Coax</a> - Turn RL papers
into code, the easy way.
<img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/distrax">Distrax</a> -
Reimplementation of TensorFlow Probability, containing probability
distributions and bijectors.
<img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center"></li>
<li><a href="https://github.com/cvxgrp/cvxpylayers">cvxpylayers</a> -
Construct differentiable convex optimization layers.
<img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center"></li>
<li><a href="https://github.com/tensorly/tensorly">TensorLy</a> - Tensor
learning made simple.
<img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center"></li>
<li><a href="https://github.com/netket/netket">NetKet</a> - Machine
Learning toolbox for Quantum Physics.
<img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center"></li>
<li><a href="https://github.com/awslabs/fortuna">Fortuna</a> - AWS
library for Uncertainty Quantification in Deep Learning.
<img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center"></li>
<li><a href="https://github.com/blackjax-devs/blackjax">BlackJAX</a> -
Library of samplers for JAX.
<img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center"></li>
</ul>
<p><a name="new-libraries" /></p>
<h3 id="new-libraries">New Libraries</h3>
<p>This section contains libraries that are well-made and useful, but
have not necessarily been battle-tested by a large userbase yet.</p>
<ul>
<li>Neural Network Libraries
<ul>
<li><a href="https://github.com/google/fedjax">FedJAX</a> - Federated
learning in JAX, built on Optax and Haiku.
<img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center"></li>
<li><a href="https://github.com/mfinzi/equivariant-MLP">Equivariant
MLP</a> - Construct equivariant neural network layers.
<img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center"></li>
<li><a href="https://github.com/n2cholas/jax-resnet/">jax-resnet</a> -
Implementations and checkpoints for ResNet variants in Flax.
<img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center"></li>
<li><a href="https://github.com/srush/parallax">Parallax</a> - Immutable
Torch Modules for JAX.
<img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center"></li>
</ul></li>
<li>Nonlinear Optimization
<ul>
<li><a
href="https://github.com/patrick-kidger/optimistix">Optimistix</a> -
Root finding, minimisation, fixed points, and least squares.
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
<li><a href="https://github.com/google/jaxopt">JAXopt</a> - Hardware
accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
<img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center"></li>
</ul></li>
<li><a href="https://github.com/ElArkk/jax-unirep">jax-unirep</a> -
Library implementing the <a
href="https://www.nature.com/articles/s41592-019-0598-1">UniRep
model</a> for protein machine learning applications.
<img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center"></li>
<li><a href="https://github.com/danielward27/flowjax">flowjax</a> -
Distributions and normalizing flows built as equinox modules.
<img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center"></li>
<li><a href="https://github.com/ChrisWaites/jax-flows">jax-flows</a> -
Normalizing flows in JAX.
<img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center"></li>
<li><a
href="https://github.com/ExpectationMax/sklearn-jax-kernels">sklearn-jax-kernels</a>
- <code>scikit-learn</code> kernel matrices using JAX.
<img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center"></li>
<li><a
href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo">jax-cosmo</a>
- Differentiable cosmology library.
<img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center"></li>
<li><a href="https://github.com/NeilGirdhar/efax">efax</a> - Exponential
Families in JAX.
<img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center"></li>
<li><a href="https://github.com/PhilipVinc/mpi4jax">mpi4jax</a> -
Combine MPI operations with your Jax code on CPUs and GPUs.
<img src="https://img.shields.io/github/stars/PhilipVinc/mpi4jax?style=social" align="center"></li>
<li><a href="https://github.com/4rtemi5/imax">imax</a> - Image
augmentations and transformations.
<img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center"></li>
<li><a href="https://github.com/rolandgvc/flaxvision">FlaxVision</a> -
Flax version of TorchVision.
<img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center"></li>
<li><a
href="https://github.com/tensorflow/probability/tree/master/spinoffs/oryx">Oryx</a>
- Probabilistic programming language based on program
transformations.</li>
<li><a href="https://github.com/google-research/ott">Optimal Transport
Tools</a> - Toolbox that bundles utilities to solve optimal transport
problems.</li>
<li><a href="https://github.com/romanodev/deltapv">delta PV</a> - A
photovoltaic simulator with automatic differentation.
<img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center"></li>
<li><a href="https://github.com/brentyi/jaxlie">jaxlie</a> - Lie theory
library for rigid body transformations and optimization.
<img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center"></li>
<li><a href="https://github.com/google/brax">BRAX</a> - Differentiable
physics engine to simulate environments along with learning algorithms
to train agents for these environments.
<img src="https://img.shields.io/github/stars/google/brax?style=social" align="center"></li>
<li><a
href="https://github.com/matthias-wright/flaxmodels">flaxmodels</a> -
Pretrained models for Jax/Flax.
<img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center"></li>
<li><a href="https://github.com/carnotresearch/cr-sparse">CR.Sparse</a>
- XLA accelerated algorithms for sparse representations and compressive
sensing.
<img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center"></li>
<li><a href="https://github.com/HajimeKawahara/exojax">exojax</a> -
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
compatible to JAX.
<img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/dm_pix">PIX</a> - PIX is an
image processing library in JAX, for JAX.
<img src="https://img.shields.io/github/stars/deepmind/dm_pix?style=social" align="center"></li>
<li><a href="https://github.com/alonfnt/bayex">bayex</a> - Bayesian
Optimization powered by JAX.
<img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center"></li>
<li><a href="https://github.com/ucl-bug/jaxdf">JaxDF</a> - Framework for
differentiable simulators with arbitrary discretizations.
<img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center"></li>
<li><a href="https://github.com/google/tree-math">tree-math</a> -
Convert functions that operate on arrays into functions that operate on
PyTrees.
<img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center"></li>
<li><a
href="https://github.com/DarshanDeshpande/jax-models">jax-models</a> -
Implementations of research papers originally without code or code
written with frameworks other than JAX.
<img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center"></li>
<li><a href="https://github.com/vicariousinc/PGMax">PGMax</a> - A
framework for building discrete Probabilistic Graphical Models (PGMs)
and running inference inference on them via JAX.
<img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center"></li>
<li><a href="https://github.com/google/evojax">EvoJAX</a> -
Hardware-Accelerated Neuroevolution
<img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center"></li>
<li><a href="https://github.com/RobertTLange/evosax">evosax</a> -
JAX-Based Evolution Strategies
<img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center"></li>
<li><a href="https://github.com/SymJAX/SymJAX">SymJAX</a> - Symbolic
CPU/GPU/TPU programming.
<img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center"></li>
<li><a href="https://github.com/rlouf/mcx">mcx</a> - Express &amp;
compile probabilistic programs for performant inference.
<img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/einshape">Einshape</a> -
DSL-based reshaping library for JAX and other frameworks.
<img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center"></li>
<li><a
href="https://github.com/google-research/google-research/tree/master/alx">ALX</a>
- Open-source library for distributed matrix factorization using
Alternating Least Squares, more info in <a
href="https://arxiv.org/abs/2112.02194"><em>ALX: Large Scale Matrix
Factorization on TPUs</em></a>.</li>
<li><a href="https://github.com/patrick-kidger/diffrax">Diffrax</a> -
Numerical differential equation solvers in JAX.
<img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center"></li>
<li><a href="https://github.com/dfm/tinygp">tinygp</a> - The
<em>tiniest</em> of Gaussian process libraries in JAX.
<img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center"></li>
<li><a href="https://github.com/RobertTLange/gymnax">gymnax</a> -
Reinforcement Learning Environments with the well-known gym API.
<img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/mctx">Mctx</a> - Monte Carlo
tree search algorithms in native JAX.
<img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/kfac-jax">KFAC-JAX</a> - Second
Order Optimization with Approximate Curvature for NNs.
<img src="https://img.shields.io/github/stars/deepmind/kfac-jax?style=social" align="center"></li>
<li><a href="https://github.com/deepmind/tf2jax">TF2JAX</a> - Convert
functions/graphs to JAX functions.
<img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center"></li>
<li><a href="https://github.com/ucl-bug/jwave">jwave</a> - A library for
differentiable acoustic simulations
<img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center"></li>
<li><a href="https://github.com/thomaspinder/GPJax">GPJax</a> - Gaussian
processes in JAX.</li>
<li><a href="https://github.com/instadeepai/jumanji">Jumanji</a> - A
Suite of Industry-Driven Hardware-Accelerated RL Environments written in
JAX.
<img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center"></li>
<li><a href="https://github.com/paganpasta/eqxvision">Eqxvision</a> -
Equinox version of Torchvision.
<img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center"></li>
<li><a href="https://github.com/dipolar-quantum-gases/jaxfit">JAXFit</a>
- Accelerated curve fitting library for nonlinear least-squares problems
(see <a href="https://arxiv.org/abs/2208.12187">arXiv paper</a>).
<img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center"></li>
<li><a href="https://github.com/gboehl/econpizza">econpizza</a> - Solve
macroeconomic models with hetereogeneous agents using JAX.
<img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center"></li>
<li><a href="https://github.com/secretflow/spu">SPU</a> - A
domain-specific compiler and runtime suite to run JAX code with
MPC(Secure Multi-Party Computation).
<img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center"></li>
<li><a href="https://github.com/jeremiecoullon/jax-tqdm">jax-tqdm</a> -
Add a tqdm progress bar to JAX scans and loops.
<img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center"></li>
<li><a href="https://github.com/alvarobartt/safejax">safejax</a> -
Serialize JAX, Flax, Haiku, or Objax model params with
🤗<code>safetensors</code>.
<img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center"></li>
<li><a href="https://github.com/ASEM000/kernex">Kernex</a> -
Differentiable stencil decorators in JAX.
<img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center"></li>
<li><a href="https://github.com/google/maxtext">MaxText</a> - A simple,
performant and scalable Jax LLM written in pure Python/Jax and targeting
Google Cloud TPUs.
<img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center"></li>
<li><a href="https://github.com/google/paxml">Pax</a> - A Jax-based
machine learning framework for training large scale models.
<img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center"></li>
<li><a href="https://github.com/google/praxis">Praxis</a> - The layer
library for Pax with a goal to be usable by other JAX-based ML projects.
<img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center"></li>
<li><a href="https://github.com/luchris429/purejaxrl">purejaxrl</a> -
Vectorisable, end-to-end RL algorithms in JAX.
<img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center"></li>
<li><a href="https://github.com/davisyoshida/lorax">Lorax</a> -
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)</li>
<li><a href="https://github.com/lanl/scico">SCICO</a> - Scientific
computational imaging in JAX.
<img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center"></li>
<li><a href="https://github.com/kmheckel/spyx">Spyx</a> - Spiking Neural
Networks in JAX for machine learning on neuromorphic hardware.
<img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center"></li>
<li>Brain Dynamics Programming Ecosystem
<ul>
<li><a href="https://github.com/brainpy/BrainPy">BrainPy</a> - Brain
Dynamics Programming in Python.
<img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center"></li>
<li><a href="https://github.com/chaobrain/brainunit">brainunit</a> -
Physical units and unit-aware mathematical system in JAX.
<img src="https://img.shields.io/github/stars/chaobrain/brainunit?style=social" align="center"></li>
<li><a href="https://github.com/chaobrain/dendritex">dendritex</a> -
Dendritic Modeling in JAX.
<img src="https://img.shields.io/github/stars/chaobrain/dendritex?style=social" align="center"></li>
<li><a href="https://github.com/chaobrain/brainstate">brainstate</a> -
State-based Transformation System for Program Compilation and
Augmentation.
<img src="https://img.shields.io/github/stars/chaobrain/brainstate?style=social" align="center"></li>
<li><a href="https://github.com/chaobrain/braintaichi">braintaichi</a> -
Leveraging Taichi Lang to customize brain dynamics operators.
<img src="https://img.shields.io/github/stars/chaobrain/braintaichi?style=social" align="center"></li>
</ul></li>
<li><a href="https://github.com/ott-jax/ott">OTT-JAX</a> - Optimal
transport tools in JAX.
<img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center"></li>
<li><a
href="https://github.com/adaptive-intelligent-robotics/QDax">QDax</a> -
Quality Diversity optimization in Jax.
<img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center"></li>
<li><a href="https://github.com/NVIDIA/JAX-Toolbox">JAX Toolbox</a> -
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries
such as T5x, Paxml, and Transformer Engine.
<img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center"></li>
<li><a href="http://github.com/sotetsuk/pgx">Pgx</a> - Vectorized board
game environments for RL with an AlphaZero example.
<img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center"></li>
<li><a href="https://github.com/erfanzar/EasyDeL">EasyDeL</a> - EasyDeL
🔮 is an OpenSource Library to make your training faster and more
Optimized With cool Options for training and serving (Llama, MPT,
Mixtral, Falcon, etc) in JAX
<img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center"></li>
<li><a href="https://github.com/Autodesk/XLB">XLB</a> - A Differentiable
Massively Parallel Lattice Boltzmann Library in Python for Physics-Based
Machine Learning.
<img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center"></li>
<li><a href="https://github.com/dynamiqs/dynamiqs">dynamiqs</a> -
High-performance and differentiable simulations of quantum systems with
JAX.
<img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center"></li>
<li><a href="https://github.com/i-m-iron-man/Foragax">foragax</a> -
Agent-Based modelling framework in JAX.
<img src="https://img.shields.io/github/stars/i-m-iron-man/Foragax?style=social" align="center"></li>
<li><a href="https://github.com/bahremsd/tmmax">tmmax</a> - Vectorized
calculation of optical properties in thin-film structures using JAX.
Swiss Army knife tool for thin-film optics research
<img src="https://img.shields.io/github/stars/bahremsd/tmmax" align="center"></li>
<li><a href="https://github.com/gchq/coreax">Coreax</a> - Algorithms for
finding coresets to compress large datasets while retaining their
statistical properties.
<img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center"></li>
<li><a href="https://github.com/epignatelli/navix">NAVIX</a> - A
reimplementation of MiniGrid, a Reinforcement Learning environment, in
JAX
<img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center"></li>
</ul>
<p><a name="models-and-projects" /></p>
<h2 id="models-and-projects">Models and Projects</h2>
<h3 id="jax">JAX</h3>
<ul>
<li><a href="https://github.com/tancik/fourier-feature-networks">Fourier
Feature Networks</a> - Official implementation of <a
href="https://people.eecs.berkeley.edu/~bmild/fourfeat"><em>Fourier
Features Let Networks Learn High Frequency Functions in Low Dimensional
Domains</em></a>.</li>
<li><a href="https://github.com/AaltoML/kalman-jax">kalman-jax</a> -
Approximate inference for Markov (i.e., temporal) Gaussian processes
using iterated Kalman filtering and smoothing.</li>
<li><a href="https://github.com/Joshuaalbert/jaxns">jaxns</a> - Nested
sampling in JAX.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/amortized_bo">Amortized
Bayesian Optimization</a> - Code related to <a
href="http://www.auai.org/uai2020/proceedings/329_main_paper.pdf"><em>Amortized
Bayesian Optimization over Discrete Spaces</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/aqt">Accurate
Quantized Training</a> - Tools and libraries for running and analyzing
neural network quantization experiments in JAX and Flax.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/bnn_hmc">BNN-HMC</a>
- Implementation for the paper <a
href="https://arxiv.org/abs/2104.14421"><em>What Are Bayesian Neural
Network Posteriors Really Like?</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/jax_dft">JAX-DFT</a>
- One-dimensional density functional theory (DFT) in JAX, with
implementation of <a
href="https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.126.036401"><em>Kohn-Sham
equations as regularizer: building prior knowledge into machine-learned
physics</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/robust_loss_jax">Robust
Loss</a> - Reference code for the paper <a
href="https://arxiv.org/abs/1701.03077"><em>A General and Adaptive
Robust Loss Function</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/symbolic_functionals">Symbolic
Functionals</a> - Demonstration from <a
href="https://arxiv.org/abs/2203.02540"><em>Evolving symbolic density
functionals</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/trimap">TriMap</a>
- Official JAX implementation of <a
href="https://arxiv.org/abs/1910.00204"><em>TriMap: Large-scale
Dimensionality Reduction Using Triplets</em></a>.</li>
</ul>
<h3 id="flax">Flax</h3>
<ul>
<li><a
href="https://github.com/J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B">DeepSeek-R1-Flax-1.5B-Distill</a>
- Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/performer/fast_attention/jax">Performer</a>
- Flax implementation of the Performer (linear transformer via FAVOR+)
architecture.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/jaxnerf">JaxNeRF</a>
- Implementation of <a
href="http://www.matthewtancik.com/nerf"><em>NeRF: Representing Scenes
as Neural Radiance Fields for View Synthesis</em></a> with multi-device
GPU/TPU support.</li>
<li><a href="https://github.com/google/mipnerf">mip-NeRF</a> - Official
implementation of <a href="https://jonbarron.info/mipnerf"><em>Mip-NeRF:
A Multiscale Representation for Anti-Aliasing Neural Radiance
Fields</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/regnerf">RegNeRF</a>
- Official implementation of <a
href="https://m-niemeyer.github.io/regnerf/"><em>RegNeRF: Regularizing
Neural Radiance Fields for View Synthesis from Sparse
Inputs</em></a>.</li>
<li><a href="https://github.com/huangjuite/jaxneus">JaxNeuS</a> -
Implementation of <a
href="https://lingjie0206.github.io/papers/NeuS/"><em>NeuS: Learning
Neural Implicit Surfaces by Volume Rendering for Multi-view
Reconstruction</em></a></li>
<li><a href="https://github.com/google-research/big_transfer">Big
Transfer (BiT)</a> - Implementation of <a
href="https://arxiv.org/abs/1912.11370"><em>Big Transfer (BiT): General
Visual Representation Learning</em></a>.</li>
<li><a href="https://github.com/ikostrikov/jax-rl">JAX RL</a> -
Implementations of reinforcement learning algorithms.</li>
<li><a href="https://github.com/SauravMaheshkar/gMLP">gMLP</a> -
Implementation of <a href="https://arxiv.org/abs/2105.08050"><em>Pay
Attention to MLPs</em></a>.</li>
<li><a href="https://github.com/SauravMaheshkar/MLP-Mixer">MLP Mixer</a>
- Minimal implementation of <a
href="https://arxiv.org/abs/2105.01601"><em>MLP-Mixer: An all-MLP
Architecture for Vision</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/scalable_shampoo">Distributed
Shampoo</a> - Implementation of <a
href="https://arxiv.org/abs/2002.09018"><em>Second Order Optimization
Made Practical</em></a>.</li>
<li><a
href="https://github.com/google-research/nested-transformer">NesT</a> -
Official implementation of <a
href="https://arxiv.org/abs/2105.12723"><em>Aggregating Nested
Transformers</em></a>.</li>
<li><a
href="https://github.com/google-research/xmcgan_image_generation">XMC-GAN</a>
- Official implementation of <a
href="https://arxiv.org/abs/2101.04702"><em>Cross-Modal Contrastive
Learning for Text-to-Image Generation</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/f_net">FNet</a>
- Official implementation of <a
href="https://arxiv.org/abs/2105.03824"><em>FNet: Mixing Tokens with
Fourier Transforms</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/gfsa">GFSA</a>
- Official implementation of <a
href="https://arxiv.org/abs/2007.04929"><em>Learning Graph Structure
With A Finite-State Automaton Layer</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/ipagnn">IPA-GNN</a>
- Official implementation of <a
href="https://arxiv.org/abs/2010.12621"><em>Learning to Execute Programs
with Instruction Pointer Attention Graph Neural Networks</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/flax_models">Flax
Models</a> - Collection of models and methods implemented in Flax.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/protein_lm">Protein
LM</a> - Implements BERT and autoregressive models for proteins, as
described in <a
href="https://www.biorxiv.org/content/10.1101/622803v1.full"><em>Biological
Structure and Function Emerge from Scaling Unsupervised Learning to 250
Million Protein Sequences</em></a> and <a
href="https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2"><em>ProGen:
Language Modeling for Protein Generation</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/ptopk_patch_selection">Slot
Attention</a> - Reference implementation for <a
href="https://arxiv.org/abs/2104.03059"><em>Differentiable Patch
Selection for Image Recognition</em></a>.</li>
<li><a
href="https://github.com/google-research/vision_transformer">Vision
Transformer</a> - Official implementation of <a
href="https://arxiv.org/abs/2010.11929"><em>An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale</em></a>.</li>
<li><a href="https://github.com/matthias-wright/jax-fid">FID
computation</a> - Port of <a
href="https://github.com/mseitzer/pytorch-fid">mseitzer/pytorch-fid</a>
to Flax.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/autoregressive_diffusion">ARDM</a>
- Official implementation of <a
href="https://arxiv.org/abs/2110.02037"><em>Autoregressive Diffusion
Models</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/d3pm">D3PM</a>
- Official implementation of <a
href="https://arxiv.org/abs/2107.03006"><em>Structured Denoising
Diffusion Models in Discrete State-Spaces</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/gumbel_max_causal_gadgets">Gumbel-max
Causal Mechanisms</a> - Code for <a
href="https://arxiv.org/abs/2111.06888"><em>Learning Generalized
Gumbel-max Causal Mechanisms</em></a>, with extra code in <a
href="https://github.com/GuyLor/gumbel_max_causal_gadgets_part2">GuyLor/gumbel_max_causal_gadgets_part2</a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/latent_programmer">Latent
Programmer</a> - Code for the ICML 2021 paper <a
href="https://arxiv.org/abs/2012.00377"><em>Latent Programmer: Discrete
Latent Codes for Program Synthesis</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/snerg">SNeRG</a>
- Official implementation of <a
href="https://phog.github.io/snerg"><em>Baking Neural Radiance Fields
for Real-Time View Synthesis</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/spin_spherical_cnns">Spin-weighted
Spherical CNNs</a> - Adaptation of <a
href="https://arxiv.org/abs/2006.10731"><em>Spin-Weighted Spherical
CNNs</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/vdvae_flax">VDVAE</a>
- Adaptation of <a href="https://arxiv.org/abs/2011.10650"><em>Very Deep
VAEs Generalize Autoregressive Models and Can Outperform Them on
Images</em></a>, original code at <a
href="https://github.com/openai/vdvae">openai/vdvae</a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/musiq">MUSIQ</a>
- Checkpoints and model inference code for the ICCV 2021 paper <a
href="https://arxiv.org/abs/2108.05997"><em>MUSIQ: Multi-scale Image
Quality Transformer</em></a></li>
<li><a
href="https://github.com/google-research/google-research/tree/master/aquadem">AQuaDem</a>
- Official implementation of <a
href="https://arxiv.org/abs/2110.10149"><em>Continuous Control with
Action Quantization from Demonstrations</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/combiner">Combiner</a>
- Official implementation of <a
href="https://arxiv.org/abs/2107.05768"><em>Combiner: Full Attention
Transformer with Sparse Computation Cost</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/dreamfields">Dreamfields</a>
- Official implementation of the ICLR 2022 paper <a
href="https://ajayj.com/dreamfields"><em>Progressive Distillation for
Fast Sampling of Diffusion Models</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/gift">GIFT</a>
- Official implementation of <a
href="https://arxiv.org/abs/2106.06080"><em>Gradual Domain Adaptation in
the Wild:When Intermediate Distributions are Absent</em></a>.</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/light_field_neural_rendering">Light
Field Neural Rendering</a> - Official implementation of <a
href="https://arxiv.org/abs/2112.09687"><em>Light Field Neural
Rendering</em></a>.</li>
<li><a
href="https://colab.research.google.com/drive/1KUKFEMneQMS3OzPYnWZGkEnry3PdzCfn?usp=sharing">Sharpened
Cosine Similarity in JAX by Raphael Pisoni</a> - A JAX/Flax
implementation of the Sharpened Cosine Similarity layer.</li>
<li><a
href="https://github.com/IvanIsCoding/GNN-for-Combinatorial-Optimization">GNNs
for Solving Combinatorial Optimization Problems</a> - A JAX + Flax
implementation of <a
href="https://arxiv.org/abs/2107.01188">Combinatorial Optimization with
Physics-Inspired Graph Neural Networks</a>.</li>
<li><a href="https://github.com/MasterSkepticista/detr">DETR</a> - Flax
implementation of <a
href="https://github.com/facebookresearch/detr"><em>DETR: End-to-end
Object Detection with Transformers</em></a> using Sinkhorn solver and
parallel bipartite matching.</li>
</ul>
<h3 id="haiku">Haiku</h3>
<ul>
<li><a href="https://github.com/deepmind/alphafold">AlphaFold</a> -
Implementation of the inference pipeline of AlphaFold v2.0, presented in
<a href="https://www.nature.com/articles/s41586-021-03819-2"><em>Highly
accurate protein structure prediction with AlphaFold</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness">Adversarial
Robustness</a> - Reference code for <a
href="https://arxiv.org/abs/2010.03593"><em>Uncovering the Limits of
Adversarial Training against Norm-Bounded Adversarial Examples</em></a>
and <a href="https://arxiv.org/abs/2103.01946"><em>Fixing Data
Augmentation to Improve Adversarial Robustness</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/byol">Bootstrap
Your Own Latent</a> - Implementation for the paper <a
href="https://arxiv.org/abs/2006.07733"><em>Bootstrap your own latent: A
new approach to self-supervised Learning</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/gated_linear_networks">Gated
Linear Networks</a> - GLNs are a family of backpropagation-free neural
networks.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/glassy_dynamics">Glassy
Dynamics</a> - Open source implementation of the paper <a
href="https://www.nature.com/articles/s41567-020-0842-8"><em>Unveiling
the predictive power of static structure in glassy
systems</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/mmv">MMV</a>
- Code for the models in <a
href="https://arxiv.org/abs/2006.16228"><em>Self-Supervised MultiModal
Versatile Networks</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/nfnets">Normalizer-Free
Networks</a> - Official Haiku implementation of <a
href="https://arxiv.org/abs/2102.06171"><em>NFNets</em></a>.</li>
<li><a
href="https://github.com/Information-Fusion-Lab-Umass/NuX">NuX</a> -
Normalizing flows with JAX.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc">OGB-LSC</a>
- This repository contains DeepMinds entry to the <a
href="https://ogb.stanford.edu/kddcup2021/pcqm4m/">PCQM4M-LSC</a>
(quantum chemistry) and <a
href="https://ogb.stanford.edu/kddcup2021/mag240m/">MAG240M-LSC</a>
(academic graph) tracks of the <a
href="https://ogb.stanford.edu/kddcup2021/">OGB Large-Scale
Challenge</a> (OGB-LSC).</li>
<li><a
href="https://github.com/google-research/google-research/tree/master/persistent_es">Persistent
Evolution Strategies</a> - Code used for the paper <a
href="http://proceedings.mlr.press/v139/vicol21a.html"><em>Unbiased
Gradient Estimation in Unrolled Computation Graphs with Persistent
Evolution Strategies</em></a>.</li>
<li><a href="https://github.com/degregat/two-player-auctions">Two Player
Auction Learning</a> - JAX implementation of the paper <a
href="https://arxiv.org/abs/2006.05684"><em>Auction learning as a
two-player game</em></a>.</li>
<li><a
href="https://github.com/deepmind/deepmind-research/tree/master/wikigraphs">WikiGraphs</a>
- Baseline code to reproduce results in <a
href="https://aclanthology.org/2021.textgraphs-1.7"><em>WikiGraphs: A
Wikipedia Text - Knowledge Graph Paired Datase</em></a>.</li>
</ul>
<h3 id="trax">Trax</h3>
<ul>
<li><a
href="https://github.com/google/trax/tree/master/trax/models/reformer">Reformer</a>
- Implementation of the Reformer (efficient transformer)
architecture.</li>
</ul>
<h3 id="numpyro">NumPyro</h3>
<ul>
<li><a href="https://github.com/RothkopfLab/lqg">lqg</a> - Official
implementation of Bayesian inverse optimal control for linear-quadratic
Gaussian problems from the paper <a
href="https://elifesciences.org/articles/76635"><em>Putting perception
into action with inverse optimal control for continuous
psychophysics</em></a></li>
</ul>
<p><a name="videos" /></p>
<h2 id="videos">Videos</h2>
<ul>
<li><a href="https://www.youtube.com/watch?v=iDxJxIyzSiM">NeurIPS 2020:
JAX Ecosystem Meetup</a> - JAX, its use at DeepMind, and discussion
between engineers, scientists, and JAX core team.</li>
<li><a href="https://youtu.be/0mVmRHMaOJ4">Introduction to JAX</a> -
Simple neural network from scratch in JAX.</li>
<li><a href="https://youtu.be/z-WSrQDXkuM">JAX: Accelerated Machine
Learning Research | SciPy 2020 | VanderPlas</a> - JAXs core design, how
its powering new research, and how you can start using it.</li>
<li><a href="https://youtu.be/CecuWGpoztw">Bayesian Programming with JAX
+ NumPyro — Andy Kitchen</a> - Introduction to Bayesian modelling using
NumPyro.</li>
<li><a
href="https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python">JAX:
Accelerated machine-learning research via composable function
transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne</a> -
JAX intro presentation in <a
href="https://program-transformations.github.io"><em>Program
Transformations for Machine Learning</em></a> workshop.</li>
<li><a
href="https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit">JAX
on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
Bradbury</a> - Presentation of TPU host access with demo.</li>
<li><a
href="https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond">Deep
Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
NeurIPS 2020</a> - Tutorial created by Zico Kolter, David Duvenaud, and
Matt Johnson with Colab notebooks avaliable in <a
href="http://implicit-layers-tutorial.org"><em>Deep Implicit
Layers</em></a>.</li>
<li><a href="http://matpalm.com/blog/ymxb_pod_slice/">Solving y=mx+b
with Jax on a TPU Pod slice - Mat Kelcey</a> - A four part YouTube
tutorial series with Colab notebooks that starts with Jax fundamentals
and moves up to training with a data parallel approach on a v3-32 TPU
Pod slice.</li>
<li><a
href="https://github.com/huggingface/transformers/blob/9160d81c98854df44b1d543ce5d65a6aa28444a2/examples/research_projects/jax-projects/README.md#talks">JAX,
Flax &amp; Transformers 🤗</a> - 3 days of talks around JAX / Flax,
Transformers, large-scale language modeling and other great topics.</li>
</ul>
<p><a name="papers" /></p>
<h2 id="papers">Papers</h2>
<p>This section contains papers focused on JAX (e.g. JAX-based library
whitepapers, research on JAX, etc). Papers implemented in JAX are listed
in the <a href="#projects">Models/Projects</a> section.</p>
<!--lint ignore awesome-list-item-->
<ul>
<li><a
href="https://mlsys.org/Conferences/doc/2018/146.pdf"><strong>Compiling
machine learning programs via high-level tracing</strong>. Roy Frostig,
Matthew James Johnson, Chris Leary. <em>MLSys 2018</em>.</a> - White
paper describing an early version of JAX, detailing how computation is
traced and compiled.</li>
<li><a href="https://arxiv.org/abs/1912.04232"><strong>JAX, M.D.: A
Framework for Differentiable Physics</strong>. Samuel S. Schoenholz,
Ekin D. Cubuk. <em>NeurIPS 2020</em>.</a> - Introduces JAX, M.D., a
differentiable physics library which includes simulation environments,
interaction potentials, neural networks, and more.</li>
<li><a href="https://arxiv.org/abs/2010.09063"><strong>Enabling Fast
Differentially Private SGD via Just-in-Time Compilation and
Vectorization</strong>. Pranav Subramani, Nicholas Vadivelu, Gautam
Kamath. <em>arXiv 2020</em>.</a> - Uses JAXs JIT and VMAP to achieve
faster differentially private than existing libraries.</li>
<li><a href="https://arxiv.org/abs/2311.16080"><strong>XLB: A
Differentiable Massively Parallel Lattice Boltzmann Library in
Python</strong>. Mohammadmehdi Ataei, Hesam Salehipour. <em>arXiv
2023</em>.</a> - White paper describing the XLB library: benchmarks,
validations, and more details about the library.</li>
</ul>
<!--lint enable awesome-list-item-->
<p><a name="tutorials-and-blog-posts" /></p>
<h2 id="tutorials-and-blog-posts">Tutorials and Blog Posts</h2>
<ul>
<li><a
href="https://deepmind.com/blog/article/using-jax-to-accelerate-our-research">Using
JAX to accelerate our research by David Budden and Matteo Hessel</a> -
Describes the state of JAX and the JAX ecosystem at DeepMind.</li>
<li><a
href="https://roberttlange.github.io/posts/2020/03/blog-post-10/">Getting
started with JAX (MLPs, CNNs &amp; RNNs) by Robert Lange</a> - Neural
network building blocks from scratch with the basic JAX operators.</li>
<li><a href="https://www.kaggle.com/code/truthr/jax-0">Learn JAX: From
Linear Regression to Neural Networks by Rito Ghosh</a> - A gentle
introduction to JAX and using it to implement Linear and Logistic
Regression, and Neural Network models and using them to solve real world
problems.</li>
<li><a
href="https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen">Tutorial:
image classification with JAX and Flax Linen by 8bitmp3</a> - Learn how
to create a simple convolutional network with the Linen API by Flax and
train it to recognize handwritten digits.</li>
<li><a
href="https://medium.com/swlh/plugging-into-jax-16c120ec3302">Plugging
Into JAX by Nick Doiron</a> - Compares Flax, Haiku, and Objax on the
Kaggle flower classification challenge.</li>
<li><a
href="https://blog.evjang.com/2019/02/maml-jax.html">Meta-Learning in 50
Lines of JAX by Eric Jang</a> - Introduction to both JAX and
Meta-Learning.</li>
<li><a href="https://blog.evjang.com/2019/07/nf-jax.html">Normalizing
Flows in 100 Lines of JAX by Eric Jang</a> - Concise implementation of
<a href="https://arxiv.org/abs/1605.08803">RealNVP</a>.</li>
<li><a href="https://blog.evjang.com/2019/11/jaxpt.html">Differentiable
Path Tracing on the GPU/TPU by Eric Jang</a> - Tutorial on implementing
path tracing.</li>
<li><a href="http://matpalm.com/blog/ensemble_nets">Ensemble networks by
Mat Kelcey</a> - Ensemble nets are a method of representing an ensemble
of models as one single logical model.</li>
<li><a href="http://matpalm.com/blog/ood_using_focal_loss">Out of
distribution (OOD) detection by Mat Kelcey</a> - Implements different
methods for OOD detection.</li>
<li><a href="https://www.radx.in/jax.html">Understanding Autodiff with
JAX by Srihari Radhakrishna</a> - Understand how autodiff works using
JAX.</li>
<li><a href="https://sjmielke.com/jax-purify.htm">From PyTorch to JAX:
towards neural net frameworks that purify stateful code by Sabrina J.
Mielke</a> - Showcases how to go from a PyTorch-like style of coding to
a more Functional-style of coding.</li>
<li><a href="https://github.com/dfm/extending-jax">Extending JAX with
custom C++ and CUDA code by Dan Foreman-Mackey</a> - Tutorial
demonstrating the infrastructure required to provide custom ops in
JAX.</li>
<li><a
href="https://roberttlange.github.io/posts/2021/02/cma-es-jax/">Evolving
Neural Networks in JAX by Robert Tjarko Lange</a> - Explores how JAX can
power the next generation of scalable neuroevolution algorithms.</li>
<li><a
href="http://lukemetz.com/exploring-hyperparameter-meta-loss-landscapes-with-jax/">Exploring
hyperparameter meta-loss landscapes with JAX by Luke Metz</a> -
Demonstrates how to use JAX to perform inner-loss optimization with SGD
and Momentum, outer-loss optimization with gradients, and outer-loss
optimization using evolutionary strategies.</li>
<li><a
href="https://martiningram.github.io/deterministic-advi/">Deterministic
ADVI in JAX by Martin Ingram</a> - Walk through of implementing
automatic differentiation variational inference (ADVI) easily and
cleanly with JAX.</li>
<li><a href="http://matpalm.com/blog/evolved_channel_selection/">Evolved
channel selection by Mat Kelcey</a> - Trains a classification model
robust to different combinations of input channels at different
resolutions, then uses a genetic algorithm to decide the best
combination for a particular loss.</li>
<li><a
href="https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb">Introduction
to JAX by Kevin Murphy</a> - Colab that introduces various aspects of
the language and applies them to simple ML problems.</li>
<li><a
href="https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/">Writing
an MCMC sampler in JAX by Jeremie Coullon</a> - Tutorial on the
different ways to write an MCMC sampler in JAX along with speed
benchmarks.</li>
<li><a
href="https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/">How
to add a progress bar to JAX scans and loops by Jeremie Coullon</a> -
Tutorial on how to add a progress bar to compiled loops in JAX using the
<code>host_callback</code> module.</li>
<li><a href="https://github.com/gordicaleksa/get-started-with-JAX">Get
started with JAX by Aleksa Gordić</a> - A series of notebooks and videos
going from zero JAX knowledge to building neural networks in Haiku.</li>
<li><a
href="https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-FLAX--VmlldzoyMzA4ODEy">Writing
a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit</a>
- A tutorial on writing a simple end-to-end training and evaluation
pipeline in JAX, Flax and Optax.</li>
<li><a
href="https://wandb.ai/wandb/nerf-jax/reports/Implementing-NeRF-in-JAX--VmlldzoxODA2NDk2?galleryTag=jax">Implementing
NeRF in JAX by Soumik Rakshit and Saurav Maheshkar</a> - A tutorial on
3D volumetric rendering of scenes represented by Neural Radiance Fields
in JAX.</li>
<li><a
href="https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html">Deep
Learning tutorials with JAX+Flax by Phillip Lippe</a> - A series of
notebooks explaining various deep learning concepts, from basics
(e.g. intro to JAX/Flax, activiation functions) to recent advances
(e.g., Vision Transformers, SimCLR), with translations to PyTorch.</li>
<li><a href="https://chrislu.page/blog/meta-disco/">Achieving 4000x
Speedups with PureJaxRL</a> - A blog post on how JAX can massively
speedup RL training through vectorisation.</li>
<li><a
href="https://levelup.gitconnected.com/create-your-own-automatically-differentiable-simulation-with-python-jax-46951e120fbb?sk=e8b9213dd2c6a5895926b2695d28e4aa">Simple
PDE solver + Constrained Optimization with JAX by Philip Mocz</a> - A
simple example of solving the advection-diffusion equations with JAX and
using it in a constrained optimization problem to find initial
conditions that yield desired result.</li>
</ul>
<p><a name="books" /></p>
<h2 id="books">Books</h2>
<ul>
<li><a href="https://www.manning.com/books/jax-in-action">Jax in
Action</a> - A hands-on guide to using JAX for deep learning and other
mathematically-intensive applications.</li>
</ul>
<p><a name="community" /></p>
<h2 id="community">Community</h2>
<ul>
<li><a
href="https://discord.com/channels/1107832795377713302/1107832795688083561">JaxLLM
(Unofficial) Discord</a></li>
<li><a href="https://github.com/google/jax/discussions">JAX GitHub
Discussions</a></li>
<li><a href="https://www.reddit.com/r/JAX/">Reddit</a></li>
</ul>
<h2 id="contributing">Contributing</h2>
<p>Contributions welcome! Read the <a
href="contributing.md">contribution guidelines</a> first.</p>
<p><a href="https://github.com/n2cholas/awesome-jax">jax.md
Github</a></p>