Engineering teams use Ray to scale AI workloads across a wide range of hardware, including both GPUs and Cloud TPUs. While Ray provides the core scaling capabilities, developers have often managed the unique architectural details of each accelerator. For Cloud TPUs, this included its specific networking model and Single Programming Multiple Data (SPMD) programming style.
As part of our partnership with Anyscale, we are working on reducing the engineering effort to get started with TPUs on Google Kubernetes Engine (GKE). Our goal is to make the Ray experience on TPUs as native and low-friction as possible.
Today, we are launching several key improveme…
Engineering teams use Ray to scale AI workloads across a wide range of hardware, including both GPUs and Cloud TPUs. While Ray provides the core scaling capabilities, developers have often managed the unique architectural details of each accelerator. For Cloud TPUs, this included its specific networking model and Single Programming Multiple Data (SPMD) programming style.
As part of our partnership with Anyscale, we are working on reducing the engineering effort to get started with TPUs on Google Kubernetes Engine (GKE). Our goal is to make the Ray experience on TPUs as native and low-friction as possible.
Today, we are launching several key improvements that help make that possible.
Ray TPU Library for improved TPU awareness and scaling in Ray Core
TPUs have a unique architecture and a specific programming style called SPMD. Large AI jobs run on a TPU slice, which is a collection of chips connected by high-speed networking called interchip interconnect (ICI).

Previously, you needed to manually configure Ray to be aware of this specific hardware topology. This was a major setup step, and if done incorrectly, jobs could get fragmented resources from different, unconnected slices, causing severe performance bottlenecks.
This new library, ray.util.tpu, abstracts away these hardware details. It uses a feature called SlicePlacementGroup along with the new label_selector API to automatically reserve the entire, co-located TPU slice as one atomic unit. This guarantees the job runs on unified hardware, preventing performance issues from fragmentation. Because Ray couldn’t guarantee this single-slice atomicity before, building reliable true multi-slice training (which intentionally spans multiple unique slices) was impossible. This new API also provides the critical foundation for Ray users to use Multislice technology to scale using multiple TPU slices.
**Expanded support for Jax, Ray Train and Ray Serve **
Our developments cover both training and inference. For training, Ray Train now offers alpha support for JAX (via JaxTrainer) and PyTorch on TPUs.
The JaxTrainer API simplifies running JAX workloads on multi-host TPUs. It now automatically handles the complex distributed host initialization. As shown in the code example below, you only need to define your hardware needs—like the number of workers, topology, and accelerator type—within a simple ScalingConfig object. The JaxTrainer takes care of the rest.
This is a significant improvement because it solves a critical performance problem: resource fragmentation. Previously, a job requesting a “4x4” topology (which must run on a single co-located hardware unit called a slice) could instead receive fragmented resources—for example, eight chips from one physical slice and eight chips from a different, unconnected slice. This fragmentation was a major bottleneck, as it prevented the workload from using the high-speed ICI interconnect that only exists within a single, unified slice.
Example of how the JaxTrainer simplifies training on multi-host TPU:
Ray Serve APIs support TPUs and with the improvements we have made to vLLM TPU, you can continue to use Ray on vLLM when moving to TPUs. This allows you to use the same stack you use on GPUs and run it on TPUs with minimal code changes.
Label-based Scheduling API for easy obtainability
The new Label-Based Scheduling API integrates with GKE custom compute classes. A custom compute class is a simple way to define a named hardware configuration. For example, you can create a class called cost-optimized that tells GKE to try acquiring a Spot instance first, then fall back to a Dynamic Workload Scheduler FlexStart instance, and finally to a reserved instance as a last resort. The new Ray API lets you use classes directly from Python. With a simple label_selector, you can request hardware like “TPU-V6E” or target your cost-optimized class, all without managing separate YAML files.
This same label_selectormechanism also exposes deep hardware control for TPUs. As GKE provisions the TPU pods for a slice, it injects metadata (like worker rank and topology) into each one. KubeRay (which manages Ray on GKE) then reads this GKE-provided metadata and automatically translates it into Ray-specific labels as it creates the nodes. This provides key information like the TPU generation (ray.io/accelerator-type), the physical chip topology (ray.io/tpu-topology), and the worker rank within the slice (ray.io/tpu-worker-id).These node labels let you use a Ray label_selector to pin SPMD workloads to specific, co-located hardware, such as a “4x4” topology or a particular worker rank.
In the example below, a Ray user can request a v6e-32 TPU slice but instruct GKE to use custom compute classes to fallback to v5e-16 if that’s not available. Similarly, the user could start by requesting spot or DWS resources and if not available, fallback to reservation instances.
| Developers select compute and nodepools | Platform Admins set up Kubernetes | 
| @ray.remote(num_cpu=1,   label_selector={ “ray.io/tpu-pod-type”: “v6e-32”, “gke-flex-start”: “true”, }, ** fallback_strategy**=[ {“label_selector”: { “ray.io/tpu-pod-type”: “v5litepod-16”, “reservation-name”: “v5e-reservation”, } }, ] ) def tpu_task(): # Attempts to run on a node in a v6e 4x8 # TPU slice, falling back to a node in a # v5e 4x4 TPU if v6e is unavailable. …  | apiVersion: cloud.google.com/v1   kind: ComputeClass metadata: name: cost-optimized spec: priorities: - flexStart: enabled: true tpu: type: tpu-v6e-slice count: 8 topology: 4x8 - tpu: type: tpu-v5-lite-podslice count: 4 topology: 4x4 reservations: specific: - name: v5e-reservation - affinity: Specific  | 
TPU metrics and logs in one place
You can now see key TPU performance metrics, like TensorCore utilization, duty cycle, High-Bandwidth Memory (HBM) usage, and memory bandwidth utilization, directly in the Ray Dashboard. We’ve also added low-level libtpu logs. This makes debugging much faster, as you can immediately check if a failure is caused by the code or by the TPU hardware itself.
Get started today
Together, these updates are a significant step toward making TPUs a seamless part of the Ray ecosystem. They make adapting your existing Ray applications between GPUs and TPUs a much more straightforward process. Here’s how to learn more and get started:
Review the documentation:
JAX Workloads: See the new Get Started with JAX guide for using the JaxTrainer and learn more about JaxTrain.
**TPU metrics: **View TPU metrics in Ray Dashboard or Grafana
Request TPU capacity: Get started quickly with DWS Flex Start for TPUs, which provides access to TPUs for jobs that run for less than 7 days.
- Related Content: Intro to TPUs
 
Posted in