Skip to content

20/04/2026 10:09 AM

⬅️ [17/04/2026 12:45 PM - ReID More explantation](<./17_04_2026 12_45 PM - ReID More explantation.md>) | ⬆️ [2026 - April](<./README.md>) | [21/04/2026 3:22 PM - VIGIL Projection Meeting](<./21_04_2026 3_22 PM - VIGIL Projection Meeting.md>) ➡️

20/04/2026 10:09 AM

I need to get the trees working better. Definitely at least roll out trees from the early tests of the RL. But really I need to do the RL project cause I don't know if anyone else is going to.

Thinking that I will implement CURL. It is meant to be off policy, but we have so many steps that it feels like it will work.

Got gemini to create a list of things that need to be updated.

1. Model Definitions (models.py)

  • [ ] Create a separate Encoder class: Extract the convolutional layers and flattening step into their own nn.Module.
  • [ ] Modify Actor/Critic to accept embeddings: Change ActorCriticConv to accept the flattened output of the Encoder rather than raw images.
  • [ ]
  • [ ] Add a CURLHead class: Add a bilinear projection layer (e.g., a trainable parameter matrix W) to project the query embedding into the key embedding space.

2. Network Initialization (make_train)

  • [ ] Initialize the separate components: Initialize the Encoder, the ActorCritic heads, and the CURLHead.
  • [ ] Create a unified TrainState: Bundle the Encoder, ActorCritic, and CURLHead parameters into a single TrainState for joint backpropagation.
  • [ ] Initialize the Momentum (Target) Encoder: Create a copy of the Encoder parameters to serve as the momentum encoder (store this in the runner state, not the Optax TrainState).

3. Data Augmentation Setup

  • [ ] Add Augmentation Logic: Bring in a JAX-compatible random crop function (e.g., jax.image or dm_pix).
  • [ ] Define Crop Sizes: Define the random crop size as a configuration variable (e.g., 64x64 for an 84x84 base observation).

4. Trajectory Collection (_env_step)

  • [ ] Augment before acting: Apply a random crop to last_obs before passing it to the network to get pi and value.
  • [ ] Store Unaugmented Obs in Transition: Store the unaugmented observations in the Transition tuple to allow for fresh random crops during PPO updates.

5. The Loss Function (_loss_fn inside _update_minbatch)

  • [ ] Generate Pairs: Apply the random crop function twice with different random keys to traj_batch.obs to generate obs_query and obs_key.
  • [ ] Forward Pass (Online): Pass obs_query through the online Encoder, then pass the resulting z_query through the ActorCritic heads to calculate standard PPO losses.
  • [ ] Forward Pass (Momentum): Pass obs_key through the momentum encoder to get z_key (use jax.lax.stop_gradient to prevent backpropagation).
  • [ ] Calculate InfoNCE Loss: Pass z_query through the CURLHead matrix W, compute logits via batch matrix multiplication with z_key.T, and calculate categorical cross-entropy.
  • [ ] Combine Losses: Add the CURL loss to the total_loss objective (total_loss = ppo_loss + config["CURL_COEF"] * curl_loss).

6. The Momentum Update

  • [ ] Implement EMA: Write a Polyak averaging update using jax.tree_util.tree_map at the end of _update_minbatch.
  • [ ] Apply Formula: momentum_params = tau * momentum_params + (1 - tau) * online_encoder_params.
  • [ ] Pass State: Ensure the updated momentum_params are returned from _update_minbatch and threaded through the jax.lax.scan loops.

7. Arguments & Configuration

  • [ ] Add --curl_coef: The weight of the CURL loss in the total loss.
  • [ ] Add --momentum_tau: The EMA update rate (e.g., 0.99 or 0.999).
  • [ ] Add --crop_size: The size of the random crop.

⬅️ [17/04/2026 12:45 PM - ReID More explantation](<./17_04_2026 12_45 PM - ReID More explantation.md>) | ⬆️ [2026 - April](<./README.md>) | [21/04/2026 3:22 PM - VIGIL Projection Meeting](<./21_04_2026 3_22 PM - VIGIL Projection Meeting.md>) ➡️