Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions)

  • Debug (e.g. transforming pmaps in vmaps within a context manager).

  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Modules overview can be found on GitHub.


Chex can be installed with pip directly from github, with the following command:

pip install git+git://

or from PyPI:

pip install chex

Citing Chex#

This repository is part of the DeepMind JAX Ecosystem.

To cite Chex please use the DeepMind JAX Ecosystem citation.



If you are having issues, please let us know by filing an issue on our issue tracker.


Chex is licensed under the Apache 2.0 License.

Indices and Tables#