Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to the colorbar and some bug fixes #43

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/plonk/analysis/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def to_dataframe(
_units = list()
for column in columns:
try:
_units.append(self[column].units)
_units.append(_get_unit(self, column, units).units)
# _units.append(self[column].units)
except AttributeError:
_units.append(plonk_units('dimensionless'))
else:
Expand Down
7 changes: 5 additions & 2 deletions src/plonk/simulation/_phantom_ev.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def _get_data(columns: Tuple[str, ...], file_paths: Tuple[Path, ...]) -> DataFra
_skiprows = [0]
if len(times) > 1:
for t1, t2 in zip(times, times[1:]):
_skiprows.append(np.where(t2 < t1[-1])[0][-1] + 2)
if t2[0] < t1[-1]:
_skiprows.append(np.where(t2 < t1[-1])[0][-1] + 2)
else:
_skiprows.append(0)

df = pd.concat(
(
Expand Down Expand Up @@ -124,7 +127,7 @@ def _get_columns(filename: Path, name_map: Dict[str, str]) -> Tuple[str, ...]:
def _check_file_consistency(
filenames: Tuple[Path, ...], name_map: Dict[str, str]
) -> None:

columns = _get_columns(filenames[0], name_map)
for filename in filenames:
columns_previous = columns
Expand Down
8 changes: 6 additions & 2 deletions src/plonk/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,12 @@ def _get_sink_ts_files(self, glob: str = None) -> List[List[Path]]:

n = len(self.prefix) + len('Sink')
n_sinks = len({p.name[n : n + 4] for p in self.paths['directory'].glob(glob)})

# n_sinks = len(np.unique({p.name[n : n + 4] for p in self.paths['directory'].glob(glob)}))
# print(self.paths['directory'].glob(glob))
# print({p.name[n : n + 4] for p in self.paths['directory'].glob(glob)})
# print(np.unique({p.name[n : n + 4] for p in self.paths['directory'].glob(glob)}))
# print(len(np.unique({p.name[n : n + 4] for p in self.paths['directory'].glob(glob)})))
# print("n_sinks is ", n_sinks)
sinks = list()
for idx in range(1, n_sinks + 1):
sinks.append(
Expand All @@ -264,7 +269,6 @@ def _get_sink_ts_files(self, glob: str = None) -> List[List[Path]]:
)
)
)

return sinks

def set_units_on_time_series(self, config: Union[str, Path] = None):
Expand Down
10 changes: 7 additions & 3 deletions src/plonk/snap/readers/phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def add_to_header_from_infile(
file = h5py.File(filename, mode='r+')
logger.info(f'Opening snapshot file {snapfile}')
for param in parameters:
print(param)
if param in config.config:
logger.info(f'Updating {param} from in-file')
file[f'header/{param}'][()] = config.config[param].value
print(param)
print(file[f'header/{param}'][()])
file.close()


Expand Down Expand Up @@ -135,7 +138,7 @@ def snap_properties_and_units(
elif ieos == 2:
properties['equation_of_state'] = 'adiabatic'
properties['adiabatic_index'] = gamma
elif ieos == 3:
elif ieos in (3, 6, 14, 21):
properties['equation_of_state'] = 'locally isothermal disc'
properties['adiabatic_index'] = gamma

Expand Down Expand Up @@ -434,8 +437,9 @@ def pressure(snap: Snap) -> Quantity:
* snap.code_units['time'] ** (-2)
)
return K * rho ** (gamma - 1)
if ieos == 3:
if ieos in (3, 6, 14, 21):
# Vertically isothermal (for accretion disc)
# All of these ieos values are vertically isothermal
K = K * snap.code_units['length'] ** 2 * snap.code_units['time'] ** (-2)
q = snap._file_pointer['header/qfacdisc'][()]
pos = get_dataset('xyz', 'particles')(snap)
Expand All @@ -451,7 +455,7 @@ def sound_speed(snap: Snap) -> Quantity:
gamma = snap.properties['adiabatic_index']
rho = density(snap)
P = pressure(snap)
if ieos in (1, 3):
if ieos in (1, 3, 6, 14, 21):
return np.sqrt(P / rho)
if ieos == 2:
return np.sqrt(gamma * P / rho)
Expand Down
30 changes: 29 additions & 1 deletion src/plonk/snap/snap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy.spatial.transform import Rotation

from .. import visualize
from .. import analysis
from .._config import read_config
from .._logging import logger
from .._units import Quantity, array_units, generate_array_code_units
Expand Down Expand Up @@ -675,6 +676,26 @@ def translate(

return self

def shift_to_com(self, sink_idx: Union[int, List[int]] = None) -> Snap:
"""Shift the snapshot to the center of mass, or the location
of the central body

Returns
-------
Snap
The Snap shifted to the center of mass. Note that the shift to center of mass
operation is in-place.
"""

logger.debug(f'Shifting snapshot to centre of mass : {self.file_path.name}')
if sink_idx == None:
com = analysis.total.center_of_mass(self)
return self.translate(-com)
else:
self.set_central_body(sink_idx)
com = self._properties['central_body']['position']
return self.translate(-com)

def particle_indices(
self, particle_type: str, squeeze: bool = False
) -> Union[ndarray, List[ndarray]]:
Expand Down Expand Up @@ -1418,10 +1439,17 @@ class Sinks:
"""

def __init__(
self, base: Snap, indices: Union[ndarray, slice, list, int, tuple] = None
self, base: Snap, indices: Union[ndarray, slice, list, int, tuple] = None,
combine: Union[ndarray, slice, list, tuple] = None
):
self.base = base

if combine is not None:
if indices is not None:
logger.warning('Sinks will be combined, ignoring indices')
self._combine = _input_indices_array(inp=combine, max_slice=base.num_sinks)


if indices is None:
indices = np.arange(base.num_sinks)
ind = _input_indices_array(inp=indices, max_slice=base.num_sinks)
Expand Down
12 changes: 11 additions & 1 deletion src/plonk/visualize/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Tuple
from collections.abc import Callable

import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -34,6 +35,7 @@ def interpolate(
interp: 'str',
weighted: bool = False,
slice_normal: Tuple[float, float, float] = None,
slice_func: Callable[[SnapLike], ndarray] = None,
slice_offset: Quantity = None,
extent: Quantity,
num_pixels: Tuple[float, float] = None,
Expand Down Expand Up @@ -62,6 +64,9 @@ def interpolate(
slice_normal
The normal vector to the plane in which to take the
cross-section slice as an array (x, y, z).
slice_func
The function which returns an ndarray of the distance of
each particle from an arbitrary slice.
slice_offset
The offset of the cross-section slice. Default is 0.0.
extent
Expand Down Expand Up @@ -123,7 +128,12 @@ def interpolate(
slice_offset = (
(slice_offset / snap.code_units['length']).to_base_units().magnitude
)
dist_from_slice = distance_from_plane(_x, _y, _z, _slice_normal, slice_offset)
if slice_func is not None:
dist_from_slice = slice_func(snap, normal=_slice_normal)
else:
dist_from_slice = distance_from_plane(_x, _y, _z, _slice_normal, slice_offset)
print("Dist from slice")
print(dist_from_slice)

if _quantity.ndim == 1:
interpolated_data = scalar_interpolation(
Expand Down
8 changes: 6 additions & 2 deletions src/plonk/visualize/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,16 @@ def imshow(*, interpolated_data: ndarray, extent: Extent, ax: Any, **kwargs):
except KeyError:
norm = 'linear'
if norm.lower() in ('linear', 'lin'):
norm = mpl.colors.Normalize()
norm = mpl.colors.Normalize
elif norm.lower() in ('logarithic', 'logarithm', 'log', 'log10'):
norm = mpl.colors.LogNorm()
norm = mpl.colors.LogNorm
else:
raise ValueError('Cannot determine normalization for colorbar')

norm = norm(vmin=_kwargs['vmin'], vmax=_kwargs['vmax'])
del _kwargs['vmin']
del _kwargs['vmax']

return ax.imshow(
interpolated_data, origin='lower', extent=extent, norm=norm, **_kwargs
)
Expand Down
1 change: 1 addition & 0 deletions src/plonk/visualize/splash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

RADKERNEL = 2.0
RADKERNEL2 = 4.0
# RADKERNEL2 = 10.0
CNORMK3D = 1.0 / np.pi

NPTS = 100
Expand Down
Loading