Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
/beetsplug/_utils/requests.py @snejus
/beetsplug/_utils/musicbrainz.py @snejus
/beetsplug/musicbrainz.py @snejus

/beetsplug/lastgenre/* @JOJ0
22 changes: 8 additions & 14 deletions beetsplug/lastgenre/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,14 @@ def _try_resolve_stage(
# Run through stages: track, album, artist,
# album artist, or most popular track genre.
if isinstance(obj, library.Item) and "track" in self.sources:
if new_genres := self.client.fetch_track_genre(
obj.artist, obj.title
):
if new_genres := self.client.fetch("track", obj):
if result := _try_resolve_stage(
"track", keep_genres, new_genres, artist=obj.artist
):
return result

if "album" in self.sources:
if new_genres := self.client.fetch_album_genre(
obj.albumartist, obj.album
):
if new_genres := self.client.fetch("album", obj):
if result := _try_resolve_stage(
"album", keep_genres, new_genres, artist=obj.albumartist
):
Expand All @@ -495,11 +491,11 @@ def _try_resolve_stage(
new_genres = []
stage_artist: str | None = None
if isinstance(obj, library.Item):
new_genres = self.client.fetch_artist_genre(obj.artist)
new_genres = self.client.fetch("artist", obj)
stage_label = "artist"
stage_artist = obj.artist
elif obj.albumartist != config["va_name"].as_str():
new_genres = self.client.fetch_artist_genre(obj.albumartist)
new_genres = self.client.fetch("album_artist", obj)
stage_label = "album artist"
stage_artist = obj.albumartist
if not new_genres:
Expand All @@ -513,8 +509,8 @@ def _try_resolve_stage(
'Fetching artist genre for "{}"',
albumartist,
)
new_genres += self.client.fetch_artist_genre(
albumartist
new_genres += self.client.fetch(
"album_artist", obj, albumartist
)
if new_genres:
stage_label = "multi-valued album artist"
Expand All @@ -528,11 +524,9 @@ def _try_resolve_stage(
for item in obj.items():
item_genre = None
if "track" in self.sources:
item_genre = self.client.fetch_track_genre(
item.artist, item.title
)
item_genre = self.client.fetch("track", item)
if not item_genre:
item_genre = self.client.fetch_artist_genre(item.artist)
item_genre = self.client.fetch("artist", item)
if item_genre:
item_genres += item_genre
if item_genres:
Expand Down
76 changes: 27 additions & 49 deletions beetsplug/lastgenre/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

import traceback
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar

import pylast

Expand All @@ -30,6 +30,7 @@
if TYPE_CHECKING:
from collections.abc import Callable

from beets.library import LibModel
from beets.logging import BeetsLogger

from .utils import GenreIgnorePatterns
Expand All @@ -51,6 +52,18 @@
class LastFmClient:
"""Client for fetching genres from Last.fm."""

FETCH_METHODS: ClassVar[
dict[
str,
tuple[Callable[..., Any], Callable[[LibModel], tuple[str, ...]]],
]
] = {
"track": (LASTFM.get_track, lambda obj: (obj.artist, obj.title)),
"album": (LASTFM.get_album, lambda obj: (obj.albumartist, obj.album)),
"artist": (LASTFM.get_artist, lambda obj: (obj.artist,)),
"album_artist": (LASTFM.get_artist, lambda obj: (obj.albumartist,)),
}

def __init__(
self,
log: BeetsLogger,
Expand All @@ -67,36 +80,12 @@ def __init__(
self._ignore_patterns: GenreIgnorePatterns = ignore_patterns
self._genre_cache: GenreCache = {}

def fetch_genre(
self, lastfm_obj: pylast.Album | pylast.Artist | pylast.Track
) -> list[str]:
"""Return genres for a pylast entity. Returns an empty list if
no suitable genres are found.
"""
return self._tags_for(lastfm_obj, self._min_weight)

def _tags_for(
self,
obj: pylast.Album | pylast.Artist | pylast.Track,
min_weight: int | None = None,
def fetch_genres(
self, obj: pylast.Album | pylast.Artist | pylast.Track
) -> list[str]:
"""Core genre identification routine.

Given a pylast entity (album or track), return a list of
tag names for that entity. Return an empty list if the entity is
not found or another error occurs.

If `min_weight` is specified, tags are filtered by weight.
"""
# Work around an inconsistency in pylast where
# Album.get_top_tags() does not return TopItem instances.
# https://github.com/pylast/pylast/issues/86
obj_to_query: Any = obj
if isinstance(obj, pylast.Album):
obj_to_query = super(pylast.Album, obj)

"""Return genres for a pylast entity."""
try:
res: Any = obj_to_query.get_top_tags()
res = obj.get_top_tags()
except PYLAST_EXCEPTIONS as exc:
self._log.debug("last.fm error: {}", exc)
return []
Expand All @@ -107,13 +96,11 @@ def _tags_for(
return []

# Filter by weight (optionally).
if min_weight:
if min_weight := self._min_weight:
res = [el for el in res if (int(el.weight or 0)) >= min_weight]

# Get strings from tags.
tags: list[str] = [el.item.get_name().lower() for el in res]

return tags
return [el.item.get_name().lower() for el in res]

def _last_lookup(
self, entity: str, method: Callable[..., Any], *args: str
Expand All @@ -133,10 +120,9 @@ def _last_lookup(
args_replaced = [a.replace("\u2010", "-") for a in args]
key = f"{entity}.{'-'.join(str(a) for a in args_replaced)}"
if key not in self._genre_cache:
self._genre_cache[key] = self.fetch_genre(method(*args_replaced))
self._genre_cache[key] = self.fetch_genres(method(*args_replaced))

genres = self._genre_cache[key]

self._log.extra_debug(
"last.fm (unfiltered) {} tags: {}", entity, genres
)
Expand All @@ -147,18 +133,10 @@ def _last_lookup(
self._log, self._ignore_patterns, genres, args[0]
)

def fetch_album_genre(self, albumartist: str, albumtitle: str) -> list[str]:
"""Return genres from Last.fm for the album by albumartist."""
return self._last_lookup(
"album", LASTFM.get_album, albumartist, albumtitle
)
def fetch(self, kind: str, obj: LibModel, *args: str) -> list[str]:
"""Fetch Last.fm genres for the specified kind and entity.

def fetch_artist_genre(self, artist: str) -> list[str]:
"""Return genres from Last.fm for the artist."""
return self._last_lookup("artist", LASTFM.get_artist, artist)

def fetch_track_genre(self, trackartist: str, tracktitle: str) -> list[str]:
"""Return genres from Last.fm for the track by artist."""
return self._last_lookup(
"track", LASTFM.get_track, trackartist, tracktitle
)
Use ``args`` if provided, otherwise derive arguments from the object.
"""
method, arg_fn = self.FETCH_METHODS[kind]
return self._last_lookup(kind, method, *(args or arg_fn(obj)))
63 changes: 29 additions & 34 deletions test/plugins/test_lastgenre.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import re
from collections import defaultdict
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch

import confuse
import pytest
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_no_duplicate(self):
self._setup_config(count=99)
assert self.plugin._resolve_genres(["blues", "blues"]) == ["blues"]

def test_tags_for(self):
def test_fetch_genre(self):
class MockPylastElem:
def __init__(self, name):
self.name = name
Expand All @@ -191,9 +191,11 @@ def get_top_tags(self):
return [tag1, tag2]

plugin = lastgenre.LastGenrePlugin()
res = plugin.client._tags_for(MockPylastObj())
res = plugin.client.fetch_genres(MockPylastObj())
assert res == ["pop", "rap"]
res = plugin.client._tags_for(MockPylastObj(), min_weight=50)

plugin.client._min_weight = 50
Comment thread
snejus marked this conversation as resolved.
res = plugin.client.fetch_genres(MockPylastObj())
assert res == ["pop"]

def test_sort_by_depth(self):
Expand Down Expand Up @@ -683,27 +685,18 @@ def config(config):
),
],
)
@pytest.mark.usefixtures("config")
def test_get_genre(
config, config_values, item_genre, mock_genres, expected_result
monkeypatch, config_values, item_genre, mock_genres, expected_result
):
"""Test _get_genre with various configurations."""

def mock_fetch_track_genre(self, trackartist, tracktitle):
return mock_genres["track"]

def mock_fetch_album_genre(self, albumartist, albumtitle):
return mock_genres["album"]

def mock_fetch_artist_genre(self, artist):
return mock_genres["artist"]

# Mock the last.fm fetchers. When whitelist enabled, we can assume only
# whitelisted genres get returned, the plugin's _resolve_genre method
# ensures it.
lastgenre.client.LastFmClient.fetch_track_genre = mock_fetch_track_genre
lastgenre.client.LastFmClient.fetch_album_genre = mock_fetch_album_genre
lastgenre.client.LastFmClient.fetch_artist_genre = mock_fetch_artist_genre

monkeypatch.setattr(
"beetsplug.lastgenre.client.LastFmClient.fetch",
lambda _, kind, __: mock_genres[kind],
)
# Initialize plugin instance and item
plugin = lastgenre.LastGenrePlugin()
# Configure
Expand Down Expand Up @@ -902,7 +895,9 @@ def test_ignorelist_config_format_errors(

assert expected_error_message in str(exc_info.value)

def test_ignorelist_multivalued_album_artist_fallback(self, config):
def test_ignorelist_multivalued_album_artist_fallback(
self, monkeypatch, config
):
"""`stage_artist=None` fallback must not re-drop per-artist results."""
config["lastgenre"]["ignorelist"] = {
"Artist A": ["Metal"],
Expand All @@ -914,23 +909,23 @@ def test_ignorelist_multivalued_album_artist_fallback(self, config):
plugin = lastgenre.LastGenrePlugin()
plugin.setup()

obj = MagicMock(spec=Album)
def fake_fetch(_, kind, obj, *args):
if kind == "album_artist" and args:
album_artist = args[0]
return {
"Artist A": ["Rock"],
"Artist B": ["Metal", "Jazz"],
}[album_artist]
return []

monkeypatch.setattr(
"beetsplug.lastgenre.client.LastFmClient.fetch", fake_fetch
)

obj = Album()
obj.albumartist = "Artist Group"
obj.album = "Album Title"
obj.albumartists = ["Artist A", "Artist B"]
obj.get.return_value = []

plugin.client = MagicMock()
plugin.client.fetch_track_genre.return_value = []
plugin.client.fetch_album_genre.return_value = []

artist_genres = {
"Artist A": ["Rock"],
"Artist B": ["Metal", "Jazz"],
}
plugin.client.fetch_artist_genre.side_effect = lambda artist: (
artist_genres.get(artist, [])
)

genres, label = plugin._get_genre(obj)

Expand Down
Loading