pre hook 문제가 첫번재에는 모든 residual 들어오고 수정되는데
test time 에서 이전 토큰 residual 수정이 안되는 점이다.
그래서 residual 말고 kv 만 캐시되는거를 믿고 해야하는건지
이건 이전부터 알았지만 어차피 같은 layer token 간에 영향 없어서 sequential learning 도 안된다
layer norm 때문인지 minimum 엄청 큰거 안들어오면 layer post hook 걸어도 거의 차이 없길래 아마 매 genration 마다 residual 은 강제로 업데이트 해줘야할거같은데
안더해주고 이미 업데이트 잘되는거면 즉 안해줘도 된다는거면 결과가 안좋아서 더 절망적이다
개선방안
self.policy = policy self.critic = critic self.batch_size = batch_size self.ppo_clip = ppo_clip self.sigma = sigma self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr) self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr) def compute_advantages(self, rewards: Tensor, values: Tensor) -> Tensor: return rewards - values def train_step( self, observations: Tensor, actions: Tensor, rewards: Tensor, old_log_probs: Tensor, critic_observations: Tensor, ) -> tuple[float, float, float, float]: mean: Tensor = self.policy(observations) sigma: Tensor = torch.ones_like(mean) * self.sigma dist: torch.distributions.Normal = torch.distributions.Normal(mean, sigma) new_log_probs: Tensor = dist.log_prob(actions).sum(dim=-1) values: Tensor = self.critic(critic_observations).squeeze(-1) rewards = rewards.unsqueeze(-1) advantages: Tensor = self.compute_advantages(rewards, values) ratio: Tensor = torch.exp(new_log_probs - old_log_probs) surr1: Tensor = ratio * advantages surr2: Tensor = ( torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * advantages
- dynamic coefficient or gating so steeirng not miniumum
- critic value 따라 낮을때만 하도록 유도
- layer 별로 나눠서 인과 주던가
- parllelize 가능한건 나름 좋은거같긴함 인과 없이
- 다만 다음 layer strering observation 으로 하고 이전 layer steering 하면 좋을듯
- 초반 thinking 제한 둬서
- Gemma transcoder
- reasoning 말고 Bias 나 jailbreak 혹은 hallucination
문제해결 cot coding
sample wise thinking finishing index 구하기 어려움 # Left padding for context & right padding for answer
징그러운 퍼포먼스 오류
여기서 mmlu max new token 1 에 non cot 일때 도대체 왜 첫 validation 사용 unique indices 가 4 에서 400으로 늘어난걸가
- 모든커밋 가보며 실험 실행
- 원래 성능 나온 부분ㅂ이랑 diff 오 오만 부분 비교
- seed 부터 데이터 정답판정등 개지랄했으나
- transformer eager 가 문제였음
/cs/st/projects2/a/2/se/control-ai 24-5544 *2 ?1 ❯ python train.py train --eval --layers="24," --task="mmlu" --decode --select_token Loading checkpoint shards: 100%|███████████████████████████████████████| 2/2 [00:03<00:00, 1.67s/it] wandb: Currently logged in as: seonglae (texonom). Use `wandb login --relogin` to force relogin wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. wandb: Tracking run with wandb version 0.19.4 wandb: Run data is saved locally in /cs/student/projects2/aisd/2024/seongcho/control-ai/wandb/run-20250720_225934-o2h4d4te wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select wandb: ⭐️ View project at https://wandb.ai/texonom/control_rl wandb: 🚀 View run at https://wandb.ai/texonom/control_rl/runs/o2h4d4te Training Steps: 0%| | 0/501 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation. Step 0: Avg Train Acc 0.7500, Val Acc 0.5880 Layer 24: Policy Loss -0.0000, Critic Loss 0.9404, Grad Norms (P/C) 0.00/135.62, Recon Loss 36.8710, Unique Indices: 481, Avg Activation: 30.0000 Training Steps: 20%|█████████▏ | 100/501 [01:46<04:53, 1.36it/s]Step 100: Avg Train Acc 0.6950, Val Acc 0.5920 Layer 24: Policy Loss 0.0661, Critic Loss 0.1462, Grad Norms (P/C) 0.00/25.43, Recon Loss 36.9874, Unique Indices: 481, Avg Activation: 30.0000 Training Steps: 40%|██████████████████▎ | 200/501 [03:30<04:08, 1.21it/s]Step 200: Avg Train Acc 0.7150, Val Acc 0.5900 Layer 24: Policy Loss 0.0602, Critic Loss 0.1387, Grad Norms (P/C) 0.00/22.88, Recon Loss 38.2507, Unique Indices: 486, Avg Activation: 30.0000 Training Steps: 60%|███████████████████████████▌ | 300/501 [05:20<02:40, 1.25it/s]Step 300: Avg Train Acc 0.7037, Val Acc 0.5940 Layer 24: Policy Loss 0.0519, Critic Loss 0.0420, Grad Norms (P/C) 0.00/22.74, Recon Loss 39.3205, Unique Indices: 484, Avg Activation: 30.0000 Training Steps: 80%|████████████████████████████████████▋ | 400/501 [07:03<00:45, 2.23it/s]Step 400: Avg Train Acc 0.7037, Val Acc 0.5920 Layer 24: Policy Loss 0.1766, Critic Loss 0.1142, Grad Norms (P/C) 0.00/59.14, Recon Loss 39.3443, Unique Indices: 484, Avg Activation: 30.0000 Training Steps: 100%|█████████████████████████████████████████████▉| 500/501 [08:36<00:00, 2.82it/s]Step 500: Avg Train Acc 0.7150, Val Acc 0.5940 Layer 24: Policy Loss 0.0893, Critic Loss 0.0963, Grad Norms (P/C) 0.00/21.85, Recon Loss 38.8676, Unique Indices: 485, Avg Activation: 30.0000 Training Steps: 100%|██████████████████████████████████████████████| 501/501 [08:49<00:00, 1.06s/it] /cs/student/projects2/aisd/2024/seongcho/control-ai/eval.py:412: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. ckpt = TrainResult.model_validate(torch.load(checkpoint)) Config model: gemma2b task: mmlu layers: [24] select_token: True decode: True category: None Evaluating: 100%|██████████████████████████████████████████████| 14042/14042 [05:56<00:00, 39.39it/s] {585: 5132, 608: 4466, 586: 2255, 599: 2189} Final mmlu Accuracy with Steering: 55.44% Results saved to ./checkpoints/gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select/mmlu_24_steered.json Stats saved to ./checkpoints/gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select/mmlu_eval.json wandb: wandb: 🚀 View run gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select at: https://wandb.ai/texonom/control_rl/runs/o2h4d4te wandb: Find logs at: wandb/run-20250720_225934-o2h4d4te/logs
/cs/st/projects2/a/2/se/control-ai 24-5544 *2 ?1 ❯ python train.py train --eval --layers="24," --task="mmlu" --decode --select_token Loading checkpoint shards: 100%|███████████████████████████████████████| 2/2 [00:03<00:00, 1.67s/it] wandb: Currently logged in as: seonglae (texonom). Use `wandb login --relogin` to force relogin wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. wandb: Tracking run with wandb version 0.19.4 wandb: Run data is saved locally in /cs/student/projects2/aisd/2024/seongcho/control-ai/wandb/run-20250720_225934-o2h4d4te wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select wandb: ⭐️ View project at https://wandb.ai/texonom/control_rl wandb: 🚀 View run at https://wandb.ai/texonom/control_rl/runs/o2h4d4te Training Steps: 0%| | 0/501 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation. Step 0: Avg Train Acc 0.7500, Val Acc 0.5880 Layer 24: Policy Loss -0.0000, Critic Loss 0.9404, Grad Norms (P/C) 0.00/135.62, Recon Loss 36.8710, Unique Indices: 481, Avg Activation: 30.0000 Training Steps: 20%|█████████▏ | 100/501 [01:46<04:53, 1.36it/s]Step 100: Avg Train Acc 0.6950, Val Acc 0.5920 Layer 24: Policy Loss 0.0661, Critic Loss 0.1462, Grad Norms (P/C) 0.00/25.43, Recon Loss 36.9874, Unique Indices: 481, Avg Activation: 30.0000 Training Steps: 40%|██████████████████▎ | 200/501 [03:30<04:08, 1.21it/s]Step 200: Avg Train Acc 0.7150, Val Acc 0.5900 Layer 24: Policy Loss 0.0602, Critic Loss 0.1387, Grad Norms (P/C) 0.00/22.88, Recon Loss 38.2507, Unique Indices: 486, Avg Activation: 30.0000 Training Steps: 60%|███████████████████████████▌ | 300/501 [05:20<02:40, 1.25it/s]Step 300: Avg Train Acc 0.7037, Val Acc 0.5940 Layer 24: Policy Loss 0.0519, Critic Loss 0.0420, Grad Norms (P/C) 0.00/22.74, Recon Loss 39.3205, Unique Indices: 484, Avg Activation: 30.0000 Training Steps: 80%|████████████████████████████████████▋ | 400/501 [07:03<00:45, 2.23it/s]Step 400: Avg Train Acc 0.7037, Val Acc 0.5920 Layer 24: Policy Loss 0.1766, Critic Loss 0.1142, Grad Norms (P/C) 0.00/59.14, Recon Loss 39.3443, Unique Indices: 484, Avg Activation: 30.0000 Training Steps: 100%|█████████████████████████████████████████████▉| 500/501 [08:36<00:00, 2.82it/s]Step 500: Avg Train Acc 0.7150, Val Acc 0.5940 Layer 24: Policy Loss 0.0893, Critic Loss 0.0963, Grad Norms (P/C) 0.00/21.85, Recon Loss 38.8676, Unique Indices: 485, Avg Activation: 30.0000 Training Steps: 100%|██████████████████████████████████████████████| 501/501 [08:49<00:00, 1.06s/it] /cs/student/projects2/aisd/2024/seongcho/control-ai/eval.py:412: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. ckpt = TrainResult.model_validate(torch.load(checkpoint)) Config model: gemma2b task: mmlu layers: [24] select_token: True decode: True category: None Evaluating: 100%|██████████████████████████████████████████████| 14042/14042 [05:56<00:00, 39.39it/s] {585: 5132, 608: 4466, 586: 2255, 599: 2189} Final mmlu Accuracy with Steering: 55.44% Results saved to ./checkpoints/gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select/mmlu_24_steered.json Stats saved to ./checkpoints/gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select/mmlu_eval.json wandb: wandb: 🚀 View run gemma2b_mmlu_24_ppo_1e-05_0720_225934_30.0_select at: https://wandb.ai/texonom/control_rl/runs/o2h4d4te wandb: Find logs at: wandb/run-20250720_225934-o2h4d4te/logs
max new token 을 @config.py 에서 task dataset 별로 정하고 bbq mmlu 는 1 그리고 gsk8k 는 1024 그리고 새로더할 두개 wmdp 는 1 tofu 는 128 simpleqa 는 32로 한다 각 데이터 구조는 wmdp 기존 bbq mmlu 랑 젤비슷한데 question 하고 choices column 간단하고 answer 도 0 1 2 3 으로 mmlu 랑 거의 흡사 tofu question, answer 라는 column 은 비슷하지만 single choice 가 아니라 generation 이야. answer extraction 할때 type 따라 추출하는데 gsm8k 처럼 reason 이면 cot extract 하고 bbq, mmlu 처럼 select 이면 지금처럼 하나토큰 하고 tofu 같이 open generation 이면 answer 를 그냥 eos 보고 generation 전체파트로 extract 하면 댐 당연히 @utils.py 에 함수 만들고 한줄만 trian eval 에서는 추가하고 자 그리고 simpleqa 는 problem answer 로 마찬가지로 tofu 같이 "answer" type 인데 알아둬야할거 - test, split val 어케할지 중요하단 말이지 - simpleqa - 이ㅏ것도 4000개 test split 에 하나만 있어서 여기 else 부분인데 아예 하나라 codataset config. test 도 없으니 .train = 'test' 에서 test_size 랑 val_size 나누면 댐 test val 둘다 없을ㄹ경우 코드구현 아ㅣ렇게 처리하고 val 은 있은경우는 기존코드 그대로 하되 test 는 알아서 추가 - wmdp - test 밖에 없는데 load_dataloaders에서 merged false 면 자동나눠지게? subset - wmdp-bio, wmdp-chem, wmdp-cyber 합치고 merged=False 야 그리고 또 데이터처리는 @dataset.py 에 다추가해 매핑 import 하고 중요한건 정답 처리방식이야 단순히 reward 로 맞고 틀리고는 select, reason 타입일때고 answer type 에서는 메소드별로 달라서 정의를 config.py 에서 개별로 reward_func: 로 줄건데 simpleqa 는 Normalized Exact Match, tofu 는 We report token-level F1 score between the model’s prediction and the gold answer, following SQuAD-style evaluation 으로 할거야 아 tofu 는 일단 넣지마 니머지 두개만 tofu 는 fine tuining 필요하다헤서 못할듯 일단 reward = token_f1(pred, gold) 은 힘수는 따로 추가해놔 simpleqa 에서도 해보게
Cot Error
Loading checkpoint shards: 100%|█████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.23s/it] wandb: Currently logged in as: seonglae (texonom). Use `wandb login --relogin` to force relogin wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. wandb: Tracking run with wandb version 0.19.4 wandb: Run data is saved locally in /cs/student/projects2/aisd/2024/seongcho/steer-rl/wandb/run-20250721_165303-l8sxauet wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run gemma2b_gsm8k_20_ppo_1e-05_0721_165303_30.0_cot wandb: ⭐️ View project at https://wandb.ai/texonom/control_rl wandb: 🚀 View run at https://wandb.ai/texonom/control_rl/runs/l8sxauet Training Steps: 0%| | 0/126 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation. W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break from `Tensor.item()`, consider setting: W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] torch._dynamo.config.capture_scalar_outputs = True W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] or: W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] to include these operations in the captured graph. W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break: from user code at: W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] return func(*args, **kwargs) W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] outputs = self.model( W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 667, in forward W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] layer_outputs = decoder_layer( W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] return inner() W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1779, in inner W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] args_result = hook(self, args) W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/steer.py", line 48, in __call__ W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] action, log_prob = self.policy_net.select_action(observation_detached) W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/ppo.py", line 77, in select_action W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] topk_vals, topk_indices = torch.topk(logits_noisy, k=int(self.topk), dim=-1) W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] W0721 16:53:08.011000 2982756 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Training Steps: 0%| | 0/126 [00:02<?, ?it/s] Traceback (most recent call last): File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 777, in <module> fire.Fire(TrainController) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 135, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 752, in train train_metrics = self.perform_training_step(batch) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 237, in perform_training_step input_ids, attention_mask, generated_ids, correct_answers = self.generate_steered(batch) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 127, in generate_steered generated_ids = self.llm.generate( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/generation/utils.py", line 2223, in generate result = self._sample( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/generation/utils.py", line 3214, in _sample outputs = model_forward(**model_inputs, return_dict=True) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn return fn(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__ return self._torchdynamo_orig_callable( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__ return _compile( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function return function(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner out_code = transform_code_object(code, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object transformations(instructions, code_options) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn return fn(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run super().run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function return variables.UserFunctionVariable(fn, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function return variables.UserFunctionVariable(fn, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 928, in call_function return self.call_method(tx, "__call__", args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 788, in call_method return UserMethodVariable(method, self, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function return handler(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 711, in <lambda> return lambda tx, args, kwargs: obj.call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function return handler(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 848, in builtin_dispatch rv = fn(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 766, in call_self_handler result = self_handler(tx, *args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1002, in _call_int_float item = arg.call_method(tx, "item", [], {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method result = handler_method(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 769, in method_item unimplemented("Tensor.item") File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: Tensor.item from user code: File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func return func(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward outputs = self.model( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 667, in forward layer_outputs = decoder_layer( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl return inner() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1779, in inner args_result = hook(self, args) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/steer.py", line 48, in __call__ action, log_prob = self.policy_net.select_action(observation_detached) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/ppo.py", line 77, in select_action topk_vals, topk_indices = torch.topk(logits_noisy, k=int(self.topk), dim=-1) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True Traceback (most recent call last): File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 777, in <module> fire.Fire(TrainController) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 135, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 752, in train train_metrics = self.perform_training_step(batch) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 237, in perform_training_step input_ids, attention_mask, generated_ids, correct_answers = self.generate_steered(batch) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/train.py", line 127, in generate_steered generated_ids = self.llm.generate( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/generation/utils.py", line 2223, in generate result = self._sample( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/generation/utils.py", line 3214, in _sample outputs = model_forward(**model_inputs, return_dict=True) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn return fn(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__ return self._torchdynamo_orig_callable( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__ return _compile( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function return function(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner out_code = transform_code_object(code, transform) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object transformations(instructions, code_options) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn return fn(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run super().run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function return variables.UserFunctionVariable(fn, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function return variables.UserFunctionVariable(fn, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 928, in call_function return self.call_method(tx, "__call__", args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 788, in call_method return UserMethodVariable(method, self, source=source).call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_ tracer.run() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION self.call_function(fn, args, {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function return handler(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 711, in <lambda> return lambda tx, args, kwargs: obj.call_function( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function return handler(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 848, in builtin_dispatch rv = fn(tx, args, kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 766, in call_self_handler result = self_handler(tx, *args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1002, in _call_int_float item = arg.call_method(tx, "item", [], {}) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method result = handler_method(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 769, in method_item unimplemented("Tensor.item") File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: Tensor.item from user code: File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func return func(*args, **kwargs) File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward outputs = self.model( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 667, in forward layer_outputs = decoder_layer( File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl return inner() File "/cs/student/projects2/aisd/2024/seongcho/miniconda3/envs/sae/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1779, in inner args_result = hook(self, args) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/steer.py", line 48, in __call__ action, log_prob = self.policy_net.select_action(observation_detached) File "/cs/student/projects2/aisd/2024/seongcho/steer-rl/control_rl/ppo.py", line 77, in select_action topk_vals, topk_indices = torch.topk(logits_noisy, k=int(self.topk), dim=-1) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True wandb: wandb: 🚀 View run gemma2b_gsm8k_20_ppo_1e-05_0721_165303_30.0_cot at: https://wandb.ai/texonom/control_rl/runs/l8sxauet wandb: Find logs at: wandb/run-20250721_165303-l8sxauet/logs
Seonglae Cho