Skip to content

Commit

Permalink
BUG: Allow pass-through of correct units (#11143)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Sep 7, 2022
1 parent cf235d5 commit 89e979b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
22 changes: 13 additions & 9 deletions mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from ...utils import verbose, logger, warn
from ...utils import verbose, logger, warn, _validate_type
from ..utils import _blk_read_lims, _mult_cal_one
from ..base import BaseRaw, _get_scaling
from ..meas_info import _empty_info, _unique_channel_names
Expand Down Expand Up @@ -141,17 +141,21 @@ def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto',
preload, include)
logger.info('Creating raw.info structure...')

if units is not None and isinstance(units, str):
units = {ch_name: units for ch_name in info['ch_names']}
elif units is None:
_validate_type(units, (str, None, dict), 'units')
if units is None:
units = dict()
elif isinstance(units, str):
units = {ch_name: units for ch_name in info['ch_names']}

for k, (this_ch, this_unit) in enumerate(orig_units.items()):
if this_unit != "" and this_ch in units:
raise ValueError(f'Unit for channel {this_ch} is present in '
'the file. Cannot overwrite it with the '
'units argument.')
if this_unit == "" and this_ch in units:
if this_ch not in units:
continue
if this_unit not in ("", units[this_ch]):
raise ValueError(
f'Unit for channel {this_ch} is present in the file as '
f'{repr(this_unit)}, cannot overwrite it with the units '
f'argument {repr(units[this_ch])}.')
if this_unit == "":
orig_units[this_ch] = units[this_ch]
ch_type = edf_info["ch_types"][k]
scaling = _get_scaling(ch_type.lower(), orig_units[this_ch])
Expand Down
9 changes: 8 additions & 1 deletion mne/io/edf/tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_orig_units():
def test_units_params():
"""Test enforcing original channel units."""
with pytest.raises(ValueError,
match=r"Unit for channel .* is present .* Cannot "
match=r"Unit for channel .* is present .* cannot "
"overwrite it"):
_ = read_raw_edf(edf_path, units='V', preload=True)

Expand Down Expand Up @@ -601,6 +601,7 @@ def test_ch_types():
assert raw.ch_names == labels

raw = read_raw_edf(edf_chtypes_path, infer_types=True)
data = raw.get_data()

labels = ['Fp1-Ref', 'Fp2-Ref', 'F3-Ref', 'F4-Ref', 'C3-Ref', 'C4-Ref',
'P3-Ref', 'P4-Ref', 'O1-Ref', 'O2-Ref', 'F7-Ref', 'F8-Ref',
Expand All @@ -617,3 +618,9 @@ def test_ch_types():

assert raw.get_channel_types() == types
assert raw.ch_names == labels

with pytest.raises(ValueError, match="cannot overwrite"):
read_raw_edf(edf_chtypes_path, units='V')
raw = read_raw_edf(edf_chtypes_path, units='uV') # should be okay
data_units = raw.get_data()
assert_allclose(data, data_units)

0 comments on commit 89e979b

Please sign in to comment.