1 How JAX makes high-performance economics accessible | Google Open Source Blog

opensource.google.com

Menu

How JAX makes high-performance economics accessible

Tuesday, November 11, 2025

How JAX makes high-performance economics accessible

JAX is widely recognized for its power in training large-scale AI models, but its core design as a system for composable function transformations unlocks its potential in a much broader scientific landscape. We're seeing adoption for applications as disparate as AI-driven protein engineering to solving high-order Partial Differential Equations (PDEs). Today, we're excited to highlight another frontier where JAX is making a significant impact: enabling economists to model complex, real-world scenarios that shape national policy—computational economics.
I recently spoke with economist John Stachurski, a co-founder of QuantEcon and an early advocate for open-source scientific computing. His story of collaborating with the Central Bank of Chile demonstrates how JAX makes achieving performance easy and accessible. John's journey shows how JAX's intuitive design and abstractions allow domain experts to solve scientific problems without needing to become parallel programming specialists. John shares the story in his own words.


A Tale of Two Implementations: The Central Bank of Chile's Challenge
Due to my work with QuantEcon, I was contacted by the Central Bank of Chile (CBC), which was facing a computational bottleneck with one of their core models. The bank's work is high-stakes; their role is to set monetary policy and act as the lender of last resort during financial crises. Such crises are inherently non-linear in nature, involving self-reinforcing cycles and feedback loops that make them challenging to model and assess.
To better prepare themselves for such crises, the CBC began working on a model originally developed by Jarvier Bianchi, in which an economic shock worsens the balance sheets of domestic economic agents, reducing collateral and tightening credit constraints. This leads to further deterioration in balance sheets, which again tightens credit constraints, and so on. The result is a downward spiral. The ramifications can be large in a country such as Chile, where economic and political instability are historically linked.

The Problem:

The task of implementing this model was led by talented CBC economist Carlos Rondon. Carlos wrote the first version using a well-known proprietary package for mathematical modeling that has been used extensively by economists over the past few decades. The completed model took 12 hours to run -- that is, to generate prices and quantities implied by a fixed set of parameters -- on a $10,000 mainframe with 356 CPUs and a terabyte of RAM. A 12 hour run-time made it almost impossible to calibrate the model and run useful scenarios. A better solution had to be found.

Carlos and I agreed that the problem was rooted in the underlying software package. The issue was that, to avoid using slow loops, all operations needed to be vectorized, so that they could be passed to precompiled binaries generated from Fortran libraries such as LAPACK. However, as users of these traditional vectorization-based environments will know, it is often necessary to generate many intermediate arrays in order to obtain a given output array. When these arrays are high-dimensional, this process is slow and extremely memory intensive. Moreover, while some manual parallelization is possible, truly efficient parallelization is difficult to achieve.

The JAX Solution:

I flew to Santiago and we began a complete rewrite in JAX. Working side-by-side, we soon found that JAX was exactly the right tool for our task. In only two days we were able to reimplement the model and — running on a consumer-grade GPU — we observed a dramatic improvement in wall-clock time . The algorithm was unchanged, but even a cheap GPU outperformed the industrial server by a factor of a thousand. Now the model was fully operational: fast, clean, and ready for calibration.
There were several factors behind the project's success. First, JAX's elegant functional style allowed us to express the economic model's logic in a way that closely mirrored the underlying mathematics. Second, we fully exploited JAX's vmap by layering it to represent nested for loops. This allowed us to work with functions that operate on scalar values (think of a function that performs the calculations on the inside of a nested for loop), rather than attempting to operate directly on high dimensional arrays — a process that is inherently error-prone and difficult to visualize.

Third, JAX automates parallelization and does it extremely efficiently. We both had experience with manual parallelization prior to using JAX. I even fancied I was good at this task. But, at the end of the day, the majority of our expertise is in economics and mathematics, not computer science. Once we handed over parallelization to JAX's compiler OpenXLA we saw a massive speed up. Of course, the fact that XLA generates specialised GPU kernels on the fly was a key part of our success.
I have to stress how much I enjoyed completing this project with JAX. First, we could write code on a laptop and then run exactly the same code on any GPU, without changing a single line. Second, for scientific computing, the pairing of an interpreted language like Python with a powerful JIT compiler provides the ideal combination of interactivity and speed. To my mind, everything about the JAX framework and compilers is just right. A functional programming style makes perfect sense in a world where functions are individually JIT-compiled. Once we adopt this paradigm, everything becomes cleaner. Throw in automatic differentiation and NumPy API compatibility and you have a close-to-perfect environment for writing high performance code for economic modeling.


Unlocking the Next Generation of Economic Models

John's story captures the essence of JAX's power. By making high performance accessible to researchers, JAX is not just accelerating existing workloads; it's democratizing access to performance and enabling entirely new avenues of research.
As economists build models that incorporate more realistic heterogeneity—such as varying wealth levels, firm sizes, ages, and education—JAX enables them to take full advantage of modern accelerators like GPUs and Google TPUs. JAX's strengths in both scientific computing and deep learning make it the ideal foundation to bridge this gap.

Explore the JAX Scientific Computing Ecosystem

Stories like John's highlight a growing trend: JAX is much more than a framework for building the largest machine learning models on the planet. It is a powerful, general-purpose framework for array-based computing across all sciences which, together with accelerators such as Google TPUs and GPUs, is empowering a new generation of scientific discovery. The JAX team at Google is committed to supporting and growing this vibrant ecosystem, and that starts with hearing directly from you.

  • Share your story: Are you using JAX to tackle a challenging scientific problem? We would love to learn how JAX is accelerating your research.
  • Help guide our roadmap: Are there new features or capabilities that would unlock your next breakthrough? Your feature requests are essential for guiding the evolution of JAX.

Please reach out to the team via GitHub to share your work or discuss what you need from JAX. You can also find documentation, examples, news, events, and more at jaxstack.ai and jax.dev.

Sincere thanks to John Stachurski for sharing his insightful journey with us. We're excited to see how he and other researchers continue to leverage JAX to solve the world's most complex scientific problems.

.