torch.amp.autocast

Creator
Creator
Seonglae Cho
Created
Created
2024 Jan 12 1:51
Editor
Edited
Edited
2024 Jan 12 1:52
Refs
from torch.cuda.amp import autocast, GradScaler model = YourModel().to(device) optimizer = torch.optim.YourOptimizer(model.parameters()) scaler = GradScaler() for epoch inrange(num_epochs): for input, target in data_loader: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
 
 
 
 
 
 
 
 
 

Recommendations