Skip to content

Commit 6072275

Browse files
committed
Store paths relative to music dir in DB, expand on read
Move path relativization/expansion logic from Item._setitem/__getitem__ into dbcore layer (PathType.to_sql/from_sql and PathQuery), so all models benefit without per-model overrides. Propagate contextvars to pipeline and replaygain pool threads so the library root context variable is available during background processing.
1 parent d14731d commit 6072275

8 files changed

Lines changed: 132 additions & 52 deletions

File tree

beets/dbcore/pathutils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
from beets import context, util
6+
7+
8+
def normalize_path_for_db(path: bytes) -> bytes:
9+
"""Convert an absolute library path to its database representation."""
10+
if not path or not os.path.isabs(path):
11+
return path
12+
13+
music_dir = context.get_music_dir()
14+
if not music_dir:
15+
return path
16+
17+
if path == music_dir:
18+
return os.path.relpath(path, music_dir)
19+
20+
if path.startswith(os.path.join(music_dir, b"")):
21+
return os.path.relpath(path, music_dir)
22+
23+
return path
24+
25+
26+
def expand_path_from_db(path: bytes) -> bytes:
27+
"""Convert a stored database path to an absolute library path."""
28+
music_dir = context.get_music_dir()
29+
if path and not os.path.isabs(path) and music_dir:
30+
return util.normpath(os.path.join(music_dir, path))
31+
32+
return path

beets/dbcore/query.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from beets import util
3131
from beets.util.units import raw_seconds_short
3232

33+
from . import pathutils
34+
3335
if TYPE_CHECKING:
3436
from collections.abc import Iterator, MutableSequence
3537

@@ -293,6 +295,9 @@ def __init__(self, field: str, pattern: bytes, fast: bool = True) -> None:
293295

294296
# Case sensitivity depends on the filesystem that the query path is located on.
295297
self.case_sensitive = util.case_sensitive(path)
298+
# Path queries compare against the DB representation, which is relative
299+
# to the library root when the file lives inside it.
300+
path = pathutils.normalize_path_for_db(path)
296301

297302
# Use a normalized-case pattern for case-insensitive matches.
298303
if not self.case_sensitive:
@@ -333,7 +338,9 @@ def match(self, obj: Model) -> bool:
333338
starts with the given directory path. Case sensitivity depends on the object's
334339
filesystem as determined during initialization.
335340
"""
336-
path = obj.path if self.case_sensitive else obj.path.lower()
341+
path = pathutils.normalize_path_for_db(obj.path)
342+
if not self.case_sensitive:
343+
path = path.lower()
337344
return (path == self.pattern) or path.startswith(self.dir_path)
338345

339346
def col_clause(self) -> tuple[str, Sequence[SQLiteType]]:

beets/dbcore/types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
from __future__ import annotations
1818

19+
import os
1920
import re
2021
import time
2122
import typing
2223
from abc import ABC
2324
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
2425

2526
import beets
26-
from beets import util
27+
from beets import context, util
2728
from beets.util.units import human_seconds_short, raw_seconds_short
2829

29-
from . import query
30+
from . import pathutils, query
3031

3132
SQLiteType = query.SQLiteType
3233
BLOB_TYPE = query.BLOB_TYPE
@@ -389,9 +390,10 @@ def normalize(self, value: Any) -> bytes | N:
389390
return value
390391

391392
def from_sql(self, sql_value):
392-
return self.normalize(sql_value)
393+
return pathutils.expand_path_from_db(self.normalize(sql_value))
393394

394395
def to_sql(self, value: bytes) -> BLOB_TYPE:
396+
value = pathutils.normalize_path_for_db(value)
395397
if isinstance(value, bytes):
396398
value = BLOB_TYPE(value)
397399
return value

beets/library/models.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unicodedata
88
from functools import cached_property
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, ClassVar
10+
from typing import TYPE_CHECKING, ClassVar
1111

1212
from mediafile import MediaFile, UnreadableFileError
1313

@@ -81,30 +81,6 @@ def add(self, lib=None):
8181
# so don't do it here
8282
super().add(lib)
8383

84-
def _setitem(self, key: str, value: Any):
85-
"""Set the item's value for a standard field or a flexattr."""
86-
# Encode unicode paths and read buffers.
87-
if key == "path":
88-
if isinstance(value, str):
89-
value = bytestring_path(value)
90-
elif isinstance(value, types.BLOB_TYPE):
91-
value = bytes(value)
92-
# Store paths relative to the music directory
93-
# Check for absolute path because item may be initialised with
94-
# a relative path already
95-
if os.path.isabs(value) and (music_dir := context.get_music_dir()):
96-
value = os.path.relpath(value, music_dir)
97-
98-
return super()._setitem(key, value)
99-
100-
def __getitem__(self, key: str):
101-
value = super().__getitem__(key)
102-
if key == "path" and value:
103-
# Return absolute paths.
104-
value = normpath(os.path.join(context.get_music_dir(), value))
105-
106-
return value
107-
10884
def __format__(self, spec):
10985
if not spec:
11086
spec = beets.config[self._format_config_key].as_str()
@@ -125,6 +101,22 @@ def field_query(
125101
) -> FieldQuery:
126102
"""Get a `FieldQuery` for the given field on this model."""
127103
fast = field in cls.all_db_fields
104+
if (
105+
cls._type(field).query is dbcore.query.PathQuery
106+
and query_cls is not dbcore.query.PathQuery
107+
and (music_dir := context.get_music_dir())
108+
):
109+
# Regex, exact, and string queries operate on the raw DB value, so
110+
# strip the library prefix to match the stored relative path.
111+
if isinstance(pattern, bytes):
112+
prefix = os.path.join(music_dir, b"")
113+
if pattern.startswith(prefix):
114+
pattern = os.path.relpath(pattern, music_dir)
115+
else:
116+
music_dir_str = os.fsdecode(music_dir)
117+
prefix = music_dir_str + os.sep
118+
if pattern.startswith(prefix):
119+
pattern = pattern.removeprefix(prefix)
128120
if field in cls.shared_db_fields:
129121
# This field exists in both tables, so SQLite will encounter
130122
# an OperationalError if we try to use it in a query.
@@ -849,7 +841,12 @@ def from_path(cls, path):
849841
def __setitem__(self, key, value):
850842
"""Set the item's value for a standard field or a flexattr."""
851843
# Encode unicode paths and read buffers.
852-
if key == "album_id":
844+
if key == "path":
845+
if isinstance(value, str):
846+
value = bytestring_path(value)
847+
elif isinstance(value, types.BLOB_TYPE):
848+
value = bytes(value)
849+
elif key == "album_id":
853850
self._cached_album = None
854851

855852
changed = super()._setitem(key, value)

beets/util/pipeline.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from __future__ import annotations
3535

36+
import contextvars
3637
import queue
3738
import sys
3839
from threading import Lock, Thread
@@ -237,12 +238,18 @@ def _allmsgs(obj):
237238
class PipelineThread(Thread):
238239
"""Abstract base class for pipeline-stage threads."""
239240

240-
def __init__(self, all_threads):
241+
def __init__(self, all_threads, ctx: contextvars.Context | None = None):
241242
super().__init__()
242243
self.abort_lock = Lock()
243244
self.abort_flag = False
244245
self.all_threads = all_threads
245246
self.exc_info = None
247+
self.ctx = ctx
248+
249+
def _run_in_context(self, func, *args):
250+
if self.ctx is None:
251+
return func(*args)
252+
return self.ctx.run(func, *args)
246253

247254
def abort(self):
248255
"""Shut down the thread at the next chance possible."""
@@ -267,8 +274,8 @@ class FirstPipelineThread(PipelineThread):
267274
The coroutine should just be a generator.
268275
"""
269276

270-
def __init__(self, coro, out_queue, all_threads):
271-
super().__init__(all_threads)
277+
def __init__(self, coro, out_queue, all_threads, ctx=None):
278+
super().__init__(all_threads, ctx)
272279
self.coro = coro
273280
self.out_queue = out_queue
274281
self.out_queue.acquire()
@@ -282,7 +289,7 @@ def run(self):
282289

283290
# Get the value from the generator.
284291
try:
285-
msg = next(self.coro)
292+
msg = self._run_in_context(next, self.coro)
286293
except StopIteration:
287294
break
288295

@@ -306,8 +313,8 @@ class MiddlePipelineThread(PipelineThread):
306313
last.
307314
"""
308315

309-
def __init__(self, coro, in_queue, out_queue, all_threads):
310-
super().__init__(all_threads)
316+
def __init__(self, coro, in_queue, out_queue, all_threads, ctx=None):
317+
super().__init__(all_threads, ctx)
311318
self.coro = coro
312319
self.in_queue = in_queue
313320
self.out_queue = out_queue
@@ -316,7 +323,7 @@ def __init__(self, coro, in_queue, out_queue, all_threads):
316323
def run(self):
317324
try:
318325
# Prime the coroutine.
319-
next(self.coro)
326+
self._run_in_context(next, self.coro)
320327

321328
while True:
322329
with self.abort_lock:
@@ -333,7 +340,7 @@ def run(self):
333340
return
334341

335342
# Invoke the current stage.
336-
out = self.coro.send(msg)
343+
out = self._run_in_context(self.coro.send, msg)
337344

338345
# Send messages to next stage.
339346
for msg in _allmsgs(out):
@@ -355,14 +362,14 @@ class LastPipelineThread(PipelineThread):
355362
should yield nothing.
356363
"""
357364

358-
def __init__(self, coro, in_queue, all_threads):
359-
super().__init__(all_threads)
365+
def __init__(self, coro, in_queue, all_threads, ctx=None):
366+
super().__init__(all_threads, ctx)
360367
self.coro = coro
361368
self.in_queue = in_queue
362369

363370
def run(self):
364371
# Prime the coroutine.
365-
next(self.coro)
372+
self._run_in_context(next, self.coro)
366373

367374
try:
368375
while True:
@@ -380,7 +387,7 @@ def run(self):
380387
return
381388

382389
# Send to consumer.
383-
self.coro.send(msg)
390+
self._run_in_context(self.coro.send, msg)
384391

385392
except BaseException:
386393
self.abort_all(sys.exc_info())
@@ -419,26 +426,37 @@ def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
419426
messages between the stages are stored in queues of the given
420427
size.
421428
"""
429+
base_ctx = contextvars.copy_context()
422430
queue_count = len(self.stages) - 1
423431
queues = [CountedQueue(queue_size) for i in range(queue_count)]
424432
threads = []
425433

426434
# Set up first stage.
427435
for coro in self.stages[0]:
428-
threads.append(FirstPipelineThread(coro, queues[0], threads))
436+
# Each worker needs its own copy because Context objects cannot be
437+
# entered concurrently from multiple threads.
438+
threads.append(
439+
FirstPipelineThread(coro, queues[0], threads, base_ctx.copy())
440+
)
429441

430442
# Middle stages.
431443
for i in range(1, queue_count):
432444
for coro in self.stages[i]:
433445
threads.append(
434446
MiddlePipelineThread(
435-
coro, queues[i - 1], queues[i], threads
447+
coro,
448+
queues[i - 1],
449+
queues[i],
450+
threads,
451+
base_ctx.copy(),
436452
)
437453
)
438454

439455
# Last stage.
440456
for coro in self.stages[-1]:
441-
threads.append(LastPipelineThread(coro, queues[-1], threads))
457+
threads.append(
458+
LastPipelineThread(coro, queues[-1], threads, base_ctx.copy())
459+
)
442460

443461
# Start threads.
444462
for thread in threads:

beetsplug/replaygain.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import collections
19+
import contextvars
1920
import enum
2021
import math
2122
import os
@@ -1427,6 +1428,9 @@ def _apply(
14271428
callback: Callable[[AnyRgTask], Any],
14281429
):
14291430
if self.pool is not None:
1431+
# Apply the caller's context to both the worker and its callbacks
1432+
# so lazy path expansion keeps the library root in pool threads.
1433+
ctx = contextvars.copy_context()
14301434

14311435
def handle_exc(exc):
14321436
"""Handle exceptions in the async work."""
@@ -1435,8 +1439,19 @@ def handle_exc(exc):
14351439
else:
14361440
self.exc_queue.put(exc)
14371441

1442+
def run_func():
1443+
return ctx.run(func, *args, **kwds)
1444+
1445+
def run_callback(task: AnyRgTask):
1446+
return ctx.run(callback, task)
1447+
1448+
def run_handle_exc(exc):
1449+
return ctx.run(handle_exc, exc)
1450+
14381451
self.pool.apply_async(
1439-
func, args, kwds, callback, error_callback=handle_exc
1452+
run_func,
1453+
callback=run_callback,
1454+
error_callback=run_handle_exc,
14401455
)
14411456
else:
14421457
callback(func(*args, **kwds))

test/test_library.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,8 +1098,12 @@ def test_artpath_stores_special_chars(self):
10981098
alb = self.lib.add_album([self.i])
10991099
alb.artpath = path
11001100
alb.store()
1101+
stored_path = self.lib._connection().execute(
1102+
"select artpath from albums where id=?", (alb.id,)
1103+
).fetchone()[0]
11011104
alb = self.lib.get_album(self.i)
1102-
assert path == alb.artpath
1105+
assert stored_path == path
1106+
assert alb.artpath == os.path.join(self.libdir, path)
11031107

11041108
def test_sanitize_path_with_special_chars(self):
11051109
path = "b\xe1r?"
@@ -1129,10 +1133,13 @@ def test_relative_path_is_stored(self):
11291133
absolute_path = os.path.join(self.libdir, relative_path)
11301134
self.i.path = absolute_path
11311135
self.i.store()
1136+
stored_path = self.lib._connection().execute(
1137+
"select path from items where id=?", (self.i.id,)
1138+
).fetchone()[0]
11321139
album = self.lib.add_album([self.i])
11331140

11341141
assert self.i.path == absolute_path
1135-
assert self.i._values_fixed["path"] == relative_path
1142+
assert stored_path == relative_path
11361143
assert album.path == os.path.dirname(absolute_path)
11371144

11381145

0 commit comments

Comments
 (0)