20/04/2026 10:09 AM
⬅️ [17/04/2026 12:45 PM - ReID More explantation](<./17_04_2026 12_45 PM - ReID More explantation.md>) | ⬆️ [2026 - April](<./README.md>) | [21/04/2026 3:22 PM - VIGIL Projection Meeting](<./21_04_2026 3_22 PM - VIGIL Projection Meeting.md>) ➡️
20/04/2026 10:09 AM
I need to get the trees working better. Definitely at least roll out trees from the early tests of the RL. But really I need to do the RL project cause I don't know if anyone else is going to.
Thinking that I will implement CURL. It is meant to be off policy, but we have so many steps that it feels like it will work.
Got gemini to create a list of things that need to be updated.
1. Model Definitions (models.py)
- [ ] Create a separate
Encoderclass: Extract the convolutional layers and flattening step into their ownnn.Module. - [ ] Modify Actor/Critic to accept embeddings: Change
ActorCriticConvto accept the flattened output of theEncoderrather than raw images. - [ ]
- [ ] Add a
CURLHeadclass: Add a bilinear projection layer (e.g., a trainable parameter matrixW) to project the query embedding into the key embedding space.
2. Network Initialization (make_train)
- [ ] Initialize the separate components: Initialize the
Encoder, theActorCriticheads, and theCURLHead. - [ ] Create a unified TrainState: Bundle the
Encoder,ActorCritic, andCURLHeadparameters into a singleTrainStatefor joint backpropagation. - [ ] Initialize the Momentum (Target) Encoder: Create a copy of the
Encoderparameters to serve as the momentum encoder (store this in the runner state, not the OptaxTrainState).
3. Data Augmentation Setup
- [ ] Add Augmentation Logic: Bring in a JAX-compatible random crop function (e.g.,
jax.imageordm_pix). - [ ] Define Crop Sizes: Define the random crop size as a configuration variable (e.g., 64x64 for an 84x84 base observation).
4. Trajectory Collection (_env_step)
- [ ] Augment before acting: Apply a random crop to
last_obsbefore passing it to the network to getpiandvalue. - [ ] Store Unaugmented Obs in Transition: Store the unaugmented observations in the
Transitiontuple to allow for fresh random crops during PPO updates.
5. The Loss Function (_loss_fn inside _update_minbatch)
- [ ] Generate Pairs: Apply the random crop function twice with different random keys to
traj_batch.obsto generateobs_queryandobs_key. - [ ] Forward Pass (Online): Pass
obs_querythrough the onlineEncoder, then pass the resultingz_querythrough theActorCriticheads to calculate standard PPO losses. - [ ] Forward Pass (Momentum): Pass
obs_keythrough the momentum encoder to getz_key(usejax.lax.stop_gradientto prevent backpropagation). - [ ] Calculate InfoNCE Loss: Pass
z_querythrough theCURLHeadmatrixW, compute logits via batch matrix multiplication withz_key.T, and calculate categorical cross-entropy. - [ ] Combine Losses: Add the CURL loss to the
total_lossobjective (total_loss = ppo_loss + config["CURL_COEF"] * curl_loss).
6. The Momentum Update
- [ ] Implement EMA: Write a Polyak averaging update using
jax.tree_util.tree_mapat the end of_update_minbatch. - [ ] Apply Formula:
momentum_params = tau * momentum_params + (1 - tau) * online_encoder_params. - [ ] Pass State: Ensure the updated
momentum_paramsare returned from_update_minbatchand threaded through thejax.lax.scanloops.
7. Arguments & Configuration
- [ ] Add
--curl_coef: The weight of the CURL loss in the total loss. - [ ] Add
--momentum_tau: The EMA update rate (e.g., 0.99 or 0.999). - [ ] Add
--crop_size: The size of the random crop.
⬅️ [17/04/2026 12:45 PM - ReID More explantation](<./17_04_2026 12_45 PM - ReID More explantation.md>) | ⬆️ [2026 - April](<./README.md>) | [21/04/2026 3:22 PM - VIGIL Projection Meeting](<./21_04_2026 3_22 PM - VIGIL Projection Meeting.md>) ➡️