Skip to content

Commit

Permalink
Merge pull request #1964 from pupil-labs/develop
Browse files Browse the repository at this point in the history
Pupil v2.1 Release Candidate 1
  • Loading branch information
papr authored Jul 7, 2020
2 parents ca77307 + 94db63b commit 4116162
Show file tree
Hide file tree
Showing 55 changed files with 543 additions and 316 deletions.
40 changes: 27 additions & 13 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@
# There's a docker image pupillabs/pupil-docker-ubuntu:latest which contains all
# required dependencies for the test-suite to run.

os: minimal
services: docker
jobs:
include:
- name: pytest
os: minimal
services: docker
before_install:
- docker pull pupillabs/pupil-docker-ubuntu:latest
- chmod +x ./.travis/*.sh
script:
- >
docker run --rm
-v `pwd`:/repo
-w /repo
pupillabs/pupil-docker-ubuntu:latest
/bin/bash /repo/.travis/run_tests.sh
before_install:
- docker pull pupillabs/pupil-docker-ubuntu:latest
- chmod +x ./.travis/*.sh

script:
- >
docker run --rm
-v `pwd`:/repo
-w /repo
pupillabs/pupil-docker-ubuntu:latest
/bin/bash /repo/.travis/run_tests.sh
- name: black formatting check
language: python
before_script:
- pip install -U pip
- pip install black
script:
- >
black . --check --exclude pupil_src/tests || (
echo -e "\033[0;31m PLEASE RUN THE BLACK FORMATTER ON YOUR CODE: \033[0m" &&
echo "See https://github.com/psf/black for details" &&
false
)
9 changes: 6 additions & 3 deletions pupil_src/launchables/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def get_dt():
"min_calibration_confidence", 0.8
)


audio.set_audio_mode(
session_settings.get("audio_mode", audio.get_default_audio_mode())
)
Expand Down Expand Up @@ -226,7 +225,12 @@ def get_dt():
def handle_notifications(n):
subject = n["subject"]
if subject == "start_plugin":
g_pool.plugins.add(plugin_by_name[n["name"]], args=n.get("args", {}))
try:
g_pool.plugins.add(
plugin_by_name[n["name"]], args=n.get("args", {})
)
except KeyError as err:
logger.error(f"Attempt to load unknown plugin: {err}")
elif subject == "service_process.should_stop":
g_pool.service_should_run = False
elif subject.startswith("meta.should_doc"):
Expand Down Expand Up @@ -272,7 +276,6 @@ def handle_notifications(n):
gaze_pub.send(gaze_datum)
events["gaze"].append(gaze_datum)


for plugin in g_pool.plugins:
plugin.recent_events(events=events)

Expand Down
8 changes: 5 additions & 3 deletions pupil_src/shared_modules/accuracy_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from plugin import Plugin

from gaze_mapping import registered_gazer_classes_by_class_name
from gaze_mapping import gazer_classes_by_class_name, registered_gazer_classes
from gaze_mapping.notifications import (
CalibrationSetupNotification,
CalibrationResultNotification,
Expand Down Expand Up @@ -80,7 +80,9 @@ def clear(self):
self.__gazer_class = None
self.__gazer_params = None

def update(self, gazer_class_name: str, gazer_params=..., pupil_list=..., ref_list=...):
def update(
self, gazer_class_name: str, gazer_params=..., pupil_list=..., ref_list=...
):
if (
self.gazer_class_name is not None
and self.gazer_class_name != gazer_class_name
Expand All @@ -107,7 +109,7 @@ def __gazer_class_from_name(gazer_class_name: str) -> T.Optional[T.Any]:
logger.info("Accuracy visualization is disabled for HMD calibration")
return None

gazers_by_name = registered_gazer_classes_by_class_name()
gazers_by_name = gazer_classes_by_class_name(registered_gazer_classes())

try:
gazer_cls = gazers_by_name[gazer_class_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def on_choreography_successfull(
calib_data = {"ref_list": ref_list, "pupil_list": pupil_list}
self._start_plugin(self.selected_gazer_class, calib_data=calib_data)
elif mode == ChoreographyMode.VALIDATION:
assert self.g_pool.active_gaze_mapping_plugin is not None
gazer_class = self.g_pool.active_gaze_mapping_plugin.__class__
assert gazer_class == self.selected_gazer_class
gazer_params = self.g_pool.active_gaze_mapping_plugin.get_params()

self._start_plugin("Accuracy_Visualizer")
Expand Down Expand Up @@ -448,8 +448,7 @@ def update_ui(self):
)
if self.shows_action_buttons:
self.__ui_button_validation.read_only = (
self.selected_gazer_class
is not self.g_pool.active_gaze_mapping_plugin.__class__
self.g_pool.active_gaze_mapping_plugin is None
)

def deinit_ui(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def selection_label(cls) -> str:
def selection_order(cls) -> float:
return 1.0

@staticmethod
def get_list_of_markers_to_show(mode: ChoreographyMode) -> list:
if ChoreographyMode.CALIBRATION == mode:
return [(0.5, 0.5), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)]
if ChoreographyMode.VALIDATION == mode:
return [(0.5, 1.0), (1.0, 0.5), (0.5, 0.0), (0.0, 0.5)]
raise ValueError(f"Unknown mode {mode}")

def __init__(
self,
g_pool,
Expand Down Expand Up @@ -269,7 +277,7 @@ def _perform_start(self):
)
return

self.__current_list_of_markers_to_show = self.__get_list_of_markers_to_show(
self.__current_list_of_markers_to_show = self.get_list_of_markers_to_show(
mode=self.current_mode,
)
self.__currently_shown_marker_position = None
Expand All @@ -289,14 +297,6 @@ def _perform_stop(self):

### Private

@staticmethod
def __get_list_of_markers_to_show(mode: ChoreographyMode,) -> list:
if ChoreographyMode.CALIBRATION == mode:
return [(0.5, 0.5), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)]
if ChoreographyMode.VALIDATION == mode:
return [(0.5, 1.0), (1.0, 0.5), (0.5, 0.0), (0.0, 0.5)]
raise ValueError(f"Unknown mode {mode}")

def _on_window_did_close(self):
self._signal_should_stop(mode=self.current_mode)

Expand Down
24 changes: 17 additions & 7 deletions pupil_src/shared_modules/csv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
import typing as t


CSV_EXPORT_RAW_TYPE = t.TypeVar('CSV_EXPORT_RAW_TYPE')
CSV_EXPORT_RAW_TYPE = t.TypeVar("CSV_EXPORT_RAW_TYPE")
CSV_EXPORT_LABEL_TYPE = t.AnyStr
CSV_EXPORT_VALUE_TYPE = t.Any
CSV_EXPORT_VALUE_GETTER_TYPE = t.Callable[[CSV_EXPORT_RAW_TYPE], CSV_EXPORT_VALUE_TYPE]
CSV_EXPORT_SCHEMA_TYPE = t.List[t.Tuple[CSV_EXPORT_LABEL_TYPE, CSV_EXPORT_VALUE_GETTER_TYPE]]
CSV_EXPORT_SCHEMA_TYPE = t.List[
t.Tuple[CSV_EXPORT_LABEL_TYPE, CSV_EXPORT_VALUE_GETTER_TYPE]
]


class CSV_Exporter(abc.ABC, t.Generic[CSV_EXPORT_RAW_TYPE]):

@classmethod
@abc.abstractmethod
def csv_export_schema(cls) -> CSV_EXPORT_SCHEMA_TYPE:
Expand All @@ -33,10 +34,17 @@ def csv_export_labels(cls) -> t.Iterable[CSV_EXPORT_LABEL_TYPE]:
return tuple(label for label, _ in cls.csv_export_schema())

@classmethod
def csv_export_values(cls, raw_value: CSV_EXPORT_RAW_TYPE) -> t.Iterable[CSV_EXPORT_VALUE_TYPE]:
def csv_export_values(
cls, raw_value: CSV_EXPORT_RAW_TYPE
) -> t.Iterable[CSV_EXPORT_VALUE_TYPE]:
return tuple(getter(raw_value) for _, getter in cls.csv_export_schema())

def csv_export(self, raw_values: t.Iterable[CSV_EXPORT_RAW_TYPE], export_dir: str, export_name: str) -> str:
def csv_export(
self,
raw_values: t.Iterable[CSV_EXPORT_RAW_TYPE],
export_dir: str,
export_name: str,
) -> str:

export_path = os.path.abspath(os.path.join(export_dir, export_name))

Expand All @@ -63,7 +71,9 @@ def read_key_value_file(csvfile):
if "key" not in first_line or "value" not in first_line:
csvfile.seek(0) # Seek to start if first_line is not an header
dialect = csv.Sniffer().sniff(first_line, delimiters=",\t")
reader = csv.reader(csvfile, dialect, quoting=csv.QUOTE_NONE, escapechar='\\') # create reader
reader = csv.reader(
csvfile, dialect, quoting=csv.QUOTE_NONE, escapechar="\\"
) # create reader
for row in reader:
kvstore[row[0]] = row[1]
return kvstore
Expand All @@ -80,7 +90,7 @@ def write_key_value_file(csvfile, dictionary, append=False):
Returns:
None: No return
"""
writer = csv.writer(csvfile, delimiter=",", quoting=csv.QUOTE_NONE, escapechar='\\')
writer = csv.writer(csvfile, delimiter=",", quoting=csv.QUOTE_NONE, escapechar="\\")
if not append:
writer.writerow(["key", "value"])
for key, val in dictionary.items():
Expand Down
6 changes: 1 addition & 5 deletions pupil_src/shared_modules/file_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,13 @@ def _deep_copy_serialized_dict(self):
return Serialized_Dict(python_dict=dict_copy)

def _deep_copy_dict(self):

def unpacking_ext_hook(self, code, data):
if code == self.MSGPACK_EXT_CODE:
return type(self)(msgpack_bytes=data)._deep_copy_dict()
return msgpack.ExtType(code, data)

return msgpack.unpackb(
self._ser_data,
raw=False,
use_list=False,
ext_hook=unpacking_ext_hook,
self._ser_data, raw=False, use_list=False, ext_hook=unpacking_ext_hook,
)


Expand Down
14 changes: 10 additions & 4 deletions pupil_src/shared_modules/gaze_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ def registered_gazer_classes() -> list:
return gazer_base.GazerBase.registered_gazer_classes()


def registered_gazer_labels_by_class_names() -> dict:
return {cls.__name__: cls.label for cls in registered_gazer_classes()}
def user_selectable_gazer_classes() -> list:
gazers = registered_gazer_classes()
gazers = filter(lambda g: g is not GazerHMD3D, gazers)
return list(gazers)


def registered_gazer_classes_by_class_name() -> dict:
return {cls.__name__: cls for cls in registered_gazer_classes()}
def gazer_labels_by_class_names(gazers: list) -> dict:
return {cls.__name__: cls.label for cls in gazers}


def gazer_classes_by_class_name(gazers: list) -> dict:
return {cls.__name__: cls for cls in gazers}


default_gazer_class = Gazer3D
Expand Down
12 changes: 11 additions & 1 deletion pupil_src/shared_modules/gaze_mapping/gazer_3d/gazer_hmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gaze_mapping.gazer_base import (
GazerBase,
Model,
CalibrationError,
NotEnoughDataError,
FitDidNotConvergeError,
)
Expand All @@ -46,13 +47,22 @@
logger = logging.getLogger(__name__)


class MissingEyeTranslationsError(CalibrationError):
message = (
"GazerHMD3D can only be calibrated if it is "
"initialised with valid eye translations."
)


class ModelHMD3D_Binocular(Model3D_Binocular):
def __init__(self, *, intrinsics, eye_translations):
self.intrinsics = intrinsics
self.eye_translations = eye_translations
self._is_fitted = False

def _fit(self, X, Y):
if self.eye_translations is None:
raise MissingEyeTranslationsError()
assert X.shape[1] == _BINOCULAR_FEATURE_COUNT, X
unprojected_ref_points = Y

Expand Down Expand Up @@ -109,7 +119,7 @@ class GazerHMD3D(Gazer3D):
def _gazer_description_text(cls) -> str:
return "Gaze mapping built specifically for HMD-Eyes."

def __init__(self, g_pool, *, eye_translations, calib_data=None, params=None):
def __init__(self, g_pool, *, eye_translations=None, calib_data=None, params=None):
self.__eye_translations = eye_translations
super().__init__(g_pool, calib_data=calib_data, params=params)

Expand Down
32 changes: 25 additions & 7 deletions pupil_src/shared_modules/gaze_mapping/gazer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ class NotEnoughDataError(CalibrationError):
message = "Not sufficient data available."


class NotEnoughPupilDataError(NotEnoughDataError):
message = "Not sufficient pupil data available."


class NotEnoughReferenceDataError(NotEnoughDataError):
message = "Not sufficient reference data available."


class FitDidNotConvergeError(CalibrationError):
message = "Model fit did not converge."

Expand Down Expand Up @@ -194,18 +202,25 @@ def __init__(
self._announce_calibration_setup(calib_data=calib_data)
try:
self.fit_on_calib_data(calib_data)
except CalibrationError:
except CalibrationError as err:
if raise_calibration_error:
raise # Let offline calibration handle this one!
logger.error("Calibration Failed!")
self.alive = False
self._announce_calibration_failure(reason=CalibrationError.__name__)
self._announce_calibration_failure(reason=err.message)
except Exception as err:
import traceback

self._announce_calibration_failure(reason=err.__class__.__name__)
logger.debug(traceback.format_exc())
raise CalibrationError() from err
if raise_calibration_error:
raise CalibrationError() from err # Let offline calibration handle this one!
logger.error("Calibration Failed!")
self.alive = False
try:
reason = err.args[0]
except (AttributeError, IndexError):
reason = err.__class__.__name__
self._announce_calibration_failure(reason=reason)
else:
self._announce_calibration_success()
self._announce_calibration_result(params=self.get_params())
Expand All @@ -214,8 +229,9 @@ def __init__(
else:
raise ValueError("Requires either `calib_data` or `params`")

# used by pupil_data_relay for gaze mapping
g_pool.active_gaze_mapping_plugin = self
if self.alive:
# Used by pupil_data_relay for gaze mapping.
g_pool.active_gaze_mapping_plugin = self

def get_init_dict(self):
return {"params": self.get_params()}
Expand Down Expand Up @@ -257,7 +273,9 @@ def fit_on_calib_data(self, calib_data):
pupil_data, self.g_pool.min_calibration_confidence
)
if not pupil_data:
raise NotEnoughDataError
raise NotEnoughPupilDataError
if not ref_data:
raise NotEnoughReferenceDataError
# match pupil to reference data (left, right, and binocular)
matches = self.match_pupil_to_ref(pupil_data, ref_data)
if matches.binocular[0]:
Expand Down
10 changes: 5 additions & 5 deletions pupil_src/shared_modules/gaze_mapping/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _filter_pupil_list_by_confidence(pupil_list, threshold):


def _match_data_batch(pupil_list, ref_list):
assert pupil_list
assert ref_list
assert pupil_list, "No pupil data to match"
assert ref_list, "No reference data to match"
pupil0 = [p for p in pupil_list if p["id"] == 0]
pupil1 = [p for p in pupil_list if p["id"] == 1]

Expand All @@ -45,9 +45,9 @@ def _match_data_batch(pupil_list, ref_list):
num_mono_right = len(matched_pupil0_data[0])
num_mono_left = len(matched_pupil1_data[0])

logger.info(f"Collected {num_bino} binocular references.")
logger.info(f"Collected {num_mono_right} right eye monocular references.")
logger.info(f"Collected {num_mono_left} left eye monocular references.")
logger.debug(f"Collected {num_bino} binocular references.")
logger.debug(f"Collected {num_mono_right} right eye monocular references.")
logger.debug(f"Collected {num_mono_left} left eye monocular references.")

return (
matched_binocular_data,
Expand Down
Loading

0 comments on commit 4116162

Please sign in to comment.