Jax

Creator
Creator
Seonglae ChoSeonglae Cho
Created
Created
2021 Jun 13 13:33
Editor
Edited
Edited
2026 Jan 6 0:40
Refs
Refs
XLA
TPU

Stack

JAX is for functional, accelerator-focused computation
  • Flax
    : model authoring
  • Optax
    : composable optimizers
  • Orbax
    large-scale asynchronous checkpointing

From

Jax Notion
 
 
Jax Usages
 
 
 
 
 
  • Infrastructure:
    XLA
    (operation fusion & memory optimization), Pathways (tens of thousands of chips distributed & fault recovery).
Impact: Real-world cases show significant throughput & cost efficiency improvements (e.g.,
Kakao
2.7× throughput with
XPK
).
JAX AI Stack — JAX AI Stack
JAX AI Stack tested releases provide high reliability by ensuring compatibility across its libraries
Google for Developers Blog - News about Web, Mobile, AI and Cloud
Introducing the JAX AI Stack: a modular, end-to-end, production-ready platform built on JAX and co-designed with Cloud TPUs for scalable ML.
Google for Developers Blog - News about Web, Mobile, AI and Cloud
Kakao’s journey with JAX and Cloud TPUs | Google Cloud Blog
Kakao’s approach provides a compelling example of the high-performance array computing framework JAX for AI model development at scale.
Kakao’s journey with JAX and Cloud TPUs | Google Cloud Blog

Programming

Programming TPUs in JAX | How To Scale Your Model
How to use JAX to program TPUs efficiently! Much of this section is taken from <a href='https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html'>here</a>.

Jax for Pytorch

A guide to JAX for PyTorch developers | Google Cloud Blog
PyTorch users can learn about JAX in this tutorial that connects JAX concepts to the PyTorch building blocks that they’re already familiar with.
A guide to JAX for PyTorch developers | Google Cloud Blog
 

Recommendations