Assertions

assert_axis_dimension(tensor, axis, expected)

Checks that tensor.shape[axis] == expected.

assert_axis_dimension_comparator(tensor, ...)

Asserts that pass_fn(tensor.shape[axis]) passes.

assert_axis_dimension_gt(tensor, axis, val)

Checks that tensor.shape[axis] > val.

assert_axis_dimension_gteq(tensor, axis, val)

Checks that tensor.shape[axis] >= val.

assert_axis_dimension_lt(tensor, axis, val)

Checks that tensor.shape[axis] < val.

assert_axis_dimension_lteq(tensor, axis, val)

Checks that tensor.shape[axis] <= val.

assert_devices_available(n, devtype[, ...])

Checks that n devices of a given type are available.

assert_equal(first, second)

Checks that the two objects are equal as determined by the == operator.

assert_equal_rank(inputs)

Checks that all arrays have the same rank.

assert_equal_size(inputs)

Checks that all arrays have the same size.

assert_equal_shape(inputs, *[, dims])

Checks that all arrays have the same shape.

assert_equal_shape_prefix(inputs, prefix_len)

Checks that the leading prefix_dims dims of all inputs have same shape.

assert_equal_shape_suffix(inputs, suffix_len)

Checks that the final suffix_len dims of all inputs have same shape.

assert_exactly_one_is_none(first, second)

Checks that one and only one of the arguments is None.

assert_gpu_available([backend])

Checks that at least one GPU device is available.

assert_is_broadcastable(shape_a, shape_b)

Checks that an array of shape_a is broadcastable to one of shape_b.

assert_is_divisible(numerator, denominator)

Checks that numerator is divisible by denominator.

assert_max_traces([fn, n])

Checks that a function is traced at most n times (inclusively).

assert_not_both_none(first, second)

Checks that at least one of the arguments is not None.

assert_numerical_grads(f, f_args, order[, atol])

Checks that autodiff and numerical gradients of a function match.

assert_rank(inputs, expected_ranks)

Checks that the rank of all inputs matches specified expected_ranks.

assert_scalar(x)

Checks that x is a scalar, as defined in pytypes.py (int or float).

assert_scalar_in(x, min_, max_[, included])

Checks that argument is a scalar within segment (by default).

assert_scalar_negative(x)

Checks that a scalar is negative.

assert_scalar_non_negative(x)

Checks that a scalar is non-negative.

assert_scalar_positive(x)

Checks that a scalar is positive.

assert_size(inputs, expected_sizes)

Checks that the size of all inputs matches specified expected_sizes.

assert_shape(inputs, expected_shapes)

Checks that the shape of all inputs matches specified expected_shapes.

assert_tpu_available([backend])

Checks that at least one TPU device is available.

assert_tree_all_finite(tree_like)

Checks that all leaves in a tree are finite.

assert_tree_has_only_ndarrays(tree)

Checks that all tree's leaves are n-dimensional arrays (tensors).

assert_tree_is_on_device(tree, *[, ...])

Checks that all leaves are ndarrays residing in device memory (in HBM).

assert_tree_is_on_host(tree, *[, ...])

Checks that all leaves are ndarrays residing in the host memory (on CPU).

assert_tree_is_sharded(tree, *, devices)

Checks that all leaves are ndarrays sharded across the specified devices.

assert_tree_no_nones(tree)

Checks that a tree does not contain None.

assert_tree_shape_prefix(tree, shape_prefix)

Checks that all tree leaves' shapes have the same prefix.

assert_tree_shape_suffix(tree, shape_suffix)

Checks that all tree leaves' shapes have the same suffix.

assert_trees_all_close(*trees[, rtol, atol])

Checks that all trees have leaves with approximately equal values.

assert_trees_all_close_ulp(*trees[, maxulp])

Checks that tree leaves differ by at most maxulp Units in the Last Place.

assert_trees_all_equal(*trees[, strict])

Checks that all trees have leaves with exactly equal values.

assert_trees_all_equal_comparator(...)

Checks that all trees are equal as per the custom comparator for leaves.

assert_trees_all_equal_dtypes(*trees)

Checks that trees' leaves have the same dtype.

assert_trees_all_equal_sizes(*trees)

Checks that trees have the same structure and leaves' sizes.

assert_trees_all_equal_shapes(*trees)

Checks that trees have the same structure and leaves' shapes.

assert_trees_all_equal_shapes_and_dtypes(*trees)

Checks that trees' leaves have the same shape and dtype.

assert_trees_all_equal_structs(*trees)

Checks that trees have the same structure.

assert_type(inputs, expected_types)

Checks that the type of all inputs matches specified expected_types.

chexify(fn[, async_check, errors])

Wraps a transformed function fn to enable Chex value assertions.

ChexifyChecks

A set of checks imported from checkify.

with_jittable_assertions(fn[, async_check])

An alias for chexify (see the docs).

block_until_chexify_assertions_complete()

Waits until all asynchronous checks complete.

Dimensions(**dim_sizes)

A lightweight utility that maps strings to shape tuples.

disable_asserts()

Disables all Chex assertions.

enable_asserts()

Enables Chex assertions.

clear_trace_counter()

Clears Chex traces' counter for assert_max_traces checks.

if_args_not_none(fn, *args, **kwargs)

Wrap chex assertion to only be evaluated if positional args not None.

Jax Assertions

chex.assert_max_traces(fn=None, n=None)[source]

Checks that a function is traced at most n times (inclusively).

JAX re-traces jitted functions every time the structure of passed arguments changes. Often this behaviour is inadvertent and leads to a significant performance drop which is hard to debug. This wrapper checks that the function is re-traced at most n times during program execution.

Examples:

@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
  return x + y

def fn_sub(x, y):
  return x - y

fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10)
More about tracing:

https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

Parameters
  • fn (Optional[Union[Callable[..., Any], int]]) – A pure python function to wrap (i.e. it must not be a jitted function).

  • n (Optional[Union[Callable[..., Any], int]]) – The maximum allowed number of retraces (non-negative).

Returns

Decorated function that raises exception when it is re-traced n+1-st time.

Raises

ValueError – If fn has already been jitted.

chex.assert_devices_available(n, devtype, backend=None, not_less_than=False)[source]

Checks that n devices of a given type are available.

Parameters
  • n (int) – A required number of devices of the given type.

  • devtype (str) – A type of devices, one of {'cpu', 'gpu', 'tpu'}.

  • backend (Optional[str]) – A type of backend to use (uses Jax default if not provided).

  • not_less_than (bool) – Whether to check if the number of devices is not less than n, instead of precise comparison.

Raises

AssertionError – If number of available device of a given type is not equal or less than n.

Return type

None

chex.assert_gpu_available(backend=None)[source]

Checks that at least one GPU device is available.

Parameters

backend (Optional[str]) – A type of backend to use (uses JAX default if not provided).

Raises

AssertionError – If no GPU device available.

Return type

None

chex.assert_tpu_available(backend=None)[source]

Checks that at least one TPU device is available.

Parameters

backend (Optional[str]) – A type of backend to use (uses JAX default if not provided).

Raises

AssertionError – If no TPU device available.

Return type

None

Value (Runtime) Assertions

chex.chexify(fn, async_check=True, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[source]

Wraps a transformed function fn to enable Chex value assertions.

Chex value/runtime assertions access concrete values of tensors (e.g. assert_tree_all_finite) which are not available during JAX tracing, see https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html and https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.

This wrapper enables them in jitted/pmapped functions by performing a specifically designed JAX transformation https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation and calling functionalised checks https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html

Example:

@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
  chex.assert_tree_all_finite(x)
  return jnp.log(jnp.abs(x) + 1)

logp1_abs_safe(jnp.ones(2))  # OK
logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS
logp1_abs_safe.wait_checks()

Note 1: This wrapper allows identifying the first failed assertion in a jitted code by printing a pointer to the line where the failed assertion was invoked. For getting verbose messages (including concrete tensor values), an unjitted version of the code will need to be executed with the same input values. Chex does not currently provide tools to help with this.

Note 2: This wrapper fully supports asynchronous executions (see https://jax.readthedocs.io/en/latest/async_dispatch.html). To block program execution until asynchronous checks for a _chexified_ function fn complete, call fn.wait_checks(). Similarly, chex.block_until_chexify_assertions_complete() will block program execution until _all_ asyncronous checks complete.

Note 3: Chex automatically selects the backend for executing its assertions (i.e. CPU or device accelerator) depending on the program context.

Note 4: Value assertions can have impact on the performance of a function, see https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations

Note 5: static assertions, such as assert_shape or assert_trees_all_equal_dtypes, can be called from a jitted function without chexify wrapper (since they do not access concrete values, only shapes and/or dtypes which are available during JAX tracing).

More examples can be found at https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py

Parameters
  • fn (Callable[..., Any]) – A transformed function to wrap.

  • async_check (bool) – Whether to check errors in the async dispatch mode. See https://jax.readthedocs.io/en/latest/async_dispatch.html.

  • errors (FrozenSet[checkify.ErrorCategory]) – A set of checkify.ErrorCategory values which defines the set of enabled checks. By default only explicit checks are enabled (user). You can also for example enable NaN and Div-by-0 errors by passing the float set, or for example combine multiple sets through set operations (float | user).

Return type

Callable[…, Any]

Returns

A _chexified_ function, i.e. the one with enabled value assertions. The returned function has wait_checks() method that blocks the caller until all pending async checks complete.

ChexifyChecks

A set of checks imported from checkify.

chex.with_jittable_assertions(fn, async_check=True)[source]

An alias for chexify (see the docs).

Return type

Callable[…, Any]

chex.block_until_chexify_assertions_complete()[source]

Waits until all asynchronous checks complete.

See chexify for more detail.

Return type

None

Tree Assertions

chex.assert_tree_all_finite(tree_like)[source]

Checks that all leaves in a tree are finite.

Parameters

tree_like (ArrayTree) – A pytree with array leaves.

Raises

AssertionError – If any leaf in tree_like is non-finite.

Return type

None

chex.assert_tree_has_only_ndarrays(tree)[source]

Checks that all tree’s leaves are n-dimensional arrays (tensors).

Parameters

tree (ArrayTree) – A tree to assert.

Raises

AssertionError – If the tree contains an object which is not an ndarray.

Return type

None

chex.assert_tree_is_on_device(tree, *, platform=('gpu', 'tpu'), device=None)[source]

Checks that all leaves are ndarrays residing in device memory (in HBM).

Sharded DeviceArrays are disallowed.

Parameters
  • tree (ArrayTree) – A tree to assert.

  • platform (Union[Sequence[str], str]) – A platform or a list of platforms where the leaves are expected to reside. Ignored if device is specified.

  • device (Optional[pytypes.Device]) – An optional device where the tree’s arrays are expected to reside. Any device (except CPU) is accepted if not specified.

Raises

AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on the specified device or platform.

Return type

None

chex.assert_tree_is_on_host(tree, *, allow_cpu_device=True, allow_sharded_arrays=False)[source]

Checks that all leaves are ndarrays residing in the host memory (on CPU).

This assertion only accepts trees consisting of ndarrays.

Parameters
  • tree (ArrayTree) – A tree to assert.

  • allow_cpu_device (bool) – Whether to allow JAX arrays that reside on a CPU device.

  • allow_sharded_arrays (bool) – Whether to allow sharded JAX arrays. Sharded arrays are considered “on host” only if they are sharded across CPU devices and allow_cpu_device is True.

Raises

AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on host.

Return type

None

chex.assert_tree_is_sharded(tree, *, devices)[source]

Checks that all leaves are ndarrays sharded across the specified devices.

Parameters
  • tree (ArrayTree) – A tree to assert.

  • devices (Sequence[pytypes.Device]) – A list of devices which the tree’s leaves are expected to be sharded across. This list is order-sensitive.

Raises

AssertionError – If the tree contains a leaf that is not a device array sharded across the specified devices.

Return type

None

chex.assert_tree_no_nones(tree)[source]

Checks that a tree does not contain None.

Parameters

tree (ArrayTree) – A tree to assert.

Raises

AssertionError – If the tree contains at least one None.

Return type

None

chex.assert_tree_shape_prefix(tree, shape_prefix)[source]

Checks that all tree leaves’ shapes have the same prefix.

Parameters
  • tree (ArrayTree) – A tree to check.

  • shape_prefix (Sequence[int]) – An expected shape prefix.

Raises

AssertionError – If some leaf’s shape doesn’t start with shape_prefix.

Return type

None

chex.assert_tree_shape_suffix(tree, shape_suffix)[source]

Checks that all tree leaves’ shapes have the same suffix.

Parameters
  • tree (ArrayTree) – A tree to check.

  • shape_suffix (Sequence[int]) – An expected shape suffix.

Raises

AssertionError – If some leaf’s shape doesn’t start with shape_suffix.

Return type

None

chex.assert_trees_all_close(*trees, rtol=1e-06, atol=0.0)[source]

Checks that all trees have leaves with approximately equal values.

This compares the difference between values of actual and desired up to

atol + rtol * abs(desired).

Parameters
  • *trees – A sequence of (at least 2) trees with array leaves.

  • rtol (float) – A relative tolerance.

  • atol (float) – An absolute tolerance.

Raises

AssertionError – If actual and desired values are not equal up to specified tolerance.

Return type

None

chex.assert_trees_all_close_ulp(*trees, maxulp=1)[source]

Checks that tree leaves differ by at most maxulp Units in the Last Place.

This is the Chex version of np.testing.assert_array_max_ulp.

Assertions on floating point values are tricky because the precision varies depending on the value. For example, with float32, the precision at 1 is np.spacing(np.float32(1.0)) ≈ 1e-7, but the precision at 5,000,000 is only np.spacing(np.float32(5e6)) = 0.5. This makes it hard to predict ahead of time what tolerance to use when checking whether two numbers are equal: a difference of only a couple of bits can equate to an arbitrarily large absolute difference.

Assertions based on _relative_ differences are one solution to this problem, but have the disadvantage that it’s hard to choose the tolerance. If you want to verify that two calculations produce _exactly_ the same result modulo the inherent non-determinism of floating point operations, do you set the tolerance to…0.01? 0.001? It’s hard to be sure you’ve set it low enough that you won’t miss one of your computations being slightly wrong.

Assertions based on ‘units in the last place’ (ULP) instead solve this problem by letting you specify tolerances in terms of the precision actually available at the current scale of your values. The ULP at some value x is essentially the spacing between the floating point numbers actually representable in the vicinity of x - equivalent to the ‘precision’ we discussed above. above. With a tolerance of, say, maxulp=5, you’re saying that two values are within 5 actually-representable-numbers of each other - a strong guarantee that two computations are as close as possible to identical, while still allowing reasonable wiggle room for small differences due to e.g. different operator orderings.

Note that this function is not currently supported within JIT contexts, and does not currently support bfloat16 dtypes.

Parameters
  • *trees – A sequence of (at least 2) trees with array leaves.

  • maxulp (int) – The maximum number of ULPs by which leaves may differ.

Raises

AssertionError – If actual and desired values are not equal up to specified tolerance.

Return type

None

chex.assert_trees_all_equal(*trees, strict=False)[source]

Checks that all trees have leaves with exactly equal values.

If you are comparing floating point numbers, an exact equality check may not be appropriate; consider using assert_trees_all_close.

Parameters
  • *trees – A sequence of (at least 2) trees with array leaves.

  • strict (bool) – If True, disable special scalar handling as described in np.testing.assert_array_equals notes section.

Raises

AssertionError – If the leaf values actual and desired are not exactly equal.

Return type

None

chex.assert_trees_all_equal_comparator(equality_comparator, error_msg_fn, *trees)[source]

Checks that all trees are equal as per the custom comparator for leaves.

Parameters
  • equality_comparator (_ai.TLeavesEqCmpFn) – A custom function that accepts two leaves and checks whether they are equal. Expected to be transitive.

  • error_msg_fn (_ai.TLeavesEqCmpErrorFn) – A function accepting two unequal as per equality_comparator leaves and returning an error message.

  • *trees – A sequence of (at least 2) trees to check on equality as per equality_comparator.

Raises
  • ValueError – If trees does not contain at least 2 elements.

  • AssertionError – if equality_comparator returns False for any pair of trees from trees.

Return type

None

chex.assert_trees_all_equal_dtypes(*trees)[source]

Checks that trees’ leaves have the same dtype.

Parameters

*trees – A sequence of (at least 2) trees to check.

Raises

AssertionError – If leaves’ dtypes for any two trees differ.

Return type

None

chex.assert_trees_all_equal_sizes(*trees)[source]

Checks that trees have the same structure and leaves’ sizes.

Parameters

*trees – A sequence of (at least 2) trees with array leaves.

Raises

AssertionError – If trees’ structures or leaves’ sizes are different.

Return type

None

chex.assert_trees_all_equal_shapes(*trees)[source]

Checks that trees have the same structure and leaves’ shapes.

Parameters

*trees – A sequence of (at least 2) trees with array leaves.

Raises

AssertionError – If trees’ structures or leaves’ shapes are different.

Return type

None

chex.assert_trees_all_equal_shapes_and_dtypes(*trees)[source]

Checks that trees’ leaves have the same shape and dtype.

Parameters

*trees – A sequence of (at least 2) trees to check.

Raises

AssertionError – If leaves’ shapes or dtypes for any two trees differ.

Return type

None

chex.assert_trees_all_equal_structs(*trees)[source]

Checks that trees have the same structure.

Parameters

*trees – A sequence of (at least 2) trees to assert equal structure between.

Raises
  • ValueError – If trees does not contain at least 2 elements.

  • AssertionError – If structures of any two trees are different.

Return type

None

Generic Assertions

chex.assert_axis_dimension(tensor, axis, expected)[source]

Checks that tensor.shape[axis] == expected.

Parameters
  • tensor (Array) – A JAX array.

  • axis (int) – An integer specifying which axis to assert.

  • expected (int) – An expected value of tensor.shape[axis].

Raises

AssertionError – The dimension of the specified axis does not match the prescribed value.

Return type

None

chex.assert_axis_dimension_comparator(tensor, axis, pass_fn, error_string)[source]

Asserts that pass_fn(tensor.shape[axis]) passes.

Used to implement ==, >, >=, <, <= checks.

Parameters
  • tensor (Array) – A JAX array.

  • axis (int) – An integer specifying which axis to assert.

  • pass_fn (Callable[[int], bool]) – A callable which takes the size of the give dimension and returns false when the assertion should fail.

  • error_string (str) – string which is inserted in assertion failure messages - ‘expected tensor to have dimension {error_string} on axis …’.

Raises

AssertionError – if pass_fn(tensor.shape[axis], val) does not return true.

chex.assert_axis_dimension_gt(tensor, axis, val)[source]

Checks that tensor.shape[axis] > val.

Parameters
  • tensor (Array) – A JAX array.

  • axis (int) – An integer specifying which axis to assert.

  • val (int) – A value tensor.shape[axis] must be greater than.

Raises

AssertionError – if the dimension of axis is <= val.

Return type

None

chex.assert_axis_dimension_gteq(tensor, axis, val)[source]

Checks that tensor.shape[axis] >= val.

Parameters
  • tensor (Array) – A JAX array.

  • axis (int) – An integer specifying which axis to assert.

  • val (int) – A value tensor.shape[axis] must be greater than or equal to.

Raises

AssertionError – if the dimension of axis is < val.

Return type

None

chex.assert_axis_dimension_lt(tensor, axis, val)[source]

Checks that tensor.shape[axis] < val.

Parameters
  • tensor (Array) – A JAX Array.

  • axis (int) – An integer specifiying with axis to assert.

  • val (int) – A value tensor.shape[axis] must be less than.

Raises

AssertionError – if the dimension of axis is >= val.

Return type

None

chex.assert_axis_dimension_lteq(tensor, axis, val)[source]

Checks that tensor.shape[axis] <= val.

Parameters
  • tensor (Array) – A JAX array.

  • axis (int) – An integer specifying which axis to assert.

  • val (int) – A value tensor.shape[axis] must be less than or equal to.

Raises

AssertionError – if the dimension of axis is > val.

Return type

None

chex.assert_equal(first, second)[source]

Checks that the two objects are equal as determined by the == operator.

Arrays with more than one element cannot be compared. Use assert_trees_all_close to compare arrays.

Parameters
  • first (Any) – A first object.

  • second (Any) – A second object.

Raises

AssertionError – If not (first == second).

Return type

None

chex.assert_equal_rank(inputs)[source]

Checks that all arrays have the same rank.

Parameters

inputs (Sequence[Array]) – A collection of arrays.

Raises
  • AssertionError – If the ranks of all arrays do not match.

  • ValueError – If inputs is not a collection of arrays.

Return type

None

chex.assert_equal_size(inputs)[source]

Checks that all arrays have the same size.

Parameters

inputs (Sequence[Array]) – A collection of arrays.

Raises

AssertionError – If the size of all arrays do not match.

Return type

None

chex.assert_equal_shape(inputs, *, dims=None)[source]

Checks that all arrays have the same shape.

Parameters
  • inputs (Sequence[Array]) – A collection of arrays.

  • dims (Optional[Union[int, Sequence[int]]]) – An optional integer or sequence of integers. If not provided, every dimension of every shape must match. If provided, equality of shape will only be asserted for the specified dim(s), i.e. to ensure all of a group of arrays have the same size in the first two dimensions, call assert_equal_shape(tensors_list, dims=(0, 1)).

Raises
  • AssertionError – If the shapes of all arrays at specified dims do not match.

  • ValueError – If the provided dims are invalid indices into any of arrays; or if inputs is not a collection of arrays.

Return type

None

chex.assert_equal_shape_prefix(inputs, prefix_len)[source]

Checks that the leading prefix_dims dims of all inputs have same shape.

Parameters
  • inputs (Sequence[Array]) – A collection of input arrays.

  • prefix_len (int) – A number of leading dimensions to compare; each input’s shape will be sliced to shape[:prefix_len]. Negative values are accepted and have the conventional Python indexing semantics.

Raises
  • AssertionError – If the shapes of all arrays do not match.

  • ValuleError – If inputs is not a collection of arrays.

Return type

None

chex.assert_equal_shape_suffix(inputs, suffix_len)[source]

Checks that the final suffix_len dims of all inputs have same shape.

Parameters
  • inputs (Sequence[Array]) – A collection of input arrays.

  • suffix_len (int) – A number of trailing dimensions to compare; each input’s shape will be sliced to shape[-suffix_len:]. Negative values are accepted and have the conventional Python indexing semantics.

Raises
  • AssertionError – If the shapes of all arrays do not match.

  • ValuleError – If inputs is not a collection of arrays.

Return type

None

chex.assert_exactly_one_is_none(first, second)[source]

Checks that one and only one of the arguments is None.

Parameters
  • first (Any) – A first object.

  • second (Any) – A second object.

Raises

AssertionError – If (first is None) xor (second is None) is False.

Return type

None

chex.assert_is_broadcastable(shape_a, shape_b)[source]

Checks that an array of shape_a is broadcastable to one of shape_b.

Parameters
  • shape_a (Sequence[int]) – A shape of the array to check.

  • shape_b (Sequence[int]) – A target shape after broadcasting.

Raises

AssertionError – If shape_a is not broadcastable to shape_b.

Return type

None

chex.assert_is_divisible(numerator, denominator)[source]

Checks that numerator is divisible by denominator.

Parameters
  • numerator (int) – A numerator.

  • denominator (int) – A denominator.

Raises

AssertionError – If numerator is not divisible by denominator.

Return type

None

chex.assert_not_both_none(first, second)[source]

Checks that at least one of the arguments is not None.

Parameters
  • first (Any) – A first object.

  • second (Any) – A second object.

Raises

AssertionError – If (first is None) and (second is None).

Return type

None

chex.assert_numerical_grads(f, f_args, order, atol=0.01, **check_kwargs)[source]

Checks that autodiff and numerical gradients of a function match.

Parameters
  • f (Callable[..., Array]) – A function to check.

  • f_args (Sequence[Array]) – Arguments of the function.

  • order (int) – An order of gradients.

  • atol (float) – An absolute tolerance.

  • **check_kwargs – Kwargs for jax_test.check_grads.

Raises

AssertionError – If automatic differentiation gradients deviate from finite difference gradients.

Return type

None

chex.assert_rank(inputs, expected_ranks)[source]

Checks that the rank of all inputs matches specified expected_ranks.

Valid usages include:

assert_rank(x, 0)                      # x is scalar
assert_rank(x, 2)                      # x is a rank-2 array
assert_rank(x, {0, 2})                 # x is scalar or rank-2 array
assert_rank([x, y], 2)                 # x and y are rank-2 arrays
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar or rank-2 arrays
Parameters
  • inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.

  • expected_ranks (Union[int, Set[int], Sequence[Union[int, Set[int]]]]) – A sequence of expected ranks associated with each input, where the expected rank is either an integer or set of integer options; if all inputs have same rank, a single scalar or set of scalars may be passed as expected_ranks.

Raises
  • AssertionError – If lengths of inputs and expected_ranks don’t match; if expected_ranks has wrong type; if the ranks of inputs do not match expected_ranks.

  • ValueError – If expected_ranks is not an integer and not a sequence of integets.

Return type

None

chex.assert_scalar(x)[source]

Checks that x is a scalar, as defined in pytypes.py (int or float).

Parameters

x (Scalar) – An object to check.

Raises

AssertionError – If x is not a scalar as per definition in pytypes.py.

Return type

None

chex.assert_scalar_in(x, min_, max_, included=True)[source]

Checks that argument is a scalar within segment (by default).

Parameters
  • x (Any) – An object to check.

  • min – A left border of the segment.

  • max – A right border of the segment.

  • included (bool) – Whether to include the borders of the segment in the set of allowed values.

Raises

AssertionError – If x is not a scalar; if x falls out of the segment.

Return type

None

chex.assert_scalar_negative(x)[source]

Checks that a scalar is negative.

Parameters

x (Scalar) – A value to check.

Raises

AssertionError – If x is not a scalar or strictly negative.

Return type

None

chex.assert_scalar_non_negative(x)[source]

Checks that a scalar is non-negative.

Parameters

x (Scalar) – A value to check.

Raises

AssertionError – If x is not a scalar or negative.

Return type

None

chex.assert_scalar_positive(x)[source]

Checks that a scalar is positive.

Parameters

x (Scalar) – A value to check.

Raises

AssertionError – If x is not a scalar or strictly positive.

Return type

None

chex.assert_size(inputs, expected_sizes)[source]

Checks that the size of all inputs matches specified expected_sizes.

Valid usages include:

assert_size(x, 1)                   # x is scalar (size 1)
assert_size([x, y], (2, {1, 3}))    # x has size 2, y has size 1 or 3
assert_size([x, y], (2, ...))       # x has size 2, y has any size
assert_size([x, y], 1)              # x and y are scalar (size 1)
assert_size((x, y), (5, 2))         # x has size 5, y has size 2
Parameters
  • inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.

  • expected_sizes (Union[_ai.TShapeMatcher, Sequence[_ai.TShapeMatcher]]) – A sqeuence of expected sizes associated with each input, where the expected size is a sequence of integer and None dimensions; if all inputs have same size, a single size may be passed as expected_sizes.

Raises

AssertionError – If the lengths of inputs and expected_sizes do not match; if expected_sizes has wrong type; if size of input does not match expected_sizes.

Return type

None

chex.assert_shape(inputs, expected_shapes)[source]

Checks that the shape of all inputs matches specified expected_shapes.

Valid usages include:

assert_shape(x, ())                  # x is scalar
assert_shape(x, (2, 3))              # x has shape (2, 3)
assert_shape(x, (2, {1, 3}))         # x has shape (2, 1) or (2, 3)
assert_shape(x, (2, None))           # x has rank 2 and `x.shape[0] == 2`
assert_shape(x, (2, ...))            # x has rank >= 1 and `x.shape[0] == 2`
assert_shape([x, y], ())             # x and y are scalar
assert_shape([x, y], [(), (2,3)])    # x is scalar and y has shape (2, 3)
Parameters
  • inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays.

  • expected_shapes (Union[_ai.TShapeMatcher, Sequence[_ai.TShapeMatcher]]) – A sequence of expected shapes associated with each input, where the expected shape is a sequence of integer and None dimensions; if all inputs have same shape, a single shape may be passed as expected_shapes.

Raises

AssertionError – If the lengths of inputs and expected_shapes do not match; if expected_shapes has wrong type; if shape of input does not match expected_shapes.

Return type

None

chex.assert_type(inputs, expected_types)[source]

Checks that the type of all inputs matches specified expected_types.

Valid usages include:

assert_type(7, int)
assert_type(7.1, float)
assert_type(False, bool)
assert_type([7, 8], int)
assert_type([7, 7.1], [int, float])
assert_type(np.array(7), int)
assert_type(np.array(7.1), float)
assert_type(jnp.array(7), int)
assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float])
Parameters
  • inputs (Union[Scalar, Union[Array, Sequence[Array]]]) – An array or a sequence of arrays or scalars.

  • expected_types (Union[Type[Scalar], Sequence[Type[Scalar]]]) – A sequence of expected types associated with each input; if all inputs have same type, a single type may be passed as expected_types.

Raises

AssertionError – If lengths of inputs and expected_types don’t match; if expected_types contains unsupported pytype; if the types of inputs do not match the expected types.

Return type

None

Shapes and Named Dimensions

class chex.Dimensions(**dim_sizes)[source]

A lightweight utility that maps strings to shape tuples.

The most basic usage is:

>>> dims = chex.Dimensions(B=3, T=5, N=7)  # You can specify any letters.
>>> dims['NBT']
(7, 3, 5)

This is useful when dealing with many differently shaped arrays. For instance, let’s check the shape of this array:

>>> x = jnp.array([[2, 0, 5, 6, 3],
...                [5, 4, 4, 3, 3],
...                [0, 0, 5, 2, 0]])
>>> chex.assert_shape(x, dims['BT'])

The dimension sizes can be gotten directly, e.g. dims.N == 7. This can be useful in many applications. For instance, let’s one-hot encode our array.

>>> y = jax.nn.one_hot(x, dims.N)
>>> chex.assert_shape(y, dims['BTN'])

You can also store the shape of a given array in dims, e.g.

>>> z = jnp.array([[0, 6, 0, 2],
...                [4, 2, 2, 4]])
>>> dims['XY'] = z.shape
>>> dims
Dimensions(B=3, N=7, T=5, X=2, Y=4)

You can set a wildcard dimension, cf. chex.assert_shape():

>>> dims.W = None
>>> dims['BTW']
(3, 5, None)

Or you can use the wildcard character ‘*’ directly:

>>> dims['BT*']
(3, 5, None)

Single digits are interpreted as literal integers. Note that this notation is limited to single-digit literals.

>>> dims['BT123']
(3, 5, 1, 2, 3)

Support for single digits was mainly included to accommodate dummy axes introduced for consistent broadcasting. For instance, instead of using jnp.expand_dims you could do the following:

>>> w = y * x  # Cannot broadcast (3, 5, 7) with (3, 5)
Traceback (most recent call last):
    ...
ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5))
>>> w = y * x.reshape(dims['BT1'])
>>> chex.assert_shape(w, dims['BTN'])

Sometimes you only care about some array dimensions but not all. You can use an underscore to ignore an axis, e.g.

>>> chex.assert_rank(y, 3)
>>> dims['__M'] = y.shape  # Skip the first two axes.

Finally note that a single-character key returns a tuple of length one.

>>> dims['M']
(7,)

Utils

chex.disable_asserts()[source]

Disables all Chex assertions.

Use wisely.

Return type

None

chex.enable_asserts()[source]

Enables Chex assertions.

Return type

None

chex.clear_trace_counter()[source]

Clears Chex traces’ counter for assert_max_traces checks.

Use it to isolate unit tests that rely on assert_max_traces, by calling it at the start of the test case.

Return type

None

chex.if_args_not_none(fn, *args, **kwargs)[source]

Wrap chex assertion to only be evaluated if positional args not None.

Backend restriction

chex.restrict_backends(*, allowed=None, forbidden=None)[source]

Disallows JAX compilation for certain backends.

Parameters
  • allowed (Optional[Sequence[str]]) – Names of backend platforms (e.g. ‘cpu’ or ‘tpu’) for which compilation is still to be permitted.

  • forbidden (Optional[Sequence[str]]) – Names of backend platforms for which compilation is to be forbidden.

Yields

None, in a context where compilation for forbidden platforms will raise a RestrictedBackendError.

Raises

ValueError – if neither allowed nor forbidden is specified (i.e. they are both None), or if anything is both allowed and forbidden.

Dataclasses

chex.dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only=False, mappable_dataclass=True)[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only (bool) – See dataclasses.dataclass().

  • mappable_dataclass – If True (the default), methods to make the class implement the collections.abc.Mapping interface will be generated and the class will include collections.abc.Mapping in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.

Returns

A JAX-friendly dataclass.

chex.mappable_dataclass(cls)[source]

Exposes dataclass as collections.abc.Mapping descendent.

Allows to traverse dataclasses in methods from dm-tree library.

NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren’t supported; however can use generators/iterables).

Parameters

cls – A dataclass to mutate.

Returns

Mutated dataclass implementing collections.abc.Mapping interface.

chex.register_dataclass_type_with_jax_tree_util(data_class)[source]

Register an existing dataclass so JAX knows how to handle it.

This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.

Parameters

data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.

Fakes

fake_jit([enable_patching])

Context manager for patching jax.jit with the identity function.

fake_pmap([enable_patching, jit_result, ...])

Context manager for patching jax.pmap with jax.vmap.

fake_pmap_and_jit([enable_pmap_patching, ...])

Context manager for patching jax.jit and jax.pmap.

set_n_cpu_devices([n])

Forces XLA to use n CPU threads as host devices.

Transformations

chex.fake_jit(enable_patching=True)[source]

Context manager for patching jax.jit with the identity function.

This is intended to be used as a debugging tool to programmatically enable or disable JIT compilation.

Can be used either as a context managed scope:

with chex.fake_jit():
  @jax.jit
  def foo(x):
    ...

or by calling start and stop:

fake_jit_context = chex.fake_jit()
fake_jit_context.start()

@jax.jit
  def foo(x):
        ...

fake_jit_context.stop()
Parameters

enable_patching (bool) – Whether to patch jax.jit.

Return type

FakeContext

Returns

Context where jax.jit is patched with the identity function jax is configured to avoid jitting internally whenever possible in functions such as jax.lax.scan, etc.

chex.fake_pmap(enable_patching=True, jit_result=False, ignore_axis_index_groups=False, fake_parallel_axis=False)[source]

Context manager for patching jax.pmap with jax.vmap.

This is intended to be used as a debugging tool to programmatically replace pmap transformations with a non-parallel vmap transformation.

Can be used either as a context managed scope:

with chex.fake_pmap():
  @jax.pmap
  def foo(x):
    ...

or by calling start and stop:

fake_pmap_context = chex.fake_pmap()
fake_pmap_context.start()
@jax.pmap
  def foo(x):
    ...
fake_pmap_context.stop()
Parameters
  • enable_patching (bool) – Whether to patch jax.pmap.

  • jit_result (bool) – Whether the transformed function should be jitted despite not being pmapped.

  • ignore_axis_index_groups (bool) – Whether to force any parallel operation within the context to set axis_index_groups to be None. This is a compatibility option to allow users of the axis_index_groups parameter to run under the fake_pmap context. This feature is not currently supported in vmap, and will fail, so we force the parameter to be None. Warning: This will produce different results to running under jax.pmap

  • fake_parallel_axis (bool) – Fake a parallel axis

Return type

FakeContext

Returns

Context where jax.pmap is patched with jax.vmap.

chex.fake_pmap_and_jit(enable_pmap_patching=True, enable_jit_patching=True)[source]

Context manager for patching jax.jit and jax.pmap.

This is a convenience function, equivalent to nested chex.fake_pmap and chex.fake_jit contexts.

Note that calling (the true implementation of) jax.pmap will compile the function, so faking jax.jit in this case will not stop the function from being compiled.

Parameters
  • enable_pmap_patching (bool) – Whether to patch jax.pmap.

  • enable_jit_patching (bool) – Whether to patch jax.jit.

Return type

FakeContext

Returns

Context where jax.pmap and jax.jit are patched with jax.vmap and the identity function

Devices

chex.set_n_cpu_devices(n=None)[source]

Forces XLA to use n CPU threads as host devices.

This allows jax.pmap to be tested on a single-CPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including jax.devices() etc.). See https://github.com/google/jax/issues/1408.

Parameters

n (Optional[int]) – A required number of CPU devices (FLAGS.chex_n_cpu_devices is used by default).

Raises

RuntimeError – If XLA backends were already initialized.

Return type

None

Pytypes

Array

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

ArrayBatched

alias of jax.Array

ArrayDevice

alias of jax.Array

ArrayDeviceTree

alias of Union[jax.Array, Iterable[ArrayDeviceTree], Mapping[Any, ArrayDeviceTree]]

ArrayDType

alias of Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

ArrayNumpy

alias of numpy.ndarray

ArrayNumpyTree

alias of Union[numpy.ndarray, Iterable[ArrayNumpyTree], Mapping[Any, ArrayNumpyTree]]

ArraySharded

alias of jax.Array

ArrayTree

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Device

A descriptor of an available device.

Numeric

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]

PRNGKey

alias of jax.Array

PyTreeDef

Scalar

alias of Union[float, int]

Shape

alias of Sequence[Union[int, Any]]

Variants

class chex.ChexVariantType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

An enumeration of available Chex variants.

Use self.variant.type to get type of the current test variant. See the docstring of chex.variants for more information.

class chex.TestCase(*args, **kwargs)[source]

A class for Chex tests that use variants.

See the docstring for chex.variants for more information.

Note: chex.variants returns a generator producing one test per variant. Therefore, the used test class must support dynamic unrolling of these generators during module import. It is implemented (and battle-tested) in absl.parameterized.TestCase, and here we subclass from it.

chex.variants(test_method='__no__default__', with_jit: bool = False, without_jit: bool = False, with_device: bool = False, without_device: bool = False, with_pmap: bool = False) VariantsTestCaseGenerator

Decorates a test to expose Chex variants.

The decorated test has access to a decorator called self.variant, which may be applied to functions to test different JAX behaviors. Consider:

@chex.variants(with_jit=True, without_jit=True)
def test(self):
  @self.variant
  def f(x, y):
    return x + y

  self.assertEqual(f(1, 2), 3)

In this example, the function test will be called twice: once with f jitted (i.e. using jax.jit) and another where f is not jitted.

Variants with_jit=True and with_pmap=True accept additional specific to them arguments. Example:

@chex.variants(with_jit=True)
def test(self):
  @self.variant(static_argnums=(1,))
  def f(x, y):
    # `y` is not traced.
    return x + y

  self.assertEqual(f(1, 2), 3)

Variant with_pmap=True also accepts broadcast_args_to_devices (whether to broadcast each input argument to all participating devices), reduce_fn (a function to apply to results of pmapped fn), and n_devices (number of devices to use in the pmap computation). See the docstring of _with_pmap for more details (including default values).

If used with absl.testing.parameterized, @chex.variants must wrap it:

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters('test', *args)
def test(self, *args):
  ...

Tests that use this wrapper must be inherited from parameterized.TestCase. For more examples see variants_test.py.

Parameters
  • test_method – A test method to decorate.

  • with_jit – Whether to test with jax.jit.

  • without_jit – Whether to test without jax.jit. Any jit compilation done within the test method will not be affected.

  • with_device – Whether to test with args placed on device, using jax.device_put.

  • without_device – Whether to test with args (explicitly) not placed on device, using jax.device_get.

  • with_pmap – Whether to test with jax.pmap, with computation duplicated across devices.

Returns

A decorated test_method.

chex.all_variants(test_method='__no__default__', with_jit: bool = True, without_jit: bool = True, with_device: bool = True, without_device: bool = True, with_pmap: bool = True) VariantsTestCaseGenerator

Equivalent to chex.variants but with flipped defaults.

chex.params_product(*params_lists, named=False)[source]

Generates a cartesian product of params_lists.

See tests from variants_test.py for examples of usage.

Parameters
  • *params_lists – A list of params combinations.

  • named (bool) – Whether to generate test names (for absl.parameterized.named_parameters(…)).

Return type

Sequence[Sequence[Any]]

Returns

A cartesian product of params_lists combinations.