def byol_loss(
online_projection_1: jax.typing.ArrayLike,
target_projection_2: jax.typing.ArrayLike,
online_projection_2: jax.typing.ArrayLike,
target_projection_1: jax.typing.ArrayLike,
eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:
def simsiam_loss(
predictor_projection_1: jax.typing.ArrayLike,
target_projection_2: jax.typing.ArrayLike,
predictor_projection_2: jax.typing.ArrayLike,
target_projection_1: jax.typing.ArrayLike,
eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:
def dino_loss(
student_logits: jax.typing.ArrayLike,
teacher_logits: jax.typing.ArrayLike,
student_temperature: jax.typing.ArrayLike = 0.1,
teacher_temperature: jax.typing.ArrayLike = 0.04,
teacher_center: jax.typing.ArrayLike = 0.0,
) -> jax.Array:
def barlow_twins_loss(
projection_1: jax.typing.ArrayLike,
projection_2: jax.typing.ArrayLike,
off_diagonal_scale: jax.typing.ArrayLike = 5e-3,
eps: jax.typing.ArrayLike = 1e-12,
) -> jax.Array:
I’d like to propose adding a few popular self-supervised losses to
optax.losses._self_supervised:byol_losssimsiam_lossdino_lossbarlow_twins_lossMotivation
These objectives are widely used in modern self-supervised representation
learning pipelines, especially for vision, and having them in Optax would:
JAX/Flax.
different copies.
ntxentand triplet margin losses alreadypresent in
_self_supervised.py.Proposed API (high-level)
All functions follow the same style as existing Optax losses:
jit/vmap.jax.typing.ArrayLikearguments andjax.Arrayreturn types.utils.check_subdtypefor float inputs.Rough signatures:
References
BYOL – Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
https://arxiv.org/abs/2006.07733
SimSiam – Exploring Simple Siamese Representation Learning
https://arxiv.org/abs/2011.10566
DINO – Emerging Properties in Self-Supervised Vision Transformers
https://arxiv.org/abs/2104.14294
Barlow Twins – Barlow Twins: Self-Supervised Learning via Redundancy Reduction
https://arxiv.org/abs/2103.03230