-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Leaky behaviour on running repeated simulations #13
Comments
Thanks! That's great to hear :) Regarding the leak - I also encountered it recently. I think it was ok not so long ago with previous jax versions. I'm not sure what causes it. Take a look at this code - here everything import jax
import jax.numpy as np
import numpy.random
import psutil
import pychastic
def main():
for _ in range(10):
initial_samples = numpy.random.random((10000, 3))
def drift_fn(X):
return -X
def noise_fn(_):
return np.eye(3)
solver = pychastic.sde_solver.SDESolver(dt=0.01)
problem = pychastic.sde_problem.SDEProblem(
a=drift_fn, b=noise_fn,
x0=initial_samples,
tmax=10)
solver.solve_many(problem, n_trajectories=None, progress_bar=True)
mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2
# jax.clear_backends()
del solver
del problem
del initial_samples
del noise_fn
del drift_fn
print(f'Mem usage: {mem_usage_now} MB')
if __name__ == "__main__":
main()
I will report this as bug in |
Related issue on jax: |
Yeah, it was quite apparent that it is jax-related, also from the behaviour respective to clearing caches. |
First: great package! It was of great help in my work on my master's project.
I found what looks like a memory leak. Here's a minimal example:
The output to this is:
If you uncoment the jax.clear_caches() call, the output is:
Maybe I should mention that I am running this on a CPU. The version of pychastic is 0.2.2 .
I might look into this these days, but maybe someone around here immediately sees what the issue could be, or what I am doing wrong.
The text was updated successfully, but these errors were encountered: