PyTorch per step Fault Tolerance
from torchft import Manager, DistributedDataParallel from torchft import Optimizer, ProcessGroupGloo manager = Manager(pg=ProcessGroupGloo()) model = nn.Linear(2, 3) model = DistributedDataParallel(manager, model) optimizer = Optimizer(manager, optim.AdamW(model.parameters())) for epoch in range(1000): optimizer.zero_grad() output = model(torch.rand(2, 2)) loss = output.sum() loss.backward() optimizer.step()
torchft — pytorch/torchft main documentation
This repository implements primitives and E2E solutions for doing a per-step
fault tolerance so you can keep training if errors occur without interrupting
the entire training job.
https://pytorch-labs.github.io/torchft/

Seonglae Cho