Chex is a library of utilities for helping to write reliable JAX code.
This includes utils to help:
Instrument your code (e.g. assertions)
Debug (e.g. transforming pmaps in vmaps within a context manager).
Test JAX code across many variants (e.g. jitted vs non-jitted).
Modules overview can be found on GitHub.
Chex can be installed with pip directly from github, with the following command:
pip install git+git://github.com/deepmind/chex.git
or from PyPI:
pip install chex
This repository is part of the DeepMind JAX Ecosystem.
To cite Chex please use the DeepMind JAX Ecosystem citation.