From 69be2424b1f71d93ae4d5402c1d18b2e0aa7df78 Mon Sep 17 00:00:00 2001 From: Alexis Arnaudon Date: Thu, 28 Sep 2023 15:19:26 +0200 Subject: [PATCH] Fix: Allow for normal orientation with basals (#90) --- neurots/generate/orientations.py | 45 ++++++++++++++++++++------------ tests/test_orientations.py | 6 ++--- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/neurots/generate/orientations.py b/neurots/generate/orientations.py index 10727c54..458a7392 100644 --- a/neurots/generate/orientations.py +++ b/neurots/generate/orientations.py @@ -222,7 +222,7 @@ def _mode_uniform(self, _, tree_type): [sample.sample_spherical_unit_vectors(rng=self._rng) for _ in range(n_orientations)] ) - def _mode_normal_pia_constraint(self, values_dict, _): + def _mode_normal_pia_constraint(self, values_dict, tree_type): """Returns orientations using normal/exp distribution along a direction. The `direction` value should be a dict with two entries: `mean` and `std`. The mean is the @@ -234,24 +234,35 @@ def _mode_normal_pia_constraint(self, values_dict, _): Pia direction can be overwritten by the parameter 'pia_direction' value. """ - means = values_dict["direction"]["mean"] - means = means if isinstance(means, list) else [means] - stds = values_dict["direction"]["std"] - stds = stds if isinstance(stds, list) else [stds] - thetas = [] - for mean, std in zip(means, stds): - if mean == 0: - if std > 0: - thetas.append(np.clip(self._rng.exponential(std), 0, np.pi)) + n_orientations = sample.n_neurites(self._distributions[tree_type]["num_trees"], self._rng) + if ( + isinstance(values_dict["direction"]["mean"], list) + and len(values_dict["direction"]["mean"]) == n_orientations + ): + # to force the direction of possibly 2 apicals, otherwise it is for basals + n_orientations = 1 + + angles = [] + for _ in range(n_orientations): + means = values_dict["direction"]["mean"] + means = means if isinstance(means, list) else [means] + stds = values_dict["direction"]["std"] + stds = stds if isinstance(stds, list) else [stds] + thetas = [] + for mean, std in zip(means, stds): + if mean == 0: + if std > 0: + thetas.append(np.clip(self._rng.exponential(std), 0, np.pi)) + else: + thetas.append(0) else: - thetas.append(0) - else: - thetas.append(np.clip(self._rng.normal(mean, std), 0, np.pi)) + thetas.append(np.clip(self._rng.normal(mean, std), 0, np.pi)) - phis = self._rng.uniform(0, 2 * np.pi, len(means)) - return spherical_angles_to_pia_orientations( - phis, thetas, self._parameters.get("pia_direction", None) - ) + phis = self._rng.uniform(0, 2 * np.pi, len(means)) + angles += spherical_angles_to_pia_orientations( + phis, thetas, self._parameters.get("pia_direction", None) + ).tolist() + return np.array(angles) def _mode_pia_constraint(self, _, tree_type): """Create trunks from distribution of angles with pia (`[0 , 1, 0]`) direction. diff --git a/tests/test_orientations.py b/tests/test_orientations.py index 05af7f80..2c54471f 100644 --- a/tests/test_orientations.py +++ b/tests/test_orientations.py @@ -366,7 +366,7 @@ def test_orientation_manager__mode_normal_pia_constraint(): om.compute_tree_type_orientations(tree_type) actual = om.get_tree_type_orientations("apical_dendrite") - expected = np.array([[-0.0084249, 0.99768935, 0.06741643]]) + expected = np.array([[0.09842876357405606, 0.9948066102055255, 0.025914991658764108]]) npt.assert_allclose(actual, expected, rtol=1e-5) # make one along pia @@ -400,7 +400,7 @@ def test_orientation_manager__mode_normal_pia_constraint(): om.compute_tree_type_orientations(tree_type) actual = om.get_tree_type_orientations("apical_dendrite") - expected = np.array([[-0.10517952, 0.52968005, 0.84165095]]) + expected = np.array([[0.8067661158336028, 0.5513710782140335, 0.21241084824428402]]) npt.assert_allclose(actual, expected, rtol=1e-5) # test other pia direction @@ -417,7 +417,7 @@ def test_orientation_manager__mode_normal_pia_constraint(): om.compute_tree_type_orientations(tree_type) actual = om.get_tree_type_orientations("apical_dendrite") - expected = np.array([[0.52968, 0.10518, 0.841651]]) + expected = np.array([[0.5513710782140335, -0.8067661158336028, 0.21241084824428402]]) npt.assert_allclose(actual, expected, rtol=1e-5)