Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Sep 27, 2024
1 parent 0ec1c62 commit 9baf834
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
4 changes: 3 additions & 1 deletion neurots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def accept_reject(
"""
n_tries = 0
best_proposal = None
best_p = 0
best_p = -1.0
while n_tries < max_tries:
proposal = propose(n_tries * randomness_increase)
_prob = probability(proposal, **probability_kwargs)
Expand All @@ -130,9 +130,11 @@ def accept_reject(

if rng.binomial(1, _prob):
return proposal

if _prob > best_p:
best_p = _prob
best_proposal = proposal

n_tries += 1
warnings.warn("We could not sample from distribution, we take best sample.")
return best_proposal
6 changes: 2 additions & 4 deletions tests/test_orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,10 @@ def test_orientation_manager__apical_constraint():
om.compute_tree_type_orientations(tree_type)

with pytest.warns(UserWarning) as record:
tested._sample_trunk_from_3d_angle(
parameters, om._rng, "basal_dendrite", [0, 0, 1], max_tries=-1
)
om._sample_trunk_from_3d_angle("basal_dendrite", [0, 0, 1], max_tries=-1)
assert len(record.list) == 1
assert str(record.list[0].message).startswith(
"We could not sample from distribution, so we take a random point."
"We could not sample from distribution, we take best sample."
)

actual = om.get_tree_type_orientations("basal_dendrite")
Expand Down
11 changes: 4 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,17 @@ def prob(proposal):
return 1.0
return 0.0

def default_propose():
return -1.0

# check we always return 1
for _ in range(10):
val = utils.accept_reject(propose, prob, rng, default_propose=default_propose)
val = utils.accept_reject(propose, prob, rng)
assert val == 1.0

def propose_null(_):
return 0.0

# check if we attain max_tries we return default_propose = -1
val = utils.accept_reject(propose_null, prob, rng, default_propose=default_propose)
assert val == -1.0
# check if we attain max_tries we return best
val = utils.accept_reject(propose_null, prob, rng)
assert val == 0.0

# check if we attain max_tries we return random
val = utils.accept_reject(propose_null, prob, rng)
Expand Down

0 comments on commit 9baf834

Please sign in to comment.