diff --git a/chex/__init__.py b/chex/__init__.py index 5fffa61..1221c4c 100644 --- a/chex/__init__.py +++ b/chex/__init__.py @@ -43,10 +43,6 @@ from chex._src.asserts import assert_shape from chex._src.asserts import assert_size from chex._src.asserts import assert_tpu_available -from chex._src.asserts import assert_tree_all_close # Deprecated -from chex._src.asserts import assert_tree_all_equal_comparator # Deprecated -from chex._src.asserts import assert_tree_all_equal_shapes # Deprecated -from chex._src.asserts import assert_tree_all_equal_structs # Deprecated from chex._src.asserts import assert_tree_all_finite from chex._src.asserts import assert_tree_has_only_ndarrays from chex._src.asserts import assert_tree_is_on_device @@ -147,10 +143,6 @@ "assert_shape", "assert_size", "assert_tpu_available", - "assert_tree_all_close", # Deprecated - "assert_tree_all_equal_comparator", # Deprecated - "assert_tree_all_equal_shapes", # Deprecated - "assert_tree_all_equal_structs", # Deprecated "assert_tree_all_finite", "assert_tree_has_only_ndarrays", "assert_tree_is_on_device", diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 600bbfe..4f1e86d 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1257,6 +1257,9 @@ def assert_tree_shape_prefix(tree: ArrayTree, Args: tree: A tree to check. shape_prefix: An expected shape prefix. + + Raises: + AssertionError: If some leaf's shape doesn't start with ``shape_prefix``. """ # To compare with the leaf's `shape`, convert int sequence to tuple. shape_prefix = tuple(shape_prefix) @@ -1358,12 +1361,6 @@ def assert_trees_all_equal_structs(*trees: ArrayTree) -> None: f"\n tree {i}: {treedef}") -assert_tree_all_equal_structs = _ai.deprecation_wrapper( - assert_trees_all_equal_structs, - old_name="assert_tree_all_equal_structs", - new_name="assert_trees_all_equal_structs") - - @_static_assertion def assert_trees_all_equal_comparator(equality_comparator: _ai.TLeavesEqCmpFn, error_msg_fn: _ai.TLeavesEqCmpErrorFn, @@ -1410,12 +1407,6 @@ def tree_error_msg_fn(l_1: _ai.TLeaf, l_2: _ai.TLeaf, path: str, i_1: int, cmp_fn(path, *[leaves[leaf_i] for leaves in trees_leaves]) -assert_tree_all_equal_comparator = _ai.deprecation_wrapper( - assert_trees_all_equal_comparator, - old_name="assert_tree_all_equal_comparator", - new_name="assert_trees_all_equal_comparator") - - @_static_assertion def assert_trees_all_equal_dtypes(*trees: ArrayTree) -> None: """Checks that trees' leaves have the same dtype. @@ -1470,12 +1461,6 @@ def assert_trees_all_equal_shapes(*trees: ArrayTree) -> None: assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees) -assert_tree_all_equal_shapes = _ai.deprecation_wrapper( - assert_trees_all_equal_shapes, - old_name="assert_tree_all_equal_shapes", - new_name="assert_trees_all_equal_shapes") - - @_static_assertion def assert_trees_all_equal_shapes_and_dtypes(*trees: ArrayTree) -> None: """Checks that trees' leaves have the same shape and dtype. @@ -1665,11 +1650,6 @@ def _assert_trees_all_close_jittable(*trees: ArrayTree, jittable_assert_fn=_assert_trees_all_close_jittable, name="assert_trees_all_close") -assert_tree_all_close = _ai.deprecation_wrapper( - assert_trees_all_close, - old_name="assert_tree_all_close", - new_name="assert_trees_all_close") - def _assert_trees_all_close_ulp_static( *trees: ArrayTree, diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index 938add5..cfed87e 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -616,7 +616,7 @@ def _is_leaf(value) -> bool: (dcls.str_val, dcls.inner_dcls, dcls.dct['md1'], dcls.dct['md2']), leaves) - asserts.assert_tree_all_equal_structs( + asserts.assert_trees_all_equal_structs( jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls) def test_decorator_alias(self): @@ -649,7 +649,7 @@ class GenericDataclass(Generic[T]): a: T # pytype: disable=invalid-annotation # enable-bare-annotations obj = GenericDataclass(a=np.array([1.0, 1.0])) - asserts.assert_tree_all_close(obj.a, 1.0) + asserts.assert_trees_all_close(obj.a, 1.0) def test_mappable_eq_override(self):