Skip to content

Commit

Permalink
cov
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Nov 19, 2024
1 parent 8d09a92 commit 28d041b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions neurots/generate/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def _propose(self, extra_randomness=0, add_random_component=True):

if add_random_component:
random_component = self.params.randomness * get_random_point(random_generator=self._rng)
# this is needed only to get 100% reproducibility
if self.context.get("y_rotation") is not None:
random_component = self.context["y_rotation"].dot(random_component)
# this is needed only to get 100% invariance under y_direction
# if self.context.get("y_rotation") is not None:
# random_component = self.context["y_rotation"].dot(random_component)
if extra_randomness > 0: # pragma: no cover
random_component *= extra_randomness
direction += random_component
Expand Down
15 changes: 15 additions & 0 deletions tests/test_generate_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ def test_TreeGrower():
sections[len(sections) - num - 1],
)

grower = NeuronGrower(
input_distributions=distributions,
input_parameters=params,
context={"y_direction": [1, 0, 0]},
)

assert grower.context["y_direction"] == [1, 0, 0]
npt.assert_array_equal(
grower.context["y_rotation"], np.array([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
)
grower = NeuronGrower(
input_distributions=distributions, input_parameters=params, context="OTHER"
)
assert grower.context == "OTHER"


def test_TreeGrower_termination_length():
np.random.seed(0)
Expand Down

0 comments on commit 28d041b

Please sign in to comment.