Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaky behaviour on running repeated simulations #13

Open
ianosd opened this issue Nov 22, 2024 · 3 comments
Open

Leaky behaviour on running repeated simulations #13

ianosd opened this issue Nov 22, 2024 · 3 comments

Comments

@ianosd
Copy link

ianosd commented Nov 22, 2024

First: great package! It was of great help in my work on my master's project.
I found what looks like a memory leak. Here's a minimal example:

import jax
import jax.numpy as np
import numpy.random 
import psutil
import pychastic

def main():
    initial_samples = numpy.random.random((10000, 3))
    def drift_fn(X):
        return -X

    def noise_fn(_):
        return np.eye(3)

    solver = pychastic.sde_solver.SDESolver(dt=0.01)
    problem = pychastic.sde_problem.SDEProblem(
        a=drift_fn, b=noise_fn,
        x0=initial_samples,
        tmax=10)

    for _ in range(10):
        solver.solve_many(problem, n_trajectories=None, progress_bar=True)
        mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2
#      jax.clear_caches()
        print(f'Mem usage: {mem_usage_now} MB')

if __name__ == "__main__":
    main()

The output to this is:

  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 545.68359375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 234.484375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 233.8984375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 252.140625 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 262.68359375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 262.11328125 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 267.51171875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 275.12109375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 280.82421875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 277.37890625 MB

If you uncoment the jax.clear_caches() call, the output is:

  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 544.79296875 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 548.65234375 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 551.78515625 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 555.6796875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 557.37109375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 557.46875 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 557.22265625 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 558.96484375 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 558.96875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 559.33984375 MB

Maybe I should mention that I am running this on a CPU. The version of pychastic is 0.2.2 .

I might look into this these days, but maybe someone around here immediately sees what the issue could be, or what I am doing wrong.

@RadostW
Copy link
Owner

RadostW commented Nov 22, 2024

Thanks! That's great to hear :)
If you've managed to build something with the package and would be willing to share, I'd be great to include it as an example.

Regarding the leak - I also encountered it recently. I think it was ok not so long ago with previous jax versions. I'm not sure what causes it.

Take a look at this code - here everything pychastic is destroyed each time loop is made but the memory keeps climbing nonetheless.

import jax
import jax.numpy as np
import numpy.random 
import psutil
import pychastic

def main():

    for _ in range(10):
    
        initial_samples = numpy.random.random((10000, 3))
        def drift_fn(X):
            return -X

        def noise_fn(_):
            return np.eye(3)
    
        solver = pychastic.sde_solver.SDESolver(dt=0.01)
        problem = pychastic.sde_problem.SDEProblem(
        a=drift_fn, b=noise_fn,
        x0=initial_samples,
        tmax=10)
    
        solver.solve_many(problem, n_trajectories=None, progress_bar=True)
        mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2

        # jax.clear_backends()
        
        del solver
        del problem
        del initial_samples
        del noise_fn
        del drift_fn
        
        print(f'Mem usage: {mem_usage_now} MB')

if __name__ == "__main__":
    main()
    

I will report this as bug in jax and will report here if I learn something.

@RadostW
Copy link
Owner

RadostW commented Nov 22, 2024

Related issue on jax:
jax-ml/jax#25069

@ianosd
Copy link
Author

ianosd commented Nov 23, 2024

Yeah, it was quite apparent that it is jax-related, also from the behaviour respective to clearing caches.
Regarding sharing, I might have something, but I'll have to clean it up first :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants