PyTorch per step Fault Tolerance
from torchft import Manager, DistributedDataParallel from torchft import Optimizer, ProcessGroupGloo manager = Manager(pg=ProcessGroupGloo()) model = nn.Linear(2, 3) model = DistributedDataParallel(manager, model) optimizer = Optimizer(manager, optim.AdamW(model.parameters())) for epoch in range(1000): optimizer.zero_grad() output = model(torch.rand(2, 2)) loss = output.sum() loss.backward() optimizer.step()