torch.DataLoader

Creator
Creator
Seonglae Cho
Created
Created
2024 Feb 24 16:17
Editor
Edited
Edited
2024 Dec 22 14:2
Refs
Refs
pin_memory=True and setting num_workers to a positive number significantly boosts performance
from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader # Define a custom collate function def custom_collate_fn(batch): texts, labels = zip(*batch) texts = [torch.tensor(t) for t in texts] texts_padded = pad_sequence(texts, batch_first=True, padding_value=0) labels = torch.tensor(labels) return texts_padded, labels # DataLoader with custom collate_fn data_loader = DataLoader(your_dataset, batch_size=2, collate_fn=custom_collate_fn)
 
 
 

pin_memory

pin_memory=True locks the data in page-locked memory, allowing the GPU to access it faster.
train_loader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=True, pin_memory=True )
 
 
 

Recommendations