Skip to content

Commit c81cfeb

Browse files
committed
Handle list fields in rewrite
1 parent bfd20b2 commit c81cfeb

4 files changed

Lines changed: 41 additions & 17 deletions

File tree

beetsplug/advancedrewrite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from beets.plugins import BeetsPlugin
2727
from beets.ui import UserError
2828

29+
from .rewrite import apply_rewrite_rules
30+
2931

3032
def rewriter(field, simple_rules, advanced_rules):
3133
"""Template field function factory.
@@ -38,10 +40,10 @@ def rewriter(field, simple_rules, advanced_rules):
3840

3941
def fieldfunc(item):
4042
value = item._values_fixed[field]
41-
for pattern, replacement in simple_rules:
42-
if pattern.match(value.lower()):
43-
# Rewrite activated.
44-
return replacement
43+
if (new_value := apply_rewrite_rules(value, simple_rules)) != value:
44+
# Rewrite activated.
45+
return new_value
46+
4547
for query, replacement in advanced_rules:
4648
if query.match(item):
4749
# Rewrite activated.

beetsplug/rewrite.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,44 @@
1818

1919
import re
2020
from collections import defaultdict
21+
from functools import singledispatch
22+
from typing import Any, TypeVar
2123

2224
from beets import library, ui
2325
from beets.plugins import BeetsPlugin
2426

27+
T = TypeVar("T")
28+
29+
30+
@singledispatch
31+
def rewrite_value(value: Any, pat: re.Pattern[str], repl: str) -> Any:
32+
"""Rewrite a value if it matches the given pattern."""
33+
return value
34+
35+
36+
@rewrite_value.register
37+
def _(value: str, pat: re.Pattern[str], repl: str) -> str:
38+
if pat.match(value.lower()):
39+
return repl
40+
return value
41+
42+
43+
@rewrite_value.register(list)
44+
def _(value: list[str], pat: re.Pattern[str], repl: str) -> list[str]:
45+
return [rewrite_value(v, pat, repl) for v in value]
46+
47+
48+
def apply_rewrite_rules(
49+
value: T, rules: list[tuple[re.Pattern[str], str]]
50+
) -> T:
51+
"""Apply the first matching rewrite rule to the given value."""
52+
for pattern, replacement in rules:
53+
if (new_value := rewrite_value(value, pattern, replacement)) != value:
54+
# Rewrite activated.
55+
return new_value
56+
# Not activated; return original value.
57+
return value
58+
2559

2660
def rewriter(field, rules):
2761
"""Create a template field function that rewrites the given field
@@ -30,13 +64,7 @@ def rewriter(field, rules):
3064
"""
3165

3266
def fieldfunc(item):
33-
value = item._values_fixed[field]
34-
for pattern, replacement in rules:
35-
if pattern.match(value.lower()):
36-
# Rewrite activated.
37-
return replacement
38-
# Not activated; return original value.
39-
return value
67+
return apply_rewrite_rules(item._values_fixed[field], rules)
4068

4169
return fieldfunc
4270

test/plugins/test_advancedrewrite.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def test_simple_rewrite_example(self):
3737

3838
assert item.artist == "이달의 소녀 오드아이써클"
3939

40-
@pytest.mark.xfail(
41-
reason="advancedrewrite currently assumes scalar field values",
42-
)
4340
def test_list_field(self):
4441
with self.configure_plugin([{"genres rock": "techno"}]):
4542
item = self.add_item(genres=["rock", "pop"])

test/plugins/test_rewrite.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ def test_rewrite_is_case_insensitive_and_leaves_non_matches_unchanged(
5252
assert matching_item.artist == "LOONA / ODD EYE CIRCLE"
5353
assert other_item.artist == "ARTMS"
5454

55-
@pytest.mark.xfail(
56-
reason="rewrite currently assumes scalar field values",
57-
)
5855
def test_genres_rewrite_applies_to_matching_list_values(self):
5956
with self.configure_plugin({"genres rock": "Classic Rock"}):
6057
item = self.add_item(genres=["rock", "pop"])

0 commit comments

Comments
 (0)