opensource.google.com

Menu

Training Marin 32B: What an open lab can build with TPUs, JAX, and a little persistence

Thursday, December 18, 2025

Last summer, we partnered with Google to share how Marin trained a fully open 8B foundation model using JAX and TPUs. Since then, our process hasn't changed much, but the scale has. Over the summer, we trained a 32B model entirely in the open, and most days there was just one person keeping the run moving.

Large-scale training is usually associated with big teams and bigger infrastructure. Large model releases typically have hundreds of authors. Marin tests a different hypothesis: using open source software and data, small teams can train serious foundation models if the tooling is good, the platform is stable, and the process is transparent. The Marin 32B run was our strongest validation yet.


A model built with one hand on the helm

Marin was started at Stanford University's Center for Research on Foundation Models with the goal of building radically open foundation models. In May, we released Marin 8B Base, which bested the popular Llama 3.1 8B Base on 14 of 19 benchmarks. Marin 8B was trained using Google TPU v4 and TPU v5e from the TPU Research Cloud.

Building on that success, we set out to build a 32B model starting in June. Our 32B training run followed Marin's usual "Tootsie Roll" style: start with a solid recipe, instrument heavily, and adapt mid-flight when necessary. That flexibility matters, because the first time you train at a larger scale, issues inevitably arise.

The timing, however, was less than ideal, as universities tend to empty out over the summer. Students graduate, get internships, go home, or travel the world. Marin was no different. By June, our team was down to one full time research engineer, with a few PhD students providing guidance when they weren't busy with their dissertations. Nevertheless, we pushed forward.

To spoil the ending, the model turned out quite well. On release, Marin 32B Base was the best open source base model, and it outperformed comparable open-weights models like Google's Gemma 3 27B PT on 24 of 42 base-model evaluations.

There were many bumps along the way, resulting in multiple mid-run corrections, but through it all Google's TPU infrastructure stayed rock-solid, and JAX's predictable performance let us iterate quickly. This meant that even with a tiny team, we could diagnose, patch, and continue training without losing momentum.

To be blunt: one researcher kept the 32B run alive all summer, juggling preemptible slices, rebuilding optimizer state, switching architectures, and generally shepherding ~6.4 trillion tokens across v5p and v4 pods—while mostly working on other Marin projects. The fact that this was possible speaks to the stability of the TPU platform and the maturity of the JAX/Marin stack.

The short version of a long summer

Our retrospective goes into much more detail about every spike, switch and cooldown. Here's the condensed version.

We began with a Llama-3-style 32B backbone and our best 8B data mix, running on preemptible TPU v5p pods. Preemptions were predictable, and recovery was nearly automatic. As availability tightened, however, we moved to dedicated TPU v4 capacity. After a slight tweak to gradient checkpointing to accommodate the older hardware (made easy by JAX's built-in support), we were back up and running and performance stayed excellent.

Around 70k steps, persistent loss spikes appeared. We tried clipping, update-norm guards, skip-step heuristics, "necromancy" (rebuilding optimizer state), and swapping in optimizers like Muon. Nothing helped. The model needed architectural support.

So, we warm-started the run onto a Qwen3-style architecture, which is the same as the Llama 3 architecture, except that it adds QK-Norm to attention. After a brief loss bump, the spikes vanished. The model recovered to its expected trajectory within ~10 billion tokens and remained stable.

Towards the end of training, it was time for a cool down. When training LLMs, one "cools down" the model by lowering the learning rate and changing the data mix to higher quality data. Our first cooldown surfaced two issues: contamination from a cached math dataset, and a training-loss phase shift caused by our linear-congruential shuffle. Switching to a Feistel-based shuffle fixed the latter completely. After cleaning the data and re-running the cooldown, the second cooldown was smooth and produced the final model.

The result: a strong, open 32B base model

Marin 32B Base is a competitive open-source base model. It outperformed Olmo 2 32B Base—the previous best fully open-source base model—on 32 of 42 tasks, and it performs especially well on knowledge-heavy evaluations like ARC, BoolQ, and PIQA.

Head-to-head, Marin 32B Base also beat Gemma 3 27B PT on 24 of 42 tasks, and its overall average rank places it alongside Qwen 2.5 32B and the newer Olmo 3 32B models. On our evaluation suite, Marin 32B Base actually ties Olmo 3 32B Base in win rate, despite Olmo 3 being trained by a much larger team and arriving a month later.

Mean rank across our evaluation suite (lower is better). Marin 32B Base lands in the top cluster of open(-weight) models, alongside Qwen 2.5 and Olmo 3, and ahead of Gemma 3 27B PT and Olmo 2 32B. Gray bars indicate open weight models, while blue bars indicate open source models.

While Olmo 3 32B Base now comfortably leads on math and coding benchmarks, Marin 32B Base holds its own and still leads on many knowledge QA evaluations. For a model trained with a fraction of the team size typically expected for a 30B-scale run, we're proud of where it landed.

Because Marin 32B Base (like Olmo 3 32B) is open source, the weights, code, data recipes, and every experimental detour are public. Anyone can reproduce, audit, or build on the work.


The stack that made it possible

TPU stability across large slices

During the run, we moved across preemptible v5p-512 slices coordinated with Cloud TPU Multislice, a v4-2048 slice for the long middle, and several mid-run architectural transitions. Throughout, TPUs were completely reliable for us: no mysterious hangs, no collective-op debugging. Preemptions were predictable and easy to recover from.

JAX + Levanter = predictable performance

Levanter builds on JAX's XLA compilation. In practice, what mattered for us was deterministic restarts, stable MFU at scale without custom kernels, and JAX's activation checkpointing, which made the v5p to v4 migration easy.

Marin's experiment system

Marin logs every step of the experimental pipeline: hyperparameters, code versions, datasets, metrics, and artifacts. Even with architectural switches and restarts, the run never devolved into a tangle of scripts. And because it's all open, anyone can retrace or reproduce the training.

What's next

Marin 32B Base is a strong base model, but we're not done. Here's what's coming next:

  • A reasoning-optimized Marin 32B
  • Hardened multislice TPU support for smoother preemptible training
  • Exploring MoE variants for the next scale
  • Continuing to release everything, including successes and failures, openly

Closing thought

Training a 32B model with a small team isn't about heroics but about using the right tools and infrastructure. TPUs' reliability, JAX's clarity and performance, and Marin's open, reproducible process provided the leverage we needed. If the 8B run showed that open labs can build credible models, the 32B run showed they can do it at scale: quietly, steadily, and with far fewer people than you might expect.

SpatialReasoner: Teaching VLMs to "see" structure — Accelerated with Tunix on TPUs

Wednesday, December 17, 2025

Introduction

We are seeing an increasing interest in Tunix among researchers focusing on the post-training phase of model development. As a native JAX library, Tunix offers the flexibility needed to refine foundation models—including Vision-Language Models (VLMs) and not just LLMs—helping them significantly improve their spatial reasoning capabilities.

Today, we are highlighting the work of the PLAN Lab (Perception and LANguage Lab) at the University of Illinois Urbana-Champaign (UIUC). To address the critical lack of spatial awareness in VLMs, they built SpatialReasoner-R1, a model capable of fine-grained spatial logic. They utilized Tunix and leveraged the Google TPU Research Cloud (TRC) to scale their experiments.

In this blog, Professor Ismini Lourentzou and her team explain how they used Tunix's modular design to implement novel alignment algorithms and improve spatial reasoning in VLMs.

The "Where" Problem in VLMs

Modern Vision-Language Models (VLMs) can describe images and answer basic visual questions with impressive fluency. However, they often struggle with fine-grained spatial understanding. If you ask a VLM to estimate distances, directions, or the precise relative positions of objects, it frequently "hallucinates" coordinates or produces inconsistent reasoning with vague answers.

These capabilities are critical for real-world applications, such as robotics, where precise spatial reasoning enables safe and intelligent interaction with physical environments.

To bridge this gap, we developed the SpatialReasoner-R1 (4B and 8B versions), a model trained to perform step-by-step visually grounded spatial reasoning. It achieves 95.59 on Qualitative Accuracy and 77.3 on Quantitative Accuracy for our 8B fDPO model, outperforming the strongest baseline by ~9% in average accuracy on the SPATIALRGPT-Bench while preserving strong general vision-language abilities.

The Method: Fine-Grained Direct Preference Optimization (fDPO)

The secret sauce behind SpatialReasoner-R1 is a new technique called Fine-Grained Direct Preference Optimization (fDPO).

Standard alignment methods (like DPO) usually give a model a simple "thumbs up" or "thumbs down" for an entire response. But spatial reasoning is complex— for example, a model might correctly identify an object yet make a flawed logical inference about its location.

fDPO introduces segment-specific preference granularity. We optimize separate loss components for:

  1. Descriptive Grounding: Does the model correctly perceive and describe the objects in the image?
  2. Logical Reasoning: Is the step-by-step deduction sound and follows coherent spatial logic?

To generate high-quality training signals, we built a Multi-Model Monte Carlo Tree Search (M3CTS) data generation pipeline, which constructs diverse reasoning trajectories that guide the model toward reliable spatial understanding.

Tunix: Modularity for Novel Research

Implementing a custom objective like fDPO can be difficult in rigid frameworks. Tunix addresses this by providing a well-structured and extensible DPOTrainer that makes it possible to introduce new alignment objectives without reengineering the training pipeline.

This modularity meant we could reuse the entire underlying training stack—sharding, data loading, and loop management—while injecting our novel research logic with just a small amount of well-contained code.

While our backbone model (Sa2VA) required specific architectural handling, the core fDPO algorithm is model-agnostic. We found the Tunix experience smooth and well-documented, making it easy to prototype and iterate on fine-tuning workflows without reinventing the wheel.

Google TRC & TPUs: Reliability at Scale

Training a model to reason over long horizons requires significant compute. The Google TPU Research Cloud (TRC) provided the infrastructure we needed to make large-scale training practical.

  • Scalability: Tunix's integration with TPUs allowed us to scale our experiments seamlessly.
  • Reliability: The system performed reliably across multiple TPU runs, which was essential for conducting large-scale spatial reasoning benchmarks.
  • Support: The Google Tunix and TRC teams assisted with infrastructure setup and experiment design, helping us refine our multi-model exploration strategy.

Looking Ahead: Open Source Contributions

We believe that open-source, extensible tools like Tunix are vital for fostering innovation. They lower the barrier for researchers to experiment with new training objectives without rebuilding core infrastructure.

In that spirit, we contributed our fDPO implementation back to the Tunix ecosystem. We open-source the core fDPO components, enabling the community to apply segment-specific preference optimization to their own models.

Get Started

You can explore our research and the tools we used below:

GRL: Turning verifiable games into a post-training suite for LLM agents with Tunix on TPUs

Tuesday, December 16, 2025

Introduction

JAX is widely recognized for its power in training large-scale AI models. However, a primary bottleneck in the next phase of AI development—LLM post-training with Reinforcement Learning (RL)—is the scarcity of environments with verifiable rewards.

Today, we are highlighting the work of the GRL (Game Reinforcement Learning) team at UC San Diego. To solve the data bottleneck, they have built a pipeline to turn video games into rigorous reasoning benchmarks. They utilized Tunix, a JAX-native research-friendly RL framework that supports multi-host, multi-turn capabilities, and leveraged the Google TPU Research Cloud (TRC) to scale their experiments. The results are promising: this approach has yielded significant improvements in model quality, particularly in planning and reasoning tasks, proving that games can be a viable substrate for serious AI capability training.

In this blog the GRL team explains how they are combining game environments, modular Tunix library for RL post-training, and TPU compute to train the next generation of agents.


Why Verifiable Games for LLM Post-Training?

Current RL post-training has shown strong gains in domains like math and coding because success can be auto-checked. However, these settings are often narrow and short-term. We are effectively overfitting RL to clean problems, while the next generation of agents must operate in messy, multi-step worlds.

To unlock RL as a systematic method for reasoning, we need a diverse pool of environments where rewards are grounded in explicit, machine-checkable rules. Games are this missing, underused substrate.

  1. The Performance Gap: LLMs still perform surprisingly poorly on many strategy games, revealing a clear gap between model behavior and human-level interactive competence.
  2. Verifiable Signals: Games come with built-in verifiable signals—wins, scores, puzzle completion—meaning outcomes are automatically and unambiguously graded without human labeling.
  3. Long-Horizon Reasoning: Unlike short QA tasks, games force models to plan, explore, and reason over many steps.
  4. Abundance: Decades of RL research has produced a standardized ecosystem of diverse environments ready to be recycled.

Game Reinforcement Learning (GRL): A Unified Game-to-Post-Training Pipeline

To harness this ecosystem, we built GRL, a comprehensive suite designed to recycle diverse game environments into a reusable post-training resource. Our mission is to prioritize environments with executable success checks—ranging from text-based puzzles to embodied 3D worlds and web/GUI workflows. Our code and ecosystem live under the LM Games organization (lmgame.org).

GRL provides three key capabilities:

  • A Unified Pipeline: We standardize the conversion of games into RL-ready environments with structured states and consistent metrics. This makes results comparable across models and research groups.
  • Versatile Configuration: Researchers can tailor interaction styles (e.g., max_turns, natural language feedback) while mixing training data from different tasks seamlessly. This allows for training on puzzles, math, and web tasks within a single run.
  • Algorithm-Agnostic Interface: GRL works with any agentic training algorithm. While we frequently use PPO, the system serves as a robust testbed for developing new RL techniques.

The Engine: Plugging into the Tunix RL Framework

Designed for Research Flexibility and Multi-Turn Agents

In practice, plugging a GRL game agent into Tunix is seamless thanks to its modular design. Tunix is built specifically to support multi-turn agentic tasks, allowing researchers to leverage native one-turn inference APIs to achieve complex multi-turn rollouts, then batch those outputs directly back into the training flow. This research flexibility is key; the framework is lightweight enough for quick iteration and benchmarking, yet modular enough to allow fine-grained adjustments to reward functions, algorithms, and hardware-aware settings like mesh sizes.

We first define an agent_cfg (see picture above) that tells the system which game to play (eg. Sokoban or Tetris), how the LLM should talk (chat template + reasoning style), and its budgets (max turns, tokens per turn, action format). On the Tunix side, we then load a pre-trained model into three roles: actor, critic, and reference and build ClusterConfig to specify rollout and training configs and PpoConfig to specify RL hyperparameters. The glue is minimal and the layout is clear and research friendly: once agent_cfg, ppo_cfg, and cluster_cfg are defined, we construct an RLCluster and pass everything into PpoLearner, which gives us a complete multi-turn PPO trainer in JAX.

Our multi-turn RL workflow is equally lightweight from the user's point of view. For example, with a 5-turn budget, the trainer repeatedly lets the LLM "play" the game for up to five conversational turns: at each turn it sees the current grid or state, reasons in language using the chat template, outputs a series of actions, and receives the next state and a verifiable reward signal (win/loss/score/step penalty). GRL's agent + env configs handle all the orchestration: they log observations, actions, and rewards into structured trajectories, which Tunix then turns into token-level advantages and returns for PPO updates. You don't manually build datasets or rollouts; the trainer owns the loop - interact -> log -> compute rewards -> update policy -> repeat.

In our preliminary experiments using this setup, training Qwen2.5-7B-Instruct on Sokoban and Tetris yielded strong in-domain gains (+2-56% across game variants). We also observed modest generalization to out-of-domain tasks, with consistent improvements in planning tasks (Blocksworld: +3-7%) and positive but unstable signals in computer use (Webshop: ~+6%). All scripts and configs are available in the GRL repo: https://github.com/lmgame-org/GRL/tree/main. To reproduce the end-to-end Tunix + GRL training example (including our Sokoban/Tetris runs), you can simply clone the repo and run one line: bash tunix_quick_training_example.sh.

Google TRC & TPUs: Accelerating Game-Based RL at Scale

A critical component of our research was the Google TPU Research Cloud (TRC) program. Access to Cloud TPUs allowed us to move from small-scale prototypes to production-grade training runs with minimal friction.

TPUs and JAX directly attacked our two biggest bottlenecks:

  1. Rollout Throughput: Using the vLLM-TPU path via tpu-inference, we could serve multiple model families on the same TPU v5p backend. This boosted sampling throughput, making the data-collection loop tighter and multi-environment concurrency cheaper.
  2. Multi-Host Scale for 7B Models: Tunix's lightweight design combined with JAX's mesh-based sharding allowed us to scale the same code from a single host to multi-host setups declaratively. This capability was essential for our experiments with 7B parameter models (such as Qwen2.5-7B), where we leveraged 2 v5p-8 hosts with minimal code change (in fact, only an env var config). The scale up is seamless, proving that the infrastructure can handle the heavy computational lifting required for modern LLM post-training without requiring complex engineering overhauls.
  3. Hardware Advantage: At the hardware level, the performance gains were significant. Each TPU v5p chip delivers around 459 BF16 TFLOPs, compared to roughly 312 on an NVIDIA A100. This raw power, combined with the TRC program's support, meant that large-N studies—involving more seeds, longer horizons, and more environments—became routine experiments rather than "special ops" engineering challenges.

This combination of Tunix's flexible abstraction and TRC's massive compute resources allowed us to iterate quickly on ideas while benefiting from production-grade infrastructure.

Get Started

GRL and Tunix are open for the community to explore. You can reproduce our end-to-end training example (including the Sokoban/Tetris runs) by cloning the repo, following the installation instructions, and then running a single command:

bash tunix_quick_training_example.sh

ESCA: Grounding embodied agents with scene graphs — Accelerated by JAX

Monday, December 15, 2025


Introduction

Multi-Modal Language Models (MLLMs) are increasingly forming the core of the brain for general-purpose embodied agents — AI that can navigate and act in the physical world as robots. While MLLMs are making rapid progress, they often stumble on a critical hurdle: precise visual perception. They struggle to reliably capture the fine-grained links between low-level visual features and high-level textual semantics.

Today, we are highlighting the work of Prof. Mayur Naik's research team at the University of Pennsylvania. To bridge the gap between high-level language and low-level visual features, they developed ESCA (Embodied and Scene-Graph Contextualized Agent). By porting their neurosymbolic pipeline to JAX, they achieved the real-time performance necessary for high-throughput decision-making. This work also demonstrates that JAX drives performance gains across a wide range of hardware, including standard CPUs and NVIDIA GPUs, and not just on Google TPUs.

In this blog, the UPenn team explains how they combined structured scene graphs with JAX's functional design to reduce perception errors by over 50% and achieve a 25% speedup in inference.


The "Grounding" Problem in Embodied AI

Existing MLLMs are powerful, but they can be surprisingly "blind" when tasked with interacting with the physical world. In our empirical analysis of 60 navigation tasks from EmbodiedBench, we found that 69% of agent failures stemmed from perception errors. See the figure below.

The three top-level error types are Perception, Reasoning, and Planning. The second-level errors are Hallucination, Wrong Recognition, Spatial Understanding, Spatial Reasoning, Reflection Error, Inaccurate Action, and Collision. For clarity, the figure uses these acronyms to label the different error types.

The models struggle to capture fine-grained links between visual features and textual semantics. They might recognize a "kitchen," but fail to identify the specific spatial relationship between a knife and a cutting board required to complete a task.

Enter ESCA: The Anglerfish of AI

To solve this, we introduced ESCA, a framework designed to contextualize MLLMs through open-domain scene graph generation.

Think of ESCA like the bioluminescent lure of a deep-sea anglerfish. Just as the fish illuminates its dark surroundings to reveal prey, ESCA "illuminates" the agent's environment by generating a structured Scene Graph—a map of objects, attributes, and relationships (e.g., Cup [Red] ON Table).

A key innovation here is Selective Grounding. Injecting a massive scene graph of everything in the room can overwhelm the model. Instead, ESCA identifies only the subset of objects and relations pertinent to the current instruction. It performs probabilistic reasoning to construct prompts enriched with exactly the contextual details the agent needs to act.

The Engine: LASER and Scallop

At the core of ESCA is LASER, a CLIP-based foundation model trained on 87k video-caption pairs. LASER uses Scallop—our neurosymbolic programming language that supports JAX backends—to align predicted scene graphs with logical specifications. This pipeline allows us to train low-level perception models to produce detailed graphs without needing tedious frame-level annotations.

JAX User Experience

1. The Power of Statelessness

JAX's design encouraged a fully functional, stateless architecture. Every component, from feature extraction to similarity computation, was made into a pure modular function. This structure enabled effective use of jit (Just-In-Time) compilation. The XLA compiler could fuse sequences—like normalization, matrix multiplication, and softmax—into fewer kernels, reducing intermediate buffers and lowering GPU overhead.

2. Handling Complex Control Flow

Our pipeline requires selecting the "top-k" most relevant objects from a probabilistic scene graph. This introduces complex control flow. JAX provided the primitives we needed to handle this efficiently:

  • We used jax.lax.cond to manage control flow inside the probabilistic graph.
  • We leveraged jax.nn and jax.numpy for all activation functions and batched math in a JIT-friendly way.

3. Debugging and Transparency

Migrating to JAX was also a learning experience. Tools like jax.debug.print/callback() allowed us to inspect values inside jit-compiled functions, while jax.disable_jit() let us easily switch to eager execution to step through the program seeing intermediate values.

Furthermore, the transparency of the open-source system was impressive. Being able to read the annotated source code and see how Python functions trace into jaxpr (JAX expression) gave us deep insight into how to design inference logic that scales.

4. Seamless Integration with Flax

NNX fits into our workflow perfectly. We used nnx.Module to structure the model and FrozenDict to keep parameters organized and immutable. The TrainState object made managing model parameters and optimizer states straightforward, without adding the complexity often found in other frameworks.

JAX Performance: A 25% Speedup

Embodied agents operate in a continuous loop: planning, acting, and updating their understanding of a dynamic world. High latency here is a dealbreaker. We ported LASER from PyTorch to JAX to improve real-time performance, and the benefits were significant. By rewriting our core similarity computations and feature pipelines as pure functions wrapped in jax.jit, we achieved significant gains.

On an NVIDIA H100 GPU, JAX reduced the average time per frame from 18.15 ms (PyTorch) to 14.55 ms (JAX)—a roughly 25% speedup.

Framework

Hardware

Avg Time Per Frame (ms) ↓

FPS ↑

PyTorch

H100 GPU

18.15 ± 0.73

55.15 ± 2.31

JAX

H100 GPU

14.55 ± 0.64

68.82 ± 3.13

Conclusion

ESCA demonstrates that better data—structured, grounded scene graphs—can solve the perception bottleneck in Embodied AI. But it also demonstrates that better infrastructure is required to run these systems in the real world. JAX provided the speed, transparency, and modularity needed to turn our research into a real-time agent capable of reliable reasoning.

Acknowledgements

This research was made possible through support from a Google Research Award to the University of Pennsylvania and from the ARPA-H program on Safe and Explainable AI under award D24AC00253-00.

Get Started

You can explore the LASER code, the ESCA framework and documentation for JAX and Flax at:

Empowering app developers: Fine-tuning Gemma 3 for mobile with Tunix in Google Colab

Thursday, December 11, 2025

In the rapidly evolving world of AI models for mobile devices, a persistent challenge is how to bring SOTA LLMs to smartphones without compromising on privacy or requiring App developers to be Machine Learning engineers.

Today, we are excited to talk about how Cactus, a startup building a next-gen inference engine for mobile devices, fine-tunes the open-source Gemma 3 model. By leveraging Tunix, the LLM post-training library in the JAX ML ecosystem, they achieved this entirely on Google Colab's Free Tier.

The Challenge: Making Small Models "Expert"

For app developers, running Large Language Models (LLMs) in the cloud isn't always an option due to privacy concerns (like GDPR) and latency requirements. The solution lies in running models locally on the device. However, most smartphones globally lack specialized MPUs (Micro Processing Units), meaning developers need highly efficient, smaller models.

While compact models like Gemma (270M or 1B parameters) are incredibly efficient, they are often "generalists." To be useful for specific mobile applications—such as a medical imaging assistant or a legal document analyzer—they need to be fine-tuned to become domain experts.

The problem? Most app developers are not ML infrastructure experts. Setting up complex training pipelines, managing dependencies, and navigating steep learning curves creates too much friction.

The Solution: SFT via Tunix on Google Colab

To solve this, Cactus created a simplified "Low-Friction" workflow by implementing a Python script using Supervised Fine Tuning (SFT) APIs of Tunix in a Colab.

1. The Engine: Tunix

Cactus utilized Tunix, Google's lightweight and modular LLM post-training library, which supports both SFT and leading RL algorithms, and executes natively on TPUs. Tunix strips away the complexity of heavy frameworks, offering a simplified path to Supervised Fine-Tuning (SFT).

2. The Access: Google Colab Free Tier

Accessibility was a key requirement. Instead of requiring developers to set up complex cloud billing and project IDs immediately, the workflow operates entirely within a Google Colab Notebook. By utilizing the free tier of Colab, developers can:

  • Load the Gemma 3 model.
  • Upload their specific dataset (e.g., medical data or customer service logs).
  • Run an SFT (Supervised Fine-Tuning) job using Tunix.
  • Export the weights for conversion.

3. The Deployment: Cactus

Once tuned, the model is converted into the Cactus graph format. This allows the now-specialized Gemma 3 model to be deployed directly into a Flutter or native mobile app with just a few lines of code, running efficiently on a wide range of smartphone hardware.

Why This Matters

"Our users are app developers, not ML engineers," explains Henry Ndubuaku, co-founder of Cactus. "They want to pick a model, upload data, and click 'tune.' By using Tunix and Colab, we can give them a 'clone-and-run' experience that removes the intimidation factor from fine-tuning."

This workflow represents the "lowest hanging fruit" in democratizing AI:

  • No complex local environment setup.
  • No upfront infrastructure costs.
  • High-performance JAX native Tunix library to tune a leading OSS model (Gemma).

What's Next?

While the Colab notebook provides an immediate, accessible solution, Cactus is exploring a future plan to build a full GUI-based portal for fine-tuning and quantization of LLMs with the back end compute as Google Cloud TPUs, allowing for scalable training of larger models and even more seamless integration into the mobile development lifecycle.

Get Started

Ready to turn your mobile app into an AI powerhouse? Check out the Tunix SFT Notebook for Cactus and start fine-tuning Gemma 3 for your device today:

You can explore Tunix sample scripts, documentation and repo at:

.