Assertions

Checks that 

Asserts that pass_fn(tensor.shape[axis]) passes. 

Checks that 

Checks that 

Checks that 

Checks that 

Checks that n devices of a given type are available. 

Checks that the two objects are equal as determined by the == operator. 

Checks that all arrays have the same rank. 

Checks that all arrays have the same size. 

Checks that all arrays have the same shape. 

Checks that the leading 

Checks that the final 

Checks that one and only one of the arguments is None. 

Checks that at least one GPU device is available. 

Checks that an array of 

Checks that 

Checks that a function is traced at most n times (inclusively). 

Checks that at least one of the arguments is not None. 

Checks that autodiff and numerical gradients of a function match. 

Checks that the rank of all inputs matches specified 
Checks that 


Checks that argument is a scalar within segment (by default). 
Checks that a scalar is negative. 

Checks that a scalar is nonnegative. 

Checks that a scalar is positive. 


Checks that the size of all inputs matches specified 

Checks that the shape of all inputs matches specified 

Checks that at least one TPU device is available. 

Checks that all leaves in a tree are finite. 
Checks that all tree's leaves are ndimensional arrays (tensors). 


Checks that all leaves are ndarrays residing in device memory (in HBM). 

Checks that all leaves are ndarrays residing in the host memory (on CPU). 

Checks that all leaves are ndarrays sharded across the specified devices. 

Checks that a tree does not contain None. 

Checks that all 

Checks that all 

Checks that all trees have leaves with approximately equal values. 

Checks that tree leaves differ by at most maxulp Units in the Last Place. 

Checks that all trees have leaves with exactly equal values. 
Checks that all trees are equal as per the custom comparator for leaves. 


Checks that trees' leaves have the same dtype. 

Checks that trees have the same structure and leaves' sizes. 

Checks that trees have the same structure and leaves' shapes. 
Checks that trees' leaves have the same shape and dtype. 


Checks that trees have the same structure. 

Checks that the type of all inputs matches specified 

Wraps a transformed function fn to enable Chex value assertions. 

A set of checks imported from checkify. 

An alias for chexify (see the docs). 
Waits until all asynchronous checks complete. 


A lightweight utility that maps strings to shape tuples. 
Disables all Chex assertions. 

Enables Chex assertions. 

Clears Chex traces' counter for 


Wrap chex assertion to only be evaluated if positional args not None. 
Jax Assertions
 chex.assert_max_traces(fn=None, n=None)[source]
Checks that a function is traced at most n times (inclusively).
JAX retraces jitted functions every time the structure of passed arguments changes. Often this behaviour is inadvertent and leads to a significant performance drop which is hard to debug. This wrapper checks that the function is retraced at most n times during program execution.
Examples:
@jax.jit @chex.assert_max_traces(n=1) def fn_sum_jitted(x, y): return x + y def fn_sub(x, y): return x  y fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10)
 More about tracing:
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
 Parameters
fn (Optional[Union[Callable[..., Any], int]]) – A pure python function to wrap (i.e. it must not be a jitted function).
n (Optional[Union[Callable[..., Any], int]]) – The maximum allowed number of retraces (nonnegative).
 Returns
Decorated function that raises exception when it is retraced n+1st time.
 Raises
ValueError – If
fn
has already been jitted.
 chex.assert_devices_available(n, devtype, backend=None, not_less_than=False)[source]
Checks that n devices of a given type are available.
 Parameters
n (int) – A required number of devices of the given type.
devtype (str) – A type of devices, one of
{'cpu', 'gpu', 'tpu'}
.backend (Optional[str]) – A type of backend to use (uses Jax default if not provided).
not_less_than (bool) – Whether to check if the number of devices is not less than n, instead of precise comparison.
 Raises
AssertionError – If number of available device of a given type is not equal or less than n.
 Return type
None
Value (Runtime) Assertions
 chex.chexify(fn, async_check=True, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[source]
Wraps a transformed function fn to enable Chex value assertions.
Chex value/runtime assertions access concrete values of tensors (e.g. assert_tree_all_finite) which are not available during JAX tracing, see https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html and https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.
This wrapper enables them in jitted/pmapped functions by performing a specifically designed JAX transformation https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#thecheckifytransformation and calling functionalised checks https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html
Example:
@chex.chexify @jax.jit def logp1_abs_safe(x: chex.Array) > chex.Array: chex.assert_tree_all_finite(x) return jnp.log(jnp.abs(x) + 1) logp1_abs_safe(jnp.ones(2)) # OK logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS logp1_abs_safe.wait_checks()
Note 1: This wrapper allows identifying the first failed assertion in a jitted code by printing a pointer to the line where the failed assertion was invoked. For getting verbose messages (including concrete tensor values), an unjitted version of the code will need to be executed with the same input values. Chex does not currently provide tools to help with this.
Note 2: This wrapper fully supports asynchronous executions (see https://jax.readthedocs.io/en/latest/async_dispatch.html). To block program execution until asynchronous checks for a _chexified_ function fn complete, call fn.wait_checks(). Similarly, chex.block_until_chexify_assertions_complete() will block program execution until _all_ asyncronous checks complete.
Note 3: Chex automatically selects the backend for executing its assertions (i.e. CPU or device accelerator) depending on the program context.
Note 4: Value assertions can have impact on the performance of a function, see https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations
Note 5: static assertions, such as assert_shape or assert_trees_all_equal_dtypes, can be called from a jitted function without chexify wrapper (since they do not access concrete values, only shapes and/or dtypes which are available during JAX tracing).
More examples can be found at https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py
 Parameters
fn (Callable[..., Any]) – A transformed function to wrap.
async_check (bool) – Whether to check errors in the async dispatch mode. See https://jax.readthedocs.io/en/latest/async_dispatch.html.
errors (FrozenSet[checkify.ErrorCategory]) – A set of checkify.ErrorCategory values which defines the set of enabled checks. By default only explicit
checks
are enabled (user). You can also for example enable NaN and Divby0 errors by passing the float set, or for example combine multiple sets through set operations (float  user).
 Return type
Callable[…, Any]
 Returns
A _chexified_ function, i.e. the one with enabled value assertions. The returned function has wait_checks() method that blocks the caller until all pending async checks complete.

A set of checks imported from checkify. 
Tree Assertions
 chex.assert_tree_all_finite(tree_like)[source]
Checks that all leaves in a tree are finite.
 Parameters
tree_like (ArrayTree) – A pytree with array leaves.
 Raises
AssertionError – If any leaf in
tree_like
is nonfinite. Return type
None
 chex.assert_tree_has_only_ndarrays(tree)[source]
Checks that all tree’s leaves are ndimensional arrays (tensors).
 Parameters
tree (ArrayTree) – A tree to assert.
 Raises
AssertionError – If the tree contains an object which is not an ndarray.
 Return type
None
 chex.assert_tree_is_on_device(tree, *, platform=('gpu', 'tpu'), device=None)[source]
Checks that all leaves are ndarrays residing in device memory (in HBM).
Sharded DeviceArrays are disallowed.
 Parameters
tree (ArrayTree) – A tree to assert.
platform (Union[Sequence[str], str]) – A platform or a list of platforms where the leaves are expected to reside. Ignored if device is specified.
device (Optional[pytypes.Device]) – An optional device where the tree’s arrays are expected to reside. Any device (except CPU) is accepted if not specified.
 Raises
AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on the specified device or platform.
 Return type
None
 chex.assert_tree_is_on_host(tree, *, allow_cpu_device=True, allow_sharded_arrays=False)[source]
Checks that all leaves are ndarrays residing in the host memory (on CPU).
This assertion only accepts trees consisting of ndarrays.
 Parameters
tree (ArrayTree) – A tree to assert.
allow_cpu_device (bool) – Whether to allow JAX arrays that reside on a CPU device.
allow_sharded_arrays (bool) – Whether to allow sharded JAX arrays. Sharded arrays are considered “on host” only if they are sharded across CPU devices and allow_cpu_device is True.
 Raises
AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on host.
 Return type
None
 chex.assert_tree_is_sharded(tree, *, devices)[source]
Checks that all leaves are ndarrays sharded across the specified devices.
 Parameters
tree (ArrayTree) – A tree to assert.
devices (Sequence[pytypes.Device]) – A list of devices which the tree’s leaves are expected to be sharded across. This list is ordersensitive.
 Raises
AssertionError – If the tree contains a leaf that is not a device array sharded across the specified devices.
 Return type
None
 chex.assert_tree_no_nones(tree)[source]
Checks that a tree does not contain None.
 Parameters
tree (ArrayTree) – A tree to assert.
 Raises
AssertionError – If the tree contains at least one None.
 Return type
None
 chex.assert_tree_shape_prefix(tree, shape_prefix)[source]
Checks that all
tree
leaves’ shapes have the same prefix. Parameters
tree (ArrayTree) – A tree to check.
shape_prefix (Sequence[int]) – An expected shape prefix.
 Raises
AssertionError – If some leaf’s shape doesn’t start with
shape_prefix
. Return type
None
 chex.assert_tree_shape_suffix(tree, shape_suffix)[source]
Checks that all
tree
leaves’ shapes have the same suffix. Parameters
tree (ArrayTree) – A tree to check.
shape_suffix (Sequence[int]) – An expected shape suffix.
 Raises
AssertionError – If some leaf’s shape doesn’t start with
shape_suffix
. Return type
None
 chex.assert_trees_all_close(*trees, rtol=1e06, atol=0.0)[source]
Checks that all trees have leaves with approximately equal values.
 This compares the difference between values of actual and desired up to
atol + rtol * abs(desired)
.
 Parameters
*trees – A sequence of (at least 2) trees with array leaves.
rtol (float) – A relative tolerance.
atol (float) – An absolute tolerance.
 Raises
AssertionError – If actual and desired values are not equal up to specified tolerance.
 Return type
None
 chex.assert_trees_all_close_ulp(*trees, maxulp=1)[source]
Checks that tree leaves differ by at most maxulp Units in the Last Place.
This is the Chex version of np.testing.assert_array_max_ulp.
Assertions on floating point values are tricky because the precision varies depending on the value. For example, with float32, the precision at 1 is np.spacing(np.float32(1.0)) ≈ 1e7, but the precision at 5,000,000 is only np.spacing(np.float32(5e6)) = 0.5. This makes it hard to predict ahead of time what tolerance to use when checking whether two numbers are equal: a difference of only a couple of bits can equate to an arbitrarily large absolute difference.
Assertions based on _relative_ differences are one solution to this problem, but have the disadvantage that it’s hard to choose the tolerance. If you want to verify that two calculations produce _exactly_ the same result modulo the inherent nondeterminism of floating point operations, do you set the tolerance to…0.01? 0.001? It’s hard to be sure you’ve set it low enough that you won’t miss one of your computations being slightly wrong.
Assertions based on ‘units in the last place’ (ULP) instead solve this problem by letting you specify tolerances in terms of the precision actually available at the current scale of your values. The ULP at some value x is essentially the spacing between the floating point numbers actually representable in the vicinity of x  equivalent to the ‘precision’ we discussed above. above. With a tolerance of, say, maxulp=5, you’re saying that two values are within 5 actuallyrepresentablenumbers of each other  a strong guarantee that two computations are as close as possible to identical, while still allowing reasonable wiggle room for small differences due to e.g. different operator orderings.
Note that this function is not currently supported within JIT contexts, and does not currently support bfloat16 dtypes.
 Parameters
*trees – A sequence of (at least 2) trees with array leaves.
maxulp (int) – The maximum number of ULPs by which leaves may differ.
 Raises
AssertionError – If actual and desired values are not equal up to specified tolerance.
 Return type
None
 chex.assert_trees_all_equal(*trees, strict=False)[source]
Checks that all trees have leaves with exactly equal values.
If you are comparing floating point numbers, an exact equality check may not be appropriate; consider using
assert_trees_all_close
. Parameters
*trees – A sequence of (at least 2) trees with array leaves.
strict (bool) – If True, disable special scalar handling as described in np.testing.assert_array_equals notes section.
 Raises
AssertionError – If the leaf values actual and desired are not exactly equal.
 Return type
None
 chex.assert_trees_all_equal_comparator(equality_comparator, error_msg_fn, *trees)[source]
Checks that all trees are equal as per the custom comparator for leaves.
 Parameters
equality_comparator (_ai.TLeavesEqCmpFn) – A custom function that accepts two leaves and checks whether they are equal. Expected to be transitive.
error_msg_fn (_ai.TLeavesEqCmpErrorFn) – A function accepting two unequal as per
equality_comparator
leaves and returning an error message.*trees – A sequence of (at least 2) trees to check on equality as per
equality_comparator
.
 Raises
ValueError – If
trees
does not contain at least 2 elements.AssertionError – if
equality_comparator
returns False for any pair of trees fromtrees
.
 Return type
None
 chex.assert_trees_all_equal_dtypes(*trees)[source]
Checks that trees’ leaves have the same dtype.
 Parameters
*trees – A sequence of (at least 2) trees to check.
 Raises
AssertionError – If leaves’ dtypes for any two trees differ.
 Return type
None
 chex.assert_trees_all_equal_sizes(*trees)[source]
Checks that trees have the same structure and leaves’ sizes.
 Parameters
*trees – A sequence of (at least 2) trees with array leaves.
 Raises
AssertionError – If trees’ structures or leaves’ sizes are different.
 Return type
None
 chex.assert_trees_all_equal_shapes(*trees)[source]
Checks that trees have the same structure and leaves’ shapes.
 Parameters
*trees – A sequence of (at least 2) trees with array leaves.
 Raises
AssertionError – If trees’ structures or leaves’ shapes are different.
 Return type
None
 chex.assert_trees_all_equal_shapes_and_dtypes(*trees)[source]
Checks that trees’ leaves have the same shape and dtype.
 Parameters
*trees – A sequence of (at least 2) trees to check.
 Raises
AssertionError – If leaves’ shapes or dtypes for any two trees differ.
 Return type
None
 chex.assert_trees_all_equal_structs(*trees)[source]
Checks that trees have the same structure.
 Parameters
*trees – A sequence of (at least 2) trees to assert equal structure between.
 Raises
ValueError – If
trees
does not contain at least 2 elements.AssertionError – If structures of any two trees are different.
 Return type
None
Generic Assertions
 chex.assert_axis_dimension(tensor, axis, expected)[source]
Checks that
tensor.shape[axis] == expected
. Parameters
tensor (Array) – A JAX array.
axis (int) – An integer specifying which axis to assert.
expected (int) – An expected value of
tensor.shape[axis]
.
 Raises
AssertionError – The dimension of the specified axis does not match the prescribed value.
 Return type
None
 chex.assert_axis_dimension_comparator(tensor, axis, pass_fn, error_string)[source]
Asserts that pass_fn(tensor.shape[axis]) passes.
Used to implement ==, >, >=, <, <= checks.
 Parameters
tensor (Array) – A JAX array.
axis (int) – An integer specifying which axis to assert.
pass_fn (Callable[[int], bool]) – A callable which takes the size of the give dimension and returns false when the assertion should fail.
error_string (str) – string which is inserted in assertion failure messages  ‘expected tensor to have dimension {error_string} on axis …’.
 Raises
AssertionError – if pass_fn(tensor.shape[axis], val) does not return true.
 chex.assert_axis_dimension_gt(tensor, axis, val)[source]
Checks that
tensor.shape[axis] > val
. Parameters
tensor (Array) – A JAX array.
axis (int) – An integer specifying which axis to assert.
val (int) – A value
tensor.shape[axis]
must be greater than.
 Raises
AssertionError – if the dimension of
axis
is <=val
. Return type
None
 chex.assert_axis_dimension_gteq(tensor, axis, val)[source]
Checks that
tensor.shape[axis] >= val
. Parameters
tensor (Array) – A JAX array.
axis (int) – An integer specifying which axis to assert.
val (int) – A value
tensor.shape[axis]
must be greater than or equal to.
 Raises
AssertionError – if the dimension of
axis
is <val
. Return type
None
 chex.assert_axis_dimension_lt(tensor, axis, val)[source]
Checks that
tensor.shape[axis] < val
. Parameters
tensor (Array) – A JAX Array.
axis (int) – An integer specifiying with axis to assert.
val (int) – A value
tensor.shape[axis]
must be less than.
 Raises
AssertionError – if the dimension of
axis
is >=val
. Return type
None
 chex.assert_axis_dimension_lteq(tensor, axis, val)[source]
Checks that
tensor.shape[axis] <= val
. Parameters
tensor (Array) – A JAX array.
axis (int) – An integer specifying which axis to assert.
val (int) – A value
tensor.shape[axis]
must be less than or equal to.
 Raises
AssertionError – if the dimension of
axis
is >val
. Return type
None
 chex.assert_equal(first, second)[source]
Checks that the two objects are equal as determined by the == operator.
Arrays with more than one element cannot be compared. Use
assert_trees_all_close
to compare arrays. Parameters
first (Any) – A first object.
second (Any) – A second object.
 Raises
AssertionError – If not
(first == second)
. Return type
None
 chex.assert_equal_rank(inputs)[source]
Checks that all arrays have the same rank.
 Parameters
inputs (Sequence[Array]) – A collection of arrays.
 Raises
AssertionError – If the ranks of all arrays do not match.
ValueError – If
inputs
is not a collection of arrays.
 Return type
None
 chex.assert_equal_size(inputs)[source]
Checks that all arrays have the same size.
 Parameters
inputs (Sequence[Array]) – A collection of arrays.
 Raises
AssertionError – If the size of all arrays do not match.
 Return type
None
 chex.assert_equal_shape(inputs, *, dims=None)[source]
Checks that all arrays have the same shape.
 Parameters
inputs (Sequence[Array]) – A collection of arrays.
dims (Optional[Union[int, Sequence[int]]]) – An optional integer or sequence of integers. If not provided, every dimension of every shape must match. If provided, equality of shape will only be asserted for the specified dim(s), i.e. to ensure all of a group of arrays have the same size in the first two dimensions, call
assert_equal_shape(tensors_list, dims=(0, 1))
.
 Raises
AssertionError – If the shapes of all arrays at specified dims do not match.
ValueError – If the provided
dims
are invalid indices into any of arrays; or ifinputs
is not a collection of arrays.
 Return type
None
 chex.assert_equal_shape_prefix(inputs, prefix_len)[source]
Checks that the leading
prefix_dims
dims of all inputs have same shape. Parameters
inputs (Sequence[Array]) – A collection of input arrays.
prefix_len (int) – A number of leading dimensions to compare; each input’s shape will be sliced to
shape[:prefix_len]
. Negative values are accepted and have the conventional Python indexing semantics.
 Raises
AssertionError – If the shapes of all arrays do not match.
ValuleError – If
inputs
is not a collection of arrays.
 Return type
None
 chex.assert_equal_shape_suffix(inputs, suffix_len)[source]
Checks that the final
suffix_len
dims of all inputs have same shape. Parameters
inputs (Sequence[Array]) – A collection of input arrays.
suffix_len (int) – A number of trailing dimensions to compare; each input’s shape will be sliced to
shape[suffix_len:]
. Negative values are accepted and have the conventional Python indexing semantics.
 Raises
AssertionError – If the shapes of all arrays do not match.
ValuleError – If
inputs
is not a collection of arrays.
 Return type
None
 chex.assert_exactly_one_is_none(first, second)[source]
Checks that one and only one of the arguments is None.
 Parameters
first (Any) – A first object.
second (Any) – A second object.
 Raises
AssertionError – If
(first is None) xor (second is None)
is False. Return type
None
 chex.assert_is_broadcastable(shape_a, shape_b)[source]
Checks that an array of
shape_a
is broadcastable to one ofshape_b
. Parameters
shape_a (Sequence[int]) – A shape of the array to check.
shape_b (Sequence[int]) – A target shape after broadcasting.
 Raises
AssertionError – If
shape_a
is not broadcastable toshape_b
. Return type
None
 chex.assert_is_divisible(numerator, denominator)[source]
Checks that
numerator
is divisible bydenominator
. Parameters
numerator (int) – A numerator.
denominator (int) – A denominator.
 Raises
AssertionError – If
numerator
is not divisible bydenominator
. Return type
None
 chex.assert_not_both_none(first, second)[source]
Checks that at least one of the arguments is not None.
 Parameters
first (Any) – A first object.
second (Any) – A second object.
 Raises
AssertionError – If
(first is None) and (second is None)
. Return type
None
 chex.assert_numerical_grads(f, f_args, order, atol=0.01, **check_kwargs)[source]
Checks that autodiff and numerical gradients of a function match.
 Parameters
f (Callable[..., Array]) – A function to check.
f_args (Sequence[Array]) – Arguments of the function.
order (int) – An order of gradients.
atol (float) – An absolute tolerance.
**check_kwargs – Kwargs for
jax_test.check_grads
.
 Raises
AssertionError – If automatic differentiation gradients deviate from finite difference gradients.
 Return type
None
 chex.assert_rank(inputs, expected_ranks)[source]
Checks that the rank of all inputs matches specified
expected_ranks
.Valid usages include:
assert_rank(x, 0) # x is scalar assert_rank(x, 2) # x is a rank2 array assert_rank(x, {0, 2}) # x is scalar or rank2 array assert_rank([x, y], 2) # x and y are rank2 arrays assert_rank([x, y], [0, 2]) # x is scalar and y is a rank2 array assert_rank([x, y], {0, 2}) # x and y are scalar or rank2 arrays
 Parameters
inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.
expected_ranks (Union[int, Set[int], Sequence[Union[int, Set[int]]]]) – A sequence of expected ranks associated with each input, where the expected rank is either an integer or set of integer options; if all inputs have same rank, a single scalar or set of scalars may be passed as
expected_ranks
.
 Raises
AssertionError – If lengths of
inputs
andexpected_ranks
don’t match; ifexpected_ranks
has wrong type; if the ranks ofinputs
do not matchexpected_ranks
.ValueError – If
expected_ranks
is not an integer and not a sequence of integets.
 Return type
None
 chex.assert_scalar(x)[source]
Checks that
x
is a scalar, as defined in pytypes.py (int or float). Parameters
x (Scalar) – An object to check.
 Raises
AssertionError – If
x
is not a scalar as per definition in pytypes.py. Return type
None
 chex.assert_scalar_in(x, min_, max_, included=True)[source]
Checks that argument is a scalar within segment (by default).
 Parameters
x (Any) – An object to check.
min – A left border of the segment.
max – A right border of the segment.
included (bool) – Whether to include the borders of the segment in the set of allowed values.
 Raises
AssertionError – If
x
is not a scalar; ifx
falls out of the segment. Return type
None
 chex.assert_scalar_negative(x)[source]
Checks that a scalar is negative.
 Parameters
x (Scalar) – A value to check.
 Raises
AssertionError – If
x
is not a scalar or strictly negative. Return type
None
 chex.assert_scalar_non_negative(x)[source]
Checks that a scalar is nonnegative.
 Parameters
x (Scalar) – A value to check.
 Raises
AssertionError – If
x
is not a scalar or negative. Return type
None
 chex.assert_scalar_positive(x)[source]
Checks that a scalar is positive.
 Parameters
x (Scalar) – A value to check.
 Raises
AssertionError – If
x
is not a scalar or strictly positive. Return type
None
 chex.assert_size(inputs, expected_sizes)[source]
Checks that the size of all inputs matches specified
expected_sizes
.Valid usages include:
assert_size(x, 1) # x is scalar (size 1) assert_size([x, y], (2, {1, 3})) # x has size 2, y has size 1 or 3 assert_size([x, y], (2, ...)) # x has size 2, y has any size assert_size([x, y], 1) # x and y are scalar (size 1) assert_size((x, y), (5, 2)) # x has size 5, y has size 2
 Parameters
inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.
expected_sizes (Union[_ai.TShapeMatcher, Sequence[_ai.TShapeMatcher]]) – A sqeuence of expected sizes associated with each input, where the expected size is a sequence of integer and None dimensions; if all inputs have same size, a single size may be passed as
expected_sizes
.
 Raises
AssertionError – If the lengths of
inputs
andexpected_sizes
do not match; ifexpected_sizes
has wrong type; if size ofinput
does not matchexpected_sizes
. Return type
None
 chex.assert_shape(inputs, expected_shapes)[source]
Checks that the shape of all inputs matches specified
expected_shapes
.Valid usages include:
assert_shape(x, ()) # x is scalar assert_shape(x, (2, 3)) # x has shape (2, 3) assert_shape(x, (2, {1, 3})) # x has shape (2, 1) or (2, 3) assert_shape(x, (2, None)) # x has rank 2 and `x.shape[0] == 2` assert_shape(x, (2, ...)) # x has rank >= 1 and `x.shape[0] == 2` assert_shape([x, y], ()) # x and y are scalar assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
 Parameters
inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.
expected_shapes (Union[_ai.TShapeMatcher, Sequence[_ai.TShapeMatcher]]) – A sequence of expected shapes associated with each input, where the expected shape is a sequence of integer and None dimensions; if all inputs have same shape, a single shape may be passed as
expected_shapes
.
 Raises
AssertionError – If the lengths of
inputs
andexpected_shapes
do not match; ifexpected_shapes
has wrong type; if shape ofinput
does not matchexpected_shapes
. Return type
None
 chex.assert_type(inputs, expected_types)[source]
Checks that the type of all inputs matches specified
expected_types
.Valid usages include:
assert_type(7, int) assert_type(7.1, float) assert_type(False, bool) assert_type([7, 8], int) assert_type([7, 7.1], [int, float]) assert_type(np.array(7), int) assert_type(np.array(7.1), float) assert_type(jnp.array(7), int) assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float])
 Parameters
inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays or scalars.
expected_types (Union[Type[Scalar], Sequence[Type[Scalar]]]) – A sequence of expected types associated with each input; if all inputs have same type, a single type may be passed as
expected_types
.
 Raises
AssertionError – If lengths of
inputs
andexpected_types
don’t match; ifexpected_types
contains unsupported pytype; if the types of inputs do not match the expected types. Return type
None
Shapes and Named Dimensions
 class chex.Dimensions(**dim_sizes)[source]
A lightweight utility that maps strings to shape tuples.
The most basic usage is:
>>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters. >>> dims['NBT'] (7, 3, 5)
This is useful when dealing with many differently shaped arrays. For instance, let’s check the shape of this array:
>>> x = jnp.array([[2, 0, 5, 6, 3], ... [5, 4, 4, 3, 3], ... [0, 0, 5, 2, 0]]) >>> chex.assert_shape(x, dims['BT'])
The dimension sizes can be gotten directly, e.g.
dims.N == 7
. This can be useful in many applications. For instance, let’s onehot encode our array.>>> y = jax.nn.one_hot(x, dims.N) >>> chex.assert_shape(y, dims['BTN'])
You can also store the shape of a given array in
dims
, e.g.>>> z = jnp.array([[0, 6, 0, 2], ... [4, 2, 2, 4]]) >>> dims['XY'] = z.shape >>> dims Dimensions(B=3, N=7, T=5, X=2, Y=4)
You can set a wildcard dimension, cf.
chex.assert_shape()
:>>> dims.W = None >>> dims['BTW'] (3, 5, None)
Or you can use the wildcard character ‘*’ directly:
>>> dims['BT*'] (3, 5, None)
Single digits are interpreted as literal integers. Note that this notation is limited to singledigit literals.
>>> dims['BT123'] (3, 5, 1, 2, 3)
Support for single digits was mainly included to accommodate dummy axes introduced for consistent broadcasting. For instance, instead of using
jnp.expand_dims
you could do the following:>>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5) Traceback (most recent call last): ... ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5)) >>> w = y * x.reshape(dims['BT1']) >>> chex.assert_shape(w, dims['BTN'])
Sometimes you only care about some array dimensions but not all. You can use an underscore to ignore an axis, e.g.
>>> chex.assert_rank(y, 3) >>> dims['__M'] = y.shape # Skip the first two axes.
Finally note that a singlecharacter key returns a tuple of length one.
>>> dims['M'] (7,)
Backend restriction
 chex.restrict_backends(*, allowed=None, forbidden=None)[source]
Disallows JAX compilation for certain backends.
 Parameters
allowed (Optional[Sequence[str]]) – Names of backend platforms (e.g. ‘cpu’ or ‘tpu’) for which compilation is still to be permitted.
forbidden (Optional[Sequence[str]]) – Names of backend platforms for which compilation is to be forbidden.
 Yields
None, in a context where compilation for forbidden platforms will raise a RestrictedBackendError.
 Raises
ValueError – if neither allowed nor forbidden is specified (i.e. they are both None), or if anything is both allowed and forbidden.
Dataclasses
 chex.dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only=False, mappable_dataclass=True)[source]
JAXfriendly wrapper for
dataclasses.dataclass()
.This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
 Parameters
cls – A class to decorate.
init – See
dataclasses.dataclass()
.repr – See
dataclasses.dataclass()
.eq – See
dataclasses.dataclass()
.order – See
dataclasses.dataclass()
.unsafe_hash – See
dataclasses.dataclass()
.frozen – See
dataclasses.dataclass()
.kw_only (bool) – See
dataclasses.dataclass()
.mappable_dataclass – If True (the default), methods to make the class implement the
collections.abc.Mapping
interface will be generated and the class will includecollections.abc.Mapping
in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a sideeffect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.
 Returns
A JAXfriendly dataclass.
 chex.mappable_dataclass(cls)[source]
Exposes dataclass as
collections.abc.Mapping
descendent.Allows to traverse dataclasses in methods from dmtree library.
NOTE: changes dataclasses constructor to dicttype (i.e. positional args aren’t supported; however can use generators/iterables).
 Parameters
cls – A dataclass to mutate.
 Returns
Mutated dataclass implementing
collections.abc.Mapping
interface.
 chex.register_dataclass_type_with_jax_tree_util(data_class)[source]
Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extendingpytrees for further information.
 Parameters
data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.
Fakes

Context manager for patching jax.jit with the identity function. 

Context manager for patching jax.pmap with jax.vmap. 

Context manager for patching jax.jit and jax.pmap. 

Forces XLA to use n CPU threads as host devices. 
Transformations
 chex.fake_jit(enable_patching=True)[source]
Context manager for patching jax.jit with the identity function.
This is intended to be used as a debugging tool to programmatically enable or disable JIT compilation.
Can be used either as a context managed scope:
with chex.fake_jit(): @jax.jit def foo(x): ...
or by calling start and stop:
fake_jit_context = chex.fake_jit() fake_jit_context.start() @jax.jit def foo(x): ... fake_jit_context.stop()
 Parameters
enable_patching (bool) – Whether to patch jax.jit.
 Return type
FakeContext
 Returns
Context where jax.jit is patched with the identity function jax is configured to avoid jitting internally whenever possible in functions such as jax.lax.scan, etc.
 chex.fake_pmap(enable_patching=True, jit_result=False, ignore_axis_index_groups=False, fake_parallel_axis=False)[source]
Context manager for patching jax.pmap with jax.vmap.
This is intended to be used as a debugging tool to programmatically replace pmap transformations with a nonparallel vmap transformation.
Can be used either as a context managed scope:
with chex.fake_pmap(): @jax.pmap def foo(x): ...
or by calling start and stop:
fake_pmap_context = chex.fake_pmap() fake_pmap_context.start() @jax.pmap def foo(x): ... fake_pmap_context.stop()
 Parameters
enable_patching (bool) – Whether to patch jax.pmap.
jit_result (bool) – Whether the transformed function should be jitted despite not being pmapped.
ignore_axis_index_groups (bool) – Whether to force any parallel operation within the context to set axis_index_groups to be None. This is a compatibility option to allow users of the axis_index_groups parameter to run under the fake_pmap context. This feature is not currently supported in vmap, and will fail, so we force the parameter to be None. Warning: This will produce different results to running under jax.pmap
fake_parallel_axis (bool) – Fake a parallel axis
 Return type
FakeContext
 Returns
Context where jax.pmap is patched with jax.vmap.
 chex.fake_pmap_and_jit(enable_pmap_patching=True, enable_jit_patching=True)[source]
Context manager for patching jax.jit and jax.pmap.
This is a convenience function, equivalent to nested chex.fake_pmap and chex.fake_jit contexts.
Note that calling (the true implementation of) jax.pmap will compile the function, so faking jax.jit in this case will not stop the function from being compiled.
 Parameters
enable_pmap_patching (bool) – Whether to patch jax.pmap.
enable_jit_patching (bool) – Whether to patch jax.jit.
 Return type
FakeContext
 Returns
Context where jax.pmap and jax.jit are patched with jax.vmap and the identity function
Devices
 chex.set_n_cpu_devices(n=None)[source]
Forces XLA to use n CPU threads as host devices.
This allows jax.pmap to be tested on a singleCPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including jax.devices() etc.). See https://github.com/google/jax/issues/1408.
 Parameters
n (Optional[int]) – A required number of CPU devices (
FLAGS.chex_n_cpu_devices
is used by default). Raises
RuntimeError – If XLA backends were already initialized.
 Return type
None
Pytypes

alias of 

alias of 

alias of 

alias of 

alias of 

alias of 

alias of 

alias of 

alias of 

A descriptor of an available device. 

alias of 

alias of 



alias of 

alias of 
Variants
 class chex.ChexVariantType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
An enumeration of available Chex variants.
Use
self.variant.type
to get type of the current test variant. See the docstring ofchex.variants
for more information.
 class chex.TestCase(*args, **kwargs)[source]
A class for Chex tests that use variants.
See the docstring for
chex.variants
for more information.Note:
chex.variants
returns a generator producing one test per variant. Therefore, the used test class must support dynamic unrolling of these generators during module import. It is implemented (and battletested) inabsl.parameterized.TestCase
, and here we subclass from it.
 chex.variants(test_method='__no__default__', with_jit: bool = False, without_jit: bool = False, with_device: bool = False, without_device: bool = False, with_pmap: bool = False) VariantsTestCaseGenerator
Decorates a test to expose Chex variants.
The decorated test has access to a decorator called
self.variant
, which may be applied to functions to test different JAX behaviors. Consider:@chex.variants(with_jit=True, without_jit=True) def test(self): @self.variant def f(x, y): return x + y self.assertEqual(f(1, 2), 3)
In this example, the function
test
will be called twice: once with f jitted (i.e. using jax.jit) and another where f is not jitted.Variants with_jit=True and with_pmap=True accept additional specific to them arguments. Example:
@chex.variants(with_jit=True) def test(self): @self.variant(static_argnums=(1,)) def f(x, y): # `y` is not traced. return x + y self.assertEqual(f(1, 2), 3)
Variant with_pmap=True also accepts broadcast_args_to_devices (whether to broadcast each input argument to all participating devices), reduce_fn (a function to apply to results of pmapped fn), and n_devices (number of devices to use in the pmap computation). See the docstring of _with_pmap for more details (including default values).
If used with
absl.testing.parameterized
, @chex.variants must wrap it:@chex.variants(with_jit=True, without_jit=True) @parameterized.named_parameters('test', *args) def test(self, *args): ...
Tests that use this wrapper must be inherited from
parameterized.TestCase
. For more examples seevariants_test.py
. Parameters
test_method – A test method to decorate.
with_jit – Whether to test with jax.jit.
without_jit – Whether to test without jax.jit. Any jit compilation done within the test method will not be affected.
with_device – Whether to test with args placed on device, using jax.device_put.
without_device – Whether to test with args (explicitly) not placed on device, using jax.device_get.
with_pmap – Whether to test with jax.pmap, with computation duplicated across devices.
 Returns
A decorated
test_method
.
 chex.all_variants(test_method='__no__default__', with_jit: bool = True, without_jit: bool = True, with_device: bool = True, without_device: bool = True, with_pmap: bool = True) VariantsTestCaseGenerator
Equivalent to
chex.variants
but with flipped defaults.
 chex.params_product(*params_lists, named=False)[source]
Generates a cartesian product of params_lists.
See tests from
variants_test.py
for examples of usage. Parameters
*params_lists – A list of params combinations.
named (bool) – Whether to generate test names (for absl.parameterized.named_parameters(…)).
 Return type
Sequence[Sequence[Any]]
 Returns
A cartesian product of params_lists combinations.