In PyTorch, detach() is a method that blocks gradient flow
It saves memory by removing a tensor from the computation graph, allowing you to manipulate it without affecting gradients
import torch from transformers import BertModel, BertTokenizer # Load the small pre-trained BERT model tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-small") model = BertModel.from_pretrained("prajjwal1/bert-small") # Example text input text = "Hello, how are you?" # Tokenize the input text inputs = tokenizer(text, return_tensors="pt") # Efficient inference with torch.no_grad(): output = model(**inputs) # Detach tensor detached_tensor = output.last_hidden_state.detach()
torch.Tensor.detach — PyTorch 2.4 documentation
This method also affects forward mode AD gradients and the result will never
have forward mode AD gradients.
https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

Seonglae Cho