Pallas: a JAX kernel language — JAX documentation
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
It aims to provide fine-grained control over the generated code, combined with
the high-level ergonomics of JAX tracing and the jax.numpy API.
https://docs.jax.dev/en/latest/pallas/index.html

Seonglae Cho