IDTA 가지고 실시간 sparse feature 뽑아서 실시간 computing 적음도 fine tuning 보다 높은 성능 보이기
Results
- single layer 좋지 않았음
- activation 중간에 있는게 좋았음
- sigmoid 보다 tanh 가 좋았음
Target
1:45분 수정
~/cloudfiles/code/Users/Seonglae.Cho/corr-steer main *1 ················ azureml_py38 azureuser@a100research 22:34:42 ❯ python train.py train --layer=global --task=harmbench /anaconda/envs/azureml_py38/lib/python3.10/site-packages/pydantic/_internal/_fields.py:198: UserWarning: Field name "validate" in "CorrConfig" shadows an attribute in parent "BaseModel" warnings.warn( Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [01:00<00:00, 30.08s/it] Training correlations for layers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] Collecting correlations: 0%| Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation. Device set to use cuda:0 Collecting correlations: 0%|▏ You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a datasets: 8it [00:16, 2.06s/it] Collecting correlations: 3%|█▋ | 108/4000 [01:10<42:14, 1.54samples/s] Layer 1: pos 9572 r=0.6924, neg None Layer 2: pos 6712 r=0.6920, neg None Layer 3: pos 16207 r=0.6858, neg None Layer 4: pos 3109 r=0.6960, neg None Layer 5: pos 11099 r=0.7374, neg None Layer 6: pos 12241 r=0.7345, neg None Layer 7: pos 11722 r=0.7794, neg None Layer 8: pos 8642 r=0.7455, neg None Layer 9: pos 9298 r=0.7751, neg None Layer 10: pos 3037 r=0.7230, neg None Layer 11: pos 6905 r=0.7349, neg None Layer 12: pos 12039 r=0.7407, neg None Layer 13: pos 6715 r=0.7092, neg None Layer 14: pos 2949 r=0.7391, neg None Layer 15: pos 1570 r=0.7418, neg None Layer 16: pos 5113 r=0.7427, neg None Layer 17: pos 5887 r=0.7196, neg None Layer 18: pos 1411 r=0.7119, neg None Layer 19: pos 324 r=0.7102, neg None Layer 20: pos 5192 r=0.7175, neg None Layer 21: pos 7129 r=0.7211, neg None Layer 22: pos 3311 r=0.7465, neg None Layer 23: pos 11246 r=0.7108, neg None Layer 24: pos 12773 r=0.6995, neg None Layer 25: pos 3912 r=0.7106, neg None Global best: Layer 7 using positive feature 11722 with correlation 0.7794 CorrSteer (global) saved to checkpoints/gemma2b_harmbench_global.json Analyzing top correlation features... Layer 1: Using positive feature 9572 with coefficient 5.2061 (corr=0.6924) [SAE] Layer 2: Using positive feature 6712 with coefficient 5.6994 (corr=0.6920) [SAE] Layer 3: Using positive feature 16207 with coefficient 2.5830 (corr=0.6858) [SAE] Layer 4: Using positive feature 3109 with coefficient 5.8908 (corr=0.6960) [SAE] Layer 5: Using positive feature 11099 with coefficient 16.9340 (corr=0.7374) [SAE] Layer 6: Using positive feature 12241 with coefficient 7.3383 (corr=0.7345) [SAE] Layer 7: Using positive feature 11722 with coefficient 5.0351 (corr=0.7794) [SAE] Layer 8: Using positive feature 8642 with coefficient 8.7294 (corr=0.7455) [SAE] Layer 9: Using positive feature 9298 with coefficient 7.5245 (corr=0.7751) [SAE] Layer 10: Using positive feature 3037 with coefficient 6.6667 (corr=0.7230) [SAE] Layer 11: Using positive feature 6905 with coefficient 13.8096 (corr=0.7349) [SAE] Layer 12: Using positive feature 12039 with coefficient 5.2533 (corr=0.7407) [SAE] Layer 13: Using positive feature 6715 with coefficient 6.9916 (corr=0.7092) [SAE] Layer 14: Using positive feature 2949 with coefficient 16.6202 (corr=0.7391) [SAE] Layer 15: Using positive feature 1570 with coefficient 23.8238 (corr=0.7418) [SAE] Layer 16: Using positive feature 5113 with coefficient 21.8320 (corr=0.7427) [SAE] Layer 17: Using positive feature 5887 with coefficient 11.3889 (corr=0.7196) [SAE] Layer 18: Using positive feature 1411 with coefficient 20.5374 (corr=0.7119) [SAE] Layer 19: Using positive feature 324 with coefficient 35.6101 (corr=0.7102) [SAE] Layer 20: Using positive feature 5192 with coefficient 45.6623 (corr=0.7175) [SAE] Layer 21: Using positive feature 7129 with coefficient 33.2255 (corr=0.7211) [SAE] Layer 22: Using positive feature 3311 with coefficient 19.0001 (corr=0.7465) [SAE] Layer 23: Using positive feature 11246 with coefficient 61.6424 (corr=0.7108) [SAE] Layer 24: Using positive feature 12773 with coefficient 50.3317 (corr=0.6995) [SAE] Layer 25: Using positive feature 3912 with coefficient 57.4309 (corr=0.7106) [SAE] Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████| 280/280 [01:05<00:00, 4.26it/s] Fixed feature accuracy: 67.50% Results saved to checkpoints/gemma2b_harmbench_multi_25.json Evaluation accuracy saved to checkpoints/gemma2b_harmbench_global_accuracy.json (accuracy=67.50%)
Current
~/cloudfiles/code/Users/Seonglae.Cho/ControlRL main ⇡1 +4 !3 · 11m 57s azureml_py38 azureuser@a100research 00:07:19 ❯ python train.py train --task=harmbench --layers=all --eval --flatten Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:49<00:00, 24.77s/it] wandb: Currently logged in as: seonglae (texonom) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin wandb: Tracking run with wandb version 0.21.1 wandb: Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/a100research/code/Users/Seonglae.Cho/ControlRL/wandb/run-20250822_002120-iu0mfhzj wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run gemma2b_harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_ppo_1e-05_0822_002120 wandb: ⭐️ View project at https://wandb.ai/texonom/control_rl wandb: 🚀 View run at https://wandb.ai/texonom/control_rl/runs/iu0mfhzj Training Steps: 0%| | 0/14 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation. You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect. Device set to use cuda:0 You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset Step 0: Avg Train Acc 0.3750, Val Acc 0.5000 Layer 1: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 43, Avg Activation: 0.0000, Avg Act Values: 7.4332, Top Corr: 1.0000 (idx: 14790, coeff: 40.3346) Layer 2: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 25, Avg Activation: 0.0000, Avg Act Values: 11.2021, Top Corr: 1.0000 (idx: 8863, coeff: 37.5957) Layer 3: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 40, Avg Activation: 0.0000, Avg Act Values: 7.7847, Top Corr: 0.9998 (idx: 13505, coeff: 45.9490) Layer 4: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 35, Avg Activation: 0.0000, Avg Act Values: 4.9902, Top Corr: 0.9999 (idx: 5786, coeff: 24.3388) Layer 5: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 33, Avg Activation: 0.0000, Avg Act Values: 10.6308, Top Corr: 1.0000 (idx: 7569, coeff: 25.9536) Layer 6: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 36, Avg Activation: 0.0000, Avg Act Values: 7.4637, Top Corr: 1.0000 (idx: 13114, coeff: 23.3237) Layer 7: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 34, Avg Activation: 0.0000, Avg Act Values: 7.9720, Top Corr: 0.9999 (idx: 11358, coeff: 23.2175) Layer 8: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 28, Avg Activation: 0.0000, Avg Act Values: 5.9078, Top Corr: 0.9999 (idx: 6221, coeff: 23.3399) Layer 9: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 24, Avg Activation: 0.0000, Avg Act Values: 9.2464, Top Corr: 1.0000 (idx: 8675, coeff: 12.2697) Layer 10: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 24, Avg Activation: 0.0000, Avg Act Values: 11.6061, Top Corr: 0.9998 (idx: 8361, coeff: 17.4053) Layer 11: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 17, Avg Activation: 0.0000, Avg Act Values: 9.9951, Top Corr: 0.9999 (idx: 16251, coeff: 30.7651) Layer 12: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 23, Avg Activation: 0.0000, Avg Act Values: 16.1377, Top Corr: 1.0000 (idx: 4854, coeff: 15.5613) Layer 13: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 20, Avg Activation: 0.0000, Avg Act Values: 12.6093, Top Corr: 1.0000 (idx: 15254, coeff: 13.6998) Layer 14: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 25, Avg Activation: 0.0000, Avg Act Values: 10.6163, Top Corr: 0.9999 (idx: 10643, coeff: 5.3389) Layer 15: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 19, Avg Activation: 0.0000, Avg Act Values: 9.6818, Top Corr: 0.9999 (idx: 8902, coeff: 6.2690) Layer 16: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 19, Avg Activation: 0.0000, Avg Act Values: 14.5731, Top Corr: 1.0000 (idx: 5113, coeff: 22.2766) Layer 17: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 18, Avg Activation: 0.0000, Avg Act Values: 13.4663, Top Corr: 0.9997 (idx: 1200, coeff: 18.2989) Layer 18: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 18, Avg Activation: 0.0000, Avg Act Values: 21.1193, Top Corr: 0.9999 (idx: 1504, coeff: 7.6450) Layer 19: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 14, Avg Activation: 0.0000, Avg Act Values: 44.7919, Top Corr: 0.9996 (idx: 9637, coeff: 57.7203) Layer 20: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 15, Avg Activation: 0.0000, Avg Act Values: 17.9265, Top Corr: 0.9997 (idx: 3423, coeff: 15.1552) Layer 21: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 11, Avg Activation: 0.0000, Avg Act Values: 20.1886, Top Corr: 0.9992 (idx: 5834, coeff: 83.5779) Layer 22: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 11, Avg Activation: 0.0000, Avg Act Values: 17.6585, Top Corr: 0.9999 (idx: 14848, coeff: 15.0354) Layer 23: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 9, Avg Activation: 0.0000, Avg Act Values: 29.8452, Top Corr: 0.9994 (idx: 13403, coeff: 20.9856) Layer 24: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 12, Avg Activation: 0.0000, Avg Act Values: 25.5663, Top Corr: 0.9999 (idx: 5380, coeff: 16.8423) Layer 25: Policy Loss 0.0000, Critic Loss 0.0000, Grad Norms (P/C) 0.00/0.00, Recon Loss 0.0000, Unique Indices: 8, Avg Activation: 0.0000, Avg Act Values: 24.8593, Top Corr: 0.9999 (idx: 1558, coeff: 32.6608) Training Steps: 100%|████████████████████████████████████████████████████████████████████████████████| 14/14 [03:28<00:00, 14.89s/it] === Final Correlation Results === Layer 1: Using positive feature 1513 with coefficient 7.5780 (corr=0.7034) [SAE] Layer 2: Using positive feature 6712 with coefficient 6.1642 (corr=0.7028) [SAE] Layer 3: Using positive feature 16207 with coefficient 2.6573 (corr=0.7362) [SAE] Layer 4: Using positive feature 3109 with coefficient 6.1323 (corr=0.7408) [SAE] Layer 5: Using positive feature 11099 with coefficient 17.5734 (corr=0.7783) [SAE] Layer 6: Using positive feature 12241 with coefficient 7.7448 (corr=0.7763) [SAE] Layer 7: Using positive feature 11099 with coefficient 19.6308 (corr=0.7847) [SAE] Layer 8: Using positive feature 8642 with coefficient 9.1598 (corr=0.7897) [SAE] Layer 9: Using positive feature 9298 with coefficient 7.5653 (corr=0.7680) [SAE] Layer 10: Using positive feature 5996 with coefficient 12.2603 (corr=0.7276) [SAE] Layer 11: Using positive feature 6905 with coefficient 14.4484 (corr=0.7744) [SAE] Layer 12: Using positive feature 13016 with coefficient 12.3013 (corr=0.7407) [SAE] Layer 13: Using positive feature 6715 with coefficient 7.1451 (corr=0.7414) [SAE] Layer 14: Using positive feature 2949 with coefficient 17.1668 (corr=0.7632) [SAE] Layer 15: Using positive feature 1570 with coefficient 24.0848 (corr=0.7495) [SAE] Layer 16: Using positive feature 5113 with coefficient 22.6561 (corr=0.7864) [SAE] Layer 17: Using positive feature 14231 with coefficient 15.3567 (corr=0.7553) [SAE] Layer 18: Using positive feature 1411 with coefficient 21.9191 (corr=0.7644) [SAE] Layer 19: Using positive feature 324 with coefficient 36.2297 (corr=0.7207) [SAE] Layer 20: Using positive feature 14645 with coefficient 15.4051 (corr=0.7253) [SAE] Layer 21: Using positive feature 7129 with coefficient 34.9367 (corr=0.7407) [SAE] Layer 22: Using positive feature 3311 with coefficient 19.7669 (corr=0.7797) [SAE] Layer 23: Using positive feature 11246 with coefficient 64.2619 (corr=0.7479) [SAE] Layer 24: Using positive feature 12433 with coefficient 62.1589 (corr=0.7288) [SAE] Layer 25: Using positive feature 3912 with coefficient 60.3746 (corr=0.7381) [SAE] Config model: gemma2b task: harmbench layers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] select_token: False decode: False category: None cot: False Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████| 280/280 [02:11<00:00, 2.14it/s] Final harmbench Accuracy with Steering: 45.71% Results saved to ./checkpoints/gemma2b_harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_ppo_1e-05_0822_002120/harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_steered.json Stats saved to ./checkpoints/gemma2b_harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_ppo_1e-05_0822_002120/harmbench_eval.json Every outputs are saved to the folder ./checkpoints/gemma2b_harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_ppo_1e-05_0822_002120 wandb: wandb: 🚀 View run gemma2b_harmbench_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21_22_23_24_25_ppo_1e-05_0822_002120 at: https://wandb.ai/texonom/control_rl/runs/iu0mfhzj wandb: Find logs at: ../../../../../../../mnt/batch/tasks/shared/LS_root/mounts/clusters/a100research/code/Users/Seonglae.Cho/ControlRL/wandb/run-20250822_002120-iu0mfhzj/logs ~/cloudfiles/code/Users/Seonglae.Cho/ControlRL main ⇡1 +4 !3 · 12m 29s azureml_py38 azureuser@a100research 00:30:21 ❯
Layer 1: Using positive feature 1513 with coefficient 7.3448 (corr=0.7034) [SAE] Layer 2: Using positive feature 6712 with coefficient 5.7849 (corr=0.7028) [SAE] Layer 3: Using positive feature 16207 with coefficient 2.6573 (corr=0.7362) [SAE] Layer 4: Using positive feature 3109 with coefficient 6.1323 (corr=0.7408) [SAE] Layer 5: Using positive feature 11099 with coefficient 17.5734 (corr=0.7783) [SAE] Layer 6: Using positive feature 12241 with coefficient 7.7448 (corr=0.7763) [SAE] Layer 7: Using positive feature 11099 with coefficient 19.6308 (corr=0.7847) [SAE] Layer 8: Using positive feature 8642 with coefficient 9.1598 (corr=0.7897) [SAE] Layer 9: Using positive feature 9298 with coefficient 7.5653 (corr=0.7680) [SAE] Layer 10: Using positive feature 5996 with coefficient 12.2603 (corr=0.7276) [SAE] Layer 11: Using positive feature 6905 with coefficient 14.4484 (corr=0.7744) [SAE] Layer 12: Using positive feature 13016 with coefficient 12.3013 (corr=0.7407) [SAE] Layer 13: Using positive feature 6715 with coefficient 7.1451 (corr=0.7414) [SAE] Layer 14: Using positive feature 2949 with coefficient 16.9027 (corr=0.7632) [SAE] Layer 15: Using positive feature 1570 with coefficient 24.0848 (corr=0.7495) [SAE] Layer 16: Using positive feature 5113 with coefficient 22.6561 (corr=0.7864) [SAE] Layer 17: Using positive feature 14231 with coefficient 15.3567 (corr=0.7553) [SAE] Layer 18: Using positive feature 1411 with coefficient 21.2447 (corr=0.7644) [SAE] Layer 19: Using positive feature 324 with coefficient 36.2297 (corr=0.7207) [SAE] Layer 20: Using positive feature 14645 with coefficient 14.6941 (corr=0.7253) [SAE] Layer 21: Using positive feature 7129 with coefficient 34.3992 (corr=0.7407) [SAE] Layer 22: Using positive feature 3311 with coefficient 19.4628 (corr=0.7797) [SAE] Layer 23: Using positive feature 11246 with coefficient 63.2732 (corr=0.7479) [SAE] Layer 24: Using positive feature 12433 with coefficient 61.2026 (corr=0.7288) [SAE] Layer 25: Using positive feature 3912 with coefficient 59.4457 (corr=0.7381) [SAE]


더 효과적인 sparse selection 구조
- token decay 혹은 그냥 correation 더하기보다 곱하기 - 성능 100 으로 73한거 유지로 hyperparameter 삭제로 좋다, feature 동일사용은 여전히 같다.
음수 corr 음수 logit 경우 고려해야하나
실데 스티어링 corr 업에이트할때 선택한 sae feature 랑 더해진 coeff 로 corr 계산해야함
- activation decay 로 steering 0.99 나 0.95 로 줄여나갈까 - 성능유지는 했는데 토큰 길이 짧아서 별의미없 73.21
- 현재토큰 활성화 중에서만 masking 하면 되잖아 성능만 제발 유지되면 encode 중에서
- 혹은 현재꺼 반대 마스킹 새로운거 더하기위해
- correlation 곱해주는 곳에다가 1 아니면 corr 이렇게 해도 되고
1. Gumbel Softmax + Top-K Selection
성능 떨어짐
class GumbelTopK(nn.Module): def __init__(self, temperature=1.0, k=1): self.temperature = temperature self.k = k def forward(self, logits): # Gumbel noise for differentiable sampling gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) noisy_logits = (logits + gumbel_noise) / self.temperature # Differentiable top-k selection return torch.topk(noisy_logits, self.k, dim=-1)
2. Sparse Attention Mechanism
성능유지
class SparseAttention(nn.Module): def __init__(self, dim, sparsity=0.1): self.attention = nn.Linear(dim, dim) self.sparsity = sparsity def forward(self, x, correlation_weights): attn_weights = torch.softmax(self.attention(x), dim=-1) # Apply sparsity mask based on correlation sparse_mask = correlation_weights > correlation_weights.quantile(1-self.sparsity) return attn_weights * sparse_mask
3. Straight-Through Estimator (STE)
성능유지
class StraightThroughTopK(nn.Module): def forward(self, logits, k=1): # Forward: discrete top-k _, indices = torch.topk(logits, k, dim=-1) y_hard = torch.zeros_like(logits).scatter_(-1, indices, 1.0) # Backward: continuous gradients y_soft = torch.softmax(logits, dim=-1) return y_hard - y_soft.detach() + y_soft
4. Learnable Sparse Gates
성능유지
class SparseGate(nn.Module): def __init__(self, dim): self.gate = nn.Parameter(torch.ones(dim)) self.threshold = nn.Parameter(torch.tensor(0.5)) def forward(self, x, correlation_bias=None): gate_scores = torch.sigmoid(self.gate) if correlation_bias is not None: gate_scores = gate_scores + correlation_bias # Learnable sparsity threshold sparse_mask = (gate_scores > self.threshold).float() return x * sparse_mask
1. Gumbel Softmax: 미분 가능한 discrete sampling
2. Sparse Attention: correlation을 attention weight로 활용
3. STE: discrete selection + continuous gradients
or simply
- Token-wise context-dependent correlation → decreased
else: # Stage 2: PPO + context-dependent correlation boost raw_logits = self.fc1(observation) # (batch, seq_len, dict_size) if selection_weights is not None: # Make correlation boost context-dependent by scaling with observation magnitude obs_magnitude = torch.norm(observation, dim=-1, keepdim=True) # (batch, seq_len, 1) context_scale = torch.sigmoid(obs_magnitude) # Normalize to [0,1] selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) * context_scale # (batch, seq_len, dict_size) raw_logits = raw_logits + selection_boost
- Token position linear freedom → same (후반강조는 오히려 낮아지고)
else: # Stage 2: PPO + token-wise correlation boost raw_logits = self.fc1(observation) # (batch, seq_len, dict_size) if selection_weights is not None: # Method 1: Position-dependent correlation (different per token) batch_size, seq_len, dict_size = raw_logits.shape # Create position-dependent correlation weights # Use a simple linear combination based on token position position_weights = torch.linspace(0.1, 1.0, seq_len, device=raw_logits.device) position_weights = position_weights.view(1, seq_len, 1) # (1, seq_len, 1) # Apply position-dependent correlation boost selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) * position_weights raw_logits = raw_logits + selection_boost
- Attention-based correlation weighting → same
else: # Stage 2: PPO + adaptive correlation boost raw_logits = self.fc1(observation) # (batch, seq_len, dict_size) if selection_weights is not None: # Option 1: Simple scaling (기존 방식) # selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) # Option 2: Context-dependent scaling # Compute attention weights based on PPO logits strength ppo_strength = torch.norm(raw_logits, dim=-1, keepdim=True) # (batch, seq_len, 1) attention_weights = torch.softmax(ppo_strength.squeeze(-1), dim=-1).unsqueeze(-1) # (batch, seq_len, 1) # Apply correlation boost proportional to PPO confidence selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) * attention_weights raw_logits = raw_logits + selection_boost
- Learnable mixing parameter → same
else: # Stage 2: PPO + learnable correlation boost raw_logits = self.fc1(observation) # (batch, seq_len, dict_size) if selection_weights is not None: # Apply learnable correlation boost selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) * self.mixing_weight raw_logits = raw_logits + selection_boost epsilon = torch.randn_like(raw_logits) * self.epsilon raw_logits_noisy = raw_logits + epsilon if isinstance(self.act, JumpReLU): logits_noisy = self.act(raw_logits_noisy, critic_values) else:
import torch import torch.nn as nn import torch.optim as optim from torch import Tensor from typing import Optional class JumpReLU(nn.Module): """JumpReLU with learnable thresholds and critic-based gating. Implements: JumpReLU(x, t) = x * H(x - exp(t)) where t is learnable With critic-based gating: higher critic values -> lower gate threshold """ def __init__(self, num_features: int, threshold: float = 5.0, critic_gating: bool = True): super(JumpReLU, self).__init__() # Single shared threshold parameter (in log space for stability) self.log_threshold = nn.Parameter(torch.tensor(threshold)) self.critic_gating = critic_gating # Single gating bias for critic-based threshold adjustment if critic_gating: self.gate_bias = nn.Parameter(torch.tensor(0.0)) def forward(self, x: Tensor, critic_values: Optional[Tensor] = None) -> Tensor: """Forward pass with optional critic-based gating. Args: x: Input tensor (batch_size, seq_len, num_features) critic_values: Critic values for gating (batch_size, seq_len) Returns: output: Gated output tensor """ # Compute base threshold from learned parameter (single shared threshold) threshold = torch.exp(self.log_threshold) # scalar if self.critic_gating and critic_values is not None: adjusted_threshold = threshold + critic_values.unsqueeze(-1) + self.gate_bias # (batch_size, seq_len, 1) else: adjusted_threshold = threshold gates = (x > adjusted_threshold).float() output = x * gates return output class PolicyNetwork(nn.Module): fc1: nn.Module act: nn.Module topk: Tensor deep: bool epsilon: Tensor lastk: Tensor def __init__( self, latent_dim: int, dict_size: int, topk: int = 1, epsilon: float = 0.01, act: str = "tanh", deep: bool = False, lastk: int = 1, ): super(PolicyNetwork, self).__init__() self.deep = deep self.dict_size = dict_size self.fc1 = nn.Linear(latent_dim, latent_dim if deep else dict_size) if act == "tanh": self.act = nn.Tanh() elif act == "gelu": self.act = nn.GELU() elif act == "silu": self.act = nn.SiLU() elif act == "linear": self.act = nn.Identity() elif act == "relu": self.act = nn.ReLU() elif act == "leaky_relu": self.act = nn.LeakyReLU() elif act == "jumprelu": self.act = JumpReLU(dict_size, critic_gating=True) else: raise ValueError(f"Invalid activation function: {act}") if deep: self.fc2 = nn.Linear(latent_dim, dict_size) self.register_buffer("epsilon", torch.tensor(epsilon)) self.register_buffer("topk", torch.tensor(topk)) self.register_buffer("lastk", torch.tensor(lastk)) def forward(self, obs: Tensor, critic_values: Optional[Tensor] = None, flatten: bool = False) -> Tensor: """Forward pass with optional critic values for gating.""" if flatten: batch_size, seq_len, dict_size = obs.shape return torch.zeros(batch_size, seq_len, dict_size, device=obs.device, dtype=obs.dtype) logits: Tensor = self.fc1(obs) if isinstance(self.act, JumpReLU): activated = self.act(logits, critic_values) else: activated = self.act(logits) if self.deep: activated = self.fc2(activated) return activated def select_action(self, observation: Tensor, critic_values: Optional[Tensor] = None, selection_weights: Optional[Tensor] = None, coefficient_weights: Optional[Tensor] = None, stage2: bool = False, sae=None, layer_id=None, generation_step: int = 0) -> tuple[Tensor, Tensor]: """Select action with optional critic-based gating and correlation weighting.""" # Stage 1: Use correlation weights for feature selection if coefficient_weights is not None and selection_weights is not None and not stage2: batch_size, seq_len = observation.size(0), observation.size(1) if sae is not None: hidden_dim = observation.shape[-1] observation_2d = observation.view(-1, hidden_dim) encoded_2d = sae.encode(observation_2d) # (batch*seq_len, dict_size) dict_size = encoded_2d.shape[-1] encoded_features = encoded_2d.view(batch_size, seq_len, dict_size) active_feature_mask = (encoded_features > 0).float() # (batch, seq_len, dict_size) masked_selection_weights = selection_weights.unsqueeze(0).unsqueeze(0) * active_feature_mask inactive_penalty = (1 - active_feature_mask) * (-1e9) logits_noisy = masked_selection_weights + inactive_penalty else: logits_noisy = selection_weights.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) # Apply JumpReLU if needed if isinstance(self.act, JumpReLU): epsilon = torch.randn_like(logits_noisy) * self.epsilon logits_noisy = logits_noisy + epsilon logits_noisy = self.act(logits_noisy, critic_values) # Apply coefficient weights only to active features active_mask = (logits_noisy > 0).float() coeff_expanded = coefficient_weights.unsqueeze(0).unsqueeze(0).expand(logits_noisy.size(0), logits_noisy.size(1), -1) action = logits_noisy * coeff_expanded * active_mask else: # Original topk selection effective_topk = min(int(self.topk), logits_noisy.size(-1)) _, topk_indices = torch.topk(logits_noisy, k=effective_topk, dim=-1) steering_coeffs = torch.gather(coefficient_weights.unsqueeze(0).unsqueeze(0).expand(logits_noisy.size(0), logits_noisy.size(1), -1), dim=-1, index=topk_indices) action = torch.zeros_like(logits_noisy) action = action.scatter(dim=-1, index=topk_indices, src=steering_coeffs) # Stage 2: PPO logits else: raw_logits = self.fc1(observation) if selection_weights is not None: selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) raw_logits = raw_logits + selection_boost epsilon = torch.randn_like(raw_logits) * self.epsilon raw_logits_noisy = raw_logits + epsilon if isinstance(self.act, JumpReLU): logits_noisy = self.act(raw_logits_noisy, critic_values) else: logits_noisy = self.act(raw_logits_noisy) if self.deep: logits_noisy = self.fc2(logits_noisy) # Apply SAE masking after activation if selection_weights is not None: if sae is not None: batch_size, seq_len, hidden_dim = observation.shape observation_2d = observation.view(-1, hidden_dim) encoded_2d = sae.encode(observation_2d) # (batch*seq_len, dict_size) dict_size = encoded_2d.shape[-1] encoded_features = encoded_2d.view(batch_size, seq_len, dict_size) active_mask = (encoded_features > 0).float() if isinstance(self.act, JumpReLU): # For JumpReLU: zero out inactive features (no negative penalty) masked_logits = logits_noisy * active_mask else: # For non-JumpReLU: use negative penalty masked_logits = logits_noisy * active_mask + (1 - active_mask) * (-1e9) else: masked_logits = logits_noisy else: masked_logits = logits_noisy if isinstance(self.act, JumpReLU): # DEBUG: Check if all logits are zero or negative active_mask = (masked_logits > 0).float() all_inactive = (active_mask.sum(dim=-1) == 0).float() print(f"DEBUG Stage2: All inactive ratio: {all_inactive.mean().item():.3f}, Active features: {active_mask.sum(dim=-1).mean().item():.1f}, Min/Max: {masked_logits.min().item():.2f}/{masked_logits.max().item():.2f}") action = masked_logits else: effective_topk = min(int(self.topk), masked_logits.size(-1)) topk_vals, topk_indices = torch.topk(masked_logits, k=effective_topk, dim=-1) steering_coeffs = topk_vals action = torch.zeros_like(logits_noisy) action = action.scatter(dim=-1, index=topk_indices, src=steering_coeffs) probs = torch.softmax(logits_noisy, dim=-1) if coefficient_weights is not None or isinstance(self.act, JumpReLU): action_mask = (action > 0).float() log_prob = torch.log(probs + 1e-8) * action_mask log_prob = log_prob.sum(dim=-1) / (action_mask.sum(dim=-1) + 1e-8) else: chosen_probs = torch.gather(input=probs, dim=-1, index=topk_indices) log_prob = torch.log(chosen_probs + 1e-8).mean(dim=-1) return action, log_prob class CriticNetwork(nn.Module): fc1: nn.Module act: nn.Module deep: bool def __init__(self, latent_dim, act="tanh", deep=False): super(CriticNetwork, self).__init__() self.deep = deep self.fc1 = nn.Linear(latent_dim, latent_dim if deep else 1) if act == "tanh": self.act = nn.Tanh() elif act == "relu": self.act = nn.ReLU() elif act == "leaky_relu": self.act = nn.LeakyReLU() elif act == "gelu": self.act = nn.GELU() elif act == "silu": self.act = nn.SiLU() elif act == "linear": self.act = nn.Identity() elif act == "jumprelu": self.act = JumpReLU(1) if self.deep: self.fc2 = nn.Linear(latent_dim, 1) def forward(self, obs: Tensor) -> Tensor: x = self.act(self.fc1(obs)) if self.deep: x = self.fc2(x) return x class PPOTrainer: policy: PolicyNetwork critic: CriticNetwork batch_size: int ppo_clip: float sigma: float = 0.1 loss_type: str q: bool sae: object sparse: bool decode: bool multiple: int optimizer_policy: optim.Adam optimizer_critic: optim.Adam def __init__( self, policy: PolicyNetwork, critic: CriticNetwork, batch_size: int, ppo_clip: float, lr: float, sigma: float = 5.0, loss_type: str = "normal", q: bool = False, sae: object = None, sparse: bool = False, decode: bool = False, multiple: int = 1, ): self.policy = policy self.critic = critic self.batch_size = batch_size self.ppo_clip = ppo_clip self.sigma = sigma self.loss_type = loss_type self.q = q self.sae = sae self.sparse = sparse self.decode = decode self.multiple = multiple self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr) self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr) def compute_advantages(self, rewards: Tensor, values: Tensor, observations: Optional[Tensor] = None, actions: Optional[Tensor] = None) -> Tensor: if self.q and observations is not None and actions is not None: steered_obs = self.action2steer(observations, actions) with torch.enable_grad(): steered_values = self.critic(steered_obs).squeeze(-1) return steered_values - values else: return rewards - values def train_step( self, observations: Tensor, actions: Tensor, rewards: Tensor, old_log_probs: Tensor, critic_values: Tensor, eos_position: Tensor, ) -> tuple[float, float, float, float]: think_lengths = eos_position.cpu().numpy().tolist() policy_losses = [] critic_losses = [] for i, think_length in enumerate(think_lengths): sample_rewards = rewards[i].repeat(think_length) sample_actions = actions[i, :think_length, :] sample_log_probs = old_log_probs[i, :think_length] sample_observations = observations[i, :think_length, :] # 2D: (think_length, features) sample_critic_values = critic_values[i, :think_length] # 1D: (think_length) if isinstance(self.policy.act, JumpReLU): sample_observations = sample_observations.unsqueeze(0) # (1, think_length, features) sample_critic_values = sample_critic_values.unsqueeze(0) # (1, think_length) sample_advantages = self.compute_advantages(sample_rewards, sample_critic_values, sample_observations, sample_actions) if self.loss_type == "softmax": if isinstance(self.policy.act, JumpReLU): sample_logits = self.policy(sample_observations, sample_critic_values) sample_logits = sample_logits.squeeze(0) # (1, think_length, dict_size) -> (think_length, dict_size) else: sample_logits = self.policy(sample_observations) sample_targets: Tensor = sample_actions.argmax(dim=-1) sample_logits_log_probs: Tensor = torch.log_softmax(sample_logits, dim=-1) sample_new_log_probs = sample_logits_log_probs.gather(-1, sample_targets.unsqueeze(-1)).squeeze(-1) elif self.loss_type == "normal": if isinstance(self.policy.act, JumpReLU): sample_mean = self.policy(sample_observations, sample_critic_values) sample_mean = sample_mean.squeeze(0) # (1, think_length, dict_size) -> (think_length, dict_size) else: sample_mean = self.policy(sample_observations) sample_mean_fp32: Tensor = sample_mean.float() sample_actions_fp32: Tensor = sample_actions.float() sample_sigma_fp32: Tensor = torch.ones_like(sample_mean_fp32) * self.sigma sample_dist: torch.distributions.Normal = torch.distributions.Normal(sample_mean_fp32, sample_sigma_fp32) log_prob_full: Tensor = sample_dist.log_prob(sample_actions_fp32) selected_mask: Tensor = (sample_actions_fp32 != 0).to(log_prob_full.dtype) sample_new_log_probs = (log_prob_full * selected_mask).sum(dim=-1) else: raise ValueError(f"Invalid loss type: {self.loss_type}") # Clamp log prob ratio to prevent extreme values log_ratio = torch.clamp(sample_new_log_probs - sample_log_probs, min=-10, max=10) sample_ratio: Tensor = torch.exp(log_ratio) sample_surr1: Tensor = sample_ratio * sample_advantages sample_surr2: Tensor = torch.clamp(sample_ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * sample_advantages sample_policy_loss: Tensor = -torch.min(sample_surr1, sample_surr2).mean() # Use mean instead of sum if isinstance(self.policy.act, JumpReLU): sample_critic_values_flat = sample_critic_values.squeeze(0) # (1, think_length) -> (think_length) else: sample_critic_values_flat = sample_critic_values # Already (think_length,) sample_critic_loss: Tensor = nn.MSELoss()(sample_critic_values_flat, sample_rewards).mean() # Use mean instead of sum policy_losses.append(sample_policy_loss) critic_losses.append(sample_critic_loss) # Calculate average losses for logging total_policy_loss = sum(policy_losses) / len(policy_losses) total_critic_loss = sum(critic_losses) / len(critic_losses) self.optimizer_policy.zero_grad() self.optimizer_critic.zero_grad() total_loss = total_policy_loss + total_critic_loss total_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0) policy_grad_norm = float( sum( p.grad.data.norm(2).item() ** 2 for p in self.policy.parameters() if p.grad is not None ) ** 0.5 ) critic_grad_norm = float( sum( p.grad.data.norm(2).item() ** 2 for p in self.critic.parameters() if p.grad is not None ) ** 0.5 ) self.optimizer_policy.step() self.optimizer_critic.step() return ( total_policy_loss.item(), total_critic_loss.item(), policy_grad_norm, critic_grad_norm, ) def action2steer(self, observations: Tensor, actions: Tensor) -> Tensor: if self.decode: batch_size, seq_len, dict_size = actions.shape action_2d = actions.view(-1, dict_size) decoded_2d = self.sae.decode(action_2d) hidden_dim = decoded_2d.shape[-1] steering = decoded_2d.view(batch_size, seq_len, hidden_dim) * self.multiple else: steering = actions @ self.sae.W_dec * self.multiple return observations.clone() + steering def load_ppo_network( latent_dim: int, dict_size: int, device: str, lr: float, batch_size: int, ppo_clip: float, topk: int, epsilon: float, dtype: torch.dtype, critic_deep: bool, policy_deep: bool, lastk: int, act="tanh", sigma: float = 5.0, loss_type: str = "normal", q: bool = False, sae: object = None, sparse: bool = False, decode: bool = False, multiple: int = 1, **kwargs, # Accept but ignore l0_lambda and other unused parameters ) -> tuple[PolicyNetwork, CriticNetwork, PPOTrainer]: policy_net: PolicyNetwork = ( PolicyNetwork( latent_dim, dict_size, topk=topk, epsilon=epsilon, deep=policy_deep, lastk=lastk, act=act, ) .to(device) .to(dtype) ) critic_net: CriticNetwork = ( CriticNetwork(latent_dim, deep=critic_deep, act=act).to(device).to(dtype) ) ppo_trainer = PPOTrainer( policy_net, critic_net, batch_size=batch_size, ppo_clip=ppo_clip, lr=lr, sigma=sigma, loss_type=loss_type, q=q, sae=sae, sparse=sparse, decode=decode, multiple=multiple ) return policy_net, critic_net, ppo_trainer
if selection_weights is not None: selection_boost = selection_weights.unsqueeze(0).unsqueeze(0) raw_logits = raw_logits + selection_boost epsilon = torch.randn_like(raw_logits) * self.epsilon raw_logits_noisy = raw_logits + epsilon if isinstance(self.act, JumpReLU): logits_noisy = self.act(raw_logits_noisy, critic_values) else: logits_noisy = self.act(raw_logits_noisy) if self.deep: logits_noisy = self.fc2(logits_noisy) # Apply SAE masking after activation if selection_weights is not None: if sae is not None: batch_size, seq_len, hidden_dim = observation.shape observation_2d = observation.view(-1, hidden_dim) encoded_2d = sae.encode(observation_2d) # (batch*seq_len, dict_size) dict_size = encoded_2d.shape[-1] encoded_features = encoded_2d.view(batch_size, seq_len, dict_size) active_mask = (encoded_features > 0).float() if isinstance(self.act, JumpReLU): # For JumpReLU: zero out inactive features (no negative penalty) masked_logits = logits_noisy * active_mask else: # For non-JumpReLU: use negative penalty masked_logits = logits_noisy * active_mask + (1 - active_mask) * (-1e9) else:
Seonglae Cho