Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible leak in random number generation #25069

Open
RadostW opened this issue Nov 22, 2024 · 0 comments
Open

Possible leak in random number generation #25069

RadostW opened this issue Nov 22, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@RadostW
Copy link

RadostW commented Nov 22, 2024

Description

When using jax based package pychastic (an SDE solver) jax backend keeps eating memory indefinitely.

import jax
import jax.numpy as np
import numpy.random 
import psutil
import pychastic

def main():

    for _ in range(100):
    
        initial_samples = numpy.random.random((10000, 3))
        def drift_fn(X):
            return -X

        def noise_fn(_):
            return np.eye(3)
    
        solver = pychastic.sde_solver.SDESolver(dt=0.01)
        problem = pychastic.sde_problem.SDEProblem(
        a=drift_fn, b=noise_fn,
        x0=initial_samples,
        tmax=10)
    
        solver.solve_many(problem, n_trajectories=None, progress_bar=True)
        
        mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2

        # jax.clear_backends()
        
        del solver
        del problem
        del initial_samples
        del noise_fn
        del drift_fn
        
        print(f'Mem usage: {mem_usage_now} MB')

if __name__ == "__main__":
    main()

Output (abbreviated)

100%|████████████████████████████████| 1000/1000 [00:01<00:00, 716.20it/s]
Mem usage: 161.609375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 869.53it/s]
Mem usage: 165.42578125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 893.27it/s]
Mem usage: 168.7578125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 885.05it/s]
Mem usage: 171.8359375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 898.58it/s]
Mem usage: 175.03125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 864.00it/s]
Mem usage: 178.3359375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 863.09it/s]
Mem usage: 181.51171875 MB

I apologize for the contrived code to reproduce the issue.
I'd be happy to chase the leak further, but I'm unfamiliar with any tools that could help diagnose the issue. Is there some way to see what's taking up all this space?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.3.25
jaxlib: 0.3.25
numpy:  1.23.5
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]

also tested (same result) with

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.1.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant