Assertions#
|
Checks that |
|
Asserts that pass_fn(tensor.shape[axis]) passes. |
|
Checks that |
|
Checks that |
|
Checks that |
|
Checks that |
|
Checks that n devices of a given type are available. |
|
Checks that the two objects are equal as determined by the == operator. |
|
Checks that all arrays have the same rank. |
|
Checks that all arrays have the same size. |
|
Checks that all arrays have the same shape. |
|
Checks that the leading |
|
Checks that the final |
|
Checks that one and only one of the arguments is None. |
|
Checks that at least one GPU device is available. |
|
Checks that an array of |
|
Checks that |
|
Checks that a function is traced at most n times (inclusively). |
|
Checks that at least one of the arguments is not None. |
|
Checks that autodiff and numerical gradients of a function match. |
|
Checks that the rank of all inputs matches specified |
Checks that |
|
|
Checks that argument is a scalar within segment (by default). |
Checks that a scalar is negative. |
|
Checks that a scalar is non-negative. |
|
Checks that a scalar is positive. |
|
|
Checks that the size of all inputs matches specified |
|
Checks that the shape of all inputs matches specified |
|
Checks that at least one TPU device is available. |
|
Checks that all leaves in a tree are finite. |
Checks that all tree's leaves are n-dimensional arrays (tensors). |
|
|
Checks that all leaves are ndarrays residing in device memory (in HBM). |
|
Checks that all leaves are ndarrays residing in the host memory (on CPU). |
|
Checks that all leaves are ndarrays sharded across the specified devices. |
|
Checks that a tree does not contain None. |
|
Checks that all |
|
Checks that all |
|
Checks that all trees have leaves with approximately equal values. |
|
Checks that tree leaves differ by at most maxulp Units in the Last Place. |
|
Checks that all trees have leaves with exactly equal values. |
Checks that all trees are equal as per the custom comparator for leaves. |
|
|
Checks that trees' leaves have the same dtype. |
|
Checks that trees have the same structure and leaves' sizes. |
|
Checks that trees have the same structure and leaves' shapes. |
Checks that trees' leaves have the same shape and dtype. |
|
|
Checks that trees have the same structure. |
|
Checks that the type of all inputs matches specified |
|
Wraps a transformed function fn to enable Chex value assertions. |
|
A set of checks imported from checkify. |
|
An alias for chexify (see the docs). |
Waits until all asynchronous checks complete. |
|
|
A lightweight utility that maps strings to shape tuples. |
Disables all Chex assertions. |
|
Enables Chex assertions. |
|
Clears Chex traces' counter for |
|
|
Wrap chex assertion to only be evaluated if positional args not None. |
Jax Assertions#
- chex.assert_max_traces(fn: Callable[[...], Any] | int | None = None, n: Callable[[...], Any] | int | None = None)[source]#
Checks that a function is traced at most n times (inclusively).
JAX re-traces jitted functions every time the structure of passed arguments changes. Often this behaviour is inadvertent and leads to a significant performance drop which is hard to debug. This wrapper checks that the function is re-traced at most n times during program execution.
Examples:
@jax.jit @chex.assert_max_traces(n=1) def fn_sum_jitted(x, y): return x + y def fn_sub(x, y): return x - y fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10)
- More about tracing:
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
- Parameters:
fn – A pure python function to wrap (i.e. it must not be a jitted function).
n – The maximum allowed number of retraces (non-negative).
- Returns:
Decorated function that raises exception when it is re-traced n+1-st time.
- Raises:
ValueError – If
fn
has already been jitted.
- chex.assert_devices_available(n: int, devtype: str, backend: str | None = None, not_less_than: bool = False) None [source]#
Checks that n devices of a given type are available.
- Parameters:
n – A required number of devices of the given type.
devtype – A type of devices, one of
{'cpu', 'gpu', 'tpu'}
.backend – A type of backend to use (uses Jax default if not provided).
not_less_than – Whether to check if the number of devices is not less than n, instead of precise comparison.
- Raises:
AssertionError – If number of available device of a given type is not equal or less than n.
Value (Runtime) Assertions#
- chex.chexify(fn: Callable[..., Any], async_check: bool = True, errors: FrozenSet[checkify.ErrorCategory] = frozenset({<class 'jax._src.checkify.FailedCheckError'>})) Callable[..., Any] [source]#
Wraps a transformed function fn to enable Chex value assertions.
Chex value/runtime assertions access concrete values of tensors (e.g. assert_tree_all_finite) which are not available during JAX tracing, see https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html and https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.
This wrapper enables them in jitted/pmapped functions by performing a specifically designed JAX transformation https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation and calling functionalised checks https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html
Example:
@chex.chexify @jax.jit def logp1_abs_safe(x: chex.Array) -> chex.Array: chex.assert_tree_all_finite(x) return jnp.log(jnp.abs(x) + 1) logp1_abs_safe(jnp.ones(2)) # OK logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS logp1_abs_safe.wait_checks()
Note 1: This wrapper allows identifying the first failed assertion in a jitted code by printing a pointer to the line where the failed assertion was invoked. For getting verbose messages (including concrete tensor values), an unjitted version of the code will need to be executed with the same input values. Chex does not currently provide tools to help with this.
Note 2: This wrapper fully supports asynchronous executions (see https://jax.readthedocs.io/en/latest/async_dispatch.html). To block program execution until asynchronous checks for a _chexified_ function fn complete, call fn.wait_checks(). Similarly, chex.block_until_chexify_assertions_complete() will block program execution until _all_ asyncronous checks complete.
Note 3: Chex automatically selects the backend for executing its assertions (i.e. CPU or device accelerator) depending on the program context.
Note 4: Value assertions can have impact on the performance of a function, see https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations
Note 5: static assertions, such as assert_shape or assert_trees_all_equal_dtypes, can be called from a jitted function without chexify wrapper (since they do not access concrete values, only shapes and/or dtypes which are available during JAX tracing).
More examples can be found at deepmind/chex
- Parameters:
fn – A transformed function to wrap.
async_check – Whether to check errors in the async dispatch mode. See https://jax.readthedocs.io/en/latest/async_dispatch.html.
errors – A set of checkify.ErrorCategory values which defines the set of enabled checks. By default only explicit
checks
are enabled (user). You can also for example enable NaN and Div-by-0 errors by passing the float set, or for example combine multiple sets through set operations (float | user).
- Returns:
A _chexified_ function, i.e. the one with enabled value assertions. The returned function has wait_checks() method that blocks the caller until all pending async checks complete.
|
A set of checks imported from checkify. |
Tree Assertions#
- chex.assert_tree_all_finite(tree_like: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that all leaves in a tree are finite.
- Parameters:
tree_like – A pytree with array leaves.
- Raises:
AssertionError – If any leaf in
tree_like
is non-finite.
- chex.assert_tree_has_only_ndarrays(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that all tree’s leaves are n-dimensional arrays (tensors).
- Parameters:
tree – A tree to assert.
- Raises:
AssertionError – If the tree contains an object which is not an ndarray.
- chex.assert_tree_is_on_device(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, platform: Sequence[str] | str = ('gpu', 'tpu'), device: Device | None = None) None [source]#
Checks that all leaves are ndarrays residing in device memory (in HBM).
Sharded DeviceArrays are disallowed.
- Parameters:
tree – A tree to assert.
platform – A platform or a list of platforms where the leaves are expected to reside. Ignored if device is specified.
device – An optional device where the tree’s arrays are expected to reside. Any device (except CPU) is accepted if not specified.
- Raises:
AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on the specified device or platform.
- chex.assert_tree_is_on_host(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, allow_cpu_device: bool = True, allow_sharded_arrays: bool = False) None [source]#
Checks that all leaves are ndarrays residing in the host memory (on CPU).
This assertion only accepts trees consisting of ndarrays.
- Parameters:
tree – A tree to assert.
allow_cpu_device – Whether to allow JAX arrays that reside on a CPU device.
allow_sharded_arrays – Whether to allow sharded JAX arrays. Sharded arrays are considered “on host” only if they are sharded across CPU devices and allow_cpu_device is True.
- Raises:
AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on host.
- chex.assert_tree_is_sharded(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, devices: Sequence[Device]) None [source]#
Checks that all leaves are ndarrays sharded across the specified devices.
- Parameters:
tree – A tree to assert.
devices – A list of devices which the tree’s leaves are expected to be sharded across. This list is order-sensitive.
- Raises:
AssertionError – If the tree contains a leaf that is not a device array sharded across the specified devices.
- chex.assert_tree_no_nones(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that a tree does not contain None.
- Parameters:
tree – A tree to assert.
- Raises:
AssertionError – If the tree contains at least one None.
- chex.assert_tree_shape_prefix(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], shape_prefix: Sequence[int]) None [source]#
Checks that all
tree
leaves’ shapes have the same prefix.- Parameters:
tree – A tree to check.
shape_prefix – An expected shape prefix.
- Raises:
AssertionError – If some leaf’s shape doesn’t start with
shape_prefix
.
- chex.assert_tree_shape_suffix(tree: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], shape_suffix: Sequence[int]) None [source]#
Checks that all
tree
leaves’ shapes have the same suffix.- Parameters:
tree – A tree to check.
shape_suffix – An expected shape suffix.
- Raises:
AssertionError – If some leaf’s shape doesn’t end with
shape_suffix
.
- chex.assert_trees_all_close(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], rtol: float = 1e-06, atol: float = 0.0) None [source]#
Checks that all trees have leaves with approximately equal values.
- This compares the difference between values of actual and desired up to
atol + rtol * abs(desired)
.
- Parameters:
*trees – A sequence of (at least 2) trees with array leaves.
rtol – A relative tolerance.
atol – An absolute tolerance.
- Raises:
AssertionError – If actual and desired values are not equal up to specified tolerance.
- chex.assert_trees_all_close_ulp(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], maxulp: int = 1) None [source]#
Checks that tree leaves differ by at most maxulp Units in the Last Place.
This is the Chex version of np.testing.assert_array_max_ulp.
Assertions on floating point values are tricky because the precision varies depending on the value. For example, with float32, the precision at 1 is np.spacing(np.float32(1.0)) ≈ 1e-7, but the precision at 5,000,000 is only np.spacing(np.float32(5e6)) = 0.5. This makes it hard to predict ahead of time what tolerance to use when checking whether two numbers are equal: a difference of only a couple of bits can equate to an arbitrarily large absolute difference.
Assertions based on _relative_ differences are one solution to this problem, but have the disadvantage that it’s hard to choose the tolerance. If you want to verify that two calculations produce _exactly_ the same result modulo the inherent non-determinism of floating point operations, do you set the tolerance to…0.01? 0.001? It’s hard to be sure you’ve set it low enough that you won’t miss one of your computations being slightly wrong.
Assertions based on ‘units in the last place’ (ULP) instead solve this problem by letting you specify tolerances in terms of the precision actually available at the current scale of your values. The ULP at some value x is essentially the spacing between the floating point numbers actually representable in the vicinity of x - equivalent to the ‘precision’ we discussed above. above. With a tolerance of, say, maxulp=5, you’re saying that two values are within 5 actually-representable-numbers of each other - a strong guarantee that two computations are as close as possible to identical, while still allowing reasonable wiggle room for small differences due to e.g. different operator orderings.
Note that this function is not currently supported within JIT contexts, and does not currently support bfloat16 dtypes.
- Parameters:
*trees – A sequence of (at least 2) trees with array leaves.
maxulp – The maximum number of ULPs by which leaves may differ.
- Raises:
AssertionError – If actual and desired values are not equal up to specified tolerance.
- chex.assert_trees_all_equal(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], strict: bool = False) None [source]#
Checks that all trees have leaves with exactly equal values.
If you are comparing floating point numbers, an exact equality check may not be appropriate; consider using
assert_trees_all_close
.- Parameters:
*trees – A sequence of (at least 2) trees with array leaves.
strict – If True, disable special scalar handling as described in np.testing.assert_array_equals notes section.
- Raises:
AssertionError – If the leaf values actual and desired are not exactly equal.
- chex.assert_trees_all_equal_comparator(equality_comparator: Callable[[Any, Any], bool], error_msg_fn: Callable[[Any, Any], str], *trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that all trees are equal as per the custom comparator for leaves.
- Parameters:
equality_comparator – A custom function that accepts two leaves and checks whether they are equal. Expected to be transitive.
error_msg_fn – A function accepting two unequal as per
equality_comparator
leaves and returning an error message.*trees – A sequence of (at least 2) trees to check on equality as per
equality_comparator
.
- Raises:
ValueError – If
trees
does not contain at least 2 elements.AssertionError – if
equality_comparator
returns False for any pair of trees fromtrees
.
- chex.assert_trees_all_equal_dtypes(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that trees’ leaves have the same dtype.
- Parameters:
*trees – A sequence of (at least 2) trees to check.
- Raises:
AssertionError – If leaves’ dtypes for any two trees differ.
- chex.assert_trees_all_equal_sizes(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that trees have the same structure and leaves’ sizes.
- Parameters:
*trees – A sequence of (at least 2) trees with array leaves.
- Raises:
AssertionError – If trees’ structures or leaves’ sizes are different.
- chex.assert_trees_all_equal_shapes(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that trees have the same structure and leaves’ shapes.
- Parameters:
*trees – A sequence of (at least 2) trees with array leaves.
- Raises:
AssertionError – If trees’ structures or leaves’ shapes are different.
- chex.assert_trees_all_equal_shapes_and_dtypes(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that trees’ leaves have the same shape and dtype.
- Parameters:
*trees – A sequence of (at least 2) trees to check.
- Raises:
AssertionError – If leaves’ shapes or dtypes for any two trees differ.
- chex.assert_trees_all_equal_structs(*trees: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None [source]#
Checks that trees have the same structure.
- Parameters:
*trees – A sequence of (at least 2) trees to assert equal structure between.
- Raises:
ValueError – If
trees
does not contain at least 2 elements.AssertionError – If structures of any two trees are different.
Generic Assertions#
- chex.assert_axis_dimension(tensor: Array | ndarray | bool | number, axis: int, expected: int) None [source]#
Checks that
tensor.shape[axis] == expected
.- Parameters:
tensor – A JAX array.
axis – An integer specifying which axis to assert.
expected – An expected value of
tensor.shape[axis]
.
- Raises:
AssertionError – The dimension of the specified axis does not match the prescribed value.
- chex.assert_axis_dimension_comparator(tensor: Array | ndarray | bool | number, axis: int, pass_fn: Callable[[int], bool], error_string: str)[source]#
Asserts that pass_fn(tensor.shape[axis]) passes.
Used to implement ==, >, >=, <, <= checks.
- Parameters:
tensor – A JAX array.
axis – An integer specifying which axis to assert.
pass_fn – A callable which takes the size of the give dimension and returns false when the assertion should fail.
error_string – string which is inserted in assertion failure messages - ‘expected tensor to have dimension {error_string} on axis …’.
- Raises:
AssertionError – if pass_fn(tensor.shape[axis], val) does not return true.
- chex.assert_axis_dimension_gt(tensor: Array | ndarray | bool | number, axis: int, val: int) None [source]#
Checks that
tensor.shape[axis] > val
.- Parameters:
tensor – A JAX array.
axis – An integer specifying which axis to assert.
val – A value
tensor.shape[axis]
must be greater than.
- Raises:
AssertionError – if the dimension of
axis
is <=val
.
- chex.assert_axis_dimension_gteq(tensor: Array | ndarray | bool | number, axis: int, val: int) None [source]#
Checks that
tensor.shape[axis] >= val
.- Parameters:
tensor – A JAX array.
axis – An integer specifying which axis to assert.
val – A value
tensor.shape[axis]
must be greater than or equal to.
- Raises:
AssertionError – if the dimension of
axis
is <val
.
- chex.assert_axis_dimension_lt(tensor: Array | ndarray | bool | number, axis: int, val: int) None [source]#
Checks that
tensor.shape[axis] < val
.- Parameters:
tensor – A JAX Array.
axis – An integer specifiying with axis to assert.
val – A value
tensor.shape[axis]
must be less than.
- Raises:
AssertionError – if the dimension of
axis
is >=val
.
- chex.assert_axis_dimension_lteq(tensor: Array | ndarray | bool | number, axis: int, val: int) None [source]#
Checks that
tensor.shape[axis] <= val
.- Parameters:
tensor – A JAX array.
axis – An integer specifying which axis to assert.
val – A value
tensor.shape[axis]
must be less than or equal to.
- Raises:
AssertionError – if the dimension of
axis
is >val
.
- chex.assert_equal(first: Any, second: Any) None [source]#
Checks that the two objects are equal as determined by the == operator.
Arrays with more than one element cannot be compared. Use
assert_trees_all_close
to compare arrays.- Parameters:
first – A first object.
second – A second object.
- Raises:
AssertionError – If not
(first == second)
.
- chex.assert_equal_rank(inputs: Sequence[Array | ndarray | bool | number]) None [source]#
Checks that all arrays have the same rank.
- Parameters:
inputs – A collection of arrays.
- Raises:
AssertionError – If the ranks of all arrays do not match.
ValueError – If
inputs
is not a collection of arrays.
- chex.assert_equal_size(inputs: Sequence[Array | ndarray | bool | number]) None [source]#
Checks that all arrays have the same size.
- Parameters:
inputs – A collection of arrays.
- Raises:
AssertionError – If the size of all arrays do not match.
- chex.assert_equal_shape(inputs: Sequence[Array | ndarray | bool | number], *, dims: int | Sequence[int] | None = None) None [source]#
Checks that all arrays have the same shape.
- Parameters:
inputs – A collection of arrays.
dims – An optional integer or sequence of integers. If not provided, every dimension of every shape must match. If provided, equality of shape will only be asserted for the specified dim(s), i.e. to ensure all of a group of arrays have the same size in the first two dimensions, call
assert_equal_shape(tensors_list, dims=(0, 1))
.
- Raises:
AssertionError – If the shapes of all arrays at specified dims do not match.
ValueError – If the provided
dims
are invalid indices into any of arrays; or ifinputs
is not a collection of arrays.
- chex.assert_equal_shape_prefix(inputs: Sequence[Array | ndarray | bool | number], prefix_len: int) None [source]#
Checks that the leading
prefix_dims
dims of all inputs have same shape.- Parameters:
inputs – A collection of input arrays.
prefix_len – A number of leading dimensions to compare; each input’s shape will be sliced to
shape[:prefix_len]
. Negative values are accepted and have the conventional Python indexing semantics.
- Raises:
AssertionError – If the shapes of all arrays do not match.
ValuleError – If
inputs
is not a collection of arrays.
- chex.assert_equal_shape_suffix(inputs: Sequence[Array | ndarray | bool | number], suffix_len: int) None [source]#
Checks that the final
suffix_len
dims of all inputs have same shape.- Parameters:
inputs – A collection of input arrays.
suffix_len – A number of trailing dimensions to compare; each input’s shape will be sliced to
shape[-suffix_len:]
. Negative values are accepted and have the conventional Python indexing semantics.
- Raises:
AssertionError – If the shapes of all arrays do not match.
ValuleError – If
inputs
is not a collection of arrays.
- chex.assert_exactly_one_is_none(first: Any, second: Any) None [source]#
Checks that one and only one of the arguments is None.
- Parameters:
first – A first object.
second – A second object.
- Raises:
AssertionError – If
(first is None) xor (second is None)
is False.
- chex.assert_is_broadcastable(shape_a: Sequence[int], shape_b: Sequence[int]) None [source]#
Checks that an array of
shape_a
is broadcastable to one ofshape_b
.- Parameters:
shape_a – A shape of the array to check.
shape_b – A target shape after broadcasting.
- Raises:
AssertionError – If
shape_a
is not broadcastable toshape_b
.
- chex.assert_is_divisible(numerator: int, denominator: int) None [source]#
Checks that
numerator
is divisible bydenominator
.- Parameters:
numerator – A numerator.
denominator – A denominator.
- Raises:
AssertionError – If
numerator
is not divisible bydenominator
.
- chex.assert_not_both_none(first: Any, second: Any) None [source]#
Checks that at least one of the arguments is not None.
- Parameters:
first – A first object.
second – A second object.
- Raises:
AssertionError – If
(first is None) and (second is None)
.
- chex.assert_numerical_grads(f: Callable[[...], Array | ndarray | bool | number], f_args: Sequence[Array | ndarray | bool | number], order: int, atol: float = 0.01, **check_kwargs) None [source]#
Checks that autodiff and numerical gradients of a function match.
- Parameters:
f – A function to check.
f_args – Arguments of the function.
order – An order of gradients.
atol – An absolute tolerance.
**check_kwargs – Kwargs for
jax_test.check_grads
.
- Raises:
AssertionError – If automatic differentiation gradients deviate from finite difference gradients.
- chex.assert_rank(inputs: float | int | Array | ndarray | bool | number | Sequence[Array | ndarray | bool | number], expected_ranks: int | Set[int] | Sequence[int | Set[int]]) None [source]#
Checks that the rank of all inputs matches specified
expected_ranks
.Valid usages include:
assert_rank(x, 0) # x is scalar assert_rank(x, 2) # x is a rank-2 array assert_rank(x, {0, 2}) # x is scalar or rank-2 array assert_rank([x, y], 2) # x and y are rank-2 arrays assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array assert_rank([x, y], {0, 2}) # x and y are scalar or rank-2 arrays
- Parameters:
inputs – An array or a sequence of arrays.
expected_ranks – A sequence of expected ranks associated with each input, where the expected rank is either an integer or set of integer options; if all inputs have same rank, a single scalar or set of scalars may be passed as
expected_ranks
.
- Raises:
AssertionError – If lengths of
inputs
andexpected_ranks
don’t match; ifexpected_ranks
has wrong type; if the ranks ofinputs
do not matchexpected_ranks
.ValueError – If
expected_ranks
is not an integer and not a sequence of integets.
- chex.assert_scalar(x: float | int) None [source]#
Checks that
x
is a scalar, as defined in pytypes.py (int or float).- Parameters:
x – An object to check.
- Raises:
AssertionError – If
x
is not a scalar as per definition in pytypes.py.
- chex.assert_scalar_in(x: Any, min_: float | int, max_: float | int, included: bool = True) None [source]#
Checks that argument is a scalar within segment (by default).
- Parameters:
x – An object to check.
min – A left border of the segment.
max – A right border of the segment.
included – Whether to include the borders of the segment in the set of allowed values.
- Raises:
AssertionError – If
x
is not a scalar; ifx
falls out of the segment.
- chex.assert_scalar_negative(x: float | int) None [source]#
Checks that a scalar is negative.
- Parameters:
x – A value to check.
- Raises:
AssertionError – If
x
is not a scalar or strictly negative.
- chex.assert_scalar_non_negative(x: float | int) None [source]#
Checks that a scalar is non-negative.
- Parameters:
x – A value to check.
- Raises:
AssertionError – If
x
is not a scalar or negative.
- chex.assert_scalar_positive(x: float | int) None [source]#
Checks that a scalar is positive.
- Parameters:
x – A value to check.
- Raises:
AssertionError – If
x
is not a scalar or strictly positive.
- chex.assert_size(inputs: float | int | Array | ndarray | bool | number | Sequence[Array | ndarray | bool | number], expected_sizes: Sequence[int | Set[int] | ellipsis | None] | Sequence[Sequence[int | Set[int] | ellipsis | None]]) None [source]#
Checks that the size of all inputs matches specified
expected_sizes
.Valid usages include:
assert_size(x, 1) # x is scalar (size 1) assert_size([x, y], (2, {1, 3})) # x has size 2, y has size 1 or 3 assert_size([x, y], (2, ...)) # x has size 2, y has any size assert_size([x, y], 1) # x and y are scalar (size 1) assert_size((x, y), (5, 2)) # x has size 5, y has size 2
- Parameters:
inputs – An array or a sequence of arrays.
expected_sizes – A sqeuence of expected sizes associated with each input, where the expected size is a sequence of integer and None dimensions; if all inputs have same size, a single size may be passed as
expected_sizes
.
- Raises:
AssertionError – If the lengths of
inputs
andexpected_sizes
do not match; ifexpected_sizes
has wrong type; if size ofinput
does not matchexpected_sizes
.
- chex.assert_shape(inputs: float | int | Array | ndarray | bool | number | Sequence[Array | ndarray | bool | number], expected_shapes: Sequence[int | Set[int] | ellipsis | None] | Sequence[Sequence[int | Set[int] | ellipsis | None]]) None [source]#
Checks that the shape of all inputs matches specified
expected_shapes
.Valid usages include:
assert_shape(x, ()) # x is scalar assert_shape(x, (2, 3)) # x has shape (2, 3) assert_shape(x, (2, {1, 3})) # x has shape (2, 1) or (2, 3) assert_shape(x, (2, None)) # x has rank 2 and `x.shape[0] == 2` assert_shape(x, (2, ...)) # x has rank >= 1 and `x.shape[0] == 2` assert_shape([x, y], ()) # x and y are scalar assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
- Parameters:
inputs – An array or a sequence of arrays.
expected_shapes – A sequence of expected shapes associated with each input, where the expected shape is a sequence of integer and None dimensions; if all inputs have same shape, a single shape may be passed as
expected_shapes
.
- Raises:
AssertionError – If the lengths of
inputs
andexpected_shapes
do not match; ifexpected_shapes
has wrong type; if shape ofinput
does not matchexpected_shapes
.
- chex.assert_type(inputs: float | int | Array | ndarray | bool | number | Sequence[Array | ndarray | bool | number], expected_types: str | type[Any] | dtype | SupportsDType | Sequence[str | type[Any] | dtype | SupportsDType]) None [source]#
Checks that the type of all inputs matches specified
expected_types
.If the expected type is a Python type or abstract dtype (e.g. np.floating), assert that the input has the same sub-type. If the expected type is a concrete dtype (e.g. np.float32), assert that the input’s type is the same.
Example usage:
assert_type(7, int) assert_type(7.1, float) assert_type(False, bool) assert_type([7, 8], int) assert_type([7, 7.1], [int, float]) assert_type(np.array(7), int) assert_type(np.array(7.1), float) assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float]) assert_type(jnp.array(1., dtype=jnp.bfloat16)), jnp.bfloat16) assert_type(jnp.ones(1, dtype=np.int8), np.int8)
- Parameters:
inputs – An array or a sequence of arrays or scalars.
expected_types – A sequence of expected types associated with each input; if all inputs have same type, a single type may be passed as
expected_types
.
- Raises:
AssertionError – If lengths of
inputs
andexpected_types
don’t match; ifexpected_types
contains unsupported pytype; if the types of inputs do not match the expected types.
Shapes and Named Dimensions#
- class chex.Dimensions(**dim_sizes)[source]#
A lightweight utility that maps strings to shape tuples.
The most basic usage is:
>>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters. >>> dims['NBT'] (7, 3, 5)
This is useful when dealing with many differently shaped arrays. For instance, let’s check the shape of this array:
>>> x = jnp.array([[2, 0, 5, 6, 3], ... [5, 4, 4, 3, 3], ... [0, 0, 5, 2, 0]]) >>> chex.assert_shape(x, dims['BT'])
The dimension sizes can be gotten directly, e.g.
dims.N == 7
. This can be useful in many applications. For instance, let’s one-hot encode our array.>>> y = jax.nn.one_hot(x, dims.N) >>> chex.assert_shape(y, dims['BTN'])
You can also store the shape of a given array in
dims
, e.g.>>> z = jnp.array([[0, 6, 0, 2], ... [4, 2, 2, 4]]) >>> dims['XY'] = z.shape >>> dims Dimensions(B=3, N=7, T=5, X=2, Y=4)
You can access the flat size of a shape as
>>> dims.size('BT') # Same as prod(dims['BT']). 15
Similarly, you can flatten axes together by wrapping them in parentheses:
>>> dims['(BT)N'] (15, 7)
You can set a wildcard dimension, cf.
chex.assert_shape()
:>>> dims.W = None >>> dims['BTW'] (3, 5, None)
Or you can use the wildcard character ‘*’ directly:
>>> dims['BT*'] (3, 5, None)
Single digits are interpreted as literal integers. Note that this notation is limited to single-digit literals.
>>> dims['BT123'] (3, 5, 1, 2, 3)
Support for single digits was mainly included to accommodate dummy axes introduced for consistent broadcasting. For instance, instead of using
jnp.expand_dims
you could do the following:>>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5) Traceback (most recent call last): ... ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5)) >>> w = y * x.reshape(dims['BT1']) >>> chex.assert_shape(w, dims['BTN'])
Sometimes you only care about some array dimensions but not all. You can use an underscore to ignore an axis, e.g.
>>> chex.assert_rank(y, 3) >>> dims['__M'] = y.shape # Skip the first two axes.
Finally note that a single-character key returns a tuple of length one.
>>> dims['M'] (7,)
Utils#
Warnings#
- chex.create_deprecated_function_alias(fun, new_name, deprecated_alias)[source]#
Create a deprecated alias for a function.
Example usage: >>> g = create_deprecated_function_alias(f, ‘path.f’, ‘path.g’)
- Parameters:
fun – the deprecated function.
new_name – the new name to use (you may include the path for clarity).
deprecated_alias – the old name (you may include the path for clarity).
- Returns:
the wrapped function.
- chex.warn_deprecated_function(fun: Callable[[...], Any], replacement: str | None = None) Callable[[...], Any] [source]#
A decorator to mark a function definition as deprecated.
Example usage: >>> @functools.partial(chex.warn_deprecated_function, replacement=’g’) … def f(a, b): … return a + b
- Parameters:
fun – the deprecated function.
replacement – name of the function to be used instead.
- Returns:
the wrapped function.
- chex.warn_keyword_args_only_in_future(fun, *, n=0)#
Warns if more than
n
positional arguments are passed tofun
.For instance: >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1) … def f(a, b, c=1): … return a + b + c
Will raise a DeprecationWarning if
f
is called with more than one positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning).- Parameters:
fun – the function to wrap.
n – the number of positional arguments to allow.
- Returns:
A wrapped function that emits a warning if more than n positional arguments are passed.
- chex.warn_only_n_pos_args_in_future(fun, n)[source]#
Warns if more than
n
positional arguments are passed tofun
.For instance: >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1) … def f(a, b, c=1): … return a + b + c
Will raise a DeprecationWarning if
f
is called with more than one positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning).- Parameters:
fun – the function to wrap.
n – the number of positional arguments to allow.
- Returns:
A wrapped function that emits a warning if more than n positional arguments are passed.
Backend restriction#
- chex.restrict_backends(*, allowed: Sequence[str] | None = None, forbidden: Sequence[str] | None = None)[source]#
Disallows JAX compilation for certain backends.
- Parameters:
allowed – Names of backend platforms (e.g. ‘cpu’ or ‘tpu’) for which compilation is still to be permitted.
forbidden – Names of backend platforms for which compilation is to be forbidden.
- Yields:
None, in a context where compilation for forbidden platforms will raise a RestrictedBackendError.
- Raises:
ValueError – if neither allowed nor forbidden is specified (i.e. they are both None), or if anything is both allowed and forbidden.
Dataclasses#
- chex.dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, mappable_dataclass=True)[source]#
JAX-friendly wrapper for
dataclasses.dataclass()
.This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass()
.repr – See
dataclasses.dataclass()
.eq – See
dataclasses.dataclass()
.order – See
dataclasses.dataclass()
.unsafe_hash – See
dataclasses.dataclass()
.frozen – See
dataclasses.dataclass()
.kw_only – See
dataclasses.dataclass()
.mappable_dataclass – If True (the default), methods to make the class implement the
collections.abc.Mapping
interface will be generated and the class will includecollections.abc.Mapping
in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.
- Returns:
A JAX-friendly dataclass.
- chex.mappable_dataclass(cls)[source]#
Exposes dataclass as
collections.abc.Mapping
descendent.Allows to traverse dataclasses in methods from dm-tree library.
NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren’t supported; however can use generators/iterables).
- Parameters:
cls – A dataclass to mutate.
- Returns:
Mutated dataclass implementing
collections.abc.Mapping
interface.
- chex.register_dataclass_type_with_jax_tree_util(data_class)[source]#
Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.
- Parameters:
data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.
Fakes#
|
Context manager for patching jax.jit with the identity function. |
|
Context manager for patching jax.pmap with jax.vmap. |
|
Context manager for patching jax.jit and jax.pmap. |
|
Forces XLA to use n CPU threads as host devices. |
Transformations#
- chex.fake_jit(enable_patching: bool = True) FakeContext [source]#
Context manager for patching jax.jit with the identity function.
This is intended to be used as a debugging tool to programmatically enable or disable JIT compilation.
Can be used either as a context managed scope:
with chex.fake_jit(): @jax.jit def foo(x): ...
or by calling start and stop:
fake_jit_context = chex.fake_jit() fake_jit_context.start() @jax.jit def foo(x): ... fake_jit_context.stop()
- Parameters:
enable_patching – Whether to patch jax.jit.
- Returns:
Context where jax.jit is patched with the identity function jax is configured to avoid jitting internally whenever possible in functions such as jax.lax.scan, etc.
- chex.fake_pmap(enable_patching: bool = True, jit_result: bool = False, ignore_axis_index_groups: bool = False, fake_parallel_axis: bool = False) FakeContext [source]#
Context manager for patching jax.pmap with jax.vmap.
This is intended to be used as a debugging tool to programmatically replace pmap transformations with a non-parallel vmap transformation.
Can be used either as a context managed scope:
with chex.fake_pmap(): @jax.pmap def foo(x): ...
or by calling start and stop:
fake_pmap_context = chex.fake_pmap() fake_pmap_context.start() @jax.pmap def foo(x): ... fake_pmap_context.stop()
- Parameters:
enable_patching – Whether to patch jax.pmap.
jit_result – Whether the transformed function should be jitted despite not being pmapped.
ignore_axis_index_groups – Whether to force any parallel operation within the context to set axis_index_groups to be None. This is a compatibility option to allow users of the axis_index_groups parameter to run under the fake_pmap context. This feature is not currently supported in vmap, and will fail, so we force the parameter to be None. Warning: This will produce different results to running under jax.pmap
fake_parallel_axis – Fake a parallel axis
- Returns:
Context where jax.pmap is patched with jax.vmap.
- chex.fake_pmap_and_jit(enable_pmap_patching: bool = True, enable_jit_patching: bool = True) FakeContext [source]#
Context manager for patching jax.jit and jax.pmap.
This is a convenience function, equivalent to nested chex.fake_pmap and chex.fake_jit contexts.
Note that calling (the true implementation of) jax.pmap will compile the function, so faking jax.jit in this case will not stop the function from being compiled.
- Parameters:
enable_pmap_patching – Whether to patch jax.pmap.
enable_jit_patching – Whether to patch jax.jit.
- Returns:
Context where jax.pmap and jax.jit are patched with jax.vmap and the identity function
Devices#
- chex.set_n_cpu_devices(n: int | None = None) None [source]#
Forces XLA to use n CPU threads as host devices.
This allows jax.pmap to be tested on a single-CPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including jax.devices() etc.). See google/jax#1408.
- Parameters:
n – A required number of CPU devices (
FLAGS.chex_n_cpu_devices
is used by default).- Raises:
RuntimeError – If XLA backends were already initialized.
Pytypes#
|
|
|
alias of |
|
alias of |
|
alias of |
|
alias of |
|
alias of |
|
alias of |
|
alias of |
|
alias of |
|
A descriptor of an available device. |
|
|
|
alias of |
|
|
|
alias of |
|
alias of |
Variants#
- class chex.ChexVariantType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
An enumeration of available Chex variants.
Use
self.variant.type
to get type of the current test variant. See the docstring ofchex.variants
for more information.
- class chex.TestCase(*args, **kwargs)[source]#
A class for Chex tests that use variants.
See the docstring for
chex.variants
for more information.Note:
chex.variants
returns a generator producing one test per variant. Therefore, the used test class must support dynamic unrolling of these generators during module import. It is implemented (and battle-tested) inabsl.parameterized.TestCase
, and here we subclass from it.
- chex.variants(test_method='__no__default__', with_jit: bool = False, without_jit: bool = False, with_device: bool = False, without_device: bool = False, with_pmap: bool = False) VariantsTestCaseGenerator #
Decorates a test to expose Chex variants.
The decorated test has access to a decorator called
self.variant
, which may be applied to functions to test different JAX behaviors. Consider:@chex.variants(with_jit=True, without_jit=True) def test(self): @self.variant def f(x, y): return x + y self.assertEqual(f(1, 2), 3)
In this example, the function
test
will be called twice: once with f jitted (i.e. using jax.jit) and another where f is not jitted.Variants with_jit=True and with_pmap=True accept additional specific to them arguments. Example:
@chex.variants(with_jit=True) def test(self): @self.variant(static_argnums=(1,)) def f(x, y): # `y` is not traced. return x + y self.assertEqual(f(1, 2), 3)
Variant with_pmap=True also accepts broadcast_args_to_devices (whether to broadcast each input argument to all participating devices), reduce_fn (a function to apply to results of pmapped fn), and n_devices (number of devices to use in the pmap computation). See the docstring of _with_pmap for more details (including default values).
If used with
absl.testing.parameterized
, @chex.variants must wrap it:@chex.variants(with_jit=True, without_jit=True) @parameterized.named_parameters('test', *args) def test(self, *args): ...
Tests that use this wrapper must be inherited from
parameterized.TestCase
. For more examples seevariants_test.py
.- Parameters:
test_method – A test method to decorate.
with_jit – Whether to test with jax.jit.
without_jit – Whether to test without jax.jit. Any jit compilation done within the test method will not be affected.
with_device – Whether to test with args placed on device, using jax.device_put.
without_device – Whether to test with args (explicitly) not placed on device, using jax.device_get.
with_pmap – Whether to test with jax.pmap, with computation duplicated across devices.
- Returns:
A decorated
test_method
.
- chex.all_variants(test_method='__no__default__', with_jit: bool = True, without_jit: bool = True, with_device: bool = True, without_device: bool = True, with_pmap: bool = True) VariantsTestCaseGenerator #
Equivalent to
chex.variants
but with flipped defaults.
- chex.params_product(*params_lists: Sequence[Sequence[Any]], named: bool = False) Sequence[Sequence[Any]] [source]#
Generates a cartesian product of params_lists.
See tests from
variants_test.py
for examples of usage.- Parameters:
*params_lists – A list of params combinations.
named – Whether to generate test names (for absl.parameterized.named_parameters(…)).
- Returns:
A cartesian product of params_lists combinations.