Stack
JAX is for functional, accelerator-focused computation
- Flax: model authoring
- Optax: composable optimizers
- Orbax large-scale asynchronous checkpointing
From
Jax Notion
Jax Usages
- Infrastructure: XLA (operation fusion & memory optimization), Pathways (tens of thousands of chips distributed & fault recovery).
- Advanced Tools: Pallas/(custom & optimal kernels), Qwix (non-invasive quantization), Jax Grain (reproducible data pipelines).tokamaxopenxla • Updated 2026 Jan 6 0:34
- End-to-End: MaxText/MaxDiffusion (pre-training), Tunix (FT), vLLM JAX TPU backend.
Impact: Real-world cases show significant throughput & cost efficiency improvements (e.g., Kakao 2.7× throughput with XPK).

Seonglae Cho

