Skip to content

Commit

Permalink
Render the earth a bit smaller than the sun
Browse files Browse the repository at this point in the history
  • Loading branch information
mstoelzle committed Mar 8, 2024
1 parent 403809c commit d43f050
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ if __name__ == "__main__":
x_min, x_max = -1.5 * jnp.ones((1,)), 1.5 * jnp.ones((1,))

# simulation settings
duration = 6.3259
duration = 6.3259 # duration of the simulation [s]
ts = jnp.linspace(0.0, duration, 1001)
dt = ts[-1] * 1e-4
dt = 2e-4 # time step [s]

# solve the ODE
ode_term = ODETerm(ode_fn)
Expand Down
Binary file modified examples/outputs/two_body.mp4
Binary file not shown.
20 changes: 14 additions & 6 deletions examples/simulate_two_body_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from nbodyx.rendering.opencv import animate_n_body, render_n_body

if __name__ == "__main__":
ode_fn = make_n_body_ode(jnp.array([M_sun, M_earth]))
body_masses = jnp.array([M_sun, M_earth])
ode_fn = make_n_body_ode(body_masses)

# initial conditions for earth
x_earth = jnp.array([-AU, 0.0])
Expand All @@ -21,20 +22,26 @@
# initial conditions for sun
x_sun = jnp.array([0.0, 0.0])
v_sun = jnp.array([0.0, 0.0])
# initial positions and velocities
x0, v0 = jnp.concatenate([x_sun, x_earth]), jnp.concatenate([v_sun, v0_earth])
# initial state
y0 = jnp.concatenate([x_sun, x_earth, v_sun, v0_earth])
y0 = jnp.concatenate([x0, v0])
print("y0", y0)

# state bounds
x_min, x_max = -2 * AU * jnp.ones((1,)), 2 * AU * jnp.ones((1,))
# external torques
tau = jnp.zeros((4,))

# animation settings
img_size = (500, 500)
body_radii = 0.05 * min(img_size) * jnp.array([1.0, 0.5])

# evaluate the ODE at the initial state
y_d0 = jit(ode_fn)(0.0, y0, tau)
print("y_d0", y_d0)
# render the image at the initial state
img = render_n_body(jnp.array([x_sun, x_earth]), 500, 500, x_min, x_max)
img = render_n_body(x0, img_size[0], img_size[1], x_min, x_max, body_radii=body_radii)
plt.figure(num="Sample rendering")
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
Expand Down Expand Up @@ -77,11 +84,12 @@
animate_n_body(
ts,
x_bds_ts,
500,
500,
img_size[0],
img_size[1],
video_path="examples/outputs/two_body.mp4",
speed_up=ts[-1] / 10,
timestamp_unit="M",
x_min=x_min,
x_max=x_max,
timestamp_unit="M",
body_radii=body_radii,
)
35 changes: 29 additions & 6 deletions src/nbodyx/rendering/opencv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from os import PathLike
from pathlib import Path
from tqdm import tqdm
from typing import Union


def render_n_body(
x_bds: Array,
x: Array,
width: int,
height: int,
x_min: Array,
Expand All @@ -21,7 +22,7 @@ def render_n_body(
"""Render the n-body problem using OpenCV.
Args:
x_bds: The positions of the bodies. Array of shape (num_bodies, 2).
x: The positions of the bodies. Array of shape (num_bodies*2).
width: The width of the image.
height: The height of the image.
body_radii: The radii of the bodies. Array of shape (num_bodies, ).
Expand All @@ -30,12 +31,17 @@ def render_n_body(
Returns:
img: The rendered image.
"""
# reshape the positions
x_bds = x.reshape(-1, 2)

# define ppm (pixels per meter)
ppm = 0.9 * min(width, height) / jnp.max(x_max - x_min)

# default body radii
if body_radii is None:
body_radii = (jnp.ones(x_bds.shape[0]) * 0.05 * min(width, height)).astype(int)
else:
body_radii = body_radii.astype(int)

# default body colors
if body_colors is None:
Expand Down Expand Up @@ -86,16 +92,33 @@ def x_to_uv(x: Array) -> Array:

def animate_n_body(
ts: Array,
x_bds_ts: Array,
x_ts: Array,
width: int,
height: int,
video_path: PathLike,
speed_up: int = 1,
speed_up: Union[float, Array] = 1,
skip_step: int = 1,
add_timestamp: bool = True,
timestamp_unit: str = "s",
**kwargs,
):
"""
Animate the n-body problem using OpenCV.
Args:
ts: The time steps of the data. Array of shape (num_time_steps, ).
x_ts: The positions of the bodies. Array of shape (num_time_steps, num_bodies*2).
width: The width of the video.
height: The height of the video.
video_path: The path where the video will be saved.
speed_up: The speed up factor of the video.
skip_step: The number of time steps to skip between animation frames.
add_timestamp: Whether to add a timestamp to the video.
timestamp_unit: The unit of the timestamp. Can be "s" (seconds), "d" (days), "M" (months), or "y" (years).
**kwargs: Additional keyword arguments for the rendering function.
Returns:
"""
dt = jnp.mean(jnp.diff(ts)).item()
fps = float(speed_up / (skip_step * dt))
print(f"fps: {fps}")
Expand All @@ -112,7 +135,7 @@ def animate_n_body(

# skip frames
ts = ts[::skip_step]
x_bds_ts = x_bds_ts[::skip_step]
x_ts = x_ts[::skip_step]

for time_idx, t in (pbar := tqdm(enumerate(ts))):
pbar.set_description(f"Rendering frame {time_idx + 1}/{len(ts)}")
Expand All @@ -131,7 +154,7 @@ def animate_n_body(
raise ValueError(f"Invalid timestamp unit: {timestamp_unit}")

# render the image
img = render_n_body(x_bds_ts[time_idx], width, height, label=label, **kwargs)
img = render_n_body(x_ts[time_idx], width, height, label=label, **kwargs)

video.write(img)

Expand Down

0 comments on commit d43f050

Please sign in to comment.