Host offloading with JAX on Intel® Xeon® processors
As Large Language Models (LLMs) continue to scale into the hundreds of billions of parameters, device memory capacity has become a big limiting factor in training, as intermediate activations from every layer in the forward pass are needed in the backward pass. To reduce device memory pressure, these activations can be rematerialized during the backward pass, trading memory for recomputation. While rematerialization enables larger models to fit within limited device memory, it significantly increases training time and cost.
Intel® Xeon® processors (5th and 6th Gen) with Advanced Matrix Extensions (AMX) enable practical host offloading of selected memory- and compute-intensive components in JAX training workflows. This approach can help teams train larger models, relieve accelerator memory pressure, improve end-to-end throughput, and reduce total cost of ownership—particularly on TPU-based Google Cloud instances.
By publishing these results and implementation details, Google and Intel aim to promote transparency and share practical guidance with the community. This post describes how to enable activation offloading for JAX on TPU platforms and outlines considerations for building scalable, cost-aware hybrid CPU–accelerator training workflows.
Host offloading
Traditional LLM training is usually done on device accelerators alone. However, modern host machines have much larger memory size than accelerators (512GB or more) and can offer extra compute power, e.g., TFLOPS in case of Intel® Xeon® Scalable Processor with AMX capability. Leveraging host resources can be a great alternative to rematerialization. Host offloading selectively moves computation or data between host and device to optimize performance and memory usage.
Host memory offloading keeps frequently-accessed tensors on the device and spills the rest to CPU memory as an extra level of cache. Activation offloading transfers activations computed on-device in the forward pass to the host, stores them in the host memory, and brings them back to the device in the backward pass for gradient computation. This unlocks the ability to train larger models, use bigger batch sizes, and improve throughput.
In this blog post, we provide a practical guide to offload activations through JAX to efficiently train larger models on TPUs with an Intel® Xeon® Scalable Processor.
Enabling memory offloading in JAX
JAX offers multiple strategies for offloading activations, model parameters, and optimizer states to the host. Users can use checkpoint_names() to create a checkpoint for a tensor. The snippet below shows how to create a checkpoint x:
from jax.ad_checkpoint import checkpoint_name
def layer_name(x, w):
w1, w2 = w
x = checkpoint_name(x, "x")
y = x @ w1
return y @ w2, None
Users can provide checkpoint_policies() to select the appropriate memory optimization strategy for intermediate values. There are three strategies:
- Recomputing during backward pass (default behavior)
- Storing on device
- Offloading to host memory after forward pass and loading back during backward pass
The code below moves x from device to the pinned host memory after the forward pass.
from jax import checkpoint_policies as cp
policy = cp.save_and_offload_only_these_names(
names_which_can_be_saved=[], # No values stored on device
names_which_can_be_offloaded=["x"], # Offload activations labeled "x"
offload_src="device", # Move from device memory
offload_dst="pinned_host" # To pinned host memory
)
Measuring Host Offloading Benefits on TPU v5p
We examined TPU host-offloading on JAX on both fine-tuning and training workloads. All our experiments were run on Google Cloud Platform, using a single v5p-8 TPU instance with single host 4th Gen Intel® Xeon® Scalable Processor.
Fine-tuning PaliGemma2: Using the base PaliGemma2 28B model for vision-language tasks, we fine-tuned the attention layers of the language model (Gemma2 27B) while keeping all other parameters frozen. During fine-tuning, we set the LLM sequence length to 256 and the batch size to 256.
The default checkpoint policy is nothing_saveable, which does not keep any activations on-device during the forward pass. The activations are rematerialized during the backward pass for gradient computation. While this approach reduces memory pressure on the TPU, it increases compute time. To apply host offloading, we offload Q, K, and V projection weights using save_and_offload_only_these_names. These activations are transferred to host memory (D2H) during the forward pass and fetched back during the backward pass (H2D), so the device neither stores nor recomputes them. Figure 2 shows 10% reduction in training time from host offloading. This translates directly into a similar reduction in TPU core-hours, yielding meaningful cost savings. The complete fine-tuning recipe is available at [JAX host offloading].
(Bottom) Memory analysis with and without host offloading.
Training Llama2-13B using MaxText: MaxText offers several rematerialization strategies that can be specified in the training configuration file. We used the policy remat_policy: 'qkv_proj_offloaded' to offload Q, K, and V projection weights. Figure 3 shows ~5% reduction in per-step training time compared to fully rematerializing all activations ( remat_policy: 'full').
The step time was 5% faster with host offloading.
When to offload activations
Activation offloading is beneficial when the time to transfer activations across host and device is lower than the time to recompute them. The timing depends on multiple factors such as PCIe bandwidth, model size, batch size, sequence length, activation tensor sizes, compute capabilities of the device, etc. An additional factor is how much the data movement can be overlapped with computation to keep the device busy. Figure 4 demonstrates an efficient overlap of the device-to-host transfer with compute during the backward pass in PaliGemma2 28B training.
Memory offloading overlaps with compute effectively during backward pass host to device.
Smaller model variants such as PaliGemma2 3B and 9B did not see benefits from host offloading because it is faster to rematerialize all tensors than to transfer them to and from the host. Therefore, identifying the appropriate workload and offloading policy is crucial to realizing performance gain from host offloading
Call to Action
If you train on TPUs and are limited by device memory, consider evaluating activation offloading. Start by labeling candidate activations (for example, Q/K/V projections) and compare step time, memory headroom, and overall cost across representative workloads.
In our experiments, we observed up to ~10% improvement in end-to-end training time for larger workloads, which can reduce total cost of ownership (TCO) by shortening time-to-train or enabling the same workload on smaller instances.
Acknowledgments
Emilio Cota, and Karlo Basioli from Google and Eugene Zhulenev (formerly at Google).