Speed up using JIT Compilation
Unlike torch.jit, this is a completely different JIT that performs kernel-level optimization by capturing the computational graph and passing it to the backend for Python execution
Backends
- “inductor” default
- “nvfuser”
- "aot_eager” Pytorch Eager Mode but extract graph
- XLA for TPU
import torch model = MyModel() compiled_model = torch.compile(model, dynamic=True, fullgraph=False)
To avoid crashes, also set
fullgraph=False
so PyTorch compiles only the safe parts. If your model takes inputs of varying sizes e.g., text sequences or images, use dynamic=True
to ensure compatibility.