torch.Tensor.detach()

Creator
Creator
Seonglae ChoSeonglae Cho
Created
Created
2024 Sep 12 23:6
Editor
Edited
Edited
2025 Jul 20 20:59

In PyTorch, detach() is a method that blocks gradient flow

It saves memory by removing a tensor from the computation graph, allowing you to manipulate it without affecting gradients
import torch from transformers import BertModel, BertTokenizer # Load the small pre-trained BERT model tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-small") model = BertModel.from_pretrained("prajjwal1/bert-small") # Example text input text = "Hello, how are you?" # Tokenize the input text inputs = tokenizer(text, return_tensors="pt") # Efficient inference with torch.no_grad(): output = model(**inputs) # Detach tensor detached_tensor = output.last_hidden_state.detach()
 
 
 
 
 
torch.Tensor.detach — PyTorch 2.4 documentation
This method also affects forward mode AD gradients and the result will never have forward mode AD gradients.
 
 

Recommendations