Training Large Language Models to Reason in a Continuous Latent Space
The final residual vector (last hidden state) is used as the input embedding. Using multi-stage training, they considered Distribution Shift and initially trained on CoT-based data, then gradually replaced CoT with continuous thoughts to minimize distribution mismatch.
Conclusion
CoT explored only one path at a time like depth-first search (DFS)-like parallel paths, which was proven through latent state analysis. However, they only used GPT2 and performance actually decreased on gsm8k.
Implementation
- Loss Masking, meaning in Continuous thoughts, the cross-entropy loss is designed to focus on predicting the next reasoning step instead of output tokens.
- Reset Optimizer State: Initialize the optimizer during stage transitions to guide the model to adapt to new distributions at each stage.