Automatic differentiation and XLA compilation brought together for high-performance machine learning research.
Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
Bayesian Optimization powered by JAX.
Library of samplers for JAX.
Brain Dynamics Programming in Python.
State-based Transformation System for Program Compilation and Augmentation.
Leveraging Taichi Lang to customize brain dynamics operators.
Physical units and unit-aware mathematical system in JAX.
Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
Utilities to write and test reliable JAX code.
Turn RL papers into code, the easy way.
Algorithms for finding coresets to compress large datasets while retaining their statistical properties.
XLA accelerated algorithms for sparse representations and compressive sensing.
Construct differentiable convex optimization layers.
A photovoltaic simulator with automatic differentation.
Dendritic Modeling in JAX.
Numerical differential equation solvers in JAX.
Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
High-performance and differentiable simulations of quantum systems with JAX.
EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Solve macroeconomic models with hetereogeneous agents using JAX.
Exponential Families in JAX.
DSL-based reshaping library for JAX and other frameworks.
A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
Construct equivariant neural network layers.
Equinox version of Torchvision.
Hardware-Accelerated Neuroevolution
JAX-Based Evolution Strategies
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
Federated learning in JAX, built on Optax and Haiku.
Centered on flexibility and clarity.
An evolution on Flax by the same team
Pretrained models for Jax/Flax.
Flax version of TorchVision.
Distributions and normalizing flows built as equinox modules.
Agent-Based modelling framework in JAX.
AWS library for Uncertainty Quantification in Deep Learning.
Gaussian processes in JAX.
Reinforcement Learning Environments with the well-known gym API.
Focused on simplicity, created by the authors of Sonnet at DeepMind.
Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
Image augmentations and transformations.
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
Accelerated, differential molecular dynamics.
Differentiable cosmology library.
Normalizing flows in JAX.
Implementations of research papers originally without code or code written with frameworks other than JAX.
Implementations and checkpoints for ResNet variants in Flax.
Add a tqdm progress bar to JAX scans and loops.
Library implementing the UniRep model for protein machine learning applications.
Framework for differentiable simulators with arbitrary discretizations.
Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
Lie theory library for rigid body transformations and optimization.
Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
Lightweight graph neural network library.
A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
A library for differentiable acoustic simulations
Differentiable stencil decorators in JAX.
Second Order Optimization with Approximate Curvature for NNs.
Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
Monte Carlo tree search algorithms in native JAX.
Express & compile probabilistic programs for performant inference.
Combine MPI operations with your Jax code on CPUs and GPUs.
A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX
Machine Learning toolbox for Quantum Physics.
High-level API for specifying neural networks of both finite and infinite width.
Probabilistic programming based on the Pyro library.
Has an object oriented design similar to PyTorch.
Gradient processing and optimization library.
Toolbox that bundles utilities to solve optimal transport problems.
Root finding, minimisation, fixed points, and least squares.
Probabilistic programming language based on program transformations.
Optimal transport tools in JAX.
Immutable Torch Modules for JAX.
A Jax-based machine learning framework for training large scale models.
Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model.
A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
Vectorized board game environments for RL with an AlphaZero example.
PIX is an image processing library in JAX, for JAX.
The layer library for Pax with a goal to be usable by other JAX-based ML projects.
Vectorisable, end-to-end RL algorithms in JAX.
Quality Diversity optimization in Jax.
Library for implementing reinforcement learning agents.
Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`.
A Jax Library for Computer Vision Research and Beyond.
Scientific computational imaging in JAX.
`scikit-learn` kernel matrices using JAX.
A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
Symbolic CPU/GPU/TPU programming.
Tensor learning made simple.
Convert functions/graphs to JAX functions.
The tiniest of Gaussian process libraries in JAX.
Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research
"Batteries included" deep learning library focused on providing solutions for common workloads.
Convert functions that operate on arrays into functions that operate on PyTrees.
A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
Reference code for Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples and Fixing Data Augmentation to Improve Adversarial Robustness.
Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
Code related to Amortized Bayesian Optimization over Discrete Spaces.
Official implementation of Continuous Control with Action Quantization from Demonstrations.
Official implementation of Autoregressive Diffusion Models.
Implementation of Big Transfer (BiT): General Visual Representation Learning.
Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.
Implementation for the paper Bootstrap your own latent: A new approach to self-supervised Learning.
Official implementation of Combiner: Full Attention Transformer with Sparse Computation Cost.
Official implementation of Structured Denoising Diffusion Models in Discrete State-Spaces.
Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.
Flax implementation of DETR: End-to-end Object Detection with Transformers using Sinkhorn solver and parallel bipartite matching.
Implementation of Second Order Optimization Made Practical.
Official implementation of the ICLR 2022 paper Progressive Distillation for Fast Sampling of Diffusion Models.
Port of mseitzer/pytorch-fid to Flax.
Collection of models and methods implemented in Flax.
Official implementation of FNet: Mixing Tokens with Fourier Transforms.
Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
GLNs are a family of backpropagation-free neural networks.
Official implementation of Learning Graph Structure With A Finite-State Automaton Layer.
Official implementation of Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent.
Open source implementation of the paper Unveiling the predictive power of static structure in glassy systems.
Implementation of Pay Attention to MLPs.
A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.
Code for Learning Generalized Gumbel-max Causal Mechanisms, with extra code in GuyLor/gumbelmaxcausalgadgetspart2.
Official implementation of Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks.
Implementations of reinforcement learning algorithms.
One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics.
Implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.
Implementation of NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction
Nested sampling in JAX.
Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
Code for the ICML 2021 paper Latent Programmer: Discrete Latent Codes for Program Synthesis.
Official implementation of Light Field Neural Rendering.
Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper Putting perception into action with inverse optimal control for continuous psychophysics
Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.
Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.
Code for the models in Self-Supervised MultiModal Versatile Networks.
Checkpoints and model inference code for the ICCV 2021 paper MUSIQ: Multi-scale Image Quality Transformer
Official implementation of Aggregating Nested Transformers.
Official Haiku implementation of NFNets.
Normalizing flows with JAX.
This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph)
Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.
Implements BERT and autoregressive models for proteins, as described in Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences and ProGen: Language Modeling for Protein Generation.
Implementation of the Reformer (efficient transformer) architecture.
Official implementation of RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs.
Reference code for the paper A General and Adaptive Robust Loss Function.
A JAX/Flax implementation of the Sharpened Cosine Similarity layer.
Reference implementation for Differentiable Patch Selection for Image Recognition.
Official implementation of Baking Neural Radiance Fields for Real-Time View Synthesis.
Adaptation of Spin-Weighted Spherical CNNs.
Demonstration from Evolving symbolic density functionals.
Official JAX implementation of TriMap: Large-scale Dimensionality Reduction Using Triplets.
JAX implementation of the paper Auction learning as a two-player game.
Adaptation of Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images, original code at openai/vdvae.
Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase.
Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.
White paper describing an early version of JAX, detailing how computation is traced and compiled.
Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
White paper describing the XLB library: benchmarks, validations, and more details about the library.
A blog post on how JAX can massively speedup RL training through vectorisation.
A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.
Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
Tutorial on implementing path tracing.
Ensemble nets are a method of representing an ensemble of models as one single logical model.
Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
Explores how JAX can power the next generation of scalable neuroevolution algorithms.
Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.
Neural network building blocks from scratch with the basic JAX operators.
Tutorial on how to add a progress bar to compiled loops in JAX using the `host_callback` module.
A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.
Colab that introduces various aspects of the language and applies them to simple ML problems.
A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.
Introduction to both JAX and Meta-Learning.
Concise implementation of RealNVP.
Implements different methods for OOD detection.
Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
A simple example of solving the advection-diffusion equations with JAX and using it in a constrained optimization problem to find initial conditions that yield desired result.
Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
Understand how autodiff works using JAX.
Describes the state of JAX and the JAX ecosystem at DeepMind.
A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.
Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
Introduction to Bayesian modelling using NumPyro.
Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.
Simple neural network from scratch in JAX.
Presentation of TPU host access with demo.
3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
JAX's core design, how it's powering new research, and how you can start using it.
JAX intro presentation in Program Transformations for Machine Learning workshop.
JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.