torchft

Creator
Creator
Seonglae Cho
Created
Created
2025 Jan 9 13:49
Editor
Edited
Edited
2025 Jan 9 13:50
Refs
Refs
FSDP
DDP

PyTorch per step Fault Tolerance

from torchft import Manager, DistributedDataParallel from torchft import Optimizer, ProcessGroupGloo manager = Manager(pg=ProcessGroupGloo()) model = nn.Linear(2, 3) model = DistributedDataParallel(manager, model) optimizer = Optimizer(manager, optim.AdamW(model.parameters())) for epoch in range(1000): optimizer.zero_grad() output = model(torch.rand(2, 2)) loss = output.sum() loss.backward() optimizer.step()
 
 
 
 
 

Recommendations