2026-02-11
This post is a work in progress!
Repository: github.com/sugolov/jepax
A little while ago, Owen and I got interested in JEPAs and the
self-supervised approach to learning good latent representations. One
theme in ongoing JEPA work are new loss regularizers: the training
setups are similar but with small augmentations to the loss that improve
stability or representation properties. We set out to make
jepax a JAX/Equinox
implementation of the self-supervised method, with the goal of a simple
and modifiable codebase that enables fast iteration.
Figure: Training loss and linear probe accuracy for IJEPA-B trained for 300 epochs on 8xA100.
For this first release, jepax v0, we focused on (1)
1-to-1 configs, losses, and logging with the original PyTorch
implementation and (2) a reproduction of IJEPA-B with data
parallelization on 8xA100. I think we collected a lot of interesting
insights about JEPA training, which I want to describe in this blog.
Some of the themes to discuss:
Stay tuned!