torch.compile()

Creator
Creator
Seonglae ChoSeonglae Cho
Created
Created
2023 Aug 24 17:13
Editor
Edited
Edited
2026 Mar 22 23:47

Speed up using
JIT Compilation

Unlike
torch.jit
, this is a completely different JIT that performs kernel-level optimization by capturing the computational graph and passing it to the backend for Python execution
One issue is that when using forward hooks to extract intermediate hidden states, compile may skip or break hooks while optimizing the graph

Backends

  • “inductor” default
  • “nvfuser”
To avoid crashes, also set fullgraph=False so PyTorch compiles only the safe parts. If your model takes inputs of varying sizes e.g., text sequences or images, use dynamic=True to ensure compatibility.
 
 

Not working for every hf model

Accelerating Hugging Face and TIMM models with PyTorch 2.0
torch.compile() makes it easy to experiment with different compiler backends to make PyTorch code faster with a single line decorator torch.compile(). It works either directly over an nn.Module as a drop-in replacement for torch.jit.script() but without requiring you to make any source code changes. We expect this one line code change to provide you with between 30%-2x training time speedups on the vast majority of models that you’re already running.
Accelerating Hugging Face and TIMM models with PyTorch 2.0
torch.compile Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation
PyTorch Recipes
 
 
 

Recommendations