Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport playlist refresh #385

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ Installation

Install by running::

sudo python3 -m pip install --break-system-packages https://github.com/mopidy/mopidy-spotify/archive/refs/tags/v5.0.0a1.zip

This is currently the only supported installation method.
sudo python3 -m pip install --break-system-packages Mopidy-Spotify==5.0.0a1


Configuration
Expand Down
67 changes: 49 additions & 18 deletions mopidy_spotify/playlists.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import threading

from mopidy import backend
from mopidy.core import listener

from mopidy_spotify import translator, utils

Expand All @@ -12,20 +14,19 @@ class SpotifyPlaylistsProvider(backend.PlaylistsProvider):
def __init__(self, backend):
self._backend = backend
self._timeout = self._backend._config["spotify"]["timeout"]
self._loaded = False
self._refresh_mutex = threading.Lock()

def as_list(self):
with utils.time_logger("playlists.as_list()", logging.DEBUG):
if not self._loaded:
return []

return list(self._get_flattened_playlist_refs())

def _get_flattened_playlist_refs(self):
def _get_flattened_playlist_refs(self, *, refresh=False):
if not self._backend._web_client.logged_in:
return []

user_playlists = self._backend._web_client.get_user_playlists()
user_playlists = self._backend._web_client.get_user_playlists(
refresh=refresh
)
return translator.to_playlist_refs(
user_playlists, self._backend._web_client.user_id
)
Expand All @@ -49,18 +50,48 @@ def _get_playlist(self, uri, as_items=False):
def refresh(self):
if not self._backend._web_client.logged_in:
return

logger.info("Refreshing Spotify playlists")

with utils.time_logger("playlists.refresh()", logging.DEBUG):
self._backend._web_client.clear_cache()
count = 0
for playlist_ref in self._get_flattened_playlist_refs():
self._get_playlist(playlist_ref.uri)
count = count + 1
logger.info(f"Refreshed {count} Spotify playlists")

self._loaded = True
if not self._refresh_mutex.acquire(blocking=False):
logger.info("Refreshing Spotify playlists already in progress")
return
try:
uris = [
ref.uri
for ref in self._get_flattened_playlist_refs(refresh=True)
]
logger.info(
f"Refreshing {len(uris)} Spotify playlists in background"
)
threading.Thread(
target=self._refresh_tracks,
args=(uris,),
daemon=True,
).start()
except Exception:
logger.exception(
"Error occurred while refreshing Spotify playlists"
)
self._refresh_mutex.release()

def _refresh_tracks(self, playlist_uris):
if not self._refresh_mutex.locked():
logger.error("Lock must be held before calling this method")
return []
try:
with utils.time_logger(
"playlists._refresh_tracks()", logging.DEBUG
):
refreshed = [uri for uri in playlist_uris if self.lookup(uri)]
logger.info(f"Refreshed {len(refreshed)} Spotify playlists")

listener.CoreListener.send("playlists_loaded")
except Exception:
logger.exception(
"Error occurred while refreshing Spotify playlists tracks"
)
else:
return refreshed # For test
finally:
self._refresh_mutex.release()

def create(self, name):
pass # TODO
Expand Down
87 changes: 56 additions & 31 deletions mopidy_spotify/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import logging
import os
import re
import threading
import time
import urllib.parse
from dataclasses import dataclass
from datetime import datetime
from email.utils import parsedate_to_datetime
from enum import Enum, unique
from typing import Optional
from email.utils import parsedate_to_datetime


import requests

Expand Down Expand Up @@ -63,6 +65,11 @@ def __init__(

self._headers = {"Content-Type": "application/json"}
self._session = utils.get_requests_session(proxy_config or {})
# TODO: Move _cache_mutex to the object it actually protects.
self._cache_mutex = threading.Lock() # Protects get() cache param.
self._refresh_mutex = (
threading.Lock()
) # Protects _headers and _expires.

def get(self, path, cache=None, *args, **kwargs):
if self._authorization_failed:
Expand All @@ -74,21 +81,22 @@ def get(self, path, cache=None, *args, **kwargs):

_trace(f"Get '{path}'")

ignore_expiry = kwargs.pop("ignore_expiry", False)
expiry_strategy = kwargs.pop("expiry_strategy", None)
if cache is not None and path in cache:
cached_result = cache.get(path)
if cached_result.still_valid(ignore_expiry):
if cached_result.still_valid(expiry_strategy=expiry_strategy):
return cached_result
kwargs.setdefault("headers", {}).update(cached_result.etag_headers)

# TODO: Factor this out once we add more methods.
# TODO: Don't silently error out.
try:
if self._should_refresh_token():
self._refresh_token()
except OAuthTokenRefreshError as e:
logger.error(e)
return WebResponse(None, None)
with self._refresh_mutex:
try:
if self._should_refresh_token():
self._refresh_token()
except OAuthTokenRefreshError as e:
logger.error(e) # noqa: TRY400
return WebResponse(None, None)

# Make sure our headers always override user supplied ones.
kwargs.setdefault("headers", {}).update(self._headers)
Expand All @@ -101,11 +109,12 @@ def get(self, path, cache=None, *args, **kwargs):
)
return WebResponse(None, None)

if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result
with self._cache_mutex:
if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result

return result

Expand All @@ -114,11 +123,16 @@ def _should_cache_response(self, cache, response):

def _should_refresh_token(self):
# TODO: Add jitter to margin?
if not self._refresh_mutex.locked():
raise OAuthTokenRefreshError("Lock must be held before calling.")
return not self._auth or time.time() > self._expires - self._margin

def _refresh_token(self):
logger.debug(f"Fetching OAuth token from {self._refresh_url}")

if not self._refresh_mutex.locked():
raise OAuthTokenRefreshError("Lock must be held before calling.")

data = {"grant_type": "client_credentials"}
result = self._request_with_retries(
"POST", self._refresh_url, auth=self._auth, data=data
Expand Down Expand Up @@ -266,6 +280,12 @@ def _parse_retry_after(self, response):
return max(0, seconds)


@unique
class ExpiryStrategy(Enum):
FORCE_FRESH = "force-fresh"
FORCE_EXPIRED = "force-expired"


class WebResponse(dict):
def __init__(self, url, data, expires=0.0, etag=None, status_code=400):
self._from_cache = False
Expand Down Expand Up @@ -324,19 +344,22 @@ def _parse_etag(response):
if etag and len(etag.groups()) == 2:
return etag.groups()[1]

def still_valid(self, ignore_expiry=False):
if ignore_expiry:
result = True
status = "forced"
elif self._expires >= time.time():
result = True
status = "fresh"
return None

def still_valid(self, *, expiry_strategy=None):
if expiry_strategy is None:
if self._expires >= time.time():
valid = True
status = "fresh"
else:
valid = False
status = "expired"
else:
result = False
status = "expired"
self._from_cache = result
valid = expiry_strategy is ExpiryStrategy.FORCE_FRESH
status = expiry_strategy.value
self._from_cache = valid
_trace("Cached data %s for %s", status, self)
return result
return valid

@property
def status_unchanged(self):
Expand Down Expand Up @@ -432,9 +455,12 @@ def login(self):
def logged_in(self):
return self.user_id is not None

def get_user_playlists(self):
def get_user_playlists(self, *, refresh=False):
expiry_strategy = ExpiryStrategy.FORCE_EXPIRED if refresh else None
pages = self.get_all(
f"users/{self.user_id}/playlists", params={"limit": 50}
f"users/{self.user_id}/playlists",
params={"limit": 50},
expiry_strategy=expiry_strategy,
)
for page in pages:
yield from page.get("items", [])
Expand All @@ -446,7 +472,9 @@ def _with_all_tracks(self, obj, params=None):
track_pages = self.get_all(
tracks_path,
params=params,
ignore_expiry=obj.status_unchanged,
expiry_strategy=(
ExpiryStrategy.FORCE_FRESH if obj.status_unchanged else None
),
)

more_tracks = []
Expand Down Expand Up @@ -531,9 +559,6 @@ def get_track(self, web_link):
f"tracks/{web_link.id}", params={"market": "from_token"}
)

def clear_cache(self, extra_expiry=None):
self._cache.clear()


@unique
class LinkType(Enum):
Expand Down
16 changes: 16 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import threading


class ThreadJoiner:
def __init__(self, timeout: int = 1):
self.timeout = timeout

def __enter__(self):
self.before = set(threading.enumerate())

def __exit__(self, exc_type, exc_val, exc_tb):
new_threads = set(threading.enumerate()) - self.before
for thread in new_threads:
thread.join(timeout=self.timeout)
if thread.is_alive():
raise RuntimeError(f"Timeout joining thread {thread}")
15 changes: 10 additions & 5 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mopidy_spotify import backend, library, playlists
from mopidy_spotify.backend import SpotifyPlaybackProvider

from tests import ThreadJoiner


def get_backend(config):
obj = backend.SpotifyBackend(config=config, audio=None)
Expand Down Expand Up @@ -60,7 +62,8 @@ def test_on_start_configures_proxy(web_mock, config):
"password": "s3cret",
}
backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

assert True

Expand All @@ -76,7 +79,8 @@ def test_on_start_configures_web_client(web_mock, config):
config["spotify"]["client_secret"] = "AbCdEfG"

backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

web_mock.SpotifyOAuthClient.assert_called_once_with(
client_id="1234567",
Expand All @@ -94,12 +98,13 @@ def test_on_start_logs_in(web_mock, config):

def test_on_start_refreshes_playlists(web_mock, config, caplog):
backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

client_mock = web_mock.SpotifyOAuthClient.return_value
client_mock.get_user_playlists.assert_called_once()
client_mock.get_user_playlists.assert_called_once_with(refresh=True)
assert "Refreshing 0 Spotify playlists in background" in caplog.text
assert "Refreshed 0 Spotify playlists" in caplog.text
assert backend.playlists._loaded


def test_on_start_doesnt_refresh_playlists_if_not_allowed(
Expand Down
Loading
Loading