GCP – Announcing a new monitoring library to optimize TPU performance
For more than a decade, TPUs have powered Google’s most demanding AI training and serving workloads. And there is strong demand from customers for Cloud TPUs as well. When running advanced AI workloads, you need to be able to monitor and optimize the efficiency of your training and inference jobs, and swiftly diagnose performance bottlenecks, node degradation, and host interruptions. Ultimately, you need real-time optimization logic built into your training and inference pipelines so you can maximize the efficiency of your applications — whether you’re optimizing for ML Goodput, operational cost, or time-to-market.
Today, we’re thrilled to introduce a new monitoring library for Google Cloud TPUs, a new set of observability and diagnostic tools that provide granular, integrated performance and accelerator utilization insights so you can continuously assess and improve the efficiency of your Cloud TPU workloads.
Note: If you have shell access to the TPU VM and just need some diagnostic information (e.g., to observe memory usage for a running process), you can use tpu-info, a command-line tool for viewing TPU metrics.
Unlocking dynamic optimization: Key metrics in action
The monitoring library provides snapshot-mode access to a rich set of metrics, such as Tensor core utilization, high-bandwidth memory (HBM) usage, and buffer transfer latency. Metrics are sampled every second (1 Hz) for consistency. See the documentation for a full list of metrics and how you can use them.
You can use these metrics in your code directly to dynamically optimize for greater efficiency. For instance, if your duty_cycle_pct
(a measure of utilization) is consistently low, you can programmatically adjust your data pipeline or increase batch size to better saturate the Tensor core. If hbm_capacity_usage
approaches limits, your code could trigger a dynamic reduction in model size or activate memory-saving strategies to avoid out-of-memory errors. Similarly, hlo_exec_timing
(how long operations are taking to execute on the accelerator) and hlo_queue_size
(how many operations are waiting to be executed) can inform runtime adjustments to communication patterns or workload distribution based on observed bottlenecks.
Let’s see how to set up the library with a couple of realistic examples.
- aside_block
- <ListValue: [StructValue([(‘title’, ‘$300 in free credit to try Google Cloud infrastructure’), (‘body’, <wagtail.rich_text.RichText object at 0x3e33286cea30>), (‘btn_text’, ‘Start building for free’), (‘href’, ‘http://console.cloud.google.com/freetrial?redirectPath=/compute’), (‘image’, None)])]>
Getting started with the library
The TPU monitoring library is integrated within the LibTPU library. Here’s how to install it:
- code_block
- <ListValue: [StructValue([(‘code’, ‘pip install libtpu’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x3e33286ce760>)])]>
For JAX or PyTorch users, libTPU is included in your installation when you install jax[tpu]
or torch_xla[tpu]
(read more about PyTorch/XLA and JAX installation).
You can refer to the library in your Python code: from libtpu.sdk import tpumonitoring
. You can then discover supported functionality with sdk.monitoring.help()
and list available metric names using tpumonitoring.list_supported_metrics()
.
Example 1. Monitoring TPU duty cycle during training for dynamic adjustment
Integrate duty_cycle_pct
logging into your JAX training loop to track how busy the TPUs are.
- code_block
- <ListValue: [StructValue([(‘code’, ‘import jaxrnfrom libtpu.sdk import tpumonitoringrnimport timernrn# — Your JAX model and training setup would go here —rn# — Example placeholder model and data (replace with your actual setup)—rndef simple_model(x):rn return jnp.sum(x)rnrndef loss_fn(params, x, y):rn preds = simple_model(x)rn return jnp.mean((preds – y)**2)rnrndef train_step(params, x, y, optimizer):rn grads = jax.grad(loss_fn)(params, x, y)rn return optimizer.update(grads, params)rnrnkey = jax.random.PRNGKey(0)rnparams = jnp.array([1.0, 2.0]) # Example paramsrnoptimizer = None # Your optimizer (for example, optax.adam)rndata_x = jnp.ones((10, 10))rndata_y = jnp.zeros((10,))rnrnnum_epochs = 10rnlog_interval_steps = 2 # Log duty cycle every 2 stepsrnrnfor epoch in range(num_epochs):rn for step in range(5): # Example steps per epochrnrn params = train_step(params, data_x, data_y, optimizer)rnrn if (step + 1) % log_interval_steps == 0:rn # — Integrate TPU Monitoring Library here to get duty_cycle —rn rn rn duty_cycle_metric = tpumonitoring.get_metric(metric_name=”duty_cycle_pct”) duty_cycle_data = duty_cycle_metric.data() rn print(f”Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:”) print(f” Description: {duty_cycle_metric.description()}”)rn print(f” Data: {duty_cycle_data}”) # — End TPU Monitoring Library Integration —rnrn # — Rest of your training loop logic —rn time.sleep(0.1) # Simulate some computation rnrnprint(“Training complete.”)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x3e33286ce220>)])]>
A consistently low duty cycle suggests potential CPU bottlenecks or inefficient data loading. This example simply prints out the value, but in the real world you can trigger a re-sharding or other actions.
Example 2. Checking HBM utilization before JAX inference for resource management
While running JAX programs on Cloud TPUs, optimizing HBM usage during compilation presents a significant opportunity. By proactively getting insights on potential TPU memory reservations during compilation, you can unlock greater efficiency and prevent out-of-Memory (OOM) errors, which is especially crucial for scaling large models. By checking the hbm_capacity_usage
metric from the monitoring library you can see the available HBM, allowing for dynamic adjustments to your inference strategy and mitigating memory errors.
- code_block
- <ListValue: [StructValue([(‘code’, ‘import jaxrnimport jax.numpy as jnprnfrom libtpu.sdk import tpumonitoringrnrn# — Your JAX model and inference setup would go here —rn# — Example placeholder model (replace with your actual model loading/setup)—rndef simple_model(x):rn return jnp.sum(x)rnrnkey = jax.random.PRNGKey(0)rnparams = None # Load your trained parametersrnrn# Integrate TPU Monitoring Library to get HBM utilization before inferencernrnhbm_util_metric = tpumonitoring.get_metric(metric_name=”hbm_capacity_usage”)rnhbm_util_data = hbm_util_metric.data()rnprint(“HBM Utilization Before Inference:”)rnprint(f” Description: {hbm_util_metric.description()}”)rnprint(f” Data: {hbm_util_data}”)rn# End TPU Monitoring Library Integrationrnrn# Your Inference Logicrninput_data = jnp.ones((1, 10)) # Example inputrnpredictions = simple_model(input_data)rnprint(“Inference Predictions:”, predictions)rnrnprint(“Inference complete.”)’), (‘language’, ‘lang-py’), (‘caption’, <wagtail.rich_text.RichText object at 0x3e33286ced00>)])]>
If HBM usage is unexpectedly high, you might consider optimizing your model size, batching strategy, or input data pipelines.
Maximize your TPU utilization
In this post, we showed you two simple examples of how you can improve the efficiency of your TPU workloads with some proactive monitoring. The TPU monitoring library can help you improve the utilization of your accelerator, dynamically tune them to your use cases, and ensure you have the best cost efficiencies.
To learn more about the TPU monitoring library, please visit the documentation. To get started with Cloud TPUs, please visit our Intro to TPU documentation.
Read More for the details.