Skip to content

Commit

Permalink
trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Nov 28, 2024
1 parent a40b8e2 commit ff94a0d
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tests/test_training/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from mrmustard.training import Optimizer
from mrmustard.training.trainer import map_trainer, train_device, update_pop

from ..conftest import skip_np
from ..conftest import skip_np, skip_tf


def wrappers():
Expand Down Expand Up @@ -79,6 +79,7 @@ class TestTrainer:
def test_circ_cost(self, tasks, seed): # pylint: disable=redefined-outer-name
"""Test distributed cost calculations."""
skip_np()
skip_tf()

has_seed = isinstance(seed, int)
_, cost_fn = wrappers()
Expand Down Expand Up @@ -113,6 +114,7 @@ def test_circ_cost(self, tasks, seed): # pylint: disable=redefined-outer-name
def test_circ_optimize(self, tasks, return_type): # pylint: disable=redefined-outer-name
"""Test distributed optimizations."""
skip_np()
skip_tf()

max_steps = 15
make_circ, cost_fn = wrappers()
Expand Down Expand Up @@ -156,6 +158,7 @@ def test_circ_optimize(self, tasks, return_type): # pylint: disable=redefined-o
def test_circ_optimize_metrics(self, metric_fns): # pylint: disable=redefined-outer-name
"""Tests custom metric functions on final circuits."""
skip_np()
skip_tf()

make_circ, cost_fn = wrappers()

Expand Down Expand Up @@ -194,6 +197,7 @@ def test_circ_optimize_metrics(self, metric_fns): # pylint: disable=redefined-o
def test_update_pop(self):
"""Test for coverage."""
skip_np()
skip_tf()

d = {"a": 3, "b": "foo"}
kwargs = {"b": "bar", "c": 22}
Expand All @@ -204,6 +208,7 @@ def test_update_pop(self):
def test_no_ray(self, monkeypatch):
"""Tests ray import error"""
skip_np()
skip_tf()

monkeypatch.setitem(sys.modules, "ray", None)
with pytest.raises(ImportError, match="Failed to import `ray`"):
Expand All @@ -215,6 +220,7 @@ def test_no_ray(self, monkeypatch):
def test_invalid_tasks(self):
"""Tests unexpected tasks arg"""
skip_np()
skip_tf()

with pytest.raises(
ValueError, match="`tasks` is expected to be of type int, list, or dict."
Expand All @@ -227,6 +233,7 @@ def test_invalid_tasks(self):
def test_warn_unused_kwargs(self): # pylint: disable=redefined-outer-name
"""Test warning of unused kwargs"""
skip_np()
skip_tf()

_, cost_fn = wrappers()
with pytest.warns(UserWarning, match="Unused kwargs:"):
Expand All @@ -240,6 +247,7 @@ def test_warn_unused_kwargs(self): # pylint: disable=redefined-outer-name
def test_no_pbar(self): # pylint: disable=redefined-outer-name
"""Test turning off pregress bar"""
skip_np()
skip_tf()

_, cost_fn = wrappers()
results = map_trainer(
Expand All @@ -254,6 +262,7 @@ def test_no_pbar(self): # pylint: disable=redefined-outer-name
def test_unblock(self, tasks): # pylint: disable=redefined-outer-name
"""Test unblock async mode"""
skip_np()
skip_tf()

_, cost_fn = wrappers()
result_getter = map_trainer(
Expand Down

0 comments on commit ff94a0d

Please sign in to comment.