This integration elegantly bridges the gap between Keras's intuitive abstractions and JAX's high-performance backend, offering a sophisticated toolkit for modern recommendation systems. It effectively simplifies the transition from rapid prototyping to production-grade scalability.
Deep Dive
Prerequisite Knowledge
- No data available.
Where to go next
- No data available.
Deep Dive
Introducing Keras Recommenders: state-of-the-art recommendation techniques at your fingertipsAdded:
Keras provides user-friendliness and accessibility, while JAX offers high-performance numerical computation.
We can now leverage the strengths of both, especially for detailed state management and advanced training capabilities within the JAX ecosystem.
All thanks to a new integration of Keras with Flax and NNX. Hi there, my name is Yufeng, and today we're going to check out how to use Keras with the Flax and NNX module system, demonstrating how this integration enhances your variable handling and training control. Keras is highly valued for its high-level API and intuitiveness, making deep learning development straightforward.
JAX is excellent for high-performance machine learning research and scalability. NNX is a modular neural network library designed for simplicity and power built on top of JAX.
It promotes ease of use through standard Python classes for modules and offers explicit state management via typed variable collections.
The integration of Keras with NNX allows you to use the modularity of Keras for model construction, while benefiting from the power and explicit control of NNX and JAX for variable management and advanced training loops. So, to activate this feature, we must first set two environment variables before importing Keras.
This enables NNX as an opt-in feature.
We set the backend to JAX and then explicitly enable NNX mode by setting the Keras NNX enabled environment variable to true.
The core of this integration lies in the keras.Variable, which is designed to be an instance of nnx.Variable from the Flax and NNX ecosystem. This means you can freely mix Keras and NNX components, and NNX's state management tools will successfully track your Keras variables. Here's an example of what that looks like. We've got an NNX module that has a linear layer we've called linear, and a vector value I've named NNX variable, so we can easily keep track of things. I've also added a Keras variable as part of the model called custom variable. We can see in the call function we're adding the NNX variable and the custom variable to the results of the linear transform being applied.
Once we have the model instantiated, there are a couple of tests we can run to really verify that the custom variable is set up just as we'd expect.
First, we check to see that the Keras variable has what's called a trace state, meaning that NNX has successfully traced through this variable, allowing it to just-in-time compile it along with the rest of the model.
We can do this by confirming that it has the attribute trace state. Second, we want to make sure that NNX is counting this variable among all the variables it's aware of using nnx.variables.
This shows all the variables that NNX is tracking, and indeed, we do see that our custom variable is listed.
Third and finally, let's confirm that the variable's value can be accessed directly by NNX, even though it's a Keras variable inside of an NNX model, the NNX model has no problem fetching its value. Hopefully, I've convinced you by now that keras.Variable is successfully integrated with NNX, allowing Keras state and NNX state to coexist seamlessly. This integration provides two powerful training workflows. The first one is going to feel just like classic Keras, but it runs NNX modules inside of Keras, letting Keras manage the training workflow.
The other approach uses NNX to run the training workflow with Keras models inside of NNX training loops.
So, in this first version, your existing high-level Keras code, including model.compile and model.fit, it all works out of the box. And under the hood, this productive experience is powered by JAX and NNX.
Here we have the other path. For maximum flexibility and fine-grained control, you can treat any Keras model or layer as an nnx.Module.
This allows you to write your own custom training loop using JAX libraries, such as Optax for optimizers, while mix and matching the model's components just as we saw in our very first example with custom variable.
You can think of this as Keras inside of NNX. Here we see an example of a Keras model with a couple of dense layers, and once that model is created, the rest of the workflow is entirely NNX and JAX code. We'll select an optimizer, we'll write a custom train step to compute the loss, the gradients, and perform the updates to the model weights. Notice that we're using the decorator nnx.jit instead of jax.jit. This special decorator speeds up your NNX code by using just-in-time compilation, and our Keras model gets to benefit from it, too.
In short, a Keras model object gets to do everything that an NNX model can do.
It's able to be passed seamlessly to NNX optimizer, differentiated using nnx.grad, and used with the broader JAX ecosystem of libraries. The Keras NNX integration offers a significant step forward, providing a unified framework for both rapid prototyping and high-performance customizable research. You can leverage the entire JAX ecosystem, including nnx.jit and libraries like Optax, while still using familiar Keras APIs like model.fit and model.save. The code shown today was an adapted sample of a complete guide on the keras.io website. So, if you're ready to dive in and try out Keras with NNX for yourself, definitely go through that guide first.
It's the perfect starting point to get hands-on. So, what are you going to be building with Keras along with the NNX backend? Share your thoughts in the comments below. Remember, for the complete guide and code examples, hit up that link in the description.
Thanks for watching, and I'll catch you in the next one.
Related Videos
Agentforce NOW AMA: Build with React and Salesforce Multi-Framework
SalesforceDevs
490 views•2026-05-28
How agent o11y differs from traditional o11y — Phil Hetzel, Braintrust
aiDotEngineer
450 views•2026-05-28
WEB TECHNOLOGIES UNIT-2 | Degree 4th sem BCOM Computers web technologies unit-2 full explanation💯✅
LearnwithSahera
1K views•2026-05-29
More tests are always better? How to use AI to identify tests that bring little value
Alliance4Qualification
335 views•2026-05-29
Search Algorithms Explained in 60 Seconds! 🤖💨
samarthtuliofficial
218 views•2026-06-01
People of Game of Thrones using JavaScript DOM
AltCampus
296 views•2026-05-30
Introduction to Problem Solving Part - 1 | Lecture 1 | Intermediate DSA
ascensionix
107 views•2026-05-29
🚀 BCS613C Compiler Design | Module 1 to 5 Schema Evaluation 🔥 | VTU 6th Sem 💯 #VTU #bcs613c #exam
Pranavaa-y4y
104 views•2026-06-02











