909 lines
50 KiB
HTML
909 lines
50 KiB
HTML
<!--lint ignore double-link-->
|
||
<h1 id="awesome-jax-awesome">Awesome JAX <a
|
||
href="https://awesome.re"><img src="https://awesome.re/badge.svg"
|
||
alt="Awesome" /></a><a
|
||
href="https://github.com/google/jax"><img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="JAX Logo" align="right" height="100"></a></h1>
|
||
<!--lint ignore double-link-->
|
||
<p><a href="https://github.com/google/jax">JAX</a> brings automatic
|
||
differentiation and the <a href="https://www.tensorflow.org/xla">XLA
|
||
compiler</a> together through a <a
|
||
href="https://numpy.org/">NumPy</a>-like API for high performance
|
||
machine learning research on accelerators like GPUs and TPUs.
|
||
<!--lint enable double-link--></p>
|
||
<p>This is a curated list of awesome JAX libraries, projects, and other
|
||
resources. Contributions are welcome!</p>
|
||
<h2 id="contents">Contents</h2>
|
||
<ul>
|
||
<li><a href="#libraries">Libraries</a></li>
|
||
<li><a href="#models-and-projects">Models and Projects</a></li>
|
||
<li><a href="#videos">Videos</a></li>
|
||
<li><a href="#papers">Papers</a></li>
|
||
<li><a href="#tutorials-and-blog-posts">Tutorials and Blog
|
||
Posts</a></li>
|
||
<li><a href="#books">Books</a></li>
|
||
<li><a href="#community">Community</a></li>
|
||
</ul>
|
||
<p><a name="libraries" /></p>
|
||
<h2 id="libraries">Libraries</h2>
|
||
<ul>
|
||
<li>Neural Network Libraries
|
||
<ul>
|
||
<li><a href="https://github.com/google/flax">Flax</a> - Centered on
|
||
flexibility and clarity.
|
||
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/flax/tree/main/flax/nnx">Flax
|
||
NNX</a> - An evolution on Flax by the same team
|
||
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/dm-haiku">Haiku</a> - Focused
|
||
on simplicity, created by the authors of Sonnet at DeepMind.
|
||
<img src="https://img.shields.io/github/stars/deepmind/dm-haiku?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/objax">Objax</a> - Has an object
|
||
oriented design similar to PyTorch.
|
||
<img src="https://img.shields.io/github/stars/google/objax?style=social" align="center"></li>
|
||
<li><a href="https://poets-ai.github.io/elegy/">Elegy</a> - A High Level
|
||
API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
|
||
<img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/trax">Trax</a> - “Batteries
|
||
included” deep learning library focused on providing solutions for
|
||
common workloads.
|
||
<img src="https://img.shields.io/github/stars/google/trax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/jraph">Jraph</a> - Lightweight
|
||
graph neural network library.
|
||
<img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/neural-tangents">Neural
|
||
Tangents</a> - High-level API for specifying neural networks of both
|
||
finite and <em>infinite</em> width.
|
||
<img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center"></li>
|
||
<li><a href="https://github.com/huggingface/transformers">HuggingFace
|
||
Transformers</a> - Ecosystem of pretrained Transformers for a wide range
|
||
of natural language tasks (Flax).
|
||
<img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center"></li>
|
||
<li><a href="https://github.com/patrick-kidger/equinox">Equinox</a> -
|
||
Callable PyTrees and filtered JIT/grad transformations => neural
|
||
networks in JAX.
|
||
<img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google-research/scenic">Scenic</a> - A
|
||
Jax Library for Computer Vision Research and Beyond.
|
||
<img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google-deepmind/penzai">Penzai</a> -
|
||
Prioritizes legibility, visualization, and easy editing of neural
|
||
network models with composable tools and a simple mental model.
|
||
<img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center"></li>
|
||
</ul></li>
|
||
<li><a href="https://github.com/stanford-crfm/levanter">Levanter</a> -
|
||
Legible, Scalable, Reproducible Foundation Models with Named Tensors and
|
||
JAX.
|
||
<img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center"></li>
|
||
<li><a href="https://github.com/young-geng/EasyLM">EasyLM</a> - LLMs
|
||
made easy: Pre-training, finetuning, evaluating and serving LLMs in
|
||
JAX/Flax.
|
||
<img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center"></li>
|
||
<li><a href="https://github.com/pyro-ppl/numpyro">NumPyro</a> -
|
||
Probabilistic programming based on the Pyro library.
|
||
<img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/chex">Chex</a> - Utilities to
|
||
write and test reliable JAX code.
|
||
<img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/optax">Optax</a> - Gradient
|
||
processing and optimization library.
|
||
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/rlax">RLax</a> - Library for
|
||
implementing reinforcement learning agents.
|
||
<img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/jax-md">JAX, M.D.</a> -
|
||
Accelerated, differential molecular dynamics.
|
||
<img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center"></li>
|
||
<li><a href="https://github.com/coax-dev/coax">Coax</a> - Turn RL papers
|
||
into code, the easy way.
|
||
<img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/distrax">Distrax</a> -
|
||
Reimplementation of TensorFlow Probability, containing probability
|
||
distributions and bijectors.
|
||
<img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/cvxgrp/cvxpylayers">cvxpylayers</a> -
|
||
Construct differentiable convex optimization layers.
|
||
<img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center"></li>
|
||
<li><a href="https://github.com/tensorly/tensorly">TensorLy</a> - Tensor
|
||
learning made simple.
|
||
<img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center"></li>
|
||
<li><a href="https://github.com/netket/netket">NetKet</a> - Machine
|
||
Learning toolbox for Quantum Physics.
|
||
<img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center"></li>
|
||
<li><a href="https://github.com/awslabs/fortuna">Fortuna</a> - AWS
|
||
library for Uncertainty Quantification in Deep Learning.
|
||
<img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center"></li>
|
||
<li><a href="https://github.com/blackjax-devs/blackjax">BlackJAX</a> -
|
||
Library of samplers for JAX.
|
||
<img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center"></li>
|
||
</ul>
|
||
<p><a name="new-libraries" /></p>
|
||
<h3 id="new-libraries">New Libraries</h3>
|
||
<p>This section contains libraries that are well-made and useful, but
|
||
have not necessarily been battle-tested by a large userbase yet.</p>
|
||
<ul>
|
||
<li>Neural Network Libraries
|
||
<ul>
|
||
<li><a href="https://github.com/google/fedjax">FedJAX</a> - Federated
|
||
learning in JAX, built on Optax and Haiku.
|
||
<img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/mfinzi/equivariant-MLP">Equivariant
|
||
MLP</a> - Construct equivariant neural network layers.
|
||
<img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center"></li>
|
||
<li><a href="https://github.com/n2cholas/jax-resnet/">jax-resnet</a> -
|
||
Implementations and checkpoints for ResNet variants in Flax.
|
||
<img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center"></li>
|
||
<li><a href="https://github.com/srush/parallax">Parallax</a> - Immutable
|
||
Torch Modules for JAX.
|
||
<img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center"></li>
|
||
</ul></li>
|
||
<li>Nonlinear Optimization
|
||
<ul>
|
||
<li><a
|
||
href="https://github.com/patrick-kidger/optimistix">Optimistix</a> -
|
||
Root finding, minimisation, fixed points, and least squares.
|
||
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/jaxopt">JAXopt</a> - Hardware
|
||
accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
|
||
<img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center"></li>
|
||
</ul></li>
|
||
<li><a href="https://github.com/ElArkk/jax-unirep">jax-unirep</a> -
|
||
Library implementing the <a
|
||
href="https://www.nature.com/articles/s41592-019-0598-1">UniRep
|
||
model</a> for protein machine learning applications.
|
||
<img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center"></li>
|
||
<li><a href="https://github.com/danielward27/flowjax">flowjax</a> -
|
||
Distributions and normalizing flows built as equinox modules.
|
||
<img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/ChrisWaites/jax-flows">jax-flows</a> -
|
||
Normalizing flows in JAX.
|
||
<img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/ExpectationMax/sklearn-jax-kernels">sklearn-jax-kernels</a>
|
||
- <code>scikit-learn</code> kernel matrices using JAX.
|
||
<img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo">jax-cosmo</a>
|
||
- Differentiable cosmology library.
|
||
<img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center"></li>
|
||
<li><a href="https://github.com/NeilGirdhar/efax">efax</a> - Exponential
|
||
Families in JAX.
|
||
<img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/PhilipVinc/mpi4jax">mpi4jax</a> -
|
||
Combine MPI operations with your Jax code on CPUs and GPUs.
|
||
<img src="https://img.shields.io/github/stars/PhilipVinc/mpi4jax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/4rtemi5/imax">imax</a> - Image
|
||
augmentations and transformations.
|
||
<img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/rolandgvc/flaxvision">FlaxVision</a> -
|
||
Flax version of TorchVision.
|
||
<img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/tensorflow/probability/tree/master/spinoffs/oryx">Oryx</a>
|
||
- Probabilistic programming language based on program
|
||
transformations.</li>
|
||
<li><a href="https://github.com/google-research/ott">Optimal Transport
|
||
Tools</a> - Toolbox that bundles utilities to solve optimal transport
|
||
problems.</li>
|
||
<li><a href="https://github.com/romanodev/deltapv">delta PV</a> - A
|
||
photovoltaic simulator with automatic differentation.
|
||
<img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center"></li>
|
||
<li><a href="https://github.com/brentyi/jaxlie">jaxlie</a> - Lie theory
|
||
library for rigid body transformations and optimization.
|
||
<img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/brax">BRAX</a> - Differentiable
|
||
physics engine to simulate environments along with learning algorithms
|
||
to train agents for these environments.
|
||
<img src="https://img.shields.io/github/stars/google/brax?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/matthias-wright/flaxmodels">flaxmodels</a> -
|
||
Pretrained models for Jax/Flax.
|
||
<img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center"></li>
|
||
<li><a href="https://github.com/carnotresearch/cr-sparse">CR.Sparse</a>
|
||
- XLA accelerated algorithms for sparse representations and compressive
|
||
sensing.
|
||
<img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center"></li>
|
||
<li><a href="https://github.com/HajimeKawahara/exojax">exojax</a> -
|
||
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
|
||
compatible to JAX.
|
||
<img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/dm_pix">PIX</a> - PIX is an
|
||
image processing library in JAX, for JAX.
|
||
<img src="https://img.shields.io/github/stars/deepmind/dm_pix?style=social" align="center"></li>
|
||
<li><a href="https://github.com/alonfnt/bayex">bayex</a> - Bayesian
|
||
Optimization powered by JAX.
|
||
<img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center"></li>
|
||
<li><a href="https://github.com/ucl-bug/jaxdf">JaxDF</a> - Framework for
|
||
differentiable simulators with arbitrary discretizations.
|
||
<img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/tree-math">tree-math</a> -
|
||
Convert functions that operate on arrays into functions that operate on
|
||
PyTrees.
|
||
<img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/DarshanDeshpande/jax-models">jax-models</a> -
|
||
Implementations of research papers originally without code or code
|
||
written with frameworks other than JAX.
|
||
<img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center"></li>
|
||
<li><a href="https://github.com/vicariousinc/PGMax">PGMax</a> - A
|
||
framework for building discrete Probabilistic Graphical Models (PGM’s)
|
||
and running inference inference on them via JAX.
|
||
<img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/evojax">EvoJAX</a> -
|
||
Hardware-Accelerated Neuroevolution
|
||
<img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/RobertTLange/evosax">evosax</a> -
|
||
JAX-Based Evolution Strategies
|
||
<img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/SymJAX/SymJAX">SymJAX</a> - Symbolic
|
||
CPU/GPU/TPU programming.
|
||
<img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center"></li>
|
||
<li><a href="https://github.com/rlouf/mcx">mcx</a> - Express &
|
||
compile probabilistic programs for performant inference.
|
||
<img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/einshape">Einshape</a> -
|
||
DSL-based reshaping library for JAX and other frameworks.
|
||
<img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/alx">ALX</a>
|
||
- Open-source library for distributed matrix factorization using
|
||
Alternating Least Squares, more info in <a
|
||
href="https://arxiv.org/abs/2112.02194"><em>ALX: Large Scale Matrix
|
||
Factorization on TPUs</em></a>.</li>
|
||
<li><a href="https://github.com/patrick-kidger/diffrax">Diffrax</a> -
|
||
Numerical differential equation solvers in JAX.
|
||
<img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/dfm/tinygp">tinygp</a> - The
|
||
<em>tiniest</em> of Gaussian process libraries in JAX.
|
||
<img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center"></li>
|
||
<li><a href="https://github.com/RobertTLange/gymnax">gymnax</a> -
|
||
Reinforcement Learning Environments with the well-known gym API.
|
||
<img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/mctx">Mctx</a> - Monte Carlo
|
||
tree search algorithms in native JAX.
|
||
<img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/kfac-jax">KFAC-JAX</a> - Second
|
||
Order Optimization with Approximate Curvature for NNs.
|
||
<img src="https://img.shields.io/github/stars/deepmind/kfac-jax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/deepmind/tf2jax">TF2JAX</a> - Convert
|
||
functions/graphs to JAX functions.
|
||
<img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/ucl-bug/jwave">jwave</a> - A library for
|
||
differentiable acoustic simulations
|
||
<img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center"></li>
|
||
<li><a href="https://github.com/thomaspinder/GPJax">GPJax</a> - Gaussian
|
||
processes in JAX.</li>
|
||
<li><a href="https://github.com/instadeepai/jumanji">Jumanji</a> - A
|
||
Suite of Industry-Driven Hardware-Accelerated RL Environments written in
|
||
JAX.
|
||
<img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center"></li>
|
||
<li><a href="https://github.com/paganpasta/eqxvision">Eqxvision</a> -
|
||
Equinox version of Torchvision.
|
||
<img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center"></li>
|
||
<li><a href="https://github.com/dipolar-quantum-gases/jaxfit">JAXFit</a>
|
||
- Accelerated curve fitting library for nonlinear least-squares problems
|
||
(see <a href="https://arxiv.org/abs/2208.12187">arXiv paper</a>).
|
||
<img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center"></li>
|
||
<li><a href="https://github.com/gboehl/econpizza">econpizza</a> - Solve
|
||
macroeconomic models with hetereogeneous agents using JAX.
|
||
<img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center"></li>
|
||
<li><a href="https://github.com/secretflow/spu">SPU</a> - A
|
||
domain-specific compiler and runtime suite to run JAX code with
|
||
MPC(Secure Multi-Party Computation).
|
||
<img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center"></li>
|
||
<li><a href="https://github.com/jeremiecoullon/jax-tqdm">jax-tqdm</a> -
|
||
Add a tqdm progress bar to JAX scans and loops.
|
||
<img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center"></li>
|
||
<li><a href="https://github.com/alvarobartt/safejax">safejax</a> -
|
||
Serialize JAX, Flax, Haiku, or Objax model params with
|
||
🤗<code>safetensors</code>.
|
||
<img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/ASEM000/kernex">Kernex</a> -
|
||
Differentiable stencil decorators in JAX.
|
||
<img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/maxtext">MaxText</a> - A simple,
|
||
performant and scalable Jax LLM written in pure Python/Jax and targeting
|
||
Google Cloud TPUs.
|
||
<img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/paxml">Pax</a> - A Jax-based
|
||
machine learning framework for training large scale models.
|
||
<img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center"></li>
|
||
<li><a href="https://github.com/google/praxis">Praxis</a> - The layer
|
||
library for Pax with a goal to be usable by other JAX-based ML projects.
|
||
<img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center"></li>
|
||
<li><a href="https://github.com/luchris429/purejaxrl">purejaxrl</a> -
|
||
Vectorisable, end-to-end RL algorithms in JAX.
|
||
<img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center"></li>
|
||
<li><a href="https://github.com/davisyoshida/lorax">Lorax</a> -
|
||
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)</li>
|
||
<li><a href="https://github.com/lanl/scico">SCICO</a> - Scientific
|
||
computational imaging in JAX.
|
||
<img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center"></li>
|
||
<li><a href="https://github.com/kmheckel/spyx">Spyx</a> - Spiking Neural
|
||
Networks in JAX for machine learning on neuromorphic hardware.
|
||
<img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center"></li>
|
||
<li>Brain Dynamics Programming Ecosystem
|
||
<ul>
|
||
<li><a href="https://github.com/brainpy/BrainPy">BrainPy</a> - Brain
|
||
Dynamics Programming in Python.
|
||
<img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center"></li>
|
||
<li><a href="https://github.com/chaobrain/brainunit">brainunit</a> -
|
||
Physical units and unit-aware mathematical system in JAX.
|
||
<img src="https://img.shields.io/github/stars/chaobrain/brainunit?style=social" align="center"></li>
|
||
<li><a href="https://github.com/chaobrain/dendritex">dendritex</a> -
|
||
Dendritic Modeling in JAX.
|
||
<img src="https://img.shields.io/github/stars/chaobrain/dendritex?style=social" align="center"></li>
|
||
<li><a href="https://github.com/chaobrain/brainstate">brainstate</a> -
|
||
State-based Transformation System for Program Compilation and
|
||
Augmentation.
|
||
<img src="https://img.shields.io/github/stars/chaobrain/brainstate?style=social" align="center"></li>
|
||
<li><a href="https://github.com/chaobrain/braintaichi">braintaichi</a> -
|
||
Leveraging Taichi Lang to customize brain dynamics operators.
|
||
<img src="https://img.shields.io/github/stars/chaobrain/braintaichi?style=social" align="center"></li>
|
||
</ul></li>
|
||
<li><a href="https://github.com/ott-jax/ott">OTT-JAX</a> - Optimal
|
||
transport tools in JAX.
|
||
<img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center"></li>
|
||
<li><a
|
||
href="https://github.com/adaptive-intelligent-robotics/QDax">QDax</a> -
|
||
Quality Diversity optimization in Jax.
|
||
<img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/NVIDIA/JAX-Toolbox">JAX Toolbox</a> -
|
||
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries
|
||
such as T5x, Paxml, and Transformer Engine.
|
||
<img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center"></li>
|
||
<li><a href="http://github.com/sotetsuk/pgx">Pgx</a> - Vectorized board
|
||
game environments for RL with an AlphaZero example.
|
||
<img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center"></li>
|
||
<li><a href="https://github.com/erfanzar/EasyDeL">EasyDeL</a> - EasyDeL
|
||
🔮 is an OpenSource Library to make your training faster and more
|
||
Optimized With cool Options for training and serving (Llama, MPT,
|
||
Mixtral, Falcon, etc) in JAX
|
||
<img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center"></li>
|
||
<li><a href="https://github.com/Autodesk/XLB">XLB</a> - A Differentiable
|
||
Massively Parallel Lattice Boltzmann Library in Python for Physics-Based
|
||
Machine Learning.
|
||
<img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center"></li>
|
||
<li><a href="https://github.com/dynamiqs/dynamiqs">dynamiqs</a> -
|
||
High-performance and differentiable simulations of quantum systems with
|
||
JAX.
|
||
<img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center"></li>
|
||
<li><a href="https://github.com/i-m-iron-man/Foragax">foragax</a> -
|
||
Agent-Based modelling framework in JAX.
|
||
<img src="https://img.shields.io/github/stars/i-m-iron-man/Foragax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/bahremsd/tmmax">tmmax</a> - Vectorized
|
||
calculation of optical properties in thin-film structures using JAX.
|
||
Swiss Army knife tool for thin-film optics research
|
||
<img src="https://img.shields.io/github/stars/bahremsd/tmmax" align="center"></li>
|
||
<li><a href="https://github.com/gchq/coreax">Coreax</a> - Algorithms for
|
||
finding coresets to compress large datasets while retaining their
|
||
statistical properties.
|
||
<img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center"></li>
|
||
<li><a href="https://github.com/epignatelli/navix">NAVIX</a> - A
|
||
reimplementation of MiniGrid, a Reinforcement Learning environment, in
|
||
JAX
|
||
<img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center"></li>
|
||
</ul>
|
||
<p><a name="models-and-projects" /></p>
|
||
<h2 id="models-and-projects">Models and Projects</h2>
|
||
<h3 id="jax">JAX</h3>
|
||
<ul>
|
||
<li><a href="https://github.com/tancik/fourier-feature-networks">Fourier
|
||
Feature Networks</a> - Official implementation of <a
|
||
href="https://people.eecs.berkeley.edu/~bmild/fourfeat"><em>Fourier
|
||
Features Let Networks Learn High Frequency Functions in Low Dimensional
|
||
Domains</em></a>.</li>
|
||
<li><a href="https://github.com/AaltoML/kalman-jax">kalman-jax</a> -
|
||
Approximate inference for Markov (i.e., temporal) Gaussian processes
|
||
using iterated Kalman filtering and smoothing.</li>
|
||
<li><a href="https://github.com/Joshuaalbert/jaxns">jaxns</a> - Nested
|
||
sampling in JAX.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/amortized_bo">Amortized
|
||
Bayesian Optimization</a> - Code related to <a
|
||
href="http://www.auai.org/uai2020/proceedings/329_main_paper.pdf"><em>Amortized
|
||
Bayesian Optimization over Discrete Spaces</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/aqt">Accurate
|
||
Quantized Training</a> - Tools and libraries for running and analyzing
|
||
neural network quantization experiments in JAX and Flax.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/bnn_hmc">BNN-HMC</a>
|
||
- Implementation for the paper <a
|
||
href="https://arxiv.org/abs/2104.14421"><em>What Are Bayesian Neural
|
||
Network Posteriors Really Like?</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/jax_dft">JAX-DFT</a>
|
||
- One-dimensional density functional theory (DFT) in JAX, with
|
||
implementation of <a
|
||
href="https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.126.036401"><em>Kohn-Sham
|
||
equations as regularizer: building prior knowledge into machine-learned
|
||
physics</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/robust_loss_jax">Robust
|
||
Loss</a> - Reference code for the paper <a
|
||
href="https://arxiv.org/abs/1701.03077"><em>A General and Adaptive
|
||
Robust Loss Function</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/symbolic_functionals">Symbolic
|
||
Functionals</a> - Demonstration from <a
|
||
href="https://arxiv.org/abs/2203.02540"><em>Evolving symbolic density
|
||
functionals</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/trimap">TriMap</a>
|
||
- Official JAX implementation of <a
|
||
href="https://arxiv.org/abs/1910.00204"><em>TriMap: Large-scale
|
||
Dimensionality Reduction Using Triplets</em></a>.</li>
|
||
</ul>
|
||
<h3 id="flax">Flax</h3>
|
||
<ul>
|
||
<li><a
|
||
href="https://github.com/J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B">DeepSeek-R1-Flax-1.5B-Distill</a>
|
||
- Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/performer/fast_attention/jax">Performer</a>
|
||
- Flax implementation of the Performer (linear transformer via FAVOR+)
|
||
architecture.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/jaxnerf">JaxNeRF</a>
|
||
- Implementation of <a
|
||
href="http://www.matthewtancik.com/nerf"><em>NeRF: Representing Scenes
|
||
as Neural Radiance Fields for View Synthesis</em></a> with multi-device
|
||
GPU/TPU support.</li>
|
||
<li><a href="https://github.com/google/mipnerf">mip-NeRF</a> - Official
|
||
implementation of <a href="https://jonbarron.info/mipnerf"><em>Mip-NeRF:
|
||
A Multiscale Representation for Anti-Aliasing Neural Radiance
|
||
Fields</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/regnerf">RegNeRF</a>
|
||
- Official implementation of <a
|
||
href="https://m-niemeyer.github.io/regnerf/"><em>RegNeRF: Regularizing
|
||
Neural Radiance Fields for View Synthesis from Sparse
|
||
Inputs</em></a>.</li>
|
||
<li><a href="https://github.com/huangjuite/jaxneus">JaxNeuS</a> -
|
||
Implementation of <a
|
||
href="https://lingjie0206.github.io/papers/NeuS/"><em>NeuS: Learning
|
||
Neural Implicit Surfaces by Volume Rendering for Multi-view
|
||
Reconstruction</em></a></li>
|
||
<li><a href="https://github.com/google-research/big_transfer">Big
|
||
Transfer (BiT)</a> - Implementation of <a
|
||
href="https://arxiv.org/abs/1912.11370"><em>Big Transfer (BiT): General
|
||
Visual Representation Learning</em></a>.</li>
|
||
<li><a href="https://github.com/ikostrikov/jax-rl">JAX RL</a> -
|
||
Implementations of reinforcement learning algorithms.</li>
|
||
<li><a href="https://github.com/SauravMaheshkar/gMLP">gMLP</a> -
|
||
Implementation of <a href="https://arxiv.org/abs/2105.08050"><em>Pay
|
||
Attention to MLPs</em></a>.</li>
|
||
<li><a href="https://github.com/SauravMaheshkar/MLP-Mixer">MLP Mixer</a>
|
||
- Minimal implementation of <a
|
||
href="https://arxiv.org/abs/2105.01601"><em>MLP-Mixer: An all-MLP
|
||
Architecture for Vision</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/scalable_shampoo">Distributed
|
||
Shampoo</a> - Implementation of <a
|
||
href="https://arxiv.org/abs/2002.09018"><em>Second Order Optimization
|
||
Made Practical</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/nested-transformer">NesT</a> -
|
||
Official implementation of <a
|
||
href="https://arxiv.org/abs/2105.12723"><em>Aggregating Nested
|
||
Transformers</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/xmcgan_image_generation">XMC-GAN</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2101.04702"><em>Cross-Modal Contrastive
|
||
Learning for Text-to-Image Generation</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/f_net">FNet</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2105.03824"><em>FNet: Mixing Tokens with
|
||
Fourier Transforms</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/gfsa">GFSA</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2007.04929"><em>Learning Graph Structure
|
||
With A Finite-State Automaton Layer</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/ipagnn">IPA-GNN</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2010.12621"><em>Learning to Execute Programs
|
||
with Instruction Pointer Attention Graph Neural Networks</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/flax_models">Flax
|
||
Models</a> - Collection of models and methods implemented in Flax.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/protein_lm">Protein
|
||
LM</a> - Implements BERT and autoregressive models for proteins, as
|
||
described in <a
|
||
href="https://www.biorxiv.org/content/10.1101/622803v1.full"><em>Biological
|
||
Structure and Function Emerge from Scaling Unsupervised Learning to 250
|
||
Million Protein Sequences</em></a> and <a
|
||
href="https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2"><em>ProGen:
|
||
Language Modeling for Protein Generation</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/ptopk_patch_selection">Slot
|
||
Attention</a> - Reference implementation for <a
|
||
href="https://arxiv.org/abs/2104.03059"><em>Differentiable Patch
|
||
Selection for Image Recognition</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/vision_transformer">Vision
|
||
Transformer</a> - Official implementation of <a
|
||
href="https://arxiv.org/abs/2010.11929"><em>An Image is Worth 16x16
|
||
Words: Transformers for Image Recognition at Scale</em></a>.</li>
|
||
<li><a href="https://github.com/matthias-wright/jax-fid">FID
|
||
computation</a> - Port of <a
|
||
href="https://github.com/mseitzer/pytorch-fid">mseitzer/pytorch-fid</a>
|
||
to Flax.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/autoregressive_diffusion">ARDM</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2110.02037"><em>Autoregressive Diffusion
|
||
Models</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/d3pm">D3PM</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2107.03006"><em>Structured Denoising
|
||
Diffusion Models in Discrete State-Spaces</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/gumbel_max_causal_gadgets">Gumbel-max
|
||
Causal Mechanisms</a> - Code for <a
|
||
href="https://arxiv.org/abs/2111.06888"><em>Learning Generalized
|
||
Gumbel-max Causal Mechanisms</em></a>, with extra code in <a
|
||
href="https://github.com/GuyLor/gumbel_max_causal_gadgets_part2">GuyLor/gumbel_max_causal_gadgets_part2</a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/latent_programmer">Latent
|
||
Programmer</a> - Code for the ICML 2021 paper <a
|
||
href="https://arxiv.org/abs/2012.00377"><em>Latent Programmer: Discrete
|
||
Latent Codes for Program Synthesis</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/snerg">SNeRG</a>
|
||
- Official implementation of <a
|
||
href="https://phog.github.io/snerg"><em>Baking Neural Radiance Fields
|
||
for Real-Time View Synthesis</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/spin_spherical_cnns">Spin-weighted
|
||
Spherical CNNs</a> - Adaptation of <a
|
||
href="https://arxiv.org/abs/2006.10731"><em>Spin-Weighted Spherical
|
||
CNNs</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/vdvae_flax">VDVAE</a>
|
||
- Adaptation of <a href="https://arxiv.org/abs/2011.10650"><em>Very Deep
|
||
VAEs Generalize Autoregressive Models and Can Outperform Them on
|
||
Images</em></a>, original code at <a
|
||
href="https://github.com/openai/vdvae">openai/vdvae</a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/musiq">MUSIQ</a>
|
||
- Checkpoints and model inference code for the ICCV 2021 paper <a
|
||
href="https://arxiv.org/abs/2108.05997"><em>MUSIQ: Multi-scale Image
|
||
Quality Transformer</em></a></li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/aquadem">AQuaDem</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2110.10149"><em>Continuous Control with
|
||
Action Quantization from Demonstrations</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/combiner">Combiner</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2107.05768"><em>Combiner: Full Attention
|
||
Transformer with Sparse Computation Cost</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/dreamfields">Dreamfields</a>
|
||
- Official implementation of the ICLR 2022 paper <a
|
||
href="https://ajayj.com/dreamfields"><em>Progressive Distillation for
|
||
Fast Sampling of Diffusion Models</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/gift">GIFT</a>
|
||
- Official implementation of <a
|
||
href="https://arxiv.org/abs/2106.06080"><em>Gradual Domain Adaptation in
|
||
the Wild:When Intermediate Distributions are Absent</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/light_field_neural_rendering">Light
|
||
Field Neural Rendering</a> - Official implementation of <a
|
||
href="https://arxiv.org/abs/2112.09687"><em>Light Field Neural
|
||
Rendering</em></a>.</li>
|
||
<li><a
|
||
href="https://colab.research.google.com/drive/1KUKFEMneQMS3OzPYnWZGkEnry3PdzCfn?usp=sharing">Sharpened
|
||
Cosine Similarity in JAX by Raphael Pisoni</a> - A JAX/Flax
|
||
implementation of the Sharpened Cosine Similarity layer.</li>
|
||
<li><a
|
||
href="https://github.com/IvanIsCoding/GNN-for-Combinatorial-Optimization">GNNs
|
||
for Solving Combinatorial Optimization Problems</a> - A JAX + Flax
|
||
implementation of <a
|
||
href="https://arxiv.org/abs/2107.01188">Combinatorial Optimization with
|
||
Physics-Inspired Graph Neural Networks</a>.</li>
|
||
<li><a href="https://github.com/MasterSkepticista/detr">DETR</a> - Flax
|
||
implementation of <a
|
||
href="https://github.com/facebookresearch/detr"><em>DETR: End-to-end
|
||
Object Detection with Transformers</em></a> using Sinkhorn solver and
|
||
parallel bipartite matching.</li>
|
||
</ul>
|
||
<h3 id="haiku">Haiku</h3>
|
||
<ul>
|
||
<li><a href="https://github.com/deepmind/alphafold">AlphaFold</a> -
|
||
Implementation of the inference pipeline of AlphaFold v2.0, presented in
|
||
<a href="https://www.nature.com/articles/s41586-021-03819-2"><em>Highly
|
||
accurate protein structure prediction with AlphaFold</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness">Adversarial
|
||
Robustness</a> - Reference code for <a
|
||
href="https://arxiv.org/abs/2010.03593"><em>Uncovering the Limits of
|
||
Adversarial Training against Norm-Bounded Adversarial Examples</em></a>
|
||
and <a href="https://arxiv.org/abs/2103.01946"><em>Fixing Data
|
||
Augmentation to Improve Adversarial Robustness</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/byol">Bootstrap
|
||
Your Own Latent</a> - Implementation for the paper <a
|
||
href="https://arxiv.org/abs/2006.07733"><em>Bootstrap your own latent: A
|
||
new approach to self-supervised Learning</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/gated_linear_networks">Gated
|
||
Linear Networks</a> - GLNs are a family of backpropagation-free neural
|
||
networks.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/glassy_dynamics">Glassy
|
||
Dynamics</a> - Open source implementation of the paper <a
|
||
href="https://www.nature.com/articles/s41567-020-0842-8"><em>Unveiling
|
||
the predictive power of static structure in glassy
|
||
systems</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/mmv">MMV</a>
|
||
- Code for the models in <a
|
||
href="https://arxiv.org/abs/2006.16228"><em>Self-Supervised MultiModal
|
||
Versatile Networks</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/nfnets">Normalizer-Free
|
||
Networks</a> - Official Haiku implementation of <a
|
||
href="https://arxiv.org/abs/2102.06171"><em>NFNets</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/Information-Fusion-Lab-Umass/NuX">NuX</a> -
|
||
Normalizing flows with JAX.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc">OGB-LSC</a>
|
||
- This repository contains DeepMind’s entry to the <a
|
||
href="https://ogb.stanford.edu/kddcup2021/pcqm4m/">PCQM4M-LSC</a>
|
||
(quantum chemistry) and <a
|
||
href="https://ogb.stanford.edu/kddcup2021/mag240m/">MAG240M-LSC</a>
|
||
(academic graph) tracks of the <a
|
||
href="https://ogb.stanford.edu/kddcup2021/">OGB Large-Scale
|
||
Challenge</a> (OGB-LSC).</li>
|
||
<li><a
|
||
href="https://github.com/google-research/google-research/tree/master/persistent_es">Persistent
|
||
Evolution Strategies</a> - Code used for the paper <a
|
||
href="http://proceedings.mlr.press/v139/vicol21a.html"><em>Unbiased
|
||
Gradient Estimation in Unrolled Computation Graphs with Persistent
|
||
Evolution Strategies</em></a>.</li>
|
||
<li><a href="https://github.com/degregat/two-player-auctions">Two Player
|
||
Auction Learning</a> - JAX implementation of the paper <a
|
||
href="https://arxiv.org/abs/2006.05684"><em>Auction learning as a
|
||
two-player game</em></a>.</li>
|
||
<li><a
|
||
href="https://github.com/deepmind/deepmind-research/tree/master/wikigraphs">WikiGraphs</a>
|
||
- Baseline code to reproduce results in <a
|
||
href="https://aclanthology.org/2021.textgraphs-1.7"><em>WikiGraphs: A
|
||
Wikipedia Text - Knowledge Graph Paired Datase</em></a>.</li>
|
||
</ul>
|
||
<h3 id="trax">Trax</h3>
|
||
<ul>
|
||
<li><a
|
||
href="https://github.com/google/trax/tree/master/trax/models/reformer">Reformer</a>
|
||
- Implementation of the Reformer (efficient transformer)
|
||
architecture.</li>
|
||
</ul>
|
||
<h3 id="numpyro">NumPyro</h3>
|
||
<ul>
|
||
<li><a href="https://github.com/RothkopfLab/lqg">lqg</a> - Official
|
||
implementation of Bayesian inverse optimal control for linear-quadratic
|
||
Gaussian problems from the paper <a
|
||
href="https://elifesciences.org/articles/76635"><em>Putting perception
|
||
into action with inverse optimal control for continuous
|
||
psychophysics</em></a></li>
|
||
</ul>
|
||
<p><a name="videos" /></p>
|
||
<h2 id="videos">Videos</h2>
|
||
<ul>
|
||
<li><a href="https://www.youtube.com/watch?v=iDxJxIyzSiM">NeurIPS 2020:
|
||
JAX Ecosystem Meetup</a> - JAX, its use at DeepMind, and discussion
|
||
between engineers, scientists, and JAX core team.</li>
|
||
<li><a href="https://youtu.be/0mVmRHMaOJ4">Introduction to JAX</a> -
|
||
Simple neural network from scratch in JAX.</li>
|
||
<li><a href="https://youtu.be/z-WSrQDXkuM">JAX: Accelerated Machine
|
||
Learning Research | SciPy 2020 | VanderPlas</a> - JAX’s core design, how
|
||
it’s powering new research, and how you can start using it.</li>
|
||
<li><a href="https://youtu.be/CecuWGpoztw">Bayesian Programming with JAX
|
||
+ NumPyro — Andy Kitchen</a> - Introduction to Bayesian modelling using
|
||
NumPyro.</li>
|
||
<li><a
|
||
href="https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python">JAX:
|
||
Accelerated machine-learning research via composable function
|
||
transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne</a> -
|
||
JAX intro presentation in <a
|
||
href="https://program-transformations.github.io"><em>Program
|
||
Transformations for Machine Learning</em></a> workshop.</li>
|
||
<li><a
|
||
href="https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit">JAX
|
||
on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
|
||
Bradbury</a> - Presentation of TPU host access with demo.</li>
|
||
<li><a
|
||
href="https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond">Deep
|
||
Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
|
||
NeurIPS 2020</a> - Tutorial created by Zico Kolter, David Duvenaud, and
|
||
Matt Johnson with Colab notebooks avaliable in <a
|
||
href="http://implicit-layers-tutorial.org"><em>Deep Implicit
|
||
Layers</em></a>.</li>
|
||
<li><a href="http://matpalm.com/blog/ymxb_pod_slice/">Solving y=mx+b
|
||
with Jax on a TPU Pod slice - Mat Kelcey</a> - A four part YouTube
|
||
tutorial series with Colab notebooks that starts with Jax fundamentals
|
||
and moves up to training with a data parallel approach on a v3-32 TPU
|
||
Pod slice.</li>
|
||
<li><a
|
||
href="https://github.com/huggingface/transformers/blob/9160d81c98854df44b1d543ce5d65a6aa28444a2/examples/research_projects/jax-projects/README.md#talks">JAX,
|
||
Flax & Transformers 🤗</a> - 3 days of talks around JAX / Flax,
|
||
Transformers, large-scale language modeling and other great topics.</li>
|
||
</ul>
|
||
<p><a name="papers" /></p>
|
||
<h2 id="papers">Papers</h2>
|
||
<p>This section contains papers focused on JAX (e.g. JAX-based library
|
||
whitepapers, research on JAX, etc). Papers implemented in JAX are listed
|
||
in the <a href="#projects">Models/Projects</a> section.</p>
|
||
<!--lint ignore awesome-list-item-->
|
||
<ul>
|
||
<li><a
|
||
href="https://mlsys.org/Conferences/doc/2018/146.pdf"><strong>Compiling
|
||
machine learning programs via high-level tracing</strong>. Roy Frostig,
|
||
Matthew James Johnson, Chris Leary. <em>MLSys 2018</em>.</a> - White
|
||
paper describing an early version of JAX, detailing how computation is
|
||
traced and compiled.</li>
|
||
<li><a href="https://arxiv.org/abs/1912.04232"><strong>JAX, M.D.: A
|
||
Framework for Differentiable Physics</strong>. Samuel S. Schoenholz,
|
||
Ekin D. Cubuk. <em>NeurIPS 2020</em>.</a> - Introduces JAX, M.D., a
|
||
differentiable physics library which includes simulation environments,
|
||
interaction potentials, neural networks, and more.</li>
|
||
<li><a href="https://arxiv.org/abs/2010.09063"><strong>Enabling Fast
|
||
Differentially Private SGD via Just-in-Time Compilation and
|
||
Vectorization</strong>. Pranav Subramani, Nicholas Vadivelu, Gautam
|
||
Kamath. <em>arXiv 2020</em>.</a> - Uses JAX’s JIT and VMAP to achieve
|
||
faster differentially private than existing libraries.</li>
|
||
<li><a href="https://arxiv.org/abs/2311.16080"><strong>XLB: A
|
||
Differentiable Massively Parallel Lattice Boltzmann Library in
|
||
Python</strong>. Mohammadmehdi Ataei, Hesam Salehipour. <em>arXiv
|
||
2023</em>.</a> - White paper describing the XLB library: benchmarks,
|
||
validations, and more details about the library.</li>
|
||
</ul>
|
||
<!--lint enable awesome-list-item-->
|
||
<p><a name="tutorials-and-blog-posts" /></p>
|
||
<h2 id="tutorials-and-blog-posts">Tutorials and Blog Posts</h2>
|
||
<ul>
|
||
<li><a
|
||
href="https://deepmind.com/blog/article/using-jax-to-accelerate-our-research">Using
|
||
JAX to accelerate our research by David Budden and Matteo Hessel</a> -
|
||
Describes the state of JAX and the JAX ecosystem at DeepMind.</li>
|
||
<li><a
|
||
href="https://roberttlange.github.io/posts/2020/03/blog-post-10/">Getting
|
||
started with JAX (MLPs, CNNs & RNNs) by Robert Lange</a> - Neural
|
||
network building blocks from scratch with the basic JAX operators.</li>
|
||
<li><a href="https://www.kaggle.com/code/truthr/jax-0">Learn JAX: From
|
||
Linear Regression to Neural Networks by Rito Ghosh</a> - A gentle
|
||
introduction to JAX and using it to implement Linear and Logistic
|
||
Regression, and Neural Network models and using them to solve real world
|
||
problems.</li>
|
||
<li><a
|
||
href="https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen">Tutorial:
|
||
image classification with JAX and Flax Linen by 8bitmp3</a> - Learn how
|
||
to create a simple convolutional network with the Linen API by Flax and
|
||
train it to recognize handwritten digits.</li>
|
||
<li><a
|
||
href="https://medium.com/swlh/plugging-into-jax-16c120ec3302">Plugging
|
||
Into JAX by Nick Doiron</a> - Compares Flax, Haiku, and Objax on the
|
||
Kaggle flower classification challenge.</li>
|
||
<li><a
|
||
href="https://blog.evjang.com/2019/02/maml-jax.html">Meta-Learning in 50
|
||
Lines of JAX by Eric Jang</a> - Introduction to both JAX and
|
||
Meta-Learning.</li>
|
||
<li><a href="https://blog.evjang.com/2019/07/nf-jax.html">Normalizing
|
||
Flows in 100 Lines of JAX by Eric Jang</a> - Concise implementation of
|
||
<a href="https://arxiv.org/abs/1605.08803">RealNVP</a>.</li>
|
||
<li><a href="https://blog.evjang.com/2019/11/jaxpt.html">Differentiable
|
||
Path Tracing on the GPU/TPU by Eric Jang</a> - Tutorial on implementing
|
||
path tracing.</li>
|
||
<li><a href="http://matpalm.com/blog/ensemble_nets">Ensemble networks by
|
||
Mat Kelcey</a> - Ensemble nets are a method of representing an ensemble
|
||
of models as one single logical model.</li>
|
||
<li><a href="http://matpalm.com/blog/ood_using_focal_loss">Out of
|
||
distribution (OOD) detection by Mat Kelcey</a> - Implements different
|
||
methods for OOD detection.</li>
|
||
<li><a href="https://www.radx.in/jax.html">Understanding Autodiff with
|
||
JAX by Srihari Radhakrishna</a> - Understand how autodiff works using
|
||
JAX.</li>
|
||
<li><a href="https://sjmielke.com/jax-purify.htm">From PyTorch to JAX:
|
||
towards neural net frameworks that purify stateful code by Sabrina J.
|
||
Mielke</a> - Showcases how to go from a PyTorch-like style of coding to
|
||
a more Functional-style of coding.</li>
|
||
<li><a href="https://github.com/dfm/extending-jax">Extending JAX with
|
||
custom C++ and CUDA code by Dan Foreman-Mackey</a> - Tutorial
|
||
demonstrating the infrastructure required to provide custom ops in
|
||
JAX.</li>
|
||
<li><a
|
||
href="https://roberttlange.github.io/posts/2021/02/cma-es-jax/">Evolving
|
||
Neural Networks in JAX by Robert Tjarko Lange</a> - Explores how JAX can
|
||
power the next generation of scalable neuroevolution algorithms.</li>
|
||
<li><a
|
||
href="http://lukemetz.com/exploring-hyperparameter-meta-loss-landscapes-with-jax/">Exploring
|
||
hyperparameter meta-loss landscapes with JAX by Luke Metz</a> -
|
||
Demonstrates how to use JAX to perform inner-loss optimization with SGD
|
||
and Momentum, outer-loss optimization with gradients, and outer-loss
|
||
optimization using evolutionary strategies.</li>
|
||
<li><a
|
||
href="https://martiningram.github.io/deterministic-advi/">Deterministic
|
||
ADVI in JAX by Martin Ingram</a> - Walk through of implementing
|
||
automatic differentiation variational inference (ADVI) easily and
|
||
cleanly with JAX.</li>
|
||
<li><a href="http://matpalm.com/blog/evolved_channel_selection/">Evolved
|
||
channel selection by Mat Kelcey</a> - Trains a classification model
|
||
robust to different combinations of input channels at different
|
||
resolutions, then uses a genetic algorithm to decide the best
|
||
combination for a particular loss.</li>
|
||
<li><a
|
||
href="https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb">Introduction
|
||
to JAX by Kevin Murphy</a> - Colab that introduces various aspects of
|
||
the language and applies them to simple ML problems.</li>
|
||
<li><a
|
||
href="https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/">Writing
|
||
an MCMC sampler in JAX by Jeremie Coullon</a> - Tutorial on the
|
||
different ways to write an MCMC sampler in JAX along with speed
|
||
benchmarks.</li>
|
||
<li><a
|
||
href="https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/">How
|
||
to add a progress bar to JAX scans and loops by Jeremie Coullon</a> -
|
||
Tutorial on how to add a progress bar to compiled loops in JAX using the
|
||
<code>host_callback</code> module.</li>
|
||
<li><a href="https://github.com/gordicaleksa/get-started-with-JAX">Get
|
||
started with JAX by Aleksa Gordić</a> - A series of notebooks and videos
|
||
going from zero JAX knowledge to building neural networks in Haiku.</li>
|
||
<li><a
|
||
href="https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-FLAX--VmlldzoyMzA4ODEy">Writing
|
||
a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit</a>
|
||
- A tutorial on writing a simple end-to-end training and evaluation
|
||
pipeline in JAX, Flax and Optax.</li>
|
||
<li><a
|
||
href="https://wandb.ai/wandb/nerf-jax/reports/Implementing-NeRF-in-JAX--VmlldzoxODA2NDk2?galleryTag=jax">Implementing
|
||
NeRF in JAX by Soumik Rakshit and Saurav Maheshkar</a> - A tutorial on
|
||
3D volumetric rendering of scenes represented by Neural Radiance Fields
|
||
in JAX.</li>
|
||
<li><a
|
||
href="https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html">Deep
|
||
Learning tutorials with JAX+Flax by Phillip Lippe</a> - A series of
|
||
notebooks explaining various deep learning concepts, from basics
|
||
(e.g. intro to JAX/Flax, activiation functions) to recent advances
|
||
(e.g., Vision Transformers, SimCLR), with translations to PyTorch.</li>
|
||
<li><a href="https://chrislu.page/blog/meta-disco/">Achieving 4000x
|
||
Speedups with PureJaxRL</a> - A blog post on how JAX can massively
|
||
speedup RL training through vectorisation.</li>
|
||
<li><a
|
||
href="https://levelup.gitconnected.com/create-your-own-automatically-differentiable-simulation-with-python-jax-46951e120fbb?sk=e8b9213dd2c6a5895926b2695d28e4aa">Simple
|
||
PDE solver + Constrained Optimization with JAX by Philip Mocz</a> - A
|
||
simple example of solving the advection-diffusion equations with JAX and
|
||
using it in a constrained optimization problem to find initial
|
||
conditions that yield desired result.</li>
|
||
</ul>
|
||
<p><a name="books" /></p>
|
||
<h2 id="books">Books</h2>
|
||
<ul>
|
||
<li><a href="https://www.manning.com/books/jax-in-action">Jax in
|
||
Action</a> - A hands-on guide to using JAX for deep learning and other
|
||
mathematically-intensive applications.</li>
|
||
</ul>
|
||
<p><a name="community" /></p>
|
||
<h2 id="community">Community</h2>
|
||
<ul>
|
||
<li><a
|
||
href="https://discord.com/channels/1107832795377713302/1107832795688083561">JaxLLM
|
||
(Unofficial) Discord</a></li>
|
||
<li><a href="https://github.com/google/jax/discussions">JAX GitHub
|
||
Discussions</a></li>
|
||
<li><a href="https://www.reddit.com/r/JAX/">Reddit</a></li>
|
||
</ul>
|
||
<h2 id="contributing">Contributing</h2>
|
||
<p>Contributions welcome! Read the <a
|
||
href="contributing.md">contribution guidelines</a> first.</p>
|
||
<p><a href="https://github.com/n2cholas/awesome-jax">jax.md
|
||
Github</a></p>
|