update
This commit is contained in:
908
html/jax.md2.html
Normal file
908
html/jax.md2.html
Normal file
@@ -0,0 +1,908 @@
|
||||
<!--lint ignore double-link-->
|
||||
<h1 id="awesome-jax-awesome">Awesome JAX <a
|
||||
href="https://awesome.re"><img src="https://awesome.re/badge.svg"
|
||||
alt="Awesome" /></a><a
|
||||
href="https://github.com/google/jax"><img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="JAX Logo" align="right" height="100"></a></h1>
|
||||
<!--lint ignore double-link-->
|
||||
<p><a href="https://github.com/google/jax">JAX</a> brings automatic
|
||||
differentiation and the <a href="https://www.tensorflow.org/xla">XLA
|
||||
compiler</a> together through a <a
|
||||
href="https://numpy.org/">NumPy</a>-like API for high performance
|
||||
machine learning research on accelerators like GPUs and TPUs.
|
||||
<!--lint enable double-link--></p>
|
||||
<p>This is a curated list of awesome JAX libraries, projects, and other
|
||||
resources. Contributions are welcome!</p>
|
||||
<h2 id="contents">Contents</h2>
|
||||
<ul>
|
||||
<li><a href="#libraries">Libraries</a></li>
|
||||
<li><a href="#models-and-projects">Models and Projects</a></li>
|
||||
<li><a href="#videos">Videos</a></li>
|
||||
<li><a href="#papers">Papers</a></li>
|
||||
<li><a href="#tutorials-and-blog-posts">Tutorials and Blog
|
||||
Posts</a></li>
|
||||
<li><a href="#books">Books</a></li>
|
||||
<li><a href="#community">Community</a></li>
|
||||
</ul>
|
||||
<p><a name="libraries" /></p>
|
||||
<h2 id="libraries">Libraries</h2>
|
||||
<ul>
|
||||
<li>Neural Network Libraries
|
||||
<ul>
|
||||
<li><a href="https://github.com/google/flax">Flax</a> - Centered on
|
||||
flexibility and clarity.
|
||||
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/flax/tree/main/flax/nnx">Flax
|
||||
NNX</a> - An evolution on Flax by the same team
|
||||
<img src="https://img.shields.io/github/stars/google/flax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/dm-haiku">Haiku</a> - Focused
|
||||
on simplicity, created by the authors of Sonnet at DeepMind.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/dm-haiku?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/objax">Objax</a> - Has an object
|
||||
oriented design similar to PyTorch.
|
||||
<img src="https://img.shields.io/github/stars/google/objax?style=social" align="center"></li>
|
||||
<li><a href="https://poets-ai.github.io/elegy/">Elegy</a> - A High Level
|
||||
API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
|
||||
<img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/trax">Trax</a> - “Batteries
|
||||
included” deep learning library focused on providing solutions for
|
||||
common workloads.
|
||||
<img src="https://img.shields.io/github/stars/google/trax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/jraph">Jraph</a> - Lightweight
|
||||
graph neural network library.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/neural-tangents">Neural
|
||||
Tangents</a> - High-level API for specifying neural networks of both
|
||||
finite and <em>infinite</em> width.
|
||||
<img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/huggingface/transformers">HuggingFace
|
||||
Transformers</a> - Ecosystem of pretrained Transformers for a wide range
|
||||
of natural language tasks (Flax).
|
||||
<img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/patrick-kidger/equinox">Equinox</a> -
|
||||
Callable PyTrees and filtered JIT/grad transformations => neural
|
||||
networks in JAX.
|
||||
<img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google-research/scenic">Scenic</a> - A
|
||||
Jax Library for Computer Vision Research and Beyond.
|
||||
<img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google-deepmind/penzai">Penzai</a> -
|
||||
Prioritizes legibility, visualization, and easy editing of neural
|
||||
network models with composable tools and a simple mental model.
|
||||
<img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center"></li>
|
||||
</ul></li>
|
||||
<li><a href="https://github.com/stanford-crfm/levanter">Levanter</a> -
|
||||
Legible, Scalable, Reproducible Foundation Models with Named Tensors and
|
||||
JAX.
|
||||
<img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/young-geng/EasyLM">EasyLM</a> - LLMs
|
||||
made easy: Pre-training, finetuning, evaluating and serving LLMs in
|
||||
JAX/Flax.
|
||||
<img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/pyro-ppl/numpyro">NumPyro</a> -
|
||||
Probabilistic programming based on the Pyro library.
|
||||
<img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/chex">Chex</a> - Utilities to
|
||||
write and test reliable JAX code.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/optax">Optax</a> - Gradient
|
||||
processing and optimization library.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/rlax">RLax</a> - Library for
|
||||
implementing reinforcement learning agents.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/jax-md">JAX, M.D.</a> -
|
||||
Accelerated, differential molecular dynamics.
|
||||
<img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/coax-dev/coax">Coax</a> - Turn RL papers
|
||||
into code, the easy way.
|
||||
<img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/distrax">Distrax</a> -
|
||||
Reimplementation of TensorFlow Probability, containing probability
|
||||
distributions and bijectors.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/cvxgrp/cvxpylayers">cvxpylayers</a> -
|
||||
Construct differentiable convex optimization layers.
|
||||
<img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/tensorly/tensorly">TensorLy</a> - Tensor
|
||||
learning made simple.
|
||||
<img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/netket/netket">NetKet</a> - Machine
|
||||
Learning toolbox for Quantum Physics.
|
||||
<img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/awslabs/fortuna">Fortuna</a> - AWS
|
||||
library for Uncertainty Quantification in Deep Learning.
|
||||
<img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/blackjax-devs/blackjax">BlackJAX</a> -
|
||||
Library of samplers for JAX.
|
||||
<img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center"></li>
|
||||
</ul>
|
||||
<p><a name="new-libraries" /></p>
|
||||
<h3 id="new-libraries">New Libraries</h3>
|
||||
<p>This section contains libraries that are well-made and useful, but
|
||||
have not necessarily been battle-tested by a large userbase yet.</p>
|
||||
<ul>
|
||||
<li>Neural Network Libraries
|
||||
<ul>
|
||||
<li><a href="https://github.com/google/fedjax">FedJAX</a> - Federated
|
||||
learning in JAX, built on Optax and Haiku.
|
||||
<img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/mfinzi/equivariant-MLP">Equivariant
|
||||
MLP</a> - Construct equivariant neural network layers.
|
||||
<img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/n2cholas/jax-resnet/">jax-resnet</a> -
|
||||
Implementations and checkpoints for ResNet variants in Flax.
|
||||
<img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/srush/parallax">Parallax</a> - Immutable
|
||||
Torch Modules for JAX.
|
||||
<img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center"></li>
|
||||
</ul></li>
|
||||
<li>Nonlinear Optimization
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://github.com/patrick-kidger/optimistix">Optimistix</a> -
|
||||
Root finding, minimisation, fixed points, and least squares.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/jaxopt">JAXopt</a> - Hardware
|
||||
accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
|
||||
<img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center"></li>
|
||||
</ul></li>
|
||||
<li><a href="https://github.com/ElArkk/jax-unirep">jax-unirep</a> -
|
||||
Library implementing the <a
|
||||
href="https://www.nature.com/articles/s41592-019-0598-1">UniRep
|
||||
model</a> for protein machine learning applications.
|
||||
<img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/danielward27/flowjax">flowjax</a> -
|
||||
Distributions and normalizing flows built as equinox modules.
|
||||
<img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/ChrisWaites/jax-flows">jax-flows</a> -
|
||||
Normalizing flows in JAX.
|
||||
<img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/ExpectationMax/sklearn-jax-kernels">sklearn-jax-kernels</a>
|
||||
- <code>scikit-learn</code> kernel matrices using JAX.
|
||||
<img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo">jax-cosmo</a>
|
||||
- Differentiable cosmology library.
|
||||
<img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/NeilGirdhar/efax">efax</a> - Exponential
|
||||
Families in JAX.
|
||||
<img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/PhilipVinc/mpi4jax">mpi4jax</a> -
|
||||
Combine MPI operations with your Jax code on CPUs and GPUs.
|
||||
<img src="https://img.shields.io/github/stars/PhilipVinc/mpi4jax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/4rtemi5/imax">imax</a> - Image
|
||||
augmentations and transformations.
|
||||
<img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/rolandgvc/flaxvision">FlaxVision</a> -
|
||||
Flax version of TorchVision.
|
||||
<img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/tensorflow/probability/tree/master/spinoffs/oryx">Oryx</a>
|
||||
- Probabilistic programming language based on program
|
||||
transformations.</li>
|
||||
<li><a href="https://github.com/google-research/ott">Optimal Transport
|
||||
Tools</a> - Toolbox that bundles utilities to solve optimal transport
|
||||
problems.</li>
|
||||
<li><a href="https://github.com/romanodev/deltapv">delta PV</a> - A
|
||||
photovoltaic simulator with automatic differentation.
|
||||
<img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/brentyi/jaxlie">jaxlie</a> - Lie theory
|
||||
library for rigid body transformations and optimization.
|
||||
<img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/brax">BRAX</a> - Differentiable
|
||||
physics engine to simulate environments along with learning algorithms
|
||||
to train agents for these environments.
|
||||
<img src="https://img.shields.io/github/stars/google/brax?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/matthias-wright/flaxmodels">flaxmodels</a> -
|
||||
Pretrained models for Jax/Flax.
|
||||
<img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/carnotresearch/cr-sparse">CR.Sparse</a>
|
||||
- XLA accelerated algorithms for sparse representations and compressive
|
||||
sensing.
|
||||
<img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/HajimeKawahara/exojax">exojax</a> -
|
||||
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
|
||||
compatible to JAX.
|
||||
<img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/dm_pix">PIX</a> - PIX is an
|
||||
image processing library in JAX, for JAX.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/dm_pix?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/alonfnt/bayex">bayex</a> - Bayesian
|
||||
Optimization powered by JAX.
|
||||
<img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/ucl-bug/jaxdf">JaxDF</a> - Framework for
|
||||
differentiable simulators with arbitrary discretizations.
|
||||
<img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/tree-math">tree-math</a> -
|
||||
Convert functions that operate on arrays into functions that operate on
|
||||
PyTrees.
|
||||
<img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/DarshanDeshpande/jax-models">jax-models</a> -
|
||||
Implementations of research papers originally without code or code
|
||||
written with frameworks other than JAX.
|
||||
<img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/vicariousinc/PGMax">PGMax</a> - A
|
||||
framework for building discrete Probabilistic Graphical Models (PGM’s)
|
||||
and running inference inference on them via JAX.
|
||||
<img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/evojax">EvoJAX</a> -
|
||||
Hardware-Accelerated Neuroevolution
|
||||
<img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/RobertTLange/evosax">evosax</a> -
|
||||
JAX-Based Evolution Strategies
|
||||
<img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/SymJAX/SymJAX">SymJAX</a> - Symbolic
|
||||
CPU/GPU/TPU programming.
|
||||
<img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/rlouf/mcx">mcx</a> - Express &
|
||||
compile probabilistic programs for performant inference.
|
||||
<img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/einshape">Einshape</a> -
|
||||
DSL-based reshaping library for JAX and other frameworks.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/alx">ALX</a>
|
||||
- Open-source library for distributed matrix factorization using
|
||||
Alternating Least Squares, more info in <a
|
||||
href="https://arxiv.org/abs/2112.02194"><em>ALX: Large Scale Matrix
|
||||
Factorization on TPUs</em></a>.</li>
|
||||
<li><a href="https://github.com/patrick-kidger/diffrax">Diffrax</a> -
|
||||
Numerical differential equation solvers in JAX.
|
||||
<img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/dfm/tinygp">tinygp</a> - The
|
||||
<em>tiniest</em> of Gaussian process libraries in JAX.
|
||||
<img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/RobertTLange/gymnax">gymnax</a> -
|
||||
Reinforcement Learning Environments with the well-known gym API.
|
||||
<img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/mctx">Mctx</a> - Monte Carlo
|
||||
tree search algorithms in native JAX.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/kfac-jax">KFAC-JAX</a> - Second
|
||||
Order Optimization with Approximate Curvature for NNs.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/kfac-jax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/deepmind/tf2jax">TF2JAX</a> - Convert
|
||||
functions/graphs to JAX functions.
|
||||
<img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/ucl-bug/jwave">jwave</a> - A library for
|
||||
differentiable acoustic simulations
|
||||
<img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/thomaspinder/GPJax">GPJax</a> - Gaussian
|
||||
processes in JAX.</li>
|
||||
<li><a href="https://github.com/instadeepai/jumanji">Jumanji</a> - A
|
||||
Suite of Industry-Driven Hardware-Accelerated RL Environments written in
|
||||
JAX.
|
||||
<img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/paganpasta/eqxvision">Eqxvision</a> -
|
||||
Equinox version of Torchvision.
|
||||
<img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/dipolar-quantum-gases/jaxfit">JAXFit</a>
|
||||
- Accelerated curve fitting library for nonlinear least-squares problems
|
||||
(see <a href="https://arxiv.org/abs/2208.12187">arXiv paper</a>).
|
||||
<img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/gboehl/econpizza">econpizza</a> - Solve
|
||||
macroeconomic models with hetereogeneous agents using JAX.
|
||||
<img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/secretflow/spu">SPU</a> - A
|
||||
domain-specific compiler and runtime suite to run JAX code with
|
||||
MPC(Secure Multi-Party Computation).
|
||||
<img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/jeremiecoullon/jax-tqdm">jax-tqdm</a> -
|
||||
Add a tqdm progress bar to JAX scans and loops.
|
||||
<img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/alvarobartt/safejax">safejax</a> -
|
||||
Serialize JAX, Flax, Haiku, or Objax model params with
|
||||
🤗<code>safetensors</code>.
|
||||
<img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/ASEM000/kernex">Kernex</a> -
|
||||
Differentiable stencil decorators in JAX.
|
||||
<img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/maxtext">MaxText</a> - A simple,
|
||||
performant and scalable Jax LLM written in pure Python/Jax and targeting
|
||||
Google Cloud TPUs.
|
||||
<img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/paxml">Pax</a> - A Jax-based
|
||||
machine learning framework for training large scale models.
|
||||
<img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/google/praxis">Praxis</a> - The layer
|
||||
library for Pax with a goal to be usable by other JAX-based ML projects.
|
||||
<img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/luchris429/purejaxrl">purejaxrl</a> -
|
||||
Vectorisable, end-to-end RL algorithms in JAX.
|
||||
<img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/davisyoshida/lorax">Lorax</a> -
|
||||
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)</li>
|
||||
<li><a href="https://github.com/lanl/scico">SCICO</a> - Scientific
|
||||
computational imaging in JAX.
|
||||
<img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/kmheckel/spyx">Spyx</a> - Spiking Neural
|
||||
Networks in JAX for machine learning on neuromorphic hardware.
|
||||
<img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center"></li>
|
||||
<li>Brain Dynamics Programming Ecosystem
|
||||
<ul>
|
||||
<li><a href="https://github.com/brainpy/BrainPy">BrainPy</a> - Brain
|
||||
Dynamics Programming in Python.
|
||||
<img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/chaobrain/brainunit">brainunit</a> -
|
||||
Physical units and unit-aware mathematical system in JAX.
|
||||
<img src="https://img.shields.io/github/stars/chaobrain/brainunit?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/chaobrain/dendritex">dendritex</a> -
|
||||
Dendritic Modeling in JAX.
|
||||
<img src="https://img.shields.io/github/stars/chaobrain/dendritex?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/chaobrain/brainstate">brainstate</a> -
|
||||
State-based Transformation System for Program Compilation and
|
||||
Augmentation.
|
||||
<img src="https://img.shields.io/github/stars/chaobrain/brainstate?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/chaobrain/braintaichi">braintaichi</a> -
|
||||
Leveraging Taichi Lang to customize brain dynamics operators.
|
||||
<img src="https://img.shields.io/github/stars/chaobrain/braintaichi?style=social" align="center"></li>
|
||||
</ul></li>
|
||||
<li><a href="https://github.com/ott-jax/ott">OTT-JAX</a> - Optimal
|
||||
transport tools in JAX.
|
||||
<img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center"></li>
|
||||
<li><a
|
||||
href="https://github.com/adaptive-intelligent-robotics/QDax">QDax</a> -
|
||||
Quality Diversity optimization in Jax.
|
||||
<img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/NVIDIA/JAX-Toolbox">JAX Toolbox</a> -
|
||||
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries
|
||||
such as T5x, Paxml, and Transformer Engine.
|
||||
<img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center"></li>
|
||||
<li><a href="http://github.com/sotetsuk/pgx">Pgx</a> - Vectorized board
|
||||
game environments for RL with an AlphaZero example.
|
||||
<img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/erfanzar/EasyDeL">EasyDeL</a> - EasyDeL
|
||||
🔮 is an OpenSource Library to make your training faster and more
|
||||
Optimized With cool Options for training and serving (Llama, MPT,
|
||||
Mixtral, Falcon, etc) in JAX
|
||||
<img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/Autodesk/XLB">XLB</a> - A Differentiable
|
||||
Massively Parallel Lattice Boltzmann Library in Python for Physics-Based
|
||||
Machine Learning.
|
||||
<img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/dynamiqs/dynamiqs">dynamiqs</a> -
|
||||
High-performance and differentiable simulations of quantum systems with
|
||||
JAX.
|
||||
<img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/i-m-iron-man/Foragax">foragax</a> -
|
||||
Agent-Based modelling framework in JAX.
|
||||
<img src="https://img.shields.io/github/stars/i-m-iron-man/Foragax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/bahremsd/tmmax">tmmax</a> - Vectorized
|
||||
calculation of optical properties in thin-film structures using JAX.
|
||||
Swiss Army knife tool for thin-film optics research
|
||||
<img src="https://img.shields.io/github/stars/bahremsd/tmmax" align="center"></li>
|
||||
<li><a href="https://github.com/gchq/coreax">Coreax</a> - Algorithms for
|
||||
finding coresets to compress large datasets while retaining their
|
||||
statistical properties.
|
||||
<img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center"></li>
|
||||
<li><a href="https://github.com/epignatelli/navix">NAVIX</a> - A
|
||||
reimplementation of MiniGrid, a Reinforcement Learning environment, in
|
||||
JAX
|
||||
<img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center"></li>
|
||||
</ul>
|
||||
<p><a name="models-and-projects" /></p>
|
||||
<h2 id="models-and-projects">Models and Projects</h2>
|
||||
<h3 id="jax">JAX</h3>
|
||||
<ul>
|
||||
<li><a href="https://github.com/tancik/fourier-feature-networks">Fourier
|
||||
Feature Networks</a> - Official implementation of <a
|
||||
href="https://people.eecs.berkeley.edu/~bmild/fourfeat"><em>Fourier
|
||||
Features Let Networks Learn High Frequency Functions in Low Dimensional
|
||||
Domains</em></a>.</li>
|
||||
<li><a href="https://github.com/AaltoML/kalman-jax">kalman-jax</a> -
|
||||
Approximate inference for Markov (i.e., temporal) Gaussian processes
|
||||
using iterated Kalman filtering and smoothing.</li>
|
||||
<li><a href="https://github.com/Joshuaalbert/jaxns">jaxns</a> - Nested
|
||||
sampling in JAX.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/amortized_bo">Amortized
|
||||
Bayesian Optimization</a> - Code related to <a
|
||||
href="http://www.auai.org/uai2020/proceedings/329_main_paper.pdf"><em>Amortized
|
||||
Bayesian Optimization over Discrete Spaces</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/aqt">Accurate
|
||||
Quantized Training</a> - Tools and libraries for running and analyzing
|
||||
neural network quantization experiments in JAX and Flax.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/bnn_hmc">BNN-HMC</a>
|
||||
- Implementation for the paper <a
|
||||
href="https://arxiv.org/abs/2104.14421"><em>What Are Bayesian Neural
|
||||
Network Posteriors Really Like?</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/jax_dft">JAX-DFT</a>
|
||||
- One-dimensional density functional theory (DFT) in JAX, with
|
||||
implementation of <a
|
||||
href="https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.126.036401"><em>Kohn-Sham
|
||||
equations as regularizer: building prior knowledge into machine-learned
|
||||
physics</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/robust_loss_jax">Robust
|
||||
Loss</a> - Reference code for the paper <a
|
||||
href="https://arxiv.org/abs/1701.03077"><em>A General and Adaptive
|
||||
Robust Loss Function</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/symbolic_functionals">Symbolic
|
||||
Functionals</a> - Demonstration from <a
|
||||
href="https://arxiv.org/abs/2203.02540"><em>Evolving symbolic density
|
||||
functionals</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/trimap">TriMap</a>
|
||||
- Official JAX implementation of <a
|
||||
href="https://arxiv.org/abs/1910.00204"><em>TriMap: Large-scale
|
||||
Dimensionality Reduction Using Triplets</em></a>.</li>
|
||||
</ul>
|
||||
<h3 id="flax">Flax</h3>
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://github.com/J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B">DeepSeek-R1-Flax-1.5B-Distill</a>
|
||||
- Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/performer/fast_attention/jax">Performer</a>
|
||||
- Flax implementation of the Performer (linear transformer via FAVOR+)
|
||||
architecture.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/jaxnerf">JaxNeRF</a>
|
||||
- Implementation of <a
|
||||
href="http://www.matthewtancik.com/nerf"><em>NeRF: Representing Scenes
|
||||
as Neural Radiance Fields for View Synthesis</em></a> with multi-device
|
||||
GPU/TPU support.</li>
|
||||
<li><a href="https://github.com/google/mipnerf">mip-NeRF</a> - Official
|
||||
implementation of <a href="https://jonbarron.info/mipnerf"><em>Mip-NeRF:
|
||||
A Multiscale Representation for Anti-Aliasing Neural Radiance
|
||||
Fields</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/regnerf">RegNeRF</a>
|
||||
- Official implementation of <a
|
||||
href="https://m-niemeyer.github.io/regnerf/"><em>RegNeRF: Regularizing
|
||||
Neural Radiance Fields for View Synthesis from Sparse
|
||||
Inputs</em></a>.</li>
|
||||
<li><a href="https://github.com/huangjuite/jaxneus">JaxNeuS</a> -
|
||||
Implementation of <a
|
||||
href="https://lingjie0206.github.io/papers/NeuS/"><em>NeuS: Learning
|
||||
Neural Implicit Surfaces by Volume Rendering for Multi-view
|
||||
Reconstruction</em></a></li>
|
||||
<li><a href="https://github.com/google-research/big_transfer">Big
|
||||
Transfer (BiT)</a> - Implementation of <a
|
||||
href="https://arxiv.org/abs/1912.11370"><em>Big Transfer (BiT): General
|
||||
Visual Representation Learning</em></a>.</li>
|
||||
<li><a href="https://github.com/ikostrikov/jax-rl">JAX RL</a> -
|
||||
Implementations of reinforcement learning algorithms.</li>
|
||||
<li><a href="https://github.com/SauravMaheshkar/gMLP">gMLP</a> -
|
||||
Implementation of <a href="https://arxiv.org/abs/2105.08050"><em>Pay
|
||||
Attention to MLPs</em></a>.</li>
|
||||
<li><a href="https://github.com/SauravMaheshkar/MLP-Mixer">MLP Mixer</a>
|
||||
- Minimal implementation of <a
|
||||
href="https://arxiv.org/abs/2105.01601"><em>MLP-Mixer: An all-MLP
|
||||
Architecture for Vision</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/scalable_shampoo">Distributed
|
||||
Shampoo</a> - Implementation of <a
|
||||
href="https://arxiv.org/abs/2002.09018"><em>Second Order Optimization
|
||||
Made Practical</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/nested-transformer">NesT</a> -
|
||||
Official implementation of <a
|
||||
href="https://arxiv.org/abs/2105.12723"><em>Aggregating Nested
|
||||
Transformers</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/xmcgan_image_generation">XMC-GAN</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2101.04702"><em>Cross-Modal Contrastive
|
||||
Learning for Text-to-Image Generation</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/f_net">FNet</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2105.03824"><em>FNet: Mixing Tokens with
|
||||
Fourier Transforms</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/gfsa">GFSA</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2007.04929"><em>Learning Graph Structure
|
||||
With A Finite-State Automaton Layer</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/ipagnn">IPA-GNN</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2010.12621"><em>Learning to Execute Programs
|
||||
with Instruction Pointer Attention Graph Neural Networks</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/flax_models">Flax
|
||||
Models</a> - Collection of models and methods implemented in Flax.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/protein_lm">Protein
|
||||
LM</a> - Implements BERT and autoregressive models for proteins, as
|
||||
described in <a
|
||||
href="https://www.biorxiv.org/content/10.1101/622803v1.full"><em>Biological
|
||||
Structure and Function Emerge from Scaling Unsupervised Learning to 250
|
||||
Million Protein Sequences</em></a> and <a
|
||||
href="https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2"><em>ProGen:
|
||||
Language Modeling for Protein Generation</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/ptopk_patch_selection">Slot
|
||||
Attention</a> - Reference implementation for <a
|
||||
href="https://arxiv.org/abs/2104.03059"><em>Differentiable Patch
|
||||
Selection for Image Recognition</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/vision_transformer">Vision
|
||||
Transformer</a> - Official implementation of <a
|
||||
href="https://arxiv.org/abs/2010.11929"><em>An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale</em></a>.</li>
|
||||
<li><a href="https://github.com/matthias-wright/jax-fid">FID
|
||||
computation</a> - Port of <a
|
||||
href="https://github.com/mseitzer/pytorch-fid">mseitzer/pytorch-fid</a>
|
||||
to Flax.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/autoregressive_diffusion">ARDM</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2110.02037"><em>Autoregressive Diffusion
|
||||
Models</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/d3pm">D3PM</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2107.03006"><em>Structured Denoising
|
||||
Diffusion Models in Discrete State-Spaces</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/gumbel_max_causal_gadgets">Gumbel-max
|
||||
Causal Mechanisms</a> - Code for <a
|
||||
href="https://arxiv.org/abs/2111.06888"><em>Learning Generalized
|
||||
Gumbel-max Causal Mechanisms</em></a>, with extra code in <a
|
||||
href="https://github.com/GuyLor/gumbel_max_causal_gadgets_part2">GuyLor/gumbel_max_causal_gadgets_part2</a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/latent_programmer">Latent
|
||||
Programmer</a> - Code for the ICML 2021 paper <a
|
||||
href="https://arxiv.org/abs/2012.00377"><em>Latent Programmer: Discrete
|
||||
Latent Codes for Program Synthesis</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/snerg">SNeRG</a>
|
||||
- Official implementation of <a
|
||||
href="https://phog.github.io/snerg"><em>Baking Neural Radiance Fields
|
||||
for Real-Time View Synthesis</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/spin_spherical_cnns">Spin-weighted
|
||||
Spherical CNNs</a> - Adaptation of <a
|
||||
href="https://arxiv.org/abs/2006.10731"><em>Spin-Weighted Spherical
|
||||
CNNs</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/vdvae_flax">VDVAE</a>
|
||||
- Adaptation of <a href="https://arxiv.org/abs/2011.10650"><em>Very Deep
|
||||
VAEs Generalize Autoregressive Models and Can Outperform Them on
|
||||
Images</em></a>, original code at <a
|
||||
href="https://github.com/openai/vdvae">openai/vdvae</a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/musiq">MUSIQ</a>
|
||||
- Checkpoints and model inference code for the ICCV 2021 paper <a
|
||||
href="https://arxiv.org/abs/2108.05997"><em>MUSIQ: Multi-scale Image
|
||||
Quality Transformer</em></a></li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/aquadem">AQuaDem</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2110.10149"><em>Continuous Control with
|
||||
Action Quantization from Demonstrations</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/combiner">Combiner</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2107.05768"><em>Combiner: Full Attention
|
||||
Transformer with Sparse Computation Cost</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/dreamfields">Dreamfields</a>
|
||||
- Official implementation of the ICLR 2022 paper <a
|
||||
href="https://ajayj.com/dreamfields"><em>Progressive Distillation for
|
||||
Fast Sampling of Diffusion Models</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/gift">GIFT</a>
|
||||
- Official implementation of <a
|
||||
href="https://arxiv.org/abs/2106.06080"><em>Gradual Domain Adaptation in
|
||||
the Wild:When Intermediate Distributions are Absent</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/light_field_neural_rendering">Light
|
||||
Field Neural Rendering</a> - Official implementation of <a
|
||||
href="https://arxiv.org/abs/2112.09687"><em>Light Field Neural
|
||||
Rendering</em></a>.</li>
|
||||
<li><a
|
||||
href="https://colab.research.google.com/drive/1KUKFEMneQMS3OzPYnWZGkEnry3PdzCfn?usp=sharing">Sharpened
|
||||
Cosine Similarity in JAX by Raphael Pisoni</a> - A JAX/Flax
|
||||
implementation of the Sharpened Cosine Similarity layer.</li>
|
||||
<li><a
|
||||
href="https://github.com/IvanIsCoding/GNN-for-Combinatorial-Optimization">GNNs
|
||||
for Solving Combinatorial Optimization Problems</a> - A JAX + Flax
|
||||
implementation of <a
|
||||
href="https://arxiv.org/abs/2107.01188">Combinatorial Optimization with
|
||||
Physics-Inspired Graph Neural Networks</a>.</li>
|
||||
<li><a href="https://github.com/MasterSkepticista/detr">DETR</a> - Flax
|
||||
implementation of <a
|
||||
href="https://github.com/facebookresearch/detr"><em>DETR: End-to-end
|
||||
Object Detection with Transformers</em></a> using Sinkhorn solver and
|
||||
parallel bipartite matching.</li>
|
||||
</ul>
|
||||
<h3 id="haiku">Haiku</h3>
|
||||
<ul>
|
||||
<li><a href="https://github.com/deepmind/alphafold">AlphaFold</a> -
|
||||
Implementation of the inference pipeline of AlphaFold v2.0, presented in
|
||||
<a href="https://www.nature.com/articles/s41586-021-03819-2"><em>Highly
|
||||
accurate protein structure prediction with AlphaFold</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness">Adversarial
|
||||
Robustness</a> - Reference code for <a
|
||||
href="https://arxiv.org/abs/2010.03593"><em>Uncovering the Limits of
|
||||
Adversarial Training against Norm-Bounded Adversarial Examples</em></a>
|
||||
and <a href="https://arxiv.org/abs/2103.01946"><em>Fixing Data
|
||||
Augmentation to Improve Adversarial Robustness</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/byol">Bootstrap
|
||||
Your Own Latent</a> - Implementation for the paper <a
|
||||
href="https://arxiv.org/abs/2006.07733"><em>Bootstrap your own latent: A
|
||||
new approach to self-supervised Learning</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/gated_linear_networks">Gated
|
||||
Linear Networks</a> - GLNs are a family of backpropagation-free neural
|
||||
networks.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/glassy_dynamics">Glassy
|
||||
Dynamics</a> - Open source implementation of the paper <a
|
||||
href="https://www.nature.com/articles/s41567-020-0842-8"><em>Unveiling
|
||||
the predictive power of static structure in glassy
|
||||
systems</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/mmv">MMV</a>
|
||||
- Code for the models in <a
|
||||
href="https://arxiv.org/abs/2006.16228"><em>Self-Supervised MultiModal
|
||||
Versatile Networks</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/nfnets">Normalizer-Free
|
||||
Networks</a> - Official Haiku implementation of <a
|
||||
href="https://arxiv.org/abs/2102.06171"><em>NFNets</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/Information-Fusion-Lab-Umass/NuX">NuX</a> -
|
||||
Normalizing flows with JAX.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc">OGB-LSC</a>
|
||||
- This repository contains DeepMind’s entry to the <a
|
||||
href="https://ogb.stanford.edu/kddcup2021/pcqm4m/">PCQM4M-LSC</a>
|
||||
(quantum chemistry) and <a
|
||||
href="https://ogb.stanford.edu/kddcup2021/mag240m/">MAG240M-LSC</a>
|
||||
(academic graph) tracks of the <a
|
||||
href="https://ogb.stanford.edu/kddcup2021/">OGB Large-Scale
|
||||
Challenge</a> (OGB-LSC).</li>
|
||||
<li><a
|
||||
href="https://github.com/google-research/google-research/tree/master/persistent_es">Persistent
|
||||
Evolution Strategies</a> - Code used for the paper <a
|
||||
href="http://proceedings.mlr.press/v139/vicol21a.html"><em>Unbiased
|
||||
Gradient Estimation in Unrolled Computation Graphs with Persistent
|
||||
Evolution Strategies</em></a>.</li>
|
||||
<li><a href="https://github.com/degregat/two-player-auctions">Two Player
|
||||
Auction Learning</a> - JAX implementation of the paper <a
|
||||
href="https://arxiv.org/abs/2006.05684"><em>Auction learning as a
|
||||
two-player game</em></a>.</li>
|
||||
<li><a
|
||||
href="https://github.com/deepmind/deepmind-research/tree/master/wikigraphs">WikiGraphs</a>
|
||||
- Baseline code to reproduce results in <a
|
||||
href="https://aclanthology.org/2021.textgraphs-1.7"><em>WikiGraphs: A
|
||||
Wikipedia Text - Knowledge Graph Paired Datase</em></a>.</li>
|
||||
</ul>
|
||||
<h3 id="trax">Trax</h3>
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://github.com/google/trax/tree/master/trax/models/reformer">Reformer</a>
|
||||
- Implementation of the Reformer (efficient transformer)
|
||||
architecture.</li>
|
||||
</ul>
|
||||
<h3 id="numpyro">NumPyro</h3>
|
||||
<ul>
|
||||
<li><a href="https://github.com/RothkopfLab/lqg">lqg</a> - Official
|
||||
implementation of Bayesian inverse optimal control for linear-quadratic
|
||||
Gaussian problems from the paper <a
|
||||
href="https://elifesciences.org/articles/76635"><em>Putting perception
|
||||
into action with inverse optimal control for continuous
|
||||
psychophysics</em></a></li>
|
||||
</ul>
|
||||
<p><a name="videos" /></p>
|
||||
<h2 id="videos">Videos</h2>
|
||||
<ul>
|
||||
<li><a href="https://www.youtube.com/watch?v=iDxJxIyzSiM">NeurIPS 2020:
|
||||
JAX Ecosystem Meetup</a> - JAX, its use at DeepMind, and discussion
|
||||
between engineers, scientists, and JAX core team.</li>
|
||||
<li><a href="https://youtu.be/0mVmRHMaOJ4">Introduction to JAX</a> -
|
||||
Simple neural network from scratch in JAX.</li>
|
||||
<li><a href="https://youtu.be/z-WSrQDXkuM">JAX: Accelerated Machine
|
||||
Learning Research | SciPy 2020 | VanderPlas</a> - JAX’s core design, how
|
||||
it’s powering new research, and how you can start using it.</li>
|
||||
<li><a href="https://youtu.be/CecuWGpoztw">Bayesian Programming with JAX
|
||||
+ NumPyro — Andy Kitchen</a> - Introduction to Bayesian modelling using
|
||||
NumPyro.</li>
|
||||
<li><a
|
||||
href="https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python">JAX:
|
||||
Accelerated machine-learning research via composable function
|
||||
transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne</a> -
|
||||
JAX intro presentation in <a
|
||||
href="https://program-transformations.github.io"><em>Program
|
||||
Transformations for Machine Learning</em></a> workshop.</li>
|
||||
<li><a
|
||||
href="https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit">JAX
|
||||
on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
|
||||
Bradbury</a> - Presentation of TPU host access with demo.</li>
|
||||
<li><a
|
||||
href="https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond">Deep
|
||||
Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
|
||||
NeurIPS 2020</a> - Tutorial created by Zico Kolter, David Duvenaud, and
|
||||
Matt Johnson with Colab notebooks avaliable in <a
|
||||
href="http://implicit-layers-tutorial.org"><em>Deep Implicit
|
||||
Layers</em></a>.</li>
|
||||
<li><a href="http://matpalm.com/blog/ymxb_pod_slice/">Solving y=mx+b
|
||||
with Jax on a TPU Pod slice - Mat Kelcey</a> - A four part YouTube
|
||||
tutorial series with Colab notebooks that starts with Jax fundamentals
|
||||
and moves up to training with a data parallel approach on a v3-32 TPU
|
||||
Pod slice.</li>
|
||||
<li><a
|
||||
href="https://github.com/huggingface/transformers/blob/9160d81c98854df44b1d543ce5d65a6aa28444a2/examples/research_projects/jax-projects/README.md#talks">JAX,
|
||||
Flax & Transformers 🤗</a> - 3 days of talks around JAX / Flax,
|
||||
Transformers, large-scale language modeling and other great topics.</li>
|
||||
</ul>
|
||||
<p><a name="papers" /></p>
|
||||
<h2 id="papers">Papers</h2>
|
||||
<p>This section contains papers focused on JAX (e.g. JAX-based library
|
||||
whitepapers, research on JAX, etc). Papers implemented in JAX are listed
|
||||
in the <a href="#projects">Models/Projects</a> section.</p>
|
||||
<!--lint ignore awesome-list-item-->
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://mlsys.org/Conferences/doc/2018/146.pdf"><strong>Compiling
|
||||
machine learning programs via high-level tracing</strong>. Roy Frostig,
|
||||
Matthew James Johnson, Chris Leary. <em>MLSys 2018</em>.</a> - White
|
||||
paper describing an early version of JAX, detailing how computation is
|
||||
traced and compiled.</li>
|
||||
<li><a href="https://arxiv.org/abs/1912.04232"><strong>JAX, M.D.: A
|
||||
Framework for Differentiable Physics</strong>. Samuel S. Schoenholz,
|
||||
Ekin D. Cubuk. <em>NeurIPS 2020</em>.</a> - Introduces JAX, M.D., a
|
||||
differentiable physics library which includes simulation environments,
|
||||
interaction potentials, neural networks, and more.</li>
|
||||
<li><a href="https://arxiv.org/abs/2010.09063"><strong>Enabling Fast
|
||||
Differentially Private SGD via Just-in-Time Compilation and
|
||||
Vectorization</strong>. Pranav Subramani, Nicholas Vadivelu, Gautam
|
||||
Kamath. <em>arXiv 2020</em>.</a> - Uses JAX’s JIT and VMAP to achieve
|
||||
faster differentially private than existing libraries.</li>
|
||||
<li><a href="https://arxiv.org/abs/2311.16080"><strong>XLB: A
|
||||
Differentiable Massively Parallel Lattice Boltzmann Library in
|
||||
Python</strong>. Mohammadmehdi Ataei, Hesam Salehipour. <em>arXiv
|
||||
2023</em>.</a> - White paper describing the XLB library: benchmarks,
|
||||
validations, and more details about the library.</li>
|
||||
</ul>
|
||||
<!--lint enable awesome-list-item-->
|
||||
<p><a name="tutorials-and-blog-posts" /></p>
|
||||
<h2 id="tutorials-and-blog-posts">Tutorials and Blog Posts</h2>
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://deepmind.com/blog/article/using-jax-to-accelerate-our-research">Using
|
||||
JAX to accelerate our research by David Budden and Matteo Hessel</a> -
|
||||
Describes the state of JAX and the JAX ecosystem at DeepMind.</li>
|
||||
<li><a
|
||||
href="https://roberttlange.github.io/posts/2020/03/blog-post-10/">Getting
|
||||
started with JAX (MLPs, CNNs & RNNs) by Robert Lange</a> - Neural
|
||||
network building blocks from scratch with the basic JAX operators.</li>
|
||||
<li><a href="https://www.kaggle.com/code/truthr/jax-0">Learn JAX: From
|
||||
Linear Regression to Neural Networks by Rito Ghosh</a> - A gentle
|
||||
introduction to JAX and using it to implement Linear and Logistic
|
||||
Regression, and Neural Network models and using them to solve real world
|
||||
problems.</li>
|
||||
<li><a
|
||||
href="https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen">Tutorial:
|
||||
image classification with JAX and Flax Linen by 8bitmp3</a> - Learn how
|
||||
to create a simple convolutional network with the Linen API by Flax and
|
||||
train it to recognize handwritten digits.</li>
|
||||
<li><a
|
||||
href="https://medium.com/swlh/plugging-into-jax-16c120ec3302">Plugging
|
||||
Into JAX by Nick Doiron</a> - Compares Flax, Haiku, and Objax on the
|
||||
Kaggle flower classification challenge.</li>
|
||||
<li><a
|
||||
href="https://blog.evjang.com/2019/02/maml-jax.html">Meta-Learning in 50
|
||||
Lines of JAX by Eric Jang</a> - Introduction to both JAX and
|
||||
Meta-Learning.</li>
|
||||
<li><a href="https://blog.evjang.com/2019/07/nf-jax.html">Normalizing
|
||||
Flows in 100 Lines of JAX by Eric Jang</a> - Concise implementation of
|
||||
<a href="https://arxiv.org/abs/1605.08803">RealNVP</a>.</li>
|
||||
<li><a href="https://blog.evjang.com/2019/11/jaxpt.html">Differentiable
|
||||
Path Tracing on the GPU/TPU by Eric Jang</a> - Tutorial on implementing
|
||||
path tracing.</li>
|
||||
<li><a href="http://matpalm.com/blog/ensemble_nets">Ensemble networks by
|
||||
Mat Kelcey</a> - Ensemble nets are a method of representing an ensemble
|
||||
of models as one single logical model.</li>
|
||||
<li><a href="http://matpalm.com/blog/ood_using_focal_loss">Out of
|
||||
distribution (OOD) detection by Mat Kelcey</a> - Implements different
|
||||
methods for OOD detection.</li>
|
||||
<li><a href="https://www.radx.in/jax.html">Understanding Autodiff with
|
||||
JAX by Srihari Radhakrishna</a> - Understand how autodiff works using
|
||||
JAX.</li>
|
||||
<li><a href="https://sjmielke.com/jax-purify.htm">From PyTorch to JAX:
|
||||
towards neural net frameworks that purify stateful code by Sabrina J.
|
||||
Mielke</a> - Showcases how to go from a PyTorch-like style of coding to
|
||||
a more Functional-style of coding.</li>
|
||||
<li><a href="https://github.com/dfm/extending-jax">Extending JAX with
|
||||
custom C++ and CUDA code by Dan Foreman-Mackey</a> - Tutorial
|
||||
demonstrating the infrastructure required to provide custom ops in
|
||||
JAX.</li>
|
||||
<li><a
|
||||
href="https://roberttlange.github.io/posts/2021/02/cma-es-jax/">Evolving
|
||||
Neural Networks in JAX by Robert Tjarko Lange</a> - Explores how JAX can
|
||||
power the next generation of scalable neuroevolution algorithms.</li>
|
||||
<li><a
|
||||
href="http://lukemetz.com/exploring-hyperparameter-meta-loss-landscapes-with-jax/">Exploring
|
||||
hyperparameter meta-loss landscapes with JAX by Luke Metz</a> -
|
||||
Demonstrates how to use JAX to perform inner-loss optimization with SGD
|
||||
and Momentum, outer-loss optimization with gradients, and outer-loss
|
||||
optimization using evolutionary strategies.</li>
|
||||
<li><a
|
||||
href="https://martiningram.github.io/deterministic-advi/">Deterministic
|
||||
ADVI in JAX by Martin Ingram</a> - Walk through of implementing
|
||||
automatic differentiation variational inference (ADVI) easily and
|
||||
cleanly with JAX.</li>
|
||||
<li><a href="http://matpalm.com/blog/evolved_channel_selection/">Evolved
|
||||
channel selection by Mat Kelcey</a> - Trains a classification model
|
||||
robust to different combinations of input channels at different
|
||||
resolutions, then uses a genetic algorithm to decide the best
|
||||
combination for a particular loss.</li>
|
||||
<li><a
|
||||
href="https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb">Introduction
|
||||
to JAX by Kevin Murphy</a> - Colab that introduces various aspects of
|
||||
the language and applies them to simple ML problems.</li>
|
||||
<li><a
|
||||
href="https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/">Writing
|
||||
an MCMC sampler in JAX by Jeremie Coullon</a> - Tutorial on the
|
||||
different ways to write an MCMC sampler in JAX along with speed
|
||||
benchmarks.</li>
|
||||
<li><a
|
||||
href="https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/">How
|
||||
to add a progress bar to JAX scans and loops by Jeremie Coullon</a> -
|
||||
Tutorial on how to add a progress bar to compiled loops in JAX using the
|
||||
<code>host_callback</code> module.</li>
|
||||
<li><a href="https://github.com/gordicaleksa/get-started-with-JAX">Get
|
||||
started with JAX by Aleksa Gordić</a> - A series of notebooks and videos
|
||||
going from zero JAX knowledge to building neural networks in Haiku.</li>
|
||||
<li><a
|
||||
href="https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-FLAX--VmlldzoyMzA4ODEy">Writing
|
||||
a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit</a>
|
||||
- A tutorial on writing a simple end-to-end training and evaluation
|
||||
pipeline in JAX, Flax and Optax.</li>
|
||||
<li><a
|
||||
href="https://wandb.ai/wandb/nerf-jax/reports/Implementing-NeRF-in-JAX--VmlldzoxODA2NDk2?galleryTag=jax">Implementing
|
||||
NeRF in JAX by Soumik Rakshit and Saurav Maheshkar</a> - A tutorial on
|
||||
3D volumetric rendering of scenes represented by Neural Radiance Fields
|
||||
in JAX.</li>
|
||||
<li><a
|
||||
href="https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html">Deep
|
||||
Learning tutorials with JAX+Flax by Phillip Lippe</a> - A series of
|
||||
notebooks explaining various deep learning concepts, from basics
|
||||
(e.g. intro to JAX/Flax, activiation functions) to recent advances
|
||||
(e.g., Vision Transformers, SimCLR), with translations to PyTorch.</li>
|
||||
<li><a href="https://chrislu.page/blog/meta-disco/">Achieving 4000x
|
||||
Speedups with PureJaxRL</a> - A blog post on how JAX can massively
|
||||
speedup RL training through vectorisation.</li>
|
||||
<li><a
|
||||
href="https://levelup.gitconnected.com/create-your-own-automatically-differentiable-simulation-with-python-jax-46951e120fbb?sk=e8b9213dd2c6a5895926b2695d28e4aa">Simple
|
||||
PDE solver + Constrained Optimization with JAX by Philip Mocz</a> - A
|
||||
simple example of solving the advection-diffusion equations with JAX and
|
||||
using it in a constrained optimization problem to find initial
|
||||
conditions that yield desired result.</li>
|
||||
</ul>
|
||||
<p><a name="books" /></p>
|
||||
<h2 id="books">Books</h2>
|
||||
<ul>
|
||||
<li><a href="https://www.manning.com/books/jax-in-action">Jax in
|
||||
Action</a> - A hands-on guide to using JAX for deep learning and other
|
||||
mathematically-intensive applications.</li>
|
||||
</ul>
|
||||
<p><a name="community" /></p>
|
||||
<h2 id="community">Community</h2>
|
||||
<ul>
|
||||
<li><a
|
||||
href="https://discord.com/channels/1107832795377713302/1107832795688083561">JaxLLM
|
||||
(Unofficial) Discord</a></li>
|
||||
<li><a href="https://github.com/google/jax/discussions">JAX GitHub
|
||||
Discussions</a></li>
|
||||
<li><a href="https://www.reddit.com/r/JAX/">Reddit</a></li>
|
||||
</ul>
|
||||
<h2 id="contributing">Contributing</h2>
|
||||
<p>Contributions welcome! Read the <a
|
||||
href="contributing.md">contribution guidelines</a> first.</p>
|
||||
<p><a href="https://github.com/n2cholas/awesome-jax">jax.md
|
||||
Github</a></p>
|
||||
Reference in New Issue
Block a user