from torch.cuda.amp import autocast, GradScaler model = YourModel().to(device) optimizer = torch.optim.YourOptimizer(model.parameters()) scaler = GradScaler() for epoch inrange(num_epochs): for input, target in data_loader: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
torch.amp.autocast
Creator
Creator
Seonglae ChoCreated
Created
2024 Jan 12 1:51Editor
Editor
Seonglae ChoEdited
Edited
2024 Jan 12 1:52Refs
Refs
Mixed Precision