AWS – AWS Neuron introduces Neuron Kernel Interface (NKI), NxD Training, and JAX support for training
Today, AWS announces the release of Neuron 2.20, introducing Neuron Kernel Interface (NKI) (beta), a programming interface for AWS Trainium and Inferentia, enabling developers to build optimized compute kernels for new functionalities, optimizations, and science innovations. Additionally, this release introduces NxD Training (beta), a PyTorch-based library enabling efficient distributed training, with a user-friendly interface compatible with NeMo. This release also introduces support for the JAX framework (beta).
AWS Neuron is the SDK for AWS Inferentia and Trainium based instances purpose-built for generative AI. Neuron integrates with popular ML frameworks like PyTorch. It includes a compiler, runtime, tools, and libraries to support high performance training and inference of AI models on Trn1 and Inf2 instances.
This release adds support features and performance improvements for model training and inference. For training, this release adds Llama 3.1 8B and 70B models support up to 32K sequence length, along with torch.autocast() for native PyTorch mixed precision support and PEFT LoRA techniques. For inference, Neuron 2.20 adds support for Llama 3.1 (405b, 70b, 8b) and Diffusion-Transformers (DiT) models like Pixart-alpha and Pixart-sigma. Additionally, this release adds inference support with top-p sampling on device and 128K context length with Flash Decoding. This release also adds support for Rocky 9.0 operating system and RMSNorm and RMSNormDx operators in the Neuron Compiler.
For more information, see Neuron Release Notes.
Read More for the details.