Blog

AI research and insights

Kvax: Fast and easy-to-use FlashAttention implementation for JAX

Today, we’re open-sourcing Kvax, our FlashAttention implementation based on JAX. Designed for efficient training with long sequences, Kvax supports context parallelism and optimized computation of document masks. It outperforms many other FlashAttention implementations in long-context training with dense packing, achieving state-of-the-art performance.