diff --git a/README.rst b/README.rst index df58e8e8..344e6386 100644 --- a/README.rst +++ b/README.rst @@ -90,9 +90,7 @@ Installation Install by running:: - sudo python3 -m pip install --break-system-packages https://github.com/mopidy/mopidy-spotify/archive/refs/tags/v5.0.0a1.zip - -This is currently the only supported installation method. + sudo python3 -m pip install --break-system-packages Mopidy-Spotify==5.0.0a1 Configuration diff --git a/mopidy_spotify/playlists.py b/mopidy_spotify/playlists.py index 493d602c..ce7ae78c 100644 --- a/mopidy_spotify/playlists.py +++ b/mopidy_spotify/playlists.py @@ -1,6 +1,8 @@ import logging +import threading from mopidy import backend +from mopidy.core import listener from mopidy_spotify import translator, utils @@ -12,20 +14,19 @@ class SpotifyPlaylistsProvider(backend.PlaylistsProvider): def __init__(self, backend): self._backend = backend self._timeout = self._backend._config["spotify"]["timeout"] - self._loaded = False + self._refresh_mutex = threading.Lock() def as_list(self): with utils.time_logger("playlists.as_list()", logging.DEBUG): - if not self._loaded: - return [] - return list(self._get_flattened_playlist_refs()) - def _get_flattened_playlist_refs(self): + def _get_flattened_playlist_refs(self, *, refresh=False): if not self._backend._web_client.logged_in: return [] - user_playlists = self._backend._web_client.get_user_playlists() + user_playlists = self._backend._web_client.get_user_playlists( + refresh=refresh + ) return translator.to_playlist_refs( user_playlists, self._backend._web_client.user_id ) @@ -49,18 +50,48 @@ def _get_playlist(self, uri, as_items=False): def refresh(self): if not self._backend._web_client.logged_in: return - - logger.info("Refreshing Spotify playlists") - - with utils.time_logger("playlists.refresh()", logging.DEBUG): - self._backend._web_client.clear_cache() - count = 0 - for playlist_ref in self._get_flattened_playlist_refs(): - self._get_playlist(playlist_ref.uri) - count = count + 1 - logger.info(f"Refreshed {count} Spotify playlists") - - self._loaded = True + if not self._refresh_mutex.acquire(blocking=False): + logger.info("Refreshing Spotify playlists already in progress") + return + try: + uris = [ + ref.uri + for ref in self._get_flattened_playlist_refs(refresh=True) + ] + logger.info( + f"Refreshing {len(uris)} Spotify playlists in background" + ) + threading.Thread( + target=self._refresh_tracks, + args=(uris,), + daemon=True, + ).start() + except Exception: + logger.exception( + "Error occurred while refreshing Spotify playlists" + ) + self._refresh_mutex.release() + + def _refresh_tracks(self, playlist_uris): + if not self._refresh_mutex.locked(): + logger.error("Lock must be held before calling this method") + return [] + try: + with utils.time_logger( + "playlists._refresh_tracks()", logging.DEBUG + ): + refreshed = [uri for uri in playlist_uris if self.lookup(uri)] + logger.info(f"Refreshed {len(refreshed)} Spotify playlists") + + listener.CoreListener.send("playlists_loaded") + except Exception: + logger.exception( + "Error occurred while refreshing Spotify playlists tracks" + ) + else: + return refreshed # For test + finally: + self._refresh_mutex.release() def create(self, name): pass # TODO diff --git a/mopidy_spotify/web.py b/mopidy_spotify/web.py index 23da1d0d..5d5926b7 100644 --- a/mopidy_spotify/web.py +++ b/mopidy_spotify/web.py @@ -2,13 +2,15 @@ import logging import os import re +import threading import time import urllib.parse from dataclasses import dataclass from datetime import datetime +from email.utils import parsedate_to_datetime from enum import Enum, unique from typing import Optional -from email.utils import parsedate_to_datetime + import requests @@ -63,6 +65,11 @@ def __init__( self._headers = {"Content-Type": "application/json"} self._session = utils.get_requests_session(proxy_config or {}) + # TODO: Move _cache_mutex to the object it actually protects. + self._cache_mutex = threading.Lock() # Protects get() cache param. + self._refresh_mutex = ( + threading.Lock() + ) # Protects _headers and _expires. def get(self, path, cache=None, *args, **kwargs): if self._authorization_failed: @@ -74,21 +81,22 @@ def get(self, path, cache=None, *args, **kwargs): _trace(f"Get '{path}'") - ignore_expiry = kwargs.pop("ignore_expiry", False) + expiry_strategy = kwargs.pop("expiry_strategy", None) if cache is not None and path in cache: cached_result = cache.get(path) - if cached_result.still_valid(ignore_expiry): + if cached_result.still_valid(expiry_strategy=expiry_strategy): return cached_result kwargs.setdefault("headers", {}).update(cached_result.etag_headers) # TODO: Factor this out once we add more methods. # TODO: Don't silently error out. - try: - if self._should_refresh_token(): - self._refresh_token() - except OAuthTokenRefreshError as e: - logger.error(e) - return WebResponse(None, None) + with self._refresh_mutex: + try: + if self._should_refresh_token(): + self._refresh_token() + except OAuthTokenRefreshError as e: + logger.error(e) # noqa: TRY400 + return WebResponse(None, None) # Make sure our headers always override user supplied ones. kwargs.setdefault("headers", {}).update(self._headers) @@ -101,11 +109,12 @@ def get(self, path, cache=None, *args, **kwargs): ) return WebResponse(None, None) - if self._should_cache_response(cache, result): - previous_result = cache.get(path) - if previous_result and previous_result.updated(result): - result = previous_result - cache[path] = result + with self._cache_mutex: + if self._should_cache_response(cache, result): + previous_result = cache.get(path) + if previous_result and previous_result.updated(result): + result = previous_result + cache[path] = result return result @@ -114,11 +123,16 @@ def _should_cache_response(self, cache, response): def _should_refresh_token(self): # TODO: Add jitter to margin? + if not self._refresh_mutex.locked(): + raise OAuthTokenRefreshError("Lock must be held before calling.") return not self._auth or time.time() > self._expires - self._margin def _refresh_token(self): logger.debug(f"Fetching OAuth token from {self._refresh_url}") + if not self._refresh_mutex.locked(): + raise OAuthTokenRefreshError("Lock must be held before calling.") + data = {"grant_type": "client_credentials"} result = self._request_with_retries( "POST", self._refresh_url, auth=self._auth, data=data @@ -266,6 +280,12 @@ def _parse_retry_after(self, response): return max(0, seconds) +@unique +class ExpiryStrategy(Enum): + FORCE_FRESH = "force-fresh" + FORCE_EXPIRED = "force-expired" + + class WebResponse(dict): def __init__(self, url, data, expires=0.0, etag=None, status_code=400): self._from_cache = False @@ -324,19 +344,22 @@ def _parse_etag(response): if etag and len(etag.groups()) == 2: return etag.groups()[1] - def still_valid(self, ignore_expiry=False): - if ignore_expiry: - result = True - status = "forced" - elif self._expires >= time.time(): - result = True - status = "fresh" + return None + + def still_valid(self, *, expiry_strategy=None): + if expiry_strategy is None: + if self._expires >= time.time(): + valid = True + status = "fresh" + else: + valid = False + status = "expired" else: - result = False - status = "expired" - self._from_cache = result + valid = expiry_strategy is ExpiryStrategy.FORCE_FRESH + status = expiry_strategy.value + self._from_cache = valid _trace("Cached data %s for %s", status, self) - return result + return valid @property def status_unchanged(self): @@ -432,9 +455,12 @@ def login(self): def logged_in(self): return self.user_id is not None - def get_user_playlists(self): + def get_user_playlists(self, *, refresh=False): + expiry_strategy = ExpiryStrategy.FORCE_EXPIRED if refresh else None pages = self.get_all( - f"users/{self.user_id}/playlists", params={"limit": 50} + f"users/{self.user_id}/playlists", + params={"limit": 50}, + expiry_strategy=expiry_strategy, ) for page in pages: yield from page.get("items", []) @@ -446,7 +472,9 @@ def _with_all_tracks(self, obj, params=None): track_pages = self.get_all( tracks_path, params=params, - ignore_expiry=obj.status_unchanged, + expiry_strategy=( + ExpiryStrategy.FORCE_FRESH if obj.status_unchanged else None + ), ) more_tracks = [] @@ -531,9 +559,6 @@ def get_track(self, web_link): f"tracks/{web_link.id}", params={"market": "from_token"} ) - def clear_cache(self, extra_expiry=None): - self._cache.clear() - @unique class LinkType(Enum): diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..30953daa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,16 @@ +import threading + + +class ThreadJoiner: + def __init__(self, timeout: int = 1): + self.timeout = timeout + + def __enter__(self): + self.before = set(threading.enumerate()) + + def __exit__(self, exc_type, exc_val, exc_tb): + new_threads = set(threading.enumerate()) - self.before + for thread in new_threads: + thread.join(timeout=self.timeout) + if thread.is_alive(): + raise RuntimeError(f"Timeout joining thread {thread}") diff --git a/tests/test_backend.py b/tests/test_backend.py index 7a1dd38e..19467080 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -6,6 +6,8 @@ from mopidy_spotify import backend, library, playlists from mopidy_spotify.backend import SpotifyPlaybackProvider +from tests import ThreadJoiner + def get_backend(config): obj = backend.SpotifyBackend(config=config, audio=None) @@ -60,7 +62,8 @@ def test_on_start_configures_proxy(web_mock, config): "password": "s3cret", } backend = get_backend(config) - backend.on_start() + with ThreadJoiner(): + backend.on_start() assert True @@ -76,7 +79,8 @@ def test_on_start_configures_web_client(web_mock, config): config["spotify"]["client_secret"] = "AbCdEfG" backend = get_backend(config) - backend.on_start() + with ThreadJoiner(): + backend.on_start() web_mock.SpotifyOAuthClient.assert_called_once_with( client_id="1234567", @@ -94,12 +98,13 @@ def test_on_start_logs_in(web_mock, config): def test_on_start_refreshes_playlists(web_mock, config, caplog): backend = get_backend(config) - backend.on_start() + with ThreadJoiner(): + backend.on_start() client_mock = web_mock.SpotifyOAuthClient.return_value - client_mock.get_user_playlists.assert_called_once() + client_mock.get_user_playlists.assert_called_once_with(refresh=True) + assert "Refreshing 0 Spotify playlists in background" in caplog.text assert "Refreshed 0 Spotify playlists" in caplog.text - assert backend.playlists._loaded def test_on_start_doesnt_refresh_playlists_if_not_allowed( diff --git a/tests/test_playlists.py b/tests/test_playlists.py index 8021335c..076ee4b9 100644 --- a/tests/test_playlists.py +++ b/tests/test_playlists.py @@ -1,3 +1,4 @@ +import logging from unittest import mock import pytest @@ -6,6 +7,8 @@ from mopidy_spotify import playlists +from tests import ThreadJoiner + @pytest.fixture def web_client_mock(web_client_mock, web_track_mock): @@ -43,9 +46,7 @@ def get_playlist(*args, **kwargs): @pytest.fixture def provider(backend_mock, web_client_mock): backend_mock._web_client = web_client_mock - provider = playlists.SpotifyPlaylistsProvider(backend_mock) - provider._loaded = True - return provider + return playlists.SpotifyPlaylistsProvider(backend_mock) def test_is_a_playlists_provider(provider): @@ -68,14 +69,6 @@ def test_as_list_when_offline(web_client_mock, provider): assert len(result) == 0 -def test_as_list_when_not_loaded(provider): - provider._loaded = False - - result = provider.as_list() - - assert len(result) == 0 - - def test_as_list_when_playlist_wont_translate(provider): result = provider.as_list() @@ -120,15 +113,6 @@ def test_get_items_when_offline(web_client_mock, provider, caplog): ) -def test_get_items_when_not_loaded(provider): - provider._loaded = False - - result = provider.get_items("spotify:user:alice:playlist:foo") - - assert len(result) == 1 - assert result[0] == Ref.track(uri="spotify:track:abc", name="ABC 123") - - def test_get_items_when_playlist_wont_translate(provider): assert provider.get_items("spotify:user:alice:playlist:malformed") is None @@ -142,7 +126,8 @@ def test_get_items_when_playlist_is_unknown(provider, caplog): def test_refresh_loads_all_playlists(provider, web_client_mock): - provider.refresh() + with ThreadJoiner(): + provider.refresh() web_client_mock.get_user_playlists.assert_called_once() assert web_client_mock.get_playlist.call_count == 2 @@ -154,36 +139,94 @@ def test_refresh_loads_all_playlists(provider, web_client_mock): def test_refresh_when_not_logged_in(provider, web_client_mock): - provider._loaded = False web_client_mock.logged_in = False - provider.refresh() + with ThreadJoiner(): + provider.refresh() web_client_mock.get_user_playlists.assert_not_called() web_client_mock.get_playlist.assert_not_called() - assert not provider._loaded -def test_refresh_sets_loaded(provider, web_client_mock): - provider._loaded = False +def test_refresh_in_progress(provider, web_client_mock, caplog): + assert provider._refresh_mutex.acquire(blocking=False) - provider.refresh() + with ThreadJoiner(): + provider.refresh() - web_client_mock.get_user_playlists.assert_called_once() - web_client_mock.get_playlist.assert_called() - assert provider._loaded + web_client_mock.get_user_playlists.assert_not_called() + web_client_mock.get_playlist.assert_not_called() + assert provider._refresh_mutex.locked() + assert "Refreshing Spotify playlists already in progress" in caplog.text -def test_refresh_counts_playlists(provider, caplog): - provider.refresh() +def test_refresh_counts_valid_playlists(provider, caplog): + caplog.set_level( + logging.INFO + ) # To avoid log corruption from debug logging. + with ThreadJoiner(): + provider.refresh() + assert "Refreshing 2 Spotify playlists in background" in caplog.text assert "Refreshed 2 Spotify playlists" in caplog.text -def test_refresh_clears_caches(provider, web_client_mock): - provider.refresh() +@mock.patch("mopidy.core.listener.CoreListener.send") +def test_refresh_triggers_playlists_loaded_event(send, provider): + with ThreadJoiner(): + provider.refresh() + + send.assert_called_once_with("playlists_loaded") + + +def test_refresh_with_refresh_true_arg(provider, web_client_mock): + with ThreadJoiner(): + provider.refresh() + + web_client_mock.get_user_playlists.assert_called_once_with(refresh=True) + + +def test_refresh_handles_error(provider, web_client_mock, caplog): + web_client_mock.get_user_playlists.side_effect = Exception() + + with ThreadJoiner(): + provider.refresh() + + assert "Error occurred while refreshing Spotify playlists" in caplog.text + assert not provider._refresh_mutex.locked() + - web_client_mock.clear_cache.assert_called_once() +def test_refresh_tracks_handles_error(provider, web_client_mock, caplog): + web_client_mock.get_playlist.side_effect = Exception() + + with ThreadJoiner(): + provider.refresh() + + assert ( + "Error occurred while refreshing Spotify playlists tracks" + in caplog.text + ) + assert not provider._refresh_mutex.locked() + + +def test_refresh_tracks_needs_lock(provider, web_client_mock, caplog): + assert provider._refresh_tracks("foo") == [] + + assert "Lock must be held before calling this method" in caplog.text + web_client_mock.get_playlist.assert_not_called() + + +def test_refresh_tracks(provider, web_client_mock, caplog): + uris = ["spotify:user:alice:playlist:foo", "spotify:user:bob:playlist:baz"] + + assert provider._refresh_mutex.acquire(blocking=False) + assert provider._refresh_tracks(uris) == uris + + expected_calls = [ + mock.call("spotify:user:alice:playlist:foo"), + mock.call("spotify:user:bob:playlist:baz"), + ] + web_client_mock.get_playlist.assert_has_calls(expected_calls) def test_lookup(provider): @@ -202,15 +245,6 @@ def test_lookup_when_not_logged_in(web_client_mock, provider): assert playlist is None -def test_lookup_when_not_loaded(provider): - provider._loaded = False - - playlist = provider.lookup("spotify:user:alice:playlist:foo") - - assert playlist.uri == "spotify:user:alice:playlist:foo" - assert playlist.name == "Foo" - - def test_lookup_when_playlist_is_empty(provider, caplog): assert provider.lookup("nothing") is None assert "Failed to lookup Spotify playlist URI 'nothing'" in caplog.text diff --git a/tests/test_translator.py b/tests/test_translator.py index 1778d342..7f784290 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -568,12 +568,6 @@ def test_returns_empty_artists_list_if_artist_is_empty( assert list(album.artists) == [] - def test_caches_results(self, web_album_mock): - album1 = translator.web_to_album(web_album_mock) - album2 = translator.web_to_album(web_album_mock) - - assert album1 is album2 - def test_web_to_album_tracks(self, web_album_mock): tracks = translator.web_to_album_tracks(web_album_mock) diff --git a/tests/test_web.py b/tests/test_web.py index e589f1d7..a273c4b9 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -50,20 +50,33 @@ def skip_refresh_token(): patcher.stop() +def test_should_refresh_token_requires_lock(oauth_client): + with pytest.raises(web.OAuthTokenRefreshError): + oauth_client._should_refresh_token() + + +def test_refresh_token_requires_lock(oauth_client): + with pytest.raises(web.OAuthTokenRefreshError): + oauth_client._refresh_token() + + def test_initial_refresh_token(oauth_client): - assert oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert oauth_client._should_refresh_token() def test_expired_refresh_token(oauth_client, mock_time): oauth_client._expires = 1060 mock_time.return_value = 1001 - assert oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert oauth_client._should_refresh_token() def test_still_valid_refresh_token(oauth_client, mock_time): oauth_client._expires = 1060 mock_time.return_value = 1000 - assert not oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert not oauth_client._should_refresh_token() def test_user_agent(oauth_client): @@ -378,7 +391,7 @@ def test_web_response_status_unchanged_from_cache(): assert not response.status_unchanged - response.still_valid(ignore_expiry=True) + response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_FRESH) assert response.status_unchanged @@ -532,8 +545,24 @@ def test_cache_response_expired( assert result["uri"] == "new" +def test_cache_response_still_valid_strategy(mock_time): + response = web.WebResponse("foo", {}, expires=9999 + 1) + mock_time.return_value = 9999 + + assert response.still_valid() is True + assert response.still_valid(expiry_strategy=None) is True + assert ( + response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_FRESH) + is True + ) + assert ( + response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_EXPIRED) + is False + ) + + @responses.activate -def test_cache_response_ignore_expiry( +def test_cache_response_force_fresh( web_response_mock, skip_refresh_token, oauth_client, mock_time, caplog ): cache = {"https://api.spotify.com/v1/tracks/abc": web_response_mock} @@ -545,11 +574,17 @@ def test_cache_response_ignore_expiry( mock_time.return_value = 9999 assert not web_response_mock.still_valid() - assert web_response_mock.still_valid(True) - assert "Cached data forced for" in caplog.text + assert "Cached data expired for" in caplog.text + + assert web_response_mock.still_valid( + expiry_strategy=web.ExpiryStrategy.FORCE_FRESH + ) + assert "Cached data force-fresh for" in caplog.text result = oauth_client.get( - "https://api.spotify.com/v1/tracks/abc", cache, ignore_expiry=True + "https://api.spotify.com/v1/tracks/abc", + cache, + expiry_strategy=web.ExpiryStrategy.FORCE_FRESH, ) assert len(responses.calls) == 0 assert result["uri"] == "spotify:track:abc" @@ -974,6 +1009,27 @@ def test_get_user_playlists_empty(self, spotify_client): assert len(responses.calls) == 1 assert len(result) == 0 + @pytest.mark.parametrize( + ("refresh", "strategy"), + [ + (True, web.ExpiryStrategy.FORCE_EXPIRED), + (False, None), + ], + ) + def test_get_user_playlists_get_all( + self, spotify_client, refresh, strategy + ): + spotify_client.get_all = mock.Mock(return_value=[]) + + result = list(spotify_client.get_user_playlists(refresh=refresh)) + + spotify_client.get_all.assert_called_once_with( + "users/alice/playlists", + params={"limit": 50}, + expiry_strategy=strategy, + ) + assert len(result) == 0 + @responses.activate def test_get_user_playlists_sets_params(self, spotify_client): responses.add(responses.GET, url("users/alice/playlists"), json={}) @@ -1179,15 +1235,8 @@ def test_get_playlist_error_msg(self, spotify_client, caplog, uri, msg): assert spotify_client.get_playlist(uri) == {} assert f"Could not parse {uri!r} as a {msg} URI" in caplog.text - def test_clear_cache(self, spotify_client): - spotify_client._cache = {"foo": "bar"} - - spotify_client.clear_cache() - - assert {} == spotify_client._cache - @pytest.mark.parametrize( - "user_id,expected", [("alice", True), (None, False)] + ("user_id", "expected"), [("alice", True), (None, False)] ) def test_logged_in(self, spotify_client, user_id, expected): spotify_client.user_id = user_id