Mixed Precision
scaler = torch.cuda.amp.GradScaler() ... scaler.scale(loss).backward() # scaler.unscale_(optimizer) # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update()
unscale is automated
update()
update scaling factor- FP overflow나 underflow 고려해서 적당하게 유지한다.
unscale_
step 내부에는 포함되어 있지만 Gradient Clipping 같이 부가적인 작업 필요하면 미리 호출해야한다.- scaling하다보면 overflow 확률이 올라가니 Gradient clipping해주는 것