← blog

jepax v0: an implementation of IJEPA-B training in JAX/Equinox

Owen L., Anton S.

2026-02-11

This post is a work in progress!

Repository: github.com/sugolov/jepax

Overview

A little while ago, Owen and I got interested in JEPAs and the self-supervised approach to learning good latent representations. One theme in ongoing JEPA work are new loss regularizers: the training setups are similar but with small augmentations to the loss that improve stability or representation properties. We set out to make jepax a JAX/Equinox implementation of the self-supervised method, with the goal of a simple and modifiable codebase that enables fast iteration.

Figure: Training loss and linear probe accuracy for IJEPA-B trained for 300 epochs on 8xA100.

For this first release, jepax v0, we focused on (1) 1-to-1 configs, losses, and logging with the original PyTorch implementation and (2) a reproduction of IJEPA-B with data parallelization on 8xA100. I think we collected a lot of interesting insights about JEPA training, which I want to describe in this blog. Some of the themes to discuss:

  1. Background on JEPA and IJEPA
  2. Interesting failure modes
    1. Smooth \(L_1\) loss and unnormalized \(L_2\)
    2. Target normalization
  3. JAX specific considerations

Stay tuned!