JAX

Automatic differentiation and XLA compilation brought together for high-performance machine learning research.

204 resources7 categoriesView Original

Libraries(100 items)

A

ALX

Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.

Libraries
B

bayex

Bayesian Optimization powered by JAX.

Libraries
B

BlackJAX

Library of samplers for JAX.

Libraries
B

BrainPy

Brain Dynamics Programming in Python.

Libraries
B

brainstate

State-based Transformation System for Program Compilation and Augmentation.

Libraries
B

braintaichi

Leveraging Taichi Lang to customize brain dynamics operators.

Libraries
B

brainunit

Physical units and unit-aware mathematical system in JAX.

Libraries
B

BRAX

Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.

Libraries
C

Chex

Utilities to write and test reliable JAX code.

Libraries
C

Coax

Turn RL papers into code, the easy way.

Libraries
C

Coreax

Algorithms for finding coresets to compress large datasets while retaining their statistical properties.

Libraries
C

CR.Sparse

XLA accelerated algorithms for sparse representations and compressive sensing.

Libraries
C

cvxpylayers

Construct differentiable convex optimization layers.

Libraries
D

delta PV

A photovoltaic simulator with automatic differentation.

Libraries
D

dendritex

Dendritic Modeling in JAX.

Libraries
D

Diffrax

Numerical differential equation solvers in JAX.

Libraries
D

Distrax

Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.

Libraries
D

dynamiqs

High-performance and differentiable simulations of quantum systems with JAX.

Libraries
E

EasyDeL

EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX

Libraries
E

EasyLM

LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.

Libraries
E

econpizza

Solve macroeconomic models with hetereogeneous agents using JAX.

Libraries
E

efax

Exponential Families in JAX.

Libraries
E

Einshape

DSL-based reshaping library for JAX and other frameworks.

Libraries
E

Elegy

A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.

Libraries
E

Equinox

Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Libraries
E

Equivariant MLP

Construct equivariant neural network layers.

Libraries
E

Eqxvision

Equinox version of Torchvision.

Libraries
E

EvoJAX

Hardware-Accelerated Neuroevolution

Libraries
E

evosax

JAX-Based Evolution Strategies

Libraries
E

exojax

Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

Libraries
F

FedJAX

Federated learning in JAX, built on Optax and Haiku.

Libraries
F

Flax

Centered on flexibility and clarity.

Libraries
F

Flax NNX

An evolution on Flax by the same team

Libraries
F

flaxmodels

Pretrained models for Jax/Flax.

Libraries
F

FlaxVision

Flax version of TorchVision.

Libraries
F

flowjax

Distributions and normalizing flows built as equinox modules.

Libraries
F

foragax

Agent-Based modelling framework in JAX.

Libraries
F

Fortuna

AWS library for Uncertainty Quantification in Deep Learning.

Libraries
G

GPJax

Gaussian processes in JAX.

Libraries
G

gymnax

Reinforcement Learning Environments with the well-known gym API.

Libraries
H

Haiku

Focused on simplicity, created by the authors of Sonnet at DeepMind.

Libraries
H

HuggingFace Transformers

Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).

Libraries
I

imax

Image augmentations and transformations.

Libraries
J

JAX Toolbox

Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.

Libraries
J

JAX, M.D.

Accelerated, differential molecular dynamics.

Libraries
J

jax-cosmo

Differentiable cosmology library.

Libraries
J

jax-flows

Normalizing flows in JAX.

Libraries
J

jax-models

Implementations of research papers originally without code or code written with frameworks other than JAX.

Libraries
J

jax-resnet

Implementations and checkpoints for ResNet variants in Flax.

Libraries
J

jax-tqdm

Add a tqdm progress bar to JAX scans and loops.

Libraries
J

jax-unirep

Library implementing the UniRep model for protein machine learning applications.

Libraries
J

JaxDF

Framework for differentiable simulators with arbitrary discretizations.

Libraries
J

JAXFit

Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).

Libraries
J

jaxlie

Lie theory library for rigid body transformations and optimization.

Libraries
J

JAXopt

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Libraries
J

Jraph

Lightweight graph neural network library.

Libraries
J

Jumanji

A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.

Libraries
J

jwave

A library for differentiable acoustic simulations

Libraries
K

Kernex

Differentiable stencil decorators in JAX.

Libraries
K

KFAC-JAX

Second Order Optimization with Approximate Curvature for NNs.

Libraries
L

Levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.

Libraries
L

Lorax

Automatically apply LoRA to JAX models (Flax, Haiku, etc.)

Libraries
M

MaxText

A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.

Libraries
M

Mctx

Monte Carlo tree search algorithms in native JAX.

Libraries
M

mcx

Express & compile probabilistic programs for performant inference.

Libraries
M

mpi4jax

Combine MPI operations with your Jax code on CPUs and GPUs.

Libraries
N

NAVIX

A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX

Libraries
N

NetKet

Machine Learning toolbox for Quantum Physics.

Libraries
N

Neural Tangents

High-level API for specifying neural networks of both finite and infinite width.

Libraries
N

NumPyro

Probabilistic programming based on the Pyro library.

Libraries
O

Objax

Has an object oriented design similar to PyTorch.

Libraries
O

Optax

Gradient processing and optimization library.

Libraries
O

Optimal Transport Tools

Toolbox that bundles utilities to solve optimal transport problems.

Libraries
O

Optimistix

Root finding, minimisation, fixed points, and least squares.

Libraries
O

Oryx

Probabilistic programming language based on program transformations.

Libraries
O

OTT-JAX

Optimal transport tools in JAX.

Libraries
P

Parallax

Immutable Torch Modules for JAX.

Libraries
P

Pax

A Jax-based machine learning framework for training large scale models.

Libraries
P

Penzai

Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model.

Libraries
P

PGMax

A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.

Libraries
P

Pgx

Vectorized board game environments for RL with an AlphaZero example.

Libraries
P

PIX

PIX is an image processing library in JAX, for JAX.

Libraries
P

Praxis

The layer library for Pax with a goal to be usable by other JAX-based ML projects.

Libraries
P

purejaxrl

Vectorisable, end-to-end RL algorithms in JAX.

Libraries
Q

QDax

Quality Diversity optimization in Jax.

Libraries
R

RLax

Library for implementing reinforcement learning agents.

Libraries
S

safejax

Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`.

Libraries
S

Scenic

A Jax Library for Computer Vision Research and Beyond.

Libraries
S

SCICO

Scientific computational imaging in JAX.

Libraries
S

sklearn-jax-kernels

`scikit-learn` kernel matrices using JAX.

Libraries
S

SPU

A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).

Libraries
S

Spyx

Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.

Libraries
S

SymJAX

Symbolic CPU/GPU/TPU programming.

Libraries
T

TensorLy

Tensor learning made simple.

Libraries
T

TF2JAX

Convert functions/graphs to JAX functions.

Libraries
T

tinygp

The tiniest of Gaussian process libraries in JAX.

Libraries
T

tmmax

Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research

Libraries
T

Trax

"Batteries included" deep learning library focused on providing solutions for common workloads.

Libraries
T

tree-math

Convert functions that operate on arrays into functions that operate on PyTrees.

Libraries
X

XLB

A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.

Libraries

Models and Projects(61 items)

A

Accurate Quantized Training

Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.

Models and Projects
A

Adversarial Robustness

Reference code for Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples and Fixing Data Augmentation to Improve Adversarial Robustness.

Models and Projects
A

AlphaFold

Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.

Models and Projects
A

Amortized Bayesian Optimization

Code related to Amortized Bayesian Optimization over Discrete Spaces.

Models and Projects
A

AQuaDem

Official implementation of Continuous Control with Action Quantization from Demonstrations.

Models and Projects
A

ARDM

Official implementation of Autoregressive Diffusion Models.

Models and Projects
B

Big Transfer (BiT)

Implementation of Big Transfer (BiT): General Visual Representation Learning.

Models and Projects
B

BNN-HMC

Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.

Models and Projects
B

Bootstrap Your Own Latent

Implementation for the paper Bootstrap your own latent: A new approach to self-supervised Learning.

Models and Projects
C

Combiner

Official implementation of Combiner: Full Attention Transformer with Sparse Computation Cost.

Models and Projects
D

D3PM

Official implementation of Structured Denoising Diffusion Models in Discrete State-Spaces.

Models and Projects
D

DeepSeek-R1-Flax-1.5B-Distill

Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.

Models and Projects
D

DETR

Flax implementation of DETR: End-to-end Object Detection with Transformers using Sinkhorn solver and parallel bipartite matching.

Models and Projects
D

Distributed Shampoo

Implementation of Second Order Optimization Made Practical.

Models and Projects
D

Dreamfields

Official implementation of the ICLR 2022 paper Progressive Distillation for Fast Sampling of Diffusion Models.

Models and Projects
F

FID computation

Port of mseitzer/pytorch-fid to Flax.

Models and Projects
F

Flax Models

Collection of models and methods implemented in Flax.

Models and Projects
F

FNet

Official implementation of FNet: Mixing Tokens with Fourier Transforms.

Models and Projects
F

Fourier Feature Networks

Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.

Models and Projects
G

Gated Linear Networks

GLNs are a family of backpropagation-free neural networks.

Models and Projects
G

GFSA

Official implementation of Learning Graph Structure With A Finite-State Automaton Layer.

Models and Projects
G

GIFT

Official implementation of Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent.

Models and Projects
G

Glassy Dynamics

Open source implementation of the paper Unveiling the predictive power of static structure in glassy systems.

Models and Projects
G

gMLP

Implementation of Pay Attention to MLPs.

Models and Projects
G

GNNs for Solving Combinatorial Optimization Pro...

A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.

Models and Projects
G

Gumbel-max Causal Mechanisms

Code for Learning Generalized Gumbel-max Causal Mechanisms, with extra code in GuyLor/gumbelmaxcausalgadgetspart2.

Models and Projects
I

IPA-GNN

Official implementation of Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks.

Models and Projects
J

JAX RL

Implementations of reinforcement learning algorithms.

Models and Projects
J

JAX-DFT

One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics.

Models and Projects
J

JaxNeRF

Implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.

Models and Projects
J

JaxNeuS

Implementation of NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction

Models and Projects
J

jaxns

Nested sampling in JAX.

Models and Projects
K

kalman-jax

Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.

Models and Projects
L

Latent Programmer

Code for the ICML 2021 paper Latent Programmer: Discrete Latent Codes for Program Synthesis.

Models and Projects
L

Light Field Neural Rendering

Official implementation of Light Field Neural Rendering.

Models and Projects
L

lqg

Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper Putting perception into action with inverse optimal control for continuous psychophysics

Models and Projects
M

mip-NeRF

Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

Models and Projects
M

MLP Mixer

Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.

Models and Projects
M

MMV

Code for the models in Self-Supervised MultiModal Versatile Networks.

Models and Projects
M

MUSIQ

Checkpoints and model inference code for the ICCV 2021 paper MUSIQ: Multi-scale Image Quality Transformer

Models and Projects
N

NesT

Official implementation of Aggregating Nested Transformers.

Models and Projects
N

Normalizer-Free Networks

Official Haiku implementation of NFNets.

Models and Projects
N

NuX

Normalizing flows with JAX.

Models and Projects
O

OGB-LSC

This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph)

Models and Projects
P

Performer

Flax implementation of the Performer (linear transformer via FAVOR+) architecture.

Models and Projects
P

Persistent Evolution Strategies

Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.

Models and Projects
P

Protein LM

Implements BERT and autoregressive models for proteins, as described in Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences and ProGen: Language Modeling for Protein Generation.

Models and Projects
R

Reformer

Implementation of the Reformer (efficient transformer) architecture.

Models and Projects
R

RegNeRF

Official implementation of RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs.

Models and Projects
R

Robust Loss

Reference code for the paper A General and Adaptive Robust Loss Function.

Models and Projects
S

Sharpened Cosine Similarity in JAX by Raphael P...

A JAX/Flax implementation of the Sharpened Cosine Similarity layer.

Models and Projects
S

Slot Attention

Reference implementation for Differentiable Patch Selection for Image Recognition.

Models and Projects
S

SNeRG

Official implementation of Baking Neural Radiance Fields for Real-Time View Synthesis.

Models and Projects
S

Spin-weighted Spherical CNNs

Adaptation of Spin-Weighted Spherical CNNs.

Models and Projects
S

Symbolic Functionals

Demonstration from Evolving symbolic density functionals.

Models and Projects
T

TriMap

Official JAX implementation of TriMap: Large-scale Dimensionality Reduction Using Triplets.

Models and Projects
T

Two Player Auction Learning

JAX implementation of the paper Auction learning as a two-player game.

Models and Projects
V

VDVAE

Adaptation of Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images, original code at openai/vdvae.

Models and Projects
V

Vision Transformer

Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

Models and Projects
W

WikiGraphs

Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase.

Models and Projects
X

XMC-GAN

Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.

Models and Projects

Tutorials and Blog Posts(26 items)

A

Achieving 4000x Speedups with PureJaxRL

A blog post on how JAX can massively speedup RL training through vectorisation.

Tutorials and Blog Posts
D

Deep Learning tutorials with JAX+Flax by Philli...

A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.

Tutorials and Blog Posts
D

Deterministic ADVI in JAX by Martin Ingram

Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.

Tutorials and Blog Posts
D

Differentiable Path Tracing on the GPU/TPU by E...

Tutorial on implementing path tracing.

Tutorials and Blog Posts
E

Ensemble networks by Mat Kelcey

Ensemble nets are a method of representing an ensemble of models as one single logical model.

Tutorials and Blog Posts
E

Evolved channel selection by Mat Kelcey

Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.

Tutorials and Blog Posts
E

Evolving Neural Networks in JAX by Robert Tjark...

Explores how JAX can power the next generation of scalable neuroevolution algorithms.

Tutorials and Blog Posts
E

Exploring hyperparameter meta-loss landscapes w...

Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.

Tutorials and Blog Posts
E

Extending JAX with custom C++ and CUDA code by ...

Tutorial demonstrating the infrastructure required to provide custom ops in JAX.

Tutorials and Blog Posts
F

From PyTorch to JAX: towards neural net framewo...

Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.

Tutorials and Blog Posts
G

Get started with JAX by Aleksa Gordić

A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.

Tutorials and Blog Posts
G

Getting started with JAX (MLPs, CNNs & RNNs) by...

Neural network building blocks from scratch with the basic JAX operators.

Tutorials and Blog Posts
H

How to add a progress bar to JAX scans and loop...

Tutorial on how to add a progress bar to compiled loops in JAX using the `host_callback` module.

Tutorials and Blog Posts
I

Implementing NeRF in JAX by Soumik Rakshit and ...

A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.

Tutorials and Blog Posts
I

Introduction to JAX by Kevin Murphy

Colab that introduces various aspects of the language and applies them to simple ML problems.

Tutorials and Blog Posts
L

Learn JAX: From Linear Regression to Neural Net...

A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.

Tutorials and Blog Posts
M

Meta-Learning in 50 Lines of JAX by Eric Jang

Introduction to both JAX and Meta-Learning.

Tutorials and Blog Posts
N

Normalizing Flows in 100 Lines of JAX by Eric Jang

Concise implementation of RealNVP.

Tutorials and Blog Posts
O

Out of distribution (OOD) detection by Mat Kelcey

Implements different methods for OOD detection.

Tutorials and Blog Posts
P

Plugging Into JAX by Nick Doiron

Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.

Tutorials and Blog Posts
S

Simple PDE solver + Constrained Optimization wi...

A simple example of solving the advection-diffusion equations with JAX and using it in a constrained optimization problem to find initial conditions that yield desired result.

Tutorials and Blog Posts
T

Tutorial: image classification with JAX and Fla...

Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.

Tutorials and Blog Posts
U

Understanding Autodiff with JAX by Srihari Radh...

Understand how autodiff works using JAX.

Tutorials and Blog Posts
U

Using JAX to accelerate our research by David B...

Describes the state of JAX and the JAX ecosystem at DeepMind.

Tutorials and Blog Posts
W

Writing a Training Loop in JAX + FLAX by Saurav...

A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.

Tutorials and Blog Posts
W

Writing an MCMC sampler in JAX by Jeremie Coullon

Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.

Tutorials and Blog Posts