GCP – An efficient path to production AI: Kakao’s journey with JAX and Cloud TPUs
When your messaging platform serves 49 million people – 93% of South Korea’s population – every technical decision carries enormous weight. The engineering team at Kakao faced exactly this challenge when their existing infrastructure hit critical limitations. Their solution? A strategic shift to Google Cloud TPUs using the JAX framework that not only solved their immediate scalability needs but opened new possibilities for advanced AI model development.
Kakao’s approach provides a compelling example of leveraging the high-performance array computing framework JAX for AI model development at scale. While their primary training environment was GPU-based, the team made a strategic decision to adopt the JAX stack on Google Cloud TPUs to optimize for cost and efficiency.
This work laid the groundwork for the development of their proprietary Kanana model family, and several Kanana models — including Kanana-MoE — have recently been released as open source on Hugging Face Hub.
In this post, Minho Ryu and Nayeon Kim detail Kakao’s technical journey. They cover their specific implementation details, from adapting the JAX large language model framework and MaxText for custom data pipelines to their work on mixture-of-experts (MoE) model training.
- aside_block
- <ListValue: [StructValue([(‘title’, ‘Try Google Cloud for free’), (‘body’, <wagtail.rich_text.RichText object at 0x3e613cbe05b0>), (‘btn_text’, ”), (‘href’, ”), (‘image’, None)])]>
Kakao’s journey by Minho and Nayeon:
As engineers at Kakao, we develop models that serve KakaoTalk, a platform supporting services that extend far beyond text. Our rich ecosystem includes chat with over 700,000 images and stickers (emojis), voice and video calls, finance, and navigation.
KakaoTalk’s massive scale and complexity demand that our language models are not only highly efficient but also excel at understanding the Korean language and are flexible enough for diverse applications. These real-world product requirements directly influenced our technical decisions and our need for a customizable training framework.
Our journey with JAX began at an important inflection point. Our existing GPU-based infrastructure was reaching power and budget capacity constraints. We had two options: expand our GPU infrastructure and maintain our existing codebase, or adopt Cloud TPUs, which offered cost-performance advantages while requiring adoption of a new toolchain. We chose Cloud TPUs, viewing the short-term investment as worthwhile for long-term cost-performance benefits, and built our stack on JAX.
We use XPK for Kubernetes cluster management, which simplifies job creation and management on GKE without requiring Kubernetes expertise. For the data pipeline, we adopted Grain due to its deterministic behavior, which is essential for the stability of long-running AI model training jobs.
We focused on adapting the MaxText framework to fit our specific research and compatibility needs. We made two key customizations to the pipeline:
1. Multi-source data blending: When we began exploring training with MaxText, it assumed a single, pre-mixed corpus. Our research requires blending different data sources — such as web text, code, and math — with specific, dynamically-adjusted weights during different training phases. To achieve this flexibility without reprocessing terabytes of data for each experiment, we implemented a solution using Grain’s mix function. This approach allows us to define blending ratios in our configuration, providing the adaptability essential for our iterative research process. We filed a PR for this feature to be supported in MaxText natively, and it has been incorporated here since.
2. Token Processing for Efficiency and Compatibility: To maintain compatibility with our existing Megatron-LM pipeline and improve efficiency, we modified MaxText’s token processing logic. Our data preparation method constructs each training sequence by appending the first token of the subsequent sequence. This creates overlapping, continuous sequences, ensuring that no information is lost at the boundaries and maximizing data utilization.
To validate our new TPU-based workflow, we trained two models. First, we trained the Kanana 2.1 billion parameter model from scratch, and the results demonstrated that our MaxText implementation achieved performance comparable to our existing GPU-based Megatron-LM pipeline at each stage. Second, we performed depth upscaling with continued pre-training from our existing 8B model to a 9.8B architecture. Both approaches succeeded and showed consistent improvements across various benchmarks, confirming that the results on GPU were effectively reproduced on TPU.
Advancing our approach: Training Mixture-of-Experts (MoE) models with MaxText
With the core pipeline validated, we began experimenting with more advanced architectures, specifically MoE models, to build inference-efficient models that maintain strong performance. Our objectives were to explore upcycling an existing dense model into an MoE structure and to evaluate the suitability of the TPU and MaxText stack for this task.
For the experiment, we upcycled our 2.1B dense model into a 13.4B parameter (2.3B active) MoE architecture with 64 experts and 8 active experts per token. We trained this model on the exact same dataset as the original dense model to isolate the impact of the architectural change. The training was performed on v5e TPUs using MaxText with Fully Sharded Data Parallelism (FSDP).
The implementation process was straightforward. We found that MaxText’s flexible design, built on Flax, Optax, and Orbax, was well-suited for the wide range of ablations required for MoE research. Specifically:
-
Integrated Kernels: Megablocks MoE kernels which support optimized MoE features like Group GEMM were already integrated into JAX.
- Combining Schedules: We used the optax.join_schedules function to combine multiple learning rate schedules (e.g. warmup, constant, and annealing) into a single, custom schedule for our training run. This ability to combine different schedules is very useful to experiment with different training strategies.
- Code Customization: We needed to enable the load balancing loss for our sparse matmul implementation. This required inserting a single line of code in the permute function within the MoE block of MaxText to calculate the loss directly from the router logits.
The results showed performance improvements, particularly in code and math benchmarks, suggesting domain specialization among the experts.
Performance Evaluation
This met our objectives and further demonstrated the JAX stack’s utility for advanced model development. We are now extending this work by experimenting with shared experts and replacing initial MoE layers with dense layers, modifications which are simple to implement within the MaxText framework.
Performance improvements and key takeaways
During our work, we gained early access to Trillium TPUs. We managed the transition from v5e by changing a few parameters in our XPK cluster and workload configurations. We observed an immediate and substantial throughput increase of 2.7x across our models, along with improved cost-performance efficiency.
Based on our experience, the JAX stack on TPUs provides a comprehensive and efficient environment for AI model development. The key advantages for our team include:
-
Performance and scalability: The JAX and XLA combination provides just-in-time compilation, and MaxText is optimized for large-scale parallel computing with support for paradigms like SPMD and FSDP.
-
Customizability and control: The codebase, being pure Python and built on libraries like Flax, Optax, and Orbax is intuitive and easy to modify. This allows us to implement custom data pipelines, training strategies, and novel architectures with minimal overhead.
-
Rapid feature adoption: The MaxText framework is updated quickly with features from new state-of-the-art models, allowing us to stay current with our research.
These strengths have made the JAX stack a powerful and flexible foundation for our work in training large language models at Kakao.
Build your Language Models with the JAX Ecosystem:
Kakao’s journey demonstrates how the JAX ecosystem’s modular design — including MaxText, Flax, Optax, and Orbax — enables the customization required for both production pipelines and advanced research, from tailored data blending to rapid experimentation with MoE architectures.
Our sincere thanks to Minho, Nayeon and their team for sharing their insightful engineering work. We look forward to seeing how they and other leading enterprises worldwide continue to use the JAX ecosystem to build the next generation of powerful and efficient language models.
Read More for the details.