Skip to content

Commit

Permalink
Fix: Fix AR algo (#112)
Browse files Browse the repository at this point in the history
* Fix: Fix AR algo

* fix

* better
  • Loading branch information
arnaudon authored Dec 2, 2024
1 parent f34dc8d commit 6071796
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
27 changes: 22 additions & 5 deletions neurots/generate/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _mode_uniform(self, _, tree_type):
[sample.sample_spherical_unit_vectors(rng=self._rng) for _ in range(n_orientations)]
)

def _mode_normal_pia_constraint(self, values_dict, tree_type):
def _mode_normal_pia_constraint(self, values_dict, tree_type, max_tries=100):
"""Returns orientations using normal/exp distribution along a direction.
The `direction` value should be a dict with two entries: `mean` and `std`. The mean is the
Expand All @@ -233,11 +233,13 @@ def _mode_normal_pia_constraint(self, values_dict, tree_type):
n_orientations = 1

angles = []
for _ in range(n_orientations):

def propose(_):
means = values_dict["direction"]["mean"]
means = means if isinstance(means, list) else [means]
stds = values_dict["direction"]["std"]
stds = stds if isinstance(stds, list) else [stds]

thetas = []
for mean, std in zip(means, stds):
if mean == 0:
Expand All @@ -249,9 +251,24 @@ def _mode_normal_pia_constraint(self, values_dict, tree_type):
thetas.append(np.clip(self._rng.normal(mean, std), 0, np.pi))

phis = self._rng.uniform(0, 2 * np.pi, len(means))
angles += spherical_angles_to_pia_orientations(
return spherical_angles_to_pia_orientations(
phis, thetas, self._context.get("y_rotation", None)
).tolist()
).tolist()[0]

if self._context is not None and self._context.get("constraints", []): # pragma: no cover

def prob(proposal):
p = 1.0
for constraint in self._context["constraints"]:
if "trunk_prob" in constraint:
p = min(p, constraint["trunk_prob"](proposal, self._soma.center))
return p

for _ in range(n_orientations):
angles.append(accept_reject(propose, prob, self._rng, max_tries=max_tries))
else:
for _ in range(n_orientations):
angles.append(propose(_))
return np.array(angles)

def _mode_pia_constraint(self, _, tree_type):
Expand Down Expand Up @@ -305,7 +322,7 @@ def prob(proposal):
if self._context.get("constraints", []): # pragma: no cover
for constraint in self._context["constraints"]:
if "trunk_prob" in constraint:
p *= constraint["trunk_prob"](proposal, self._soma.center)
p = min(p, constraint["trunk_prob"](proposal, self._soma.center))
return p

return accept_reject(propose, prob, self._rng, max_tries=max_tries)
Expand Down
8 changes: 3 additions & 5 deletions neurots/generate/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# default parameters for accept/reject
DEFAULT_MAX_TRIES = 50
DEFAULT_RANDOMNESS_INCREASE = 0.5
DEFAULT_RANDOMNESS_INCREASE = 1.2


class SectionGrower:
Expand Down Expand Up @@ -94,13 +94,11 @@ def _propose(self, extra_randomness=0, add_random_component=True):
add_random_component (bool): add a random component to the direction
"""
direction = self.params.targeting * self.direction + self.params.history * self.history()

if add_random_component:
if add_random_component or extra_randomness > 0.0:
random_component = self.params.randomness * get_random_point(random_generator=self._rng)
if extra_randomness > 0: # pragma: no cover
random_component *= extra_randomness
direction += random_component

return direction / vectorial_norm(direction)

def next_point(self, add_random_component=True, extra_randomness=0):
Expand All @@ -117,7 +115,7 @@ def next_point(self, add_random_component=True, extra_randomness=0):
def prob(*args, **kwargs):
p = 1.0
for constraint in self.context["constraints"]:
p *= constraint["section_prob"](*args, **kwargs)
p = min(p, constraint["section_prob"](*args, **kwargs))
return p

max_tries = -1
Expand Down
5 changes: 2 additions & 3 deletions neurots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def accept_reject(
probability,
rng,
max_tries=50,
randomness_increase=0.5,
randomness_increase=1.2,
**probability_kwargs,
):
"""Generic accept/reject algorithm.
Expand All @@ -122,7 +122,7 @@ def accept_reject(
best_proposal = None
best_p = -1.0
while n_tries < max_tries:
proposal = propose(n_tries * randomness_increase)
proposal = propose((1 + n_tries) * randomness_increase)
_prob = probability(proposal, **probability_kwargs)
if _prob == 1.0:
# this ensures we don't change rng for the tests, but its not really needed
Expand All @@ -134,7 +134,6 @@ def accept_reject(
if _prob > best_p:
best_p = _prob
best_proposal = proposal

n_tries += 1
warnings.warn("We could not sample from distribution, we take best sample.")
return best_proposal

0 comments on commit 6071796

Please sign in to comment.